Skip to main content

aft/commands/
extract_function.rs

1//! Handler for the `extract_function` command: extract a range of code into
2//! a new function with auto-detected parameters and return value.
3//!
4//! Follows the edit_symbol.rs pattern: validate → parse → compute →
5//! auto_backup → write_format_validate → respond.
6
7use std::path::Path;
8
9use tree_sitter::Parser;
10
11use crate::context::AppContext;
12use crate::edit;
13use crate::extract::{
14    detect_free_variables, detect_return_value, generate_call_site, generate_extracted_function,
15    ReturnKind,
16};
17use crate::indent::detect_indent;
18use crate::parser::{detect_language, grammar_for, LangId};
19use crate::protocol::{RawRequest, Response};
20
21/// Handle an `extract_function` request.
22///
23/// Params:
24///   - `file` (string, required) — target file path
25///   - `name` (string, required) — name for the new function
26///   - `start_line` (u32, required) — first line of the range to extract (1-based)
27///   - `end_line` (u32, required) — last line (exclusive, 1-based) of the range to extract
28///
29/// Returns on success:
30///   `{ file, name, parameters, return_type, extracted_range, call_site_range, syntax_valid, backup_id }`
31///
32/// Error codes:
33///   - `unsupported_language` — file is not TS/JS/TSX/Python
34///   - `this_reference_in_range` — range contains `this`/`self`
35pub fn handle_extract_function(req: &RawRequest, ctx: &AppContext) -> Response {
36    // --- Extract params ---
37    let file = match req.params.get("file").and_then(|v| v.as_str()) {
38        Some(f) => f,
39        None => {
40            return Response::error(
41                &req.id,
42                "invalid_request",
43                "extract_function: missing required param 'file'",
44            );
45        }
46    };
47
48    let name = match req.params.get("name").and_then(|v| v.as_str()) {
49        Some(n) => n,
50        None => {
51            return Response::error(
52                &req.id,
53                "invalid_request",
54                "extract_function: missing required param 'name'",
55            );
56        }
57    };
58
59    let start_line_1based = match req.params.get("start_line").and_then(|v| v.as_u64()) {
60        Some(l) if l >= 1 => l as u32,
61        Some(_) => {
62            return Response::error(
63                &req.id,
64                "invalid_request",
65                "extract_function: 'start_line' must be >= 1 (1-based)",
66            );
67        }
68        None => {
69            return Response::error(
70                &req.id,
71                "invalid_request",
72                "extract_function: missing required param 'start_line'",
73            );
74        }
75    };
76    let start_line = start_line_1based - 1;
77
78    let end_line_1based = match req.params.get("end_line").and_then(|v| v.as_u64()) {
79        Some(l) if l >= 1 => l as u32,
80        Some(_) => {
81            return Response::error(
82                &req.id,
83                "invalid_request",
84                "extract_function: 'end_line' must be >= 1 (1-based)",
85            );
86        }
87        None => {
88            return Response::error(
89                &req.id,
90                "invalid_request",
91                "extract_function: missing required param 'end_line'",
92            );
93        }
94    };
95    let end_line = end_line_1based - 1;
96
97    if start_line >= end_line {
98        return Response::error(
99            &req.id,
100            "invalid_request",
101            format!(
102                "extract_function: start_line ({}) must be less than end_line ({})",
103                start_line, end_line
104            ),
105        );
106    }
107
108    // --- Validate file ---
109    let path = match ctx.validate_path(&req.id, Path::new(file)) {
110        Ok(path) => path,
111        Err(resp) => return resp,
112    };
113    if !path.exists() {
114        return Response::error(
115            &req.id,
116            "file_not_found",
117            format!("extract_function: file not found: {}", file),
118        );
119    }
120
121    // --- Language guard (D101) ---
122    let lang = match detect_language(&path) {
123        Some(l) => l,
124        None => {
125            return Response::error(
126                &req.id,
127                "unsupported_language",
128                "extract_function: unsupported file type",
129            );
130        }
131    };
132
133    if !matches!(
134        lang,
135        LangId::TypeScript | LangId::Tsx | LangId::JavaScript | LangId::Python
136    ) {
137        return Response::error(
138            &req.id,
139            "unsupported_language",
140            format!(
141                "extract_function: only TypeScript/JavaScript/Python files are supported, got {:?}",
142                lang
143            ),
144        );
145    }
146
147    // --- Read and parse ---
148    let source = match std::fs::read_to_string(&path) {
149        Ok(s) => s,
150        Err(e) => {
151            return Response::error(
152                &req.id,
153                "file_not_found",
154                format!("extract_function: {}: {}", file, e),
155            );
156        }
157    };
158
159    let grammar = grammar_for(lang);
160    let mut parser = Parser::new();
161    if parser.set_language(&grammar).is_err() {
162        return Response::error(
163            &req.id,
164            "parse_error",
165            "extract_function: failed to initialize parser",
166        );
167    }
168    let tree = match parser.parse(source.as_bytes(), None) {
169        Some(t) => t,
170        None => {
171            return Response::error(
172                &req.id,
173                "parse_error",
174                "extract_function: failed to parse file",
175            );
176        }
177    };
178
179    // --- Convert line range to byte range ---
180    let start_byte = edit::line_col_to_byte(&source, start_line, 0);
181    let end_byte = edit::line_col_to_byte(&source, end_line, 0);
182
183    if start_byte >= source.len() {
184        return Response::error(
185            &req.id,
186            "invalid_request",
187            format!(
188                "extract_function: start_line {} is beyond end of file",
189                start_line
190            ),
191        );
192    }
193
194    // --- Detect free variables ---
195    let free_vars = detect_free_variables(&source, &tree, start_byte, end_byte, lang);
196
197    // Check for this/self
198    if free_vars.has_this_or_self {
199        let keyword = match lang {
200            LangId::Python => "self",
201            _ => "this",
202        };
203        return Response::error(
204            &req.id,
205            "this_reference_in_range",
206            format!(
207                "extract_function: selected range contains '{}' reference. Consider extracting as a method instead, or move the {} usage outside the extracted range.",
208                keyword, keyword
209            ),
210        );
211    }
212
213    // --- Find enclosing function for return value detection ---
214    let root = tree.root_node();
215    let enclosing_fn = find_enclosing_function_node(&root, start_byte, lang);
216    let enclosing_fn_end_byte = enclosing_fn.map(|n| n.end_byte());
217
218    // --- Detect return value ---
219    let return_kind = detect_return_value(
220        &source,
221        &tree,
222        start_byte,
223        end_byte,
224        enclosing_fn_end_byte,
225        lang,
226    );
227
228    // --- Detect indentation ---
229    let indent_style = detect_indent(&source, lang);
230
231    // Determine base indent (indentation of the line where the enclosing function starts,
232    // or no indent if at module level)
233    let base_indent = if let Some(fn_node) = enclosing_fn {
234        let fn_start_line = fn_node.start_position().row;
235        get_line_indent(&source, fn_start_line as usize)
236    } else {
237        String::new()
238    };
239
240    // Determine the indent of the extracted range (for the call site)
241    let range_indent = get_line_indent(&source, start_line as usize);
242
243    // --- Extract body text ---
244    let body_text = &source[start_byte..end_byte];
245    let body_text = body_text.trim_end_matches('\n');
246
247    // --- Generate function and call site ---
248    let extracted_fn = generate_extracted_function(
249        name,
250        &free_vars.parameters,
251        &return_kind,
252        body_text,
253        &base_indent,
254        lang,
255        indent_style,
256    );
257
258    let call_site = generate_call_site(
259        name,
260        &free_vars.parameters,
261        &return_kind,
262        &range_indent,
263        lang,
264    );
265
266    // --- Compute new file content ---
267    // Insert the extracted function before the enclosing function (or at the range position
268    // if there's no enclosing function).
269    //
270    // For TS/JS, when the enclosing function is wrapped in `export` (or
271    // `export default`), the parser reports the function_declaration as
272    // a child of an export_statement. If we use fn_node.start_byte() as
273    // the insertion point, the `export` keyword stays attached to the
274    // start of the file content, and the extracted function gets inserted
275    // BETWEEN `export ` and `function`, producing:
276    //   `export function newFn(...) {} \n\n function originalFn(...) {}`
277    // The export keyword silently jumps from the original function to the
278    // extracted one. Fix: if the parent is an export_statement, insert
279    // before the export_statement instead.
280    let insert_pos = if let Some(fn_node) = enclosing_fn {
281        let mut anchor = fn_node;
282        if matches!(lang, LangId::TypeScript | LangId::Tsx | LangId::JavaScript) {
283            if let Some(parent) = fn_node.parent() {
284                if parent.kind() == "export_statement" {
285                    anchor = parent;
286                }
287            }
288        }
289        anchor.start_byte()
290    } else {
291        start_byte
292    };
293
294    let new_source = build_new_source(
295        &source,
296        insert_pos,
297        start_byte,
298        end_byte,
299        &extracted_fn,
300        &call_site,
301    );
302
303    // --- Return type string for the response ---
304    let return_type = match &return_kind {
305        ReturnKind::Expression(_) => "expression",
306        ReturnKind::Variable(_) => "variable",
307        ReturnKind::Void => "void",
308    };
309
310    // --- Auto-backup before mutation ---
311    let backup_id = match edit::auto_backup(
312        ctx,
313        req.session(),
314        &path,
315        &format!("extract_function: {}", name),
316    ) {
317        Ok(id) => id,
318        Err(e) => {
319            return Response::error(&req.id, e.code(), e.to_string());
320        }
321    };
322
323    // --- Write, format, validate ---
324    let mut write_result =
325        match edit::write_format_validate(&path, &new_source, &ctx.config(), &req.params) {
326            Ok(r) => r,
327            Err(e) => {
328                return Response::error(&req.id, e.code(), e.to_string());
329            }
330        };
331
332    if let Ok(final_content) = std::fs::read_to_string(&path) {
333        write_result.lsp_outcome = ctx.lsp_post_write(&path, &final_content, &req.params);
334    }
335
336    let param_count = free_vars.parameters.len();
337    log::debug!(
338        "extract_function: {} from {}:{}-{} ({} params)",
339        name,
340        file,
341        start_line,
342        end_line,
343        param_count
344    );
345
346    // --- Build response ---
347    let syntax_valid = write_result.syntax_valid.unwrap_or(true);
348
349    let mut result = serde_json::json!({
350        "file": file,
351        "name": name,
352        "parameters": free_vars.parameters,
353        "return_type": return_type,
354        "syntax_valid": syntax_valid,
355        "formatted": write_result.formatted,
356    });
357
358    if let Some(ref reason) = write_result.format_skipped_reason {
359        result["format_skipped_reason"] = serde_json::json!(reason);
360    }
361
362    if write_result.validate_requested {
363        result["validation_errors"] = serde_json::json!(write_result.validation_errors);
364    }
365    if let Some(ref reason) = write_result.validate_skipped_reason {
366        result["validate_skipped_reason"] = serde_json::json!(reason);
367    }
368
369    if let Some(ref id) = backup_id {
370        result["backup_id"] = serde_json::json!(id);
371    }
372
373    write_result.append_lsp_diagnostics_to(&mut result);
374    Response::success(&req.id, result)
375}
376
377/// Find the enclosing function node for a byte position.
378fn find_enclosing_function_node<'a>(
379    root: &'a tree_sitter::Node<'a>,
380    byte_pos: usize,
381    lang: LangId,
382) -> Option<tree_sitter::Node<'a>> {
383    let fn_kinds: &[&str] = match lang {
384        LangId::TypeScript | LangId::Tsx | LangId::JavaScript => &[
385            "function_declaration",
386            "method_definition",
387            "arrow_function",
388            "lexical_declaration",
389        ],
390        LangId::Python => &["function_definition"],
391        _ => &[],
392    };
393
394    find_deepest_ancestor(root, byte_pos, fn_kinds)
395}
396
397/// Find the deepest ancestor node of given kinds containing byte_pos.
398fn find_deepest_ancestor<'a>(
399    node: &tree_sitter::Node<'a>,
400    byte_pos: usize,
401    kinds: &[&str],
402) -> Option<tree_sitter::Node<'a>> {
403    let mut result: Option<tree_sitter::Node<'a>> = None;
404    if kinds.contains(&node.kind()) && node.start_byte() <= byte_pos && byte_pos < node.end_byte() {
405        result = Some(*node);
406    }
407
408    let child_count = node.child_count();
409    for i in 0..child_count {
410        if let Some(child) = node.child(i as u32) {
411            if child.start_byte() <= byte_pos && byte_pos < child.end_byte() {
412                if let Some(deeper) = find_deepest_ancestor(&child, byte_pos, kinds) {
413                    result = Some(deeper);
414                }
415            }
416        }
417    }
418
419    result
420}
421
422/// Get the leading whitespace of a source line.
423fn get_line_indent(source: &str, line: usize) -> String {
424    source
425        .lines()
426        .nth(line)
427        .map(|l| {
428            let trimmed = l.trim_start();
429            l[..l.len() - trimmed.len()].to_string()
430        })
431        .unwrap_or_default()
432}
433
434/// Build the new source with the extracted function inserted and the range replaced.
435fn build_new_source(
436    source: &str,
437    insert_pos: usize,
438    range_start: usize,
439    range_end: usize,
440    extracted_fn: &str,
441    call_site: &str,
442) -> String {
443    let mut result = String::with_capacity(source.len() + extracted_fn.len() + 64);
444
445    // Everything before the insertion point
446    result.push_str(&source[..insert_pos]);
447
448    // The extracted function + blank line
449    result.push_str(extracted_fn);
450    result.push_str("\n\n");
451
452    // Everything between insert point and the range start (the original function
453    // declaration up to where extraction begins)
454    result.push_str(&source[insert_pos..range_start]);
455
456    // The call site replacing the original range
457    result.push_str(call_site);
458    result.push('\n');
459
460    // Everything after the range
461    result.push_str(&source[range_end..]);
462
463    result
464}
465
466// ---------------------------------------------------------------------------
467// Tests
468// ---------------------------------------------------------------------------
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473    use crate::protocol::RawRequest;
474
475    fn make_request(id: &str, command: &str, params: serde_json::Value) -> RawRequest {
476        RawRequest {
477            id: id.to_string(),
478            command: command.to_string(),
479            params,
480            lsp_hints: None,
481            session_id: None,
482        }
483    }
484
485    // --- Param validation ---
486
487    #[test]
488    fn extract_function_missing_file() {
489        let req = make_request("1", "extract_function", serde_json::json!({}));
490        let ctx = crate::context::AppContext::new(
491            Box::new(crate::parser::TreeSitterProvider::new()),
492            crate::config::Config::default(),
493        );
494        let resp = handle_extract_function(&req, &ctx);
495        let json = serde_json::to_value(&resp).unwrap();
496        assert_eq!(json["success"], false);
497        assert_eq!(json["code"], "invalid_request");
498        let msg = json["message"].as_str().unwrap();
499        assert!(
500            msg.contains("file"),
501            "message should mention 'file': {}",
502            msg
503        );
504    }
505
506    #[test]
507    fn extract_function_missing_name() {
508        let req = make_request(
509            "2",
510            "extract_function",
511            serde_json::json!({"file": "/tmp/test.ts"}),
512        );
513        let ctx = crate::context::AppContext::new(
514            Box::new(crate::parser::TreeSitterProvider::new()),
515            crate::config::Config::default(),
516        );
517        let resp = handle_extract_function(&req, &ctx);
518        let json = serde_json::to_value(&resp).unwrap();
519        assert_eq!(json["success"], false);
520        assert_eq!(json["code"], "invalid_request");
521        let msg = json["message"].as_str().unwrap();
522        assert!(
523            msg.contains("name"),
524            "message should mention 'name': {}",
525            msg
526        );
527    }
528
529    #[test]
530    fn extract_function_missing_start_line() {
531        let req = make_request(
532            "3",
533            "extract_function",
534            serde_json::json!({"file": "/tmp/test.ts", "name": "foo"}),
535        );
536        let ctx = crate::context::AppContext::new(
537            Box::new(crate::parser::TreeSitterProvider::new()),
538            crate::config::Config::default(),
539        );
540        let resp = handle_extract_function(&req, &ctx);
541        let json = serde_json::to_value(&resp).unwrap();
542        assert_eq!(json["success"], false);
543        assert_eq!(json["code"], "invalid_request");
544    }
545
546    #[test]
547    fn extract_function_unsupported_language() {
548        // Create a temp .rs file (Rust is not supported for extract_function)
549        let dir = std::env::temp_dir().join("aft_test_extract");
550        std::fs::create_dir_all(&dir).ok();
551        let file = dir.join("test.rs");
552        std::fs::write(&file, "fn main() {}").unwrap();
553
554        let req = make_request(
555            "4",
556            "extract_function",
557            serde_json::json!({
558                "file": file.display().to_string(),
559                "name": "foo",
560                "start_line": 1,
561                "end_line": 2,
562            }),
563        );
564        let ctx = crate::context::AppContext::new(
565            Box::new(crate::parser::TreeSitterProvider::new()),
566            crate::config::Config::default(),
567        );
568        let resp = handle_extract_function(&req, &ctx);
569        let json = serde_json::to_value(&resp).unwrap();
570        assert_eq!(json["success"], false);
571        assert_eq!(json["code"], "unsupported_language");
572
573        std::fs::remove_dir_all(&dir).ok();
574    }
575
576    #[test]
577    fn extract_function_invalid_line_range() {
578        let dir = std::env::temp_dir().join("aft_test_extract_range");
579        std::fs::create_dir_all(&dir).ok();
580        let file = dir.join("test.ts");
581        std::fs::write(&file, "const x = 1;\n").unwrap();
582
583        let req = make_request(
584            "5",
585            "extract_function",
586            serde_json::json!({
587                "file": file.display().to_string(),
588                "name": "foo",
589                "start_line": 6,
590                "end_line": 4,
591            }),
592        );
593        let ctx = crate::context::AppContext::new(
594            Box::new(crate::parser::TreeSitterProvider::new()),
595            crate::config::Config::default(),
596        );
597        let resp = handle_extract_function(&req, &ctx);
598        let json = serde_json::to_value(&resp).unwrap();
599        assert_eq!(json["success"], false);
600        assert_eq!(json["code"], "invalid_request");
601
602        std::fs::remove_dir_all(&dir).ok();
603    }
604
605    #[test]
606    fn extract_function_this_reference_error() {
607        let fixture = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
608            .join("tests/fixtures/extract_function/sample_this.ts");
609
610        let req = make_request(
611            "6",
612            "extract_function",
613            serde_json::json!({
614                "file": fixture.display().to_string(),
615                "name": "extracted",
616                "start_line": 5,
617                "end_line": 8,
618            }),
619        );
620        let ctx = crate::context::AppContext::new(
621            Box::new(crate::parser::TreeSitterProvider::new()),
622            crate::config::Config::default(),
623        );
624        let resp = handle_extract_function(&req, &ctx);
625        let json = serde_json::to_value(&resp).unwrap();
626        assert_eq!(json["success"], false);
627        assert_eq!(json["code"], "this_reference_in_range");
628    }
629}