1use crate::error::RragResult;
8use crate::storage::{Memory, MemoryValue};
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11
12#[cfg(feature = "rexis-llm-client")]
13use rexis_llm::{ChatMessage, Client, MessageRole};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Episode {
18 pub id: String,
20
21 pub timestamp: chrono::DateTime<chrono::Utc>,
23
24 pub summary: String,
26
27 pub topics: Vec<String>,
29
30 pub importance: f64,
32
33 pub session_id: Option<String>,
35
36 pub insights: Vec<String>,
38
39 pub metadata: std::collections::HashMap<String, String>,
41}
42
43impl Episode {
44 pub fn new(summary: impl Into<String>) -> Self {
46 Self {
47 id: uuid::Uuid::new_v4().to_string(),
48 timestamp: chrono::Utc::now(),
49 summary: summary.into(),
50 topics: Vec::new(),
51 importance: 0.5,
52 session_id: None,
53 insights: Vec::new(),
54 metadata: std::collections::HashMap::new(),
55 }
56 }
57
58 pub fn with_topics(mut self, topics: Vec<String>) -> Self {
60 self.topics = topics;
61 self
62 }
63
64 pub fn with_importance(mut self, importance: f64) -> Self {
66 self.importance = importance.clamp(0.0, 1.0);
67 self
68 }
69
70 pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
72 self.session_id = Some(session_id.into());
73 self
74 }
75
76 pub fn with_insights(mut self, insights: Vec<String>) -> Self {
78 self.insights = insights;
79 self
80 }
81
82 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
84 self.metadata.insert(key.into(), value.into());
85 self
86 }
87}
88
89pub struct EpisodicMemory {
91 storage: Arc<dyn Memory>,
93
94 namespace: String,
96
97 max_episodes: usize,
99}
100
101impl EpisodicMemory {
102 pub fn new(storage: Arc<dyn Memory>, agent_id: String) -> Self {
104 let namespace = format!("agent::{}::episodic", agent_id);
105
106 Self {
107 storage,
108 namespace,
109 max_episodes: 1000,
110 }
111 }
112
113 pub fn with_max_episodes(mut self, max: usize) -> Self {
115 self.max_episodes = max;
116 self
117 }
118
119 pub async fn store_episode(&self, episode: Episode) -> RragResult<()> {
121 let key = self.episode_key(&episode.id);
122 let value = serde_json::to_value(&episode).map_err(|e| {
123 crate::error::RragError::storage(
124 "serialize_episode",
125 std::io::Error::new(std::io::ErrorKind::Other, e),
126 )
127 })?;
128
129 self.storage.set(&key, MemoryValue::Json(value)).await?;
130
131 self.prune_if_needed().await?;
133
134 Ok(())
135 }
136
137 pub async fn get_episode(&self, episode_id: &str) -> RragResult<Option<Episode>> {
139 let key = self.episode_key(episode_id);
140 if let Some(value) = self.storage.get(&key).await? {
141 if let Some(json) = value.as_json() {
142 let episode = serde_json::from_value(json.clone()).map_err(|e| {
143 crate::error::RragError::storage(
144 "deserialize_episode",
145 std::io::Error::new(std::io::ErrorKind::Other, e),
146 )
147 })?;
148 return Ok(Some(episode));
149 }
150 }
151 Ok(None)
152 }
153
154 pub async fn get_recent_episodes(&self, limit: usize) -> RragResult<Vec<Episode>> {
156 let mut all_episodes = self.get_all_episodes().await?;
157
158 all_episodes.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
160
161 all_episodes.truncate(limit);
163
164 Ok(all_episodes)
165 }
166
167 pub async fn find_by_topic(&self, topic: &str) -> RragResult<Vec<Episode>> {
169 let all_episodes = self.get_all_episodes().await?;
170
171 let matching = all_episodes
172 .into_iter()
173 .filter(|e| e.topics.iter().any(|t| t.contains(topic)))
174 .collect();
175
176 Ok(matching)
177 }
178
179 pub async fn find_by_importance(&self, min_importance: f64) -> RragResult<Vec<Episode>> {
181 let all_episodes = self.get_all_episodes().await?;
182
183 let important = all_episodes
184 .into_iter()
185 .filter(|e| e.importance >= min_importance)
186 .collect();
187
188 Ok(important)
189 }
190
191 pub async fn find_by_date_range(
193 &self,
194 start: chrono::DateTime<chrono::Utc>,
195 end: chrono::DateTime<chrono::Utc>,
196 ) -> RragResult<Vec<Episode>> {
197 let all_episodes = self.get_all_episodes().await?;
198
199 let in_range = all_episodes
200 .into_iter()
201 .filter(|e| e.timestamp >= start && e.timestamp <= end)
202 .collect();
203
204 Ok(in_range)
205 }
206
207 pub async fn get_all_episodes(&self) -> RragResult<Vec<Episode>> {
209 let all_keys = self.list_episode_keys().await?;
210 let mut episodes = Vec::new();
211
212 for key in all_keys {
213 if let Some(episode) = self.get_episode(&key).await? {
214 episodes.push(episode);
215 }
216 }
217
218 Ok(episodes)
219 }
220
221 pub async fn delete_episode(&self, episode_id: &str) -> RragResult<bool> {
223 let key = self.episode_key(episode_id);
224 self.storage.delete(&key).await
225 }
226
227 pub async fn count(&self) -> RragResult<usize> {
229 self.storage.count(Some(&self.namespace)).await
230 }
231
232 pub async fn clear(&self) -> RragResult<()> {
234 self.storage.clear(Some(&self.namespace)).await
235 }
236
237 pub async fn generate_context_summary(&self, num_episodes: usize) -> RragResult<String> {
239 let recent = self.get_recent_episodes(num_episodes).await?;
240
241 if recent.is_empty() {
242 return Ok(String::new());
243 }
244
245 let mut summary = String::from("Recent interaction history:\n");
246
247 for episode in recent.iter() {
248 summary.push_str(&format!(
249 "- [{}] {}\n",
250 episode.timestamp.format("%Y-%m-%d"),
251 episode.summary
252 ));
253
254 if !episode.topics.is_empty() {
255 summary.push_str(&format!(" Topics: {}\n", episode.topics.join(", ")));
256 }
257 }
258
259 Ok(summary)
260 }
261
262 async fn prune_if_needed(&self) -> RragResult<()> {
264 let count = self.count().await?;
265
266 if count <= self.max_episodes {
267 return Ok(());
268 }
269
270 let mut all_episodes = self.get_all_episodes().await?;
272
273 all_episodes.sort_by(|a, b| {
275 a.importance
276 .partial_cmp(&b.importance)
277 .unwrap()
278 .then(a.timestamp.cmp(&b.timestamp))
279 });
280
281 let to_delete = count - self.max_episodes;
283 for episode in all_episodes.iter().take(to_delete) {
284 self.delete_episode(&episode.id).await?;
285 }
286
287 Ok(())
288 }
289
290 fn episode_key(&self, episode_id: &str) -> String {
292 format!("{}::episode::{}", self.namespace, episode_id)
293 }
294
295 async fn list_episode_keys(&self) -> RragResult<Vec<String>> {
297 use crate::storage::MemoryQuery;
298
299 let query = MemoryQuery::new().with_namespace(self.namespace.clone());
300 let all_keys = self.storage.keys(&query).await?;
301
302 let prefix = format!("{}::episode::", self.namespace);
304 let ids = all_keys
305 .into_iter()
306 .filter_map(|k| k.strip_prefix(&prefix).map(String::from))
307 .collect();
308
309 Ok(ids)
310 }
311
312 #[cfg(feature = "rexis-llm-client")]
314 pub async fn create_episode_from_messages(
315 &self,
316 messages: &[ChatMessage],
317 llm_client: &Client,
318 ) -> RragResult<Episode> {
319 if messages.is_empty() {
320 return Err(crate::error::RragError::validation(
321 "messages",
322 "must not be empty",
323 "0 messages provided".to_string(),
324 ));
325 }
326
327 let mut conversation = String::new();
329 for msg in messages {
330 let content_text = match &msg.content {
331 rexis_llm::MessageContent::Text(text) => text.clone(),
332 rexis_llm::MessageContent::MultiModal { text, .. } => {
333 text.clone().unwrap_or_default()
334 }
335 };
336
337 conversation.push_str(&format!(
338 "{}: {}\n",
339 match msg.role {
340 MessageRole::User => "User",
341 MessageRole::Assistant => "Assistant",
342 MessageRole::System => "System",
343 MessageRole::Tool => "Tool",
344 },
345 content_text
346 ));
347 }
348
349 let summary_prompt = format!(
351 "Summarize this conversation in 2-3 sentences, focusing on key topics and outcomes:\n\n{}",
352 conversation
353 );
354
355 let summary_msg = ChatMessage::user(summary_prompt);
356
357 let response = llm_client
359 .chat_completion(vec![summary_msg])
360 .await
361 .map_err(|e| crate::error::RragError::rsllm_client("summarization", e))?;
362
363 let summary = response.content.trim().to_string();
364
365 let topics = self.extract_topics_from_text(&summary);
367
368 let importance = self.calculate_importance(messages.len(), &conversation);
370
371 let episode = Episode::new(summary)
372 .with_topics(topics)
373 .with_importance(importance);
374
375 Ok(episode)
376 }
377
378 #[cfg(feature = "rexis-llm-client")]
380 pub async fn generate_llm_summary(
381 &self,
382 num_episodes: usize,
383 llm_client: &Client,
384 ) -> RragResult<String> {
385 let recent = self.get_recent_episodes(num_episodes).await?;
386
387 if recent.is_empty() {
388 return Ok(String::from("No recent episodes to summarize."));
389 }
390
391 let mut episode_text = String::new();
393 for (i, episode) in recent.iter().enumerate() {
394 episode_text.push_str(&format!(
395 "{}. [{}] {}\n",
396 i + 1,
397 episode.timestamp.format("%Y-%m-%d"),
398 episode.summary
399 ));
400 }
401
402 let summary_prompt = format!(
404 "Provide a coherent summary of these conversation episodes, highlighting key themes and progression:\n\n{}",
405 episode_text
406 );
407
408 let msg = ChatMessage::user(summary_prompt);
409
410 let response = llm_client
412 .chat_completion(vec![msg])
413 .await
414 .map_err(|e| crate::error::RragError::rsllm_client("episode_summary", e))?;
415
416 Ok(response.content.trim().to_string())
417 }
418
419 #[cfg(feature = "rexis-llm-client")]
421 pub async fn extract_insights(
422 &self,
423 episode: &Episode,
424 llm_client: &Client,
425 ) -> RragResult<Vec<String>> {
426 let insight_prompt = format!(
427 "Extract 3-5 key insights or learnings from this conversation summary:\n\n{}",
428 episode.summary
429 );
430
431 let msg = ChatMessage::user(insight_prompt);
432
433 let response = llm_client
434 .chat_completion(vec![msg])
435 .await
436 .map_err(|e| crate::error::RragError::rsllm_client("insight_extraction", e))?;
437
438 let insights: Vec<String> = response
440 .content
441 .lines()
442 .filter(|line| !line.trim().is_empty())
443 .map(|line| {
444 line.trim()
446 .trim_start_matches(|c: char| {
447 c.is_numeric() || c == '.' || c == '-' || c == '*'
448 })
449 .trim()
450 .to_string()
451 })
452 .filter(|s| !s.is_empty())
453 .collect();
454
455 Ok(insights)
456 }
457
458 fn extract_topics_from_text(&self, text: &str) -> Vec<String> {
460 let common_topics = [
462 "rust",
463 "python",
464 "javascript",
465 "programming",
466 "coding",
467 "algorithm",
468 "database",
469 "api",
470 "frontend",
471 "backend",
472 "testing",
473 "deployment",
474 "performance",
475 "security",
476 "design",
477 "architecture",
478 "error",
479 "debugging",
480 ];
481
482 let text_lower = text.to_lowercase();
483 let mut topics = Vec::new();
484
485 for topic in common_topics {
486 if text_lower.contains(topic) {
487 topics.push(topic.to_string());
488 }
489 }
490
491 topics.truncate(5);
493
494 topics
495 }
496
497 fn calculate_importance(&self, message_count: usize, conversation: &str) -> f64 {
499 let mut importance: f64 = 0.5; if message_count > 10 {
503 importance += 0.2;
504 } else if message_count > 5 {
505 importance += 0.1;
506 }
507
508 let word_count = conversation.split_whitespace().count();
510 if word_count > 500 {
511 importance += 0.2;
512 } else if word_count > 200 {
513 importance += 0.1;
514 }
515
516 let important_terms = [
518 "important",
519 "critical",
520 "urgent",
521 "key",
522 "essential",
523 "decision",
524 ];
525 let conv_lower = conversation.to_lowercase();
526 for term in important_terms {
527 if conv_lower.contains(term) {
528 importance += 0.1;
529 break;
530 }
531 }
532
533 importance.clamp(0.0, 1.0)
535 }
536}
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541 use crate::storage::InMemoryStorage;
542
543 #[tokio::test]
544 async fn test_episodic_memory_store_and_retrieve() {
545 let storage = Arc::new(InMemoryStorage::new());
546 let episodic = EpisodicMemory::new(storage, "test-agent".to_string());
547
548 let episode = Episode::new("User asked about Rust programming")
550 .with_topics(vec!["rust".to_string(), "programming".to_string()])
551 .with_importance(0.8);
552
553 let episode_id = episode.id.clone();
554 episodic.store_episode(episode).await.unwrap();
555
556 let retrieved = episodic.get_episode(&episode_id).await.unwrap().unwrap();
558 assert_eq!(retrieved.summary, "User asked about Rust programming");
559 assert_eq!(retrieved.topics.len(), 2);
560 assert_eq!(retrieved.importance, 0.8);
561 }
562
563 #[tokio::test]
564 async fn test_episodic_memory_recent_episodes() {
565 let storage = Arc::new(InMemoryStorage::new());
566 let episodic = EpisodicMemory::new(storage, "test-agent".to_string());
567
568 for i in 1..=5 {
570 let episode = Episode::new(format!("Episode {}", i));
571 episodic.store_episode(episode).await.unwrap();
572 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
573 }
574
575 let recent = episodic.get_recent_episodes(3).await.unwrap();
577 assert_eq!(recent.len(), 3);
578 assert!(recent[0].summary.contains("Episode 5"));
579 }
580
581 #[tokio::test]
582 async fn test_episodic_memory_find_by_topic() {
583 let storage = Arc::new(InMemoryStorage::new());
584 let episodic = EpisodicMemory::new(storage, "test-agent".to_string());
585
586 episodic
588 .store_episode(Episode::new("Discussed Rust").with_topics(vec!["rust".to_string()]))
589 .await
590 .unwrap();
591 episodic
592 .store_episode(
593 Episode::new("Talked about Python").with_topics(vec!["python".to_string()]),
594 )
595 .await
596 .unwrap();
597 episodic
598 .store_episode(
599 Episode::new("Rust performance")
600 .with_topics(vec!["rust".to_string(), "performance".to_string()]),
601 )
602 .await
603 .unwrap();
604
605 let rust_episodes = episodic.find_by_topic("rust").await.unwrap();
607 assert_eq!(rust_episodes.len(), 2);
608 }
609
610 #[tokio::test]
611 async fn test_episodic_memory_context_summary() {
612 let storage = Arc::new(InMemoryStorage::new());
613 let episodic = EpisodicMemory::new(storage, "test-agent".to_string());
614
615 episodic
617 .store_episode(Episode::new("User asked about Rust"))
618 .await
619 .unwrap();
620 episodic
621 .store_episode(Episode::new("Discussed error handling"))
622 .await
623 .unwrap();
624
625 let summary = episodic.generate_context_summary(5).await.unwrap();
627 assert!(summary.contains("Recent interaction history"));
628 assert!(summary.contains("User asked about Rust"));
629 }
630}