Skip to main content

imp_core/tools/
mod.rs

1pub mod ask;
2pub mod bash;
3pub mod edit;
4pub mod git;
5pub mod lua;
6pub mod mana;
7pub mod memory;
8pub mod multi_edit;
9pub mod query;
10pub mod read;
11pub mod scan;
12pub mod session_search;
13pub mod shell;
14pub mod web;
15pub mod worktree;
16pub mod write;
17
18use std::collections::hash_map::DefaultHasher;
19use std::collections::HashMap;
20use std::hash::{Hash, Hasher};
21use std::path::{Path, PathBuf};
22use std::sync::Arc;
23
24use async_trait::async_trait;
25use imp_llm::provider::ToolDefinition;
26use imp_llm::{ContentBlock, ToolResultMessage};
27
28use crate::agent::AgentCommand;
29use crate::config::AgentMode;
30use crate::config::LuaCapabilityPolicy;
31use crate::error::Result;
32use crate::mana_review::TurnManaReviewAccumulator;
33use crate::reference_monitor::ToolMetadata;
34use crate::trust::Provenance;
35use crate::ui::UserInterface;
36
37/// Resolve a user-provided path: expands `~` to home dir, resolves relative paths against cwd.
38pub fn resolve_path(cwd: &Path, raw: &str) -> PathBuf {
39    if raw == "~" {
40        if let Ok(home) = std::env::var("HOME") {
41            return PathBuf::from(home);
42        }
43    } else if let Some(rest) = raw.strip_prefix("~/") {
44        if let Ok(home) = std::env::var("HOME") {
45            return PathBuf::from(home).join(rest);
46        }
47    }
48    let p = Path::new(raw);
49    if p.is_absolute() {
50        p.to_path_buf()
51    } else {
52        cwd.join(p)
53    }
54}
55
56/// A tool that can be invoked by the agent.
57#[async_trait]
58pub trait Tool: Send + Sync {
59    /// Tool name (used in LLM tool calls).
60    fn name(&self) -> &str;
61
62    /// Human-readable label.
63    fn label(&self) -> &str;
64
65    /// Description shown to the LLM.
66    fn description(&self) -> &str;
67
68    /// JSON Schema for parameters.
69    fn parameters(&self) -> serde_json::Value;
70
71    /// Whether this tool only reads (no side effects).
72    fn is_readonly(&self) -> bool;
73
74    /// Metadata used by the runtime reference monitor.
75    fn policy_metadata(&self) -> ToolMetadata {
76        ToolMetadata::for_tool_name(self.name(), self.is_readonly())
77    }
78
79    /// Execute the tool.
80    async fn execute(
81        &self,
82        call_id: &str,
83        params: serde_json::Value,
84        ctx: ToolContext,
85    ) -> Result<ToolOutput>;
86}
87
88/// Tracks which files have been read in the current session and when.
89///
90/// Used to warn on edits to unread files and detect external modifications.
91pub struct FileTracker {
92    reads: HashMap<PathBuf, std::time::SystemTime>,
93}
94
95impl Default for FileTracker {
96    fn default() -> Self {
97        Self::new()
98    }
99}
100
101impl FileTracker {
102    pub fn new() -> Self {
103        Self {
104            reads: HashMap::new(),
105        }
106    }
107
108    /// Record that a file was read at the current time.
109    pub fn record_read(&mut self, path: &Path) {
110        let mtime = std::fs::metadata(path)
111            .and_then(|m| m.modified())
112            .unwrap_or(std::time::UNIX_EPOCH);
113        self.reads.insert(path.to_path_buf(), mtime);
114    }
115
116    /// Returns true if the file has been read in this session.
117    pub fn was_read(&self, path: &Path) -> bool {
118        self.reads.contains_key(path)
119    }
120
121    /// Returns true if the file's mtime differs from when it was last read,
122    /// indicating an external modification. Returns false if the file was
123    /// never read or if the mtime cannot be determined.
124    pub fn is_stale(&self, path: &Path) -> bool {
125        let Some(&recorded_mtime) = self.reads.get(path) else {
126            return false;
127        };
128        let Ok(current_mtime) = std::fs::metadata(path).and_then(|m| m.modified()) else {
129            return false;
130        };
131        current_mtime != recorded_mtime
132    }
133}
134
135#[derive(Debug, Clone, PartialEq, Eq)]
136pub struct LineAnchor {
137    pub id: String,
138    pub line: usize,
139    pub content_hash: u64,
140}
141
142#[derive(Debug, Default)]
143pub struct AnchorStore {
144    files: std::sync::Mutex<HashMap<PathBuf, HashMap<String, LineAnchor>>>,
145}
146
147impl AnchorStore {
148    pub fn new() -> Self {
149        Self::default()
150    }
151
152    pub fn record_lines(
153        &self,
154        path: &Path,
155        file_hash: u64,
156        start_line: usize,
157        lines: &[&str],
158    ) -> Vec<LineAnchor> {
159        let anchors = lines
160            .iter()
161            .enumerate()
162            .map(|(idx, line)| {
163                let line_number = start_line + idx;
164                let content_hash = stable_hash(line);
165                LineAnchor {
166                    id: format!(
167                        "a{:016x}{:08x}{:016x}",
168                        file_hash, line_number, content_hash
169                    ),
170                    line: line_number,
171                    content_hash,
172                }
173            })
174            .collect::<Vec<_>>();
175
176        if let Ok(mut files) = self.files.lock() {
177            let entry = files.entry(path.to_path_buf()).or_default();
178            for anchor in &anchors {
179                entry.insert(anchor.id.clone(), anchor.clone());
180            }
181        }
182
183        anchors
184    }
185
186    pub fn get(&self, path: &Path, id: &str) -> Option<LineAnchor> {
187        self.files
188            .lock()
189            .ok()?
190            .get(path)
191            .and_then(|anchors| anchors.get(id).cloned())
192    }
193}
194
195pub fn stable_hash<T: Hash>(value: T) -> u64 {
196    let mut hasher = DefaultHasher::new();
197    value.hash(&mut hasher);
198    hasher.finish()
199}
200
201/// Cloneable runtime hook for loading Lua extension tools into a registry.
202pub type LuaToolLoader = Arc<dyn Fn(&LuaCapabilityPolicy, &mut ToolRegistry) + Send + Sync>;
203
204/// Context provided to tools during execution.
205#[derive(Clone)]
206pub struct ToolContext {
207    pub cwd: PathBuf,
208    pub cancelled: Arc<std::sync::atomic::AtomicBool>,
209    pub update_tx: tokio::sync::mpsc::Sender<ToolUpdate>,
210    pub command_tx: tokio::sync::mpsc::Sender<AgentCommand>,
211    pub ui: Arc<dyn UserInterface>,
212    pub file_cache: Arc<FileCache>,
213    /// Shared checkpoint/file-history state for destructive tool operations.
214    pub checkpoint_state: Arc<CheckpointState>,
215    /// Tracks file reads for staleness detection and unread-edit warnings.
216    pub file_tracker: Arc<std::sync::Mutex<FileTracker>>,
217    /// Session-local anchors emitted by read and consumed by anchored edit mode.
218    pub anchor_store: Arc<AnchorStore>,
219    /// Cloneable Lua extension loader inherited from the parent runtime.
220    pub lua_tool_loader: Option<LuaToolLoader>,
221    /// Active agent mode — determines which actions are permitted.
222    pub mode: AgentMode,
223    /// Max lines the read tool may return before truncating. 0 means unlimited.
224    pub read_max_lines: usize,
225    /// Turn-scoped runtime accumulator for between-turn mana review packets.
226    pub turn_mana_review: Arc<std::sync::Mutex<TurnManaReviewAccumulator>>,
227    /// Resolved runtime config for tool-specific policy checks.
228    pub config: Arc<crate::config::Config>,
229    /// Per-run tool/write policy layered on top of AgentMode.
230    pub run_policy: crate::policy::RunPolicy,
231    /// Supporting provenance for content that motivates durable writes in this tool call.
232    pub supporting_provenance: Vec<Provenance>,
233}
234
235/// In-session file content cache. Avoids re-reading files that haven't changed.
236pub struct FileCache {
237    entries: std::sync::Mutex<std::collections::HashMap<PathBuf, FileCacheEntry>>,
238}
239
240struct FileCacheEntry {
241    mtime: std::time::SystemTime,
242    content: String,
243}
244
245impl Default for FileCache {
246    fn default() -> Self {
247        Self::new()
248    }
249}
250
251impl FileCache {
252    pub fn new() -> Self {
253        Self {
254            entries: std::sync::Mutex::new(std::collections::HashMap::new()),
255        }
256    }
257
258    /// Read a file, returning cached content if mtime hasn't changed.
259    pub fn read(&self, path: &Path) -> std::io::Result<String> {
260        let metadata = std::fs::metadata(path)?;
261        let mtime = metadata.modified().unwrap_or(std::time::UNIX_EPOCH);
262
263        {
264            let cache = self.entries.lock().unwrap();
265            if let Some(entry) = cache.get(path) {
266                if entry.mtime == mtime {
267                    return Ok(entry.content.clone());
268                }
269            }
270        }
271
272        let content = std::fs::read_to_string(path)?;
273
274        {
275            let mut cache = self.entries.lock().unwrap();
276            cache.insert(
277                path.to_path_buf(),
278                FileCacheEntry {
279                    mtime,
280                    content: content.clone(),
281                },
282            );
283        }
284
285        Ok(content)
286    }
287
288    /// Invalidate a cache entry (call after write/edit).
289    pub fn invalidate(&self, path: &Path) {
290        let mut cache = self.entries.lock().unwrap();
291        cache.remove(path);
292    }
293}
294
295/// Pre-edit file snapshots for rollback safety.
296///
297/// Before the first edit to any file in a session, stores the original content.
298/// If the file didn't exist before the edit, nothing is stored.
299/// Enables rollback to pre-session state if the agent makes bad edits.
300pub struct FileHistory {
301    originals: std::sync::Mutex<HashMap<PathBuf, String>>,
302}
303
304#[derive(Debug, Clone, PartialEq, Eq)]
305pub struct CheckpointRecord {
306    pub id: String,
307    pub label: Option<String>,
308    pub created_at: u64,
309    pub files: Vec<PathBuf>,
310}
311
312/// Shared session-scoped checkpoint state built on top of FileHistory.
313///
314/// This keeps the existing rollback primitive as the source of truth for file
315/// contents while adding lightweight checkpoint metadata that later layers can
316/// persist or surface in the UI.
317pub struct CheckpointState {
318    history: FileHistory,
319    records: std::sync::Mutex<Vec<CheckpointRecord>>,
320}
321
322impl Default for CheckpointState {
323    fn default() -> Self {
324        Self::new()
325    }
326}
327
328impl CheckpointState {
329    pub fn new() -> Self {
330        Self {
331            history: FileHistory::new(),
332            records: std::sync::Mutex::new(Vec::new()),
333        }
334    }
335
336    /// Snapshot a set of existing files and record a checkpoint if any were captured.
337    pub fn snapshot_paths(
338        &self,
339        paths: &[PathBuf],
340        label: Option<String>,
341    ) -> std::io::Result<Option<CheckpointRecord>> {
342        let mut unique = Vec::new();
343        for path in paths {
344            if !unique.iter().any(|existing: &PathBuf| existing == path) {
345                unique.push(path.clone());
346            }
347        }
348
349        let mut captured = Vec::new();
350        for path in unique {
351            self.history.snapshot_before_edit(&path)?;
352            if self.history.original(&path).is_some() {
353                captured.push(path);
354            }
355        }
356
357        if captured.is_empty() {
358            return Ok(None);
359        }
360
361        let record = CheckpointRecord {
362            id: uuid::Uuid::new_v4().to_string(),
363            label,
364            created_at: imp_llm::now(),
365            files: captured,
366        };
367        self.records.lock().unwrap().push(record.clone());
368        Ok(Some(record))
369    }
370
371    pub fn checkpoints(&self) -> Vec<CheckpointRecord> {
372        self.records.lock().unwrap().clone()
373    }
374
375    pub fn checkpoint(&self, id: &str) -> Option<CheckpointRecord> {
376        self.records
377            .lock()
378            .unwrap()
379            .iter()
380            .find(|record| record.id == id)
381            .cloned()
382    }
383
384    pub fn restore_checkpoint(&self, id: &str) -> std::io::Result<Vec<PathBuf>> {
385        let Some(record) = self.checkpoint(id) else {
386            return Ok(Vec::new());
387        };
388
389        let mut restored = Vec::new();
390        for path in &record.files {
391            self.history.rollback(path)?;
392            restored.push(path.clone());
393        }
394        Ok(restored)
395    }
396
397    pub fn rollback(&self, path: &Path) -> std::io::Result<()> {
398        self.history.rollback(path)
399    }
400
401    pub fn tracked_files(&self) -> Vec<PathBuf> {
402        self.history.tracked_files()
403    }
404
405    pub fn original(&self, path: &Path) -> Option<String> {
406        self.history.original(path)
407    }
408}
409
410impl Default for FileHistory {
411    fn default() -> Self {
412        Self::new()
413    }
414}
415
416impl FileHistory {
417    pub fn new() -> Self {
418        Self {
419            originals: std::sync::Mutex::new(HashMap::new()),
420        }
421    }
422
423    /// Store original content if not already stored for this path (first edit wins).
424    /// Does nothing if the file doesn't exist (new file creation).
425    pub fn snapshot_before_edit(&self, path: &Path) -> std::io::Result<()> {
426        let canonical = path.to_path_buf();
427
428        let mut originals = self.originals.lock().unwrap();
429        if originals.contains_key(&canonical) {
430            return Ok(()); // first edit wins
431        }
432        if canonical.exists() {
433            let content = std::fs::read_to_string(&canonical)?;
434            originals.insert(canonical, content);
435        }
436        Ok(())
437    }
438
439    /// Get the original content of a file (before any edits in this session).
440    pub fn original(&self, path: &Path) -> Option<String> {
441        self.originals.lock().unwrap().get(path).cloned()
442    }
443
444    /// Rollback a file to its original content.
445    pub fn rollback(&self, path: &Path) -> std::io::Result<()> {
446        let originals = self.originals.lock().unwrap();
447        if let Some(content) = originals.get(path) {
448            std::fs::write(path, content)?;
449        }
450        Ok(())
451    }
452
453    /// List all files with snapshots.
454    pub fn tracked_files(&self) -> Vec<PathBuf> {
455        self.originals.lock().unwrap().keys().cloned().collect()
456    }
457}
458
459impl ToolContext {
460    pub fn is_cancelled(&self) -> bool {
461        self.cancelled.load(std::sync::atomic::Ordering::Relaxed)
462    }
463
464    pub fn check_write_path(&self, path: &Path) -> std::result::Result<(), String> {
465        match self.run_policy.check_write_path(&self.cwd, path) {
466            crate::policy::WritePolicyDecision::Allowed => Ok(()),
467            crate::policy::WritePolicyDecision::Denied(reason) => Err(reason),
468        }
469    }
470}
471
472/// Result of a tool execution.
473pub struct ToolOutput {
474    pub content: Vec<ContentBlock>,
475    pub details: serde_json::Value,
476    pub is_error: bool,
477}
478
479impl ToolOutput {
480    pub fn text(text: impl Into<String>) -> Self {
481        Self {
482            content: vec![ContentBlock::Text { text: text.into() }],
483            details: serde_json::Value::Null,
484            is_error: false,
485        }
486    }
487
488    pub fn error(text: impl Into<String>) -> Self {
489        Self {
490            content: vec![ContentBlock::Text { text: text.into() }],
491            details: serde_json::Value::Null,
492            is_error: true,
493        }
494    }
495
496    /// Extract the first text block, if any. Useful for tests.
497    pub fn text_content(&self) -> Option<&str> {
498        self.content.iter().find_map(|b| match b {
499            ContentBlock::Text { text } => Some(text.as_str()),
500            _ => None,
501        })
502    }
503
504    pub fn into_tool_result(self, call_id: &str, tool_name: &str) -> ToolResultMessage {
505        ToolResultMessage {
506            tool_call_id: call_id.to_string(),
507            tool_name: tool_name.to_string(),
508            content: self.content,
509            is_error: self.is_error,
510            details: self.details,
511            timestamp: imp_llm::now(),
512        }
513    }
514}
515
516/// Partial update from a running tool (for streaming output).
517pub struct ToolUpdate {
518    pub content: Vec<ContentBlock>,
519    pub details: serde_json::Value,
520}
521
522/// Registry of available tools.
523pub struct ToolRegistry {
524    tools: HashMap<String, Arc<dyn Tool>>,
525    aliases: HashMap<String, String>,
526}
527
528impl ToolRegistry {
529    pub fn new() -> Self {
530        Self {
531            tools: HashMap::new(),
532            aliases: HashMap::new(),
533        }
534    }
535
536    /// Register a native Rust tool.
537    pub fn register(&mut self, tool: Arc<dyn Tool>) {
538        self.tools.insert(tool.name().to_string(), tool);
539    }
540
541    /// Register a compatibility alias for an existing canonical tool name.
542    ///
543    /// Aliases resolve at execution time but are intentionally omitted from
544    /// tool definitions so models see the canonical surface only.
545    pub fn register_alias(&mut self, alias: impl Into<String>, canonical: impl Into<String>) {
546        self.aliases.insert(alias.into(), canonical.into());
547    }
548
549    pub fn extend(&mut self, other: ToolRegistry) {
550        for tool in other.tools.into_values() {
551            self.register(tool);
552        }
553        for (alias, canonical) in other.aliases {
554            self.register_alias(alias, canonical);
555        }
556    }
557
558    /// Get a tool by canonical name or compatibility alias.
559    pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
560        if let Some(tool) = self.tools.get(name) {
561            return Some(tool);
562        }
563
564        self.aliases
565            .get(name)
566            .and_then(|canonical| self.tools.get(canonical))
567    }
568
569    /// Get a cloned map of all tools, including compatibility aliases.
570    pub fn tools_map(&self) -> HashMap<String, Arc<dyn Tool>> {
571        let mut map = self.tools.clone();
572        for (alias, canonical) in &self.aliases {
573            if let Some(tool) = self.tools.get(canonical) {
574                map.insert(alias.clone(), Arc::clone(tool));
575            }
576        }
577        map
578    }
579
580    /// Get all canonical tool definitions (for LLM context).
581    /// Compatibility aliases such as legacy `multi_edit` are intentionally hidden
582    /// so models learn one canonical edit surface.
583    pub fn definitions(&self) -> Vec<ToolDefinition> {
584        let mut defs: Vec<_> = self
585            .tools
586            .values()
587            .map(|t| ToolDefinition {
588                name: t.name().to_string(),
589                description: t.description().to_string(),
590                parameters: t.parameters(),
591            })
592            .collect();
593        defs.sort_by(|a, b| a.name.cmp(&b.name));
594        defs
595    }
596
597    /// Get only readonly tool definitions (for readonly roles).
598    pub fn readonly_definitions(&self) -> Vec<ToolDefinition> {
599        let mut defs: Vec<_> = self
600            .tools
601            .values()
602            .filter(|t| t.is_readonly())
603            .map(|t| ToolDefinition {
604                name: t.name().to_string(),
605                description: t.description().to_string(),
606                parameters: t.parameters(),
607            })
608            .collect();
609        defs.sort_by(|a, b| a.name.cmp(&b.name));
610        defs
611    }
612
613    /// List all tool names.
614    pub fn names(&self) -> Vec<String> {
615        self.tools.keys().cloned().collect()
616    }
617
618    /// Retain only tools whose names satisfy the predicate.
619    ///
620    /// Used by `AgentBuilder` to filter tools based on agent mode before the
621    /// agent is handed out to callers.
622    pub fn retain<F>(&mut self, predicate: F)
623    where
624        F: Fn(&str) -> bool,
625    {
626        self.tools.retain(|name, _| predicate(name));
627        self.aliases
628            .retain(|_, canonical| self.tools.contains_key(canonical));
629    }
630
631    /// Get tool definitions filtered to those allowed by an agent mode.
632    ///
633    /// For `Full` mode (empty allow-list), returns all definitions.
634    /// For all other modes, returns only the intersection.
635    pub fn definitions_for_mode(
636        &self,
637        mode: &crate::config::AgentMode,
638    ) -> Vec<imp_llm::provider::ToolDefinition> {
639        let mut defs: Vec<_> = self
640            .tools
641            .values()
642            .filter(|t| mode.allows_tool(t.name()))
643            .map(|t| imp_llm::provider::ToolDefinition {
644                name: t.name().to_string(),
645                description: t.description().to_string(),
646                parameters: t.parameters(),
647            })
648            .collect();
649        defs.sort_by(|a, b| a.name.cmp(&b.name));
650        defs
651    }
652
653    /// Lookup reference monitor metadata by canonical name or alias.
654    pub fn policy_metadata(&self, name: &str) -> Option<ToolMetadata> {
655        self.get(name).map(|tool| tool.policy_metadata())
656    }
657
658    /// Number of registered tools.
659    pub fn len(&self) -> usize {
660        self.tools.len()
661    }
662
663    pub fn is_empty(&self) -> bool {
664        self.tools.is_empty()
665    }
666}
667
668impl Default for ToolRegistry {
669    fn default() -> Self {
670        Self::new()
671    }
672}
673
674// ── Truncation helpers ──────────────────────────────────────────────
675
676pub struct TruncationResult {
677    pub content: String,
678    pub truncated: bool,
679    pub output_lines: usize,
680    pub total_lines: usize,
681    pub output_bytes: usize,
682    pub total_bytes: usize,
683    pub temp_file: Option<PathBuf>,
684}
685
686/// Truncate a single line to max_bytes, appending "…" if truncated.
687pub fn truncate_line(line: &str, max_bytes: usize) -> String {
688    if line.len() <= max_bytes {
689        return line.to_string();
690    }
691    let mut end = max_bytes.min(line.len());
692    while end > 0 && !line.is_char_boundary(end) {
693        end -= 1;
694    }
695    format!("{}…", &line[..end])
696}
697
698/// Write full content to a temp file, returning the path.
699fn write_temp_file(content: &str) -> Option<PathBuf> {
700    let dir = std::env::temp_dir().join("imp-tools");
701    std::fs::create_dir_all(&dir).ok()?;
702    let name = format!("truncated-{}.txt", uuid::Uuid::new_v4());
703    let path = dir.join(name);
704    std::fs::write(&path, content).ok()?;
705    Some(path)
706}
707
708/// Truncate keeping the head (first N lines/bytes).
709/// When truncated, writes full output to a temp file.
710pub fn truncate_head(input: &str, max_lines: usize, max_bytes: usize) -> TruncationResult {
711    let lines: Vec<&str> = input.lines().collect();
712    let total_lines = lines.len();
713    let total_bytes = input.len();
714
715    if total_lines <= max_lines && total_bytes <= max_bytes {
716        return TruncationResult {
717            content: input.to_string(),
718            truncated: false,
719            output_lines: total_lines,
720            total_lines,
721            output_bytes: total_bytes,
722            total_bytes,
723            temp_file: None,
724        };
725    }
726
727    let mut result = String::new();
728    let mut byte_count = 0;
729    let mut line_count = 0;
730
731    for line in &lines {
732        let line_with_newline = format!("{line}\n");
733        if line_count >= max_lines || byte_count + line_with_newline.len() > max_bytes {
734            break;
735        }
736        result.push_str(&line_with_newline);
737        byte_count += line_with_newline.len();
738        line_count += 1;
739    }
740
741    let temp_file = write_temp_file(input);
742
743    TruncationResult {
744        content: result,
745        truncated: true,
746        output_lines: line_count,
747        total_lines,
748        output_bytes: byte_count,
749        total_bytes,
750        temp_file,
751    }
752}
753
754/// Truncate keeping the tail (last N lines/bytes).
755/// When truncated, writes full output to a temp file.
756pub fn truncate_tail(input: &str, max_lines: usize, max_bytes: usize) -> TruncationResult {
757    let lines: Vec<&str> = input.lines().collect();
758    let total_lines = lines.len();
759    let total_bytes = input.len();
760
761    if total_lines <= max_lines && total_bytes <= max_bytes {
762        return TruncationResult {
763            content: input.to_string(),
764            truncated: false,
765            output_lines: total_lines,
766            total_lines,
767            output_bytes: total_bytes,
768            total_bytes,
769            temp_file: None,
770        };
771    }
772
773    // Walk backwards from the end, collecting lines that fit.
774    let start = total_lines.saturating_sub(max_lines);
775    let mut actual_start = start;
776    let mut remaining_bytes = max_bytes;
777
778    for (i, line) in lines[start..].iter().enumerate() {
779        let line_with_newline = format!("{line}\n");
780        if line_with_newline.len() > remaining_bytes {
781            actual_start = start + i + 1;
782            remaining_bytes = max_bytes;
783            // Recalculate from new start
784            for line2 in &lines[actual_start..] {
785                let l = format!("{line2}\n");
786                if l.len() > remaining_bytes {
787                    break;
788                }
789                remaining_bytes -= l.len();
790            }
791            break;
792        }
793        remaining_bytes -= line_with_newline.len();
794    }
795
796    let mut result = String::new();
797    for line in &lines[actual_start..] {
798        result.push_str(&format!("{line}\n"));
799    }
800
801    let output_lines = total_lines - actual_start;
802    let output_bytes = result.len();
803    let temp_file = write_temp_file(input);
804
805    TruncationResult {
806        content: result,
807        truncated: true,
808        output_lines,
809        total_lines,
810        output_bytes,
811        total_bytes,
812        temp_file,
813    }
814}
815
816// ── Fuzzy matching for edit tools ───────────────────────────────────
817
818pub(crate) mod fuzzy {
819    /// Normalize text for fuzzy matching: strip trailing whitespace per line,
820    /// convert smart quotes and unicode dashes to ASCII equivalents.
821    pub fn normalize_for_matching(text: &str) -> String {
822        text.lines()
823            .map(|line| normalize_unicode(line.trim_end()))
824            .collect::<Vec<_>>()
825            .join("\n")
826    }
827
828    fn normalize_unicode(s: &str) -> String {
829        s.chars()
830            .map(|c| match c {
831                '\u{2018}' | '\u{2019}' => '\'',
832                '\u{201C}' | '\u{201D}' => '"',
833                '\u{2013}' | '\u{2014}' => '-',
834                '\u{00A0}' | '\u{2003}' | '\u{2002}' | '\u{2009}' => ' ',
835                other => other,
836            })
837            .collect()
838    }
839
840    /// Result of a fuzzy find: byte range in original content.
841    pub struct FuzzyMatch {
842        pub start: usize,
843        pub end: usize,
844    }
845
846    /// Try to find old_text in content using fuzzy matching.
847    /// Works line-by-line: normalizes both sides and does sliding-window
848    /// matching over lines, then returns the byte range in the original.
849    pub fn fuzzy_find(content: &str, old_text: &str) -> Option<FuzzyMatch> {
850        let content_lines: Vec<&str> = content.lines().collect();
851        let search_norm = normalize_for_matching(old_text);
852        let search_lines: Vec<&str> = search_norm.lines().collect();
853
854        if search_lines.is_empty() {
855            return None;
856        }
857
858        let content_norm_lines: Vec<String> = content_lines
859            .iter()
860            .map(|l| normalize_unicode(l.trim_end()))
861            .collect();
862
863        if search_lines.len() > content_norm_lines.len() {
864            return None;
865        }
866
867        // Sliding window over content lines
868        let window_size = search_lines.len();
869        for start_line in 0..=(content_norm_lines.len() - window_size) {
870            let matches = content_norm_lines[start_line..start_line + window_size]
871                .iter()
872                .zip(search_lines.iter())
873                .all(|(content_line, search_line)| content_line == search_line);
874
875            if matches {
876                // Calculate byte offsets in original content
877                let byte_start: usize = content_lines[..start_line]
878                    .iter()
879                    .map(|l| l.len() + 1) // +1 for \n
880                    .sum();
881
882                let end_line = start_line + window_size - 1;
883                let byte_end: usize = content_lines[..end_line]
884                    .iter()
885                    .map(|l| l.len() + 1)
886                    .sum::<usize>()
887                    + content_lines[end_line].len();
888
889                return Some(FuzzyMatch {
890                    start: byte_start,
891                    end: byte_end,
892                });
893            }
894        }
895
896        None
897    }
898}
899
900// ── File-not-found suggestions ──────────────────────────────────────
901
902/// Compute the Levenshtein edit distance between two strings.
903///
904/// Uses a standard DP row-reduction approach — O(m*n) time, O(n) space.
905pub fn levenshtein(a: &str, b: &str) -> usize {
906    let a_chars: Vec<char> = a.chars().collect();
907    let b_chars: Vec<char> = b.chars().collect();
908    let m = a_chars.len();
909    let n = b_chars.len();
910
911    let mut prev: Vec<usize> = (0..=n).collect();
912    let mut curr = vec![0usize; n + 1];
913
914    for i in 1..=m {
915        curr[0] = i;
916        for j in 1..=n {
917            let cost = if a_chars[i - 1] == b_chars[j - 1] {
918                0
919            } else {
920                1
921            };
922            curr[j] = (prev[j] + 1).min(curr[j - 1] + 1).min(prev[j - 1] + cost);
923        }
924        std::mem::swap(&mut prev, &mut curr);
925    }
926
927    prev[n]
928}
929
930/// Search for files with names similar to the missing `target` path.
931///
932/// Extracts the filename component, walks up to 4 directory levels from `cwd`,
933/// and returns up to 3 candidates ranked by Levenshtein distance (closest first).
934/// Only files with distance ≤ 3 from the target filename are included.
935pub fn suggest_similar_files(cwd: &Path, target: &str) -> Vec<String> {
936    let target_name = Path::new(target)
937        .file_name()
938        .and_then(|n: &std::ffi::OsStr| n.to_str())
939        .unwrap_or(target);
940
941    let mut candidates: Vec<(usize, String)> = Vec::new();
942
943    // Skip directories that are typically huge and irrelevant for suggestions
944    const SKIP_DIRS: &[&str] = &[
945        "target",
946        "node_modules",
947        ".git",
948        "vendor",
949        "dist",
950        "build",
951        "__pycache__",
952        ".mypy_cache",
953        ".tox",
954        ".venv",
955    ];
956
957    let walker = walkdir::WalkDir::new(cwd)
958        .max_depth(3)
959        .follow_links(false)
960        .into_iter()
961        .filter_entry(|e| {
962            if e.file_type().is_dir() {
963                if let Some(name) = e.file_name().to_str() {
964                    return !SKIP_DIRS.contains(&name);
965                }
966            }
967            true
968        })
969        .filter_map(|e| e.ok());
970
971    for entry in walker {
972        if entry.file_type().is_file() {
973            if let Some(name) = entry.file_name().to_str() {
974                let dist = levenshtein(target_name, name);
975                if dist <= 3 {
976                    let rel = entry
977                        .path()
978                        .strip_prefix(cwd)
979                        .unwrap_or(entry.path())
980                        .display()
981                        .to_string();
982                    candidates.push((dist, rel));
983                }
984            }
985        }
986    }
987
988    candidates.sort_by_key(|(d, _)| *d);
989    candidates.truncate(3);
990    candidates.into_iter().map(|(_, p)| p).collect()
991}
992
993// ── Diff generation ─────────────────────────────────────────────────
994
995/// Generate a unified diff between old and new content.
996pub fn generate_diff(file_path: &str, old: &str, new: &str) -> String {
997    use similar::TextDiff;
998
999    let diff = TextDiff::from_lines(old, new);
1000    let mut output = String::new();
1001    output.push_str(&format!("--- {file_path}\n"));
1002    output.push_str(&format!("+++ {file_path}\n"));
1003
1004    for hunk in diff.unified_diff().context_radius(3).iter_hunks() {
1005        output.push_str(&format!("{hunk}"));
1006    }
1007
1008    output
1009}
1010
1011// ── Tool argument validation ─────────────────────────────────────────
1012
1013/// Validate tool arguments against a JSON Schema.
1014///
1015/// Returns `Ok(())` if args are valid, or `Err` with a human-readable
1016/// description of what failed. Extra/unknown fields are permitted — LLMs often
1017/// include them and tools should be lenient on input.
1018pub fn validate_tool_args(schema: &serde_json::Value, args: &serde_json::Value) -> Result<()> {
1019    use jsonschema::Validator;
1020
1021    let validator = Validator::new(schema)
1022        .map_err(|e| crate::error::Error::Tool(format!("Invalid tool schema: {e}")))?;
1023
1024    let errors: Vec<String> = validator
1025        .iter_errors(args)
1026        .map(|e| format!("{e}"))
1027        .collect();
1028
1029    if errors.is_empty() {
1030        Ok(())
1031    } else {
1032        Err(crate::error::Error::Tool(format!(
1033            "Tool argument validation failed:\n{}",
1034            errors.join("\n")
1035        )))
1036    }
1037}
1038
1039#[cfg(test)]
1040mod tests {
1041    use super::*;
1042
1043    // ── levenshtein ───────────────────────────────────────────────────
1044
1045    #[test]
1046    fn suggest_similar_levenshtein_identical() {
1047        assert_eq!(levenshtein("hello", "hello"), 0);
1048    }
1049
1050    #[test]
1051    fn suggest_similar_levenshtein_one_substitution() {
1052        assert_eq!(levenshtein("auth", "aath"), 1);
1053    }
1054
1055    #[test]
1056    fn suggest_similar_levenshtein_one_insertion() {
1057        assert_eq!(levenshtein("helo", "hello"), 1);
1058    }
1059
1060    #[test]
1061    fn suggest_similar_levenshtein_one_deletion() {
1062        assert_eq!(levenshtein("hello", "helo"), 1);
1063    }
1064
1065    #[test]
1066    fn suggest_similar_levenshtein_empty_strings() {
1067        assert_eq!(levenshtein("", ""), 0);
1068        assert_eq!(levenshtein("abc", ""), 3);
1069        assert_eq!(levenshtein("", "abc"), 3);
1070    }
1071
1072    #[test]
1073    fn suggest_similar_levenshtein_completely_different() {
1074        // "abc" vs "xyz": 3 substitutions
1075        assert_eq!(levenshtein("abc", "xyz"), 3);
1076    }
1077
1078    #[test]
1079    fn suggest_similar_levenshtein_transposition() {
1080        // "atuh" vs "auth": swap two adjacent chars = distance 2
1081        assert_eq!(levenshtein("atuh", "auth"), 2);
1082    }
1083
1084    // ── suggest_similar_files ─────────────────────────────────────────
1085
1086    #[test]
1087    fn suggest_similar_finds_close_match() {
1088        let dir = tempfile::tempdir().unwrap();
1089        std::fs::write(dir.path().join("middleware.rs"), "").unwrap();
1090        std::fs::write(dir.path().join("unrelated.rs"), "").unwrap();
1091
1092        let suggestions = suggest_similar_files(dir.path(), "middlewar.rs");
1093        assert!(
1094            suggestions.iter().any(|s| s.contains("middleware.rs")),
1095            "expected middleware.rs in suggestions, got: {suggestions:?}"
1096        );
1097    }
1098
1099    #[test]
1100    fn suggest_similar_returns_at_most_three() {
1101        let dir = tempfile::tempdir().unwrap();
1102        // Create five files each 1 edit away from "xod.rs"
1103        for name in &["mod.rs", "rod.rs", "cod.rs", "nod.rs", "pod.rs"] {
1104            std::fs::write(dir.path().join(name), "").unwrap();
1105        }
1106
1107        let suggestions = suggest_similar_files(dir.path(), "xod.rs");
1108        assert!(suggestions.len() <= 3);
1109    }
1110
1111    #[test]
1112    fn suggest_similar_nothing_close_returns_empty() {
1113        let dir = tempfile::tempdir().unwrap();
1114        std::fs::write(dir.path().join("completely_different.rs"), "").unwrap();
1115
1116        // "a.rs" is far from "completely_different.rs"
1117        let suggestions = suggest_similar_files(dir.path(), "a.rs");
1118        assert!(
1119            suggestions.is_empty(),
1120            "expected no suggestions, got: {suggestions:?}"
1121        );
1122    }
1123
1124    #[test]
1125    fn suggest_similar_ranks_closer_matches_first() {
1126        let dir = tempfile::tempdir().unwrap();
1127        std::fs::write(dir.path().join("auth.rs"), "").unwrap();
1128        std::fs::write(dir.path().join("autho.rs"), "").unwrap();
1129
1130        let suggestions = suggest_similar_files(dir.path(), "atuh.rs");
1131        assert!(!suggestions.is_empty());
1132        assert!(
1133            suggestions.iter().any(|s| s.contains("auth.rs")),
1134            "expected auth.rs, got: {suggestions:?}"
1135        );
1136    }
1137
1138    fn simple_schema() -> serde_json::Value {
1139        serde_json::json!({
1140            "type": "object",
1141            "properties": {
1142                "path": { "type": "string" },
1143                "count": { "type": "integer" }
1144            },
1145            "required": ["path"]
1146        })
1147    }
1148
1149    #[test]
1150    fn validate_tool_args_valid_passes() {
1151        let schema = simple_schema();
1152        let args = serde_json::json!({ "path": "/tmp/foo.txt" });
1153        assert!(validate_tool_args(&schema, &args).is_ok());
1154    }
1155
1156    #[test]
1157    fn validate_tool_args_valid_with_optional_passes() {
1158        let schema = simple_schema();
1159        let args = serde_json::json!({ "path": "/tmp/foo.txt", "count": 5 });
1160        assert!(validate_tool_args(&schema, &args).is_ok());
1161    }
1162
1163    #[test]
1164    fn validate_tool_args_missing_required_returns_error() {
1165        let schema = simple_schema();
1166        // Missing the required "path" field
1167        let args = serde_json::json!({ "count": 5 });
1168        let result = validate_tool_args(&schema, &args);
1169        assert!(result.is_err());
1170        let msg = result.unwrap_err().to_string();
1171        assert!(
1172            msg.contains("path") || msg.contains("required"),
1173            "expected error mentioning 'path' or 'required', got: {msg}"
1174        );
1175    }
1176
1177    #[test]
1178    fn validate_tool_args_wrong_type_returns_error() {
1179        let schema = simple_schema();
1180        // "count" must be integer, not string
1181        let args = serde_json::json!({ "path": "/tmp/foo.txt", "count": "not-a-number" });
1182        let result = validate_tool_args(&schema, &args);
1183        assert!(result.is_err());
1184        let msg = result.unwrap_err().to_string();
1185        assert!(
1186            msg.contains("integer") || msg.contains("type"),
1187            "expected type error, got: {msg}"
1188        );
1189    }
1190
1191    #[test]
1192    fn validate_tool_args_extra_fields_allowed() {
1193        // LLMs often add extra fields — we should not reject them
1194        let schema = simple_schema();
1195        let args = serde_json::json!({
1196            "path": "/tmp/foo.txt",
1197            "llm_added_extra": "some value",
1198            "another_unknown": 42
1199        });
1200        assert!(
1201            validate_tool_args(&schema, &args).is_ok(),
1202            "extra/unknown fields should be allowed"
1203        );
1204    }
1205
1206    // ── FileTracker ───────────────────────────────────────────────────
1207
1208    #[test]
1209    fn file_track_was_read_false_for_unread_file() {
1210        let dir = tempfile::tempdir().unwrap();
1211        let file = dir.path().join("test.txt");
1212        std::fs::write(&file, "content").unwrap();
1213
1214        let tracker = FileTracker::new();
1215        assert!(!tracker.was_read(&file), "unread file should return false");
1216    }
1217
1218    #[test]
1219    fn file_track_was_read_true_after_recording() {
1220        let dir = tempfile::tempdir().unwrap();
1221        let file = dir.path().join("test.txt");
1222        std::fs::write(&file, "content").unwrap();
1223
1224        let mut tracker = FileTracker::new();
1225        tracker.record_read(&file);
1226        assert!(
1227            tracker.was_read(&file),
1228            "file should be marked as read after recording"
1229        );
1230    }
1231
1232    #[test]
1233    fn file_track_is_stale_false_for_unread_file() {
1234        let dir = tempfile::tempdir().unwrap();
1235        let file = dir.path().join("test.txt");
1236        std::fs::write(&file, "content").unwrap();
1237
1238        let tracker = FileTracker::new();
1239        // Unread file is never stale (no baseline to compare against)
1240        assert!(!tracker.is_stale(&file));
1241    }
1242
1243    #[test]
1244    fn file_track_is_stale_false_immediately_after_read() {
1245        let dir = tempfile::tempdir().unwrap();
1246        let file = dir.path().join("test.txt");
1247        std::fs::write(&file, "content").unwrap();
1248
1249        let mut tracker = FileTracker::new();
1250        tracker.record_read(&file);
1251        // No modification since read — should not be stale
1252        assert!(!tracker.is_stale(&file));
1253    }
1254
1255    #[test]
1256    fn file_track_is_stale_detects_external_modification() {
1257        let dir = tempfile::tempdir().unwrap();
1258        let file = dir.path().join("test.txt");
1259        std::fs::write(&file, "original content").unwrap();
1260
1261        let mut tracker = FileTracker::new();
1262        tracker.record_read(&file);
1263
1264        // Set the file's mtime to 2 seconds in the future to guarantee a detectable change.
1265        // std::fs::File::set_modified is stable since Rust 1.75 and needs no extra crate.
1266        let future = std::time::SystemTime::now() + std::time::Duration::from_secs(2);
1267        if let Ok(f) = std::fs::OpenOptions::new().write(true).open(&file) {
1268            let _ = f.set_modified(future);
1269        }
1270
1271        assert!(
1272            tracker.is_stale(&file),
1273            "file with advanced mtime should be stale"
1274        );
1275    }
1276
1277    // ── FileHistory tests ─────────────────────────────────────
1278
1279    #[test]
1280    fn file_history_snapshot_stores_original() {
1281        let dir = tempfile::tempdir().unwrap();
1282        let file = dir.path().join("test.rs");
1283        std::fs::write(&file, "fn main() {}").unwrap();
1284
1285        let history = FileHistory::new();
1286        history.snapshot_before_edit(&file).unwrap();
1287
1288        assert_eq!(history.original(&file).unwrap(), "fn main() {}");
1289    }
1290
1291    #[test]
1292    fn file_history_second_snapshot_is_noop() {
1293        let dir = tempfile::tempdir().unwrap();
1294        let file = dir.path().join("test.rs");
1295        std::fs::write(&file, "original").unwrap();
1296
1297        let history = FileHistory::new();
1298        history.snapshot_before_edit(&file).unwrap();
1299
1300        // Modify the file and snapshot again — should keep original
1301        std::fs::write(&file, "modified").unwrap();
1302        history.snapshot_before_edit(&file).unwrap();
1303
1304        assert_eq!(history.original(&file).unwrap(), "original");
1305    }
1306
1307    #[test]
1308    fn file_history_rollback_restores_original() {
1309        let dir = tempfile::tempdir().unwrap();
1310        let file = dir.path().join("test.rs");
1311        std::fs::write(&file, "original content").unwrap();
1312
1313        let history = FileHistory::new();
1314        history.snapshot_before_edit(&file).unwrap();
1315
1316        std::fs::write(&file, "agent wrote this").unwrap();
1317        history.rollback(&file).unwrap();
1318
1319        assert_eq!(std::fs::read_to_string(&file).unwrap(), "original content");
1320    }
1321
1322    #[test]
1323    fn file_history_skips_nonexistent_files() {
1324        let dir = tempfile::tempdir().unwrap();
1325        let file = dir.path().join("does_not_exist.rs");
1326
1327        let history = FileHistory::new();
1328        history.snapshot_before_edit(&file).unwrap();
1329
1330        assert!(history.original(&file).is_none());
1331    }
1332
1333    #[test]
1334    fn file_history_tracked_files_lists_all() {
1335        let dir = tempfile::tempdir().unwrap();
1336        let f1 = dir.path().join("a.rs");
1337        let f2 = dir.path().join("b.rs");
1338        std::fs::write(&f1, "a").unwrap();
1339        std::fs::write(&f2, "b").unwrap();
1340
1341        let history = FileHistory::new();
1342        history.snapshot_before_edit(&f1).unwrap();
1343        history.snapshot_before_edit(&f2).unwrap();
1344
1345        let tracked = history.tracked_files();
1346        assert_eq!(tracked.len(), 2);
1347        assert!(tracked.contains(&f1));
1348        assert!(tracked.contains(&f2));
1349    }
1350}