1pub mod ask;
2pub mod bash;
3pub mod edit;
4pub mod extend;
5pub mod git;
6pub mod imp;
7pub mod lua;
8pub mod mana;
9pub mod memory;
10pub mod multi_edit;
11pub mod query;
12pub mod read;
13pub mod scan;
14pub mod session_search;
15pub mod shell;
16pub mod web;
17pub mod write;
18
19use std::collections::hash_map::DefaultHasher;
20use std::collections::HashMap;
21use std::hash::{Hash, Hasher};
22use std::path::{Path, PathBuf};
23use std::sync::Arc;
24
25use async_trait::async_trait;
26use imp_llm::provider::ToolDefinition;
27use imp_llm::{ContentBlock, ToolResultMessage};
28
29use crate::agent::AgentCommand;
30use crate::config::AgentMode;
31use crate::config::LuaCapabilityPolicy;
32use crate::error::Result;
33use crate::mana_review::TurnManaReviewAccumulator;
34use crate::ui::UserInterface;
35
36pub fn resolve_path(cwd: &Path, raw: &str) -> PathBuf {
38 if raw == "~" {
39 if let Ok(home) = std::env::var("HOME") {
40 return PathBuf::from(home);
41 }
42 } else if let Some(rest) = raw.strip_prefix("~/") {
43 if let Ok(home) = std::env::var("HOME") {
44 return PathBuf::from(home).join(rest);
45 }
46 }
47 let p = Path::new(raw);
48 if p.is_absolute() {
49 p.to_path_buf()
50 } else {
51 cwd.join(p)
52 }
53}
54
55#[async_trait]
57pub trait Tool: Send + Sync {
58 fn name(&self) -> &str;
60
61 fn label(&self) -> &str;
63
64 fn description(&self) -> &str;
66
67 fn parameters(&self) -> serde_json::Value;
69
70 fn is_readonly(&self) -> bool;
72
73 async fn execute(
75 &self,
76 call_id: &str,
77 params: serde_json::Value,
78 ctx: ToolContext,
79 ) -> Result<ToolOutput>;
80}
81
82pub struct FileTracker {
86 reads: HashMap<PathBuf, std::time::SystemTime>,
87}
88
89impl Default for FileTracker {
90 fn default() -> Self {
91 Self::new()
92 }
93}
94
95impl FileTracker {
96 pub fn new() -> Self {
97 Self {
98 reads: HashMap::new(),
99 }
100 }
101
102 pub fn record_read(&mut self, path: &Path) {
104 let mtime = std::fs::metadata(path)
105 .and_then(|m| m.modified())
106 .unwrap_or(std::time::UNIX_EPOCH);
107 self.reads.insert(path.to_path_buf(), mtime);
108 }
109
110 pub fn was_read(&self, path: &Path) -> bool {
112 self.reads.contains_key(path)
113 }
114
115 pub fn is_stale(&self, path: &Path) -> bool {
119 let Some(&recorded_mtime) = self.reads.get(path) else {
120 return false;
121 };
122 let Ok(current_mtime) = std::fs::metadata(path).and_then(|m| m.modified()) else {
123 return false;
124 };
125 current_mtime != recorded_mtime
126 }
127}
128
129#[derive(Debug, Clone, PartialEq, Eq)]
130pub struct LineAnchor {
131 pub id: String,
132 pub line: usize,
133 pub content_hash: u64,
134}
135
136#[derive(Debug, Default)]
137pub struct AnchorStore {
138 files: std::sync::Mutex<HashMap<PathBuf, HashMap<String, LineAnchor>>>,
139}
140
141impl AnchorStore {
142 pub fn new() -> Self {
143 Self::default()
144 }
145
146 pub fn record_lines(
147 &self,
148 path: &Path,
149 file_hash: u64,
150 start_line: usize,
151 lines: &[&str],
152 ) -> Vec<LineAnchor> {
153 let anchors = lines
154 .iter()
155 .enumerate()
156 .map(|(idx, line)| {
157 let line_number = start_line + idx;
158 let content_hash = stable_hash(line);
159 LineAnchor {
160 id: format!(
161 "a{:016x}{:08x}{:016x}",
162 file_hash, line_number, content_hash
163 ),
164 line: line_number,
165 content_hash,
166 }
167 })
168 .collect::<Vec<_>>();
169
170 if let Ok(mut files) = self.files.lock() {
171 let entry = files.entry(path.to_path_buf()).or_default();
172 for anchor in &anchors {
173 entry.insert(anchor.id.clone(), anchor.clone());
174 }
175 }
176
177 anchors
178 }
179
180 pub fn get(&self, path: &Path, id: &str) -> Option<LineAnchor> {
181 self.files
182 .lock()
183 .ok()?
184 .get(path)
185 .and_then(|anchors| anchors.get(id).cloned())
186 }
187}
188
189pub fn stable_hash<T: Hash>(value: T) -> u64 {
190 let mut hasher = DefaultHasher::new();
191 value.hash(&mut hasher);
192 hasher.finish()
193}
194
195pub type LuaToolLoader = Arc<dyn Fn(&LuaCapabilityPolicy, &mut ToolRegistry) + Send + Sync>;
197
198#[derive(Clone)]
200pub struct ToolContext {
201 pub cwd: PathBuf,
202 pub cancelled: Arc<std::sync::atomic::AtomicBool>,
203 pub update_tx: tokio::sync::mpsc::Sender<ToolUpdate>,
204 pub command_tx: tokio::sync::mpsc::Sender<AgentCommand>,
205 pub ui: Arc<dyn UserInterface>,
206 pub file_cache: Arc<FileCache>,
207 pub checkpoint_state: Arc<CheckpointState>,
209 pub file_tracker: Arc<std::sync::Mutex<FileTracker>>,
211 pub anchor_store: Arc<AnchorStore>,
213 pub lua_tool_loader: Option<LuaToolLoader>,
215 pub mode: AgentMode,
217 pub read_max_lines: usize,
219 pub turn_mana_review: Arc<std::sync::Mutex<TurnManaReviewAccumulator>>,
221 pub config: Arc<crate::config::Config>,
223}
224
225pub struct FileCache {
227 entries: std::sync::Mutex<std::collections::HashMap<PathBuf, FileCacheEntry>>,
228}
229
230struct FileCacheEntry {
231 mtime: std::time::SystemTime,
232 content: String,
233}
234
235impl Default for FileCache {
236 fn default() -> Self {
237 Self::new()
238 }
239}
240
241impl FileCache {
242 pub fn new() -> Self {
243 Self {
244 entries: std::sync::Mutex::new(std::collections::HashMap::new()),
245 }
246 }
247
248 pub fn read(&self, path: &Path) -> std::io::Result<String> {
250 let metadata = std::fs::metadata(path)?;
251 let mtime = metadata.modified().unwrap_or(std::time::UNIX_EPOCH);
252
253 {
254 let cache = self.entries.lock().unwrap();
255 if let Some(entry) = cache.get(path) {
256 if entry.mtime == mtime {
257 return Ok(entry.content.clone());
258 }
259 }
260 }
261
262 let content = std::fs::read_to_string(path)?;
263
264 {
265 let mut cache = self.entries.lock().unwrap();
266 cache.insert(
267 path.to_path_buf(),
268 FileCacheEntry {
269 mtime,
270 content: content.clone(),
271 },
272 );
273 }
274
275 Ok(content)
276 }
277
278 pub fn invalidate(&self, path: &Path) {
280 let mut cache = self.entries.lock().unwrap();
281 cache.remove(path);
282 }
283}
284
285pub struct FileHistory {
291 originals: std::sync::Mutex<HashMap<PathBuf, String>>,
292}
293
294#[derive(Debug, Clone, PartialEq, Eq)]
295pub struct CheckpointRecord {
296 pub id: String,
297 pub label: Option<String>,
298 pub created_at: u64,
299 pub files: Vec<PathBuf>,
300}
301
302pub struct CheckpointState {
308 history: FileHistory,
309 records: std::sync::Mutex<Vec<CheckpointRecord>>,
310}
311
312impl Default for CheckpointState {
313 fn default() -> Self {
314 Self::new()
315 }
316}
317
318impl CheckpointState {
319 pub fn new() -> Self {
320 Self {
321 history: FileHistory::new(),
322 records: std::sync::Mutex::new(Vec::new()),
323 }
324 }
325
326 pub fn snapshot_paths(
328 &self,
329 paths: &[PathBuf],
330 label: Option<String>,
331 ) -> std::io::Result<Option<CheckpointRecord>> {
332 let mut unique = Vec::new();
333 for path in paths {
334 if !unique.iter().any(|existing: &PathBuf| existing == path) {
335 unique.push(path.clone());
336 }
337 }
338
339 let mut captured = Vec::new();
340 for path in unique {
341 self.history.snapshot_before_edit(&path)?;
342 if self.history.original(&path).is_some() {
343 captured.push(path);
344 }
345 }
346
347 if captured.is_empty() {
348 return Ok(None);
349 }
350
351 let record = CheckpointRecord {
352 id: uuid::Uuid::new_v4().to_string(),
353 label,
354 created_at: imp_llm::now(),
355 files: captured,
356 };
357 self.records.lock().unwrap().push(record.clone());
358 Ok(Some(record))
359 }
360
361 pub fn checkpoints(&self) -> Vec<CheckpointRecord> {
362 self.records.lock().unwrap().clone()
363 }
364
365 pub fn checkpoint(&self, id: &str) -> Option<CheckpointRecord> {
366 self.records
367 .lock()
368 .unwrap()
369 .iter()
370 .find(|record| record.id == id)
371 .cloned()
372 }
373
374 pub fn restore_checkpoint(&self, id: &str) -> std::io::Result<Vec<PathBuf>> {
375 let Some(record) = self.checkpoint(id) else {
376 return Ok(Vec::new());
377 };
378
379 let mut restored = Vec::new();
380 for path in &record.files {
381 self.history.rollback(path)?;
382 restored.push(path.clone());
383 }
384 Ok(restored)
385 }
386
387 pub fn rollback(&self, path: &Path) -> std::io::Result<()> {
388 self.history.rollback(path)
389 }
390
391 pub fn tracked_files(&self) -> Vec<PathBuf> {
392 self.history.tracked_files()
393 }
394
395 pub fn original(&self, path: &Path) -> Option<String> {
396 self.history.original(path)
397 }
398}
399
400impl Default for FileHistory {
401 fn default() -> Self {
402 Self::new()
403 }
404}
405
406impl FileHistory {
407 pub fn new() -> Self {
408 Self {
409 originals: std::sync::Mutex::new(HashMap::new()),
410 }
411 }
412
413 pub fn snapshot_before_edit(&self, path: &Path) -> std::io::Result<()> {
416 let canonical = path.to_path_buf();
417
418 let mut originals = self.originals.lock().unwrap();
419 if originals.contains_key(&canonical) {
420 return Ok(()); }
422 if canonical.exists() {
423 let content = std::fs::read_to_string(&canonical)?;
424 originals.insert(canonical, content);
425 }
426 Ok(())
427 }
428
429 pub fn original(&self, path: &Path) -> Option<String> {
431 self.originals.lock().unwrap().get(path).cloned()
432 }
433
434 pub fn rollback(&self, path: &Path) -> std::io::Result<()> {
436 let originals = self.originals.lock().unwrap();
437 if let Some(content) = originals.get(path) {
438 std::fs::write(path, content)?;
439 }
440 Ok(())
441 }
442
443 pub fn tracked_files(&self) -> Vec<PathBuf> {
445 self.originals.lock().unwrap().keys().cloned().collect()
446 }
447}
448
449impl ToolContext {
450 pub fn is_cancelled(&self) -> bool {
451 self.cancelled.load(std::sync::atomic::Ordering::Relaxed)
452 }
453}
454
455pub struct ToolOutput {
457 pub content: Vec<ContentBlock>,
458 pub details: serde_json::Value,
459 pub is_error: bool,
460}
461
462impl ToolOutput {
463 pub fn text(text: impl Into<String>) -> Self {
464 Self {
465 content: vec![ContentBlock::Text { text: text.into() }],
466 details: serde_json::Value::Null,
467 is_error: false,
468 }
469 }
470
471 pub fn error(text: impl Into<String>) -> Self {
472 Self {
473 content: vec![ContentBlock::Text { text: text.into() }],
474 details: serde_json::Value::Null,
475 is_error: true,
476 }
477 }
478
479 pub fn text_content(&self) -> Option<&str> {
481 self.content.iter().find_map(|b| match b {
482 ContentBlock::Text { text } => Some(text.as_str()),
483 _ => None,
484 })
485 }
486
487 pub fn into_tool_result(self, call_id: &str, tool_name: &str) -> ToolResultMessage {
488 ToolResultMessage {
489 tool_call_id: call_id.to_string(),
490 tool_name: tool_name.to_string(),
491 content: self.content,
492 is_error: self.is_error,
493 details: self.details,
494 timestamp: imp_llm::now(),
495 }
496 }
497}
498
499pub struct ToolUpdate {
501 pub content: Vec<ContentBlock>,
502 pub details: serde_json::Value,
503}
504
505pub struct ToolRegistry {
507 tools: HashMap<String, Arc<dyn Tool>>,
508 aliases: HashMap<String, String>,
509}
510
511impl ToolRegistry {
512 pub fn new() -> Self {
513 Self {
514 tools: HashMap::new(),
515 aliases: HashMap::new(),
516 }
517 }
518
519 pub fn register(&mut self, tool: Arc<dyn Tool>) {
521 self.tools.insert(tool.name().to_string(), tool);
522 }
523
524 pub fn register_alias(&mut self, alias: impl Into<String>, canonical: impl Into<String>) {
529 self.aliases.insert(alias.into(), canonical.into());
530 }
531
532 pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
534 if let Some(tool) = self.tools.get(name) {
535 return Some(tool);
536 }
537
538 self.aliases
539 .get(name)
540 .and_then(|canonical| self.tools.get(canonical))
541 }
542
543 pub fn tools_map(&self) -> HashMap<String, Arc<dyn Tool>> {
545 let mut map = self.tools.clone();
546 for (alias, canonical) in &self.aliases {
547 if let Some(tool) = self.tools.get(canonical) {
548 map.insert(alias.clone(), Arc::clone(tool));
549 }
550 }
551 map
552 }
553
554 pub fn definitions(&self) -> Vec<ToolDefinition> {
558 let mut defs: Vec<_> = self
559 .tools
560 .values()
561 .map(|t| ToolDefinition {
562 name: t.name().to_string(),
563 description: t.description().to_string(),
564 parameters: t.parameters(),
565 })
566 .collect();
567 defs.sort_by(|a, b| a.name.cmp(&b.name));
568 defs
569 }
570
571 pub fn readonly_definitions(&self) -> Vec<ToolDefinition> {
573 let mut defs: Vec<_> = self
574 .tools
575 .values()
576 .filter(|t| t.is_readonly())
577 .map(|t| ToolDefinition {
578 name: t.name().to_string(),
579 description: t.description().to_string(),
580 parameters: t.parameters(),
581 })
582 .collect();
583 defs.sort_by(|a, b| a.name.cmp(&b.name));
584 defs
585 }
586
587 pub fn names(&self) -> Vec<String> {
589 self.tools.keys().cloned().collect()
590 }
591
592 pub fn retain<F>(&mut self, predicate: F)
597 where
598 F: Fn(&str) -> bool,
599 {
600 self.tools.retain(|name, _| predicate(name));
601 self.aliases
602 .retain(|_, canonical| self.tools.contains_key(canonical));
603 }
604
605 pub fn definitions_for_mode(
610 &self,
611 mode: &crate::config::AgentMode,
612 ) -> Vec<imp_llm::provider::ToolDefinition> {
613 let mut defs: Vec<_> = self
614 .tools
615 .values()
616 .filter(|t| mode.allows_tool(t.name()))
617 .map(|t| imp_llm::provider::ToolDefinition {
618 name: t.name().to_string(),
619 description: t.description().to_string(),
620 parameters: t.parameters(),
621 })
622 .collect();
623 defs.sort_by(|a, b| a.name.cmp(&b.name));
624 defs
625 }
626
627 pub fn len(&self) -> usize {
629 self.tools.len()
630 }
631
632 pub fn is_empty(&self) -> bool {
633 self.tools.is_empty()
634 }
635}
636
637impl Default for ToolRegistry {
638 fn default() -> Self {
639 Self::new()
640 }
641}
642
643pub struct TruncationResult {
646 pub content: String,
647 pub truncated: bool,
648 pub output_lines: usize,
649 pub total_lines: usize,
650 pub output_bytes: usize,
651 pub total_bytes: usize,
652 pub temp_file: Option<PathBuf>,
653}
654
655pub fn truncate_line(line: &str, max_bytes: usize) -> String {
657 if line.len() <= max_bytes {
658 return line.to_string();
659 }
660 let mut end = max_bytes.min(line.len());
661 while end > 0 && !line.is_char_boundary(end) {
662 end -= 1;
663 }
664 format!("{}…", &line[..end])
665}
666
667fn write_temp_file(content: &str) -> Option<PathBuf> {
669 let dir = std::env::temp_dir().join("imp-tools");
670 std::fs::create_dir_all(&dir).ok()?;
671 let name = format!("truncated-{}.txt", uuid::Uuid::new_v4());
672 let path = dir.join(name);
673 std::fs::write(&path, content).ok()?;
674 Some(path)
675}
676
677pub fn truncate_head(input: &str, max_lines: usize, max_bytes: usize) -> TruncationResult {
680 let lines: Vec<&str> = input.lines().collect();
681 let total_lines = lines.len();
682 let total_bytes = input.len();
683
684 if total_lines <= max_lines && total_bytes <= max_bytes {
685 return TruncationResult {
686 content: input.to_string(),
687 truncated: false,
688 output_lines: total_lines,
689 total_lines,
690 output_bytes: total_bytes,
691 total_bytes,
692 temp_file: None,
693 };
694 }
695
696 let mut result = String::new();
697 let mut byte_count = 0;
698 let mut line_count = 0;
699
700 for line in &lines {
701 let line_with_newline = format!("{line}\n");
702 if line_count >= max_lines || byte_count + line_with_newline.len() > max_bytes {
703 break;
704 }
705 result.push_str(&line_with_newline);
706 byte_count += line_with_newline.len();
707 line_count += 1;
708 }
709
710 let temp_file = write_temp_file(input);
711
712 TruncationResult {
713 content: result,
714 truncated: true,
715 output_lines: line_count,
716 total_lines,
717 output_bytes: byte_count,
718 total_bytes,
719 temp_file,
720 }
721}
722
723pub fn truncate_tail(input: &str, max_lines: usize, max_bytes: usize) -> TruncationResult {
726 let lines: Vec<&str> = input.lines().collect();
727 let total_lines = lines.len();
728 let total_bytes = input.len();
729
730 if total_lines <= max_lines && total_bytes <= max_bytes {
731 return TruncationResult {
732 content: input.to_string(),
733 truncated: false,
734 output_lines: total_lines,
735 total_lines,
736 output_bytes: total_bytes,
737 total_bytes,
738 temp_file: None,
739 };
740 }
741
742 let start = total_lines.saturating_sub(max_lines);
744 let mut actual_start = start;
745 let mut remaining_bytes = max_bytes;
746
747 for (i, line) in lines[start..].iter().enumerate() {
748 let line_with_newline = format!("{line}\n");
749 if line_with_newline.len() > remaining_bytes {
750 actual_start = start + i + 1;
751 remaining_bytes = max_bytes;
752 for line2 in &lines[actual_start..] {
754 let l = format!("{line2}\n");
755 if l.len() > remaining_bytes {
756 break;
757 }
758 remaining_bytes -= l.len();
759 }
760 break;
761 }
762 remaining_bytes -= line_with_newline.len();
763 }
764
765 let mut result = String::new();
766 for line in &lines[actual_start..] {
767 result.push_str(&format!("{line}\n"));
768 }
769
770 let output_lines = total_lines - actual_start;
771 let output_bytes = result.len();
772 let temp_file = write_temp_file(input);
773
774 TruncationResult {
775 content: result,
776 truncated: true,
777 output_lines,
778 total_lines,
779 output_bytes,
780 total_bytes,
781 temp_file,
782 }
783}
784
785pub(crate) mod fuzzy {
788 pub fn normalize_for_matching(text: &str) -> String {
791 text.lines()
792 .map(|line| normalize_unicode(line.trim_end()))
793 .collect::<Vec<_>>()
794 .join("\n")
795 }
796
797 fn normalize_unicode(s: &str) -> String {
798 s.chars()
799 .map(|c| match c {
800 '\u{2018}' | '\u{2019}' => '\'',
801 '\u{201C}' | '\u{201D}' => '"',
802 '\u{2013}' | '\u{2014}' => '-',
803 '\u{00A0}' | '\u{2003}' | '\u{2002}' | '\u{2009}' => ' ',
804 other => other,
805 })
806 .collect()
807 }
808
809 pub struct FuzzyMatch {
811 pub start: usize,
812 pub end: usize,
813 }
814
815 pub fn fuzzy_find(content: &str, old_text: &str) -> Option<FuzzyMatch> {
819 let content_lines: Vec<&str> = content.lines().collect();
820 let search_norm = normalize_for_matching(old_text);
821 let search_lines: Vec<&str> = search_norm.lines().collect();
822
823 if search_lines.is_empty() {
824 return None;
825 }
826
827 let content_norm_lines: Vec<String> = content_lines
828 .iter()
829 .map(|l| normalize_unicode(l.trim_end()))
830 .collect();
831
832 if search_lines.len() > content_norm_lines.len() {
833 return None;
834 }
835
836 let window_size = search_lines.len();
838 for start_line in 0..=(content_norm_lines.len() - window_size) {
839 let matches = content_norm_lines[start_line..start_line + window_size]
840 .iter()
841 .zip(search_lines.iter())
842 .all(|(content_line, search_line)| content_line == search_line);
843
844 if matches {
845 let byte_start: usize = content_lines[..start_line]
847 .iter()
848 .map(|l| l.len() + 1) .sum();
850
851 let end_line = start_line + window_size - 1;
852 let byte_end: usize = content_lines[..end_line]
853 .iter()
854 .map(|l| l.len() + 1)
855 .sum::<usize>()
856 + content_lines[end_line].len();
857
858 return Some(FuzzyMatch {
859 start: byte_start,
860 end: byte_end,
861 });
862 }
863 }
864
865 None
866 }
867}
868
869pub fn levenshtein(a: &str, b: &str) -> usize {
875 let a_chars: Vec<char> = a.chars().collect();
876 let b_chars: Vec<char> = b.chars().collect();
877 let m = a_chars.len();
878 let n = b_chars.len();
879
880 let mut prev: Vec<usize> = (0..=n).collect();
881 let mut curr = vec![0usize; n + 1];
882
883 for i in 1..=m {
884 curr[0] = i;
885 for j in 1..=n {
886 let cost = if a_chars[i - 1] == b_chars[j - 1] {
887 0
888 } else {
889 1
890 };
891 curr[j] = (prev[j] + 1).min(curr[j - 1] + 1).min(prev[j - 1] + cost);
892 }
893 std::mem::swap(&mut prev, &mut curr);
894 }
895
896 prev[n]
897}
898
899pub fn suggest_similar_files(cwd: &Path, target: &str) -> Vec<String> {
905 let target_name = Path::new(target)
906 .file_name()
907 .and_then(|n: &std::ffi::OsStr| n.to_str())
908 .unwrap_or(target);
909
910 let mut candidates: Vec<(usize, String)> = Vec::new();
911
912 const SKIP_DIRS: &[&str] = &[
914 "target",
915 "node_modules",
916 ".git",
917 "vendor",
918 "dist",
919 "build",
920 "__pycache__",
921 ".mypy_cache",
922 ".tox",
923 ".venv",
924 ];
925
926 let walker = walkdir::WalkDir::new(cwd)
927 .max_depth(3)
928 .follow_links(false)
929 .into_iter()
930 .filter_entry(|e| {
931 if e.file_type().is_dir() {
932 if let Some(name) = e.file_name().to_str() {
933 return !SKIP_DIRS.contains(&name);
934 }
935 }
936 true
937 })
938 .filter_map(|e| e.ok());
939
940 for entry in walker {
941 if entry.file_type().is_file() {
942 if let Some(name) = entry.file_name().to_str() {
943 let dist = levenshtein(target_name, name);
944 if dist <= 3 {
945 let rel = entry
946 .path()
947 .strip_prefix(cwd)
948 .unwrap_or(entry.path())
949 .display()
950 .to_string();
951 candidates.push((dist, rel));
952 }
953 }
954 }
955 }
956
957 candidates.sort_by_key(|(d, _)| *d);
958 candidates.truncate(3);
959 candidates.into_iter().map(|(_, p)| p).collect()
960}
961
962pub fn generate_diff(file_path: &str, old: &str, new: &str) -> String {
966 use similar::TextDiff;
967
968 let diff = TextDiff::from_lines(old, new);
969 let mut output = String::new();
970 output.push_str(&format!("--- {file_path}\n"));
971 output.push_str(&format!("+++ {file_path}\n"));
972
973 for hunk in diff.unified_diff().context_radius(3).iter_hunks() {
974 output.push_str(&format!("{hunk}"));
975 }
976
977 output
978}
979
980pub fn validate_tool_args(schema: &serde_json::Value, args: &serde_json::Value) -> Result<()> {
988 use jsonschema::Validator;
989
990 let validator = Validator::new(schema)
991 .map_err(|e| crate::error::Error::Tool(format!("Invalid tool schema: {e}")))?;
992
993 let errors: Vec<String> = validator
994 .iter_errors(args)
995 .map(|e| format!("{e}"))
996 .collect();
997
998 if errors.is_empty() {
999 Ok(())
1000 } else {
1001 Err(crate::error::Error::Tool(format!(
1002 "Tool argument validation failed:\n{}",
1003 errors.join("\n")
1004 )))
1005 }
1006}
1007
1008#[cfg(test)]
1009mod tests {
1010 use super::*;
1011
1012 #[test]
1015 fn suggest_similar_levenshtein_identical() {
1016 assert_eq!(levenshtein("hello", "hello"), 0);
1017 }
1018
1019 #[test]
1020 fn suggest_similar_levenshtein_one_substitution() {
1021 assert_eq!(levenshtein("auth", "aath"), 1);
1022 }
1023
1024 #[test]
1025 fn suggest_similar_levenshtein_one_insertion() {
1026 assert_eq!(levenshtein("helo", "hello"), 1);
1027 }
1028
1029 #[test]
1030 fn suggest_similar_levenshtein_one_deletion() {
1031 assert_eq!(levenshtein("hello", "helo"), 1);
1032 }
1033
1034 #[test]
1035 fn suggest_similar_levenshtein_empty_strings() {
1036 assert_eq!(levenshtein("", ""), 0);
1037 assert_eq!(levenshtein("abc", ""), 3);
1038 assert_eq!(levenshtein("", "abc"), 3);
1039 }
1040
1041 #[test]
1042 fn suggest_similar_levenshtein_completely_different() {
1043 assert_eq!(levenshtein("abc", "xyz"), 3);
1045 }
1046
1047 #[test]
1048 fn suggest_similar_levenshtein_transposition() {
1049 assert_eq!(levenshtein("atuh", "auth"), 2);
1051 }
1052
1053 #[test]
1056 fn suggest_similar_finds_close_match() {
1057 let dir = tempfile::tempdir().unwrap();
1058 std::fs::write(dir.path().join("middleware.rs"), "").unwrap();
1059 std::fs::write(dir.path().join("unrelated.rs"), "").unwrap();
1060
1061 let suggestions = suggest_similar_files(dir.path(), "middlewar.rs");
1062 assert!(
1063 suggestions.iter().any(|s| s.contains("middleware.rs")),
1064 "expected middleware.rs in suggestions, got: {suggestions:?}"
1065 );
1066 }
1067
1068 #[test]
1069 fn suggest_similar_returns_at_most_three() {
1070 let dir = tempfile::tempdir().unwrap();
1071 for name in &["mod.rs", "rod.rs", "cod.rs", "nod.rs", "pod.rs"] {
1073 std::fs::write(dir.path().join(name), "").unwrap();
1074 }
1075
1076 let suggestions = suggest_similar_files(dir.path(), "xod.rs");
1077 assert!(suggestions.len() <= 3);
1078 }
1079
1080 #[test]
1081 fn suggest_similar_nothing_close_returns_empty() {
1082 let dir = tempfile::tempdir().unwrap();
1083 std::fs::write(dir.path().join("completely_different.rs"), "").unwrap();
1084
1085 let suggestions = suggest_similar_files(dir.path(), "a.rs");
1087 assert!(
1088 suggestions.is_empty(),
1089 "expected no suggestions, got: {suggestions:?}"
1090 );
1091 }
1092
1093 #[test]
1094 fn suggest_similar_ranks_closer_matches_first() {
1095 let dir = tempfile::tempdir().unwrap();
1096 std::fs::write(dir.path().join("auth.rs"), "").unwrap();
1097 std::fs::write(dir.path().join("autho.rs"), "").unwrap();
1098
1099 let suggestions = suggest_similar_files(dir.path(), "atuh.rs");
1100 assert!(!suggestions.is_empty());
1101 assert!(
1102 suggestions.iter().any(|s| s.contains("auth.rs")),
1103 "expected auth.rs, got: {suggestions:?}"
1104 );
1105 }
1106
1107 fn simple_schema() -> serde_json::Value {
1108 serde_json::json!({
1109 "type": "object",
1110 "properties": {
1111 "path": { "type": "string" },
1112 "count": { "type": "integer" }
1113 },
1114 "required": ["path"]
1115 })
1116 }
1117
1118 #[test]
1119 fn validate_tool_args_valid_passes() {
1120 let schema = simple_schema();
1121 let args = serde_json::json!({ "path": "/tmp/foo.txt" });
1122 assert!(validate_tool_args(&schema, &args).is_ok());
1123 }
1124
1125 #[test]
1126 fn validate_tool_args_valid_with_optional_passes() {
1127 let schema = simple_schema();
1128 let args = serde_json::json!({ "path": "/tmp/foo.txt", "count": 5 });
1129 assert!(validate_tool_args(&schema, &args).is_ok());
1130 }
1131
1132 #[test]
1133 fn validate_tool_args_missing_required_returns_error() {
1134 let schema = simple_schema();
1135 let args = serde_json::json!({ "count": 5 });
1137 let result = validate_tool_args(&schema, &args);
1138 assert!(result.is_err());
1139 let msg = result.unwrap_err().to_string();
1140 assert!(
1141 msg.contains("path") || msg.contains("required"),
1142 "expected error mentioning 'path' or 'required', got: {msg}"
1143 );
1144 }
1145
1146 #[test]
1147 fn validate_tool_args_wrong_type_returns_error() {
1148 let schema = simple_schema();
1149 let args = serde_json::json!({ "path": "/tmp/foo.txt", "count": "not-a-number" });
1151 let result = validate_tool_args(&schema, &args);
1152 assert!(result.is_err());
1153 let msg = result.unwrap_err().to_string();
1154 assert!(
1155 msg.contains("integer") || msg.contains("type"),
1156 "expected type error, got: {msg}"
1157 );
1158 }
1159
1160 #[test]
1161 fn validate_tool_args_extra_fields_allowed() {
1162 let schema = simple_schema();
1164 let args = serde_json::json!({
1165 "path": "/tmp/foo.txt",
1166 "llm_added_extra": "some value",
1167 "another_unknown": 42
1168 });
1169 assert!(
1170 validate_tool_args(&schema, &args).is_ok(),
1171 "extra/unknown fields should be allowed"
1172 );
1173 }
1174
1175 #[test]
1178 fn file_track_was_read_false_for_unread_file() {
1179 let dir = tempfile::tempdir().unwrap();
1180 let file = dir.path().join("test.txt");
1181 std::fs::write(&file, "content").unwrap();
1182
1183 let tracker = FileTracker::new();
1184 assert!(!tracker.was_read(&file), "unread file should return false");
1185 }
1186
1187 #[test]
1188 fn file_track_was_read_true_after_recording() {
1189 let dir = tempfile::tempdir().unwrap();
1190 let file = dir.path().join("test.txt");
1191 std::fs::write(&file, "content").unwrap();
1192
1193 let mut tracker = FileTracker::new();
1194 tracker.record_read(&file);
1195 assert!(
1196 tracker.was_read(&file),
1197 "file should be marked as read after recording"
1198 );
1199 }
1200
1201 #[test]
1202 fn file_track_is_stale_false_for_unread_file() {
1203 let dir = tempfile::tempdir().unwrap();
1204 let file = dir.path().join("test.txt");
1205 std::fs::write(&file, "content").unwrap();
1206
1207 let tracker = FileTracker::new();
1208 assert!(!tracker.is_stale(&file));
1210 }
1211
1212 #[test]
1213 fn file_track_is_stale_false_immediately_after_read() {
1214 let dir = tempfile::tempdir().unwrap();
1215 let file = dir.path().join("test.txt");
1216 std::fs::write(&file, "content").unwrap();
1217
1218 let mut tracker = FileTracker::new();
1219 tracker.record_read(&file);
1220 assert!(!tracker.is_stale(&file));
1222 }
1223
1224 #[test]
1225 fn file_track_is_stale_detects_external_modification() {
1226 let dir = tempfile::tempdir().unwrap();
1227 let file = dir.path().join("test.txt");
1228 std::fs::write(&file, "original content").unwrap();
1229
1230 let mut tracker = FileTracker::new();
1231 tracker.record_read(&file);
1232
1233 let future = std::time::SystemTime::now() + std::time::Duration::from_secs(2);
1236 if let Ok(f) = std::fs::OpenOptions::new().write(true).open(&file) {
1237 let _ = f.set_modified(future);
1238 }
1239
1240 assert!(
1241 tracker.is_stale(&file),
1242 "file with advanced mtime should be stale"
1243 );
1244 }
1245
1246 #[test]
1249 fn file_history_snapshot_stores_original() {
1250 let dir = tempfile::tempdir().unwrap();
1251 let file = dir.path().join("test.rs");
1252 std::fs::write(&file, "fn main() {}").unwrap();
1253
1254 let history = FileHistory::new();
1255 history.snapshot_before_edit(&file).unwrap();
1256
1257 assert_eq!(history.original(&file).unwrap(), "fn main() {}");
1258 }
1259
1260 #[test]
1261 fn file_history_second_snapshot_is_noop() {
1262 let dir = tempfile::tempdir().unwrap();
1263 let file = dir.path().join("test.rs");
1264 std::fs::write(&file, "original").unwrap();
1265
1266 let history = FileHistory::new();
1267 history.snapshot_before_edit(&file).unwrap();
1268
1269 std::fs::write(&file, "modified").unwrap();
1271 history.snapshot_before_edit(&file).unwrap();
1272
1273 assert_eq!(history.original(&file).unwrap(), "original");
1274 }
1275
1276 #[test]
1277 fn file_history_rollback_restores_original() {
1278 let dir = tempfile::tempdir().unwrap();
1279 let file = dir.path().join("test.rs");
1280 std::fs::write(&file, "original content").unwrap();
1281
1282 let history = FileHistory::new();
1283 history.snapshot_before_edit(&file).unwrap();
1284
1285 std::fs::write(&file, "agent wrote this").unwrap();
1286 history.rollback(&file).unwrap();
1287
1288 assert_eq!(std::fs::read_to_string(&file).unwrap(), "original content");
1289 }
1290
1291 #[test]
1292 fn file_history_skips_nonexistent_files() {
1293 let dir = tempfile::tempdir().unwrap();
1294 let file = dir.path().join("does_not_exist.rs");
1295
1296 let history = FileHistory::new();
1297 history.snapshot_before_edit(&file).unwrap();
1298
1299 assert!(history.original(&file).is_none());
1300 }
1301
1302 #[test]
1303 fn file_history_tracked_files_lists_all() {
1304 let dir = tempfile::tempdir().unwrap();
1305 let f1 = dir.path().join("a.rs");
1306 let f2 = dir.path().join("b.rs");
1307 std::fs::write(&f1, "a").unwrap();
1308 std::fs::write(&f2, "b").unwrap();
1309
1310 let history = FileHistory::new();
1311 history.snapshot_before_edit(&f1).unwrap();
1312 history.snapshot_before_edit(&f2).unwrap();
1313
1314 let tracked = history.tracked_files();
1315 assert_eq!(tracked.len(), 2);
1316 assert!(tracked.contains(&f1));
1317 assert!(tracked.contains(&f2));
1318 }
1319}