Skip to main content

aft/commands/
extract_function.rs

1//! Handler for the `extract_function` command: extract a range of code into
2//! a new function with auto-detected parameters and return value.
3//!
4//! Follows the edit_symbol.rs pattern: validate → parse → compute → dry_run
5//! check → auto_backup → write_format_validate → respond.
6
7use std::path::Path;
8
9use tree_sitter::Parser;
10
11use crate::context::AppContext;
12use crate::edit;
13use crate::extract::{
14    detect_free_variables, detect_return_value, generate_call_site, generate_extracted_function,
15    ReturnKind,
16};
17use crate::indent::detect_indent;
18use crate::parser::{detect_language, grammar_for, LangId};
19use crate::protocol::{RawRequest, Response};
20
21/// Handle an `extract_function` request.
22///
23/// Params:
24///   - `file` (string, required) — target file path
25///   - `name` (string, required) — name for the new function
26///   - `start_line` (u32, required) — first line of the range to extract (1-based)
27///   - `end_line` (u32, required) — last line (exclusive, 1-based) of the range to extract
28///   - `dry_run` (bool, optional) — if true, return diff without writing
29///
30/// Returns on success:
31///   `{ file, name, parameters, return_type, extracted_range, call_site_range, syntax_valid, backup_id }`
32///
33/// Error codes:
34///   - `unsupported_language` — file is not TS/JS/TSX/Python
35///   - `this_reference_in_range` — range contains `this`/`self`
36pub fn handle_extract_function(req: &RawRequest, ctx: &AppContext) -> Response {
37    // --- Extract params ---
38    let file = match req.params.get("file").and_then(|v| v.as_str()) {
39        Some(f) => f,
40        None => {
41            return Response::error(
42                &req.id,
43                "invalid_request",
44                "extract_function: missing required param 'file'",
45            );
46        }
47    };
48
49    let name = match req.params.get("name").and_then(|v| v.as_str()) {
50        Some(n) => n,
51        None => {
52            return Response::error(
53                &req.id,
54                "invalid_request",
55                "extract_function: missing required param 'name'",
56            );
57        }
58    };
59
60    let start_line_1based = match req.params.get("start_line").and_then(|v| v.as_u64()) {
61        Some(l) if l >= 1 => l as u32,
62        Some(_) => {
63            return Response::error(
64                &req.id,
65                "invalid_request",
66                "extract_function: 'start_line' must be >= 1 (1-based)",
67            );
68        }
69        None => {
70            return Response::error(
71                &req.id,
72                "invalid_request",
73                "extract_function: missing required param 'start_line'",
74            );
75        }
76    };
77    let start_line = start_line_1based - 1;
78
79    let end_line_1based = match req.params.get("end_line").and_then(|v| v.as_u64()) {
80        Some(l) if l >= 1 => l as u32,
81        Some(_) => {
82            return Response::error(
83                &req.id,
84                "invalid_request",
85                "extract_function: 'end_line' must be >= 1 (1-based)",
86            );
87        }
88        None => {
89            return Response::error(
90                &req.id,
91                "invalid_request",
92                "extract_function: missing required param 'end_line'",
93            );
94        }
95    };
96    let end_line = end_line_1based - 1;
97
98    if start_line >= end_line {
99        return Response::error(
100            &req.id,
101            "invalid_request",
102            format!(
103                "extract_function: start_line ({}) must be less than end_line ({})",
104                start_line, end_line
105            ),
106        );
107    }
108
109    // --- Validate file ---
110    let path = Path::new(file);
111    if !path.exists() {
112        return Response::error(
113            &req.id,
114            "file_not_found",
115            format!("extract_function: file not found: {}", file),
116        );
117    }
118
119    // --- Language guard (D101) ---
120    let lang = match detect_language(path) {
121        Some(l) => l,
122        None => {
123            return Response::error(
124                &req.id,
125                "unsupported_language",
126                "extract_function: unsupported file type",
127            );
128        }
129    };
130
131    if !matches!(
132        lang,
133        LangId::TypeScript | LangId::Tsx | LangId::JavaScript | LangId::Python
134    ) {
135        return Response::error(
136            &req.id,
137            "unsupported_language",
138            format!(
139                "extract_function: only TypeScript/JavaScript/Python files are supported, got {:?}",
140                lang
141            ),
142        );
143    }
144
145    // --- Read and parse ---
146    let source = match std::fs::read_to_string(path) {
147        Ok(s) => s,
148        Err(e) => {
149            return Response::error(
150                &req.id,
151                "file_not_found",
152                format!("extract_function: {}: {}", file, e),
153            );
154        }
155    };
156
157    let grammar = grammar_for(lang);
158    let mut parser = Parser::new();
159    if parser.set_language(&grammar).is_err() {
160        return Response::error(
161            &req.id,
162            "parse_error",
163            "extract_function: failed to initialize parser",
164        );
165    }
166    let tree = match parser.parse(source.as_bytes(), None) {
167        Some(t) => t,
168        None => {
169            return Response::error(
170                &req.id,
171                "parse_error",
172                "extract_function: failed to parse file",
173            );
174        }
175    };
176
177    // --- Convert line range to byte range ---
178    let start_byte = edit::line_col_to_byte(&source, start_line, 0);
179    let end_byte = edit::line_col_to_byte(&source, end_line, 0);
180
181    if start_byte >= source.len() {
182        return Response::error(
183            &req.id,
184            "invalid_request",
185            format!(
186                "extract_function: start_line {} is beyond end of file",
187                start_line
188            ),
189        );
190    }
191
192    // --- Detect free variables ---
193    let free_vars = detect_free_variables(&source, &tree, start_byte, end_byte, lang);
194
195    // Check for this/self
196    if free_vars.has_this_or_self {
197        let keyword = match lang {
198            LangId::Python => "self",
199            _ => "this",
200        };
201        return Response::error(
202            &req.id,
203            "this_reference_in_range",
204            format!(
205                "extract_function: selected range contains '{}' reference. Consider extracting as a method instead, or move the {} usage outside the extracted range.",
206                keyword, keyword
207            ),
208        );
209    }
210
211    // --- Find enclosing function for return value detection ---
212    let root = tree.root_node();
213    let enclosing_fn = find_enclosing_function_node(&root, start_byte, lang);
214    let enclosing_fn_end_byte = enclosing_fn.map(|n| n.end_byte());
215
216    // --- Detect return value ---
217    let return_kind = detect_return_value(
218        &source,
219        &tree,
220        start_byte,
221        end_byte,
222        enclosing_fn_end_byte,
223        lang,
224    );
225
226    // --- Detect indentation ---
227    let indent_style = detect_indent(&source, lang);
228
229    // Determine base indent (indentation of the line where the enclosing function starts,
230    // or no indent if at module level)
231    let base_indent = if let Some(fn_node) = enclosing_fn {
232        let fn_start_line = fn_node.start_position().row;
233        get_line_indent(&source, fn_start_line as usize)
234    } else {
235        String::new()
236    };
237
238    // Determine the indent of the extracted range (for the call site)
239    let range_indent = get_line_indent(&source, start_line as usize);
240
241    // --- Extract body text ---
242    let body_text = &source[start_byte..end_byte];
243    let body_text = body_text.trim_end_matches('\n');
244
245    // --- Generate function and call site ---
246    let extracted_fn = generate_extracted_function(
247        name,
248        &free_vars.parameters,
249        &return_kind,
250        body_text,
251        &base_indent,
252        lang,
253        indent_style,
254    );
255
256    let call_site = generate_call_site(
257        name,
258        &free_vars.parameters,
259        &return_kind,
260        &range_indent,
261        lang,
262    );
263
264    // --- Compute new file content ---
265    // Insert the extracted function before the enclosing function (or at the range position
266    // if there's no enclosing function).
267    let insert_pos = if let Some(fn_node) = enclosing_fn {
268        fn_node.start_byte()
269    } else {
270        start_byte
271    };
272
273    let new_source = build_new_source(
274        &source,
275        insert_pos,
276        start_byte,
277        end_byte,
278        &extracted_fn,
279        &call_site,
280    );
281
282    // --- Return type string for the response ---
283    let return_type = match &return_kind {
284        ReturnKind::Expression(_) => "expression",
285        ReturnKind::Variable(_) => "variable",
286        ReturnKind::Void => "void",
287    };
288
289    // --- Dry-run check ---
290    if edit::is_dry_run(&req.params) {
291        let dr = edit::dry_run_diff(&source, &new_source, path);
292        return Response::success(
293            &req.id,
294            serde_json::json!({
295                "ok": true,
296                "dry_run": true,
297                "diff": dr.diff,
298                "syntax_valid": dr.syntax_valid,
299                "parameters": free_vars.parameters,
300                "return_type": return_type,
301            }),
302        );
303    }
304
305    // --- Auto-backup before mutation ---
306    let backup_id = match edit::auto_backup(ctx, path, &format!("extract_function: {}", name)) {
307        Ok(id) => id,
308        Err(e) => {
309            return Response::error(&req.id, e.code(), e.to_string());
310        }
311    };
312
313    // --- Write, format, validate ---
314    let mut write_result =
315        match edit::write_format_validate(path, &new_source, &ctx.config(), &req.params) {
316            Ok(r) => r,
317            Err(e) => {
318                return Response::error(&req.id, e.code(), e.to_string());
319            }
320        };
321
322    if let Ok(final_content) = std::fs::read_to_string(path) {
323        write_result.lsp_diagnostics = ctx.lsp_post_write(path, &final_content, &req.params);
324    }
325
326    let param_count = free_vars.parameters.len();
327    eprintln!(
328        "[aft] extract_function: {} from {}:{}-{} ({} params)",
329        name, file, start_line, end_line, param_count
330    );
331
332    // --- Build response ---
333    let syntax_valid = write_result.syntax_valid.unwrap_or(true);
334
335    let mut result = serde_json::json!({
336        "file": file,
337        "name": name,
338        "parameters": free_vars.parameters,
339        "return_type": return_type,
340        "syntax_valid": syntax_valid,
341        "formatted": write_result.formatted,
342    });
343
344    if let Some(ref reason) = write_result.format_skipped_reason {
345        result["format_skipped_reason"] = serde_json::json!(reason);
346    }
347
348    if write_result.validate_requested {
349        result["validation_errors"] = serde_json::json!(write_result.validation_errors);
350    }
351    if let Some(ref reason) = write_result.validate_skipped_reason {
352        result["validate_skipped_reason"] = serde_json::json!(reason);
353    }
354
355    if let Some(ref id) = backup_id {
356        result["backup_id"] = serde_json::json!(id);
357    }
358
359    write_result.append_lsp_diagnostics_to(&mut result);
360    Response::success(&req.id, result)
361}
362
363/// Find the enclosing function node for a byte position.
364fn find_enclosing_function_node<'a>(
365    root: &'a tree_sitter::Node<'a>,
366    byte_pos: usize,
367    lang: LangId,
368) -> Option<tree_sitter::Node<'a>> {
369    let fn_kinds: &[&str] = match lang {
370        LangId::TypeScript | LangId::Tsx | LangId::JavaScript => &[
371            "function_declaration",
372            "method_definition",
373            "arrow_function",
374            "lexical_declaration",
375        ],
376        LangId::Python => &["function_definition"],
377        _ => &[],
378    };
379
380    find_deepest_ancestor(root, byte_pos, fn_kinds)
381}
382
383/// Find the deepest ancestor node of given kinds containing byte_pos.
384fn find_deepest_ancestor<'a>(
385    node: &tree_sitter::Node<'a>,
386    byte_pos: usize,
387    kinds: &[&str],
388) -> Option<tree_sitter::Node<'a>> {
389    let mut result: Option<tree_sitter::Node<'a>> = None;
390    if kinds.contains(&node.kind()) && node.start_byte() <= byte_pos && byte_pos < node.end_byte() {
391        result = Some(*node);
392    }
393
394    let child_count = node.child_count();
395    for i in 0..child_count {
396        if let Some(child) = node.child(i as u32) {
397            if child.start_byte() <= byte_pos && byte_pos < child.end_byte() {
398                if let Some(deeper) = find_deepest_ancestor(&child, byte_pos, kinds) {
399                    result = Some(deeper);
400                }
401            }
402        }
403    }
404
405    result
406}
407
408/// Get the leading whitespace of a source line.
409fn get_line_indent(source: &str, line: usize) -> String {
410    source
411        .lines()
412        .nth(line)
413        .map(|l| {
414            let trimmed = l.trim_start();
415            l[..l.len() - trimmed.len()].to_string()
416        })
417        .unwrap_or_default()
418}
419
420/// Build the new source with the extracted function inserted and the range replaced.
421fn build_new_source(
422    source: &str,
423    insert_pos: usize,
424    range_start: usize,
425    range_end: usize,
426    extracted_fn: &str,
427    call_site: &str,
428) -> String {
429    let mut result = String::with_capacity(source.len() + extracted_fn.len() + 64);
430
431    // Everything before the insertion point
432    result.push_str(&source[..insert_pos]);
433
434    // The extracted function + blank line
435    result.push_str(extracted_fn);
436    result.push_str("\n\n");
437
438    // Everything between insert point and the range start (the original function
439    // declaration up to where extraction begins)
440    result.push_str(&source[insert_pos..range_start]);
441
442    // The call site replacing the original range
443    result.push_str(call_site);
444    result.push('\n');
445
446    // Everything after the range
447    result.push_str(&source[range_end..]);
448
449    result
450}
451
452// ---------------------------------------------------------------------------
453// Tests
454// ---------------------------------------------------------------------------
455
456#[cfg(test)]
457mod tests {
458    use super::*;
459    use crate::protocol::RawRequest;
460
461    fn make_request(id: &str, command: &str, params: serde_json::Value) -> RawRequest {
462        RawRequest {
463            id: id.to_string(),
464            command: command.to_string(),
465            params,
466            lsp_hints: None,
467        }
468    }
469
470    // --- Param validation ---
471
472    #[test]
473    fn extract_function_missing_file() {
474        let req = make_request("1", "extract_function", serde_json::json!({}));
475        let ctx = crate::context::AppContext::new(
476            Box::new(crate::parser::TreeSitterProvider::new()),
477            crate::config::Config::default(),
478        );
479        let resp = handle_extract_function(&req, &ctx);
480        let json = serde_json::to_value(&resp).unwrap();
481        assert_eq!(json["ok"], false);
482        assert_eq!(json["code"], "invalid_request");
483        let msg = json["message"].as_str().unwrap();
484        assert!(
485            msg.contains("file"),
486            "message should mention 'file': {}",
487            msg
488        );
489    }
490
491    #[test]
492    fn extract_function_missing_name() {
493        let req = make_request(
494            "2",
495            "extract_function",
496            serde_json::json!({"file": "/tmp/test.ts"}),
497        );
498        let ctx = crate::context::AppContext::new(
499            Box::new(crate::parser::TreeSitterProvider::new()),
500            crate::config::Config::default(),
501        );
502        let resp = handle_extract_function(&req, &ctx);
503        let json = serde_json::to_value(&resp).unwrap();
504        assert_eq!(json["ok"], false);
505        assert_eq!(json["code"], "invalid_request");
506        let msg = json["message"].as_str().unwrap();
507        assert!(
508            msg.contains("name"),
509            "message should mention 'name': {}",
510            msg
511        );
512    }
513
514    #[test]
515    fn extract_function_missing_start_line() {
516        let req = make_request(
517            "3",
518            "extract_function",
519            serde_json::json!({"file": "/tmp/test.ts", "name": "foo"}),
520        );
521        let ctx = crate::context::AppContext::new(
522            Box::new(crate::parser::TreeSitterProvider::new()),
523            crate::config::Config::default(),
524        );
525        let resp = handle_extract_function(&req, &ctx);
526        let json = serde_json::to_value(&resp).unwrap();
527        assert_eq!(json["ok"], false);
528        assert_eq!(json["code"], "invalid_request");
529    }
530
531    #[test]
532    fn extract_function_unsupported_language() {
533        // Create a temp .rs file (Rust is not supported for extract_function)
534        let dir = std::env::temp_dir().join("aft_test_extract");
535        std::fs::create_dir_all(&dir).ok();
536        let file = dir.join("test.rs");
537        std::fs::write(&file, "fn main() {}").unwrap();
538
539        let req = make_request(
540            "4",
541            "extract_function",
542            serde_json::json!({
543                "file": file.display().to_string(),
544                "name": "foo",
545                "start_line": 1,
546                "end_line": 2,
547            }),
548        );
549        let ctx = crate::context::AppContext::new(
550            Box::new(crate::parser::TreeSitterProvider::new()),
551            crate::config::Config::default(),
552        );
553        let resp = handle_extract_function(&req, &ctx);
554        let json = serde_json::to_value(&resp).unwrap();
555        assert_eq!(json["ok"], false);
556        assert_eq!(json["code"], "unsupported_language");
557
558        std::fs::remove_dir_all(&dir).ok();
559    }
560
561    #[test]
562    fn extract_function_invalid_line_range() {
563        let dir = std::env::temp_dir().join("aft_test_extract_range");
564        std::fs::create_dir_all(&dir).ok();
565        let file = dir.join("test.ts");
566        std::fs::write(&file, "const x = 1;\n").unwrap();
567
568        let req = make_request(
569            "5",
570            "extract_function",
571            serde_json::json!({
572                "file": file.display().to_string(),
573                "name": "foo",
574                "start_line": 6,
575                "end_line": 4,
576            }),
577        );
578        let ctx = crate::context::AppContext::new(
579            Box::new(crate::parser::TreeSitterProvider::new()),
580            crate::config::Config::default(),
581        );
582        let resp = handle_extract_function(&req, &ctx);
583        let json = serde_json::to_value(&resp).unwrap();
584        assert_eq!(json["ok"], false);
585        assert_eq!(json["code"], "invalid_request");
586
587        std::fs::remove_dir_all(&dir).ok();
588    }
589
590    #[test]
591    fn extract_function_this_reference_error() {
592        let fixture = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
593            .join("tests/fixtures/extract_function/sample_this.ts");
594
595        let req = make_request(
596            "6",
597            "extract_function",
598            serde_json::json!({
599                "file": fixture.display().to_string(),
600                "name": "extracted",
601                "start_line": 5,
602                "end_line": 8,
603            }),
604        );
605        let ctx = crate::context::AppContext::new(
606            Box::new(crate::parser::TreeSitterProvider::new()),
607            crate::config::Config::default(),
608        );
609        let resp = handle_extract_function(&req, &ctx);
610        let json = serde_json::to_value(&resp).unwrap();
611        assert_eq!(json["ok"], false);
612        assert_eq!(json["code"], "this_reference_in_range");
613    }
614
615    #[test]
616    fn extract_function_dry_run_returns_diff() {
617        let fixture = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
618            .join("tests/fixtures/extract_function/sample.ts");
619
620        let req = make_request(
621            "7",
622            "extract_function",
623            serde_json::json!({
624                "file": fixture.display().to_string(),
625                "name": "computeResult",
626                "start_line": 15,
627                "end_line": 17,
628                "dry_run": true,
629            }),
630        );
631        let ctx = crate::context::AppContext::new(
632            Box::new(crate::parser::TreeSitterProvider::new()),
633            crate::config::Config::default(),
634        );
635        let resp = handle_extract_function(&req, &ctx);
636        let json = serde_json::to_value(&resp).unwrap();
637        assert_eq!(json["ok"], true);
638        assert_eq!(json["dry_run"], true);
639        assert!(json["diff"].as_str().is_some(), "should have diff");
640        assert!(json["parameters"].is_array(), "should have parameters");
641    }
642}