1use crate::checkpoint::CheckpointManager;
8use crate::registry::Tool;
9use async_trait::async_trait;
10use rustant_core::error::ToolError;
11use rustant_core::types::{Artifact, RiskLevel, ToolOutput};
12use similar::TextDiff;
13use std::path::{Path, PathBuf};
14use tokio::sync::Mutex;
15use tracing::debug;
16
17pub struct SmartEditTool {
20 workspace: PathBuf,
21 checkpoint_mgr: Mutex<CheckpointManager>,
22}
23
24impl SmartEditTool {
25 pub fn new(workspace: PathBuf) -> Self {
26 let checkpoint_mgr = CheckpointManager::new(workspace.clone());
27 Self {
28 workspace,
29 checkpoint_mgr: Mutex::new(checkpoint_mgr),
30 }
31 }
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36enum EditType {
37 Replace,
39 InsertAfter,
41 InsertBefore,
43 Delete,
45}
46
47impl EditType {
48 fn from_str(s: &str) -> Option<Self> {
49 match s.to_lowercase().as_str() {
50 "replace" => Some(Self::Replace),
51 "insert_after" | "insert-after" => Some(Self::InsertAfter),
52 "insert_before" | "insert-before" => Some(Self::InsertBefore),
53 "delete" | "remove" => Some(Self::Delete),
54 _ => None,
55 }
56 }
57}
58
59#[derive(Debug)]
61#[allow(dead_code)]
62struct LocationMatch {
63 start: usize,
65 end: usize,
67 matched_text: String,
69 line_number: usize,
71 context_preview: String,
73}
74
75fn find_location(content: &str, pattern: &str) -> Result<LocationMatch, String> {
78 if let Some(start) = content.find(pattern) {
80 let end = start + pattern.len();
81 let line_number = content[..start].matches('\n').count() + 1;
82 let preview = extract_context(content, start, end, 2);
83 return Ok(LocationMatch {
84 start,
85 end,
86 matched_text: pattern.to_string(),
87 line_number,
88 context_preview: preview,
89 });
90 }
91
92 if let Some(m) = parse_line_pattern(pattern) {
94 return find_by_line_range(content, m.0, m.1);
95 }
96
97 if let Some(m) = find_by_function_pattern(content, pattern) {
99 return Ok(m);
100 }
101
102 if let Some(m) = find_by_fuzzy_match(content, pattern) {
104 return Ok(m);
105 }
106
107 Err(format!(
108 "Could not locate '{}' in the file. Try using exact text, a line number (e.g., 'line 42'), or a function name.",
109 truncate(pattern, 80)
110 ))
111}
112
113fn parse_line_pattern(pattern: &str) -> Option<(usize, usize)> {
115 let lower = pattern.trim().to_lowercase();
116
117 if let Some(rest) = lower.strip_prefix("line ")
119 && let Ok(n) = rest.trim().parse::<usize>()
120 {
121 return Some((n, n));
122 }
123
124 if let Some(rest) = lower.strip_prefix("lines ") {
126 let parts: Vec<&str> = rest.split('-').collect();
127 if parts.len() == 2
128 && let (Ok(a), Ok(b)) = (
129 parts[0].trim().parse::<usize>(),
130 parts[1].trim().parse::<usize>(),
131 )
132 {
133 return Some((a, b));
134 }
135 }
136
137 None
138}
139
140fn find_by_line_range(
142 content: &str,
143 start_line: usize,
144 end_line: usize,
145) -> Result<LocationMatch, String> {
146 let lines: Vec<&str> = content.lines().collect();
147 let total = lines.len();
148
149 if start_line == 0 || start_line > total {
150 return Err(format!(
151 "Line {} is out of range (file has {} lines)",
152 start_line, total
153 ));
154 }
155
156 let end_line = end_line.min(total);
157 let start_idx = start_line - 1;
158 let end_idx = end_line;
159
160 let mut byte_offset = 0;
162 let mut start_byte = 0;
163 let mut end_byte = content.len();
164
165 for (i, line) in content.lines().enumerate() {
166 if i == start_idx {
167 start_byte = byte_offset;
168 }
169 byte_offset += line.len() + 1; if i + 1 == end_idx {
171 end_byte = byte_offset.min(content.len());
172 }
173 }
174
175 let matched = &content[start_byte..end_byte];
176 let preview = extract_context(content, start_byte, end_byte, 1);
177
178 Ok(LocationMatch {
179 start: start_byte,
180 end: end_byte,
181 matched_text: matched.to_string(),
182 line_number: start_line,
183 context_preview: preview,
184 })
185}
186
187fn find_by_function_pattern(content: &str, pattern: &str) -> Option<LocationMatch> {
189 let pattern_lower = pattern.to_lowercase();
190
191 let fn_prefixes = [
193 "fn ",
194 "def ",
195 "func ",
196 "function ",
197 "pub fn ",
198 "async fn ",
199 "pub async fn ",
200 "impl ",
201 "class ",
202 "struct ",
203 "enum ",
204 ];
205
206 let is_fn_pattern = fn_prefixes.iter().any(|p| pattern_lower.starts_with(p))
208 || pattern_lower.starts_with("the ")
209 || pattern_lower.contains(" function")
210 || pattern_lower.contains(" method");
211
212 if !is_fn_pattern {
213 return None;
214 }
215
216 let name = extract_identifier_from_pattern(&pattern_lower);
218 if name.is_empty() {
219 return None;
220 }
221
222 for (i, line) in content.lines().enumerate() {
224 let line_lower = line.to_lowercase();
225 let has_fn_keyword = fn_prefixes.iter().any(|p| line_lower.contains(p));
226
227 if has_fn_keyword && line_lower.contains(&name) {
228 let byte_start: usize = content.lines().take(i).map(|l| l.len() + 1).sum();
229
230 let block_end = find_block_end(content, byte_start);
232
233 let matched = &content[byte_start..block_end];
234 let preview = extract_context(content, byte_start, block_end, 0);
235
236 return Some(LocationMatch {
237 start: byte_start,
238 end: block_end,
239 matched_text: matched.to_string(),
240 line_number: i + 1,
241 context_preview: preview,
242 });
243 }
244 }
245
246 None
247}
248
249fn extract_identifier_from_pattern(pattern: &str) -> String {
251 const KEYWORDS: &[&str] = &[
253 "fn",
254 "def",
255 "func",
256 "function",
257 "pub",
258 "async",
259 "impl",
260 "class",
261 "struct",
262 "enum",
263 "let",
264 "const",
265 "var",
266 "type",
267 "trait",
268 "interface",
269 "the",
270 "a",
271 "an",
272 "in",
273 "of",
274 "for",
275 "with",
276 "from",
277 "to",
278 ];
279
280 let cleaned = pattern
282 .replace("the ", "")
283 .replace(" function", "")
284 .replace(" method", "")
285 .replace(" that ", " ")
286 .replace("called ", "");
287
288 for word in cleaned.split_whitespace() {
290 let w = word.trim_matches(|c: char| !c.is_alphanumeric() && c != '_');
291 if w.len() >= 2
292 && !KEYWORDS.contains(&w)
293 && (w.contains('_')
294 || w.chars().any(|c| c.is_uppercase())
295 || w.chars().all(|c| c.is_alphanumeric() || c == '_'))
296 {
297 return w.to_string();
298 }
299 }
300
301 cleaned
303 .split_whitespace()
304 .rev()
305 .map(|w| w.trim_matches(|c: char| !c.is_alphanumeric() && c != '_'))
306 .find(|w| w.len() >= 2 && !KEYWORDS.contains(w))
307 .unwrap_or("")
308 .to_string()
309}
310
311fn find_block_end(content: &str, start: usize) -> usize {
314 let rest = &content[start..];
315 let lines: Vec<&str> = rest.lines().collect();
316
317 if lines.is_empty() {
318 return content.len();
319 }
320
321 let first_line = lines[0];
323 let has_opening_brace = first_line.contains('{');
324
325 if has_opening_brace {
326 let mut depth = 0;
328 let mut byte_pos = start;
329 for ch in content[start..].chars() {
330 match ch {
331 '{' => depth += 1,
332 '}' => {
333 depth -= 1;
334 if depth == 0 {
335 return byte_pos + ch.len_utf8();
336 }
337 }
338 _ => {}
339 }
340 byte_pos += ch.len_utf8();
341 }
342 content.len()
343 } else {
344 let base_indent = first_line.len() - first_line.trim_start().len();
346 let mut end = start + first_line.len() + 1;
347
348 for line in lines.iter().skip(1) {
349 if line.trim().is_empty() {
350 end += line.len() + 1;
351 continue;
352 }
353 let indent = line.len() - line.trim_start().len();
354 if indent <= base_indent {
355 break;
356 }
357 end += line.len() + 1;
358 }
359
360 end.min(content.len())
361 }
362}
363
364fn find_by_fuzzy_match(content: &str, pattern: &str) -> Option<LocationMatch> {
367 let pattern_lower = pattern.to_lowercase();
368 let pattern_words: Vec<&str> = pattern_lower.split_whitespace().collect();
369 let mut best_score = 0.0_f64;
370 let mut best_line_idx = None;
371
372 for (i, line) in content.lines().enumerate() {
373 let line_lower = line.to_lowercase();
374 let line_trimmed = line_lower.trim();
375 if line_trimmed.is_empty() {
376 continue;
377 }
378
379 let bigram_score = similarity_score(line_trimmed, &pattern_lower);
381 let word_score = pattern_words
382 .iter()
383 .filter(|w| line_trimmed.contains(**w))
384 .count() as f64
385 / pattern_words.len().max(1) as f64;
386
387 let score = (bigram_score + word_score) / 2.0;
388 if score > best_score && score > 0.25 {
389 best_score = score;
390 best_line_idx = Some(i);
391 }
392 }
393
394 let line_idx = best_line_idx?;
395 let byte_start: usize = content.lines().take(line_idx).map(|l| l.len() + 1).sum();
396 let line_text = content.lines().nth(line_idx)?;
397 let byte_end = byte_start + line_text.len();
398 let preview = extract_context(content, byte_start, byte_end, 2);
399
400 Some(LocationMatch {
401 start: byte_start,
402 end: byte_end,
403 matched_text: line_text.to_string(),
404 line_number: line_idx + 1,
405 context_preview: preview,
406 })
407}
408
409fn similarity_score(a: &str, b: &str) -> f64 {
411 if a.is_empty() || b.is_empty() {
412 return 0.0;
413 }
414
415 let bigrams_a: std::collections::HashSet<(char, char)> =
416 a.chars().zip(a.chars().skip(1)).collect();
417 let bigrams_b: std::collections::HashSet<(char, char)> =
418 b.chars().zip(b.chars().skip(1)).collect();
419
420 if bigrams_a.is_empty() || bigrams_b.is_empty() {
421 return 0.0;
422 }
423
424 let intersection = bigrams_a.intersection(&bigrams_b).count() as f64;
425 let union = bigrams_a.union(&bigrams_b).count() as f64;
426
427 intersection / union
428}
429
430fn extract_context(content: &str, start: usize, end: usize, context_lines: usize) -> String {
432 let lines: Vec<&str> = content.lines().collect();
433 let start_line = content[..start].lines().count().saturating_sub(1);
434 let end_line = content[..end].lines().count();
435
436 let from = start_line.saturating_sub(context_lines);
437 let to = (end_line + context_lines).min(lines.len());
438
439 lines[from..to]
440 .iter()
441 .enumerate()
442 .map(|(i, line)| format!("{:4} | {}", from + i + 1, line))
443 .collect::<Vec<_>>()
444 .join("\n")
445}
446
447fn generate_diff(path: &str, old: &str, new: &str) -> String {
449 let diff = TextDiff::from_lines(old, new);
450 let mut output = String::new();
451
452 output.push_str(&format!("--- a/{}\n", path));
453 output.push_str(&format!("+++ b/{}\n", path));
454
455 for hunk in diff.unified_diff().context_radius(3).iter_hunks() {
456 output.push_str(&format!("{}", hunk));
457 }
458
459 output
460}
461
462fn apply_edit(
464 content: &str,
465 location: &LocationMatch,
466 edit_type: EditType,
467 new_text: &str,
468) -> String {
469 match edit_type {
470 EditType::Replace => {
471 let mut result = String::with_capacity(content.len());
472 result.push_str(&content[..location.start]);
473 result.push_str(new_text);
474 result.push_str(&content[location.end..]);
475 result
476 }
477 EditType::InsertAfter => {
478 let mut result = String::with_capacity(content.len() + new_text.len());
479 result.push_str(&content[..location.end]);
480 if !new_text.starts_with('\n') && !content[..location.end].ends_with('\n') {
481 result.push('\n');
482 }
483 result.push_str(new_text);
484 result.push_str(&content[location.end..]);
485 result
486 }
487 EditType::InsertBefore => {
488 let mut result = String::with_capacity(content.len() + new_text.len());
489 result.push_str(&content[..location.start]);
490 result.push_str(new_text);
491 if !new_text.ends_with('\n') && !content[location.start..].starts_with('\n') {
492 result.push('\n');
493 }
494 result.push_str(&content[location.start..]);
495 result
496 }
497 EditType::Delete => {
498 let mut result = String::with_capacity(content.len());
499 result.push_str(&content[..location.start]);
500 result.push_str(&content[location.end..]);
501 result
502 }
503 }
504}
505
506fn truncate(s: &str, max_len: usize) -> String {
508 if s.len() <= max_len {
509 s.to_string()
510 } else {
511 format!("{}...", &s[..max_len.saturating_sub(3)])
512 }
513}
514
515fn validate_workspace_path(workspace: &Path, path_str: &str) -> Result<PathBuf, ToolError> {
517 let workspace_canonical = workspace
518 .canonicalize()
519 .unwrap_or_else(|_| workspace.to_path_buf());
520
521 let resolved = if Path::new(path_str).is_absolute() {
522 PathBuf::from(path_str)
523 } else {
524 workspace_canonical.join(path_str)
525 };
526
527 if resolved.exists() {
528 let canonical = resolved
529 .canonicalize()
530 .map_err(|e| ToolError::ExecutionFailed {
531 name: "smart_edit".into(),
532 message: format!("Path resolution failed: {}", e),
533 })?;
534
535 if !canonical.starts_with(&workspace_canonical) {
536 return Err(ToolError::PermissionDenied {
537 name: "smart_edit".into(),
538 reason: format!("Path '{}' is outside the workspace", path_str),
539 });
540 }
541 return Ok(canonical);
542 }
543
544 let mut normalized = Vec::new();
546 for component in resolved.components() {
547 match component {
548 std::path::Component::ParentDir => {
549 if normalized.pop().is_none() {
550 return Err(ToolError::PermissionDenied {
551 name: "smart_edit".into(),
552 reason: format!("Path '{}' escapes the workspace", path_str),
553 });
554 }
555 }
556 std::path::Component::CurDir => {}
557 other => normalized.push(other),
558 }
559 }
560 let normalized_path: PathBuf = normalized.iter().collect();
561
562 if !normalized_path.starts_with(&workspace_canonical) {
563 return Err(ToolError::PermissionDenied {
564 name: "smart_edit".into(),
565 reason: format!("Path '{}' is outside the workspace", path_str),
566 });
567 }
568
569 Ok(resolved)
570}
571
572#[async_trait]
573impl Tool for SmartEditTool {
574 fn name(&self) -> &str {
575 "smart_edit"
576 }
577
578 fn description(&self) -> &str {
579 "Smart code editor that accepts fuzzy location descriptions (function names, \
580 line numbers, search patterns) and edit types (replace, insert_after, \
581 insert_before, delete). Creates an auto-checkpoint before writing and \
582 returns a unified diff preview."
583 }
584
585 fn parameters_schema(&self) -> serde_json::Value {
586 serde_json::json!({
587 "type": "object",
588 "properties": {
589 "path": {
590 "type": "string",
591 "description": "Path to the file to edit (relative to workspace)"
592 },
593 "location": {
594 "type": "string",
595 "description": "Where to apply the edit. Supports: exact text to match, \
596 'line N' or 'lines N-M', function/method names (e.g. 'fn handle_request'), \
597 or fuzzy descriptions."
598 },
599 "edit_type": {
600 "type": "string",
601 "enum": ["replace", "insert_after", "insert_before", "delete"],
602 "description": "Type of edit to perform"
603 },
604 "new_text": {
605 "type": "string",
606 "description": "The new text (required for replace, insert_after, insert_before; \
607 omit for delete)"
608 }
609 },
610 "required": ["path", "location", "edit_type"]
611 })
612 }
613
614 async fn execute(&self, args: serde_json::Value) -> Result<ToolOutput, ToolError> {
615 let path_str = args["path"]
616 .as_str()
617 .ok_or_else(|| ToolError::InvalidArguments {
618 name: "smart_edit".into(),
619 reason: "'path' parameter is required".into(),
620 })?;
621
622 let location_str =
623 args["location"]
624 .as_str()
625 .ok_or_else(|| ToolError::InvalidArguments {
626 name: "smart_edit".into(),
627 reason: "'location' parameter is required".into(),
628 })?;
629
630 let edit_type_str =
631 args["edit_type"]
632 .as_str()
633 .ok_or_else(|| ToolError::InvalidArguments {
634 name: "smart_edit".into(),
635 reason: "'edit_type' parameter is required".into(),
636 })?;
637
638 let edit_type = EditType::from_str(edit_type_str).ok_or_else(|| {
639 ToolError::InvalidArguments {
640 name: "smart_edit".into(),
641 reason: format!(
642 "Invalid edit_type '{}'. Must be one of: replace, insert_after, insert_before, delete",
643 edit_type_str
644 ),
645 }
646 })?;
647
648 let new_text = args["new_text"].as_str().unwrap_or("");
649
650 if edit_type != EditType::Delete && new_text.is_empty() {
651 return Err(ToolError::InvalidArguments {
652 name: "smart_edit".into(),
653 reason: "'new_text' is required for replace and insert operations".into(),
654 });
655 }
656
657 let _ = validate_workspace_path(&self.workspace, path_str)?;
659 let path = self.workspace.join(path_str);
660
661 let content =
663 tokio::fs::read_to_string(&path)
664 .await
665 .map_err(|e| ToolError::ExecutionFailed {
666 name: "smart_edit".into(),
667 message: format!("Failed to read '{}': {}", path_str, e),
668 })?;
669
670 let location =
672 find_location(&content, location_str).map_err(|e| ToolError::ExecutionFailed {
673 name: "smart_edit".into(),
674 message: e,
675 })?;
676
677 debug!(
678 "smart_edit: matched at line {} ({} bytes)",
679 location.line_number,
680 location.matched_text.len()
681 );
682
683 let new_content = apply_edit(&content, &location, edit_type, new_text);
685
686 let diff = generate_diff(path_str, &content, &new_content);
688
689 let checkpoint_result = {
691 let mut mgr = self.checkpoint_mgr.lock().await;
692 mgr.create_checkpoint(&format!("before smart_edit on {}", path_str))
693 };
694
695 if let Err(e) = &checkpoint_result {
696 debug!("Checkpoint creation failed (non-fatal): {}", e);
697 }
698
699 tokio::fs::write(&path, &new_content)
701 .await
702 .map_err(|e| ToolError::ExecutionFailed {
703 name: "smart_edit".into(),
704 message: format!("Failed to write '{}': {}", path_str, e),
705 })?;
706
707 let edit_desc = match edit_type {
709 EditType::Replace => "replaced",
710 EditType::InsertAfter => "inserted after",
711 EditType::InsertBefore => "inserted before",
712 EditType::Delete => "deleted",
713 };
714
715 let checkpoint_note = if checkpoint_result.is_ok() {
716 " (checkpoint created, use /undo to revert)"
717 } else {
718 ""
719 };
720
721 let summary = format!(
722 "Edited '{}': {} at line {}{}\n\nDiff:\n{}",
723 path_str, edit_desc, location.line_number, checkpoint_note, diff
724 );
725
726 let mut output = ToolOutput::text(summary);
727 output.artifacts.push(Artifact::FileModified {
728 path: PathBuf::from(path_str),
729 diff,
730 });
731
732 Ok(output)
733 }
734
735 fn risk_level(&self) -> RiskLevel {
736 RiskLevel::Write
737 }
738}
739
740#[cfg(test)]
741mod tests {
742 use super::*;
743 use std::fs;
744 use tempfile::TempDir;
745
746 #[test]
747 fn test_find_location_exact() {
748 let content = "fn main() {\n println!(\"hello\");\n}\n";
749 let loc = find_location(content, "println!(\"hello\")").unwrap();
750 assert_eq!(loc.line_number, 2);
751 assert_eq!(loc.matched_text, "println!(\"hello\")");
752 }
753
754 #[test]
755 fn test_find_location_line_number() {
756 let content = "line one\nline two\nline three\n";
757 let loc = find_location(content, "line 2").unwrap();
758 assert_eq!(loc.line_number, 2);
759 assert!(loc.matched_text.contains("line two"));
760 }
761
762 #[test]
763 fn test_find_location_line_range() {
764 let content = "a\nb\nc\nd\ne\n";
765 let loc = find_location(content, "lines 2-4").unwrap();
766 assert_eq!(loc.line_number, 2);
767 assert!(loc.matched_text.contains('b'));
768 assert!(loc.matched_text.contains('c'));
769 assert!(loc.matched_text.contains('d'));
770 }
771
772 #[test]
773 fn test_find_location_function_pattern() {
774 let content = "use std::io;\n\nfn handle_request(req: Request) {\n process(req);\n}\n\nfn main() {}\n";
775 let loc = find_location(content, "fn handle_request").unwrap();
776 assert_eq!(loc.line_number, 3);
777 assert!(loc.matched_text.contains("handle_request"));
778 }
779
780 #[test]
781 fn test_find_location_fuzzy() {
782 let content = "struct Config {\n timeout: u64,\n retries: usize,\n}\n";
783 let loc = find_location(content, "timeout field").unwrap();
784 assert!(loc.matched_text.contains("timeout"));
785 }
786
787 #[test]
788 fn test_find_location_not_found() {
789 let content = "hello world\n";
790 let result = find_location(content, "nonexistent_xyz_123");
791 assert!(result.is_err());
792 }
793
794 #[test]
795 fn test_apply_edit_replace() {
796 let content = "fn old_name() {}\n";
797 let loc = find_location(content, "old_name").unwrap();
798 let result = apply_edit(content, &loc, EditType::Replace, "new_name");
799 assert!(result.contains("new_name"));
800 assert!(!result.contains("old_name"));
801 }
802
803 #[test]
804 fn test_apply_edit_insert_after() {
805 let content = "use std::io;\n\nfn main() {}\n";
806 let loc = find_location(content, "use std::io;").unwrap();
807 let result = apply_edit(content, &loc, EditType::InsertAfter, "use std::fs;");
808 assert!(result.contains("use std::io;\nuse std::fs;"));
809 }
810
811 #[test]
812 fn test_apply_edit_insert_before() {
813 let content = "fn main() {}\n";
814 let loc = find_location(content, "fn main").unwrap();
815 let result = apply_edit(content, &loc, EditType::InsertBefore, "// Entry point\n");
816 assert!(result.starts_with("// Entry point\n"));
817 }
818
819 #[test]
820 fn test_apply_edit_delete() {
821 let content = "line1\nline2\nline3\n";
822 let loc = find_location(content, "line2").unwrap();
823 let result = apply_edit(content, &loc, EditType::Delete, "");
824 assert!(!result.contains("line2"));
825 assert!(result.contains("line1"));
826 assert!(result.contains("line3"));
827 }
828
829 #[test]
830 fn test_generate_diff() {
831 let old = "line1\nline2\nline3\n";
832 let new = "line1\nmodified\nline3\n";
833 let diff = generate_diff("test.rs", old, new);
834 assert!(diff.contains("--- a/test.rs"));
835 assert!(diff.contains("+++ b/test.rs"));
836 assert!(diff.contains("-line2"));
837 assert!(diff.contains("+modified"));
838 }
839
840 #[test]
841 fn test_similarity_score() {
842 let a = "handle_request";
843 let b = "handle_request";
844 assert!((similarity_score(a, b) - 1.0).abs() < 0.01);
845
846 let c = "handle_response";
847 let score = similarity_score(a, c);
848 assert!(score > 0.3); let d = "totally_different_thing";
851 let score2 = similarity_score(a, d);
852 assert!(score2 < score); }
854
855 #[test]
856 fn test_edit_type_from_str() {
857 assert_eq!(EditType::from_str("replace"), Some(EditType::Replace));
858 assert_eq!(
859 EditType::from_str("insert_after"),
860 Some(EditType::InsertAfter)
861 );
862 assert_eq!(
863 EditType::from_str("insert-before"),
864 Some(EditType::InsertBefore)
865 );
866 assert_eq!(EditType::from_str("delete"), Some(EditType::Delete));
867 assert_eq!(EditType::from_str("remove"), Some(EditType::Delete));
868 assert_eq!(EditType::from_str("unknown"), None);
869 }
870
871 #[test]
872 fn test_parse_line_pattern() {
873 assert_eq!(parse_line_pattern("line 42"), Some((42, 42)));
874 assert_eq!(parse_line_pattern("lines 10-20"), Some((10, 20)));
875 assert_eq!(parse_line_pattern("not a line pattern"), None);
876 }
877
878 #[test]
879 fn test_extract_identifier() {
880 assert_eq!(
881 extract_identifier_from_pattern("fn handle_request"),
882 "handle_request"
883 );
884 assert_eq!(
885 extract_identifier_from_pattern("the process_data function"),
886 "process_data"
887 );
888 }
889
890 #[test]
891 fn test_find_block_end_braces() {
892 let content = "fn foo() {\n bar();\n baz();\n}\nfn next() {}";
893 let end = find_block_end(content, 0);
894 let block = &content[0..end];
895 assert!(block.contains("baz();"));
896 assert!(block.ends_with('}'));
897 }
898
899 #[test]
900 fn test_truncate() {
901 assert_eq!(truncate("short", 10), "short");
902 assert_eq!(truncate("a long string here", 10), "a long ...");
903 }
904
905 #[tokio::test]
906 async fn test_smart_edit_tool_execute_replace() {
907 let dir = TempDir::new().unwrap();
908 let workspace = dir.path().to_path_buf();
909
910 git2::Repository::init(&workspace).unwrap();
912
913 fs::write(
915 workspace.join("test.rs"),
916 "fn old_name() {\n // body\n}\n",
917 )
918 .unwrap();
919
920 let repo = git2::Repository::open(&workspace).unwrap();
922 let mut index = repo.index().unwrap();
923 index
924 .add_all(["*"].iter(), git2::IndexAddOption::DEFAULT, None)
925 .unwrap();
926 index.write().unwrap();
927 let tree_oid = index.write_tree().unwrap();
928 let tree = repo.find_tree(tree_oid).unwrap();
929 let sig = git2::Signature::now("test", "test@test.com").unwrap();
930 repo.commit(Some("HEAD"), &sig, &sig, "init", &tree, &[])
931 .unwrap();
932
933 let tool = SmartEditTool::new(workspace.clone());
934
935 let args = serde_json::json!({
936 "path": "test.rs",
937 "location": "old_name",
938 "edit_type": "replace",
939 "new_text": "new_name"
940 });
941
942 let result = tool.execute(args).await.unwrap();
943 assert!(result.content.contains("Edited"));
944 assert!(result.content.contains("replaced"));
945
946 let content = fs::read_to_string(workspace.join("test.rs")).unwrap();
948 assert!(content.contains("new_name"));
949 assert!(!content.contains("old_name"));
950 }
951
952 #[tokio::test]
953 async fn test_smart_edit_tool_execute_delete() {
954 let dir = TempDir::new().unwrap();
955 let workspace = dir.path().to_path_buf();
956
957 git2::Repository::init(&workspace).unwrap();
958 fs::write(
959 workspace.join("test.txt"),
960 "keep this\ndelete this line\nkeep this too\n",
961 )
962 .unwrap();
963
964 let repo = git2::Repository::open(&workspace).unwrap();
965 let mut index = repo.index().unwrap();
966 index
967 .add_all(["*"].iter(), git2::IndexAddOption::DEFAULT, None)
968 .unwrap();
969 index.write().unwrap();
970 let tree_oid = index.write_tree().unwrap();
971 let tree = repo.find_tree(tree_oid).unwrap();
972 let sig = git2::Signature::now("test", "test@test.com").unwrap();
973 repo.commit(Some("HEAD"), &sig, &sig, "init", &tree, &[])
974 .unwrap();
975
976 let tool = SmartEditTool::new(workspace.clone());
977
978 let args = serde_json::json!({
979 "path": "test.txt",
980 "location": "delete this line",
981 "edit_type": "delete"
982 });
983
984 let result = tool.execute(args).await.unwrap();
985 assert!(result.content.contains("deleted"));
986
987 let content = fs::read_to_string(workspace.join("test.txt")).unwrap();
988 assert!(!content.contains("delete this line"));
989 assert!(content.contains("keep this"));
990 }
991
992 #[tokio::test]
993 async fn test_smart_edit_tool_line_number() {
994 let dir = TempDir::new().unwrap();
995 let workspace = dir.path().to_path_buf();
996
997 git2::Repository::init(&workspace).unwrap();
998 fs::write(workspace.join("test.txt"), "line 1\nline 2\nline 3\n").unwrap();
999
1000 let repo = git2::Repository::open(&workspace).unwrap();
1001 let mut index = repo.index().unwrap();
1002 index
1003 .add_all(["*"].iter(), git2::IndexAddOption::DEFAULT, None)
1004 .unwrap();
1005 index.write().unwrap();
1006 let tree_oid = index.write_tree().unwrap();
1007 let tree = repo.find_tree(tree_oid).unwrap();
1008 let sig = git2::Signature::now("test", "test@test.com").unwrap();
1009 repo.commit(Some("HEAD"), &sig, &sig, "init", &tree, &[])
1010 .unwrap();
1011
1012 let tool = SmartEditTool::new(workspace.clone());
1013
1014 let args = serde_json::json!({
1015 "path": "test.txt",
1016 "location": "line 2",
1017 "edit_type": "replace",
1018 "new_text": "replaced line\n"
1019 });
1020
1021 let result = tool.execute(args).await.unwrap();
1022 assert!(result.content.contains("replaced"));
1023
1024 let content = fs::read_to_string(workspace.join("test.txt")).unwrap();
1025 assert!(content.contains("replaced line"));
1026 assert!(!content.contains("line 2"));
1027 }
1028
1029 #[tokio::test]
1030 async fn test_smart_edit_tool_missing_new_text() {
1031 let dir = TempDir::new().unwrap();
1032 let workspace = dir.path().to_path_buf();
1033 let tool = SmartEditTool::new(workspace);
1034
1035 let args = serde_json::json!({
1036 "path": "test.txt",
1037 "location": "something",
1038 "edit_type": "replace"
1039 });
1040
1041 let result = tool.execute(args).await;
1042 assert!(result.is_err());
1043 }
1044
1045 #[tokio::test]
1046 async fn test_smart_edit_tool_invalid_edit_type() {
1047 let dir = TempDir::new().unwrap();
1048 let workspace = dir.path().to_path_buf();
1049 let tool = SmartEditTool::new(workspace);
1050
1051 let args = serde_json::json!({
1052 "path": "test.txt",
1053 "location": "something",
1054 "edit_type": "invalid_op",
1055 "new_text": "hello"
1056 });
1057
1058 let result = tool.execute(args).await;
1059 assert!(result.is_err());
1060 }
1061}