Skip to main content

aft/commands/
extract_function.rs

1//! Handler for the `extract_function` command: extract a range of code into
2//! a new function with auto-detected parameters and return value.
3//!
4//! Follows the edit_symbol.rs pattern: validate → parse → compute → dry_run
5//! check → auto_backup → write_format_validate → respond.
6
7use std::path::Path;
8
9use tree_sitter::Parser;
10
11use crate::context::AppContext;
12use crate::edit;
13use crate::extract::{
14    detect_free_variables, detect_return_value, generate_call_site, generate_extracted_function,
15    ReturnKind,
16};
17use crate::indent::detect_indent;
18use crate::parser::{detect_language, grammar_for, LangId};
19use crate::protocol::{RawRequest, Response};
20
21/// Handle an `extract_function` request.
22///
23/// Params:
24///   - `file` (string, required) — target file path
25///   - `name` (string, required) — name for the new function
26///   - `start_line` (u32, required) — first line of the range to extract (1-based)
27///   - `end_line` (u32, required) — last line (exclusive, 1-based) of the range to extract
28///   - `dry_run` (bool, optional) — if true, return diff without writing
29///
30/// Returns on success:
31///   `{ file, name, parameters, return_type, extracted_range, call_site_range, syntax_valid, backup_id }`
32///
33/// Error codes:
34///   - `unsupported_language` — file is not TS/JS/TSX/Python
35///   - `this_reference_in_range` — range contains `this`/`self`
36pub fn handle_extract_function(req: &RawRequest, ctx: &AppContext) -> Response {
37    // --- Extract params ---
38    let file = match req.params.get("file").and_then(|v| v.as_str()) {
39        Some(f) => f,
40        None => {
41            return Response::error(
42                &req.id,
43                "invalid_request",
44                "extract_function: missing required param 'file'",
45            );
46        }
47    };
48
49    let name = match req.params.get("name").and_then(|v| v.as_str()) {
50        Some(n) => n,
51        None => {
52            return Response::error(
53                &req.id,
54                "invalid_request",
55                "extract_function: missing required param 'name'",
56            );
57        }
58    };
59
60    let start_line_1based = match req.params.get("start_line").and_then(|v| v.as_u64()) {
61        Some(l) if l >= 1 => l as u32,
62        Some(_) => {
63            return Response::error(
64                &req.id,
65                "invalid_request",
66                "extract_function: 'start_line' must be >= 1 (1-based)",
67            );
68        }
69        None => {
70            return Response::error(
71                &req.id,
72                "invalid_request",
73                "extract_function: missing required param 'start_line'",
74            );
75        }
76    };
77    let start_line = start_line_1based - 1;
78
79    let end_line_1based = match req.params.get("end_line").and_then(|v| v.as_u64()) {
80        Some(l) if l >= 1 => l as u32,
81        Some(_) => {
82            return Response::error(
83                &req.id,
84                "invalid_request",
85                "extract_function: 'end_line' must be >= 1 (1-based)",
86            );
87        }
88        None => {
89            return Response::error(
90                &req.id,
91                "invalid_request",
92                "extract_function: missing required param 'end_line'",
93            );
94        }
95    };
96    let end_line = end_line_1based - 1;
97
98    if start_line >= end_line {
99        return Response::error(
100            &req.id,
101            "invalid_request",
102            format!(
103                "extract_function: start_line ({}) must be less than end_line ({})",
104                start_line, end_line
105            ),
106        );
107    }
108
109    // --- Validate file ---
110    let path = match ctx.validate_path(&req.id, Path::new(file)) {
111        Ok(path) => path,
112        Err(resp) => return resp,
113    };
114    if !path.exists() {
115        return Response::error(
116            &req.id,
117            "file_not_found",
118            format!("extract_function: file not found: {}", file),
119        );
120    }
121
122    // --- Language guard (D101) ---
123    let lang = match detect_language(&path) {
124        Some(l) => l,
125        None => {
126            return Response::error(
127                &req.id,
128                "unsupported_language",
129                "extract_function: unsupported file type",
130            );
131        }
132    };
133
134    if !matches!(
135        lang,
136        LangId::TypeScript | LangId::Tsx | LangId::JavaScript | LangId::Python
137    ) {
138        return Response::error(
139            &req.id,
140            "unsupported_language",
141            format!(
142                "extract_function: only TypeScript/JavaScript/Python files are supported, got {:?}",
143                lang
144            ),
145        );
146    }
147
148    // --- Read and parse ---
149    let source = match std::fs::read_to_string(&path) {
150        Ok(s) => s,
151        Err(e) => {
152            return Response::error(
153                &req.id,
154                "file_not_found",
155                format!("extract_function: {}: {}", file, e),
156            );
157        }
158    };
159
160    let grammar = grammar_for(lang);
161    let mut parser = Parser::new();
162    if parser.set_language(&grammar).is_err() {
163        return Response::error(
164            &req.id,
165            "parse_error",
166            "extract_function: failed to initialize parser",
167        );
168    }
169    let tree = match parser.parse(source.as_bytes(), None) {
170        Some(t) => t,
171        None => {
172            return Response::error(
173                &req.id,
174                "parse_error",
175                "extract_function: failed to parse file",
176            );
177        }
178    };
179
180    // --- Convert line range to byte range ---
181    let start_byte = edit::line_col_to_byte(&source, start_line, 0);
182    let end_byte = edit::line_col_to_byte(&source, end_line, 0);
183
184    if start_byte >= source.len() {
185        return Response::error(
186            &req.id,
187            "invalid_request",
188            format!(
189                "extract_function: start_line {} is beyond end of file",
190                start_line
191            ),
192        );
193    }
194
195    // --- Detect free variables ---
196    let free_vars = detect_free_variables(&source, &tree, start_byte, end_byte, lang);
197
198    // Check for this/self
199    if free_vars.has_this_or_self {
200        let keyword = match lang {
201            LangId::Python => "self",
202            _ => "this",
203        };
204        return Response::error(
205            &req.id,
206            "this_reference_in_range",
207            format!(
208                "extract_function: selected range contains '{}' reference. Consider extracting as a method instead, or move the {} usage outside the extracted range.",
209                keyword, keyword
210            ),
211        );
212    }
213
214    // --- Find enclosing function for return value detection ---
215    let root = tree.root_node();
216    let enclosing_fn = find_enclosing_function_node(&root, start_byte, lang);
217    let enclosing_fn_end_byte = enclosing_fn.map(|n| n.end_byte());
218
219    // --- Detect return value ---
220    let return_kind = detect_return_value(
221        &source,
222        &tree,
223        start_byte,
224        end_byte,
225        enclosing_fn_end_byte,
226        lang,
227    );
228
229    // --- Detect indentation ---
230    let indent_style = detect_indent(&source, lang);
231
232    // Determine base indent (indentation of the line where the enclosing function starts,
233    // or no indent if at module level)
234    let base_indent = if let Some(fn_node) = enclosing_fn {
235        let fn_start_line = fn_node.start_position().row;
236        get_line_indent(&source, fn_start_line as usize)
237    } else {
238        String::new()
239    };
240
241    // Determine the indent of the extracted range (for the call site)
242    let range_indent = get_line_indent(&source, start_line as usize);
243
244    // --- Extract body text ---
245    let body_text = &source[start_byte..end_byte];
246    let body_text = body_text.trim_end_matches('\n');
247
248    // --- Generate function and call site ---
249    let extracted_fn = generate_extracted_function(
250        name,
251        &free_vars.parameters,
252        &return_kind,
253        body_text,
254        &base_indent,
255        lang,
256        indent_style,
257    );
258
259    let call_site = generate_call_site(
260        name,
261        &free_vars.parameters,
262        &return_kind,
263        &range_indent,
264        lang,
265    );
266
267    // --- Compute new file content ---
268    // Insert the extracted function before the enclosing function (or at the range position
269    // if there's no enclosing function).
270    //
271    // For TS/JS, when the enclosing function is wrapped in `export` (or
272    // `export default`), the parser reports the function_declaration as
273    // a child of an export_statement. If we use fn_node.start_byte() as
274    // the insertion point, the `export` keyword stays attached to the
275    // start of the file content, and the extracted function gets inserted
276    // BETWEEN `export ` and `function`, producing:
277    //   `export function newFn(...) {} \n\n function originalFn(...) {}`
278    // The export keyword silently jumps from the original function to the
279    // extracted one. Fix: if the parent is an export_statement, insert
280    // before the export_statement instead.
281    let insert_pos = if let Some(fn_node) = enclosing_fn {
282        let mut anchor = fn_node;
283        if matches!(lang, LangId::TypeScript | LangId::Tsx | LangId::JavaScript) {
284            if let Some(parent) = fn_node.parent() {
285                if parent.kind() == "export_statement" {
286                    anchor = parent;
287                }
288            }
289        }
290        anchor.start_byte()
291    } else {
292        start_byte
293    };
294
295    let new_source = build_new_source(
296        &source,
297        insert_pos,
298        start_byte,
299        end_byte,
300        &extracted_fn,
301        &call_site,
302    );
303
304    // --- Return type string for the response ---
305    let return_type = match &return_kind {
306        ReturnKind::Expression(_) => "expression",
307        ReturnKind::Variable(_) => "variable",
308        ReturnKind::Void => "void",
309    };
310
311    // --- Dry-run check ---
312    if edit::is_dry_run(&req.params) {
313        let dr = edit::dry_run_diff(&source, &new_source, &path);
314        return Response::success(
315            &req.id,
316            serde_json::json!({
317                "ok": true,
318                "dry_run": true,
319                "diff": dr.diff,
320                "syntax_valid": dr.syntax_valid,
321                "parameters": free_vars.parameters,
322                "return_type": return_type,
323            }),
324        );
325    }
326
327    // --- Auto-backup before mutation ---
328    let backup_id = match edit::auto_backup(
329        ctx,
330        req.session(),
331        &path,
332        &format!("extract_function: {}", name),
333    ) {
334        Ok(id) => id,
335        Err(e) => {
336            return Response::error(&req.id, e.code(), e.to_string());
337        }
338    };
339
340    // --- Write, format, validate ---
341    let mut write_result =
342        match edit::write_format_validate(&path, &new_source, &ctx.config(), &req.params) {
343            Ok(r) => r,
344            Err(e) => {
345                return Response::error(&req.id, e.code(), e.to_string());
346            }
347        };
348
349    if let Ok(final_content) = std::fs::read_to_string(&path) {
350        write_result.lsp_outcome = ctx.lsp_post_write(&path, &final_content, &req.params);
351    }
352
353    let param_count = free_vars.parameters.len();
354    log::debug!(
355        "extract_function: {} from {}:{}-{} ({} params)",
356        name,
357        file,
358        start_line,
359        end_line,
360        param_count
361    );
362
363    // --- Build response ---
364    let syntax_valid = write_result.syntax_valid.unwrap_or(true);
365
366    let mut result = serde_json::json!({
367        "file": file,
368        "name": name,
369        "parameters": free_vars.parameters,
370        "return_type": return_type,
371        "syntax_valid": syntax_valid,
372        "formatted": write_result.formatted,
373    });
374
375    if let Some(ref reason) = write_result.format_skipped_reason {
376        result["format_skipped_reason"] = serde_json::json!(reason);
377    }
378
379    if write_result.validate_requested {
380        result["validation_errors"] = serde_json::json!(write_result.validation_errors);
381    }
382    if let Some(ref reason) = write_result.validate_skipped_reason {
383        result["validate_skipped_reason"] = serde_json::json!(reason);
384    }
385
386    if let Some(ref id) = backup_id {
387        result["backup_id"] = serde_json::json!(id);
388    }
389
390    write_result.append_lsp_diagnostics_to(&mut result);
391    Response::success(&req.id, result)
392}
393
394/// Find the enclosing function node for a byte position.
395fn find_enclosing_function_node<'a>(
396    root: &'a tree_sitter::Node<'a>,
397    byte_pos: usize,
398    lang: LangId,
399) -> Option<tree_sitter::Node<'a>> {
400    let fn_kinds: &[&str] = match lang {
401        LangId::TypeScript | LangId::Tsx | LangId::JavaScript => &[
402            "function_declaration",
403            "method_definition",
404            "arrow_function",
405            "lexical_declaration",
406        ],
407        LangId::Python => &["function_definition"],
408        _ => &[],
409    };
410
411    find_deepest_ancestor(root, byte_pos, fn_kinds)
412}
413
414/// Find the deepest ancestor node of given kinds containing byte_pos.
415fn find_deepest_ancestor<'a>(
416    node: &tree_sitter::Node<'a>,
417    byte_pos: usize,
418    kinds: &[&str],
419) -> Option<tree_sitter::Node<'a>> {
420    let mut result: Option<tree_sitter::Node<'a>> = None;
421    if kinds.contains(&node.kind()) && node.start_byte() <= byte_pos && byte_pos < node.end_byte() {
422        result = Some(*node);
423    }
424
425    let child_count = node.child_count();
426    for i in 0..child_count {
427        if let Some(child) = node.child(i as u32) {
428            if child.start_byte() <= byte_pos && byte_pos < child.end_byte() {
429                if let Some(deeper) = find_deepest_ancestor(&child, byte_pos, kinds) {
430                    result = Some(deeper);
431                }
432            }
433        }
434    }
435
436    result
437}
438
439/// Get the leading whitespace of a source line.
440fn get_line_indent(source: &str, line: usize) -> String {
441    source
442        .lines()
443        .nth(line)
444        .map(|l| {
445            let trimmed = l.trim_start();
446            l[..l.len() - trimmed.len()].to_string()
447        })
448        .unwrap_or_default()
449}
450
451/// Build the new source with the extracted function inserted and the range replaced.
452fn build_new_source(
453    source: &str,
454    insert_pos: usize,
455    range_start: usize,
456    range_end: usize,
457    extracted_fn: &str,
458    call_site: &str,
459) -> String {
460    let mut result = String::with_capacity(source.len() + extracted_fn.len() + 64);
461
462    // Everything before the insertion point
463    result.push_str(&source[..insert_pos]);
464
465    // The extracted function + blank line
466    result.push_str(extracted_fn);
467    result.push_str("\n\n");
468
469    // Everything between insert point and the range start (the original function
470    // declaration up to where extraction begins)
471    result.push_str(&source[insert_pos..range_start]);
472
473    // The call site replacing the original range
474    result.push_str(call_site);
475    result.push('\n');
476
477    // Everything after the range
478    result.push_str(&source[range_end..]);
479
480    result
481}
482
483// ---------------------------------------------------------------------------
484// Tests
485// ---------------------------------------------------------------------------
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490    use crate::protocol::RawRequest;
491
492    fn make_request(id: &str, command: &str, params: serde_json::Value) -> RawRequest {
493        RawRequest {
494            id: id.to_string(),
495            command: command.to_string(),
496            params,
497            lsp_hints: None,
498            session_id: None,
499        }
500    }
501
502    // --- Param validation ---
503
504    #[test]
505    fn extract_function_missing_file() {
506        let req = make_request("1", "extract_function", serde_json::json!({}));
507        let ctx = crate::context::AppContext::new(
508            Box::new(crate::parser::TreeSitterProvider::new()),
509            crate::config::Config::default(),
510        );
511        let resp = handle_extract_function(&req, &ctx);
512        let json = serde_json::to_value(&resp).unwrap();
513        assert_eq!(json["success"], false);
514        assert_eq!(json["code"], "invalid_request");
515        let msg = json["message"].as_str().unwrap();
516        assert!(
517            msg.contains("file"),
518            "message should mention 'file': {}",
519            msg
520        );
521    }
522
523    #[test]
524    fn extract_function_missing_name() {
525        let req = make_request(
526            "2",
527            "extract_function",
528            serde_json::json!({"file": "/tmp/test.ts"}),
529        );
530        let ctx = crate::context::AppContext::new(
531            Box::new(crate::parser::TreeSitterProvider::new()),
532            crate::config::Config::default(),
533        );
534        let resp = handle_extract_function(&req, &ctx);
535        let json = serde_json::to_value(&resp).unwrap();
536        assert_eq!(json["success"], false);
537        assert_eq!(json["code"], "invalid_request");
538        let msg = json["message"].as_str().unwrap();
539        assert!(
540            msg.contains("name"),
541            "message should mention 'name': {}",
542            msg
543        );
544    }
545
546    #[test]
547    fn extract_function_missing_start_line() {
548        let req = make_request(
549            "3",
550            "extract_function",
551            serde_json::json!({"file": "/tmp/test.ts", "name": "foo"}),
552        );
553        let ctx = crate::context::AppContext::new(
554            Box::new(crate::parser::TreeSitterProvider::new()),
555            crate::config::Config::default(),
556        );
557        let resp = handle_extract_function(&req, &ctx);
558        let json = serde_json::to_value(&resp).unwrap();
559        assert_eq!(json["success"], false);
560        assert_eq!(json["code"], "invalid_request");
561    }
562
563    #[test]
564    fn extract_function_unsupported_language() {
565        // Create a temp .rs file (Rust is not supported for extract_function)
566        let dir = std::env::temp_dir().join("aft_test_extract");
567        std::fs::create_dir_all(&dir).ok();
568        let file = dir.join("test.rs");
569        std::fs::write(&file, "fn main() {}").unwrap();
570
571        let req = make_request(
572            "4",
573            "extract_function",
574            serde_json::json!({
575                "file": file.display().to_string(),
576                "name": "foo",
577                "start_line": 1,
578                "end_line": 2,
579            }),
580        );
581        let ctx = crate::context::AppContext::new(
582            Box::new(crate::parser::TreeSitterProvider::new()),
583            crate::config::Config::default(),
584        );
585        let resp = handle_extract_function(&req, &ctx);
586        let json = serde_json::to_value(&resp).unwrap();
587        assert_eq!(json["success"], false);
588        assert_eq!(json["code"], "unsupported_language");
589
590        std::fs::remove_dir_all(&dir).ok();
591    }
592
593    #[test]
594    fn extract_function_invalid_line_range() {
595        let dir = std::env::temp_dir().join("aft_test_extract_range");
596        std::fs::create_dir_all(&dir).ok();
597        let file = dir.join("test.ts");
598        std::fs::write(&file, "const x = 1;\n").unwrap();
599
600        let req = make_request(
601            "5",
602            "extract_function",
603            serde_json::json!({
604                "file": file.display().to_string(),
605                "name": "foo",
606                "start_line": 6,
607                "end_line": 4,
608            }),
609        );
610        let ctx = crate::context::AppContext::new(
611            Box::new(crate::parser::TreeSitterProvider::new()),
612            crate::config::Config::default(),
613        );
614        let resp = handle_extract_function(&req, &ctx);
615        let json = serde_json::to_value(&resp).unwrap();
616        assert_eq!(json["success"], false);
617        assert_eq!(json["code"], "invalid_request");
618
619        std::fs::remove_dir_all(&dir).ok();
620    }
621
622    #[test]
623    fn extract_function_this_reference_error() {
624        let fixture = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
625            .join("tests/fixtures/extract_function/sample_this.ts");
626
627        let req = make_request(
628            "6",
629            "extract_function",
630            serde_json::json!({
631                "file": fixture.display().to_string(),
632                "name": "extracted",
633                "start_line": 5,
634                "end_line": 8,
635            }),
636        );
637        let ctx = crate::context::AppContext::new(
638            Box::new(crate::parser::TreeSitterProvider::new()),
639            crate::config::Config::default(),
640        );
641        let resp = handle_extract_function(&req, &ctx);
642        let json = serde_json::to_value(&resp).unwrap();
643        assert_eq!(json["success"], false);
644        assert_eq!(json["code"], "this_reference_in_range");
645    }
646
647    #[test]
648    fn extract_function_dry_run_returns_diff() {
649        let fixture = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
650            .join("tests/fixtures/extract_function/sample.ts");
651
652        let req = make_request(
653            "7",
654            "extract_function",
655            serde_json::json!({
656                "file": fixture.display().to_string(),
657                "name": "computeResult",
658                "start_line": 15,
659                "end_line": 17,
660                "dry_run": true,
661            }),
662        );
663        let ctx = crate::context::AppContext::new(
664            Box::new(crate::parser::TreeSitterProvider::new()),
665            crate::config::Config::default(),
666        );
667        let resp = handle_extract_function(&req, &ctx);
668        let json = serde_json::to_value(&resp).unwrap();
669        assert_eq!(json["success"], true);
670        assert_eq!(json["dry_run"], true);
671        assert!(json["diff"].as_str().is_some(), "should have diff");
672        assert!(json["parameters"].is_array(), "should have parameters");
673    }
674}