Skip to main content

st/mcp/
smart_edit.rs

1//! Smart Edit Tools - Revolutionary token-efficient code editing
2//! By Aye, with inspiration from Omni's wave patterns
3//!
4//! "Why send entire diffs when you can send intentions?" - Aye
5
6use anyhow::{Context, Result};
7use serde::{Deserialize, Serialize};
8use serde_json::{json, Value};
9use std::path::Path;
10use tree_sitter::{Node, Parser};
11
12/// Supported languages with their tree-sitter parsers
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14pub enum SupportedLanguage {
15    Rust,
16    Python,
17    JavaScript,
18    TypeScript,
19    Go,
20    Java,
21    CSharp,
22    Cpp,
23    Ruby,
24}
25
26impl SupportedLanguage {
27    fn from_extension(ext: &str) -> Option<Self> {
28        match ext {
29            "rs" => Some(Self::Rust),
30            "py" => Some(Self::Python),
31            "js" | "mjs" => Some(Self::JavaScript),
32            "ts" | "tsx" => Some(Self::TypeScript),
33            "go" => Some(Self::Go),
34            "java" => Some(Self::Java),
35            "cs" => Some(Self::CSharp),
36            "cpp" | "cc" | "cxx" | "hpp" | "h" => Some(Self::Cpp),
37            "rb" => Some(Self::Ruby),
38            _ => None,
39        }
40    }
41
42    fn get_parser(&self) -> Result<Parser> {
43        use tree_sitter_language::LanguageFn;
44
45        let mut parser = Parser::new();
46        let language_fn: LanguageFn = match self {
47            Self::Rust => tree_sitter_rust::LANGUAGE,
48            Self::Python => tree_sitter_python::LANGUAGE,
49            Self::JavaScript => tree_sitter_javascript::LANGUAGE,
50            Self::TypeScript => tree_sitter_typescript::LANGUAGE_TYPESCRIPT,
51            Self::Go => tree_sitter_go::LANGUAGE,
52            Self::Java => tree_sitter_java::LANGUAGE,
53            Self::CSharp => tree_sitter_c_sharp::LANGUAGE,
54            Self::Cpp => tree_sitter_cpp::LANGUAGE,
55            Self::Ruby => tree_sitter_ruby::LANGUAGE,
56        };
57        let language = language_fn.into();
58        parser.set_language(&language)?;
59        Ok(parser)
60    }
61}
62
63/// Smart edit operations that use minimal tokens
64#[derive(Debug, Clone, Serialize, Deserialize)]
65#[serde(tag = "operation")]
66pub enum SmartEdit {
67    /// Insert a function at the appropriate location
68    InsertFunction {
69        name: String,
70        #[serde(skip_serializing_if = "Option::is_none")]
71        class_name: Option<String>,
72        #[serde(skip_serializing_if = "Option::is_none")]
73        namespace: Option<String>,
74        body: String,
75        #[serde(skip_serializing_if = "Option::is_none")]
76        after: Option<String>,
77        #[serde(skip_serializing_if = "Option::is_none")]
78        before: Option<String>,
79        #[serde(default)]
80        visibility: String, // public, private, protected
81    },
82
83    /// Replace a function body (keeps signature)
84    ReplaceFunction {
85        name: String,
86        #[serde(skip_serializing_if = "Option::is_none")]
87        class_name: Option<String>,
88        new_body: String,
89    },
90
91    /// Add imports/use statements intelligently
92    AddImport {
93        import: String,
94        #[serde(skip_serializing_if = "Option::is_none")]
95        alias: Option<String>,
96    },
97
98    /// Insert a class/struct
99    InsertClass {
100        name: String,
101        #[serde(skip_serializing_if = "Option::is_none")]
102        namespace: Option<String>,
103        body: String,
104        #[serde(skip_serializing_if = "Option::is_none")]
105        extends: Option<String>,
106        #[serde(default)]
107        implements: Vec<String>,
108    },
109
110    /// Add a method to a class
111    AddMethod {
112        class_name: String,
113        method_name: String,
114        body: String,
115        #[serde(default)]
116        visibility: String,
117    },
118
119    /// Wrap code in a construct (try-catch, if statement, etc)
120    WrapCode {
121        start_line: usize,
122        end_line: usize,
123        wrapper_type: String, // "try", "if", "while", "for"
124        #[serde(skip_serializing_if = "Option::is_none")]
125        condition: Option<String>,
126    },
127
128    /// Delete a named element
129    DeleteElement {
130        element_type: String, // "function", "class", "method"
131        name: String,
132        #[serde(skip_serializing_if = "Option::is_none")]
133        parent: Option<String>,
134    },
135
136    /// Rename across the file
137    Rename {
138        old_name: String,
139        new_name: String,
140        #[serde(default)]
141        scope: String, // "global", "class", "function"
142    },
143
144    /// Add documentation comment
145    AddDocumentation {
146        target_type: String, // "function", "class", "method"
147        target_name: String,
148        documentation: String,
149    },
150
151    /// Smart append - adds to the end of a logical section
152    SmartAppend {
153        section: String, // "imports", "functions", "classes", "main"
154        content: String,
155    },
156
157    /// Remove a function with dependency awareness
158    RemoveFunction {
159        name: String,
160        #[serde(skip_serializing_if = "Option::is_none")]
161        class_name: Option<String>,
162        #[serde(default)]
163        force: bool, // Remove even if it would break dependencies
164        #[serde(default)]
165        cascade: bool, // Also remove functions that only this one calls
166    },
167}
168
169/// Function information for the function tree
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct FunctionInfo {
172    pub name: String,
173    pub start_line: usize,
174    pub end_line: usize,
175    pub signature: String,
176    #[serde(skip_serializing_if = "Option::is_none")]
177    pub class_name: Option<String>,
178    #[serde(skip_serializing_if = "Option::is_none")]
179    pub namespace: Option<String>,
180    pub visibility: String,
181    #[serde(default)]
182    pub calls: Vec<String>,
183    #[serde(default)]
184    pub called_by: Vec<String>,
185}
186
187/// Code structure representation
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct CodeStructure {
190    pub language: String,
191    pub imports: Vec<String>,
192    pub functions: Vec<FunctionInfo>,
193    pub classes: Vec<ClassInfo>,
194    #[serde(skip_serializing_if = "Option::is_none")]
195    pub main_function: Option<String>,
196    pub line_count: usize,
197    #[serde(default)]
198    pub dependencies: DependencyGraph,
199}
200
201/// Dependency graph for tracking function relationships
202#[derive(Debug, Clone, Default, Serialize, Deserialize)]
203pub struct DependencyGraph {
204    /// Map from function name to functions it calls
205    pub calls: std::collections::HashMap<String, Vec<String>>,
206    /// Map from function name to functions that call it
207    pub called_by: std::collections::HashMap<String, Vec<String>>,
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct ClassInfo {
212    pub name: String,
213    pub start_line: usize,
214    pub end_line: usize,
215    #[serde(skip_serializing_if = "Option::is_none")]
216    pub extends: Option<String>,
217    #[serde(default)]
218    pub implements: Vec<String>,
219    pub methods: Vec<FunctionInfo>,
220}
221
222/// Smart editor that understands code structure
223pub struct SmartEditor {
224    content: String,
225    language: SupportedLanguage,
226    parser: Parser,
227    tree: Option<tree_sitter::Tree>,
228    structure: Option<CodeStructure>,
229}
230
231impl SmartEditor {
232    pub fn new(content: String, language: SupportedLanguage) -> Result<Self> {
233        let mut parser = language.get_parser()?;
234        let tree = parser.parse(&content, None);
235
236        let mut editor = Self {
237            content,
238            language,
239            parser,
240            tree,
241            structure: None,
242        };
243
244        editor.analyze_structure()?;
245        Ok(editor)
246    }
247
248    /// Analyze code structure to build a map
249    fn analyze_structure(&mut self) -> Result<()> {
250        let tree = self.tree.as_ref().context("No parse tree available")?;
251        let root = tree.root_node();
252
253        let mut structure = CodeStructure {
254            language: format!("{:?}", self.language),
255            imports: Vec::new(),
256            functions: Vec::new(),
257            classes: Vec::new(),
258            main_function: None,
259            line_count: self.content.lines().count(),
260            dependencies: DependencyGraph::default(),
261        };
262
263        // Walk the tree and extract structure
264        self.walk_node(&root, &mut structure, None)?;
265
266        self.structure = Some(structure);
267        Ok(())
268    }
269
270    fn walk_node(
271        &self,
272        node: &Node,
273        structure: &mut CodeStructure,
274        current_class: Option<&str>,
275    ) -> Result<()> {
276        match node.kind() {
277            // Rust patterns
278            "use_declaration" => {
279                if let Some(text) = self.node_text(node) {
280                    structure.imports.push(text);
281                }
282            }
283            "function_item"
284            | "method_definition"
285            | "function_definition"
286            | "function_declaration" => {
287                if let Some(func_info) = self.extract_function_info(node, current_class) {
288                    if func_info.name == "main" {
289                        structure.main_function = Some(func_info.name.clone());
290                    }
291                    structure.functions.push(func_info);
292                }
293            }
294            "struct_item" | "class_definition" | "class_declaration" => {
295                if let Some(class_info) = self.extract_class_info(node) {
296                    structure.classes.push(class_info);
297                }
298            }
299            // Python patterns
300            "import_statement" | "import_from_statement" => {
301                if let Some(text) = self.node_text(node) {
302                    structure.imports.push(text);
303                }
304            }
305            _ => {}
306        }
307
308        // Handle class context for methods
309        let class_name = match node.kind() {
310            "class_definition" | "class_declaration" => {
311                // Extract class name for method context
312                self.find_child_by_kind(node, "identifier")
313                    .and_then(|n| self.node_text(&n))
314            }
315            _ => None,
316        };
317
318        let class_context = class_name.as_deref().or(current_class);
319
320        // Recurse through children
321        for child in node.children(&mut node.walk()) {
322            self.walk_node(&child, structure, class_context)?;
323        }
324
325        Ok(())
326    }
327
328    fn node_text(&self, node: &Node) -> Option<String> {
329        node.utf8_text(self.content.as_bytes())
330            .ok()
331            .map(|s| s.to_string())
332    }
333
334    fn extract_function_info(&self, node: &Node, class_name: Option<&str>) -> Option<FunctionInfo> {
335        let name = self
336            .find_child_by_kind(node, "identifier")
337            .or_else(|| self.find_child_by_kind(node, "property_identifier"))
338            .and_then(|n| self.node_text(&n))?;
339
340        let start_line = node.start_position().row + 1;
341        let end_line = node.end_position().row + 1;
342
343        let signature = self.extract_signature(node)?;
344
345        Some(FunctionInfo {
346            name,
347            start_line,
348            end_line,
349            signature,
350            class_name: class_name.map(String::from),
351            namespace: None, // TODO: Extract namespace
352            visibility: self.extract_visibility(node),
353            calls: Vec::new(), // TODO: Extract function calls
354            called_by: Vec::new(),
355        })
356    }
357
358    fn extract_class_info(&self, node: &Node) -> Option<ClassInfo> {
359        let name = self
360            .find_child_by_kind(node, "identifier")
361            .or_else(|| self.find_child_by_kind(node, "type_identifier"))
362            .and_then(|n| self.node_text(&n))?;
363
364        let start_line = node.start_position().row + 1;
365        let end_line = node.end_position().row + 1;
366
367        let mut methods = Vec::new();
368        self.extract_methods(node, &name, &mut methods);
369
370        Some(ClassInfo {
371            name,
372            start_line,
373            end_line,
374            extends: None, // TODO: Extract inheritance
375            implements: Vec::new(),
376            methods,
377        })
378    }
379
380    fn extract_methods(&self, node: &Node, class_name: &str, methods: &mut Vec<FunctionInfo>) {
381        for child in node.children(&mut node.walk()) {
382            if matches!(child.kind(), "method_definition" | "function_item") {
383                if let Some(method_info) = self.extract_function_info(&child, Some(class_name)) {
384                    methods.push(method_info);
385                }
386            } else if child.kind().contains("body") {
387                self.extract_methods(&child, class_name, methods);
388            }
389        }
390    }
391
392    fn find_child_by_kind<'a>(&self, node: &'a Node, kind: &str) -> Option<Node<'a>> {
393        node.children(&mut node.walk()).find(|n| n.kind() == kind)
394    }
395
396    fn extract_signature(&self, node: &Node) -> Option<String> {
397        // Simple extraction - can be enhanced per language
398        let start = node.start_byte();
399        let body_start = self
400            .find_child_by_kind(node, "block")
401            .or_else(|| self.find_child_by_kind(node, "body"))
402            .map(|n| n.start_byte())
403            .unwrap_or(node.end_byte());
404
405        self.content
406            .as_bytes()
407            .get(start..body_start)
408            .and_then(|bytes| std::str::from_utf8(bytes).ok())
409            .map(|s| s.trim().to_string())
410    }
411
412    fn extract_visibility(&self, node: &Node) -> String {
413        // Look for visibility modifiers
414        for child in node.children(&mut node.walk()) {
415            match child.kind() {
416                "visibility_modifier" => {
417                    if let Some(text) = self.node_text(&child) {
418                        return text;
419                    }
420                }
421                "pub" => return "public".to_string(),
422                "private" => return "private".to_string(),
423                "protected" => return "protected".to_string(),
424                _ => {}
425            }
426        }
427        "private".to_string() // Default
428    }
429
430    /// Apply a smart edit operation
431    pub fn apply_edit(&mut self, edit: &SmartEdit) -> Result<String> {
432        match edit {
433            SmartEdit::InsertFunction {
434                name,
435                class_name,
436                body,
437                after,
438                before,
439                visibility,
440                ..
441            } => {
442                self.insert_function(
443                    name,
444                    class_name.as_deref(),
445                    body,
446                    after.as_deref(),
447                    before.as_deref(),
448                    visibility,
449                )?;
450            }
451            SmartEdit::ReplaceFunction {
452                name,
453                class_name,
454                new_body,
455            } => {
456                self.replace_function(name, class_name.as_deref(), new_body)?;
457            }
458            SmartEdit::AddImport { import, alias } => {
459                self.add_import(import, alias.as_deref())?;
460            }
461            SmartEdit::SmartAppend { section, content } => {
462                self.smart_append(section, content)?;
463            }
464            SmartEdit::RemoveFunction {
465                name,
466                class_name,
467                force,
468                cascade,
469            } => {
470                self.remove_function(name, class_name.as_deref(), *force, *cascade)?;
471            }
472            _ => {
473                return Err(anyhow::anyhow!("Operation not yet implemented"));
474            }
475        }
476
477        // Re-analyze structure after edit
478        self.tree = self.parser.parse(&self.content, None);
479        self.analyze_structure()?;
480
481        Ok(self.content.clone())
482    }
483
484    fn insert_function(
485        &mut self,
486        name: &str,
487        class_name: Option<&str>,
488        body: &str,
489        after: Option<&str>,
490        before: Option<&str>,
491        visibility: &str,
492    ) -> Result<()> {
493        let structure = self.structure.as_ref().context("No structure analyzed")?;
494
495        // Find insertion point
496        let insert_line = if let Some(after_name) = after {
497            // Insert after specified function
498            structure
499                .functions
500                .iter()
501                .find(|f| f.name == after_name && f.class_name.as_deref() == class_name)
502                .map(|f| f.end_line + 1)
503                .with_context(|| format!("Function not found: {}", after_name))?
504        } else if let Some(before_name) = before {
505            // Insert before specified function
506            structure
507                .functions
508                .iter()
509                .find(|f| f.name == before_name && f.class_name.as_deref() == class_name)
510                .map(|f| f.start_line.saturating_sub(1))
511                .with_context(|| format!("Function not found: {}", before_name))?
512        } else if let Some(class) = class_name {
513            // Insert at end of class
514            structure
515                .classes
516                .iter()
517                .find(|c| c.name == class)
518                .map(|c| {
519                    // Find last method or class end
520                    c.methods
521                        .iter()
522                        .map(|m| m.end_line)
523                        .max()
524                        .unwrap_or(c.start_line)
525                        + 1
526                })
527                .context("Class not found: {class}")?
528        } else {
529            // Insert at end of file functions
530            structure
531                .functions
532                .iter()
533                .filter(|f| f.class_name.is_none())
534                .map(|f| f.end_line)
535                .max()
536                .unwrap_or(structure.imports.len() + 1)
537                + 1
538        };
539
540        // Format function based on language
541        let formatted_function =
542            self.format_function(name, body, visibility, class_name.is_some())?;
543
544        // Insert at the calculated position
545        let lines: Vec<&str> = self.content.lines().collect();
546        let mut new_lines: Vec<String> = Vec::new();
547
548        for (i, line) in lines.iter().enumerate() {
549            new_lines.push(line.to_string());
550            if i + 1 == insert_line {
551                new_lines.push(String::new());
552                new_lines.push(formatted_function.clone());
553            }
554        }
555
556        // Handle case where we want to insert at the very end
557        if insert_line > lines.len() {
558            new_lines.push(String::new());
559            new_lines.push(formatted_function);
560        }
561
562        self.content = new_lines.join("\n");
563        Ok(())
564    }
565
566    fn format_function(
567        &self,
568        name: &str,
569        body: &str,
570        visibility: &str,
571        is_method: bool,
572    ) -> Result<String> {
573        // Format based on language
574        let formatted = match self.language {
575            SupportedLanguage::Rust => {
576                let vis = if visibility == "public" { "pub " } else { "" };
577                let indent = if is_method { "    " } else { "" };
578                format!("{indent}{vis}fn {name}{body}")
579            }
580            SupportedLanguage::Python => {
581                let indent = if is_method { "    " } else { "" };
582                format!("{indent}def {name}{body}")
583            }
584            SupportedLanguage::JavaScript | SupportedLanguage::TypeScript => {
585                let indent = if is_method { "  " } else { "" };
586                format!("{indent}function {name}{body}")
587            }
588            _ => {
589                format!("{visibility} function {name}{body}")
590            }
591        };
592
593        Ok(formatted)
594    }
595
596    fn replace_function(
597        &mut self,
598        name: &str,
599        class_name: Option<&str>,
600        new_body: &str,
601    ) -> Result<()> {
602        let structure = self.structure.as_ref().context("No structure analyzed")?;
603
604        let function = structure
605            .functions
606            .iter()
607            .find(|f| f.name == name && f.class_name.as_deref() == class_name)
608            .context("Function not found")?;
609
610        // Find the function body start (after signature)
611        let lines: Vec<&str> = self.content.lines().collect();
612        let signature_line = function.start_line - 1;
613
614        // TODO: More robust body detection
615        let body_start_line = signature_line + 1;
616        let body_end_line = function.end_line - 1;
617
618        // Replace the body
619        let mut new_lines: Vec<String> = Vec::new();
620        for (i, line) in lines.iter().enumerate() {
621            if i < body_start_line || i > body_end_line {
622                new_lines.push(line.to_string());
623            } else if i == body_start_line {
624                new_lines.push(new_body.to_string());
625            }
626        }
627
628        self.content = new_lines.join("\n");
629        Ok(())
630    }
631
632    fn add_import(&mut self, import: &str, alias: Option<&str>) -> Result<()> {
633        let structure = self.structure.as_ref().context("No structure analyzed")?;
634
635        // Format import based on language
636        let formatted_import = match self.language {
637            SupportedLanguage::Rust => {
638                if let Some(alias) = alias {
639                    format!("use {import} as {alias};")
640                } else {
641                    format!("use {import};")
642                }
643            }
644            SupportedLanguage::Python => {
645                if let Some(alias) = alias {
646                    format!("import {import} as {alias}")
647                } else {
648                    format!("import {import}")
649                }
650            }
651            SupportedLanguage::JavaScript | SupportedLanguage::TypeScript => {
652                // For now, use CommonJS style which is more common
653                if let Some(a) = alias {
654                    format!("const {} = require('{}');", a, import)
655                } else {
656                    format!("const {} = require('{}');", import, import)
657                }
658            }
659            _ => format!("import {import};"),
660        };
661
662        // Find where to insert (after last import or at top)
663        let insert_line = if structure.imports.is_empty() {
664            1
665        } else {
666            structure.imports.len() + 1
667        };
668
669        let lines: Vec<&str> = self.content.lines().collect();
670        let mut new_lines: Vec<String> = Vec::new();
671
672        for (i, line) in lines.iter().enumerate() {
673            if i + 1 == insert_line {
674                new_lines.push(formatted_import.clone());
675            }
676            new_lines.push(line.to_string());
677        }
678
679        self.content = new_lines.join("\n");
680        Ok(())
681    }
682
683    fn smart_append(&mut self, section: &str, content: &str) -> Result<()> {
684        let structure = self.structure.as_ref().context("No structure analyzed")?;
685
686        let insert_line = match section {
687            "imports" => structure.imports.len() + 1,
688            "functions" => {
689                structure
690                    .functions
691                    .iter()
692                    .filter(|f| f.class_name.is_none())
693                    .map(|f| f.end_line)
694                    .max()
695                    .unwrap_or(structure.imports.len() + 1)
696                    + 1
697            }
698            "classes" => {
699                structure
700                    .classes
701                    .iter()
702                    .map(|c| c.end_line)
703                    .max()
704                    .unwrap_or_else(|| {
705                        structure
706                            .functions
707                            .iter()
708                            .map(|f| f.end_line)
709                            .max()
710                            .unwrap_or(structure.imports.len() + 1)
711                    })
712                    + 1
713            }
714            "main" => {
715                if let Some(main_fn) = &structure.main_function {
716                    structure
717                        .functions
718                        .iter()
719                        .find(|f| &f.name == main_fn)
720                        .map(|f| f.end_line - 1)
721                        .unwrap_or(structure.line_count)
722                } else {
723                    structure.line_count
724                }
725            }
726            _ => structure.line_count,
727        };
728
729        let lines: Vec<&str> = self.content.lines().collect();
730        let mut new_lines: Vec<String> = Vec::new();
731
732        for (i, line) in lines.iter().enumerate() {
733            new_lines.push(line.to_string());
734            if i + 1 == insert_line {
735                new_lines.push(String::new());
736                new_lines.push(content.to_string());
737            }
738        }
739
740        self.content = new_lines.join("\n");
741        Ok(())
742    }
743
744    /// Get the current code structure
745    pub fn get_structure(&self) -> Option<&CodeStructure> {
746        self.structure.as_ref()
747    }
748
749    fn remove_function(
750        &mut self,
751        name: &str,
752        class_name: Option<&str>,
753        force: bool,
754        cascade: bool,
755    ) -> Result<()> {
756        // Extract data we need before borrowing self mutably
757        let (function_start, function_end, functions_to_cascade) = {
758            let structure = self.structure.as_ref().context("No structure analyzed")?;
759
760            // Find the function to remove
761            let function = structure
762                .functions
763                .iter()
764                .find(|f| f.name == name && f.class_name.as_deref() == class_name)
765                .context("Function not found")?;
766
767            // Check dependencies unless force is set
768            if !force {
769                let dependents = structure
770                    .dependencies
771                    .called_by
772                    .get(name)
773                    .map(|v| v.as_slice())
774                    .unwrap_or(&[]);
775
776                if !dependents.is_empty() {
777                    return Err(anyhow::anyhow!(
778                        "Function '{}' is called by: {}. Use force=true to remove anyway.",
779                        name,
780                        dependents.join(", ")
781                    ));
782                }
783            }
784
785            let mut functions_to_cascade = Vec::new();
786
787            // Collect functions to cascade
788            if cascade {
789                if let Some(calls) = structure.dependencies.calls.get(name) {
790                    for called_func in calls {
791                        // Check if this is the only caller
792                        if let Some(callers) = structure.dependencies.called_by.get(called_func) {
793                            if callers.len() == 1 && callers[0] == name {
794                                functions_to_cascade.push(called_func.clone());
795                            }
796                        }
797                    }
798                }
799            }
800
801            (function.start_line, function.end_line, functions_to_cascade)
802        };
803
804        // Remove the function lines
805        let lines: Vec<&str> = self.content.lines().collect();
806        let mut new_lines: Vec<String> = Vec::new();
807        let mut skip_lines = false;
808
809        for (i, line) in lines.iter().enumerate() {
810            let line_num = i + 1;
811
812            if line_num == function_start {
813                skip_lines = true;
814            }
815
816            if !skip_lines {
817                new_lines.push(line.to_string());
818            }
819
820            if line_num == function_end {
821                skip_lines = false;
822            }
823        }
824
825        self.content = new_lines.join("\n");
826
827        // Re-analyze structure after modification
828        self.tree = self.parser.parse(&self.content, None);
829        self.analyze_structure()?;
830
831        // Handle cascade removal
832        for func_to_remove in functions_to_cascade {
833            self.remove_function(&func_to_remove, None, true, cascade)?;
834        }
835
836        Ok(())
837    }
838
839    /// Get function tree with relationships
840    pub fn get_function_tree(&self) -> Result<Value> {
841        let structure = self.structure.as_ref().context("No structure analyzed")?;
842
843        // Build call graph (simplified for now)
844        let tree = json!({
845            "language": format!("{:?}", self.language),
846            "file_structure": {
847                "imports": structure.imports,
848                "line_count": structure.line_count,
849                "main_function": structure.main_function,
850            },
851            "functions": structure.functions.iter().map(|f| {
852                json!({
853                    "name": f.name,
854                    "lines": format!("{}-{}", f.start_line, f.end_line),
855                    "class": f.class_name,
856                    "visibility": f.visibility,
857                    "signature": f.signature,
858                    "calls": f.calls,
859                    "called_by": f.called_by,
860                })
861            }).collect::<Vec<_>>(),
862            "classes": structure.classes.iter().map(|c| {
863                json!({
864                    "name": c.name,
865                    "lines": format!("{}-{}", c.start_line, c.end_line),
866                    "extends": c.extends,
867                    "implements": c.implements,
868                    "methods": c.methods.iter().map(|m| {
869                        json!({
870                            "name": m.name,
871                            "lines": format!("{}-{}", m.start_line, m.end_line),
872                            "visibility": m.visibility,
873                        })
874                    }).collect::<Vec<_>>(),
875                })
876            }).collect::<Vec<_>>(),
877        });
878
879        Ok(tree)
880    }
881}
882
883/// MCP tool handler for smart edit operations
884pub async fn handle_smart_edit(params: Option<Value>) -> Result<Value> {
885    let params = params.context("Parameters required")?;
886
887    let file_path = params["file_path"].as_str().context("file_path required")?;
888
889    let edits = params["edits"].as_array().context("edits array required")?;
890
891    // Read file
892    let content = std::fs::read_to_string(file_path)?;
893    let original_content = content.clone(); // Clone for diff storage
894    let extension = Path::new(file_path)
895        .extension()
896        .and_then(|e| e.to_str())
897        .context("Could not determine file extension")?;
898
899    let language = SupportedLanguage::from_extension(extension).context("Unsupported language")?;
900
901    // Create smart editor
902    let mut editor = SmartEditor::new(content, language)?;
903
904    // Get initial structure
905    let initial_structure = editor.get_function_tree()?;
906
907    // Apply edits
908    let mut results = Vec::new();
909    for edit in edits {
910        let smart_edit: SmartEdit = serde_json::from_value(edit.clone())?;
911        match editor.apply_edit(&smart_edit) {
912            Ok(_) => {
913                results.push(json!({
914                    "status": "success",
915                    "operation": edit["operation"],
916                }));
917            }
918            Err(e) => {
919                results.push(json!({
920                    "status": "error",
921                    "operation": edit["operation"],
922                    "error": e.to_string(),
923                }));
924            }
925        }
926    }
927
928    // Get final structure
929    let final_structure = editor.get_function_tree()?;
930
931    // Store diff before writing
932    if let Ok(project_root) = std::env::current_dir() {
933        if let Ok(storage) = crate::smart_edit_diff::DiffStorage::new(&project_root) {
934            // Store the diff
935            let _ = storage.store_diff(
936                Path::new(file_path),
937                &original_content, // original content
938                &editor.content,   // new content
939            );
940
941            // Also store original if this is the first edit
942            let _ = storage.store_original(Path::new(file_path), &original_content);
943        }
944    }
945
946    // Write back to file
947    std::fs::write(file_path, &editor.content)?;
948
949    let result = json!({
950        "file_path": file_path,
951        "language": format!("{:?}", language),
952        "edits_applied": results,
953        "initial_structure": initial_structure,
954        "final_structure": final_structure,
955        "content_preview": editor.content.lines().take(20).collect::<Vec<_>>().join("\n"),
956    });
957
958    // Wrap in MCP content format
959    Ok(json!({
960        "content": [{
961            "type": "text",
962            "text": serde_json::to_string_pretty(&result)?
963        }]
964    }))
965}
966
967/// Get function tree without making changes
968pub async fn handle_get_function_tree(params: Option<Value>) -> Result<Value> {
969    let params = params.context("Parameters required")?;
970    let file_path = params["file_path"].as_str().context("file_path required")?;
971
972    let content = std::fs::read_to_string(file_path)?;
973    let extension = Path::new(file_path)
974        .extension()
975        .and_then(|e| e.to_str())
976        .context("Could not determine file extension")?;
977
978    let language = SupportedLanguage::from_extension(extension).context("Unsupported language")?;
979
980    let editor = SmartEditor::new(content, language)?;
981    let function_tree = editor.get_function_tree()?;
982
983    // Wrap in MCP content format
984    Ok(json!({
985        "content": [{
986            "type": "text",
987            "text": serde_json::to_string_pretty(&function_tree)?
988        }]
989    }))
990}
991
992/// Insert a single function using minimal tokens
993pub async fn handle_insert_function(params: Option<Value>) -> Result<Value> {
994    let params = params.context("Parameters required")?;
995
996    let edit = SmartEdit::InsertFunction {
997        name: params["name"]
998            .as_str()
999            .context("name required")?
1000            .to_string(),
1001        class_name: params["class_name"].as_str().map(String::from),
1002        namespace: params["namespace"].as_str().map(String::from),
1003        body: params["body"]
1004            .as_str()
1005            .context("body required")?
1006            .to_string(),
1007        after: params["after"].as_str().map(String::from),
1008        before: params["before"].as_str().map(String::from),
1009        visibility: params["visibility"]
1010            .as_str()
1011            .unwrap_or("private")
1012            .to_string(),
1013    };
1014
1015    handle_smart_edit(Some(json!({
1016        "file_path": params["file_path"],
1017        "edits": [edit],
1018    })))
1019    .await
1020}
1021
1022/// Remove a function with dependency checking
1023pub async fn handle_remove_function(params: Option<Value>) -> Result<Value> {
1024    let params = params.context("Parameters required")?;
1025
1026    let edit = SmartEdit::RemoveFunction {
1027        name: params["name"]
1028            .as_str()
1029            .context("name required")?
1030            .to_string(),
1031        class_name: params["class_name"].as_str().map(String::from),
1032        force: params["force"].as_bool().unwrap_or(false),
1033        cascade: params["cascade"].as_bool().unwrap_or(false),
1034    };
1035
1036    handle_smart_edit(Some(json!({
1037        "file_path": params["file_path"],
1038        "edits": [edit],
1039    })))
1040    .await
1041}
1042
1043/// Create a new file with initial content
1044pub async fn handle_create_file(params: Option<Value>) -> Result<Value> {
1045    let params = params.context("Parameters required")?;
1046
1047    let file_path = params["file_path"]
1048        .as_str()
1049        .context("file_path required")?;
1050    
1051    let content = params["content"]
1052        .as_str()
1053        .unwrap_or("");  // Empty file if no content provided
1054
1055    // Check if file already exists
1056    if Path::new(file_path).exists() {
1057        return Err(anyhow::anyhow!("File already exists: {}. Use edit operations to modify existing files.", file_path));
1058    }
1059
1060    // Create parent directories if they don't exist
1061    if let Some(parent) = Path::new(file_path).parent() {
1062        if !parent.exists() {
1063            std::fs::create_dir_all(parent)
1064                .with_context(|| format!("Failed to create parent directories for: {}", file_path))?;
1065        }
1066    }
1067
1068    // Write the file
1069    std::fs::write(file_path, content)
1070        .with_context(|| format!("Failed to create file: {}", file_path))?;
1071
1072    // Build the standard tool result and wrap it in the MCP content envelope
1073    let result = json!({
1074        "status": "success",
1075        "file_path": file_path,
1076        "message": format!("File created: {}", file_path),
1077        "size": content.len(),
1078    });
1079
1080    let pretty = serde_json::to_string_pretty(&result).unwrap_or_else(|_| result.to_string());
1081
1082    Ok(json!({
1083        "content": [
1084            {
1085                "type": "text",
1086                "text": pretty
1087            }
1088        ]
1089    }))
1090}
1091
1092#[cfg(test)]
1093mod tests {
1094    use super::*;
1095
1096    #[test]
1097    fn test_rust_function_insertion() {
1098        let content = r#"
1099use std::io;
1100
1101fn main() {
1102    println!("Hello, world!");
1103}
1104
1105fn helper() {
1106    println!("Helper");
1107}
1108"#
1109        .to_string();
1110
1111        let mut editor = SmartEditor::new(content, SupportedLanguage::Rust).unwrap();
1112        let edit = SmartEdit::InsertFunction {
1113            name: "new_function".to_string(),
1114            class_name: None,
1115            namespace: None,
1116            body: r#"() -> Result<()> {
1117    println!("New function!");
1118    Ok(())
1119}"#
1120            .to_string(),
1121            after: Some("main".to_string()),
1122            before: None,
1123            visibility: "public".to_string(),
1124        };
1125
1126        editor.apply_edit(&edit).unwrap();
1127        assert!(editor.content.contains("pub fn new_function"));
1128        assert!(
1129            editor.content.find("pub fn new_function").unwrap()
1130                > editor.content.find("fn main").unwrap()
1131        );
1132    }
1133
1134    #[test]
1135    fn test_python_function_insertion() {
1136        let content = r#"
1137import os
1138
1139def main():
1140    print("Hello, world!")
1141
1142def helper():
1143    print("Helper")
1144"#
1145        .to_string();
1146
1147        let mut editor = SmartEditor::new(content, SupportedLanguage::Python).unwrap();
1148        let edit = SmartEdit::InsertFunction {
1149            name: "process_data".to_string(),
1150            class_name: None,
1151            namespace: None,
1152            body: r#"(data):
1153    """Process the data."""
1154    return data * 2"#
1155                .to_string(),
1156            after: Some("main".to_string()),
1157            before: None,
1158            visibility: "public".to_string(),
1159        };
1160
1161        editor.apply_edit(&edit).unwrap();
1162        assert!(editor.content.contains("def process_data(data):"));
1163        assert!(editor.content.contains("return data * 2"));
1164    }
1165
1166    #[test]
1167    fn test_javascript_function_insertion() {
1168        let content = r#"
1169function main() {
1170    console.log("Hello, world!");
1171}
1172
1173function helper() {
1174    console.log("Helper");
1175}
1176"#
1177        .to_string();
1178
1179        let mut editor = SmartEditor::new(content, SupportedLanguage::JavaScript).unwrap();
1180        let edit = SmartEdit::InsertFunction {
1181            name: "processData".to_string(),
1182            class_name: None,
1183            namespace: None,
1184            body: r#"(data) {
1185    return data.map(x => x * 2);
1186}"#
1187            .to_string(),
1188            before: Some("helper".to_string()),
1189            after: None,
1190            visibility: "public".to_string(),
1191        };
1192
1193        editor.apply_edit(&edit).unwrap();
1194        assert!(editor.content.contains("function processData(data)"));
1195        assert!(editor.content.contains("return data.map(x => x * 2)"));
1196    }
1197
1198    #[test]
1199    fn test_add_import() {
1200        let content = r#"
1201use std::io;
1202
1203fn main() {
1204    println!("Hello");
1205}
1206"#
1207        .to_string();
1208
1209        let mut editor = SmartEditor::new(content, SupportedLanguage::Rust).unwrap();
1210        let edit = SmartEdit::AddImport {
1211            import: "std::collections::HashMap".to_string(),
1212            alias: None,
1213        };
1214
1215        editor.apply_edit(&edit).unwrap();
1216        assert!(editor.content.contains("use std::collections::HashMap;"));
1217
1218        // Test with alias
1219        let edit_with_alias = SmartEdit::AddImport {
1220            import: "std::sync::Arc".to_string(),
1221            alias: Some("MyArc".to_string()),
1222        };
1223
1224        editor.apply_edit(&edit_with_alias).unwrap();
1225        assert!(editor.content.contains("use std::sync::Arc as MyArc;"));
1226    }
1227
1228    #[test]
1229    fn test_replace_function() {
1230        let content = r#"
1231fn calculate(x: i32) -> i32 {
1232    x + 1
1233}
1234
1235fn main() {
1236    let result = calculate(5);
1237}
1238"#
1239        .to_string();
1240
1241        let mut editor = SmartEditor::new(content, SupportedLanguage::Rust).unwrap();
1242
1243        // First analyze to build structure
1244        let _ = editor.analyze_structure();
1245
1246        let edit = SmartEdit::ReplaceFunction {
1247            name: "calculate".to_string(),
1248            class_name: None,
1249            new_body: r#"{
1250    // Improved calculation with logging
1251    println!("Calculating for: {}", x);
1252    x * 2
1253}"#
1254            .to_string(),
1255        };
1256
1257        editor.apply_edit(&edit).unwrap();
1258        assert!(editor.content.contains("x * 2"));
1259        assert!(editor.content.contains("Improved calculation"));
1260        assert!(!editor.content.contains("x + 1")); // Old body should be gone
1261    }
1262
1263    #[test]
1264    fn test_smart_append() {
1265        let content = r#"
1266import os
1267
1268def main():
1269    pass
1270
1271class MyClass:
1272    pass
1273"#
1274        .to_string();
1275
1276        let mut editor = SmartEditor::new(content, SupportedLanguage::Python).unwrap();
1277
1278        // Append to imports section
1279        let import_edit = SmartEdit::SmartAppend {
1280            section: "imports".to_string(),
1281            content: "import sys".to_string(),
1282        };
1283
1284        editor.apply_edit(&import_edit).unwrap();
1285        assert!(editor.content.contains("import sys"));
1286
1287        // Append to functions section
1288        let func_edit = SmartEdit::SmartAppend {
1289            section: "functions".to_string(),
1290            content: "def helper():\n    return True".to_string(),
1291        };
1292
1293        editor.apply_edit(&func_edit).unwrap();
1294        assert!(editor.content.contains("def helper():"));
1295    }
1296
1297    #[test]
1298    fn test_remove_function_with_dependencies() {
1299        let content = r#"
1300fn caller() {
1301    helper();
1302}
1303
1304fn helper() {
1305    println!("I'm helping!");
1306}
1307
1308fn orphan() {
1309    // Only called by helper
1310}
1311"#
1312        .to_string();
1313
1314        let mut editor = SmartEditor::new(content, SupportedLanguage::Rust).unwrap();
1315
1316        // Build dependency graph (simplified for test)
1317        editor.structure = Some(CodeStructure {
1318            language: "Rust".to_string(),
1319            imports: vec![],
1320            functions: vec![
1321                FunctionInfo {
1322                    name: "caller".to_string(),
1323                    class_name: None,
1324                    namespace: None,
1325                    start_line: 2,
1326                    end_line: 4,
1327                    signature: "fn caller()".to_string(),
1328                    visibility: "private".to_string(),
1329                    calls: vec!["helper".to_string()],
1330                    called_by: vec![],
1331                },
1332                FunctionInfo {
1333                    name: "helper".to_string(),
1334                    class_name: None,
1335                    namespace: None,
1336                    start_line: 6,
1337                    end_line: 8,
1338                    signature: "fn helper()".to_string(),
1339                    visibility: "private".to_string(),
1340                    calls: vec!["orphan".to_string()],
1341                    called_by: vec!["caller".to_string()],
1342                },
1343                FunctionInfo {
1344                    name: "orphan".to_string(),
1345                    class_name: None,
1346                    namespace: None,
1347                    start_line: 10,
1348                    end_line: 12,
1349                    signature: "fn orphan()".to_string(),
1350                    visibility: "private".to_string(),
1351                    calls: vec![],
1352                    called_by: vec!["helper".to_string()],
1353                },
1354            ],
1355            classes: vec![],
1356            main_function: None,
1357            line_count: 12,
1358            dependencies: DependencyGraph {
1359                calls: [
1360                    ("caller".to_string(), vec!["helper".to_string()]),
1361                    ("helper".to_string(), vec!["orphan".to_string()]),
1362                ]
1363                .into_iter()
1364                .collect(),
1365                called_by: [
1366                    ("helper".to_string(), vec!["caller".to_string()]),
1367                    ("orphan".to_string(), vec!["helper".to_string()]),
1368                ]
1369                .into_iter()
1370                .collect(),
1371            },
1372        });
1373
1374        // Try to remove helper without force - should fail
1375        let remove_edit = SmartEdit::RemoveFunction {
1376            name: "helper".to_string(),
1377            class_name: None,
1378            force: false,
1379            cascade: false,
1380        };
1381
1382        let result = editor.apply_edit(&remove_edit);
1383        assert!(result.is_err());
1384        assert!(result
1385            .unwrap_err()
1386            .to_string()
1387            .contains("called by: caller"));
1388
1389        // Remove with force
1390        let force_remove = SmartEdit::RemoveFunction {
1391            name: "helper".to_string(),
1392            class_name: None,
1393            force: true,
1394            cascade: false,
1395        };
1396
1397        editor.apply_edit(&force_remove).unwrap();
1398        assert!(!editor.content.contains("fn helper()"));
1399        assert!(editor.content.contains("fn orphan()")); // Orphan still there without cascade
1400    }
1401
1402    #[test]
1403    fn test_get_function_tree() {
1404        let content = r#"
1405class Calculator:
1406    def add(self, a, b):
1407        return a + b
1408    
1409    def multiply(self, a, b):
1410        return self.add(a, b) * b
1411
1412def main():
1413    calc = Calculator()
1414    result = calc.add(5, 3)
1415"#
1416        .to_string();
1417
1418        let editor = SmartEditor::new(content, SupportedLanguage::Python).unwrap();
1419        let tree = editor.get_function_tree().unwrap();
1420
1421        // Check tree structure
1422        assert!(tree["language"].as_str().unwrap().contains("Python"));
1423        assert!(tree["functions"].is_array());
1424        assert!(tree["classes"].is_array());
1425
1426        // Verify it found the functions and classes
1427        let functions = tree["functions"].as_array().unwrap();
1428        assert!(functions.iter().any(|f| f["name"] == "main"));
1429
1430        let classes = tree["classes"].as_array().unwrap();
1431        assert!(classes.iter().any(|c| c["name"] == "Calculator"));
1432    }
1433
1434    #[test]
1435    fn test_multiple_edits() {
1436        let content = r#"
1437fn main() {
1438    println!("Start");
1439}
1440"#
1441        .to_string();
1442
1443        let mut editor = SmartEditor::new(content, SupportedLanguage::Rust).unwrap();
1444
1445        // Apply multiple edits
1446        let edits = vec![
1447            SmartEdit::AddImport {
1448                import: "std::thread".to_string(),
1449                alias: None,
1450            },
1451            SmartEdit::InsertFunction {
1452                name: "worker".to_string(),
1453                class_name: None,
1454                namespace: None,
1455                body: r#"() {
1456    thread::sleep(std::time::Duration::from_secs(1));
1457}"#
1458                .to_string(),
1459                after: Some("main".to_string()),
1460                before: None,
1461                visibility: "private".to_string(),
1462            },
1463        ];
1464
1465        for edit in edits {
1466            editor.apply_edit(&edit).unwrap();
1467        }
1468
1469        assert!(editor.content.contains("use std::thread;"));
1470        assert!(editor.content.contains("fn worker()"));
1471    }
1472
1473    #[test]
1474    fn test_class_method_insertion() {
1475        let content = r#"
1476class MyClass:
1477    def __init__(self):
1478        self.value = 0
1479    
1480    def get_value(self):
1481        return self.value
1482"#
1483        .to_string();
1484
1485        let mut editor = SmartEditor::new(content, SupportedLanguage::Python).unwrap();
1486
1487        let edit = SmartEdit::InsertFunction {
1488            name: "set_value".to_string(),
1489            class_name: Some("MyClass".to_string()),
1490            namespace: None,
1491            body: r#"(self, value):
1492        self.value = value"#
1493                .to_string(),
1494            after: Some("get_value".to_string()),
1495            before: None,
1496            visibility: "public".to_string(),
1497        };
1498
1499        editor.apply_edit(&edit).unwrap();
1500        assert!(editor.content.contains("def set_value(self, value):"));
1501        assert!(editor.content.contains("self.value = value"));
1502    }
1503
1504    #[tokio::test]
1505    async fn test_create_file() {
1506        use tempfile::tempdir;
1507        
1508        let dir = tempdir().unwrap();
1509        let test_file = dir.path().join("new_test.rs");
1510        
1511        let params = json!({
1512            "file_path": test_file.to_str().unwrap(),
1513            "content": "// Test file\npub fn hello() {\n    println!(\"Hello!\");\n}\n"
1514        });
1515        
1516        // Test successful creation
1517        let result = handle_create_file(Some(params.clone())).await;
1518        assert!(result.is_ok(), "Failed to create file: {:?}", result.err());
1519        
1520        // Verify file exists
1521        assert!(test_file.exists(), "File was not created");
1522        
1523        // Verify content
1524        let content = std::fs::read_to_string(&test_file).unwrap();
1525        assert!(content.contains("pub fn hello()"));
1526        assert!(content.contains("println!"));
1527        
1528        // Test that creating existing file fails
1529        let result2 = handle_create_file(Some(params)).await;
1530        assert!(result2.is_err(), "Should fail when file already exists");
1531        assert!(result2.unwrap_err().to_string().contains("already exists"));
1532    }
1533
1534    #[tokio::test]
1535    async fn test_create_file_with_parent_dirs() {
1536        use tempfile::tempdir;
1537        
1538        let dir = tempdir().unwrap();
1539        let test_file = dir.path().join("subdir/nested/test.py");
1540        
1541        let params = json!({
1542            "file_path": test_file.to_str().unwrap(),
1543            "content": "def main():\n    print('Hello')\n"
1544        });
1545        
1546        // Should create parent directories
1547        let result = handle_create_file(Some(params)).await;
1548        assert!(result.is_ok(), "Failed to create file with parent dirs: {:?}", result.err());
1549        
1550        // Verify file and parents exist
1551        assert!(test_file.exists(), "File was not created");
1552        assert!(test_file.parent().unwrap().exists(), "Parent directory was not created");
1553        
1554        // Verify content
1555        let content = std::fs::read_to_string(&test_file).unwrap();
1556        assert!(content.contains("def main()"));
1557    }
1558
1559    #[tokio::test]
1560    async fn test_create_empty_file() {
1561        use tempfile::tempdir;
1562        
1563        let dir = tempdir().unwrap();
1564        let test_file = dir.path().join("empty.txt");
1565        
1566        let params = json!({
1567            "file_path": test_file.to_str().unwrap()
1568            // No content field - should create empty file
1569        });
1570        
1571        let result = handle_create_file(Some(params)).await;
1572        assert!(result.is_ok(), "Failed to create empty file: {:?}", result.err());
1573        
1574        // Verify file exists and is empty
1575        assert!(test_file.exists());
1576        let content = std::fs::read_to_string(&test_file).unwrap();
1577        assert_eq!(content, "");
1578    }
1579}