1use std::collections::HashSet;
14use std::fs;
15use std::io;
16use std::path::{Path, PathBuf};
17use std::sync::{Arc, Mutex};
18use std::time::Instant;
19
20use async_trait::async_trait;
21use tracing::{debug, info, warn};
22
23use nexo_driver_types::{ExtractMemoriesConfig, GoalId, MemoryExtractor};
24
25use crate::events::{DriverEvent, ExtractSkipReason};
26use crate::extract_memories_prompt::build_extract_prompt;
27
28#[async_trait]
34pub trait ExtractMemoriesLlm: Send + Sync + 'static {
35 async fn chat(
37 &self,
38 system_prompt: &str,
39 user_messages: &str,
40 max_tokens: u32,
41 ) -> Result<String, String>;
42}
43
44#[derive(Debug, serde::Deserialize)]
48struct MemoryFile {
49 file_path: String,
50 content: String,
51}
52
53#[derive(Debug)]
55pub struct ExtractMemoriesOutcome {
56 pub memories_saved: u32,
57 pub duration_ms: u64,
58}
59
60struct PendingExtraction {
64 messages_text: String,
65 memory_dir: PathBuf,
66}
67
68struct ExtractMemoriesState {
69 last_message_uuid: Option<String>,
71 in_progress: bool,
73 turns_since_last: u32,
75 pending: Option<PendingExtraction>,
77 consecutive_failures: u32,
79}
80
81impl ExtractMemoriesState {
82 fn new() -> Self {
83 Self {
84 last_message_uuid: None,
85 in_progress: false,
86 turns_since_last: 0,
87 pending: None,
88 consecutive_failures: 0,
89 }
90 }
91}
92
93pub struct ExtractMemories {
96 config: ExtractMemoriesConfig,
97 state: Mutex<ExtractMemoriesState>,
98 llm: Arc<dyn ExtractMemoriesLlm>,
99 new_message_count: u32,
101 guard: Option<nexo_memory::SecretGuard>,
104}
105
106impl ExtractMemories {
107 pub fn new(config: ExtractMemoriesConfig, llm: Arc<dyn ExtractMemoriesLlm>) -> Self {
108 Self {
109 config,
110 state: Mutex::new(ExtractMemoriesState::new()),
111 llm,
112 new_message_count: 20,
113 guard: None,
114 }
115 }
116
117 pub fn with_message_count(mut self, n: u32) -> Self {
119 self.new_message_count = n;
120 self
121 }
122
123 pub fn with_guard(mut self, guard: nexo_memory::SecretGuard) -> Self {
126 self.guard = Some(guard);
127 self
128 }
129
130 pub fn check_gates(&self) -> Result<(), ExtractSkipReason> {
135 if !self.config.enabled {
136 return Err(ExtractSkipReason::Disabled);
137 }
138
139 let state = self.state.lock().unwrap();
140
141 if state.turns_since_last < self.config.turns_throttle.saturating_sub(1) {
143 return Err(ExtractSkipReason::Throttled);
144 }
145
146 if state.in_progress {
148 return Err(ExtractSkipReason::InProgress);
149 }
150
151 if self.config.max_consecutive_failures > 0
153 && state.consecutive_failures >= self.config.max_consecutive_failures
154 {
155 return Err(ExtractSkipReason::CircuitBreakerOpen);
156 }
157
158 Ok(())
159 }
160
161 pub fn tick(&self) {
164 let mut state = self.state.lock().unwrap();
165 state.turns_since_last = state.turns_since_last.saturating_add(1);
166 }
167
168 fn mark_started(&self) {
171 let mut state = self.state.lock().unwrap();
172 state.in_progress = true;
173 }
174
175 fn record_success(&self, last_message_uuid: Option<String>) {
177 let mut state = self.state.lock().unwrap();
178 state.in_progress = false;
179 state.turns_since_last = 0;
180 state.consecutive_failures = 0;
181 state.last_message_uuid = last_message_uuid;
182 }
183
184 fn record_failure(&self) {
186 let mut state = self.state.lock().unwrap();
187 state.in_progress = false;
188 state.consecutive_failures = state.consecutive_failures.saturating_add(1);
189 }
190
191 pub fn stash_pending(&self, messages_text: String, memory_dir: PathBuf) {
194 let mut state = self.state.lock().unwrap();
195 state.pending = Some(PendingExtraction {
196 messages_text,
197 memory_dir,
198 });
199 }
200
201 fn take_pending(&self) -> Option<PendingExtraction> {
203 let mut state = self.state.lock().unwrap();
204 state.pending.take()
205 }
206
207 pub fn extract(
217 self: &Arc<Self>,
218 goal_id: GoalId,
219 turn_index: u32,
220 messages_text: String,
221 memory_dir: PathBuf,
222 ) {
223 let skip_reason = self.check_gates().err();
225
226 if let Some(reason) = skip_reason {
227 if matches!(reason, ExtractSkipReason::InProgress) {
229 self.stash_pending(messages_text, memory_dir);
230 }
231 debug!(
232 goal_id = %goal_id.0,
233 reason = ?reason,
234 "ExtractMemories skipped"
235 );
236 return;
237 }
238
239 self.mark_started();
240
241 let this = Arc::clone(self);
242 tokio::spawn(async move {
243 let start = Instant::now();
244 match this.run_extraction(&messages_text, &memory_dir).await {
245 Ok(memories_saved) => {
246 let duration_ms = start.elapsed().as_millis() as u64;
247 info!(
248 goal_id = %goal_id.0,
249 memories_saved,
250 duration_ms,
251 "ExtractMemories completed"
252 );
253 this.record_success(None);
254 let _ = (goal_id, turn_index, memories_saved, duration_ms);
258 }
259 Err(e) => {
260 warn!(
261 goal_id = %goal_id.0,
262 error = %e,
263 "ExtractMemories failed"
264 );
265 this.record_failure();
266 }
267 }
268
269 if let Some(pending) = this.take_pending() {
271 debug!("ExtractMemories: draining coalesced extraction");
272 let start = Instant::now();
273 match this
274 .run_extraction(&pending.messages_text, &pending.memory_dir)
275 .await
276 {
277 Ok(n) => {
278 this.record_success(None);
279 info!(memories_saved = n, "ExtractMemories coalesced ok");
280 }
281 Err(e) => {
282 warn!(error = %e, "ExtractMemories coalesced failed");
283 this.record_failure();
284 }
285 }
286 let _ = start; }
288 });
289 }
290
291 async fn run_extraction(&self, messages_text: &str, memory_dir: &Path) -> Result<u32, String> {
294 let manifest = scan_memory_manifest(memory_dir).unwrap_or_default();
296
297 let system_prompt = build_extract_prompt(self.new_message_count, &manifest);
299
300 let response_text = self
302 .llm
303 .chat(&system_prompt, messages_text, self.config.max_turns * 1024)
304 .await?;
305
306 let files: Vec<MemoryFile> = parse_extraction_response(&response_text)?;
308
309 if files.is_empty() {
310 return Ok(0);
311 }
312
313 for f in &files {
315 let resolved = resolve_memory_path(memory_dir, &f.file_path)?;
316 if !resolved.starts_with(memory_dir) {
317 return Err(format!(
318 "path escape attempt: {} -> {}",
319 f.file_path,
320 resolved.display()
321 ));
322 }
323 }
324
325 let mut written = 0u32;
327 for f in &files {
328 let content_to_write = if let Some(ref guard) = self.guard {
329 match guard.check(&f.content) {
330 Ok(redacted) => redacted,
331 Err(e) => {
332 tracing::warn!(
333 target = "memory.secret.blocked",
334 rule_ids = ?e.rule_ids,
335 content_hash = %e.content_hash,
336 file = %f.file_path,
337 "extract_memories: secret blocked, skipping file"
338 );
339 continue; }
341 }
342 } else {
343 f.content.clone()
344 };
345
346 let dest = resolve_memory_path(memory_dir, &f.file_path)?;
347 if let Some(parent) = dest.parent() {
348 fs::create_dir_all(parent)
349 .map_err(|e| format!("mkdir {}: {e}", parent.display()))?;
350 }
351 fs::write(&dest, &content_to_write)
352 .map_err(|e| format!("write {}: {e}", dest.display()))?;
353 written += 1;
354 }
355
356 update_memory_index(memory_dir, &files)?;
358
359 Ok(written)
360 }
361
362 pub fn has_memory_writes(&self, messages_text: &str, memory_dir: &Path) -> bool {
368 has_memory_writes_in_text(messages_text, memory_dir)
369 }
370}
371
372pub fn scan_memory_manifest(memory_dir: &Path) -> Result<String, io::Error> {
379 if !memory_dir.exists() {
380 return Ok(String::new());
381 }
382
383 let mut lines: Vec<String> = Vec::new();
384 let entries = fs::read_dir(memory_dir)?;
385
386 for entry in entries {
387 let entry = entry?;
388 let path = entry.path();
389 if path.extension().map_or(true, |e| e != "md") {
390 continue;
391 }
392 if path.file_name().map_or(false, |n| n == "MEMORY.md") {
394 continue;
395 }
396
397 let Some(file_name) = path.file_name().and_then(|n| n.to_str()) else {
398 continue;
399 };
400
401 match read_frontmatter(&path) {
402 Ok(Some(fm)) => {
403 let mem_type = fm.get("type").and_then(|s| s.as_str()).unwrap_or("unknown");
404 let _name = fm
405 .get("name")
406 .and_then(|s| s.as_str())
407 .unwrap_or(file_name.trim_end_matches(".md"));
408 let desc = fm.get("description").and_then(|s| s.as_str()).unwrap_or("");
409 lines.push(format!("- [{mem_type}] {file_name}: {desc}"));
410 }
411 Ok(None) => {
412 lines.push(format!("- [unknown] {file_name}"));
414 }
415 Err(_) => {
416 lines.push(format!("- [unknown] {file_name} (unreadable)"));
417 }
418 }
419 }
420
421 Ok(lines.join("\n"))
422}
423
424fn read_frontmatter(
427 path: &Path,
428) -> Result<Option<serde_json::Map<String, serde_json::Value>>, io::Error> {
429 let content = fs::read_to_string(path)?;
430 let mut lines = content.lines();
431
432 if lines.next() != Some("---") {
434 return Ok(None);
435 }
436
437 let mut yaml_lines: Vec<&str> = Vec::new();
438 for line in &mut lines {
439 if line == "---" {
440 break;
441 }
442 yaml_lines.push(line);
443 }
444
445 if yaml_lines.is_empty() {
446 return Ok(None);
447 }
448
449 let yaml_str = yaml_lines.join("\n");
450 let map: serde_json::Map<String, serde_json::Value> = serde_yaml::from_str(&yaml_str)
451 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("yaml parse: {e}")))?;
452
453 Ok(Some(map))
454}
455
456fn parse_extraction_response(text: &str) -> Result<Vec<MemoryFile>, String> {
459 let json_str = text
461 .trim()
462 .strip_prefix("```json")
463 .and_then(|s| s.strip_suffix("```"))
464 .map(|s| s.trim())
465 .unwrap_or(text.trim());
466
467 serde_json::from_str::<Vec<MemoryFile>>(json_str)
468 .map_err(|e| format!("parse extraction response: {e}"))
469}
470
471fn resolve_memory_path(memory_dir: &Path, file_path: &str) -> Result<PathBuf, String> {
477 if file_path.contains('\0') {
479 return Err(format!("null byte in path: {file_path}"));
480 }
481
482 let lower = file_path.to_lowercase();
486 if lower.contains("%2e") || lower.contains("%2f") || lower.contains("%5c") {
487 if let Ok(decoded) = urlencoding_maybe(file_path) {
489 if decoded.contains("..") {
490 return Err(format!("URL-encoded traversal rejected: {file_path}"));
491 }
492 }
493 }
494
495 if file_path.contains('\u{FF0E}')
499 || file_path.contains('\u{FF0F}')
500 || file_path.contains('\u{FF3C}')
501 || file_path.contains('\u{2215}')
502 {
504 return Err(format!("unicode traversal rejected: {file_path}"));
505 }
506
507 let p = Path::new(file_path);
508 if p.is_absolute() {
509 return Err(format!("absolute path rejected: {file_path}"));
510 }
511
512 let mut normalized = PathBuf::new();
514 for component in p.components() {
515 match component {
516 std::path::Component::ParentDir => {
517 return Err(format!("path traversal rejected: {file_path}"));
518 }
519 std::path::Component::Normal(c) => normalized.push(c),
520 std::path::Component::CurDir => {}
521 _ => return Err(format!("invalid path component in: {file_path}")),
522 }
523 }
524
525 let resolved = memory_dir.join(&normalized);
526
527 if memory_dir.exists() {
530 match resolved.canonicalize() {
531 Ok(real) => {
532 if !real.starts_with(memory_dir) {
533 return Err(format!("symlink escape rejected: {file_path}"));
534 }
535 }
536 Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
537 let mut current = resolved.clone();
539 while let Some(parent) = current.parent() {
540 if parent.as_os_str().is_empty() {
541 break;
542 }
543 match parent.canonicalize() {
544 Ok(real_parent) => {
545 if !real_parent.starts_with(memory_dir) {
546 return Err(format!("symlink escape in parent of: {file_path}"));
547 }
548 break; }
550 Err(_) => {
551 current = parent.to_path_buf();
552 continue; }
554 }
555 }
556 }
557 Err(_) => {
558 }
560 }
561 }
562
563 Ok(resolved)
564}
565
566fn urlencoding_maybe(input: &str) -> Result<String, ()> {
569 if !input.contains('%') {
570 return Ok(input.to_string());
571 }
572 let mut out = String::with_capacity(input.len());
573 let mut chars = input.chars();
574 while let Some(c) = chars.next() {
575 if c == '%' {
576 let h1 = chars.next().ok_or(())?;
577 let h2 = chars.next().ok_or(())?;
578 let byte = u8::from_str_radix(&format!("{h1}{h2}"), 16).map_err(|_| ())?;
579 out.push(byte as char);
580 } else {
581 out.push(c);
582 }
583 }
584 Ok(out)
585}
586
587fn update_memory_index(memory_dir: &Path, files: &[MemoryFile]) -> Result<(), String> {
592 let index_path = memory_dir.join("MEMORY.md");
593 let existing = if index_path.exists() {
594 fs::read_to_string(&index_path).map_err(|e| format!("read MEMORY.md: {e}"))?
595 } else {
596 String::from("# Memory index\n\n")
597 };
598
599 let existing_paths: HashSet<&str> = existing
600 .lines()
601 .filter_map(|line| {
602 line.trim()
603 .strip_prefix("- [")
604 .and_then(|rest| rest.split_once("]("))
605 .and_then(|(_, rest)| rest.split_once(')').map(|(path, _)| path))
606 })
607 .collect();
608
609 let mut new_lines: Vec<String> = Vec::new();
610 for f in files {
611 if existing_paths.contains(f.file_path.as_str()) {
612 continue;
613 }
614 let mut in_frontmatter = false;
617 let mut closed_frontmatter = false;
618 let hook = f
619 .content
620 .lines()
621 .find(|l| {
622 if l.trim() == "---" {
623 if !in_frontmatter {
624 in_frontmatter = true;
625 } else if in_frontmatter && !closed_frontmatter {
626 closed_frontmatter = true;
627 }
628 return false;
629 }
630 if in_frontmatter && !closed_frontmatter {
632 return false;
633 }
634 !l.is_empty()
636 })
637 .map(|l| {
638 let trimmed = l.trim();
639 if trimmed.len() > 80 {
641 format!("{}…", &trimmed[..80])
642 } else {
643 trimmed.to_string()
644 }
645 })
646 .unwrap_or_default();
647 new_lines.push(format!("- [{}]({}) — {}", f.file_path, f.file_path, hook));
648 }
649
650 if new_lines.is_empty() {
651 return Ok(());
652 }
653
654 let mut updated = existing;
655 while updated.ends_with('\n') {
657 updated.pop();
658 }
659 updated.push('\n');
660 for line in &new_lines {
661 updated.push_str(line);
662 updated.push('\n');
663 }
664
665 fs::write(&index_path, updated).map_err(|e| format!("write MEMORY.md: {e}"))?;
666 Ok(())
667}
668
669fn has_memory_writes_in_text(messages_text: &str, memory_dir: &Path) -> bool {
675 let mem_dir_str = memory_dir.to_string_lossy();
676 let has_memory_path = messages_text.contains(mem_dir_str.as_ref());
679 if !has_memory_path {
680 return false;
681 }
682 let write_patterns = [
684 "Write",
685 "\"name\": \"Write\"",
686 "\"name\":\"Write\"",
687 "Edit",
688 "\"name\": \"Edit\"",
689 "\"name\":\"Edit\"",
690 "file_write",
691 "file_edit",
692 "write_to_file",
693 ];
694 write_patterns.iter().any(|p| messages_text.contains(p))
695}
696
697impl ExtractSkipReason {
700 pub fn to_event(self, goal_id: GoalId) -> DriverEvent {
702 DriverEvent::ExtractMemoriesSkipped {
703 goal_id,
704 reason: self,
705 }
706 }
707}
708
709pub struct NoopExtractMemoriesLlm {
712 pub canned_response: Mutex<Option<String>>,
714}
715
716impl NoopExtractMemoriesLlm {
717 pub fn new() -> Self {
718 Self {
719 canned_response: Mutex::new(None),
720 }
721 }
722
723 pub fn with_response(response: impl Into<String>) -> Self {
724 Self {
725 canned_response: Mutex::new(Some(response.into())),
726 }
727 }
728}
729
730impl Default for NoopExtractMemoriesLlm {
731 fn default() -> Self {
732 Self::new()
733 }
734}
735
736#[async_trait]
737impl ExtractMemoriesLlm for NoopExtractMemoriesLlm {
738 async fn chat(
739 &self,
740 _system_prompt: &str,
741 _user_messages: &str,
742 _max_tokens: u32,
743 ) -> Result<String, String> {
744 self.canned_response
745 .lock()
746 .unwrap()
747 .take()
748 .ok_or_else(|| "NoopExtractMemoriesLlm: no canned response set".to_string())
749 }
750}
751
752pub struct LlmClientAdapter {
781 llm: Arc<dyn nexo_llm::LlmClient>,
782 model: String,
783}
784
785impl LlmClientAdapter {
786 pub fn new(llm: Arc<dyn nexo_llm::LlmClient>, model: impl Into<String>) -> Self {
787 Self {
788 llm,
789 model: model.into(),
790 }
791 }
792}
793
794#[async_trait]
795impl ExtractMemoriesLlm for LlmClientAdapter {
796 async fn chat(
797 &self,
798 system_prompt: &str,
799 user_messages: &str,
800 max_tokens: u32,
801 ) -> Result<String, String> {
802 let mut req = nexo_llm::ChatRequest::new(
803 self.model.clone(),
804 vec![nexo_llm::ChatMessage::user(user_messages.to_string())],
805 );
806 req.system_prompt = Some(system_prompt.to_string());
807 req.max_tokens = max_tokens;
808 let resp = self
809 .llm
810 .chat(req)
811 .await
812 .map_err(|e| format!("LlmClientAdapter chat error: {e}"))?;
813 match resp.content {
814 nexo_llm::ResponseContent::Text(text) => Ok(text),
815 nexo_llm::ResponseContent::ToolCalls(_) => {
816 Err("LlmClientAdapter: response is tool_calls, expected text".into())
817 }
818 }
819 }
820}
821
822impl MemoryExtractor for ExtractMemories {
835 fn tick(&self) {
836 ExtractMemories::tick(self);
837 }
838
839 fn extract(
840 self: Arc<Self>,
841 goal_id: GoalId,
842 turn_index: u32,
843 messages_text: String,
844 memory_dir: PathBuf,
845 ) {
846 ExtractMemories::extract(&self, goal_id, turn_index, messages_text, memory_dir);
849 }
850}
851
852#[cfg(test)]
855mod tests {
856 use super::*;
857 use tempfile::TempDir;
858
859 #[test]
862 fn scan_manifest_empty_dir() {
863 let dir = TempDir::new().unwrap();
864 let manifest = scan_memory_manifest(dir.path()).unwrap();
865 assert!(manifest.is_empty());
866 }
867
868 #[test]
869 fn scan_manifest_nonexistent_dir() {
870 let manifest = scan_memory_manifest(Path::new("/tmp/nonexistent-memdir-77-5")).unwrap();
871 assert!(manifest.is_empty());
872 }
873
874 #[test]
875 fn scan_manifest_reads_frontmatter() {
876 let dir = TempDir::new().unwrap();
877 fs::write(
878 dir.path().join("preferences.md"),
879 "---\nname: user preferences\ndescription: likes dark mode\ntype: user\n---\n\nUser prefers dark mode.",
880 )
881 .unwrap();
882 fs::write(
883 dir.path().join("deploy.md"),
884 "---\nname: deploy notes\ndescription: deploy process\ntype: project\n---\n\nDeploy on Fridays.",
885 )
886 .unwrap();
887
888 let manifest = scan_memory_manifest(dir.path()).unwrap();
889 assert!(
890 manifest.contains("preferences.md"),
891 "missing preferences: {manifest}"
892 );
893 assert!(
894 manifest.contains("dark mode"),
895 "missing description: {manifest}"
896 );
897 assert!(manifest.contains("[user]"), "missing type tag: {manifest}");
898 assert!(manifest.contains("deploy.md"), "missing deploy: {manifest}");
899 assert!(
900 manifest.contains("[project]"),
901 "missing project type: {manifest}"
902 );
903 }
904
905 #[test]
906 fn scan_manifest_skips_memory_index() {
907 let dir = TempDir::new().unwrap();
908 fs::write(
909 dir.path().join("MEMORY.md"),
910 "# Memory index\n\n- [prefs](preferences.md)\n",
911 )
912 .unwrap();
913 fs::write(
914 dir.path().join("preferences.md"),
915 "---\nname: prefs\ndescription: x\ntype: user\n---\n\nContent.",
916 )
917 .unwrap();
918
919 let manifest = scan_memory_manifest(dir.path()).unwrap();
920 assert!(
921 !manifest.contains("MEMORY.md"),
922 "MEMORY.md should be excluded: {manifest}"
923 );
924 assert!(
925 manifest.contains("preferences.md"),
926 "should list preferences: {manifest}"
927 );
928 }
929
930 #[test]
931 fn scan_manifest_file_without_frontmatter() {
932 let dir = TempDir::new().unwrap();
933 fs::write(
934 dir.path().join("notes.md"),
935 "Just some notes.\nNo frontmatter here.",
936 )
937 .unwrap();
938
939 let manifest = scan_memory_manifest(dir.path()).unwrap();
940 assert!(
941 manifest.contains("[unknown]"),
942 "should tag as unknown: {manifest}"
943 );
944 assert!(
945 manifest.contains("notes.md"),
946 "should list the file: {manifest}"
947 );
948 }
949
950 #[test]
953 fn has_memory_writes_detects_write_tool() {
954 let text = r#"Tool: Write
955Arguments: {"file_path": "/home/user/.claude/projects/test/memory/foo.md", "content": "..."}"#;
956 assert!(has_memory_writes_in_text(
957 text,
958 Path::new("/home/user/.claude/projects/test/memory")
959 ));
960 }
961
962 #[test]
963 fn has_memory_writes_detects_file_write_tool() {
964 let text = r#"I'll use file_write to save this memory.
965{"tool": "file_write", "path": "/home/user/.claude/projects/x/memory/bar.md"}"#;
966 assert!(has_memory_writes_in_text(
967 text,
968 Path::new("/home/user/.claude/projects/x/memory")
969 ));
970 }
971
972 #[test]
973 fn has_memory_writes_no_write() {
974 let text = "Just a normal conversation.\nNo tool calls here.";
975 assert!(!has_memory_writes_in_text(
976 text,
977 Path::new("/home/user/.claude/projects/test/memory")
978 ));
979 }
980
981 #[test]
982 fn has_memory_writes_write_outside_memory_dir() {
983 let text = r#"Write to /tmp/some-other-file.txt"#;
984 assert!(!has_memory_writes_in_text(
985 text,
986 Path::new("/home/user/memory")
987 ));
988 }
989
990 #[test]
993 fn resolve_memory_path_rejects_absolute() {
994 assert!(resolve_memory_path(Path::new("/mem"), "/etc/passwd").is_err());
995 }
996
997 #[test]
998 fn resolve_memory_path_rejects_parent_traversal() {
999 assert!(resolve_memory_path(Path::new("/mem"), "../outside.md").is_err());
1000 assert!(resolve_memory_path(Path::new("/mem"), "sub/../../outside.md").is_err());
1001 }
1002
1003 #[test]
1004 fn resolve_memory_path_accepts_normal() {
1005 let result = resolve_memory_path(Path::new("/mem"), "user_role.md").unwrap();
1006 assert_eq!(result, PathBuf::from("/mem/user_role.md"));
1007 }
1008
1009 #[test]
1010 fn resolve_memory_path_accepts_subdir() {
1011 let result = resolve_memory_path(Path::new("/mem"), "sub/dir/file.md").unwrap();
1012 assert_eq!(result, PathBuf::from("/mem/sub/dir/file.md"));
1013 }
1014
1015 #[test]
1018 fn parse_response_bare_json() {
1019 let json = r#"[{"file_path": "test.md", "content": "hello"}]"#;
1020 let files = parse_extraction_response(json).unwrap();
1021 assert_eq!(files.len(), 1);
1022 assert_eq!(files[0].file_path, "test.md");
1023 assert_eq!(files[0].content, "hello");
1024 }
1025
1026 #[test]
1027 fn parse_response_json_fenced() {
1028 let json = "```json\n[{\"file_path\": \"x.md\", \"content\": \"y\"}]\n```";
1029 let files = parse_extraction_response(json).unwrap();
1030 assert_eq!(files.len(), 1);
1031 assert_eq!(files[0].file_path, "x.md");
1032 }
1033
1034 #[test]
1035 fn parse_response_empty_array() {
1036 let files = parse_extraction_response("[]").unwrap();
1037 assert!(files.is_empty());
1038 }
1039
1040 #[test]
1041 fn parse_response_invalid_json() {
1042 assert!(parse_extraction_response("not json").is_err());
1043 }
1044
1045 fn make_config() -> ExtractMemoriesConfig {
1048 ExtractMemoriesConfig {
1049 enabled: true,
1050 turns_throttle: 1,
1051 max_turns: 5,
1052 max_consecutive_failures: 3,
1053 }
1054 }
1055
1056 fn make_extractor(config: ExtractMemoriesConfig) -> Arc<ExtractMemories> {
1057 Arc::new(ExtractMemories::new(
1058 config,
1059 Arc::new(NoopExtractMemoriesLlm::new()),
1060 ))
1061 }
1062
1063 #[test]
1064 fn gate_disabled_when_enabled_false() {
1065 let mut cfg = make_config();
1066 cfg.enabled = false;
1067 let ext = make_extractor(cfg);
1068 assert!(matches!(
1069 ext.check_gates(),
1070 Err(ExtractSkipReason::Disabled)
1071 ));
1072 }
1073
1074 #[test]
1075 fn gate_throttled_when_not_enough_turns() {
1076 let mut cfg = make_config();
1077 cfg.turns_throttle = 3;
1078 let ext = make_extractor(cfg);
1079 assert!(matches!(
1082 ext.check_gates(),
1083 Err(ExtractSkipReason::Throttled)
1084 ));
1085 }
1086
1087 #[test]
1088 fn gate_passes_when_throttle_satisfied() {
1089 let cfg = make_config(); let ext = make_extractor(cfg);
1091 assert!(ext.check_gates().is_ok());
1092 }
1093
1094 #[test]
1095 fn gate_passes_after_tick_accumulates() {
1096 let mut cfg = make_config();
1097 cfg.turns_throttle = 2;
1098 let ext = make_extractor(cfg);
1099 assert!(ext.check_gates().is_err());
1101 ext.tick();
1102 assert!(ext.check_gates().is_ok());
1104 }
1105
1106 #[test]
1107 fn gate_circuit_breaker_trips_after_n_failures() {
1108 let ext = make_extractor(make_config());
1109 ext.record_failure();
1110 ext.record_failure();
1111 ext.record_failure();
1112 assert!(matches!(
1113 ext.check_gates(),
1114 Err(ExtractSkipReason::CircuitBreakerOpen)
1115 ));
1116 }
1117
1118 #[test]
1119 fn gate_circuit_breaker_disabled_when_max_zero() {
1120 let mut cfg = make_config();
1121 cfg.max_consecutive_failures = 0;
1122 let ext = make_extractor(cfg);
1123 ext.record_failure();
1124 ext.record_failure();
1125 ext.record_failure();
1126 assert!(ext.check_gates().is_ok());
1128 }
1129
1130 #[test]
1131 fn record_success_resets_failures_and_turns() {
1132 let ext = make_extractor(make_config());
1133 ext.record_failure();
1134 ext.record_failure();
1135 ext.record_success(None);
1136 assert!(ext.check_gates().is_ok());
1138 }
1139
1140 #[test]
1143 fn update_index_creates_file_when_missing() {
1144 let dir = TempDir::new().unwrap();
1145 let files = vec![MemoryFile {
1146 file_path: "new_memory.md".to_string(),
1147 content: "---\nname: test\ntype: user\n---\n\nSome content here.".to_string(),
1148 }];
1149 update_memory_index(dir.path(), &files).unwrap();
1150
1151 let index = fs::read_to_string(dir.path().join("MEMORY.md")).unwrap();
1152 assert!(index.contains("new_memory.md"), "should list new file");
1153 assert!(index.contains("Some content here"), "should include hook");
1154 }
1155
1156 #[test]
1157 fn update_index_skips_duplicates() {
1158 let dir = TempDir::new().unwrap();
1159 fs::write(
1160 dir.path().join("MEMORY.md"),
1161 "# Memory index\n\n- [existing](existing.md) — already there\n",
1162 )
1163 .unwrap();
1164
1165 let files = vec![
1166 MemoryFile {
1167 file_path: "existing.md".to_string(),
1168 content: "duplicate".to_string(),
1169 },
1170 MemoryFile {
1171 file_path: "new_one.md".to_string(),
1172 content: "new content here".to_string(),
1173 },
1174 ];
1175 update_memory_index(dir.path(), &files).unwrap();
1176
1177 let index = fs::read_to_string(dir.path().join("MEMORY.md")).unwrap();
1178 let existing_count = index.matches("existing.md").count();
1179 assert_eq!(existing_count, 1, "duplicate should not be appended");
1180 assert!(index.contains("new_one.md"), "new file should be appended");
1181 }
1182
1183 #[test]
1186 fn resolve_memory_path_rejects_null_byte() {
1187 assert!(resolve_memory_path(Path::new("/mem"), "foo\0bar.md").is_err());
1188 }
1189
1190 #[test]
1191 fn resolve_memory_path_rejects_url_encoded_traversal() {
1192 assert!(resolve_memory_path(Path::new("/mem"), "%2e%2e%2foutside.md").is_err());
1193 assert!(resolve_memory_path(Path::new("/mem"), "%2e%2e%2Foutside.md").is_err());
1194 assert!(resolve_memory_path(Path::new("/mem"), "sub/%2e%2e/outside.md").is_err());
1195 }
1196
1197 #[test]
1198 fn resolve_memory_path_rejects_unicode_fullwidth_dots() {
1199 assert!(resolve_memory_path(Path::new("/mem"), "foo/\u{FF0E}\u{FF0E}/bar.md").is_err());
1200 }
1201
1202 #[test]
1203 fn resolve_memory_path_rejects_unicode_fullwidth_slash() {
1204 assert!(resolve_memory_path(Path::new("/mem"), "foo\u{FF0F}bar.md").is_err());
1205 }
1206
1207 #[test]
1208 fn resolve_memory_path_accepts_normal_path_phase77_7() {
1209 let result = resolve_memory_path(Path::new("/mem"), "notes.md").unwrap();
1210 assert_eq!(result, Path::new("/mem/notes.md"));
1211 }
1212
1213 #[tokio::test]
1216 async fn llm_client_adapter_chat_round_trips() {
1217 use nexo_llm::{ChatRequest, ChatResponse, FinishReason, ResponseContent, TokenUsage};
1218 use std::sync::Mutex;
1219
1220 struct MockLlm {
1221 captured: Mutex<Option<ChatRequest>>,
1222 }
1223 #[async_trait]
1224 impl nexo_llm::LlmClient for MockLlm {
1225 async fn chat(&self, req: ChatRequest) -> anyhow::Result<ChatResponse> {
1226 *self.captured.lock().unwrap() = Some(req);
1227 Ok(ChatResponse {
1228 content: ResponseContent::Text("extracted-content".into()),
1229 usage: TokenUsage::default(),
1230 finish_reason: FinishReason::Stop,
1231 cache_usage: None,
1232 })
1233 }
1234 fn model_id(&self) -> &str {
1235 "mock-model"
1236 }
1237 }
1238
1239 let mock = Arc::new(MockLlm {
1240 captured: Mutex::new(None),
1241 });
1242 let adapter = LlmClientAdapter::new(
1243 Arc::clone(&mock) as Arc<dyn nexo_llm::LlmClient>,
1244 "test-model",
1245 );
1246 let result = adapter
1247 .chat("system prompt", "user msg", 1024)
1248 .await
1249 .unwrap();
1250 assert_eq!(result, "extracted-content");
1251
1252 let captured = mock.captured.lock().unwrap().take().unwrap();
1253 assert_eq!(captured.model, "test-model");
1254 assert_eq!(captured.system_prompt.as_deref(), Some("system prompt"));
1255 assert_eq!(captured.max_tokens, 1024);
1256 assert_eq!(captured.messages.len(), 1);
1257 }
1258
1259 #[tokio::test]
1260 async fn llm_client_adapter_errors_on_tool_call_response() {
1261 use nexo_llm::{
1262 ChatRequest, ChatResponse, FinishReason, ResponseContent, TokenUsage, ToolCall,
1263 };
1264
1265 struct ToolCallLlm;
1266 #[async_trait]
1267 impl nexo_llm::LlmClient for ToolCallLlm {
1268 async fn chat(&self, _req: ChatRequest) -> anyhow::Result<ChatResponse> {
1269 Ok(ChatResponse {
1270 content: ResponseContent::ToolCalls(vec![ToolCall {
1271 id: "id".into(),
1272 name: "noop".into(),
1273 arguments: serde_json::json!({}),
1274 }]),
1275 usage: TokenUsage::default(),
1276 finish_reason: FinishReason::ToolUse,
1277 cache_usage: None,
1278 })
1279 }
1280 fn model_id(&self) -> &str {
1281 "mock-tool"
1282 }
1283 }
1284
1285 let adapter = LlmClientAdapter::new(
1286 Arc::new(ToolCallLlm) as Arc<dyn nexo_llm::LlmClient>,
1287 "tool-model",
1288 );
1289 let err = adapter
1290 .chat("sys", "user", 256)
1291 .await
1292 .expect_err("tool_calls response must surface an error");
1293 assert!(err.contains("tool_calls"));
1294 }
1295}