rexis_rag/agent/memory/
episodic.rs

1//! Episodic memory - summarized conversation history
2//!
3//! Episodic memory stores summarized versions of past conversations and important
4//! events. It's agent-scoped and provides long-term context without storing full
5//! conversation transcripts.
6
7use crate::error::RragResult;
8use crate::storage::{Memory, MemoryValue};
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11
12#[cfg(feature = "rexis-llm-client")]
13use rexis_llm::{ChatMessage, Client, MessageRole};
14
15/// An episode (summarized interaction or event)
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Episode {
18    /// Unique identifier
19    pub id: String,
20
21    /// When this episode occurred
22    pub timestamp: chrono::DateTime<chrono::Utc>,
23
24    /// Summary of the interaction/event
25    pub summary: String,
26
27    /// Key topics discussed
28    pub topics: Vec<String>,
29
30    /// Importance score (0.0 to 1.0)
31    pub importance: f64,
32
33    /// Optional session ID this episode is from
34    pub session_id: Option<String>,
35
36    /// Extracted facts or insights
37    pub insights: Vec<String>,
38
39    /// Optional metadata
40    pub metadata: std::collections::HashMap<String, String>,
41}
42
43impl Episode {
44    /// Create a new episode
45    pub fn new(summary: impl Into<String>) -> Self {
46        Self {
47            id: uuid::Uuid::new_v4().to_string(),
48            timestamp: chrono::Utc::now(),
49            summary: summary.into(),
50            topics: Vec::new(),
51            importance: 0.5,
52            session_id: None,
53            insights: Vec::new(),
54            metadata: std::collections::HashMap::new(),
55        }
56    }
57
58    /// Set topics
59    pub fn with_topics(mut self, topics: Vec<String>) -> Self {
60        self.topics = topics;
61        self
62    }
63
64    /// Set importance
65    pub fn with_importance(mut self, importance: f64) -> Self {
66        self.importance = importance.clamp(0.0, 1.0);
67        self
68    }
69
70    /// Set session ID
71    pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
72        self.session_id = Some(session_id.into());
73        self
74    }
75
76    /// Add insights
77    pub fn with_insights(mut self, insights: Vec<String>) -> Self {
78        self.insights = insights;
79        self
80    }
81
82    /// Add metadata
83    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
84        self.metadata.insert(key.into(), value.into());
85        self
86    }
87}
88
89/// Episodic memory for long-term context
90pub struct EpisodicMemory {
91    /// Storage backend
92    storage: Arc<dyn Memory>,
93
94    /// Namespace for this episodic memory (agent::{agent_id}::episodic)
95    namespace: String,
96
97    /// Maximum number of episodes to retain
98    max_episodes: usize,
99}
100
101impl EpisodicMemory {
102    /// Create a new episodic memory
103    pub fn new(storage: Arc<dyn Memory>, agent_id: String) -> Self {
104        let namespace = format!("agent::{}::episodic", agent_id);
105
106        Self {
107            storage,
108            namespace,
109            max_episodes: 1000,
110        }
111    }
112
113    /// Create episodic memory with custom max episodes
114    pub fn with_max_episodes(mut self, max: usize) -> Self {
115        self.max_episodes = max;
116        self
117    }
118
119    /// Store an episode
120    pub async fn store_episode(&self, episode: Episode) -> RragResult<()> {
121        let key = self.episode_key(&episode.id);
122        let value = serde_json::to_value(&episode).map_err(|e| {
123            crate::error::RragError::storage(
124                "serialize_episode",
125                std::io::Error::new(std::io::ErrorKind::Other, e),
126            )
127        })?;
128
129        self.storage.set(&key, MemoryValue::Json(value)).await?;
130
131        // Prune old episodes if exceeded max
132        self.prune_if_needed().await?;
133
134        Ok(())
135    }
136
137    /// Retrieve an episode by ID
138    pub async fn get_episode(&self, episode_id: &str) -> RragResult<Option<Episode>> {
139        let key = self.episode_key(episode_id);
140        if let Some(value) = self.storage.get(&key).await? {
141            if let Some(json) = value.as_json() {
142                let episode = serde_json::from_value(json.clone()).map_err(|e| {
143                    crate::error::RragError::storage(
144                        "deserialize_episode",
145                        std::io::Error::new(std::io::ErrorKind::Other, e),
146                    )
147                })?;
148                return Ok(Some(episode));
149            }
150        }
151        Ok(None)
152    }
153
154    /// Get recent episodes
155    pub async fn get_recent_episodes(&self, limit: usize) -> RragResult<Vec<Episode>> {
156        let mut all_episodes = self.get_all_episodes().await?;
157
158        // Sort by timestamp descending
159        all_episodes.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
160
161        // Take limit
162        all_episodes.truncate(limit);
163
164        Ok(all_episodes)
165    }
166
167    /// Find episodes by topic
168    pub async fn find_by_topic(&self, topic: &str) -> RragResult<Vec<Episode>> {
169        let all_episodes = self.get_all_episodes().await?;
170
171        let matching = all_episodes
172            .into_iter()
173            .filter(|e| e.topics.iter().any(|t| t.contains(topic)))
174            .collect();
175
176        Ok(matching)
177    }
178
179    /// Find episodes by importance threshold
180    pub async fn find_by_importance(&self, min_importance: f64) -> RragResult<Vec<Episode>> {
181        let all_episodes = self.get_all_episodes().await?;
182
183        let important = all_episodes
184            .into_iter()
185            .filter(|e| e.importance >= min_importance)
186            .collect();
187
188        Ok(important)
189    }
190
191    /// Find episodes within a date range
192    pub async fn find_by_date_range(
193        &self,
194        start: chrono::DateTime<chrono::Utc>,
195        end: chrono::DateTime<chrono::Utc>,
196    ) -> RragResult<Vec<Episode>> {
197        let all_episodes = self.get_all_episodes().await?;
198
199        let in_range = all_episodes
200            .into_iter()
201            .filter(|e| e.timestamp >= start && e.timestamp <= end)
202            .collect();
203
204        Ok(in_range)
205    }
206
207    /// Get all episodes
208    pub async fn get_all_episodes(&self) -> RragResult<Vec<Episode>> {
209        let all_keys = self.list_episode_keys().await?;
210        let mut episodes = Vec::new();
211
212        for key in all_keys {
213            if let Some(episode) = self.get_episode(&key).await? {
214                episodes.push(episode);
215            }
216        }
217
218        Ok(episodes)
219    }
220
221    /// Delete an episode
222    pub async fn delete_episode(&self, episode_id: &str) -> RragResult<bool> {
223        let key = self.episode_key(episode_id);
224        self.storage.delete(&key).await
225    }
226
227    /// Count episodes
228    pub async fn count(&self) -> RragResult<usize> {
229        self.storage.count(Some(&self.namespace)).await
230    }
231
232    /// Clear all episodes
233    pub async fn clear(&self) -> RragResult<()> {
234        self.storage.clear(Some(&self.namespace)).await
235    }
236
237    /// Generate a summary from recent episodes
238    pub async fn generate_context_summary(&self, num_episodes: usize) -> RragResult<String> {
239        let recent = self.get_recent_episodes(num_episodes).await?;
240
241        if recent.is_empty() {
242            return Ok(String::new());
243        }
244
245        let mut summary = String::from("Recent interaction history:\n");
246
247        for episode in recent.iter() {
248            summary.push_str(&format!(
249                "- [{}] {}\n",
250                episode.timestamp.format("%Y-%m-%d"),
251                episode.summary
252            ));
253
254            if !episode.topics.is_empty() {
255                summary.push_str(&format!("  Topics: {}\n", episode.topics.join(", ")));
256            }
257        }
258
259        Ok(summary)
260    }
261
262    /// Prune old episodes if exceeded max_episodes
263    async fn prune_if_needed(&self) -> RragResult<()> {
264        let count = self.count().await?;
265
266        if count <= self.max_episodes {
267            return Ok(());
268        }
269
270        // Get all episodes and sort by importance and timestamp
271        let mut all_episodes = self.get_all_episodes().await?;
272
273        // Sort by importance (ascending) then timestamp (oldest first)
274        all_episodes.sort_by(|a, b| {
275            a.importance
276                .partial_cmp(&b.importance)
277                .unwrap()
278                .then(a.timestamp.cmp(&b.timestamp))
279        });
280
281        // Delete least important/oldest episodes
282        let to_delete = count - self.max_episodes;
283        for episode in all_episodes.iter().take(to_delete) {
284            self.delete_episode(&episode.id).await?;
285        }
286
287        Ok(())
288    }
289
290    /// Generate episode key
291    fn episode_key(&self, episode_id: &str) -> String {
292        format!("{}::episode::{}", self.namespace, episode_id)
293    }
294
295    /// List all episode keys (IDs)
296    async fn list_episode_keys(&self) -> RragResult<Vec<String>> {
297        use crate::storage::MemoryQuery;
298
299        let query = MemoryQuery::new().with_namespace(self.namespace.clone());
300        let all_keys = self.storage.keys(&query).await?;
301
302        // Extract episode IDs from keys
303        let prefix = format!("{}::episode::", self.namespace);
304        let ids = all_keys
305            .into_iter()
306            .filter_map(|k| k.strip_prefix(&prefix).map(String::from))
307            .collect();
308
309        Ok(ids)
310    }
311
312    /// Create an episode from conversation messages using LLM summarization (requires 'rsllm-client' feature)
313    #[cfg(feature = "rexis-llm-client")]
314    pub async fn create_episode_from_messages(
315        &self,
316        messages: &[ChatMessage],
317        llm_client: &Client,
318    ) -> RragResult<Episode> {
319        if messages.is_empty() {
320            return Err(crate::error::RragError::validation(
321                "messages",
322                "must not be empty",
323                "0 messages provided".to_string(),
324            ));
325        }
326
327        // Build conversation text
328        let mut conversation = String::new();
329        for msg in messages {
330            let content_text = match &msg.content {
331                rexis_llm::MessageContent::Text(text) => text.clone(),
332                rexis_llm::MessageContent::MultiModal { text, .. } => {
333                    text.clone().unwrap_or_default()
334                }
335            };
336
337            conversation.push_str(&format!(
338                "{}: {}\n",
339                match msg.role {
340                    MessageRole::User => "User",
341                    MessageRole::Assistant => "Assistant",
342                    MessageRole::System => "System",
343                    MessageRole::Tool => "Tool",
344                },
345                content_text
346            ));
347        }
348
349        // Create summarization prompt
350        let summary_prompt = format!(
351            "Summarize this conversation in 2-3 sentences, focusing on key topics and outcomes:\n\n{}",
352            conversation
353        );
354
355        let summary_msg = ChatMessage::user(summary_prompt);
356
357        // Generate summary using LLM
358        let response = llm_client
359            .chat_completion(vec![summary_msg])
360            .await
361            .map_err(|e| crate::error::RragError::rsllm_client("summarization", e))?;
362
363        let summary = response.content.trim().to_string();
364
365        // Extract topics (simple keyword extraction from summary)
366        let topics = self.extract_topics_from_text(&summary);
367
368        // Calculate importance (based on message count and engagement)
369        let importance = self.calculate_importance(messages.len(), &conversation);
370
371        let episode = Episode::new(summary)
372            .with_topics(topics)
373            .with_importance(importance);
374
375        Ok(episode)
376    }
377
378    /// Generate a comprehensive summary of recent episodes using LLM (requires 'rsllm-client' feature)
379    #[cfg(feature = "rexis-llm-client")]
380    pub async fn generate_llm_summary(
381        &self,
382        num_episodes: usize,
383        llm_client: &Client,
384    ) -> RragResult<String> {
385        let recent = self.get_recent_episodes(num_episodes).await?;
386
387        if recent.is_empty() {
388            return Ok(String::from("No recent episodes to summarize."));
389        }
390
391        // Build context from episodes
392        let mut episode_text = String::new();
393        for (i, episode) in recent.iter().enumerate() {
394            episode_text.push_str(&format!(
395                "{}. [{}] {}\n",
396                i + 1,
397                episode.timestamp.format("%Y-%m-%d"),
398                episode.summary
399            ));
400        }
401
402        // Create summarization prompt
403        let summary_prompt = format!(
404            "Provide a coherent summary of these conversation episodes, highlighting key themes and progression:\n\n{}",
405            episode_text
406        );
407
408        let msg = ChatMessage::user(summary_prompt);
409
410        // Generate comprehensive summary
411        let response = llm_client
412            .chat_completion(vec![msg])
413            .await
414            .map_err(|e| crate::error::RragError::rsllm_client("episode_summary", e))?;
415
416        Ok(response.content.trim().to_string())
417    }
418
419    /// Extract insights from an episode using LLM analysis (requires 'rsllm-client' feature)
420    #[cfg(feature = "rexis-llm-client")]
421    pub async fn extract_insights(
422        &self,
423        episode: &Episode,
424        llm_client: &Client,
425    ) -> RragResult<Vec<String>> {
426        let insight_prompt = format!(
427            "Extract 3-5 key insights or learnings from this conversation summary:\n\n{}",
428            episode.summary
429        );
430
431        let msg = ChatMessage::user(insight_prompt);
432
433        let response = llm_client
434            .chat_completion(vec![msg])
435            .await
436            .map_err(|e| crate::error::RragError::rsllm_client("insight_extraction", e))?;
437
438        // Parse insights (assuming one per line)
439        let insights: Vec<String> = response
440            .content
441            .lines()
442            .filter(|line| !line.trim().is_empty())
443            .map(|line| {
444                // Remove leading numbers/bullets
445                line.trim()
446                    .trim_start_matches(|c: char| {
447                        c.is_numeric() || c == '.' || c == '-' || c == '*'
448                    })
449                    .trim()
450                    .to_string()
451            })
452            .filter(|s| !s.is_empty())
453            .collect();
454
455        Ok(insights)
456    }
457
458    /// Simple topic extraction from text (fallback when LLM not available)
459    fn extract_topics_from_text(&self, text: &str) -> Vec<String> {
460        // Simple keyword extraction - look for capitalized words and common programming terms
461        let common_topics = [
462            "rust",
463            "python",
464            "javascript",
465            "programming",
466            "coding",
467            "algorithm",
468            "database",
469            "api",
470            "frontend",
471            "backend",
472            "testing",
473            "deployment",
474            "performance",
475            "security",
476            "design",
477            "architecture",
478            "error",
479            "debugging",
480        ];
481
482        let text_lower = text.to_lowercase();
483        let mut topics = Vec::new();
484
485        for topic in common_topics {
486            if text_lower.contains(topic) {
487                topics.push(topic.to_string());
488            }
489        }
490
491        // Limit to top 5 topics
492        topics.truncate(5);
493
494        topics
495    }
496
497    /// Calculate importance score based on conversation characteristics
498    fn calculate_importance(&self, message_count: usize, conversation: &str) -> f64 {
499        let mut importance: f64 = 0.5; // Base importance
500
501        // More messages = potentially more important
502        if message_count > 10 {
503            importance += 0.2;
504        } else if message_count > 5 {
505            importance += 0.1;
506        }
507
508        // Longer conversations = potentially more important
509        let word_count = conversation.split_whitespace().count();
510        if word_count > 500 {
511            importance += 0.2;
512        } else if word_count > 200 {
513            importance += 0.1;
514        }
515
516        // Presence of key terms indicates importance
517        let important_terms = [
518            "important",
519            "critical",
520            "urgent",
521            "key",
522            "essential",
523            "decision",
524        ];
525        let conv_lower = conversation.to_lowercase();
526        for term in important_terms {
527            if conv_lower.contains(term) {
528                importance += 0.1;
529                break;
530            }
531        }
532
533        // Clamp to [0.0, 1.0]
534        importance.clamp(0.0, 1.0)
535    }
536}
537
538#[cfg(test)]
539mod tests {
540    use super::*;
541    use crate::storage::InMemoryStorage;
542
543    #[tokio::test]
544    async fn test_episodic_memory_store_and_retrieve() {
545        let storage = Arc::new(InMemoryStorage::new());
546        let episodic = EpisodicMemory::new(storage, "test-agent".to_string());
547
548        // Store an episode
549        let episode = Episode::new("User asked about Rust programming")
550            .with_topics(vec!["rust".to_string(), "programming".to_string()])
551            .with_importance(0.8);
552
553        let episode_id = episode.id.clone();
554        episodic.store_episode(episode).await.unwrap();
555
556        // Retrieve it
557        let retrieved = episodic.get_episode(&episode_id).await.unwrap().unwrap();
558        assert_eq!(retrieved.summary, "User asked about Rust programming");
559        assert_eq!(retrieved.topics.len(), 2);
560        assert_eq!(retrieved.importance, 0.8);
561    }
562
563    #[tokio::test]
564    async fn test_episodic_memory_recent_episodes() {
565        let storage = Arc::new(InMemoryStorage::new());
566        let episodic = EpisodicMemory::new(storage, "test-agent".to_string());
567
568        // Store multiple episodes
569        for i in 1..=5 {
570            let episode = Episode::new(format!("Episode {}", i));
571            episodic.store_episode(episode).await.unwrap();
572            tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
573        }
574
575        // Get recent (should be in reverse order)
576        let recent = episodic.get_recent_episodes(3).await.unwrap();
577        assert_eq!(recent.len(), 3);
578        assert!(recent[0].summary.contains("Episode 5"));
579    }
580
581    #[tokio::test]
582    async fn test_episodic_memory_find_by_topic() {
583        let storage = Arc::new(InMemoryStorage::new());
584        let episodic = EpisodicMemory::new(storage, "test-agent".to_string());
585
586        // Store episodes with different topics
587        episodic
588            .store_episode(Episode::new("Discussed Rust").with_topics(vec!["rust".to_string()]))
589            .await
590            .unwrap();
591        episodic
592            .store_episode(
593                Episode::new("Talked about Python").with_topics(vec!["python".to_string()]),
594            )
595            .await
596            .unwrap();
597        episodic
598            .store_episode(
599                Episode::new("Rust performance")
600                    .with_topics(vec!["rust".to_string(), "performance".to_string()]),
601            )
602            .await
603            .unwrap();
604
605        // Find by topic
606        let rust_episodes = episodic.find_by_topic("rust").await.unwrap();
607        assert_eq!(rust_episodes.len(), 2);
608    }
609
610    #[tokio::test]
611    async fn test_episodic_memory_context_summary() {
612        let storage = Arc::new(InMemoryStorage::new());
613        let episodic = EpisodicMemory::new(storage, "test-agent".to_string());
614
615        // Store some episodes
616        episodic
617            .store_episode(Episode::new("User asked about Rust"))
618            .await
619            .unwrap();
620        episodic
621            .store_episode(Episode::new("Discussed error handling"))
622            .await
623            .unwrap();
624
625        // Generate summary
626        let summary = episodic.generate_context_summary(5).await.unwrap();
627        assert!(summary.contains("Recent interaction history"));
628        assert!(summary.contains("User asked about Rust"));
629    }
630}