1use crate::AgentError;
7use crate::constants::env::ai;
8use crate::types::*;
9
10#[derive(Debug, Clone)]
12pub struct ExtractMemoriesConfig {
13 pub min_messages: u32,
15 pub min_tool_calls: u32,
17 pub auto_only: bool,
19 pub max_entries: u32,
21}
22
23impl Default for ExtractMemoriesConfig {
24 fn default() -> Self {
25 Self {
26 min_messages: 10,
27 min_tool_calls: 3,
28 auto_only: false,
29 max_entries: 50,
30 }
31 }
32}
33
34#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
36pub struct MemoryEntry {
37 pub key: String,
39 pub content: String,
41 pub entry_type: MemoryEntryType,
43 pub is_auto: bool,
45}
46
47#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
48#[serde(rename_all = "snake_case")]
49pub enum MemoryEntryType {
50 KeyPoints,
51 Decisions,
52 OpenItems,
53 Context,
54}
55
56#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
58pub struct ExtractMemoriesResult {
59 pub success: bool,
60 pub entries: Vec<MemoryEntry>,
61 pub error: Option<String>,
62 pub messages_processed: u32,
63 pub tool_calls_count: u32,
64}
65
66#[derive(Debug, Clone)]
68pub struct PendingExtraction {
69 pub session_id: String,
70 pub messages: Vec<Message>,
71 pub timestamp: u64,
72}
73
74impl PendingExtraction {
75 pub fn new(session_id: String, messages: Vec<Message>) -> Self {
76 Self {
77 session_id,
78 messages,
79 timestamp: std::time::SystemTime::now()
80 .duration_since(std::time::UNIX_EPOCH)
81 .unwrap_or_default()
82 .as_secs(),
83 }
84 }
85}
86
87#[derive(Debug, Clone)]
89pub struct ExtractMemoriesState {
90 config: ExtractMemoriesConfig,
91 pending_extractions: Vec<PendingExtraction>,
92 is_extracting: bool,
93 last_extraction_time: Option<u64>,
94}
95
96impl ExtractMemoriesState {
97 pub fn new() -> Self {
98 Self {
99 config: ExtractMemoriesConfig::default(),
100 pending_extractions: Vec::new(),
101 is_extracting: false,
102 last_extraction_time: None,
103 }
104 }
105
106 pub fn with_config(config: ExtractMemoriesConfig) -> Self {
107 Self {
108 config,
109 pending_extractions: Vec::new(),
110 is_extracting: false,
111 last_extraction_time: None,
112 }
113 }
114
115 pub fn is_extracting(&self) -> bool {
116 self.is_extracting
117 }
118
119 pub fn set_extracting(&mut self, extracting: bool) {
120 self.is_extracting = extracting;
121 }
122
123 pub fn add_pending(&mut self, extraction: PendingExtraction) {
124 self.pending_extractions.push(extraction);
125 }
126
127 pub fn pop_pending(&mut self) -> Option<PendingExtraction> {
128 self.pending_extractions.pop()
129 }
130
131 pub fn pending_count(&self) -> usize {
132 self.pending_extractions.len()
133 }
134
135 pub fn update_extraction_time(&mut self) {
136 self.last_extraction_time = Some(
137 std::time::SystemTime::now()
138 .duration_since(std::time::UNIX_EPOCH)
139 .unwrap_or_default()
140 .as_secs(),
141 );
142 }
143
144 pub fn get_config(&self) -> &ExtractMemoriesConfig {
145 &self.config
146 }
147}
148
149impl Default for ExtractMemoriesState {
150 fn default() -> Self {
151 Self::new()
152 }
153}
154
155fn is_model_visible_message(message: &Message) -> bool {
159 matches!(message.role, MessageRole::User | MessageRole::Assistant)
160}
161
162pub fn count_model_visible_messages_since(
164 messages: &[Message],
165 since_index: Option<usize>,
166) -> usize {
167 let start = since_index.unwrap_or(0);
168 messages
169 .iter()
170 .skip(start)
171 .filter(|m| is_model_visible_message(m))
172 .count()
173}
174
175pub fn count_tool_calls(messages: &[Message]) -> usize {
177 let mut count = 0;
178 for message in messages {
179 if message.role == MessageRole::Assistant {
180 if let Some(ref tool_calls) = message.tool_calls {
181 count += tool_calls.len();
182 }
183 if message.content.contains("tool_use") {
185 count += 1;
186 }
187 }
188 }
189 count
190}
191
192pub fn should_extract_memories(messages: &[Message], config: &ExtractMemoriesConfig) -> bool {
194 let visible_count = messages
195 .iter()
196 .filter(|m| is_model_visible_message(m))
197 .count();
198 let tool_call_count = count_tool_calls(messages);
199
200 (visible_count as u32) >= config.min_messages
201 && (tool_call_count as u32) >= config.min_tool_calls
202}
203
204pub fn build_extract_auto_only_prompt() -> String {
208 r#"Extract key information from this conversation for memory.
209
210Focus on:
2111. Key Points - Important facts, findings, or conclusions
2122. Decisions Made - Any decisions or commitments
2133. Open Items - Tasks or questions still pending
214
215Provide your output as markdown that can be saved to memory files.
216Keep it concise but informative.
217
218Current conversation:"#
219 .to_string()
220}
221
222pub fn build_extract_combined_prompt() -> String {
224 r#"Extract key information from this conversation for memory.
225
226Focus on:
2271. Key Points - Important facts, findings, or conclusions
2282. Decisions Made - Any decisions or commitments
2293. Open Items - Tasks or questions still pending
2304. Context - Important background information that would help in future sessions
231
232Provide your output as markdown files with clear headers for each category.
233Keep it concise but informative.
234
235Current conversation:"#
236 .to_string()
237}
238
239pub fn parse_extracted_content(content: &str, is_auto: bool) -> Vec<MemoryEntry> {
243 let mut entries = Vec::new();
244
245 let mut current_section = String::new();
247 let mut current_content = String::new();
248
249 for line in content.lines() {
250 if line.starts_with("## ") {
251 if !current_content.trim().is_empty() {
253 let entry_type = match current_section.to_lowercase().as_str() {
254 s if s.contains("key") => MemoryEntryType::KeyPoints,
255 s if s.contains("decision") => MemoryEntryType::Decisions,
256 s if s.contains("open") => MemoryEntryType::OpenItems,
257 s if s.contains("context") => MemoryEntryType::Context,
258 _ => MemoryEntryType::Context,
259 };
260 entries.push(MemoryEntry {
261 key: format!("{}.md", current_section.to_lowercase().replace(' ', "_")),
262 content: current_content.trim().to_string(),
263 entry_type,
264 is_auto,
265 });
266 }
267 current_section = line.trim_start_matches("## ").to_string();
268 current_content = String::new();
269 } else {
270 current_content.push_str(line);
271 current_content.push('\n');
272 }
273 }
274
275 if !current_content.trim().is_empty() {
277 let entry_type = match current_section.to_lowercase().as_str() {
278 s if s.contains("key") => MemoryEntryType::KeyPoints,
279 s if s.contains("decision") => MemoryEntryType::Decisions,
280 s if s.contains("open") => MemoryEntryType::OpenItems,
281 s if s.contains("context") => MemoryEntryType::Context,
282 _ => MemoryEntryType::Context,
283 };
284 entries.push(MemoryEntry {
285 key: format!("{}.md", current_section.to_lowercase().replace(' ', "_")),
286 content: current_content.trim().to_string(),
287 entry_type,
288 is_auto,
289 });
290 }
291
292 entries
293}
294
295pub async fn execute_extract_memories(
299 messages: Vec<Message>,
300 config: ExtractMemoriesConfig,
301) -> Result<ExtractMemoriesResult, AgentError> {
302 if !should_extract_memories(&messages, &config) {
304 return Ok(ExtractMemoriesResult {
305 success: true,
306 entries: Vec::new(),
307 error: None,
308 messages_processed: messages.len() as u32,
309 tool_calls_count: count_tool_calls(&messages) as u32,
310 });
311 }
312
313 Ok(ExtractMemoriesResult {
320 success: false,
321 entries: Vec::new(),
322 error: Some("Memory extraction requires agent integration".to_string()),
323 messages_processed: messages.len() as u32,
324 tool_calls_count: count_tool_calls(&messages) as u32,
325 })
326}
327
328pub async fn drain_pending_extractions(
330 state: &mut ExtractMemoriesState,
331) -> Result<Vec<ExtractMemoriesResult>, AgentError> {
332 let mut results = Vec::new();
333
334 while let Some(pending) = state.pop_pending() {
335 let result = execute_extract_memories(pending.messages, state.get_config().clone()).await?;
336 results.push(result);
337 state.update_extraction_time();
338 }
339
340 Ok(results)
341}
342
343pub const TOOL_NAME_FILE_READ: &str = "Read";
347pub const TOOL_NAME_FILE_WRITE: &str = "Write";
348pub const TOOL_NAME_FILE_EDIT: &str = "Edit";
349pub const TOOL_NAME_GLOB: &str = "Glob";
350pub const TOOL_NAME_GREP: &str = "Grep";
351pub const TOOL_NAME_BASH: &str = "Bash";
352pub const TOOL_NAME_REPL: &str = "REPL";
353
354#[derive(Debug, Clone)]
356pub struct ToolPermission {
357 pub behavior: PermissionBehavior,
358 pub message: Option<String>,
359}
360
361#[derive(Debug, Clone, PartialEq)]
362pub enum PermissionBehavior {
363 Allow,
364 Deny,
365}
366
367pub fn create_auto_mem_can_use_tool(
370 memory_dir: &str,
371) -> impl Fn(&str, Option<&str>) -> ToolPermission + '_ {
372 move |tool_name: &str, file_path: Option<&str>| {
373 if tool_name == TOOL_NAME_REPL {
375 return ToolPermission {
376 behavior: PermissionBehavior::Allow,
377 message: None,
378 };
379 }
380
381 if matches!(
383 tool_name,
384 TOOL_NAME_FILE_READ | TOOL_NAME_GREP | TOOL_NAME_GLOB
385 ) {
386 return ToolPermission {
387 behavior: PermissionBehavior::Allow,
388 message: None,
389 };
390 }
391
392 if tool_name == TOOL_NAME_BASH {
395 return ToolPermission {
396 behavior: PermissionBehavior::Deny,
397 message: Some("Only read-only shell commands are permitted in this context (ls, find, grep, cat, stat, wc, head, tail, and similar)".to_string()),
398 };
399 }
400
401 if tool_name == TOOL_NAME_FILE_EDIT || tool_name == TOOL_NAME_FILE_WRITE {
403 if let Some(path) = file_path {
404 if is_auto_mem_path_str(path, memory_dir) {
405 return ToolPermission {
406 behavior: PermissionBehavior::Allow,
407 message: None,
408 };
409 }
410 }
411 }
412
413 ToolPermission {
414 behavior: PermissionBehavior::Deny,
415 message: Some(format!(
416 "only {}, {}, {}, read-only {}, and {}/{} within {} are allowed",
417 TOOL_NAME_FILE_READ,
418 TOOL_NAME_GREP,
419 TOOL_NAME_GLOB,
420 TOOL_NAME_BASH,
421 TOOL_NAME_FILE_EDIT,
422 TOOL_NAME_FILE_WRITE,
423 memory_dir
424 )),
425 }
426 }
427}
428
429fn is_auto_mem_path_str(absolute_path: &str, memory_dir: &str) -> bool {
431 absolute_path.starts_with(memory_dir)
432}
433
434#[allow(dead_code)]
438pub fn get_message_uuid(_message: &Message) -> Option<&str> {
439 None
442}
443
444pub fn count_model_visible_messages_since_uuid(
447 messages: &[Message],
448 since_uuid: Option<&str>,
449) -> usize {
450 if since_uuid.is_none() {
451 return messages
452 .iter()
453 .filter(|m| is_model_visible_message(m))
454 .count();
455 }
456
457 let since_uuid = since_uuid.unwrap();
458 let mut found_start = false;
459 let mut n = 0;
460
461 for message in messages {
462 if !found_start {
463 if get_message_uuid(message) == Some(since_uuid) {
465 found_start = true;
466 }
467 continue;
468 }
469 if is_model_visible_message(message) {
470 n += 1;
471 }
472 }
473
474 if !found_start {
476 return messages
477 .iter()
478 .filter(|m| is_model_visible_message(m))
479 .count();
480 }
481
482 n
483}
484
485pub fn has_memory_writes_since(
490 messages: &[Message],
491 since_uuid: Option<&str>,
492 memory_dir: &str,
493) -> bool {
494 let mut found_start = since_uuid.is_none();
495
496 for message in messages {
497 if !found_start {
498 if let Some(uuid) = get_message_uuid(message) {
499 if uuid == since_uuid.unwrap() {
500 found_start = true;
501 }
502 }
503 continue;
504 }
505
506 if message.role != MessageRole::Assistant {
507 continue;
508 }
509
510 if let Some(ref tool_calls) = message.tool_calls {
512 for tool_call in tool_calls {
513 let name = &tool_call.name;
514 if name == TOOL_NAME_FILE_WRITE || name == TOOL_NAME_FILE_EDIT {
515 if let Some(file_path) = extract_file_path_from_args(&tool_call.arguments) {
517 if is_auto_mem_path_str(&file_path, memory_dir) {
518 return true;
519 }
520 }
521 }
522 }
523 }
524 }
525
526 false
527}
528
529fn extract_file_path_from_args(args: &serde_json::Value) -> Option<String> {
531 if let Some(obj) = args.as_object() {
532 if let Some(fp) = obj.get("file_path") {
533 return fp.as_str().map(|s| s.to_string());
534 }
535 }
536 None
537}
538
539pub fn get_written_file_path(block: &serde_json::Value) -> Option<String> {
541 let obj = block.as_object()?;
542
543 if obj.get("type")?.as_str()? != "tool_use" {
545 return None;
546 }
547
548 let name = obj.get("name")?.as_str()?;
549 if name != TOOL_NAME_FILE_WRITE && name != TOOL_NAME_FILE_EDIT {
550 return None;
551 }
552
553 let input = obj.get("input")?;
554 let input_obj = input.as_object()?;
555
556 let fp = input_obj.get("file_path")?;
557 fp.as_str().map(|s| s.to_string())
558}
559
560pub fn extract_written_paths(agent_messages: &[Message]) -> Vec<String> {
562 let mut paths = Vec::new();
563
564 for message in agent_messages {
565 if message.role != MessageRole::Assistant {
566 continue;
567 }
568
569 if let Some(ref tool_calls) = message.tool_calls {
571 for tool_call in tool_calls {
572 let name = &tool_call.name;
573 if name == TOOL_NAME_FILE_WRITE || name == TOOL_NAME_FILE_EDIT {
574 if let Some(fp) = extract_file_path_from_args(&tool_call.arguments) {
575 paths.push(fp);
576 }
577 }
578 }
579 }
580 }
581
582 paths.sort();
584 paths.dedup();
585 paths
586}
587
588#[allow(dead_code)]
593pub struct ExtractMemories {
594 in_flight: std::sync::Arc<std::sync::Mutex<Vec<tokio::task::JoinHandle<()>>>>,
595 last_memory_message_uuid: std::sync::Arc<std::sync::Mutex<Option<String>>>,
596 has_logged_gate_failure: std::sync::Arc<std::sync::Mutex<bool>>,
597 in_progress: std::sync::Arc<std::sync::Mutex<bool>>,
598 turns_since_last_extraction: std::sync::Arc<std::sync::Mutex<u32>>,
599 pending_context: std::sync::Arc<std::sync::Mutex<Option<ExtractMemoriesContext>>>,
600}
601
602#[derive(Debug, Clone)]
603pub struct ExtractMemoriesContext {
604 pub messages: Vec<Message>,
605 }
607
608impl ExtractMemories {
609 pub fn new() -> Self {
610 Self {
611 in_flight: std::sync::Arc::new(std::sync::Mutex::new(Vec::new())),
612 last_memory_message_uuid: std::sync::Arc::new(std::sync::Mutex::new(None)),
613 has_logged_gate_failure: std::sync::Arc::new(std::sync::Mutex::new(false)),
614 in_progress: std::sync::Arc::new(std::sync::Mutex::new(false)),
615 turns_since_last_extraction: std::sync::Arc::new(std::sync::Mutex::new(0)),
616 pending_context: std::sync::Arc::new(std::sync::Mutex::new(None)),
617 }
618 }
619
620 pub fn is_gate_enabled() -> bool {
622 std::env::var(ai::DISABLE_EXTRACT_MEMORIES).ok() != Some("1".to_string())
625 }
626
627 pub fn is_auto_memory_enabled() -> bool {
629 crate::memdir::paths::is_auto_memory_enabled()
630 }
631
632 pub fn is_remote_mode() -> bool {
634 std::env::var(ai::REMOTE).ok() == Some("1".to_string())
635 }
636
637 pub async fn execute(&self, context: ExtractMemoriesContext) {
639 if !Self::is_gate_enabled() {
643 return;
644 }
645
646 if !Self::is_auto_memory_enabled() {
647 return;
648 }
649
650 if Self::is_remote_mode() {
651 return;
652 }
653
654 {
656 let in_progress = self.in_progress.lock().unwrap();
657 if *in_progress {
658 let mut pending = self.pending_context.lock().unwrap();
660 *pending = Some(context);
661 return;
662 }
663 }
664
665 self.run_extraction(context).await;
667 }
668
669 async fn run_extraction(&self, context: ExtractMemoriesContext) {
670 {
672 let mut in_progress = self.in_progress.lock().unwrap();
673 *in_progress = true;
674 }
675
676 let memory_dir = crate::memdir::paths::get_auto_mem_path();
678 let memory_dir_str = memory_dir.to_string_lossy().to_string();
679
680 let last_uuid = {
682 let guard = self.last_memory_message_uuid.lock().unwrap();
683 guard.clone()
684 };
685 let _new_message_count =
686 count_model_visible_messages_since_uuid(&context.messages, last_uuid.as_deref());
687
688 if has_memory_writes_since(&context.messages, last_uuid.as_deref(), &memory_dir_str) {
690 if let Some(last_msg) = context.messages.last() {
692 if let Some(uuid) = get_message_uuid(last_msg) {
693 let mut guard = self.last_memory_message_uuid.lock().unwrap();
694 *guard = Some(uuid.to_string());
695 }
696 }
697 }
698
699 {
701 let mut turns = self.turns_since_last_extraction.lock().unwrap();
702 *turns += 1;
703 if *turns < 1 {
704 {
706 let mut in_progress = self.in_progress.lock().unwrap();
707 *in_progress = false;
708 }
709 return;
710 }
711 *turns = 0;
712 }
713
714 {
724 let mut in_progress = self.in_progress.lock().unwrap();
725 *in_progress = false;
726 }
727 }
728
729 pub async fn drain(&self, timeout_ms: Option<u64>) {
731 let handles = {
732 let mut guard = self.in_flight.lock().unwrap();
733 std::mem::take(&mut *guard)
734 };
735
736 let timeout = timeout_ms.unwrap_or(60_000);
737 let timeout_duration = std::time::Duration::from_millis(timeout);
738
739 for handle in handles {
740 let _ = tokio::time::timeout(timeout_duration, handle).await;
741 }
742 }
743}
744
745impl Default for ExtractMemories {
746 fn default() -> Self {
747 Self::new()
748 }
749}
750
751#[cfg(test)]
754mod tests {
755 use super::*;
756
757 #[test]
758 fn test_is_model_visible_message() {
759 let user_msg = Message {
760 role: MessageRole::User,
761 content: "hello".to_string(),
762 attachments: None,
763 tool_call_id: None,
764 tool_calls: None,
765 is_error: None,
766 is_meta: None,
767 is_api_error_message: None,
768 error_details: None,
769 uuid: None,
770 };
771 let assistant_msg = Message {
772 role: MessageRole::Assistant,
773 content: "hi".to_string(),
774 attachments: None,
775 tool_call_id: None,
776 tool_calls: None,
777 is_error: None,
778 is_meta: None,
779 is_api_error_message: None,
780 error_details: None,
781 uuid: None,
782 };
783 let tool_msg = Message {
784 role: MessageRole::Tool,
785 content: "result".to_string(),
786 attachments: None,
787 tool_call_id: Some("call_1".to_string()),
788 tool_calls: None,
789 is_error: None,
790 is_meta: None,
791 is_api_error_message: None,
792 error_details: None,
793 uuid: None,
794 };
795
796 assert!(is_model_visible_message(&user_msg));
797 assert!(is_model_visible_message(&assistant_msg));
798 assert!(!is_model_visible_message(&tool_msg));
799 }
800
801 #[test]
802 fn test_count_model_visible_messages_since() {
803 let messages = vec![
804 Message {
805 role: MessageRole::User,
806 content: "hello".to_string(),
807 attachments: None,
808 tool_call_id: None,
809 tool_calls: None,
810 is_error: None,
811 is_meta: None,
812 is_api_error_message: None,
813 error_details: None,
814 uuid: None,
815 },
816 Message {
817 role: MessageRole::Assistant,
818 content: "hi".to_string(),
819 attachments: None,
820 tool_call_id: None,
821 tool_calls: None,
822 is_error: None,
823 is_meta: None,
824 is_api_error_message: None,
825 error_details: None,
826 uuid: None,
827 },
828 Message {
829 role: MessageRole::User,
830 content: "question".to_string(),
831 attachments: None,
832 tool_call_id: None,
833 tool_calls: None,
834 is_error: None,
835 is_meta: None,
836 is_api_error_message: None,
837 error_details: None,
838 uuid: None,
839 },
840 ];
841
842 assert_eq!(count_model_visible_messages_since(&messages, None), 3);
843 }
846
847 #[test]
848 fn test_should_extract_memories() {
849 let config = ExtractMemoriesConfig::default();
850
851 let few_messages = vec![Message {
852 role: MessageRole::User,
853 content: "hello".to_string(),
854 attachments: None,
855 tool_call_id: None,
856 tool_calls: None,
857 is_error: None,
858 is_meta: None,
859 is_api_error_message: None,
860 error_details: None,
861 uuid: None,
862 }];
863
864 assert!(!should_extract_memories(&few_messages, &config));
865
866 let enough_messages: Vec<Message> = (0..15)
867 .map(|i| Message {
868 role: if i % 2 == 0 {
869 MessageRole::User
870 } else {
871 MessageRole::Assistant
872 },
873 content: if i % 3 == 1 {
875 format!("message {} tool_use", i)
876 } else {
877 format!("message {}", i)
878 },
879 attachments: None,
880 tool_call_id: None,
881 tool_calls: None,
882 is_error: None,
883 is_meta: None,
884 is_api_error_message: None,
885 error_details: None,
886 uuid: None,
887 })
888 .collect();
889
890 assert!(should_extract_memories(&enough_messages, &config));
891 }
892
893 #[test]
894 fn test_extract_memories_state() {
895 let mut state = ExtractMemoriesState::new();
896 assert!(!state.is_extracting());
897
898 state.set_extracting(true);
899 assert!(state.is_extracting());
900
901 let extraction = PendingExtraction::new("session_1".to_string(), vec![]);
902 state.add_pending(extraction);
903 assert_eq!(state.pending_count(), 1);
904
905 let popped = state.pop_pending();
906 assert!(popped.is_some());
907 assert_eq!(state.pending_count(), 0);
908 }
909
910 #[test]
911 fn test_parse_extracted_content() {
912 let content = r#"## Key Points
913- First important point
914- Second important point
915
916## Decisions Made
917- Decision one
918- Decision two
919
920## Open Items
921- Task one
922"#;
923
924 let entries = parse_extracted_content(content, true);
925 assert!(!entries.is_empty());
926
927 let key_points = entries.iter().find(|e| e.key.contains("key_points"));
928 assert!(key_points.is_some());
929 assert!(
930 key_points
931 .unwrap()
932 .content
933 .contains("First important point")
934 );
935 }
936}