Skip to main content

aft/commands/
extract_function.rs

1//! Handler for the `extract_function` command: extract a range of code into
2//! a new function with auto-detected parameters and return value.
3//!
4//! Follows the edit_symbol.rs pattern: validate → parse → compute →
5//! auto_backup → write_format_validate → respond.
6
7use std::path::Path;
8
9use tree_sitter::Parser;
10
11use crate::context::AppContext;
12use crate::edit;
13use crate::extract::{
14    detect_free_variables, detect_return_value, generate_call_site, generate_extracted_function,
15    ReturnKind,
16};
17use crate::indent::detect_indent;
18use crate::parser::{detect_language, grammar_for, LangId};
19use crate::protocol::{RawRequest, Response};
20
21/// Handle an `extract_function` request.
22///
23/// Params:
24///   - `file` (string, required) — target file path
25///   - `name` (string, required) — name for the new function
26///   - `start_line` (u32, required) — first line of the range to extract (1-based)
27///   - `end_line` (u32, required) — last line (exclusive, 1-based) of the range to extract
28///
29/// Returns on success:
30///   `{ file, name, parameters, return_type, extracted_range, call_site_range, syntax_valid?, backup_id }`
31///
32/// `syntax_valid` is absent when syntax validation could not run.
33///
34/// Error codes:
35///   - `unsupported_language` — file is not TS/JS/TSX/Python
36///   - `this_reference_in_range` — range contains `this`/`self`
37pub fn handle_extract_function(req: &RawRequest, ctx: &AppContext) -> Response {
38    let op_id = crate::backup::new_op_id();
39    // --- Extract params ---
40    let file = match req.params.get("file").and_then(|v| v.as_str()) {
41        Some(f) => f,
42        None => {
43            return Response::error(
44                &req.id,
45                "invalid_request",
46                "extract_function: missing required param 'file'",
47            );
48        }
49    };
50
51    let name = match req.params.get("name").and_then(|v| v.as_str()) {
52        Some(n) => n,
53        None => {
54            return Response::error(
55                &req.id,
56                "invalid_request",
57                "extract_function: missing required param 'name'",
58            );
59        }
60    };
61
62    let start_line_1based = match req.params.get("start_line").and_then(|v| v.as_u64()) {
63        Some(l) if l >= 1 => l as u32,
64        Some(_) => {
65            return Response::error(
66                &req.id,
67                "invalid_request",
68                "extract_function: 'start_line' must be >= 1 (1-based)",
69            );
70        }
71        None => {
72            return Response::error(
73                &req.id,
74                "invalid_request",
75                "extract_function: missing required param 'start_line'",
76            );
77        }
78    };
79    let start_line = start_line_1based - 1;
80
81    let end_line_1based = match req.params.get("end_line").and_then(|v| v.as_u64()) {
82        Some(l) if l >= 1 => l as u32,
83        Some(_) => {
84            return Response::error(
85                &req.id,
86                "invalid_request",
87                "extract_function: 'end_line' must be >= 1 (1-based)",
88            );
89        }
90        None => {
91            return Response::error(
92                &req.id,
93                "invalid_request",
94                "extract_function: missing required param 'end_line'",
95            );
96        }
97    };
98    let end_line = end_line_1based - 1;
99
100    if start_line >= end_line {
101        return Response::error(
102            &req.id,
103            "invalid_request",
104            format!(
105                "extract_function: start_line ({}) must be less than end_line ({})",
106                start_line, end_line
107            ),
108        );
109    }
110
111    // --- Validate file ---
112    let path = match ctx.validate_path(&req.id, Path::new(file)) {
113        Ok(path) => path,
114        Err(resp) => return resp,
115    };
116    if !path.exists() {
117        return Response::error(
118            &req.id,
119            "file_not_found",
120            format!("extract_function: file not found: {}", file),
121        );
122    }
123
124    // --- Language guard (D101) ---
125    let lang = match detect_language(&path) {
126        Some(l) => l,
127        None => {
128            return Response::error(
129                &req.id,
130                "unsupported_language",
131                "extract_function: unsupported file type",
132            );
133        }
134    };
135
136    if !matches!(
137        lang,
138        LangId::TypeScript | LangId::Tsx | LangId::JavaScript | LangId::Python
139    ) {
140        return Response::error(
141            &req.id,
142            "unsupported_language",
143            format!(
144                "extract_function: only TypeScript/JavaScript/Python files are supported, got {:?}",
145                lang
146            ),
147        );
148    }
149
150    // --- Read and parse ---
151    let source = match std::fs::read_to_string(&path) {
152        Ok(s) => s,
153        Err(e) => {
154            return Response::error(
155                &req.id,
156                "file_not_found",
157                format!("extract_function: {}: {}", file, e),
158            );
159        }
160    };
161
162    let grammar = grammar_for(lang);
163    let mut parser = Parser::new();
164    if parser.set_language(&grammar).is_err() {
165        return Response::error(
166            &req.id,
167            "parse_error",
168            "extract_function: failed to initialize parser",
169        );
170    }
171    let tree = match parser.parse(source.as_bytes(), None) {
172        Some(t) => t,
173        None => {
174            return Response::error(
175                &req.id,
176                "parse_error",
177                "extract_function: failed to parse file",
178            );
179        }
180    };
181
182    // --- Convert line range to byte range ---
183    let start_byte = edit::line_col_to_byte(&source, start_line, 0);
184    let end_byte = edit::line_col_to_byte(&source, end_line, 0);
185
186    if start_byte >= source.len() {
187        return Response::error(
188            &req.id,
189            "invalid_request",
190            format!(
191                "extract_function: start_line {} is beyond end of file",
192                start_line
193            ),
194        );
195    }
196
197    // --- Detect free variables ---
198    let free_vars = detect_free_variables(&source, &tree, start_byte, end_byte, lang);
199
200    // Check for this/self
201    if free_vars.has_this_or_self {
202        let keyword = match lang {
203            LangId::Python => "self",
204            _ => "this",
205        };
206        return Response::error(
207            &req.id,
208            "this_reference_in_range",
209            format!(
210                "extract_function: selected range contains '{}' reference. Consider extracting as a method instead, or move the {} usage outside the extracted range.",
211                keyword, keyword
212            ),
213        );
214    }
215
216    // --- Find enclosing function for return value detection ---
217    let root = tree.root_node();
218    let enclosing_fn = find_enclosing_function_node(&root, start_byte, lang);
219    let enclosing_fn_end_byte = enclosing_fn.map(|n| n.end_byte());
220
221    // --- Detect return value ---
222    let return_kind = detect_return_value(
223        &source,
224        &tree,
225        start_byte,
226        end_byte,
227        enclosing_fn_end_byte,
228        lang,
229    );
230
231    // --- Detect indentation ---
232    let indent_style = detect_indent(&source, lang);
233
234    // Determine base indent (indentation of the line where the enclosing function starts,
235    // or no indent if at module level)
236    let base_indent = if let Some(fn_node) = enclosing_fn {
237        let fn_start_line = fn_node.start_position().row;
238        get_line_indent(&source, fn_start_line as usize)
239    } else {
240        String::new()
241    };
242
243    // Determine the indent of the extracted range (for the call site)
244    let range_indent = get_line_indent(&source, start_line as usize);
245
246    // --- Extract body text ---
247    let body_text = &source[start_byte..end_byte];
248    let body_text = body_text.trim_end_matches('\n');
249
250    // --- Generate function and call site ---
251    let extracted_fn = generate_extracted_function(
252        name,
253        &free_vars.parameters,
254        &return_kind,
255        body_text,
256        &base_indent,
257        lang,
258        indent_style,
259    );
260
261    let call_site = generate_call_site(
262        name,
263        &free_vars.parameters,
264        &return_kind,
265        &range_indent,
266        lang,
267    );
268
269    // --- Compute new file content ---
270    // Insert the extracted function before the enclosing function (or at the range position
271    // if there's no enclosing function).
272    //
273    // For TS/JS, when the enclosing function is wrapped in `export` (or
274    // `export default`), the parser reports the function_declaration as
275    // a child of an export_statement. If we use fn_node.start_byte() as
276    // the insertion point, the `export` keyword stays attached to the
277    // start of the file content, and the extracted function gets inserted
278    // BETWEEN `export ` and `function`, producing:
279    //   `export function newFn(...) {} \n\n function originalFn(...) {}`
280    // The export keyword silently jumps from the original function to the
281    // extracted one. Fix: if the parent is an export_statement, insert
282    // before the export_statement instead.
283    let insert_pos = if let Some(fn_node) = enclosing_fn {
284        let mut anchor = fn_node;
285        if matches!(lang, LangId::TypeScript | LangId::Tsx | LangId::JavaScript) {
286            if let Some(parent) = fn_node.parent() {
287                if parent.kind() == "export_statement" {
288                    anchor = parent;
289                }
290            }
291        }
292        anchor.start_byte()
293    } else {
294        start_byte
295    };
296
297    let new_source = build_new_source(
298        &source,
299        insert_pos,
300        start_byte,
301        end_byte,
302        &extracted_fn,
303        &call_site,
304    );
305
306    // --- Return type string for the response ---
307    let return_type = match &return_kind {
308        ReturnKind::Expression(_) => "expression",
309        ReturnKind::Variable(_) => "variable",
310        ReturnKind::Void => "void",
311    };
312
313    // --- Auto-backup before mutation ---
314    let backup_id = match edit::auto_backup(
315        ctx,
316        req.session(),
317        &path,
318        &format!("extract_function: {}", name),
319        Some(&op_id),
320    ) {
321        Ok(id) => id,
322        Err(e) => {
323            return Response::error(&req.id, e.code(), e.to_string());
324        }
325    };
326
327    // --- Write, format, validate ---
328    let mut write_result =
329        match edit::write_format_validate(&path, &new_source, &ctx.config(), &req.params) {
330            Ok(r) => r,
331            Err(e) => {
332                return Response::error(&req.id, e.code(), e.to_string());
333            }
334        };
335
336    // Honesty: write_format_validate reverts the file when the extraction
337    // produced invalid syntax (e.g. extracting a class method emits a bare
338    // `function helper` into a class body). The file is unchanged, so returning
339    // a success response with the new symbol metadata would be a lie. Fail with
340    // the validation errors so the agent retries. (edit_match/batch surface this
341    // via rolled_back already; the refactor handlers did not.)
342    if write_result.rolled_back {
343        return Response::error(
344            &req.id,
345            "generated_invalid_syntax",
346            format!(
347                "extract_function produced invalid syntax; the file was left unchanged. {}",
348                edit::format_validation_errors(&write_result.validation_errors)
349            ),
350        );
351    }
352
353    if let Ok(final_content) = std::fs::read_to_string(&path) {
354        write_result.lsp_outcome = ctx.lsp_post_write(&path, &final_content, &req.params);
355    }
356
357    let param_count = free_vars.parameters.len();
358    log::debug!(
359        "extract_function: {} from {}:{}-{} ({} params)",
360        name,
361        file,
362        start_line,
363        end_line,
364        param_count
365    );
366
367    // --- Build response ---
368    let mut result = serde_json::json!({
369        "file": file,
370        "name": name,
371        "parameters": free_vars.parameters,
372        "return_type": return_type,
373        "formatted": write_result.formatted,
374    });
375
376    if let Some(valid) = write_result.syntax_valid {
377        result["syntax_valid"] = serde_json::json!(valid);
378    }
379
380    if let Some(ref reason) = write_result.format_skipped_reason {
381        result["format_skipped_reason"] = serde_json::json!(reason);
382    }
383
384    if write_result.validate_requested {
385        result["validation_errors"] = serde_json::json!(write_result.validation_errors);
386    }
387    if let Some(ref reason) = write_result.validate_skipped_reason {
388        result["validate_skipped_reason"] = serde_json::json!(reason);
389    }
390
391    if let Some(ref id) = backup_id {
392        result["backup_id"] = serde_json::json!(id);
393    }
394
395    write_result.append_lsp_diagnostics_to(&mut result);
396    Response::success(&req.id, result)
397}
398
399/// Find the enclosing function node for a byte position.
400fn find_enclosing_function_node<'a>(
401    root: &'a tree_sitter::Node<'a>,
402    byte_pos: usize,
403    lang: LangId,
404) -> Option<tree_sitter::Node<'a>> {
405    let fn_kinds: &[&str] = match lang {
406        LangId::TypeScript | LangId::Tsx | LangId::JavaScript => &[
407            "function_declaration",
408            "method_definition",
409            "arrow_function",
410            "lexical_declaration",
411        ],
412        LangId::Python => &["function_definition"],
413        _ => &[],
414    };
415
416    find_deepest_ancestor(root, byte_pos, fn_kinds)
417}
418
419/// Find the deepest ancestor node of given kinds containing byte_pos.
420fn find_deepest_ancestor<'a>(
421    node: &tree_sitter::Node<'a>,
422    byte_pos: usize,
423    kinds: &[&str],
424) -> Option<tree_sitter::Node<'a>> {
425    let mut result: Option<tree_sitter::Node<'a>> = None;
426    if kinds.contains(&node.kind()) && node.start_byte() <= byte_pos && byte_pos < node.end_byte() {
427        result = Some(*node);
428    }
429
430    let child_count = node.child_count();
431    for i in 0..child_count {
432        if let Some(child) = node.child(i as u32) {
433            if child.start_byte() <= byte_pos && byte_pos < child.end_byte() {
434                if let Some(deeper) = find_deepest_ancestor(&child, byte_pos, kinds) {
435                    result = Some(deeper);
436                }
437            }
438        }
439    }
440
441    result
442}
443
444/// Get the leading whitespace of a source line.
445fn get_line_indent(source: &str, line: usize) -> String {
446    source
447        .lines()
448        .nth(line)
449        .map(|l| {
450            let trimmed = l.trim_start();
451            l[..l.len() - trimmed.len()].to_string()
452        })
453        .unwrap_or_default()
454}
455
456/// Build the new source with the extracted function inserted and the range replaced.
457fn build_new_source(
458    source: &str,
459    insert_pos: usize,
460    range_start: usize,
461    range_end: usize,
462    extracted_fn: &str,
463    call_site: &str,
464) -> String {
465    let mut result = String::with_capacity(source.len() + extracted_fn.len() + 64);
466
467    // Everything before the insertion point
468    result.push_str(&source[..insert_pos]);
469
470    // The extracted function + blank line
471    result.push_str(extracted_fn);
472    result.push_str("\n\n");
473
474    // Everything between insert point and the range start (the original function
475    // declaration up to where extraction begins)
476    result.push_str(&source[insert_pos..range_start]);
477
478    // The call site replacing the original range
479    result.push_str(call_site);
480    result.push('\n');
481
482    // Everything after the range
483    result.push_str(&source[range_end..]);
484
485    result
486}
487
488// ---------------------------------------------------------------------------
489// Tests
490// ---------------------------------------------------------------------------
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495    use crate::protocol::RawRequest;
496
497    fn make_request(id: &str, command: &str, params: serde_json::Value) -> RawRequest {
498        RawRequest {
499            id: id.to_string(),
500            command: command.to_string(),
501            params,
502            lsp_hints: None,
503            session_id: None,
504        }
505    }
506
507    // --- Param validation ---
508
509    #[test]
510    fn extract_function_missing_file() {
511        let req = make_request("1", "extract_function", serde_json::json!({}));
512        let ctx = crate::context::AppContext::new(
513            Box::new(crate::parser::TreeSitterProvider::new()),
514            crate::config::Config::default(),
515        );
516        let resp = handle_extract_function(&req, &ctx);
517        let json = serde_json::to_value(&resp).unwrap();
518        assert_eq!(json["success"], false);
519        assert_eq!(json["code"], "invalid_request");
520        let msg = json["message"].as_str().unwrap();
521        assert!(
522            msg.contains("file"),
523            "message should mention 'file': {}",
524            msg
525        );
526    }
527
528    #[test]
529    fn extract_function_missing_name() {
530        let req = make_request(
531            "2",
532            "extract_function",
533            serde_json::json!({"file": "/tmp/test.ts"}),
534        );
535        let ctx = crate::context::AppContext::new(
536            Box::new(crate::parser::TreeSitterProvider::new()),
537            crate::config::Config::default(),
538        );
539        let resp = handle_extract_function(&req, &ctx);
540        let json = serde_json::to_value(&resp).unwrap();
541        assert_eq!(json["success"], false);
542        assert_eq!(json["code"], "invalid_request");
543        let msg = json["message"].as_str().unwrap();
544        assert!(
545            msg.contains("name"),
546            "message should mention 'name': {}",
547            msg
548        );
549    }
550
551    #[test]
552    fn extract_function_missing_start_line() {
553        let req = make_request(
554            "3",
555            "extract_function",
556            serde_json::json!({"file": "/tmp/test.ts", "name": "foo"}),
557        );
558        let ctx = crate::context::AppContext::new(
559            Box::new(crate::parser::TreeSitterProvider::new()),
560            crate::config::Config::default(),
561        );
562        let resp = handle_extract_function(&req, &ctx);
563        let json = serde_json::to_value(&resp).unwrap();
564        assert_eq!(json["success"], false);
565        assert_eq!(json["code"], "invalid_request");
566    }
567
568    #[test]
569    fn extract_function_unsupported_language() {
570        // Create a temp .rs file (Rust is not supported for extract_function)
571        let dir = std::env::temp_dir().join("aft_test_extract");
572        std::fs::create_dir_all(&dir).ok();
573        let file = dir.join("test.rs");
574        std::fs::write(&file, "fn main() {}").unwrap();
575
576        let req = make_request(
577            "4",
578            "extract_function",
579            serde_json::json!({
580                "file": file.display().to_string(),
581                "name": "foo",
582                "start_line": 1,
583                "end_line": 2,
584            }),
585        );
586        let ctx = crate::context::AppContext::new(
587            Box::new(crate::parser::TreeSitterProvider::new()),
588            crate::config::Config::default(),
589        );
590        let resp = handle_extract_function(&req, &ctx);
591        let json = serde_json::to_value(&resp).unwrap();
592        assert_eq!(json["success"], false);
593        assert_eq!(json["code"], "unsupported_language");
594
595        std::fs::remove_dir_all(&dir).ok();
596    }
597
598    #[test]
599    fn extract_function_invalid_line_range() {
600        let dir = std::env::temp_dir().join("aft_test_extract_range");
601        std::fs::create_dir_all(&dir).ok();
602        let file = dir.join("test.ts");
603        std::fs::write(&file, "const x = 1;\n").unwrap();
604
605        let req = make_request(
606            "5",
607            "extract_function",
608            serde_json::json!({
609                "file": file.display().to_string(),
610                "name": "foo",
611                "start_line": 6,
612                "end_line": 4,
613            }),
614        );
615        let ctx = crate::context::AppContext::new(
616            Box::new(crate::parser::TreeSitterProvider::new()),
617            crate::config::Config::default(),
618        );
619        let resp = handle_extract_function(&req, &ctx);
620        let json = serde_json::to_value(&resp).unwrap();
621        assert_eq!(json["success"], false);
622        assert_eq!(json["code"], "invalid_request");
623
624        std::fs::remove_dir_all(&dir).ok();
625    }
626
627    #[test]
628    fn extract_function_this_reference_error() {
629        let fixture = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
630            .join("tests/fixtures/extract_function/sample_this.ts");
631
632        let req = make_request(
633            "6",
634            "extract_function",
635            serde_json::json!({
636                "file": fixture.display().to_string(),
637                "name": "extracted",
638                "start_line": 5,
639                "end_line": 8,
640            }),
641        );
642        let ctx = crate::context::AppContext::new(
643            Box::new(crate::parser::TreeSitterProvider::new()),
644            crate::config::Config::default(),
645        );
646        let resp = handle_extract_function(&req, &ctx);
647        let json = serde_json::to_value(&resp).unwrap();
648        assert_eq!(json["success"], false);
649        assert_eq!(json["code"], "this_reference_in_range");
650    }
651}