1use std::collections::HashSet;
12use std::fs;
13use std::io;
14use std::path::{Path, PathBuf};
15use std::sync::{Arc, Mutex};
16use std::time::Instant;
17
18use async_trait::async_trait;
19use tracing::{debug, info, warn};
20
21use nexo_driver_types::{ExtractMemoriesConfig, GoalId, MemoryExtractor};
22
23use crate::events::{DriverEvent, ExtractSkipReason};
24use crate::extract_memories_prompt::build_extract_prompt;
25
26#[async_trait]
32pub trait ExtractMemoriesLlm: Send + Sync + 'static {
33 async fn chat(
35 &self,
36 system_prompt: &str,
37 user_messages: &str,
38 max_tokens: u32,
39 ) -> Result<String, String>;
40}
41
42#[derive(Debug, serde::Deserialize)]
46struct MemoryFile {
47 file_path: String,
48 content: String,
49}
50
51#[derive(Debug)]
53pub struct ExtractMemoriesOutcome {
54 pub memories_saved: u32,
55 pub duration_ms: u64,
56}
57
58struct PendingExtraction {
62 messages_text: String,
63 memory_dir: PathBuf,
64}
65
66struct ExtractMemoriesState {
67 last_message_uuid: Option<String>,
69 in_progress: bool,
71 turns_since_last: u32,
73 pending: Option<PendingExtraction>,
75 consecutive_failures: u32,
77}
78
79impl ExtractMemoriesState {
80 fn new() -> Self {
81 Self {
82 last_message_uuid: None,
83 in_progress: false,
84 turns_since_last: 0,
85 pending: None,
86 consecutive_failures: 0,
87 }
88 }
89}
90
91pub struct ExtractMemories {
94 config: ExtractMemoriesConfig,
95 state: Mutex<ExtractMemoriesState>,
96 llm: Arc<dyn ExtractMemoriesLlm>,
97 new_message_count: u32,
99 guard: Option<nexo_memory::SecretGuard>,
102}
103
104impl ExtractMemories {
105 pub fn new(config: ExtractMemoriesConfig, llm: Arc<dyn ExtractMemoriesLlm>) -> Self {
106 Self {
107 config,
108 state: Mutex::new(ExtractMemoriesState::new()),
109 llm,
110 new_message_count: 20,
111 guard: None,
112 }
113 }
114
115 pub fn with_message_count(mut self, n: u32) -> Self {
117 self.new_message_count = n;
118 self
119 }
120
121 pub fn with_guard(mut self, guard: nexo_memory::SecretGuard) -> Self {
124 self.guard = Some(guard);
125 self
126 }
127
128 pub fn check_gates(&self) -> Result<(), ExtractSkipReason> {
133 if !self.config.enabled {
134 return Err(ExtractSkipReason::Disabled);
135 }
136
137 let state = self.state.lock().unwrap();
138
139 if state.turns_since_last < self.config.turns_throttle.saturating_sub(1) {
141 return Err(ExtractSkipReason::Throttled);
142 }
143
144 if state.in_progress {
146 return Err(ExtractSkipReason::InProgress);
147 }
148
149 if self.config.max_consecutive_failures > 0
151 && state.consecutive_failures >= self.config.max_consecutive_failures
152 {
153 return Err(ExtractSkipReason::CircuitBreakerOpen);
154 }
155
156 Ok(())
157 }
158
159 pub fn tick(&self) {
162 let mut state = self.state.lock().unwrap();
163 state.turns_since_last = state.turns_since_last.saturating_add(1);
164 }
165
166 fn mark_started(&self) {
169 let mut state = self.state.lock().unwrap();
170 state.in_progress = true;
171 }
172
173 fn record_success(&self, last_message_uuid: Option<String>) {
175 let mut state = self.state.lock().unwrap();
176 state.in_progress = false;
177 state.turns_since_last = 0;
178 state.consecutive_failures = 0;
179 state.last_message_uuid = last_message_uuid;
180 }
181
182 fn record_failure(&self) {
184 let mut state = self.state.lock().unwrap();
185 state.in_progress = false;
186 state.consecutive_failures = state.consecutive_failures.saturating_add(1);
187 }
188
189 pub fn stash_pending(&self, messages_text: String, memory_dir: PathBuf) {
192 let mut state = self.state.lock().unwrap();
193 state.pending = Some(PendingExtraction {
194 messages_text,
195 memory_dir,
196 });
197 }
198
199 fn take_pending(&self) -> Option<PendingExtraction> {
201 let mut state = self.state.lock().unwrap();
202 state.pending.take()
203 }
204
205 pub fn extract(
215 self: &Arc<Self>,
216 goal_id: GoalId,
217 turn_index: u32,
218 messages_text: String,
219 memory_dir: PathBuf,
220 ) {
221 let skip_reason = self.check_gates().err();
223
224 if let Some(reason) = skip_reason {
225 if matches!(reason, ExtractSkipReason::InProgress) {
227 self.stash_pending(messages_text, memory_dir);
228 }
229 debug!(
230 goal_id = %goal_id.0,
231 reason = ?reason,
232 "ExtractMemories skipped"
233 );
234 return;
235 }
236
237 self.mark_started();
238
239 let this = Arc::clone(self);
240 tokio::spawn(async move {
241 let start = Instant::now();
242 match this.run_extraction(&messages_text, &memory_dir).await {
243 Ok(memories_saved) => {
244 let duration_ms = start.elapsed().as_millis() as u64;
245 info!(
246 goal_id = %goal_id.0,
247 memories_saved,
248 duration_ms,
249 "ExtractMemories completed"
250 );
251 this.record_success(None);
252 let _ = (goal_id, turn_index, memories_saved, duration_ms);
256 }
257 Err(e) => {
258 warn!(
259 goal_id = %goal_id.0,
260 error = %e,
261 "ExtractMemories failed"
262 );
263 this.record_failure();
264 }
265 }
266
267 if let Some(pending) = this.take_pending() {
269 debug!("ExtractMemories: draining coalesced extraction");
270 let start = Instant::now();
271 match this
272 .run_extraction(&pending.messages_text, &pending.memory_dir)
273 .await
274 {
275 Ok(n) => {
276 this.record_success(None);
277 info!(memories_saved = n, "ExtractMemories coalesced ok");
278 }
279 Err(e) => {
280 warn!(error = %e, "ExtractMemories coalesced failed");
281 this.record_failure();
282 }
283 }
284 let _ = start; }
286 });
287 }
288
289 async fn run_extraction(&self, messages_text: &str, memory_dir: &Path) -> Result<u32, String> {
292 let manifest = scan_memory_manifest(memory_dir).unwrap_or_default();
294
295 let system_prompt = build_extract_prompt(self.new_message_count, &manifest);
297
298 let response_text = self
300 .llm
301 .chat(&system_prompt, messages_text, self.config.max_turns * 1024)
302 .await?;
303
304 let files: Vec<MemoryFile> = parse_extraction_response(&response_text)?;
306
307 if files.is_empty() {
308 return Ok(0);
309 }
310
311 for f in &files {
313 let resolved = resolve_memory_path(memory_dir, &f.file_path)?;
314 if !resolved.starts_with(memory_dir) {
315 return Err(format!(
316 "path escape attempt: {} -> {}",
317 f.file_path,
318 resolved.display()
319 ));
320 }
321 }
322
323 let mut written = 0u32;
325 for f in &files {
326 let content_to_write = if let Some(ref guard) = self.guard {
327 match guard.check(&f.content) {
328 Ok(redacted) => redacted,
329 Err(e) => {
330 tracing::warn!(
331 target = "memory.secret.blocked",
332 rule_ids = ?e.rule_ids,
333 content_hash = %e.content_hash,
334 file = %f.file_path,
335 "extract_memories: secret blocked, skipping file"
336 );
337 continue; }
339 }
340 } else {
341 f.content.clone()
342 };
343
344 let dest = resolve_memory_path(memory_dir, &f.file_path)?;
345 if let Some(parent) = dest.parent() {
346 fs::create_dir_all(parent)
347 .map_err(|e| format!("mkdir {}: {e}", parent.display()))?;
348 }
349 fs::write(&dest, &content_to_write)
350 .map_err(|e| format!("write {}: {e}", dest.display()))?;
351 written += 1;
352 }
353
354 update_memory_index(memory_dir, &files)?;
356
357 Ok(written)
358 }
359
360 pub fn has_memory_writes(&self, messages_text: &str, memory_dir: &Path) -> bool {
366 has_memory_writes_in_text(messages_text, memory_dir)
367 }
368}
369
370pub fn scan_memory_manifest(memory_dir: &Path) -> Result<String, io::Error> {
377 if !memory_dir.exists() {
378 return Ok(String::new());
379 }
380
381 let mut lines: Vec<String> = Vec::new();
382 let entries = fs::read_dir(memory_dir)?;
383
384 for entry in entries {
385 let entry = entry?;
386 let path = entry.path();
387 if path.extension().map_or(true, |e| e != "md") {
388 continue;
389 }
390 if path.file_name().map_or(false, |n| n == "MEMORY.md") {
392 continue;
393 }
394
395 let Some(file_name) = path.file_name().and_then(|n| n.to_str()) else {
396 continue;
397 };
398
399 match read_frontmatter(&path) {
400 Ok(Some(fm)) => {
401 let mem_type = fm.get("type").and_then(|s| s.as_str()).unwrap_or("unknown");
402 let _name = fm
403 .get("name")
404 .and_then(|s| s.as_str())
405 .unwrap_or(file_name.trim_end_matches(".md"));
406 let desc = fm.get("description").and_then(|s| s.as_str()).unwrap_or("");
407 lines.push(format!("- [{mem_type}] {file_name}: {desc}"));
408 }
409 Ok(None) => {
410 lines.push(format!("- [unknown] {file_name}"));
412 }
413 Err(_) => {
414 lines.push(format!("- [unknown] {file_name} (unreadable)"));
415 }
416 }
417 }
418
419 Ok(lines.join("\n"))
420}
421
422fn read_frontmatter(
425 path: &Path,
426) -> Result<Option<serde_json::Map<String, serde_json::Value>>, io::Error> {
427 let content = fs::read_to_string(path)?;
428 let mut lines = content.lines();
429
430 if lines.next() != Some("---") {
432 return Ok(None);
433 }
434
435 let mut yaml_lines: Vec<&str> = Vec::new();
436 for line in &mut lines {
437 if line == "---" {
438 break;
439 }
440 yaml_lines.push(line);
441 }
442
443 if yaml_lines.is_empty() {
444 return Ok(None);
445 }
446
447 let yaml_str = yaml_lines.join("\n");
448 let map: serde_json::Map<String, serde_json::Value> = serde_yaml::from_str(&yaml_str)
449 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("yaml parse: {e}")))?;
450
451 Ok(Some(map))
452}
453
454fn parse_extraction_response(text: &str) -> Result<Vec<MemoryFile>, String> {
457 let json_str = text
459 .trim()
460 .strip_prefix("```json")
461 .and_then(|s| s.strip_suffix("```"))
462 .map(|s| s.trim())
463 .unwrap_or(text.trim());
464
465 serde_json::from_str::<Vec<MemoryFile>>(json_str)
466 .map_err(|e| format!("parse extraction response: {e}"))
467}
468
469fn resolve_memory_path(memory_dir: &Path, file_path: &str) -> Result<PathBuf, String> {
475 if file_path.contains('\0') {
477 return Err(format!("null byte in path: {file_path}"));
478 }
479
480 let lower = file_path.to_lowercase();
484 if lower.contains("%2e") || lower.contains("%2f") || lower.contains("%5c") {
485 if let Ok(decoded) = urlencoding_maybe(file_path) {
487 if decoded.contains("..") {
488 return Err(format!("URL-encoded traversal rejected: {file_path}"));
489 }
490 }
491 }
492
493 if file_path.contains('\u{FF0E}')
497 || file_path.contains('\u{FF0F}')
498 || file_path.contains('\u{FF3C}')
499 || file_path.contains('\u{2215}')
500 {
502 return Err(format!("unicode traversal rejected: {file_path}"));
503 }
504
505 let p = Path::new(file_path);
506 if p.is_absolute() {
507 return Err(format!("absolute path rejected: {file_path}"));
508 }
509
510 let mut normalized = PathBuf::new();
512 for component in p.components() {
513 match component {
514 std::path::Component::ParentDir => {
515 return Err(format!("path traversal rejected: {file_path}"));
516 }
517 std::path::Component::Normal(c) => normalized.push(c),
518 std::path::Component::CurDir => {}
519 _ => return Err(format!("invalid path component in: {file_path}")),
520 }
521 }
522
523 let resolved = memory_dir.join(&normalized);
524
525 if memory_dir.exists() {
528 match resolved.canonicalize() {
529 Ok(real) => {
530 if !real.starts_with(memory_dir) {
531 return Err(format!("symlink escape rejected: {file_path}"));
532 }
533 }
534 Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
535 let mut current = resolved.clone();
537 while let Some(parent) = current.parent() {
538 if parent.as_os_str().is_empty() {
539 break;
540 }
541 match parent.canonicalize() {
542 Ok(real_parent) => {
543 if !real_parent.starts_with(memory_dir) {
544 return Err(format!("symlink escape in parent of: {file_path}"));
545 }
546 break; }
548 Err(_) => {
549 current = parent.to_path_buf();
550 continue; }
552 }
553 }
554 }
555 Err(_) => {
556 }
558 }
559 }
560
561 Ok(resolved)
562}
563
564fn urlencoding_maybe(input: &str) -> Result<String, ()> {
567 if !input.contains('%') {
568 return Ok(input.to_string());
569 }
570 let mut out = String::with_capacity(input.len());
571 let mut chars = input.chars();
572 while let Some(c) = chars.next() {
573 if c == '%' {
574 let h1 = chars.next().ok_or(())?;
575 let h2 = chars.next().ok_or(())?;
576 let byte = u8::from_str_radix(&format!("{h1}{h2}"), 16).map_err(|_| ())?;
577 out.push(byte as char);
578 } else {
579 out.push(c);
580 }
581 }
582 Ok(out)
583}
584
585fn update_memory_index(memory_dir: &Path, files: &[MemoryFile]) -> Result<(), String> {
590 let index_path = memory_dir.join("MEMORY.md");
591 let existing = if index_path.exists() {
592 fs::read_to_string(&index_path).map_err(|e| format!("read MEMORY.md: {e}"))?
593 } else {
594 String::from("# Memory index\n\n")
595 };
596
597 let existing_paths: HashSet<&str> = existing
598 .lines()
599 .filter_map(|line| {
600 line.trim()
601 .strip_prefix("- [")
602 .and_then(|rest| rest.split_once("]("))
603 .and_then(|(_, rest)| rest.split_once(')').map(|(path, _)| path))
604 })
605 .collect();
606
607 let mut new_lines: Vec<String> = Vec::new();
608 for f in files {
609 if existing_paths.contains(f.file_path.as_str()) {
610 continue;
611 }
612 let mut in_frontmatter = false;
615 let mut closed_frontmatter = false;
616 let hook = f
617 .content
618 .lines()
619 .find(|l| {
620 if l.trim() == "---" {
621 if !in_frontmatter {
622 in_frontmatter = true;
623 } else if in_frontmatter && !closed_frontmatter {
624 closed_frontmatter = true;
625 }
626 return false;
627 }
628 if in_frontmatter && !closed_frontmatter {
630 return false;
631 }
632 !l.is_empty()
634 })
635 .map(|l| {
636 let trimmed = l.trim();
637 if trimmed.len() > 80 {
639 format!("{}…", &trimmed[..80])
640 } else {
641 trimmed.to_string()
642 }
643 })
644 .unwrap_or_default();
645 new_lines.push(format!("- [{}]({}) — {}", f.file_path, f.file_path, hook));
646 }
647
648 if new_lines.is_empty() {
649 return Ok(());
650 }
651
652 let mut updated = existing;
653 while updated.ends_with('\n') {
655 updated.pop();
656 }
657 updated.push('\n');
658 for line in &new_lines {
659 updated.push_str(line);
660 updated.push('\n');
661 }
662
663 fs::write(&index_path, updated).map_err(|e| format!("write MEMORY.md: {e}"))?;
664 Ok(())
665}
666
667fn has_memory_writes_in_text(messages_text: &str, memory_dir: &Path) -> bool {
673 let mem_dir_str = memory_dir.to_string_lossy();
674 let has_memory_path = messages_text.contains(mem_dir_str.as_ref());
677 if !has_memory_path {
678 return false;
679 }
680 let write_patterns = [
682 "Write",
683 "\"name\": \"Write\"",
684 "\"name\":\"Write\"",
685 "Edit",
686 "\"name\": \"Edit\"",
687 "\"name\":\"Edit\"",
688 "file_write",
689 "file_edit",
690 "write_to_file",
691 ];
692 write_patterns.iter().any(|p| messages_text.contains(p))
693}
694
695impl ExtractSkipReason {
698 pub fn to_event(self, goal_id: GoalId) -> DriverEvent {
700 DriverEvent::ExtractMemoriesSkipped {
701 goal_id,
702 reason: self,
703 }
704 }
705}
706
707pub struct NoopExtractMemoriesLlm {
710 pub canned_response: Mutex<Option<String>>,
712}
713
714impl NoopExtractMemoriesLlm {
715 pub fn new() -> Self {
716 Self {
717 canned_response: Mutex::new(None),
718 }
719 }
720
721 pub fn with_response(response: impl Into<String>) -> Self {
722 Self {
723 canned_response: Mutex::new(Some(response.into())),
724 }
725 }
726}
727
728impl Default for NoopExtractMemoriesLlm {
729 fn default() -> Self {
730 Self::new()
731 }
732}
733
734#[async_trait]
735impl ExtractMemoriesLlm for NoopExtractMemoriesLlm {
736 async fn chat(
737 &self,
738 _system_prompt: &str,
739 _user_messages: &str,
740 _max_tokens: u32,
741 ) -> Result<String, String> {
742 self.canned_response
743 .lock()
744 .unwrap()
745 .take()
746 .ok_or_else(|| "NoopExtractMemoriesLlm: no canned response set".to_string())
747 }
748}
749
750pub struct LlmClientAdapter {
773 llm: Arc<dyn nexo_llm::LlmClient>,
774 model: String,
775}
776
777impl LlmClientAdapter {
778 pub fn new(llm: Arc<dyn nexo_llm::LlmClient>, model: impl Into<String>) -> Self {
779 Self {
780 llm,
781 model: model.into(),
782 }
783 }
784}
785
786#[async_trait]
787impl ExtractMemoriesLlm for LlmClientAdapter {
788 async fn chat(
789 &self,
790 system_prompt: &str,
791 user_messages: &str,
792 max_tokens: u32,
793 ) -> Result<String, String> {
794 let mut req = nexo_llm::ChatRequest::new(
795 self.model.clone(),
796 vec![nexo_llm::ChatMessage::user(user_messages.to_string())],
797 );
798 req.system_prompt = Some(system_prompt.to_string());
799 req.max_tokens = max_tokens;
800 let resp = self
801 .llm
802 .chat(req)
803 .await
804 .map_err(|e| format!("LlmClientAdapter chat error: {e}"))?;
805 match resp.content {
806 nexo_llm::ResponseContent::Text(text) => Ok(text),
807 nexo_llm::ResponseContent::ToolCalls(_) => {
808 Err("LlmClientAdapter: response is tool_calls, expected text".into())
809 }
810 }
811 }
812}
813
814impl MemoryExtractor for ExtractMemories {
826 fn tick(&self) {
827 ExtractMemories::tick(self);
828 }
829
830 fn extract(
831 self: Arc<Self>,
832 goal_id: GoalId,
833 turn_index: u32,
834 messages_text: String,
835 memory_dir: PathBuf,
836 ) {
837 ExtractMemories::extract(&self, goal_id, turn_index, messages_text, memory_dir);
840 }
841}
842
843#[cfg(test)]
846mod tests {
847 use super::*;
848 use tempfile::TempDir;
849
850 #[test]
853 fn scan_manifest_empty_dir() {
854 let dir = TempDir::new().unwrap();
855 let manifest = scan_memory_manifest(dir.path()).unwrap();
856 assert!(manifest.is_empty());
857 }
858
859 #[test]
860 fn scan_manifest_nonexistent_dir() {
861 let manifest = scan_memory_manifest(Path::new("/tmp/nonexistent-memdir-77-5")).unwrap();
862 assert!(manifest.is_empty());
863 }
864
865 #[test]
866 fn scan_manifest_reads_frontmatter() {
867 let dir = TempDir::new().unwrap();
868 fs::write(
869 dir.path().join("preferences.md"),
870 "---\nname: user preferences\ndescription: likes dark mode\ntype: user\n---\n\nUser prefers dark mode.",
871 )
872 .unwrap();
873 fs::write(
874 dir.path().join("deploy.md"),
875 "---\nname: deploy notes\ndescription: deploy process\ntype: project\n---\n\nDeploy on Fridays.",
876 )
877 .unwrap();
878
879 let manifest = scan_memory_manifest(dir.path()).unwrap();
880 assert!(
881 manifest.contains("preferences.md"),
882 "missing preferences: {manifest}"
883 );
884 assert!(
885 manifest.contains("dark mode"),
886 "missing description: {manifest}"
887 );
888 assert!(manifest.contains("[user]"), "missing type tag: {manifest}");
889 assert!(manifest.contains("deploy.md"), "missing deploy: {manifest}");
890 assert!(
891 manifest.contains("[project]"),
892 "missing project type: {manifest}"
893 );
894 }
895
896 #[test]
897 fn scan_manifest_skips_memory_index() {
898 let dir = TempDir::new().unwrap();
899 fs::write(
900 dir.path().join("MEMORY.md"),
901 "# Memory index\n\n- [prefs](preferences.md)\n",
902 )
903 .unwrap();
904 fs::write(
905 dir.path().join("preferences.md"),
906 "---\nname: prefs\ndescription: x\ntype: user\n---\n\nContent.",
907 )
908 .unwrap();
909
910 let manifest = scan_memory_manifest(dir.path()).unwrap();
911 assert!(
912 !manifest.contains("MEMORY.md"),
913 "MEMORY.md should be excluded: {manifest}"
914 );
915 assert!(
916 manifest.contains("preferences.md"),
917 "should list preferences: {manifest}"
918 );
919 }
920
921 #[test]
922 fn scan_manifest_file_without_frontmatter() {
923 let dir = TempDir::new().unwrap();
924 fs::write(
925 dir.path().join("notes.md"),
926 "Just some notes.\nNo frontmatter here.",
927 )
928 .unwrap();
929
930 let manifest = scan_memory_manifest(dir.path()).unwrap();
931 assert!(
932 manifest.contains("[unknown]"),
933 "should tag as unknown: {manifest}"
934 );
935 assert!(
936 manifest.contains("notes.md"),
937 "should list the file: {manifest}"
938 );
939 }
940
941 #[test]
944 fn has_memory_writes_detects_write_tool() {
945 let text = r#"Tool: Write
946Arguments: {"file_path": "/home/user/.claude/projects/test/memory/foo.md", "content": "..."}"#;
947 assert!(has_memory_writes_in_text(
948 text,
949 Path::new("/home/user/.claude/projects/test/memory")
950 ));
951 }
952
953 #[test]
954 fn has_memory_writes_detects_file_write_tool() {
955 let text = r#"I'll use file_write to save this memory.
956{"tool": "file_write", "path": "/home/user/.claude/projects/x/memory/bar.md"}"#;
957 assert!(has_memory_writes_in_text(
958 text,
959 Path::new("/home/user/.claude/projects/x/memory")
960 ));
961 }
962
963 #[test]
964 fn has_memory_writes_no_write() {
965 let text = "Just a normal conversation.\nNo tool calls here.";
966 assert!(!has_memory_writes_in_text(
967 text,
968 Path::new("/home/user/.claude/projects/test/memory")
969 ));
970 }
971
972 #[test]
973 fn has_memory_writes_write_outside_memory_dir() {
974 let text = r#"Write to /tmp/some-other-file.txt"#;
975 assert!(!has_memory_writes_in_text(
976 text,
977 Path::new("/home/user/memory")
978 ));
979 }
980
981 #[test]
984 fn resolve_memory_path_rejects_absolute() {
985 assert!(resolve_memory_path(Path::new("/mem"), "/etc/passwd").is_err());
986 }
987
988 #[test]
989 fn resolve_memory_path_rejects_parent_traversal() {
990 assert!(resolve_memory_path(Path::new("/mem"), "../outside.md").is_err());
991 assert!(resolve_memory_path(Path::new("/mem"), "sub/../../outside.md").is_err());
992 }
993
994 #[test]
995 fn resolve_memory_path_accepts_normal() {
996 let result = resolve_memory_path(Path::new("/mem"), "user_role.md").unwrap();
997 assert_eq!(result, PathBuf::from("/mem/user_role.md"));
998 }
999
1000 #[test]
1001 fn resolve_memory_path_accepts_subdir() {
1002 let result = resolve_memory_path(Path::new("/mem"), "sub/dir/file.md").unwrap();
1003 assert_eq!(result, PathBuf::from("/mem/sub/dir/file.md"));
1004 }
1005
1006 #[test]
1009 fn parse_response_bare_json() {
1010 let json = r#"[{"file_path": "test.md", "content": "hello"}]"#;
1011 let files = parse_extraction_response(json).unwrap();
1012 assert_eq!(files.len(), 1);
1013 assert_eq!(files[0].file_path, "test.md");
1014 assert_eq!(files[0].content, "hello");
1015 }
1016
1017 #[test]
1018 fn parse_response_json_fenced() {
1019 let json = "```json\n[{\"file_path\": \"x.md\", \"content\": \"y\"}]\n```";
1020 let files = parse_extraction_response(json).unwrap();
1021 assert_eq!(files.len(), 1);
1022 assert_eq!(files[0].file_path, "x.md");
1023 }
1024
1025 #[test]
1026 fn parse_response_empty_array() {
1027 let files = parse_extraction_response("[]").unwrap();
1028 assert!(files.is_empty());
1029 }
1030
1031 #[test]
1032 fn parse_response_invalid_json() {
1033 assert!(parse_extraction_response("not json").is_err());
1034 }
1035
1036 fn make_config() -> ExtractMemoriesConfig {
1039 ExtractMemoriesConfig {
1040 enabled: true,
1041 turns_throttle: 1,
1042 max_turns: 5,
1043 max_consecutive_failures: 3,
1044 }
1045 }
1046
1047 fn make_extractor(config: ExtractMemoriesConfig) -> Arc<ExtractMemories> {
1048 Arc::new(ExtractMemories::new(
1049 config,
1050 Arc::new(NoopExtractMemoriesLlm::new()),
1051 ))
1052 }
1053
1054 #[test]
1055 fn gate_disabled_when_enabled_false() {
1056 let mut cfg = make_config();
1057 cfg.enabled = false;
1058 let ext = make_extractor(cfg);
1059 assert!(matches!(
1060 ext.check_gates(),
1061 Err(ExtractSkipReason::Disabled)
1062 ));
1063 }
1064
1065 #[test]
1066 fn gate_throttled_when_not_enough_turns() {
1067 let mut cfg = make_config();
1068 cfg.turns_throttle = 3;
1069 let ext = make_extractor(cfg);
1070 assert!(matches!(
1073 ext.check_gates(),
1074 Err(ExtractSkipReason::Throttled)
1075 ));
1076 }
1077
1078 #[test]
1079 fn gate_passes_when_throttle_satisfied() {
1080 let cfg = make_config(); let ext = make_extractor(cfg);
1082 assert!(ext.check_gates().is_ok());
1083 }
1084
1085 #[test]
1086 fn gate_passes_after_tick_accumulates() {
1087 let mut cfg = make_config();
1088 cfg.turns_throttle = 2;
1089 let ext = make_extractor(cfg);
1090 assert!(ext.check_gates().is_err());
1092 ext.tick();
1093 assert!(ext.check_gates().is_ok());
1095 }
1096
1097 #[test]
1098 fn gate_circuit_breaker_trips_after_n_failures() {
1099 let ext = make_extractor(make_config());
1100 ext.record_failure();
1101 ext.record_failure();
1102 ext.record_failure();
1103 assert!(matches!(
1104 ext.check_gates(),
1105 Err(ExtractSkipReason::CircuitBreakerOpen)
1106 ));
1107 }
1108
1109 #[test]
1110 fn gate_circuit_breaker_disabled_when_max_zero() {
1111 let mut cfg = make_config();
1112 cfg.max_consecutive_failures = 0;
1113 let ext = make_extractor(cfg);
1114 ext.record_failure();
1115 ext.record_failure();
1116 ext.record_failure();
1117 assert!(ext.check_gates().is_ok());
1119 }
1120
1121 #[test]
1122 fn record_success_resets_failures_and_turns() {
1123 let ext = make_extractor(make_config());
1124 ext.record_failure();
1125 ext.record_failure();
1126 ext.record_success(None);
1127 assert!(ext.check_gates().is_ok());
1129 }
1130
1131 #[test]
1134 fn update_index_creates_file_when_missing() {
1135 let dir = TempDir::new().unwrap();
1136 let files = vec![MemoryFile {
1137 file_path: "new_memory.md".to_string(),
1138 content: "---\nname: test\ntype: user\n---\n\nSome content here.".to_string(),
1139 }];
1140 update_memory_index(dir.path(), &files).unwrap();
1141
1142 let index = fs::read_to_string(dir.path().join("MEMORY.md")).unwrap();
1143 assert!(index.contains("new_memory.md"), "should list new file");
1144 assert!(index.contains("Some content here"), "should include hook");
1145 }
1146
1147 #[test]
1148 fn update_index_skips_duplicates() {
1149 let dir = TempDir::new().unwrap();
1150 fs::write(
1151 dir.path().join("MEMORY.md"),
1152 "# Memory index\n\n- [existing](existing.md) — already there\n",
1153 )
1154 .unwrap();
1155
1156 let files = vec![
1157 MemoryFile {
1158 file_path: "existing.md".to_string(),
1159 content: "duplicate".to_string(),
1160 },
1161 MemoryFile {
1162 file_path: "new_one.md".to_string(),
1163 content: "new content here".to_string(),
1164 },
1165 ];
1166 update_memory_index(dir.path(), &files).unwrap();
1167
1168 let index = fs::read_to_string(dir.path().join("MEMORY.md")).unwrap();
1169 let existing_count = index.matches("existing.md").count();
1170 assert_eq!(existing_count, 1, "duplicate should not be appended");
1171 assert!(index.contains("new_one.md"), "new file should be appended");
1172 }
1173
1174 #[test]
1177 fn resolve_memory_path_rejects_null_byte() {
1178 assert!(resolve_memory_path(Path::new("/mem"), "foo\0bar.md").is_err());
1179 }
1180
1181 #[test]
1182 fn resolve_memory_path_rejects_url_encoded_traversal() {
1183 assert!(resolve_memory_path(Path::new("/mem"), "%2e%2e%2foutside.md").is_err());
1184 assert!(resolve_memory_path(Path::new("/mem"), "%2e%2e%2Foutside.md").is_err());
1185 assert!(resolve_memory_path(Path::new("/mem"), "sub/%2e%2e/outside.md").is_err());
1186 }
1187
1188 #[test]
1189 fn resolve_memory_path_rejects_unicode_fullwidth_dots() {
1190 assert!(resolve_memory_path(Path::new("/mem"), "foo/\u{FF0E}\u{FF0E}/bar.md").is_err());
1191 }
1192
1193 #[test]
1194 fn resolve_memory_path_rejects_unicode_fullwidth_slash() {
1195 assert!(resolve_memory_path(Path::new("/mem"), "foo\u{FF0F}bar.md").is_err());
1196 }
1197
1198 #[test]
1199 fn resolve_memory_path_accepts_normal_path_phase77_7() {
1200 let result = resolve_memory_path(Path::new("/mem"), "notes.md").unwrap();
1201 assert_eq!(result, Path::new("/mem/notes.md"));
1202 }
1203
1204 #[tokio::test]
1207 async fn llm_client_adapter_chat_round_trips() {
1208 use nexo_llm::{ChatRequest, ChatResponse, FinishReason, ResponseContent, TokenUsage};
1209 use std::sync::Mutex;
1210
1211 struct MockLlm {
1212 captured: Mutex<Option<ChatRequest>>,
1213 }
1214 #[async_trait]
1215 impl nexo_llm::LlmClient for MockLlm {
1216 async fn chat(&self, req: ChatRequest) -> anyhow::Result<ChatResponse> {
1217 *self.captured.lock().unwrap() = Some(req);
1218 Ok(ChatResponse {
1219 content: ResponseContent::Text("extracted-content".into()),
1220 usage: TokenUsage::default(),
1221 finish_reason: FinishReason::Stop,
1222 cache_usage: None,
1223 })
1224 }
1225 fn model_id(&self) -> &str {
1226 "mock-model"
1227 }
1228 }
1229
1230 let mock = Arc::new(MockLlm {
1231 captured: Mutex::new(None),
1232 });
1233 let adapter = LlmClientAdapter::new(
1234 Arc::clone(&mock) as Arc<dyn nexo_llm::LlmClient>,
1235 "test-model",
1236 );
1237 let result = adapter
1238 .chat("system prompt", "user msg", 1024)
1239 .await
1240 .unwrap();
1241 assert_eq!(result, "extracted-content");
1242
1243 let captured = mock.captured.lock().unwrap().take().unwrap();
1244 assert_eq!(captured.model, "test-model");
1245 assert_eq!(captured.system_prompt.as_deref(), Some("system prompt"));
1246 assert_eq!(captured.max_tokens, 1024);
1247 assert_eq!(captured.messages.len(), 1);
1248 }
1249
1250 #[tokio::test]
1251 async fn llm_client_adapter_errors_on_tool_call_response() {
1252 use nexo_llm::{
1253 ChatRequest, ChatResponse, FinishReason, ResponseContent, TokenUsage, ToolCall,
1254 };
1255
1256 struct ToolCallLlm;
1257 #[async_trait]
1258 impl nexo_llm::LlmClient for ToolCallLlm {
1259 async fn chat(&self, _req: ChatRequest) -> anyhow::Result<ChatResponse> {
1260 Ok(ChatResponse {
1261 content: ResponseContent::ToolCalls(vec![ToolCall {
1262 id: "id".into(),
1263 name: "noop".into(),
1264 arguments: serde_json::json!({}),
1265 }]),
1266 usage: TokenUsage::default(),
1267 finish_reason: FinishReason::ToolUse,
1268 cache_usage: None,
1269 })
1270 }
1271 fn model_id(&self) -> &str {
1272 "mock-tool"
1273 }
1274 }
1275
1276 let adapter = LlmClientAdapter::new(
1277 Arc::new(ToolCallLlm) as Arc<dyn nexo_llm::LlmClient>,
1278 "tool-model",
1279 );
1280 let err = adapter
1281 .chat("sys", "user", 256)
1282 .await
1283 .expect_err("tool_calls response must surface an error");
1284 assert!(err.contains("tool_calls"));
1285 }
1286}