1pub mod ask;
2pub mod bash;
3pub mod edit;
4pub mod git;
5pub mod lua;
6pub mod mana;
7pub mod memory;
8pub mod multi_edit;
9pub mod query;
10pub mod read;
11pub mod scan;
12pub mod session_search;
13pub mod shell;
14pub mod web;
15pub mod worktree;
16pub mod write;
17
18use std::collections::hash_map::DefaultHasher;
19use std::collections::HashMap;
20use std::hash::{Hash, Hasher};
21use std::path::{Path, PathBuf};
22use std::sync::Arc;
23
24use async_trait::async_trait;
25use imp_llm::provider::ToolDefinition;
26use imp_llm::{ContentBlock, ToolResultMessage};
27
28use crate::agent::AgentCommand;
29use crate::config::AgentMode;
30use crate::config::LuaCapabilityPolicy;
31use crate::error::Result;
32use crate::mana_review::TurnManaReviewAccumulator;
33use crate::reference_monitor::ToolMetadata;
34use crate::trust::Provenance;
35use crate::ui::UserInterface;
36
37pub fn resolve_path(cwd: &Path, raw: &str) -> PathBuf {
39 if raw == "~" {
40 if let Ok(home) = std::env::var("HOME") {
41 return PathBuf::from(home);
42 }
43 } else if let Some(rest) = raw.strip_prefix("~/") {
44 if let Ok(home) = std::env::var("HOME") {
45 return PathBuf::from(home).join(rest);
46 }
47 }
48 let p = Path::new(raw);
49 if p.is_absolute() {
50 p.to_path_buf()
51 } else {
52 cwd.join(p)
53 }
54}
55
56#[async_trait]
58pub trait Tool: Send + Sync {
59 fn name(&self) -> &str;
61
62 fn label(&self) -> &str;
64
65 fn description(&self) -> &str;
67
68 fn parameters(&self) -> serde_json::Value;
70
71 fn is_readonly(&self) -> bool;
73
74 fn policy_metadata(&self) -> ToolMetadata {
76 ToolMetadata::for_tool_name(self.name(), self.is_readonly())
77 }
78
79 async fn execute(
81 &self,
82 call_id: &str,
83 params: serde_json::Value,
84 ctx: ToolContext,
85 ) -> Result<ToolOutput>;
86}
87
88pub struct FileTracker {
92 reads: HashMap<PathBuf, std::time::SystemTime>,
93}
94
95impl Default for FileTracker {
96 fn default() -> Self {
97 Self::new()
98 }
99}
100
101impl FileTracker {
102 pub fn new() -> Self {
103 Self {
104 reads: HashMap::new(),
105 }
106 }
107
108 pub fn record_read(&mut self, path: &Path) {
110 let mtime = std::fs::metadata(path)
111 .and_then(|m| m.modified())
112 .unwrap_or(std::time::UNIX_EPOCH);
113 self.reads.insert(path.to_path_buf(), mtime);
114 }
115
116 pub fn was_read(&self, path: &Path) -> bool {
118 self.reads.contains_key(path)
119 }
120
121 pub fn is_stale(&self, path: &Path) -> bool {
125 let Some(&recorded_mtime) = self.reads.get(path) else {
126 return false;
127 };
128 let Ok(current_mtime) = std::fs::metadata(path).and_then(|m| m.modified()) else {
129 return false;
130 };
131 current_mtime != recorded_mtime
132 }
133}
134
135#[derive(Debug, Clone, PartialEq, Eq)]
136pub struct LineAnchor {
137 pub id: String,
138 pub line: usize,
139 pub content_hash: u64,
140}
141
142#[derive(Debug, Default)]
143pub struct AnchorStore {
144 files: std::sync::Mutex<HashMap<PathBuf, HashMap<String, LineAnchor>>>,
145}
146
147impl AnchorStore {
148 pub fn new() -> Self {
149 Self::default()
150 }
151
152 pub fn record_lines(
153 &self,
154 path: &Path,
155 file_hash: u64,
156 start_line: usize,
157 lines: &[&str],
158 ) -> Vec<LineAnchor> {
159 let anchors = lines
160 .iter()
161 .enumerate()
162 .map(|(idx, line)| {
163 let line_number = start_line + idx;
164 let content_hash = stable_hash(line);
165 LineAnchor {
166 id: format!(
167 "a{:016x}{:08x}{:016x}",
168 file_hash, line_number, content_hash
169 ),
170 line: line_number,
171 content_hash,
172 }
173 })
174 .collect::<Vec<_>>();
175
176 if let Ok(mut files) = self.files.lock() {
177 let entry = files.entry(path.to_path_buf()).or_default();
178 for anchor in &anchors {
179 entry.insert(anchor.id.clone(), anchor.clone());
180 }
181 }
182
183 anchors
184 }
185
186 pub fn get(&self, path: &Path, id: &str) -> Option<LineAnchor> {
187 self.files
188 .lock()
189 .ok()?
190 .get(path)
191 .and_then(|anchors| anchors.get(id).cloned())
192 }
193}
194
195pub fn stable_hash<T: Hash>(value: T) -> u64 {
196 let mut hasher = DefaultHasher::new();
197 value.hash(&mut hasher);
198 hasher.finish()
199}
200
201pub type LuaToolLoader = Arc<dyn Fn(&LuaCapabilityPolicy, &mut ToolRegistry) + Send + Sync>;
203
204#[derive(Clone)]
206pub struct ToolContext {
207 pub cwd: PathBuf,
208 pub cancelled: Arc<std::sync::atomic::AtomicBool>,
209 pub update_tx: tokio::sync::mpsc::Sender<ToolUpdate>,
210 pub command_tx: tokio::sync::mpsc::Sender<AgentCommand>,
211 pub ui: Arc<dyn UserInterface>,
212 pub file_cache: Arc<FileCache>,
213 pub checkpoint_state: Arc<CheckpointState>,
215 pub file_tracker: Arc<std::sync::Mutex<FileTracker>>,
217 pub anchor_store: Arc<AnchorStore>,
219 pub lua_tool_loader: Option<LuaToolLoader>,
221 pub mode: AgentMode,
223 pub read_max_lines: usize,
225 pub turn_mana_review: Arc<std::sync::Mutex<TurnManaReviewAccumulator>>,
227 pub config: Arc<crate::config::Config>,
229 pub run_policy: crate::policy::RunPolicy,
231 pub supporting_provenance: Vec<Provenance>,
233}
234
235pub struct FileCache {
237 entries: std::sync::Mutex<std::collections::HashMap<PathBuf, FileCacheEntry>>,
238}
239
240struct FileCacheEntry {
241 mtime: std::time::SystemTime,
242 content: String,
243}
244
245impl Default for FileCache {
246 fn default() -> Self {
247 Self::new()
248 }
249}
250
251impl FileCache {
252 pub fn new() -> Self {
253 Self {
254 entries: std::sync::Mutex::new(std::collections::HashMap::new()),
255 }
256 }
257
258 pub fn read(&self, path: &Path) -> std::io::Result<String> {
260 let metadata = std::fs::metadata(path)?;
261 let mtime = metadata.modified().unwrap_or(std::time::UNIX_EPOCH);
262
263 {
264 let cache = self.entries.lock().unwrap();
265 if let Some(entry) = cache.get(path) {
266 if entry.mtime == mtime {
267 return Ok(entry.content.clone());
268 }
269 }
270 }
271
272 let content = std::fs::read_to_string(path)?;
273
274 {
275 let mut cache = self.entries.lock().unwrap();
276 cache.insert(
277 path.to_path_buf(),
278 FileCacheEntry {
279 mtime,
280 content: content.clone(),
281 },
282 );
283 }
284
285 Ok(content)
286 }
287
288 pub fn invalidate(&self, path: &Path) {
290 let mut cache = self.entries.lock().unwrap();
291 cache.remove(path);
292 }
293}
294
295pub struct FileHistory {
301 originals: std::sync::Mutex<HashMap<PathBuf, String>>,
302}
303
304#[derive(Debug, Clone, PartialEq, Eq)]
305pub struct CheckpointRecord {
306 pub id: String,
307 pub label: Option<String>,
308 pub created_at: u64,
309 pub files: Vec<PathBuf>,
310}
311
312pub struct CheckpointState {
318 history: FileHistory,
319 records: std::sync::Mutex<Vec<CheckpointRecord>>,
320}
321
322impl Default for CheckpointState {
323 fn default() -> Self {
324 Self::new()
325 }
326}
327
328impl CheckpointState {
329 pub fn new() -> Self {
330 Self {
331 history: FileHistory::new(),
332 records: std::sync::Mutex::new(Vec::new()),
333 }
334 }
335
336 pub fn snapshot_paths(
338 &self,
339 paths: &[PathBuf],
340 label: Option<String>,
341 ) -> std::io::Result<Option<CheckpointRecord>> {
342 let mut unique = Vec::new();
343 for path in paths {
344 if !unique.iter().any(|existing: &PathBuf| existing == path) {
345 unique.push(path.clone());
346 }
347 }
348
349 let mut captured = Vec::new();
350 for path in unique {
351 self.history.snapshot_before_edit(&path)?;
352 if self.history.original(&path).is_some() {
353 captured.push(path);
354 }
355 }
356
357 if captured.is_empty() {
358 return Ok(None);
359 }
360
361 let record = CheckpointRecord {
362 id: uuid::Uuid::new_v4().to_string(),
363 label,
364 created_at: imp_llm::now(),
365 files: captured,
366 };
367 self.records.lock().unwrap().push(record.clone());
368 Ok(Some(record))
369 }
370
371 pub fn checkpoints(&self) -> Vec<CheckpointRecord> {
372 self.records.lock().unwrap().clone()
373 }
374
375 pub fn checkpoint(&self, id: &str) -> Option<CheckpointRecord> {
376 self.records
377 .lock()
378 .unwrap()
379 .iter()
380 .find(|record| record.id == id)
381 .cloned()
382 }
383
384 pub fn restore_checkpoint(&self, id: &str) -> std::io::Result<Vec<PathBuf>> {
385 let Some(record) = self.checkpoint(id) else {
386 return Ok(Vec::new());
387 };
388
389 let mut restored = Vec::new();
390 for path in &record.files {
391 self.history.rollback(path)?;
392 restored.push(path.clone());
393 }
394 Ok(restored)
395 }
396
397 pub fn rollback(&self, path: &Path) -> std::io::Result<()> {
398 self.history.rollback(path)
399 }
400
401 pub fn tracked_files(&self) -> Vec<PathBuf> {
402 self.history.tracked_files()
403 }
404
405 pub fn original(&self, path: &Path) -> Option<String> {
406 self.history.original(path)
407 }
408}
409
410impl Default for FileHistory {
411 fn default() -> Self {
412 Self::new()
413 }
414}
415
416impl FileHistory {
417 pub fn new() -> Self {
418 Self {
419 originals: std::sync::Mutex::new(HashMap::new()),
420 }
421 }
422
423 pub fn snapshot_before_edit(&self, path: &Path) -> std::io::Result<()> {
426 let canonical = path.to_path_buf();
427
428 let mut originals = self.originals.lock().unwrap();
429 if originals.contains_key(&canonical) {
430 return Ok(()); }
432 if canonical.exists() {
433 let content = std::fs::read_to_string(&canonical)?;
434 originals.insert(canonical, content);
435 }
436 Ok(())
437 }
438
439 pub fn original(&self, path: &Path) -> Option<String> {
441 self.originals.lock().unwrap().get(path).cloned()
442 }
443
444 pub fn rollback(&self, path: &Path) -> std::io::Result<()> {
446 let originals = self.originals.lock().unwrap();
447 if let Some(content) = originals.get(path) {
448 std::fs::write(path, content)?;
449 }
450 Ok(())
451 }
452
453 pub fn tracked_files(&self) -> Vec<PathBuf> {
455 self.originals.lock().unwrap().keys().cloned().collect()
456 }
457}
458
459impl ToolContext {
460 pub fn is_cancelled(&self) -> bool {
461 self.cancelled.load(std::sync::atomic::Ordering::Relaxed)
462 }
463
464 pub fn check_write_path(&self, path: &Path) -> std::result::Result<(), String> {
465 match self.run_policy.check_write_path(&self.cwd, path) {
466 crate::policy::WritePolicyDecision::Allowed => Ok(()),
467 crate::policy::WritePolicyDecision::Denied(reason) => Err(reason),
468 }
469 }
470}
471
472pub struct ToolOutput {
474 pub content: Vec<ContentBlock>,
475 pub details: serde_json::Value,
476 pub is_error: bool,
477}
478
479impl ToolOutput {
480 pub fn text(text: impl Into<String>) -> Self {
481 Self {
482 content: vec![ContentBlock::Text { text: text.into() }],
483 details: serde_json::Value::Null,
484 is_error: false,
485 }
486 }
487
488 pub fn error(text: impl Into<String>) -> Self {
489 Self {
490 content: vec![ContentBlock::Text { text: text.into() }],
491 details: serde_json::Value::Null,
492 is_error: true,
493 }
494 }
495
496 pub fn text_content(&self) -> Option<&str> {
498 self.content.iter().find_map(|b| match b {
499 ContentBlock::Text { text } => Some(text.as_str()),
500 _ => None,
501 })
502 }
503
504 pub fn into_tool_result(self, call_id: &str, tool_name: &str) -> ToolResultMessage {
505 ToolResultMessage {
506 tool_call_id: call_id.to_string(),
507 tool_name: tool_name.to_string(),
508 content: self.content,
509 is_error: self.is_error,
510 details: self.details,
511 timestamp: imp_llm::now(),
512 }
513 }
514}
515
516pub struct ToolUpdate {
518 pub content: Vec<ContentBlock>,
519 pub details: serde_json::Value,
520}
521
522pub struct ToolRegistry {
524 tools: HashMap<String, Arc<dyn Tool>>,
525 aliases: HashMap<String, String>,
526}
527
528impl ToolRegistry {
529 pub fn new() -> Self {
530 Self {
531 tools: HashMap::new(),
532 aliases: HashMap::new(),
533 }
534 }
535
536 pub fn register(&mut self, tool: Arc<dyn Tool>) {
538 self.tools.insert(tool.name().to_string(), tool);
539 }
540
541 pub fn register_alias(&mut self, alias: impl Into<String>, canonical: impl Into<String>) {
546 self.aliases.insert(alias.into(), canonical.into());
547 }
548
549 pub fn extend(&mut self, other: ToolRegistry) {
550 for tool in other.tools.into_values() {
551 self.register(tool);
552 }
553 for (alias, canonical) in other.aliases {
554 self.register_alias(alias, canonical);
555 }
556 }
557
558 pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
560 if let Some(tool) = self.tools.get(name) {
561 return Some(tool);
562 }
563
564 self.aliases
565 .get(name)
566 .and_then(|canonical| self.tools.get(canonical))
567 }
568
569 pub fn tools_map(&self) -> HashMap<String, Arc<dyn Tool>> {
571 let mut map = self.tools.clone();
572 for (alias, canonical) in &self.aliases {
573 if let Some(tool) = self.tools.get(canonical) {
574 map.insert(alias.clone(), Arc::clone(tool));
575 }
576 }
577 map
578 }
579
580 pub fn definitions(&self) -> Vec<ToolDefinition> {
584 let mut defs: Vec<_> = self
585 .tools
586 .values()
587 .map(|t| ToolDefinition {
588 name: t.name().to_string(),
589 description: t.description().to_string(),
590 parameters: t.parameters(),
591 })
592 .collect();
593 defs.sort_by(|a, b| a.name.cmp(&b.name));
594 defs
595 }
596
597 pub fn readonly_definitions(&self) -> Vec<ToolDefinition> {
599 let mut defs: Vec<_> = self
600 .tools
601 .values()
602 .filter(|t| t.is_readonly())
603 .map(|t| ToolDefinition {
604 name: t.name().to_string(),
605 description: t.description().to_string(),
606 parameters: t.parameters(),
607 })
608 .collect();
609 defs.sort_by(|a, b| a.name.cmp(&b.name));
610 defs
611 }
612
613 pub fn names(&self) -> Vec<String> {
615 self.tools.keys().cloned().collect()
616 }
617
618 pub fn retain<F>(&mut self, predicate: F)
623 where
624 F: Fn(&str) -> bool,
625 {
626 self.tools.retain(|name, _| predicate(name));
627 self.aliases
628 .retain(|_, canonical| self.tools.contains_key(canonical));
629 }
630
631 pub fn definitions_for_mode(
636 &self,
637 mode: &crate::config::AgentMode,
638 ) -> Vec<imp_llm::provider::ToolDefinition> {
639 let mut defs: Vec<_> = self
640 .tools
641 .values()
642 .filter(|t| mode.allows_tool(t.name()))
643 .map(|t| imp_llm::provider::ToolDefinition {
644 name: t.name().to_string(),
645 description: t.description().to_string(),
646 parameters: t.parameters(),
647 })
648 .collect();
649 defs.sort_by(|a, b| a.name.cmp(&b.name));
650 defs
651 }
652
653 pub fn policy_metadata(&self, name: &str) -> Option<ToolMetadata> {
655 self.get(name).map(|tool| tool.policy_metadata())
656 }
657
658 pub fn len(&self) -> usize {
660 self.tools.len()
661 }
662
663 pub fn is_empty(&self) -> bool {
664 self.tools.is_empty()
665 }
666}
667
668impl Default for ToolRegistry {
669 fn default() -> Self {
670 Self::new()
671 }
672}
673
674pub struct TruncationResult {
677 pub content: String,
678 pub truncated: bool,
679 pub output_lines: usize,
680 pub total_lines: usize,
681 pub output_bytes: usize,
682 pub total_bytes: usize,
683 pub temp_file: Option<PathBuf>,
684}
685
686pub fn truncate_line(line: &str, max_bytes: usize) -> String {
688 if line.len() <= max_bytes {
689 return line.to_string();
690 }
691 let mut end = max_bytes.min(line.len());
692 while end > 0 && !line.is_char_boundary(end) {
693 end -= 1;
694 }
695 format!("{}…", &line[..end])
696}
697
698fn write_temp_file(content: &str) -> Option<PathBuf> {
700 let dir = std::env::temp_dir().join("imp-tools");
701 std::fs::create_dir_all(&dir).ok()?;
702 let name = format!("truncated-{}.txt", uuid::Uuid::new_v4());
703 let path = dir.join(name);
704 std::fs::write(&path, content).ok()?;
705 Some(path)
706}
707
708pub fn truncate_head(input: &str, max_lines: usize, max_bytes: usize) -> TruncationResult {
711 let lines: Vec<&str> = input.lines().collect();
712 let total_lines = lines.len();
713 let total_bytes = input.len();
714
715 if total_lines <= max_lines && total_bytes <= max_bytes {
716 return TruncationResult {
717 content: input.to_string(),
718 truncated: false,
719 output_lines: total_lines,
720 total_lines,
721 output_bytes: total_bytes,
722 total_bytes,
723 temp_file: None,
724 };
725 }
726
727 let mut result = String::new();
728 let mut byte_count = 0;
729 let mut line_count = 0;
730
731 for line in &lines {
732 let line_with_newline = format!("{line}\n");
733 if line_count >= max_lines || byte_count + line_with_newline.len() > max_bytes {
734 break;
735 }
736 result.push_str(&line_with_newline);
737 byte_count += line_with_newline.len();
738 line_count += 1;
739 }
740
741 let temp_file = write_temp_file(input);
742
743 TruncationResult {
744 content: result,
745 truncated: true,
746 output_lines: line_count,
747 total_lines,
748 output_bytes: byte_count,
749 total_bytes,
750 temp_file,
751 }
752}
753
754pub fn truncate_tail(input: &str, max_lines: usize, max_bytes: usize) -> TruncationResult {
757 let lines: Vec<&str> = input.lines().collect();
758 let total_lines = lines.len();
759 let total_bytes = input.len();
760
761 if total_lines <= max_lines && total_bytes <= max_bytes {
762 return TruncationResult {
763 content: input.to_string(),
764 truncated: false,
765 output_lines: total_lines,
766 total_lines,
767 output_bytes: total_bytes,
768 total_bytes,
769 temp_file: None,
770 };
771 }
772
773 let start = total_lines.saturating_sub(max_lines);
775 let mut actual_start = start;
776 let mut remaining_bytes = max_bytes;
777
778 for (i, line) in lines[start..].iter().enumerate() {
779 let line_with_newline = format!("{line}\n");
780 if line_with_newline.len() > remaining_bytes {
781 actual_start = start + i + 1;
782 remaining_bytes = max_bytes;
783 for line2 in &lines[actual_start..] {
785 let l = format!("{line2}\n");
786 if l.len() > remaining_bytes {
787 break;
788 }
789 remaining_bytes -= l.len();
790 }
791 break;
792 }
793 remaining_bytes -= line_with_newline.len();
794 }
795
796 let mut result = String::new();
797 for line in &lines[actual_start..] {
798 result.push_str(&format!("{line}\n"));
799 }
800
801 let output_lines = total_lines - actual_start;
802 let output_bytes = result.len();
803 let temp_file = write_temp_file(input);
804
805 TruncationResult {
806 content: result,
807 truncated: true,
808 output_lines,
809 total_lines,
810 output_bytes,
811 total_bytes,
812 temp_file,
813 }
814}
815
816pub(crate) mod fuzzy {
819 pub fn normalize_for_matching(text: &str) -> String {
822 text.lines()
823 .map(|line| normalize_unicode(line.trim_end()))
824 .collect::<Vec<_>>()
825 .join("\n")
826 }
827
828 fn normalize_unicode(s: &str) -> String {
829 s.chars()
830 .map(|c| match c {
831 '\u{2018}' | '\u{2019}' => '\'',
832 '\u{201C}' | '\u{201D}' => '"',
833 '\u{2013}' | '\u{2014}' => '-',
834 '\u{00A0}' | '\u{2003}' | '\u{2002}' | '\u{2009}' => ' ',
835 other => other,
836 })
837 .collect()
838 }
839
840 pub struct FuzzyMatch {
842 pub start: usize,
843 pub end: usize,
844 }
845
846 pub fn fuzzy_find(content: &str, old_text: &str) -> Option<FuzzyMatch> {
850 let content_lines: Vec<&str> = content.lines().collect();
851 let search_norm = normalize_for_matching(old_text);
852 let search_lines: Vec<&str> = search_norm.lines().collect();
853
854 if search_lines.is_empty() {
855 return None;
856 }
857
858 let content_norm_lines: Vec<String> = content_lines
859 .iter()
860 .map(|l| normalize_unicode(l.trim_end()))
861 .collect();
862
863 if search_lines.len() > content_norm_lines.len() {
864 return None;
865 }
866
867 let window_size = search_lines.len();
869 for start_line in 0..=(content_norm_lines.len() - window_size) {
870 let matches = content_norm_lines[start_line..start_line + window_size]
871 .iter()
872 .zip(search_lines.iter())
873 .all(|(content_line, search_line)| content_line == search_line);
874
875 if matches {
876 let byte_start: usize = content_lines[..start_line]
878 .iter()
879 .map(|l| l.len() + 1) .sum();
881
882 let end_line = start_line + window_size - 1;
883 let byte_end: usize = content_lines[..end_line]
884 .iter()
885 .map(|l| l.len() + 1)
886 .sum::<usize>()
887 + content_lines[end_line].len();
888
889 return Some(FuzzyMatch {
890 start: byte_start,
891 end: byte_end,
892 });
893 }
894 }
895
896 None
897 }
898}
899
900pub fn levenshtein(a: &str, b: &str) -> usize {
906 let a_chars: Vec<char> = a.chars().collect();
907 let b_chars: Vec<char> = b.chars().collect();
908 let m = a_chars.len();
909 let n = b_chars.len();
910
911 let mut prev: Vec<usize> = (0..=n).collect();
912 let mut curr = vec![0usize; n + 1];
913
914 for i in 1..=m {
915 curr[0] = i;
916 for j in 1..=n {
917 let cost = if a_chars[i - 1] == b_chars[j - 1] {
918 0
919 } else {
920 1
921 };
922 curr[j] = (prev[j] + 1).min(curr[j - 1] + 1).min(prev[j - 1] + cost);
923 }
924 std::mem::swap(&mut prev, &mut curr);
925 }
926
927 prev[n]
928}
929
930pub fn suggest_similar_files(cwd: &Path, target: &str) -> Vec<String> {
936 let target_name = Path::new(target)
937 .file_name()
938 .and_then(|n: &std::ffi::OsStr| n.to_str())
939 .unwrap_or(target);
940
941 let mut candidates: Vec<(usize, String)> = Vec::new();
942
943 const SKIP_DIRS: &[&str] = &[
945 "target",
946 "node_modules",
947 ".git",
948 "vendor",
949 "dist",
950 "build",
951 "__pycache__",
952 ".mypy_cache",
953 ".tox",
954 ".venv",
955 ];
956
957 let walker = walkdir::WalkDir::new(cwd)
958 .max_depth(3)
959 .follow_links(false)
960 .into_iter()
961 .filter_entry(|e| {
962 if e.file_type().is_dir() {
963 if let Some(name) = e.file_name().to_str() {
964 return !SKIP_DIRS.contains(&name);
965 }
966 }
967 true
968 })
969 .filter_map(|e| e.ok());
970
971 for entry in walker {
972 if entry.file_type().is_file() {
973 if let Some(name) = entry.file_name().to_str() {
974 let dist = levenshtein(target_name, name);
975 if dist <= 3 {
976 let rel = entry
977 .path()
978 .strip_prefix(cwd)
979 .unwrap_or(entry.path())
980 .display()
981 .to_string();
982 candidates.push((dist, rel));
983 }
984 }
985 }
986 }
987
988 candidates.sort_by_key(|(d, _)| *d);
989 candidates.truncate(3);
990 candidates.into_iter().map(|(_, p)| p).collect()
991}
992
993pub fn generate_diff(file_path: &str, old: &str, new: &str) -> String {
997 use similar::TextDiff;
998
999 let diff = TextDiff::from_lines(old, new);
1000 let mut output = String::new();
1001 output.push_str(&format!("--- {file_path}\n"));
1002 output.push_str(&format!("+++ {file_path}\n"));
1003
1004 for hunk in diff.unified_diff().context_radius(3).iter_hunks() {
1005 output.push_str(&format!("{hunk}"));
1006 }
1007
1008 output
1009}
1010
1011pub fn validate_tool_args(schema: &serde_json::Value, args: &serde_json::Value) -> Result<()> {
1019 use jsonschema::Validator;
1020
1021 let validator = Validator::new(schema)
1022 .map_err(|e| crate::error::Error::Tool(format!("Invalid tool schema: {e}")))?;
1023
1024 let errors: Vec<String> = validator
1025 .iter_errors(args)
1026 .map(|e| format!("{e}"))
1027 .collect();
1028
1029 if errors.is_empty() {
1030 Ok(())
1031 } else {
1032 Err(crate::error::Error::Tool(format!(
1033 "Tool argument validation failed:\n{}",
1034 errors.join("\n")
1035 )))
1036 }
1037}
1038
1039#[cfg(test)]
1040mod tests {
1041 use super::*;
1042
1043 #[test]
1046 fn suggest_similar_levenshtein_identical() {
1047 assert_eq!(levenshtein("hello", "hello"), 0);
1048 }
1049
1050 #[test]
1051 fn suggest_similar_levenshtein_one_substitution() {
1052 assert_eq!(levenshtein("auth", "aath"), 1);
1053 }
1054
1055 #[test]
1056 fn suggest_similar_levenshtein_one_insertion() {
1057 assert_eq!(levenshtein("helo", "hello"), 1);
1058 }
1059
1060 #[test]
1061 fn suggest_similar_levenshtein_one_deletion() {
1062 assert_eq!(levenshtein("hello", "helo"), 1);
1063 }
1064
1065 #[test]
1066 fn suggest_similar_levenshtein_empty_strings() {
1067 assert_eq!(levenshtein("", ""), 0);
1068 assert_eq!(levenshtein("abc", ""), 3);
1069 assert_eq!(levenshtein("", "abc"), 3);
1070 }
1071
1072 #[test]
1073 fn suggest_similar_levenshtein_completely_different() {
1074 assert_eq!(levenshtein("abc", "xyz"), 3);
1076 }
1077
1078 #[test]
1079 fn suggest_similar_levenshtein_transposition() {
1080 assert_eq!(levenshtein("atuh", "auth"), 2);
1082 }
1083
1084 #[test]
1087 fn suggest_similar_finds_close_match() {
1088 let dir = tempfile::tempdir().unwrap();
1089 std::fs::write(dir.path().join("middleware.rs"), "").unwrap();
1090 std::fs::write(dir.path().join("unrelated.rs"), "").unwrap();
1091
1092 let suggestions = suggest_similar_files(dir.path(), "middlewar.rs");
1093 assert!(
1094 suggestions.iter().any(|s| s.contains("middleware.rs")),
1095 "expected middleware.rs in suggestions, got: {suggestions:?}"
1096 );
1097 }
1098
1099 #[test]
1100 fn suggest_similar_returns_at_most_three() {
1101 let dir = tempfile::tempdir().unwrap();
1102 for name in &["mod.rs", "rod.rs", "cod.rs", "nod.rs", "pod.rs"] {
1104 std::fs::write(dir.path().join(name), "").unwrap();
1105 }
1106
1107 let suggestions = suggest_similar_files(dir.path(), "xod.rs");
1108 assert!(suggestions.len() <= 3);
1109 }
1110
1111 #[test]
1112 fn suggest_similar_nothing_close_returns_empty() {
1113 let dir = tempfile::tempdir().unwrap();
1114 std::fs::write(dir.path().join("completely_different.rs"), "").unwrap();
1115
1116 let suggestions = suggest_similar_files(dir.path(), "a.rs");
1118 assert!(
1119 suggestions.is_empty(),
1120 "expected no suggestions, got: {suggestions:?}"
1121 );
1122 }
1123
1124 #[test]
1125 fn suggest_similar_ranks_closer_matches_first() {
1126 let dir = tempfile::tempdir().unwrap();
1127 std::fs::write(dir.path().join("auth.rs"), "").unwrap();
1128 std::fs::write(dir.path().join("autho.rs"), "").unwrap();
1129
1130 let suggestions = suggest_similar_files(dir.path(), "atuh.rs");
1131 assert!(!suggestions.is_empty());
1132 assert!(
1133 suggestions.iter().any(|s| s.contains("auth.rs")),
1134 "expected auth.rs, got: {suggestions:?}"
1135 );
1136 }
1137
1138 fn simple_schema() -> serde_json::Value {
1139 serde_json::json!({
1140 "type": "object",
1141 "properties": {
1142 "path": { "type": "string" },
1143 "count": { "type": "integer" }
1144 },
1145 "required": ["path"]
1146 })
1147 }
1148
1149 #[test]
1150 fn validate_tool_args_valid_passes() {
1151 let schema = simple_schema();
1152 let args = serde_json::json!({ "path": "/tmp/foo.txt" });
1153 assert!(validate_tool_args(&schema, &args).is_ok());
1154 }
1155
1156 #[test]
1157 fn validate_tool_args_valid_with_optional_passes() {
1158 let schema = simple_schema();
1159 let args = serde_json::json!({ "path": "/tmp/foo.txt", "count": 5 });
1160 assert!(validate_tool_args(&schema, &args).is_ok());
1161 }
1162
1163 #[test]
1164 fn validate_tool_args_missing_required_returns_error() {
1165 let schema = simple_schema();
1166 let args = serde_json::json!({ "count": 5 });
1168 let result = validate_tool_args(&schema, &args);
1169 assert!(result.is_err());
1170 let msg = result.unwrap_err().to_string();
1171 assert!(
1172 msg.contains("path") || msg.contains("required"),
1173 "expected error mentioning 'path' or 'required', got: {msg}"
1174 );
1175 }
1176
1177 #[test]
1178 fn validate_tool_args_wrong_type_returns_error() {
1179 let schema = simple_schema();
1180 let args = serde_json::json!({ "path": "/tmp/foo.txt", "count": "not-a-number" });
1182 let result = validate_tool_args(&schema, &args);
1183 assert!(result.is_err());
1184 let msg = result.unwrap_err().to_string();
1185 assert!(
1186 msg.contains("integer") || msg.contains("type"),
1187 "expected type error, got: {msg}"
1188 );
1189 }
1190
1191 #[test]
1192 fn validate_tool_args_extra_fields_allowed() {
1193 let schema = simple_schema();
1195 let args = serde_json::json!({
1196 "path": "/tmp/foo.txt",
1197 "llm_added_extra": "some value",
1198 "another_unknown": 42
1199 });
1200 assert!(
1201 validate_tool_args(&schema, &args).is_ok(),
1202 "extra/unknown fields should be allowed"
1203 );
1204 }
1205
1206 #[test]
1209 fn file_track_was_read_false_for_unread_file() {
1210 let dir = tempfile::tempdir().unwrap();
1211 let file = dir.path().join("test.txt");
1212 std::fs::write(&file, "content").unwrap();
1213
1214 let tracker = FileTracker::new();
1215 assert!(!tracker.was_read(&file), "unread file should return false");
1216 }
1217
1218 #[test]
1219 fn file_track_was_read_true_after_recording() {
1220 let dir = tempfile::tempdir().unwrap();
1221 let file = dir.path().join("test.txt");
1222 std::fs::write(&file, "content").unwrap();
1223
1224 let mut tracker = FileTracker::new();
1225 tracker.record_read(&file);
1226 assert!(
1227 tracker.was_read(&file),
1228 "file should be marked as read after recording"
1229 );
1230 }
1231
1232 #[test]
1233 fn file_track_is_stale_false_for_unread_file() {
1234 let dir = tempfile::tempdir().unwrap();
1235 let file = dir.path().join("test.txt");
1236 std::fs::write(&file, "content").unwrap();
1237
1238 let tracker = FileTracker::new();
1239 assert!(!tracker.is_stale(&file));
1241 }
1242
1243 #[test]
1244 fn file_track_is_stale_false_immediately_after_read() {
1245 let dir = tempfile::tempdir().unwrap();
1246 let file = dir.path().join("test.txt");
1247 std::fs::write(&file, "content").unwrap();
1248
1249 let mut tracker = FileTracker::new();
1250 tracker.record_read(&file);
1251 assert!(!tracker.is_stale(&file));
1253 }
1254
1255 #[test]
1256 fn file_track_is_stale_detects_external_modification() {
1257 let dir = tempfile::tempdir().unwrap();
1258 let file = dir.path().join("test.txt");
1259 std::fs::write(&file, "original content").unwrap();
1260
1261 let mut tracker = FileTracker::new();
1262 tracker.record_read(&file);
1263
1264 let future = std::time::SystemTime::now() + std::time::Duration::from_secs(2);
1267 if let Ok(f) = std::fs::OpenOptions::new().write(true).open(&file) {
1268 let _ = f.set_modified(future);
1269 }
1270
1271 assert!(
1272 tracker.is_stale(&file),
1273 "file with advanced mtime should be stale"
1274 );
1275 }
1276
1277 #[test]
1280 fn file_history_snapshot_stores_original() {
1281 let dir = tempfile::tempdir().unwrap();
1282 let file = dir.path().join("test.rs");
1283 std::fs::write(&file, "fn main() {}").unwrap();
1284
1285 let history = FileHistory::new();
1286 history.snapshot_before_edit(&file).unwrap();
1287
1288 assert_eq!(history.original(&file).unwrap(), "fn main() {}");
1289 }
1290
1291 #[test]
1292 fn file_history_second_snapshot_is_noop() {
1293 let dir = tempfile::tempdir().unwrap();
1294 let file = dir.path().join("test.rs");
1295 std::fs::write(&file, "original").unwrap();
1296
1297 let history = FileHistory::new();
1298 history.snapshot_before_edit(&file).unwrap();
1299
1300 std::fs::write(&file, "modified").unwrap();
1302 history.snapshot_before_edit(&file).unwrap();
1303
1304 assert_eq!(history.original(&file).unwrap(), "original");
1305 }
1306
1307 #[test]
1308 fn file_history_rollback_restores_original() {
1309 let dir = tempfile::tempdir().unwrap();
1310 let file = dir.path().join("test.rs");
1311 std::fs::write(&file, "original content").unwrap();
1312
1313 let history = FileHistory::new();
1314 history.snapshot_before_edit(&file).unwrap();
1315
1316 std::fs::write(&file, "agent wrote this").unwrap();
1317 history.rollback(&file).unwrap();
1318
1319 assert_eq!(std::fs::read_to_string(&file).unwrap(), "original content");
1320 }
1321
1322 #[test]
1323 fn file_history_skips_nonexistent_files() {
1324 let dir = tempfile::tempdir().unwrap();
1325 let file = dir.path().join("does_not_exist.rs");
1326
1327 let history = FileHistory::new();
1328 history.snapshot_before_edit(&file).unwrap();
1329
1330 assert!(history.original(&file).is_none());
1331 }
1332
1333 #[test]
1334 fn file_history_tracked_files_lists_all() {
1335 let dir = tempfile::tempdir().unwrap();
1336 let f1 = dir.path().join("a.rs");
1337 let f2 = dir.path().join("b.rs");
1338 std::fs::write(&f1, "a").unwrap();
1339 std::fs::write(&f2, "b").unwrap();
1340
1341 let history = FileHistory::new();
1342 history.snapshot_before_edit(&f1).unwrap();
1343 history.snapshot_before_edit(&f2).unwrap();
1344
1345 let tracked = history.tracked_files();
1346 assert_eq!(tracked.len(), 2);
1347 assert!(tracked.contains(&f1));
1348 assert!(tracked.contains(&f2));
1349 }
1350}