1use crate::types::{Memory, SearchResult};
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::HashSet;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
13#[serde(rename_all = "snake_case")]
14pub enum SuggestionType {
15 TopicMatch,
17 FrequentlyUsed,
19 SemanticallySimilar,
21 NeedsReview,
23 RelatedContext,
25 PotentialConflict,
27 RecentlyAdded,
29 SuggestCreate,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct Suggestion {
36 pub memory: Option<Memory>,
38 pub suggestion_type: SuggestionType,
40 pub relevance: f32,
42 pub reason: String,
44 pub trigger_keywords: Vec<String>,
46 pub confidence: f32,
48 pub suggested_content: Option<String>,
50 pub generated_at: DateTime<Utc>,
52}
53
54impl Suggestion {
55 pub fn new(
57 memory: Option<Memory>,
58 suggestion_type: SuggestionType,
59 relevance: f32,
60 reason: impl Into<String>,
61 ) -> Self {
62 Self {
63 memory,
64 suggestion_type,
65 relevance,
66 reason: reason.into(),
67 trigger_keywords: vec![],
68 confidence: relevance,
69 suggested_content: None,
70 generated_at: Utc::now(),
71 }
72 }
73
74 pub fn with_keywords(mut self, keywords: Vec<String>) -> Self {
76 self.trigger_keywords = keywords;
77 self
78 }
79
80 pub fn with_confidence(mut self, confidence: f32) -> Self {
82 self.confidence = confidence;
83 self
84 }
85
86 pub fn with_suggested_content(mut self, content: impl Into<String>) -> Self {
88 self.suggested_content = Some(content.into());
89 self
90 }
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct SuggestionConfig {
96 pub max_suggestions: usize,
98 pub min_relevance: f32,
100 pub recency_weight: f32,
102 pub frequency_weight: f32,
104 pub semantic_weight: f32,
106 pub keyword_weight: f32,
108 pub recency_window_days: i64,
110 pub enable_create_suggestions: bool,
112}
113
114impl Default for SuggestionConfig {
115 fn default() -> Self {
116 Self {
117 max_suggestions: 5,
118 min_relevance: 0.3,
119 recency_weight: 0.2,
120 frequency_weight: 0.15,
121 semantic_weight: 0.4,
122 keyword_weight: 0.25,
123 recency_window_days: 30,
124 enable_create_suggestions: true,
125 }
126 }
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct ConversationContext {
132 pub messages: Vec<String>,
134 pub keywords: Vec<String>,
136 pub topic: Option<String>,
138 pub referenced_memories: Vec<i64>,
140 pub intent: Option<String>,
142}
143
144impl ConversationContext {
145 pub fn from_message(message: impl Into<String>) -> Self {
147 let msg = message.into();
148 let keywords = Self::extract_keywords(&msg);
149 Self {
150 messages: vec![msg],
151 keywords,
152 topic: None,
153 referenced_memories: vec![],
154 intent: None,
155 }
156 }
157
158 pub fn from_messages(messages: Vec<String>) -> Self {
160 let all_text = messages.join(" ");
161 let keywords = Self::extract_keywords(&all_text);
162 Self {
163 messages,
164 keywords,
165 topic: None,
166 referenced_memories: vec![],
167 intent: None,
168 }
169 }
170
171 fn extract_keywords(text: &str) -> Vec<String> {
173 let stop_words: HashSet<&str> = [
175 "the", "a", "an", "is", "are", "was", "were", "be", "been", "being", "have", "has",
176 "had", "do", "does", "did", "will", "would", "could", "should", "may", "might", "can",
177 "this", "that", "these", "those", "i", "you", "he", "she", "it", "we", "they", "what",
178 "which", "who", "when", "where", "why", "how", "all", "each", "every", "both", "few",
179 "more", "most", "other", "some", "such", "no", "nor", "not", "only", "own", "same",
180 "so", "than", "too", "very", "just", "and", "but", "or", "if", "because", "as",
181 "until", "while", "of", "at", "by", "for", "with", "about", "against", "between",
182 "into", "through", "during", "before", "after", "above", "below", "to", "from", "up",
183 "down", "in", "out", "on", "off", "over", "under", "again", "further", "then", "once",
184 "here", "there", "any", "your", "my", "his", "her", "its", "our", "their", "need",
185 "want", "like", "know", "think", "make",
186 ]
187 .iter()
188 .cloned()
189 .collect();
190
191 text.to_lowercase()
192 .split(|c: char| !c.is_alphanumeric())
193 .filter(|word| word.len() > 2 && !stop_words.contains(word))
194 .map(String::from)
195 .collect::<HashSet<_>>()
196 .into_iter()
197 .collect()
198 }
199
200 pub fn with_topic(mut self, topic: impl Into<String>) -> Self {
202 self.topic = Some(topic.into());
203 self
204 }
205
206 pub fn with_referenced_memories(mut self, ids: Vec<i64>) -> Self {
208 self.referenced_memories = ids;
209 self
210 }
211
212 pub fn with_intent(mut self, intent: impl Into<String>) -> Self {
214 self.intent = Some(intent.into());
215 self
216 }
217}
218
219pub struct SuggestionEngine {
221 config: SuggestionConfig,
222}
223
224impl Default for SuggestionEngine {
225 fn default() -> Self {
226 Self::new(SuggestionConfig::default())
227 }
228}
229
230impl SuggestionEngine {
231 pub fn new(config: SuggestionConfig) -> Self {
233 Self { config }
234 }
235
236 pub fn generate_suggestions(
238 &self,
239 context: &ConversationContext,
240 memories: &[Memory],
241 search_results: Option<&[SearchResult]>,
242 ) -> Vec<Suggestion> {
243 let mut suggestions = Vec::new();
244
245 let mut scored_memories: Vec<(f32, &Memory, SuggestionType, String)> = memories
247 .iter()
248 .filter(|m| !context.referenced_memories.contains(&m.id))
249 .filter_map(|memory| {
250 let (score, suggestion_type, reason) =
251 self.score_memory(memory, context, search_results);
252 if score >= self.config.min_relevance {
253 Some((score, memory, suggestion_type, reason))
254 } else {
255 None
256 }
257 })
258 .collect();
259
260 scored_memories.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
262
263 for (score, memory, suggestion_type, reason) in scored_memories
265 .into_iter()
266 .take(self.config.max_suggestions)
267 {
268 let keywords: Vec<String> = context
269 .keywords
270 .iter()
271 .filter(|kw| memory.content.to_lowercase().contains(&kw.to_lowercase()))
272 .cloned()
273 .collect();
274
275 suggestions.push(
276 Suggestion::new(Some(memory.clone()), suggestion_type, score, reason)
277 .with_keywords(keywords),
278 );
279 }
280
281 if self.config.enable_create_suggestions {
283 if let Some(create_suggestion) = self.suggest_create(context) {
284 suggestions.push(create_suggestion);
285 }
286 }
287
288 suggestions
289 }
290
291 fn score_memory(
293 &self,
294 memory: &Memory,
295 context: &ConversationContext,
296 search_results: Option<&[SearchResult]>,
297 ) -> (f32, SuggestionType, String) {
298 let mut total_score = 0.0;
299 let mut suggestion_type = SuggestionType::TopicMatch;
300 let mut reasons = Vec::new();
301
302 let keyword_score = self.calculate_keyword_score(memory, context);
304 if keyword_score > 0.0 {
305 total_score += keyword_score * self.config.keyword_weight;
306 reasons.push(format!(
307 "matches keywords ({}%)",
308 (keyword_score * 100.0) as i32
309 ));
310 }
311
312 if let Some(results) = search_results {
314 if let Some(result) = results.iter().find(|r| r.memory.id == memory.id) {
315 let semantic_score = result.match_info.semantic_score.unwrap_or(0.0);
316 total_score += semantic_score * self.config.semantic_weight;
317 if semantic_score > 0.5 {
318 suggestion_type = SuggestionType::SemanticallySimilar;
319 reasons.push(format!(
320 "semantically similar ({}%)",
321 (semantic_score * 100.0) as i32
322 ));
323 }
324 }
325 }
326
327 let recency_score = self.calculate_recency_score(memory);
329 total_score += recency_score * self.config.recency_weight;
330 if recency_score > 0.8 {
331 if total_score > 0.5 {
332 suggestion_type = SuggestionType::RecentlyAdded;
333 }
334 reasons.push("recently updated".to_string());
335 }
336
337 let frequency_score = self.calculate_frequency_score(memory);
339 total_score += frequency_score * self.config.frequency_weight;
340 if frequency_score > 0.7 {
341 suggestion_type = SuggestionType::FrequentlyUsed;
342 reasons.push("frequently accessed".to_string());
343 }
344
345 if self.might_conflict(memory, context) {
347 suggestion_type = SuggestionType::PotentialConflict;
348 reasons.push("might contain conflicting information".to_string());
349 }
350
351 if self.needs_review(memory) {
353 suggestion_type = SuggestionType::NeedsReview;
354 reasons.push("may need review (outdated)".to_string());
355 }
356
357 let reason = if reasons.is_empty() {
358 "Related to conversation".to_string()
359 } else {
360 reasons.join(", ")
361 };
362
363 (total_score.min(1.0), suggestion_type, reason)
364 }
365
366 fn calculate_keyword_score(&self, memory: &Memory, context: &ConversationContext) -> f32 {
368 if context.keywords.is_empty() {
369 return 0.0;
370 }
371
372 let content_lower = memory.content.to_lowercase();
373 let tags_lower: Vec<String> = memory.tags.iter().map(|t| t.to_lowercase()).collect();
374
375 let matches: usize = context
376 .keywords
377 .iter()
378 .filter(|kw| {
379 let kw_lower = kw.to_lowercase();
380 content_lower.contains(&kw_lower)
381 || tags_lower.iter().any(|t| t.contains(&kw_lower))
382 })
383 .count();
384
385 (matches as f32 / context.keywords.len() as f32).min(1.0)
386 }
387
388 fn calculate_recency_score(&self, memory: &Memory) -> f32 {
390 let age_days = (Utc::now() - memory.updated_at).num_days() as f32;
391 let window = self.config.recency_window_days as f32;
392
393 if age_days <= 0.0 {
394 1.0
395 } else if age_days >= window {
396 0.0
397 } else {
398 1.0 - (age_days / window)
399 }
400 }
401
402 fn calculate_frequency_score(&self, memory: &Memory) -> f32 {
404 (memory.access_count as f32 / 100.0).min(1.0)
406 }
407
408 fn might_conflict(&self, memory: &Memory, context: &ConversationContext) -> bool {
410 let contradiction_pairs = [
412 ("true", "false"),
413 ("yes", "no"),
414 ("enable", "disable"),
415 ("start", "stop"),
416 ("add", "remove"),
417 ("create", "delete"),
418 ];
419
420 let content_lower = memory.content.to_lowercase();
421 let context_text = context.messages.join(" ").to_lowercase();
422
423 for (word1, word2) in contradiction_pairs {
424 if (content_lower.contains(word1) && context_text.contains(word2))
425 || (content_lower.contains(word2) && context_text.contains(word1))
426 {
427 return true;
428 }
429 }
430
431 false
432 }
433
434 fn needs_review(&self, memory: &Memory) -> bool {
436 let age_days = (Utc::now() - memory.updated_at).num_days();
437 let last_access_days = memory
438 .last_accessed_at
439 .map(|dt| (Utc::now() - dt).num_days())
440 .unwrap_or(age_days);
441
442 age_days > 90 && last_access_days > 30
444 }
445
446 fn suggest_create(&self, context: &ConversationContext) -> Option<Suggestion> {
448 let context_text = context.messages.join(" ").to_lowercase();
450
451 let create_triggers = [
452 ("decide", "Decision detected in conversation"),
453 ("agreed", "Agreement detected in conversation"),
454 ("remember", "User wants to remember something"),
455 ("important", "Important information mentioned"),
456 ("todo", "Task or todo mentioned"),
457 ("deadline", "Deadline mentioned"),
458 ("bug", "Bug or issue mentioned"),
459 ("fix", "Fix or solution mentioned"),
460 ("learn", "Learning opportunity detected"),
461 ];
462
463 for (trigger, reason) in create_triggers {
464 if context_text.contains(trigger) {
465 let suggested_content = context
467 .messages
468 .last()
469 .cloned()
470 .unwrap_or_else(|| context.keywords.join(" "));
471
472 return Some(
473 Suggestion::new(None, SuggestionType::SuggestCreate, 0.6, reason)
474 .with_suggested_content(suggested_content)
475 .with_keywords(context.keywords.clone()),
476 );
477 }
478 }
479
480 None
481 }
482
483 pub fn config(&self) -> &SuggestionConfig {
485 &self.config
486 }
487
488 pub fn set_config(&mut self, config: SuggestionConfig) {
490 self.config = config;
491 }
492}
493
494#[cfg(test)]
495mod tests {
496 use super::*;
497 use crate::types::{MemoryType, Visibility};
498 use std::collections::HashMap;
499
500 fn create_test_memory(id: i64, content: &str, tags: Vec<&str>) -> Memory {
501 Memory {
502 id,
503 content: content.to_string(),
504 memory_type: MemoryType::Note,
505 tags: tags.into_iter().map(String::from).collect(),
506 metadata: HashMap::new(),
507 importance: 0.5,
508 access_count: 10,
509 created_at: Utc::now() - chrono::Duration::days(5),
510 updated_at: Utc::now() - chrono::Duration::days(1),
511 last_accessed_at: Some(Utc::now() - chrono::Duration::hours(2)),
512 owner_id: None,
513 visibility: Visibility::Private,
514 scope: crate::types::MemoryScope::Global,
515 workspace: "default".to_string(),
516 tier: crate::types::MemoryTier::Permanent,
517 version: 1,
518 has_embedding: false,
519 expires_at: None,
520 content_hash: None,
521 event_time: None,
522 event_duration_seconds: None,
523 trigger_pattern: None,
524 procedure_success_count: 0,
525 procedure_failure_count: 0,
526 summary_of_id: None,
527 lifecycle_state: crate::types::LifecycleState::Active,
528 }
529 }
530
531 #[test]
532 fn test_conversation_context_keyword_extraction() {
533 let context =
534 ConversationContext::from_message("I need to fix the bug in the authentication system");
535
536 assert!(context.keywords.contains(&"fix".to_string()));
537 assert!(context.keywords.contains(&"bug".to_string()));
538 assert!(context.keywords.contains(&"authentication".to_string()));
539 assert!(context.keywords.contains(&"system".to_string()));
540 assert!(!context.keywords.contains(&"the".to_string()));
542 assert!(!context.keywords.contains(&"in".to_string()));
543 }
544
545 #[test]
546 fn test_suggestion_generation() {
547 let engine = SuggestionEngine::default();
548
549 let memories = vec![
550 create_test_memory(1, "Authentication bug fix for OAuth", vec!["bug", "auth"]),
551 create_test_memory(
552 2,
553 "Database optimization notes",
554 vec!["database", "performance"],
555 ),
556 create_test_memory(3, "OAuth configuration guide", vec!["oauth", "config"]),
557 ];
558
559 let context = ConversationContext::from_message("How do I fix the OAuth authentication?");
560
561 let suggestions = engine.generate_suggestions(&context, &memories, None);
562
563 assert!(!suggestions.is_empty());
565
566 let first = &suggestions[0];
568 assert!(first
569 .memory
570 .as_ref()
571 .map(|m| m.content.to_lowercase().contains("auth")
572 || m.content.to_lowercase().contains("oauth"))
573 .unwrap_or(false));
574 }
575
576 #[test]
577 fn test_create_suggestion() {
578 let engine = SuggestionEngine::default();
579 let memories: Vec<Memory> = vec![];
580
581 let context = ConversationContext::from_message("We decided to use JWT for authentication");
582
583 let suggestions = engine.generate_suggestions(&context, &memories, None);
584
585 let create_suggestion = suggestions
587 .iter()
588 .find(|s| s.suggestion_type == SuggestionType::SuggestCreate);
589
590 assert!(create_suggestion.is_some());
591 }
592
593 #[test]
594 fn test_keyword_score() {
595 let engine = SuggestionEngine::default();
596
597 let memory = create_test_memory(
598 1,
599 "Rust programming best practices",
600 vec!["rust", "programming"],
601 );
602 let context = ConversationContext::from_message("What are the best practices for Rust?");
603
604 let score = engine.calculate_keyword_score(&memory, &context);
605 assert!(score > 0.0);
606 }
607
608 #[test]
609 fn test_recency_score() {
610 let engine = SuggestionEngine::default();
611
612 let mut recent_memory = create_test_memory(1, "Recent note", vec![]);
613 recent_memory.updated_at = Utc::now();
614
615 let mut old_memory = create_test_memory(2, "Old note", vec![]);
616 old_memory.updated_at = Utc::now() - chrono::Duration::days(60);
617
618 let recent_score = engine.calculate_recency_score(&recent_memory);
619 let old_score = engine.calculate_recency_score(&old_memory);
620
621 assert!(recent_score > old_score);
622 assert!(recent_score > 0.9);
623 }
624
625 #[test]
626 fn test_needs_review() {
627 let engine = SuggestionEngine::default();
628
629 let mut old_memory = create_test_memory(1, "Old content", vec![]);
630 old_memory.updated_at = Utc::now() - chrono::Duration::days(100);
631 old_memory.last_accessed_at = Some(Utc::now() - chrono::Duration::days(40));
632
633 assert!(engine.needs_review(&old_memory));
634
635 let mut recent_memory = create_test_memory(2, "Recent content", vec![]);
636 recent_memory.updated_at = Utc::now() - chrono::Duration::days(10);
637
638 assert!(!engine.needs_review(&recent_memory));
639 }
640}