Skip to main content

aft/commands/
extract_function.rs

1//! Handler for the `extract_function` command: extract a range of code into
2//! a new function with auto-detected parameters and return value.
3//!
4//! Follows the edit_symbol.rs pattern: validate → parse → compute →
5//! auto_backup → write_format_validate → respond.
6
7use std::path::Path;
8
9use tree_sitter::Parser;
10
11use crate::context::AppContext;
12use crate::edit;
13use crate::extract::{
14    detect_free_variables, detect_return_value, generate_call_site, generate_extracted_function,
15    ReturnKind,
16};
17use crate::indent::detect_indent;
18use crate::parser::{detect_language, grammar_for, LangId};
19use crate::protocol::{RawRequest, Response};
20
21/// Handle an `extract_function` request.
22///
23/// Params:
24///   - `file` (string, required) — target file path
25///   - `name` (string, required) — name for the new function
26///   - `start_line` (u32, required) — first line of the range to extract (1-based)
27///   - `end_line` (u32, required) — last line (exclusive, 1-based) of the range to extract
28///
29/// Returns on success:
30///   `{ file, name, parameters, return_type, extracted_range, call_site_range, syntax_valid?, backup_id }`
31///
32/// `syntax_valid` is absent when syntax validation could not run.
33///
34/// Error codes:
35///   - `unsupported_language` — file is not TS/JS/TSX/Python
36///   - `this_reference_in_range` — range contains `this`/`self`
37pub fn handle_extract_function(req: &RawRequest, ctx: &AppContext) -> Response {
38    let op_id = crate::backup::new_op_id();
39    // --- Extract params ---
40    let file = match req.params.get("file").and_then(|v| v.as_str()) {
41        Some(f) => f,
42        None => {
43            return Response::error(
44                &req.id,
45                "invalid_request",
46                "extract_function: missing required param 'file'",
47            );
48        }
49    };
50
51    let name = match req.params.get("name").and_then(|v| v.as_str()) {
52        Some(n) => n,
53        None => {
54            return Response::error(
55                &req.id,
56                "invalid_request",
57                "extract_function: missing required param 'name'",
58            );
59        }
60    };
61
62    let start_line_1based = match req.params.get("start_line").and_then(|v| v.as_u64()) {
63        Some(l) if l >= 1 => l as u32,
64        Some(_) => {
65            return Response::error(
66                &req.id,
67                "invalid_request",
68                "extract_function: 'start_line' must be >= 1 (1-based)",
69            );
70        }
71        None => {
72            return Response::error(
73                &req.id,
74                "invalid_request",
75                "extract_function: missing required param 'start_line'",
76            );
77        }
78    };
79    let start_line = start_line_1based - 1;
80
81    let end_line_1based = match req.params.get("end_line").and_then(|v| v.as_u64()) {
82        Some(l) if l >= 1 => l as u32,
83        Some(_) => {
84            return Response::error(
85                &req.id,
86                "invalid_request",
87                "extract_function: 'end_line' must be >= 1 (1-based)",
88            );
89        }
90        None => {
91            return Response::error(
92                &req.id,
93                "invalid_request",
94                "extract_function: missing required param 'end_line'",
95            );
96        }
97    };
98    let end_line = end_line_1based - 1;
99
100    if start_line >= end_line {
101        return Response::error(
102            &req.id,
103            "invalid_request",
104            format!(
105                "extract_function: start_line ({}) must be less than end_line ({})",
106                start_line, end_line
107            ),
108        );
109    }
110
111    // --- Validate file ---
112    let path = match ctx.validate_path(&req.id, Path::new(file)) {
113        Ok(path) => path,
114        Err(resp) => return resp,
115    };
116    if !path.exists() {
117        return Response::error(
118            &req.id,
119            "file_not_found",
120            format!("extract_function: file not found: {}", file),
121        );
122    }
123
124    // --- Language guard (D101) ---
125    let lang = match detect_language(&path) {
126        Some(l) => l,
127        None => {
128            return Response::error(
129                &req.id,
130                "unsupported_language",
131                "extract_function: unsupported file type",
132            );
133        }
134    };
135
136    if !matches!(
137        lang,
138        LangId::TypeScript | LangId::Tsx | LangId::JavaScript | LangId::Python
139    ) {
140        return Response::error(
141            &req.id,
142            "unsupported_language",
143            format!(
144                "extract_function: only TypeScript/JavaScript/Python files are supported, got {:?}",
145                lang
146            ),
147        );
148    }
149
150    // --- Read and parse ---
151    let source = match std::fs::read_to_string(&path) {
152        Ok(s) => s,
153        Err(e) => {
154            return Response::error(
155                &req.id,
156                "file_not_found",
157                format!("extract_function: {}: {}", file, e),
158            );
159        }
160    };
161
162    let grammar = grammar_for(lang);
163    let mut parser = Parser::new();
164    if parser.set_language(&grammar).is_err() {
165        return Response::error(
166            &req.id,
167            "parse_error",
168            "extract_function: failed to initialize parser",
169        );
170    }
171    let tree = match parser.parse(source.as_bytes(), None) {
172        Some(t) => t,
173        None => {
174            return Response::error(
175                &req.id,
176                "parse_error",
177                "extract_function: failed to parse file",
178            );
179        }
180    };
181
182    // --- Convert line range to byte range ---
183    let start_byte = edit::line_col_to_byte(&source, start_line, 0);
184    let end_byte = edit::line_col_to_byte(&source, end_line, 0);
185
186    if start_byte >= source.len() {
187        return Response::error(
188            &req.id,
189            "invalid_request",
190            format!(
191                "extract_function: start_line {} is beyond end of file",
192                start_line
193            ),
194        );
195    }
196
197    // --- Detect free variables ---
198    let free_vars = detect_free_variables(&source, &tree, start_byte, end_byte, lang);
199
200    // Check for this/self
201    if free_vars.has_this_or_self {
202        let keyword = match lang {
203            LangId::Python => "self",
204            _ => "this",
205        };
206        return Response::error(
207            &req.id,
208            "this_reference_in_range",
209            format!(
210                "extract_function: selected range contains '{}' reference. Consider extracting as a method instead, or move the {} usage outside the extracted range.",
211                keyword, keyword
212            ),
213        );
214    }
215
216    // --- Find enclosing function for return value detection ---
217    let root = tree.root_node();
218    let enclosing_fn = find_enclosing_function_node(&root, start_byte, lang);
219    let enclosing_fn_end_byte = enclosing_fn.map(|n| n.end_byte());
220
221    // --- Detect return value ---
222    let return_kind = detect_return_value(
223        &source,
224        &tree,
225        start_byte,
226        end_byte,
227        enclosing_fn_end_byte,
228        lang,
229    );
230
231    // --- Detect indentation ---
232    let indent_style = detect_indent(&source, lang);
233
234    // Determine base indent (indentation of the line where the enclosing function starts,
235    // or no indent if at module level)
236    let base_indent = if let Some(fn_node) = enclosing_fn {
237        let fn_start_line = fn_node.start_position().row;
238        get_line_indent(&source, fn_start_line as usize)
239    } else {
240        String::new()
241    };
242
243    // Determine the indent of the extracted range (for the call site)
244    let range_indent = get_line_indent(&source, start_line as usize);
245
246    // --- Extract body text ---
247    let body_text = &source[start_byte..end_byte];
248    let body_text = body_text.trim_end_matches('\n');
249
250    // --- Generate function and call site ---
251    let extracted_fn = generate_extracted_function(
252        name,
253        &free_vars.parameters,
254        &return_kind,
255        body_text,
256        &base_indent,
257        lang,
258        indent_style,
259    );
260
261    let call_site = generate_call_site(
262        name,
263        &free_vars.parameters,
264        &return_kind,
265        &range_indent,
266        lang,
267    );
268
269    // --- Compute new file content ---
270    // Insert the extracted function before the enclosing function (or at the range position
271    // if there's no enclosing function).
272    //
273    // For TS/JS, when the enclosing function is wrapped in `export` (or
274    // `export default`), the parser reports the function_declaration as
275    // a child of an export_statement. If we use fn_node.start_byte() as
276    // the insertion point, the `export` keyword stays attached to the
277    // start of the file content, and the extracted function gets inserted
278    // BETWEEN `export ` and `function`, producing:
279    //   `export function newFn(...) {} \n\n function originalFn(...) {}`
280    // The export keyword silently jumps from the original function to the
281    // extracted one. Fix: if the parent is an export_statement, insert
282    // before the export_statement instead.
283    let insert_pos = if let Some(fn_node) = enclosing_fn {
284        let mut anchor = fn_node;
285        if matches!(lang, LangId::TypeScript | LangId::Tsx | LangId::JavaScript) {
286            if let Some(parent) = fn_node.parent() {
287                if parent.kind() == "export_statement" {
288                    anchor = parent;
289                }
290            }
291        }
292        anchor.start_byte()
293    } else {
294        start_byte
295    };
296
297    let new_source = build_new_source(
298        &source,
299        insert_pos,
300        start_byte,
301        end_byte,
302        &extracted_fn,
303        &call_site,
304    );
305
306    // --- Return type string for the response ---
307    let return_type = match &return_kind {
308        ReturnKind::Expression(_) => "expression",
309        ReturnKind::Variable(_) => "variable",
310        ReturnKind::Void => "void",
311    };
312
313    // --- Auto-backup before mutation ---
314    let backup_id = match edit::auto_backup(
315        ctx,
316        req.session(),
317        &path,
318        &format!("extract_function: {}", name),
319        Some(&op_id),
320    ) {
321        Ok(id) => id,
322        Err(e) => {
323            return Response::error(&req.id, e.code(), e.to_string());
324        }
325    };
326
327    // --- Write, format, validate ---
328    let mut write_result =
329        match edit::write_format_validate(&path, &new_source, &ctx.config(), &req.params) {
330            Ok(r) => r,
331            Err(e) => {
332                return Response::error(&req.id, e.code(), e.to_string());
333            }
334        };
335
336    if let Ok(final_content) = std::fs::read_to_string(&path) {
337        write_result.lsp_outcome = ctx.lsp_post_write(&path, &final_content, &req.params);
338    }
339
340    let param_count = free_vars.parameters.len();
341    log::debug!(
342        "extract_function: {} from {}:{}-{} ({} params)",
343        name,
344        file,
345        start_line,
346        end_line,
347        param_count
348    );
349
350    // --- Build response ---
351    let mut result = serde_json::json!({
352        "file": file,
353        "name": name,
354        "parameters": free_vars.parameters,
355        "return_type": return_type,
356        "formatted": write_result.formatted,
357    });
358
359    if let Some(valid) = write_result.syntax_valid {
360        result["syntax_valid"] = serde_json::json!(valid);
361    }
362
363    if let Some(ref reason) = write_result.format_skipped_reason {
364        result["format_skipped_reason"] = serde_json::json!(reason);
365    }
366
367    if write_result.validate_requested {
368        result["validation_errors"] = serde_json::json!(write_result.validation_errors);
369    }
370    if let Some(ref reason) = write_result.validate_skipped_reason {
371        result["validate_skipped_reason"] = serde_json::json!(reason);
372    }
373
374    if let Some(ref id) = backup_id {
375        result["backup_id"] = serde_json::json!(id);
376    }
377
378    write_result.append_lsp_diagnostics_to(&mut result);
379    Response::success(&req.id, result)
380}
381
382/// Find the enclosing function node for a byte position.
383fn find_enclosing_function_node<'a>(
384    root: &'a tree_sitter::Node<'a>,
385    byte_pos: usize,
386    lang: LangId,
387) -> Option<tree_sitter::Node<'a>> {
388    let fn_kinds: &[&str] = match lang {
389        LangId::TypeScript | LangId::Tsx | LangId::JavaScript => &[
390            "function_declaration",
391            "method_definition",
392            "arrow_function",
393            "lexical_declaration",
394        ],
395        LangId::Python => &["function_definition"],
396        _ => &[],
397    };
398
399    find_deepest_ancestor(root, byte_pos, fn_kinds)
400}
401
402/// Find the deepest ancestor node of given kinds containing byte_pos.
403fn find_deepest_ancestor<'a>(
404    node: &tree_sitter::Node<'a>,
405    byte_pos: usize,
406    kinds: &[&str],
407) -> Option<tree_sitter::Node<'a>> {
408    let mut result: Option<tree_sitter::Node<'a>> = None;
409    if kinds.contains(&node.kind()) && node.start_byte() <= byte_pos && byte_pos < node.end_byte() {
410        result = Some(*node);
411    }
412
413    let child_count = node.child_count();
414    for i in 0..child_count {
415        if let Some(child) = node.child(i as u32) {
416            if child.start_byte() <= byte_pos && byte_pos < child.end_byte() {
417                if let Some(deeper) = find_deepest_ancestor(&child, byte_pos, kinds) {
418                    result = Some(deeper);
419                }
420            }
421        }
422    }
423
424    result
425}
426
427/// Get the leading whitespace of a source line.
428fn get_line_indent(source: &str, line: usize) -> String {
429    source
430        .lines()
431        .nth(line)
432        .map(|l| {
433            let trimmed = l.trim_start();
434            l[..l.len() - trimmed.len()].to_string()
435        })
436        .unwrap_or_default()
437}
438
439/// Build the new source with the extracted function inserted and the range replaced.
440fn build_new_source(
441    source: &str,
442    insert_pos: usize,
443    range_start: usize,
444    range_end: usize,
445    extracted_fn: &str,
446    call_site: &str,
447) -> String {
448    let mut result = String::with_capacity(source.len() + extracted_fn.len() + 64);
449
450    // Everything before the insertion point
451    result.push_str(&source[..insert_pos]);
452
453    // The extracted function + blank line
454    result.push_str(extracted_fn);
455    result.push_str("\n\n");
456
457    // Everything between insert point and the range start (the original function
458    // declaration up to where extraction begins)
459    result.push_str(&source[insert_pos..range_start]);
460
461    // The call site replacing the original range
462    result.push_str(call_site);
463    result.push('\n');
464
465    // Everything after the range
466    result.push_str(&source[range_end..]);
467
468    result
469}
470
471// ---------------------------------------------------------------------------
472// Tests
473// ---------------------------------------------------------------------------
474
475#[cfg(test)]
476mod tests {
477    use super::*;
478    use crate::protocol::RawRequest;
479
480    fn make_request(id: &str, command: &str, params: serde_json::Value) -> RawRequest {
481        RawRequest {
482            id: id.to_string(),
483            command: command.to_string(),
484            params,
485            lsp_hints: None,
486            session_id: None,
487        }
488    }
489
490    // --- Param validation ---
491
492    #[test]
493    fn extract_function_missing_file() {
494        let req = make_request("1", "extract_function", serde_json::json!({}));
495        let ctx = crate::context::AppContext::new(
496            Box::new(crate::parser::TreeSitterProvider::new()),
497            crate::config::Config::default(),
498        );
499        let resp = handle_extract_function(&req, &ctx);
500        let json = serde_json::to_value(&resp).unwrap();
501        assert_eq!(json["success"], false);
502        assert_eq!(json["code"], "invalid_request");
503        let msg = json["message"].as_str().unwrap();
504        assert!(
505            msg.contains("file"),
506            "message should mention 'file': {}",
507            msg
508        );
509    }
510
511    #[test]
512    fn extract_function_missing_name() {
513        let req = make_request(
514            "2",
515            "extract_function",
516            serde_json::json!({"file": "/tmp/test.ts"}),
517        );
518        let ctx = crate::context::AppContext::new(
519            Box::new(crate::parser::TreeSitterProvider::new()),
520            crate::config::Config::default(),
521        );
522        let resp = handle_extract_function(&req, &ctx);
523        let json = serde_json::to_value(&resp).unwrap();
524        assert_eq!(json["success"], false);
525        assert_eq!(json["code"], "invalid_request");
526        let msg = json["message"].as_str().unwrap();
527        assert!(
528            msg.contains("name"),
529            "message should mention 'name': {}",
530            msg
531        );
532    }
533
534    #[test]
535    fn extract_function_missing_start_line() {
536        let req = make_request(
537            "3",
538            "extract_function",
539            serde_json::json!({"file": "/tmp/test.ts", "name": "foo"}),
540        );
541        let ctx = crate::context::AppContext::new(
542            Box::new(crate::parser::TreeSitterProvider::new()),
543            crate::config::Config::default(),
544        );
545        let resp = handle_extract_function(&req, &ctx);
546        let json = serde_json::to_value(&resp).unwrap();
547        assert_eq!(json["success"], false);
548        assert_eq!(json["code"], "invalid_request");
549    }
550
551    #[test]
552    fn extract_function_unsupported_language() {
553        // Create a temp .rs file (Rust is not supported for extract_function)
554        let dir = std::env::temp_dir().join("aft_test_extract");
555        std::fs::create_dir_all(&dir).ok();
556        let file = dir.join("test.rs");
557        std::fs::write(&file, "fn main() {}").unwrap();
558
559        let req = make_request(
560            "4",
561            "extract_function",
562            serde_json::json!({
563                "file": file.display().to_string(),
564                "name": "foo",
565                "start_line": 1,
566                "end_line": 2,
567            }),
568        );
569        let ctx = crate::context::AppContext::new(
570            Box::new(crate::parser::TreeSitterProvider::new()),
571            crate::config::Config::default(),
572        );
573        let resp = handle_extract_function(&req, &ctx);
574        let json = serde_json::to_value(&resp).unwrap();
575        assert_eq!(json["success"], false);
576        assert_eq!(json["code"], "unsupported_language");
577
578        std::fs::remove_dir_all(&dir).ok();
579    }
580
581    #[test]
582    fn extract_function_invalid_line_range() {
583        let dir = std::env::temp_dir().join("aft_test_extract_range");
584        std::fs::create_dir_all(&dir).ok();
585        let file = dir.join("test.ts");
586        std::fs::write(&file, "const x = 1;\n").unwrap();
587
588        let req = make_request(
589            "5",
590            "extract_function",
591            serde_json::json!({
592                "file": file.display().to_string(),
593                "name": "foo",
594                "start_line": 6,
595                "end_line": 4,
596            }),
597        );
598        let ctx = crate::context::AppContext::new(
599            Box::new(crate::parser::TreeSitterProvider::new()),
600            crate::config::Config::default(),
601        );
602        let resp = handle_extract_function(&req, &ctx);
603        let json = serde_json::to_value(&resp).unwrap();
604        assert_eq!(json["success"], false);
605        assert_eq!(json["code"], "invalid_request");
606
607        std::fs::remove_dir_all(&dir).ok();
608    }
609
610    #[test]
611    fn extract_function_this_reference_error() {
612        let fixture = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
613            .join("tests/fixtures/extract_function/sample_this.ts");
614
615        let req = make_request(
616            "6",
617            "extract_function",
618            serde_json::json!({
619                "file": fixture.display().to_string(),
620                "name": "extracted",
621                "start_line": 5,
622                "end_line": 8,
623            }),
624        );
625        let ctx = crate::context::AppContext::new(
626            Box::new(crate::parser::TreeSitterProvider::new()),
627            crate::config::Config::default(),
628        );
629        let resp = handle_extract_function(&req, &ctx);
630        let json = serde_json::to_value(&resp).unwrap();
631        assert_eq!(json["success"], false);
632        assert_eq!(json["code"], "this_reference_in_range");
633    }
634}