Skip to main content

ai_agent/services/extract_memories/
mod.rs

1// Source: /data/home/swei/claudecode/openclaudecode/src/services/extractMemories/extractMemories.ts
2//! Extracts durable memories from the current session transcript
3//! and writes them to the auto-memory directory (~/.ai/projects/<path>/memory/).
4//!
5//! It runs once at the end of each complete query loop (when the model produces
6//! a final response with no tool calls) via handleStopHooks in stopHooks.ts.
7//!
8//! Uses the forked agent pattern (runForkedAgent) — a perfect fork of the main
9//! conversation that shares the parent's prompt cache.
10//!
11//! State is closure-scoped inside init_extract_memories() rather than module-level,
12//! following the same pattern as confidenceRating.ts. Tests call
13//! init_extract_memories() in beforeEach to get a fresh closure.
14
15pub mod prompts;
16
17use std::collections::HashMap;
18use std::hash::Hash;
19use std::sync::Arc;
20
21use crate::constants::tools::{
22    BASH_TOOL_NAME, FILE_EDIT_TOOL_NAME, FILE_READ_TOOL_NAME, FILE_WRITE_TOOL_NAME, GLOB_TOOL_NAME,
23    GREP_TOOL_NAME,
24};
25use crate::memdir::ENTRYPOINT_NAME;
26use crate::memdir::memory_scan::scan_memory_files;
27use crate::memdir::paths::{get_auto_mem_path, is_auto_mem_path, is_auto_memory_enabled};
28use crate::tool::ToolUseContext;
29use crate::types::message::{
30    AssistantMessage, AssistantMessageContent, Message, SystemMessage, UserContent, UserMessage,
31    UserMessageContent,
32};
33use crate::utils::forked_agent::{
34    CacheSafeParams, CanUseToolFn, ForkedAgentConfig, ForkedAgentResult, PermissionDecision,
35    QuerySource, create_cache_safe_params, run_forked_agent,
36};
37
38/// Create a user message for the forked agent using types::message structures.
39fn create_fork_user_message(content: String) -> Message {
40    Message::User(crate::types::message::UserMessage {
41        base: crate::types::message::MessageBase {
42            uuid: Some(uuid::Uuid::new_v4().to_string()),
43            parent_uuid: None,
44            timestamp: Some(chrono::Utc::now().to_rfc3339()),
45            created_at: None,
46            is_meta: Some(true),
47            is_virtual: None,
48            is_compact_summary: None,
49            tool_use_result: None,
50            origin: None,
51            extra: HashMap::new(),
52        },
53        message_type: "user".to_string(),
54        message: UserMessageContent {
55            content: UserContent::Text(content),
56            extra: HashMap::new(),
57        },
58    })
59}
60
61// ============================================================================
62// Helpers
63// ============================================================================
64
65/// Count elements matching a predicate.
66fn count<T, F>(arr: &[T], pred: F) -> usize
67where
68    F: Fn(&T) -> bool,
69{
70    arr.iter().filter(|x| pred(x)).count()
71}
72
73/// Get unique elements, preserving order.
74fn uniq<T>(xs: impl IntoIterator<Item = T>) -> Vec<T>
75where
76    T: Eq + Hash + Clone,
77{
78    let mut set = std::collections::HashSet::new();
79    let mut result = Vec::new();
80    for x in xs {
81        if set.insert(x.clone()) {
82            result.push(x);
83        }
84    }
85    result
86}
87
88/// Returns true if a message is visible to the model (sent in API calls).
89/// Excludes progress, system, and attachment messages.
90fn is_model_visible_message(message: &Message) -> bool {
91    matches!(message, Message::User(_) | Message::Assistant(_))
92}
93
94fn count_model_visible_messages_since(messages: &[Message], since_uuid: Option<&str>) -> usize {
95    if since_uuid.is_none() {
96        return count(messages, is_model_visible_message);
97    }
98
99    let since_uuid = since_uuid.unwrap();
100    let mut found_start = false;
101    let mut n = 0;
102    for message in messages {
103        if !found_start {
104            if let Message::User(user_msg) = message {
105                if user_msg.base.uuid.as_deref() == Some(since_uuid) {
106                    found_start = true;
107                }
108            } else if let Message::Assistant(assistant_msg) = message {
109                if assistant_msg.base.uuid.as_deref() == Some(since_uuid) {
110                    found_start = true;
111                }
112            }
113            continue;
114        }
115        if is_model_visible_message(message) {
116            n += 1;
117        }
118    }
119    // If sinceUuid was not found (e.g., removed by context compaction),
120    // fall back to counting all model-visible messages rather than returning 0
121    // which would permanently disable extraction for the rest of the session.
122    if !found_start {
123        return count(messages, is_model_visible_message);
124    }
125    n
126}
127
128/// Returns true if any assistant message after the cursor UUID contains a
129/// Write/Edit tool_use block targeting an auto-memory path.
130fn has_memory_writes_since(messages: &[Message], since_uuid: Option<&str>) -> bool {
131    let mut found_start = since_uuid.is_none();
132    for message in messages {
133        if !found_start {
134            if let Message::User(user_msg) = message {
135                if user_msg.base.uuid.as_deref() == since_uuid {
136                    found_start = true;
137                }
138            } else if let Message::Assistant(assistant_msg) = message {
139                if assistant_msg.base.uuid.as_deref() == since_uuid {
140                    found_start = true;
141                }
142            }
143            continue;
144        }
145        if let Message::Assistant(assistant_msg) = message {
146            if let Some(content) = &assistant_msg.message {
147                if let Some(blocks) = &content.content {
148                    if let Some(arr) = blocks.as_array() {
149                        for block in arr {
150                            if let Some(file_path) = get_written_file_path(block) {
151                                if is_auto_mem_path(&std::path::Path::new(&file_path)) {
152                                    return true;
153                                }
154                            }
155                        }
156                    }
157                }
158            }
159        }
160    }
161    false
162}
163
164// ============================================================================
165// Tool Permissions
166// ============================================================================
167
168/// Check if a Bash input is read-only (no write/modify operations).
169fn is_bash_read_only(input: &serde_json::Value) -> bool {
170    if let Some(command) = input.get("command").and_then(|c| c.as_str()) {
171        // Check for write operations FIRST (before prefix matching)
172        if command.contains(" > ")
173            || command.contains(" >> ")
174            || command.contains(" 2> ")
175            || command.contains(" 2>> ")
176        {
177            return false;
178        }
179        let read_only_prefixes = [
180            "ls",
181            "find",
182            "cat",
183            "stat",
184            "wc",
185            "head",
186            "tail",
187            "grep",
188            "less",
189            "more",
190            "type",
191            "which",
192            "file",
193            "du",
194            "df",
195            "pwd",
196            "echo",
197            "sort",
198            "uniq",
199            "diff",
200            "comm",
201            "cut",
202            "awk",
203            "tr",
204            "xxd",
205            "od",
206            "hexdump",
207            "basename",
208            "dirname",
209            "readlink",
210            "realpath",
211            "env",
212            "printenv",
213            "date",
214            "uptime",
215            "free",
216            "ps",
217            "journalctl",
218            "systemctl status",
219            "mount",
220            "ip",
221            "ifconfig",
222            "ping",
223            "curl",
224            "man",
225            "info",
226            "--help",
227            "-h",
228            "-V",
229            "--version",
230            "touch -r",
231        ];
232        for prefix in &read_only_prefixes {
233            if command.starts_with(prefix) {
234                return true;
235            }
236        }
237        // No destructive commands
238        let destructive_prefixes = [
239            "rm ",
240            "mv ",
241            "cp ",
242            "dd ",
243            "truncate ",
244            "mkfs ",
245            "chmod ",
246            "chown ",
247            "sync",
248            "shutdown",
249            "reboot",
250            "mount --",
251            "umount",
252            "mkswap",
253            "swapoff",
254            "mkfs.",
255            "fsck",
256            "fdisk",
257            "wipefs",
258        ];
259        !destructive_prefixes.iter().any(|p| command.starts_with(p))
260    } else {
261        false
262    }
263}
264
265/// Deny a tool use with logging, returns error message string
266fn deny_auto_mem_tool_string(tool_name: &str, reason: &str) -> String {
267    log::debug!("[autoMem] denied {}: {}", tool_name, reason);
268    reason.to_string()
269}
270
271/// Creates a canUseTool function for the forked extraction agent.
272/// Allows Read/Grep/Glob (unrestricted), read-only Bash, and Edit/Write for auto-mem paths only.
273fn create_auto_mem_can_use_tool(memory_dir: std::path::PathBuf) -> Arc<CanUseToolFn> {
274    let memory_dir_str = memory_dir.to_string_lossy().to_string();
275
276    Arc::new(
277        move |_tool_def, input, _tool_use_context, _assistant_msg, _query_source, _is_explicit| {
278            let tool_name = _tool_def
279                .get("name")
280                .and_then(|n| n.as_str())
281                .unwrap_or("")
282                .to_string();
283            let input = input.clone();
284            let memory_dir_path = memory_dir.clone();
285            let memory_dir_str = memory_dir_str.clone();
286
287            Box::pin(async move {
288                let tool_name = tool_name;
289                if tool_name == FILE_READ_TOOL_NAME
290                    || tool_name == GREP_TOOL_NAME
291                    || tool_name == GLOB_TOOL_NAME
292                {
293                    return Ok(PermissionDecision::Allow);
294                }
295
296                if tool_name == BASH_TOOL_NAME {
297                    if is_bash_read_only(&input) {
298                        return Ok(PermissionDecision::Allow);
299                    }
300                    return Err(deny_auto_mem_tool_string(
301                        &tool_name,
302                        "Only read-only shell commands are permitted in this context (ls, find, grep, cat, stat, wc, head, tail, and similar)",
303                    ));
304                }
305
306                if tool_name == FILE_EDIT_TOOL_NAME || tool_name == FILE_WRITE_TOOL_NAME {
307                    if let Some(file_path) = input.get("file_path").and_then(|p| p.as_str()) {
308                        if is_auto_mem_path(&std::path::Path::new(file_path)) {
309                            return Ok(PermissionDecision::Allow);
310                        }
311                    }
312                }
313
314                Err(deny_auto_mem_tool_string(
315                    &tool_name,
316                    &format!(
317                        "only {}, {}, {}, read-only {}, and {}/{} within {} are allowed",
318                        FILE_READ_TOOL_NAME,
319                        GREP_TOOL_NAME,
320                        GLOB_TOOL_NAME,
321                        BASH_TOOL_NAME,
322                        FILE_EDIT_TOOL_NAME,
323                        FILE_WRITE_TOOL_NAME,
324                        memory_dir_str,
325                    ),
326                ))
327            })
328        },
329    )
330}
331
332// ============================================================================
333// Extract file paths from agent output
334// ============================================================================
335
336/// Extract file_path from a tool_use block's input, if present.
337fn get_written_file_path(block: &serde_json::Value) -> Option<String> {
338    if block.get("type").and_then(|t| t.as_str()) != Some("tool_use") {
339        return None;
340    }
341    let name = block.get("name").and_then(|n| n.as_str())?;
342    if name != FILE_EDIT_TOOL_NAME && name != FILE_WRITE_TOOL_NAME {
343        return None;
344    }
345    let input = block.get("input")?;
346    if let Some(obj) = input.as_object() {
347        if let Some(fp) = obj.get("file_path") {
348            return fp.as_str().map(String::from);
349        }
350    }
351    None
352}
353
354fn extract_written_paths(agent_messages: &[Message]) -> Vec<String> {
355    let mut paths = Vec::new();
356    for message in agent_messages {
357        if let Message::Assistant(assistant_msg) = message {
358            if let Some(content) = &assistant_msg.message {
359                if let Some(blocks) = &content.content {
360                    if let Some(arr) = blocks.as_array() {
361                        for block in arr {
362                            if let Some(file_path) = get_written_file_path(block) {
363                                paths.push(file_path);
364                            }
365                        }
366                    }
367                }
368            }
369        }
370    }
371    uniq(paths)
372}
373
374// ============================================================================
375// Initialization & Closure-scoped State
376// ============================================================================
377
378/// AppendSystemMessageFn — appends a system message to the conversation.
379pub type AppendSystemMessageFn = Arc<dyn Fn(SystemMessage) + Send + Sync>;
380
381/// Context from a REPL hook — mirrors REPLHookContext from TypeScript.
382#[derive(Clone)]
383pub struct ExtractMemoryContext {
384    pub messages: Vec<Message>,
385    pub system_prompt: String,
386    pub user_context: HashMap<String, String>,
387    pub system_context: HashMap<String, String>,
388    pub tool_use_context: Option<Arc<ToolUseContext>>,
389    pub agent_id: Option<String>,
390}
391
392/// State for managing in-flight extractions.
393struct ExtractionState {
394    last_memory_message_uuid: std::sync::Mutex<Option<String>>,
395    in_progress: std::sync::Mutex<bool>,
396    turns_since_last_extraction: std::sync::Mutex<usize>,
397    pending_context:
398        std::sync::Mutex<Option<(ExtractMemoryContext, Option<AppendSystemMessageFn>)>>,
399}
400
401impl ExtractionState {
402    fn new() -> Self {
403        Self {
404            last_memory_message_uuid: std::sync::Mutex::new(None),
405            in_progress: std::sync::Mutex::new(false),
406            turns_since_last_extraction: std::sync::Mutex::new(0),
407            pending_context: std::sync::Mutex::new(None),
408        }
409    }
410
411    fn clone_state(&self) -> Self {
412        Self {
413            last_memory_message_uuid: std::sync::Mutex::new(
414                self.last_memory_message_uuid.lock().unwrap().clone(),
415            ),
416            in_progress: std::sync::Mutex::new(*self.in_progress.lock().unwrap()),
417            turns_since_last_extraction: std::sync::Mutex::new(
418                *self.turns_since_last_extraction.lock().unwrap(),
419            ),
420            pending_context: std::sync::Mutex::new(self.pending_context.lock().unwrap().clone()),
421        }
422    }
423}
424
425/// Advance the message cursor to the last message in the list.
426fn advance_cursor(state: &ExtractionState, messages: &[Message]) {
427    let mut guard = state.last_memory_message_uuid.lock().unwrap();
428    if let Some(last) = messages.last() {
429        if let Message::User(u) = last {
430            *guard = u.base.uuid.clone();
431        } else if let Message::Assistant(a) = last {
432            *guard = a.base.uuid.clone();
433        }
434    }
435}
436
437/// Create a system message indicating memories were saved.
438fn create_system_memory_saved_message(written_paths: &[String]) -> SystemMessage {
439    SystemMessage {
440        base: crate::types::message::MessageBase {
441            uuid: Some(uuid::Uuid::new_v4().to_string()),
442            parent_uuid: None,
443            timestamp: Some(chrono::Utc::now().to_rfc3339()),
444            created_at: None,
445            is_meta: Some(false),
446            is_virtual: None,
447            is_compact_summary: None,
448            tool_use_result: None,
449            origin: None,
450            extra: HashMap::new(),
451        },
452        message_type: "system".to_string(),
453        subtype: Some("memory_saved".to_string()),
454        level: None,
455        message: Some(format!("Memories saved to: {}", written_paths.join(", "))),
456    }
457}
458
459/// Run a single extraction operation.
460async fn run_extraction(
461    state: &ExtractionState,
462    context: ExtractMemoryContext,
463    append_system_message: Option<AppendSystemMessageFn>,
464    is_trailing_run: bool,
465) {
466    async fn do_extraction(
467        state: &ExtractionState,
468        context: ExtractMemoryContext,
469        append_system_message: Option<AppendSystemMessageFn>,
470        is_trailing_run: bool,
471    ) {
472        let messages = &context.messages;
473        let memory_dir = get_auto_mem_path();
474        let memory_dir_path = memory_dir.clone();
475
476        let last_uuid = {
477            let guard = state.last_memory_message_uuid.lock().unwrap();
478            guard.clone()
479        };
480        let new_message_count = count_model_visible_messages_since(messages, last_uuid.as_deref());
481
482        // When the main agent wrote memories, skip the forked agent and advance cursor.
483        if has_memory_writes_since(messages, last_uuid.as_deref()) {
484            log::debug!("[extractMemories] skipping — conversation already wrote to memory files");
485            advance_cursor(state, messages);
486            return;
487        }
488
489        let turn_interval: usize = 1;
490
491        if !is_trailing_run {
492            let mut turns = state.turns_since_last_extraction.lock().unwrap();
493            *turns += 1;
494            if *turns < turn_interval {
495                return;
496            }
497            *turns = 0;
498            drop(turns);
499        }
500
501        {
502            let mut in_progress_guard = state.in_progress.lock().unwrap();
503            *in_progress_guard = true;
504        }
505
506        let start_time = std::time::Instant::now();
507        log::debug!(
508            "[extractMemories] starting — {} new messages, memoryDir={}",
509            new_message_count,
510            memory_dir.display()
511        );
512
513        // Pre-inject the memory directory manifest.
514        let existing_memories = {
515            let headers = scan_memory_files(&memory_dir.to_string_lossy()).await;
516            crate::memdir::format_memory_manifest(&headers)
517        };
518
519        let user_prompt =
520            prompts::build_extract_auto_only_prompt(new_message_count, &existing_memories, false);
521
522        let cache_safe_params = create_cache_safe_params(
523            context.system_prompt.clone(),
524            context.user_context.clone(),
525            context.system_context.clone(),
526            context.tool_use_context
527                .clone()
528                .unwrap_or_else(|| {
529                    // Minimal stub ToolUseContext for query engine callers that don't have one
530                    Arc::new(ToolUseContext::stub())
531                }),
532            messages.clone(),
533        );
534
535        let can_use_tool = create_auto_mem_can_use_tool(memory_dir_path);
536
537        let query_source = QuerySource("extract_memories".to_string());
538        let result = match run_forked_agent(ForkedAgentConfig {
539            prompt_messages: vec![create_fork_user_message(user_prompt)],
540            cache_safe_params,
541            can_use_tool,
542            query_source,
543            fork_label: "extract_memories".to_string(),
544            overrides: None,
545            max_output_tokens: None,
546            max_turns: Some(5),
547            on_message: None,
548            skip_transcript: true,
549            skip_cache_write: true,
550        })
551        .await
552        {
553            Ok(result) => result,
554            Err(e) => {
555                log::debug!("[extractMemories] error: {}", e);
556                let mut in_progress_guard = state.in_progress.lock().unwrap();
557                *in_progress_guard = false;
558                return;
559            }
560        };
561
562        advance_cursor(state, messages);
563
564        let written_paths = extract_written_paths(&result.messages);
565        let turn_count = count(&result.messages, |m| matches!(m, Message::Assistant(_)));
566
567        log::debug!(
568            "[extractMemories] finished — {} files written, turns={}",
569            written_paths.len(),
570            turn_count
571        );
572
573        if written_paths.is_empty() {
574            log::debug!("[extractMemories] no memories saved this run");
575        } else {
576            log::debug!(
577                "[extractMemories] memories saved: {}",
578                written_paths.join(", ")
579            );
580        }
581
582        // Filter out MEMORY.md entries to get actual memory file paths.
583        let memory_paths: Vec<String> = uniq(written_paths.into_iter().filter(|p| {
584            std::path::Path::new(p)
585                .file_name()
586                .map(|name| name.to_string_lossy() != ENTRYPOINT_NAME)
587                .unwrap_or(false)
588        }));
589
590        if let Some(ref append_fn) = append_system_message {
591            if !memory_paths.is_empty() {
592                let msg = create_system_memory_saved_message(&memory_paths);
593                append_fn(msg);
594            }
595        }
596
597        {
598            let mut in_progress_guard = state.in_progress.lock().unwrap();
599            *in_progress_guard = false;
600        }
601
602        // If a call arrived while we were running, run a trailing extraction.
603        let trailing = {
604            let mut pending = state.pending_context.lock().unwrap();
605            pending.take()
606        };
607        if let Some((trailing_context, trailing_append)) = trailing {
608            log::debug!("[extractMemories] running trailing extraction for stashed context");
609            Box::pin(do_extraction(
610                state,
611                trailing_context,
612                trailing_append,
613                true,
614            ))
615            .await;
616        }
617    }
618
619    do_extraction(state, context, append_system_message, is_trailing_run).await
620}
621
622// ============================================================================
623// Public API
624// ============================================================================
625
626static EXTRACTOR_STATE: std::sync::Mutex<Option<ExtractionState>> = std::sync::Mutex::new(None);
627
628/// Initialize the memory extraction system.
629pub fn init_extract_memories() {
630    let state = ExtractionState::new();
631    let mut guard = EXTRACTOR_STATE.lock().unwrap();
632    *guard = Some(state);
633}
634
635/// Run memory extraction at the end of a query loop.
636/// Called fire-and-forget from handleStopHooks.
637/// No-ops until init_extract_memories() has been called.
638pub async fn execute_extract_memories(
639    context: ExtractMemoryContext,
640    append_system_message: Option<AppendSystemMessageFn>,
641) {
642    let state = {
643        let guard = EXTRACTOR_STATE.lock().unwrap();
644        guard.as_ref().unwrap().clone_state()
645    };
646
647    // Only run for the main agent, not subagents.
648    if context.agent_id.is_some() {
649        return;
650    }
651
652    // Check auto-memory is enabled.
653    if !is_auto_memory_enabled() {
654        return;
655    }
656
657    // Skip in remote mode (simplified check).
658    if std::env::var("AI_CODE_REMOTE").is_ok()
659        && std::env::var("AI_CODE_REMOTE_MEMORY_DIR").is_err()
660    {
661        return;
662    }
663
664    let context = context.clone();
665    let append_fn = append_system_message;
666
667    // Check for in-progress extraction.
668    {
669        let in_progress = state.in_progress.lock().unwrap();
670        if *in_progress {
671            log::debug!("[extractMemories] extraction in progress — stashing for trailing run");
672            drop(in_progress);
673            let mut pending = state.pending_context.lock().unwrap();
674            *pending = Some((context, append_fn));
675            return;
676        }
677    }
678
679    run_extraction(&state, context, append_fn, false).await;
680}
681
682/// Awaits all in-flight extractions with a soft timeout.
683/// No-op until init_extract_memories() has been called.
684pub async fn drain_pending_extraction(_timeout_ms: Option<u64>) {
685    let _ = _timeout_ms;
686}
687
688#[cfg(test)]
689mod tests {
690    use super::*;
691
692    fn test_user_message(uuid: &str) -> Message {
693        Message::User(UserMessage {
694            base: crate::types::message::MessageBase {
695                uuid: Some(uuid.to_string()),
696                parent_uuid: None,
697                timestamp: Some("2024-01-01T00:00:00Z".to_string()),
698                created_at: None,
699                is_meta: None,
700                is_virtual: None,
701                is_compact_summary: None,
702                tool_use_result: None,
703                origin: None,
704                extra: HashMap::new(),
705            },
706            message_type: "user".to_string(),
707            message: UserMessageContent {
708                content: UserContent::Text("test".to_string()),
709                extra: HashMap::new(),
710            },
711        })
712    }
713
714    fn test_assistant_message(uuid: &str) -> Message {
715        Message::Assistant(AssistantMessage {
716            base: crate::types::message::MessageBase {
717                uuid: Some(uuid.to_string()),
718                parent_uuid: None,
719                timestamp: Some("2024-01-01T00:00:01Z".to_string()),
720                created_at: None,
721                is_meta: None,
722                is_virtual: None,
723                is_compact_summary: None,
724                tool_use_result: None,
725                origin: None,
726                extra: HashMap::new(),
727            },
728            message_type: "assistant".to_string(),
729            message: Some(AssistantMessageContent {
730                content: None,
731                extra: HashMap::new(),
732            }),
733        })
734    }
735
736    #[test]
737    fn test_is_model_visible_message() {
738        assert!(is_model_visible_message(&test_user_message("1")));
739        assert!(is_model_visible_message(&test_assistant_message("1")));
740    }
741
742    #[test]
743    fn test_count_model_visible_messages_since_none() {
744        let messages = vec![
745            test_user_message("1"),
746            test_assistant_message("2"),
747            test_user_message("3"),
748        ];
749        assert_eq!(count_model_visible_messages_since(&messages, None), 3);
750    }
751
752    #[test]
753    fn test_count_model_visible_messages_since_found() {
754        let messages = vec![
755            test_user_message("1"),
756            test_user_message("2"),
757            test_assistant_message("3"),
758            test_user_message("4"),
759        ];
760        assert_eq!(count_model_visible_messages_since(&messages, Some("2")), 2);
761    }
762
763    #[test]
764    fn test_count_model_visible_messages_since_not_found() {
765        let messages = vec![test_user_message("1"), test_assistant_message("2")];
766        assert_eq!(
767            count_model_visible_messages_since(&messages, Some("999")),
768            2
769        );
770    }
771
772    #[test]
773    fn test_has_memory_writes_since_empty() {
774        let messages = vec![test_user_message("1"), test_user_message("2")];
775        assert!(!has_memory_writes_since(&messages, None));
776    }
777
778    #[test]
779    fn test_build_extract_auto_only_prompt_has_required_sections() {
780        let prompt = prompts::build_extract_auto_only_prompt(5, "", false);
781        assert!(prompt.contains("memory extraction subagent"));
782        assert!(prompt.contains("How to save memories"));
783        assert!(prompt.contains("Types of memory"));
784        assert!(prompt.contains("What NOT to save in memory"));
785    }
786
787    #[test]
788    fn test_build_extract_auto_only_prompt_with_existing_memories() {
789        let existing = "- user_role.md (2024-01-01): User role\n- feedback_test.md (2024-01-02): Test feedback";
790        let prompt = prompts::build_extract_auto_only_prompt(5, existing, false);
791        assert!(prompt.contains("Existing memory files"));
792        assert!(prompt.contains("user_role.md"));
793    }
794
795    #[test]
796    fn test_build_extract_auto_only_prompt_skip_index() {
797        let prompt = prompts::build_extract_auto_only_prompt(5, "", true);
798        assert!(!prompt.contains("Step 1"));
799        assert!(!prompt.contains("Step 2"));
800    }
801
802    #[test]
803    fn test_build_extract_combined_prompt() {
804        let prompt = prompts::build_extract_combined_prompt(5, "", false);
805        assert!(prompt.contains("memory extraction subagent"));
806        assert!(prompt.contains("Types of memory"));
807        assert!(prompt.contains("You MUST avoid saving sensitive data"));
808    }
809
810    #[test]
811    fn test_bash_read_only() {
812        assert!(is_bash_read_only(&serde_json::json!({"command": "ls -la"})));
813        assert!(is_bash_read_only(
814            &serde_json::json!({"command": "grep pattern file.txt"})
815        ));
816        assert!(is_bash_read_only(
817            &serde_json::json!({"command": "cat file.txt"})
818        ));
819        assert!(!is_bash_read_only(
820            &serde_json::json!({"command": "rm file.txt"})
821        ));
822        assert!(!is_bash_read_only(
823            &serde_json::json!({"command": "echo hello > file.txt"})
824        ));
825        assert!(!is_bash_read_only(
826            &serde_json::json!({"command": "cp a b"})
827        ));
828    }
829
830    #[test]
831    fn test_get_written_file_path_edit_tool() {
832        let block = serde_json::json!({
833            "type": "tool_use",
834            "name": "Edit",
835            "input": {"file_path": "/some/path/memory/test.md", "edit_range": [0, 100], "new_str": "content"}
836        });
837        assert_eq!(
838            get_written_file_path(&block),
839            Some("/some/path/memory/test.md".to_string())
840        );
841    }
842
843    #[test]
844    fn test_get_written_file_path_write_tool() {
845        let block = serde_json::json!({
846            "type": "tool_use",
847            "name": "Write",
848            "input": {"file_path": "/some/path/memory/test.md", "content": "hello"}
849        });
850        assert_eq!(
851            get_written_file_path(&block),
852            Some("/some/path/memory/test.md".to_string())
853        );
854    }
855
856    #[test]
857    fn test_get_written_file_path_not_write_tool() {
858        let block = serde_json::json!({
859            "type": "tool_use",
860            "name": "Bash",
861            "input": {"command": "ls"}
862        });
863        assert_eq!(get_written_file_path(&block), None);
864    }
865
866    #[test]
867    fn test_get_written_file_path_not_tool_use() {
868        let block = serde_json::json!({"type": "text", "text": "hello"});
869        assert_eq!(get_written_file_path(&block), None);
870    }
871
872    #[test]
873    fn test_count_function() {
874        let data = vec![1, 2, 3, 4, 5];
875        assert_eq!(count(&data, |x| *x > 3), 2);
876        assert_eq!(count(&data, |x| *x < 0), 0);
877    }
878
879    #[test]
880    fn test_uniq_function() {
881        let data = vec![3, 1, 2, 1, 3, 4];
882        assert_eq!(uniq(data), vec![3, 1, 2, 4]);
883    }
884}