use std::collections::HashSet;
use std::fs;
use std::io;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use std::time::Instant;
use async_trait::async_trait;
use tracing::{debug, info, warn};
use nexo_driver_types::{ExtractMemoriesConfig, GoalId, MemoryExtractor};
use crate::events::{DriverEvent, ExtractSkipReason};
use crate::extract_memories_prompt::build_extract_prompt;
#[async_trait]
pub trait ExtractMemoriesLlm: Send + Sync + 'static {
async fn chat(
&self,
system_prompt: &str,
user_messages: &str,
max_tokens: u32,
) -> Result<String, String>;
}
#[derive(Debug, serde::Deserialize)]
struct MemoryFile {
file_path: String,
content: String,
}
#[derive(Debug)]
pub struct ExtractMemoriesOutcome {
pub memories_saved: u32,
pub duration_ms: u64,
}
struct PendingExtraction {
messages_text: String,
memory_dir: PathBuf,
}
struct ExtractMemoriesState {
last_message_uuid: Option<String>,
in_progress: bool,
turns_since_last: u32,
pending: Option<PendingExtraction>,
consecutive_failures: u32,
}
impl ExtractMemoriesState {
fn new() -> Self {
Self {
last_message_uuid: None,
in_progress: false,
turns_since_last: 0,
pending: None,
consecutive_failures: 0,
}
}
}
pub struct ExtractMemories {
config: ExtractMemoriesConfig,
state: Mutex<ExtractMemoriesState>,
llm: Arc<dyn ExtractMemoriesLlm>,
new_message_count: u32,
guard: Option<nexo_memory::SecretGuard>,
}
impl ExtractMemories {
pub fn new(config: ExtractMemoriesConfig, llm: Arc<dyn ExtractMemoriesLlm>) -> Self {
Self {
config,
state: Mutex::new(ExtractMemoriesState::new()),
llm,
new_message_count: 20,
guard: None,
}
}
pub fn with_message_count(mut self, n: u32) -> Self {
self.new_message_count = n;
self
}
pub fn with_guard(mut self, guard: nexo_memory::SecretGuard) -> Self {
self.guard = Some(guard);
self
}
pub fn check_gates(&self) -> Result<(), ExtractSkipReason> {
if !self.config.enabled {
return Err(ExtractSkipReason::Disabled);
}
let state = self.state.lock().unwrap();
if state.turns_since_last < self.config.turns_throttle.saturating_sub(1) {
return Err(ExtractSkipReason::Throttled);
}
if state.in_progress {
return Err(ExtractSkipReason::InProgress);
}
if self.config.max_consecutive_failures > 0
&& state.consecutive_failures >= self.config.max_consecutive_failures
{
return Err(ExtractSkipReason::CircuitBreakerOpen);
}
Ok(())
}
pub fn tick(&self) {
let mut state = self.state.lock().unwrap();
state.turns_since_last = state.turns_since_last.saturating_add(1);
}
fn mark_started(&self) {
let mut state = self.state.lock().unwrap();
state.in_progress = true;
}
fn record_success(&self, last_message_uuid: Option<String>) {
let mut state = self.state.lock().unwrap();
state.in_progress = false;
state.turns_since_last = 0;
state.consecutive_failures = 0;
state.last_message_uuid = last_message_uuid;
}
fn record_failure(&self) {
let mut state = self.state.lock().unwrap();
state.in_progress = false;
state.consecutive_failures = state.consecutive_failures.saturating_add(1);
}
pub fn stash_pending(&self, messages_text: String, memory_dir: PathBuf) {
let mut state = self.state.lock().unwrap();
state.pending = Some(PendingExtraction {
messages_text,
memory_dir,
});
}
fn take_pending(&self) -> Option<PendingExtraction> {
let mut state = self.state.lock().unwrap();
state.pending.take()
}
pub fn extract(
self: &Arc<Self>,
goal_id: GoalId,
turn_index: u32,
messages_text: String,
memory_dir: PathBuf,
) {
let skip_reason = self.check_gates().err();
if let Some(reason) = skip_reason {
if matches!(reason, ExtractSkipReason::InProgress) {
self.stash_pending(messages_text, memory_dir);
}
debug!(
goal_id = %goal_id.0,
reason = ?reason,
"ExtractMemories skipped"
);
return;
}
self.mark_started();
let this = Arc::clone(self);
tokio::spawn(async move {
let start = Instant::now();
match this.run_extraction(&messages_text, &memory_dir).await {
Ok(memories_saved) => {
let duration_ms = start.elapsed().as_millis() as u64;
info!(
goal_id = %goal_id.0,
memories_saved,
duration_ms,
"ExtractMemories completed"
);
this.record_success(None);
let _ = (goal_id, turn_index, memories_saved, duration_ms);
}
Err(e) => {
warn!(
goal_id = %goal_id.0,
error = %e,
"ExtractMemories failed"
);
this.record_failure();
}
}
if let Some(pending) = this.take_pending() {
debug!("ExtractMemories: draining coalesced extraction");
let start = Instant::now();
match this
.run_extraction(&pending.messages_text, &pending.memory_dir)
.await
{
Ok(n) => {
this.record_success(None);
info!(memories_saved = n, "ExtractMemories coalesced ok");
}
Err(e) => {
warn!(error = %e, "ExtractMemories coalesced failed");
this.record_failure();
}
}
let _ = start; }
});
}
async fn run_extraction(&self, messages_text: &str, memory_dir: &Path) -> Result<u32, String> {
let manifest = scan_memory_manifest(memory_dir).unwrap_or_default();
let system_prompt = build_extract_prompt(self.new_message_count, &manifest);
let response_text = self
.llm
.chat(&system_prompt, messages_text, self.config.max_turns * 1024)
.await?;
let files: Vec<MemoryFile> = parse_extraction_response(&response_text)?;
if files.is_empty() {
return Ok(0);
}
for f in &files {
let resolved = resolve_memory_path(memory_dir, &f.file_path)?;
if !resolved.starts_with(memory_dir) {
return Err(format!(
"path escape attempt: {} -> {}",
f.file_path,
resolved.display()
));
}
}
let mut written = 0u32;
for f in &files {
let content_to_write = if let Some(ref guard) = self.guard {
match guard.check(&f.content) {
Ok(redacted) => redacted,
Err(e) => {
tracing::warn!(
target = "memory.secret.blocked",
rule_ids = ?e.rule_ids,
content_hash = %e.content_hash,
file = %f.file_path,
"extract_memories: secret blocked, skipping file"
);
continue; }
}
} else {
f.content.clone()
};
let dest = resolve_memory_path(memory_dir, &f.file_path)?;
if let Some(parent) = dest.parent() {
fs::create_dir_all(parent)
.map_err(|e| format!("mkdir {}: {e}", parent.display()))?;
}
fs::write(&dest, &content_to_write)
.map_err(|e| format!("write {}: {e}", dest.display()))?;
written += 1;
}
update_memory_index(memory_dir, &files)?;
Ok(written)
}
pub fn has_memory_writes(&self, messages_text: &str, memory_dir: &Path) -> bool {
has_memory_writes_in_text(messages_text, memory_dir)
}
}
pub fn scan_memory_manifest(memory_dir: &Path) -> Result<String, io::Error> {
if !memory_dir.exists() {
return Ok(String::new());
}
let mut lines: Vec<String> = Vec::new();
let entries = fs::read_dir(memory_dir)?;
for entry in entries {
let entry = entry?;
let path = entry.path();
if path.extension().map_or(true, |e| e != "md") {
continue;
}
if path.file_name().map_or(false, |n| n == "MEMORY.md") {
continue;
}
let Some(file_name) = path.file_name().and_then(|n| n.to_str()) else {
continue;
};
match read_frontmatter(&path) {
Ok(Some(fm)) => {
let mem_type = fm.get("type").and_then(|s| s.as_str()).unwrap_or("unknown");
let _name = fm
.get("name")
.and_then(|s| s.as_str())
.unwrap_or(file_name.trim_end_matches(".md"));
let desc = fm.get("description").and_then(|s| s.as_str()).unwrap_or("");
lines.push(format!("- [{mem_type}] {file_name}: {desc}"));
}
Ok(None) => {
lines.push(format!("- [unknown] {file_name}"));
}
Err(_) => {
lines.push(format!("- [unknown] {file_name} (unreadable)"));
}
}
}
Ok(lines.join("\n"))
}
fn read_frontmatter(
path: &Path,
) -> Result<Option<serde_json::Map<String, serde_json::Value>>, io::Error> {
let content = fs::read_to_string(path)?;
let mut lines = content.lines();
if lines.next() != Some("---") {
return Ok(None);
}
let mut yaml_lines: Vec<&str> = Vec::new();
for line in &mut lines {
if line == "---" {
break;
}
yaml_lines.push(line);
}
if yaml_lines.is_empty() {
return Ok(None);
}
let yaml_str = yaml_lines.join("\n");
let map: serde_json::Map<String, serde_json::Value> = serde_yaml::from_str(&yaml_str)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("yaml parse: {e}")))?;
Ok(Some(map))
}
fn parse_extraction_response(text: &str) -> Result<Vec<MemoryFile>, String> {
let json_str = text
.trim()
.strip_prefix("```json")
.and_then(|s| s.strip_suffix("```"))
.map(|s| s.trim())
.unwrap_or(text.trim());
serde_json::from_str::<Vec<MemoryFile>>(json_str)
.map_err(|e| format!("parse extraction response: {e}"))
}
fn resolve_memory_path(memory_dir: &Path, file_path: &str) -> Result<PathBuf, String> {
if file_path.contains('\0') {
return Err(format!("null byte in path: {file_path}"));
}
let lower = file_path.to_lowercase();
if lower.contains("%2e") || lower.contains("%2f") || lower.contains("%5c") {
if let Ok(decoded) = urlencoding_maybe(file_path) {
if decoded.contains("..") {
return Err(format!("URL-encoded traversal rejected: {file_path}"));
}
}
}
if file_path.contains('\u{FF0E}')
|| file_path.contains('\u{FF0F}')
|| file_path.contains('\u{FF3C}')
|| file_path.contains('\u{2215}')
{
return Err(format!("unicode traversal rejected: {file_path}"));
}
let p = Path::new(file_path);
if p.is_absolute() {
return Err(format!("absolute path rejected: {file_path}"));
}
let mut normalized = PathBuf::new();
for component in p.components() {
match component {
std::path::Component::ParentDir => {
return Err(format!("path traversal rejected: {file_path}"));
}
std::path::Component::Normal(c) => normalized.push(c),
std::path::Component::CurDir => {}
_ => return Err(format!("invalid path component in: {file_path}")),
}
}
let resolved = memory_dir.join(&normalized);
if memory_dir.exists() {
match resolved.canonicalize() {
Ok(real) => {
if !real.starts_with(memory_dir) {
return Err(format!("symlink escape rejected: {file_path}"));
}
}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
let mut current = resolved.clone();
while let Some(parent) = current.parent() {
if parent.as_os_str().is_empty() {
break;
}
match parent.canonicalize() {
Ok(real_parent) => {
if !real_parent.starts_with(memory_dir) {
return Err(format!("symlink escape in parent of: {file_path}"));
}
break; }
Err(_) => {
current = parent.to_path_buf();
continue; }
}
}
}
Err(_) => {
}
}
}
Ok(resolved)
}
fn urlencoding_maybe(input: &str) -> Result<String, ()> {
if !input.contains('%') {
return Ok(input.to_string());
}
let mut out = String::with_capacity(input.len());
let mut chars = input.chars();
while let Some(c) = chars.next() {
if c == '%' {
let h1 = chars.next().ok_or(())?;
let h2 = chars.next().ok_or(())?;
let byte = u8::from_str_radix(&format!("{h1}{h2}"), 16).map_err(|_| ())?;
out.push(byte as char);
} else {
out.push(c);
}
}
Ok(out)
}
fn update_memory_index(memory_dir: &Path, files: &[MemoryFile]) -> Result<(), String> {
let index_path = memory_dir.join("MEMORY.md");
let existing = if index_path.exists() {
fs::read_to_string(&index_path).map_err(|e| format!("read MEMORY.md: {e}"))?
} else {
String::from("# Memory index\n\n")
};
let existing_paths: HashSet<&str> = existing
.lines()
.filter_map(|line| {
line.trim()
.strip_prefix("- [")
.and_then(|rest| rest.split_once("]("))
.and_then(|(_, rest)| rest.split_once(')').map(|(path, _)| path))
})
.collect();
let mut new_lines: Vec<String> = Vec::new();
for f in files {
if existing_paths.contains(f.file_path.as_str()) {
continue;
}
let mut in_frontmatter = false;
let mut closed_frontmatter = false;
let hook = f
.content
.lines()
.find(|l| {
if l.trim() == "---" {
if !in_frontmatter {
in_frontmatter = true;
} else if in_frontmatter && !closed_frontmatter {
closed_frontmatter = true;
}
return false;
}
if in_frontmatter && !closed_frontmatter {
return false;
}
!l.is_empty()
})
.map(|l| {
let trimmed = l.trim();
if trimmed.len() > 80 {
format!("{}…", &trimmed[..80])
} else {
trimmed.to_string()
}
})
.unwrap_or_default();
new_lines.push(format!("- [{}]({}) — {}", f.file_path, f.file_path, hook));
}
if new_lines.is_empty() {
return Ok(());
}
let mut updated = existing;
while updated.ends_with('\n') {
updated.pop();
}
updated.push('\n');
for line in &new_lines {
updated.push_str(line);
updated.push('\n');
}
fs::write(&index_path, updated).map_err(|e| format!("write MEMORY.md: {e}"))?;
Ok(())
}
fn has_memory_writes_in_text(messages_text: &str, memory_dir: &Path) -> bool {
let mem_dir_str = memory_dir.to_string_lossy();
let has_memory_path = messages_text.contains(mem_dir_str.as_ref());
if !has_memory_path {
return false;
}
let write_patterns = [
"Write",
"\"name\": \"Write\"",
"\"name\":\"Write\"",
"Edit",
"\"name\": \"Edit\"",
"\"name\":\"Edit\"",
"file_write",
"file_edit",
"write_to_file",
];
write_patterns.iter().any(|p| messages_text.contains(p))
}
impl ExtractSkipReason {
pub fn to_event(self, goal_id: GoalId) -> DriverEvent {
DriverEvent::ExtractMemoriesSkipped {
goal_id,
reason: self,
}
}
}
pub struct NoopExtractMemoriesLlm {
pub canned_response: Mutex<Option<String>>,
}
impl NoopExtractMemoriesLlm {
pub fn new() -> Self {
Self {
canned_response: Mutex::new(None),
}
}
pub fn with_response(response: impl Into<String>) -> Self {
Self {
canned_response: Mutex::new(Some(response.into())),
}
}
}
impl Default for NoopExtractMemoriesLlm {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl ExtractMemoriesLlm for NoopExtractMemoriesLlm {
async fn chat(
&self,
_system_prompt: &str,
_user_messages: &str,
_max_tokens: u32,
) -> Result<String, String> {
self.canned_response
.lock()
.unwrap()
.take()
.ok_or_else(|| "NoopExtractMemoriesLlm: no canned response set".to_string())
}
}
pub struct LlmClientAdapter {
llm: Arc<dyn nexo_llm::LlmClient>,
model: String,
}
impl LlmClientAdapter {
pub fn new(llm: Arc<dyn nexo_llm::LlmClient>, model: impl Into<String>) -> Self {
Self {
llm,
model: model.into(),
}
}
}
#[async_trait]
impl ExtractMemoriesLlm for LlmClientAdapter {
async fn chat(
&self,
system_prompt: &str,
user_messages: &str,
max_tokens: u32,
) -> Result<String, String> {
let mut req = nexo_llm::ChatRequest::new(
self.model.clone(),
vec![nexo_llm::ChatMessage::user(user_messages.to_string())],
);
req.system_prompt = Some(system_prompt.to_string());
req.max_tokens = max_tokens;
let resp = self
.llm
.chat(req)
.await
.map_err(|e| format!("LlmClientAdapter chat error: {e}"))?;
match resp.content {
nexo_llm::ResponseContent::Text(text) => Ok(text),
nexo_llm::ResponseContent::ToolCalls(_) => {
Err("LlmClientAdapter: response is tool_calls, expected text".into())
}
}
}
}
impl MemoryExtractor for ExtractMemories {
fn tick(&self) {
ExtractMemories::tick(self);
}
fn extract(
self: Arc<Self>,
goal_id: GoalId,
turn_index: u32,
messages_text: String,
memory_dir: PathBuf,
) {
ExtractMemories::extract(&self, goal_id, turn_index, messages_text, memory_dir);
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn scan_manifest_empty_dir() {
let dir = TempDir::new().unwrap();
let manifest = scan_memory_manifest(dir.path()).unwrap();
assert!(manifest.is_empty());
}
#[test]
fn scan_manifest_nonexistent_dir() {
let manifest = scan_memory_manifest(Path::new("/tmp/nonexistent-memdir-77-5")).unwrap();
assert!(manifest.is_empty());
}
#[test]
fn scan_manifest_reads_frontmatter() {
let dir = TempDir::new().unwrap();
fs::write(
dir.path().join("preferences.md"),
"---\nname: user preferences\ndescription: likes dark mode\ntype: user\n---\n\nUser prefers dark mode.",
)
.unwrap();
fs::write(
dir.path().join("deploy.md"),
"---\nname: deploy notes\ndescription: deploy process\ntype: project\n---\n\nDeploy on Fridays.",
)
.unwrap();
let manifest = scan_memory_manifest(dir.path()).unwrap();
assert!(
manifest.contains("preferences.md"),
"missing preferences: {manifest}"
);
assert!(
manifest.contains("dark mode"),
"missing description: {manifest}"
);
assert!(manifest.contains("[user]"), "missing type tag: {manifest}");
assert!(manifest.contains("deploy.md"), "missing deploy: {manifest}");
assert!(
manifest.contains("[project]"),
"missing project type: {manifest}"
);
}
#[test]
fn scan_manifest_skips_memory_index() {
let dir = TempDir::new().unwrap();
fs::write(
dir.path().join("MEMORY.md"),
"# Memory index\n\n- [prefs](preferences.md)\n",
)
.unwrap();
fs::write(
dir.path().join("preferences.md"),
"---\nname: prefs\ndescription: x\ntype: user\n---\n\nContent.",
)
.unwrap();
let manifest = scan_memory_manifest(dir.path()).unwrap();
assert!(
!manifest.contains("MEMORY.md"),
"MEMORY.md should be excluded: {manifest}"
);
assert!(
manifest.contains("preferences.md"),
"should list preferences: {manifest}"
);
}
#[test]
fn scan_manifest_file_without_frontmatter() {
let dir = TempDir::new().unwrap();
fs::write(
dir.path().join("notes.md"),
"Just some notes.\nNo frontmatter here.",
)
.unwrap();
let manifest = scan_memory_manifest(dir.path()).unwrap();
assert!(
manifest.contains("[unknown]"),
"should tag as unknown: {manifest}"
);
assert!(
manifest.contains("notes.md"),
"should list the file: {manifest}"
);
}
#[test]
fn has_memory_writes_detects_write_tool() {
let text = r#"Tool: Write
Arguments: {"file_path": "/home/user/.claude/projects/test/memory/foo.md", "content": "..."}"#;
assert!(has_memory_writes_in_text(
text,
Path::new("/home/user/.claude/projects/test/memory")
));
}
#[test]
fn has_memory_writes_detects_file_write_tool() {
let text = r#"I'll use file_write to save this memory.
{"tool": "file_write", "path": "/home/user/.claude/projects/x/memory/bar.md"}"#;
assert!(has_memory_writes_in_text(
text,
Path::new("/home/user/.claude/projects/x/memory")
));
}
#[test]
fn has_memory_writes_no_write() {
let text = "Just a normal conversation.\nNo tool calls here.";
assert!(!has_memory_writes_in_text(
text,
Path::new("/home/user/.claude/projects/test/memory")
));
}
#[test]
fn has_memory_writes_write_outside_memory_dir() {
let text = r#"Write to /tmp/some-other-file.txt"#;
assert!(!has_memory_writes_in_text(
text,
Path::new("/home/user/memory")
));
}
#[test]
fn resolve_memory_path_rejects_absolute() {
assert!(resolve_memory_path(Path::new("/mem"), "/etc/passwd").is_err());
}
#[test]
fn resolve_memory_path_rejects_parent_traversal() {
assert!(resolve_memory_path(Path::new("/mem"), "../outside.md").is_err());
assert!(resolve_memory_path(Path::new("/mem"), "sub/../../outside.md").is_err());
}
#[test]
fn resolve_memory_path_accepts_normal() {
let result = resolve_memory_path(Path::new("/mem"), "user_role.md").unwrap();
assert_eq!(result, PathBuf::from("/mem/user_role.md"));
}
#[test]
fn resolve_memory_path_accepts_subdir() {
let result = resolve_memory_path(Path::new("/mem"), "sub/dir/file.md").unwrap();
assert_eq!(result, PathBuf::from("/mem/sub/dir/file.md"));
}
#[test]
fn parse_response_bare_json() {
let json = r#"[{"file_path": "test.md", "content": "hello"}]"#;
let files = parse_extraction_response(json).unwrap();
assert_eq!(files.len(), 1);
assert_eq!(files[0].file_path, "test.md");
assert_eq!(files[0].content, "hello");
}
#[test]
fn parse_response_json_fenced() {
let json = "```json\n[{\"file_path\": \"x.md\", \"content\": \"y\"}]\n```";
let files = parse_extraction_response(json).unwrap();
assert_eq!(files.len(), 1);
assert_eq!(files[0].file_path, "x.md");
}
#[test]
fn parse_response_empty_array() {
let files = parse_extraction_response("[]").unwrap();
assert!(files.is_empty());
}
#[test]
fn parse_response_invalid_json() {
assert!(parse_extraction_response("not json").is_err());
}
fn make_config() -> ExtractMemoriesConfig {
ExtractMemoriesConfig {
enabled: true,
turns_throttle: 1,
max_turns: 5,
max_consecutive_failures: 3,
}
}
fn make_extractor(config: ExtractMemoriesConfig) -> Arc<ExtractMemories> {
Arc::new(ExtractMemories::new(
config,
Arc::new(NoopExtractMemoriesLlm::new()),
))
}
#[test]
fn gate_disabled_when_enabled_false() {
let mut cfg = make_config();
cfg.enabled = false;
let ext = make_extractor(cfg);
assert!(matches!(
ext.check_gates(),
Err(ExtractSkipReason::Disabled)
));
}
#[test]
fn gate_throttled_when_not_enough_turns() {
let mut cfg = make_config();
cfg.turns_throttle = 3;
let ext = make_extractor(cfg);
assert!(matches!(
ext.check_gates(),
Err(ExtractSkipReason::Throttled)
));
}
#[test]
fn gate_passes_when_throttle_satisfied() {
let cfg = make_config(); let ext = make_extractor(cfg);
assert!(ext.check_gates().is_ok());
}
#[test]
fn gate_passes_after_tick_accumulates() {
let mut cfg = make_config();
cfg.turns_throttle = 2;
let ext = make_extractor(cfg);
assert!(ext.check_gates().is_err());
ext.tick();
assert!(ext.check_gates().is_ok());
}
#[test]
fn gate_circuit_breaker_trips_after_n_failures() {
let ext = make_extractor(make_config());
ext.record_failure();
ext.record_failure();
ext.record_failure();
assert!(matches!(
ext.check_gates(),
Err(ExtractSkipReason::CircuitBreakerOpen)
));
}
#[test]
fn gate_circuit_breaker_disabled_when_max_zero() {
let mut cfg = make_config();
cfg.max_consecutive_failures = 0;
let ext = make_extractor(cfg);
ext.record_failure();
ext.record_failure();
ext.record_failure();
assert!(ext.check_gates().is_ok());
}
#[test]
fn record_success_resets_failures_and_turns() {
let ext = make_extractor(make_config());
ext.record_failure();
ext.record_failure();
ext.record_success(None);
assert!(ext.check_gates().is_ok());
}
#[test]
fn update_index_creates_file_when_missing() {
let dir = TempDir::new().unwrap();
let files = vec![MemoryFile {
file_path: "new_memory.md".to_string(),
content: "---\nname: test\ntype: user\n---\n\nSome content here.".to_string(),
}];
update_memory_index(dir.path(), &files).unwrap();
let index = fs::read_to_string(dir.path().join("MEMORY.md")).unwrap();
assert!(index.contains("new_memory.md"), "should list new file");
assert!(index.contains("Some content here"), "should include hook");
}
#[test]
fn update_index_skips_duplicates() {
let dir = TempDir::new().unwrap();
fs::write(
dir.path().join("MEMORY.md"),
"# Memory index\n\n- [existing](existing.md) — already there\n",
)
.unwrap();
let files = vec![
MemoryFile {
file_path: "existing.md".to_string(),
content: "duplicate".to_string(),
},
MemoryFile {
file_path: "new_one.md".to_string(),
content: "new content here".to_string(),
},
];
update_memory_index(dir.path(), &files).unwrap();
let index = fs::read_to_string(dir.path().join("MEMORY.md")).unwrap();
let existing_count = index.matches("existing.md").count();
assert_eq!(existing_count, 1, "duplicate should not be appended");
assert!(index.contains("new_one.md"), "new file should be appended");
}
#[test]
fn resolve_memory_path_rejects_null_byte() {
assert!(resolve_memory_path(Path::new("/mem"), "foo\0bar.md").is_err());
}
#[test]
fn resolve_memory_path_rejects_url_encoded_traversal() {
assert!(resolve_memory_path(Path::new("/mem"), "%2e%2e%2foutside.md").is_err());
assert!(resolve_memory_path(Path::new("/mem"), "%2e%2e%2Foutside.md").is_err());
assert!(resolve_memory_path(Path::new("/mem"), "sub/%2e%2e/outside.md").is_err());
}
#[test]
fn resolve_memory_path_rejects_unicode_fullwidth_dots() {
assert!(resolve_memory_path(Path::new("/mem"), "foo/\u{FF0E}\u{FF0E}/bar.md").is_err());
}
#[test]
fn resolve_memory_path_rejects_unicode_fullwidth_slash() {
assert!(resolve_memory_path(Path::new("/mem"), "foo\u{FF0F}bar.md").is_err());
}
#[test]
fn resolve_memory_path_accepts_normal_path_phase77_7() {
let result = resolve_memory_path(Path::new("/mem"), "notes.md").unwrap();
assert_eq!(result, Path::new("/mem/notes.md"));
}
#[tokio::test]
async fn llm_client_adapter_chat_round_trips() {
use nexo_llm::{ChatRequest, ChatResponse, FinishReason, ResponseContent, TokenUsage};
use std::sync::Mutex;
struct MockLlm {
captured: Mutex<Option<ChatRequest>>,
}
#[async_trait]
impl nexo_llm::LlmClient for MockLlm {
async fn chat(&self, req: ChatRequest) -> anyhow::Result<ChatResponse> {
*self.captured.lock().unwrap() = Some(req);
Ok(ChatResponse {
content: ResponseContent::Text("extracted-content".into()),
usage: TokenUsage::default(),
finish_reason: FinishReason::Stop,
cache_usage: None,
})
}
fn model_id(&self) -> &str {
"mock-model"
}
}
let mock = Arc::new(MockLlm {
captured: Mutex::new(None),
});
let adapter = LlmClientAdapter::new(
Arc::clone(&mock) as Arc<dyn nexo_llm::LlmClient>,
"test-model",
);
let result = adapter
.chat("system prompt", "user msg", 1024)
.await
.unwrap();
assert_eq!(result, "extracted-content");
let captured = mock.captured.lock().unwrap().take().unwrap();
assert_eq!(captured.model, "test-model");
assert_eq!(captured.system_prompt.as_deref(), Some("system prompt"));
assert_eq!(captured.max_tokens, 1024);
assert_eq!(captured.messages.len(), 1);
}
#[tokio::test]
async fn llm_client_adapter_errors_on_tool_call_response() {
use nexo_llm::{
ChatRequest, ChatResponse, FinishReason, ResponseContent, TokenUsage, ToolCall,
};
struct ToolCallLlm;
#[async_trait]
impl nexo_llm::LlmClient for ToolCallLlm {
async fn chat(&self, _req: ChatRequest) -> anyhow::Result<ChatResponse> {
Ok(ChatResponse {
content: ResponseContent::ToolCalls(vec![ToolCall {
id: "id".into(),
name: "noop".into(),
arguments: serde_json::json!({}),
}]),
usage: TokenUsage::default(),
finish_reason: FinishReason::ToolUse,
cache_usage: None,
})
}
fn model_id(&self) -> &str {
"mock-tool"
}
}
let adapter = LlmClientAdapter::new(
Arc::new(ToolCallLlm) as Arc<dyn nexo_llm::LlmClient>,
"tool-model",
);
let err = adapter
.chat("sys", "user", 256)
.await
.expect_err("tool_calls response must surface an error");
assert!(err.contains("tool_calls"));
}
}