1use std::path::Path;
8
9use tree_sitter::Parser;
10
11use crate::context::AppContext;
12use crate::edit;
13use crate::extract::{
14 detect_free_variables, detect_return_value, generate_call_site, generate_extracted_function,
15 ReturnKind,
16};
17use crate::indent::detect_indent;
18use crate::parser::{detect_language, grammar_for, LangId};
19use crate::protocol::{RawRequest, Response};
20
21pub fn handle_extract_function(req: &RawRequest, ctx: &AppContext) -> Response {
37 let file = match req.params.get("file").and_then(|v| v.as_str()) {
39 Some(f) => f,
40 None => {
41 return Response::error(
42 &req.id,
43 "invalid_request",
44 "extract_function: missing required param 'file'",
45 );
46 }
47 };
48
49 let name = match req.params.get("name").and_then(|v| v.as_str()) {
50 Some(n) => n,
51 None => {
52 return Response::error(
53 &req.id,
54 "invalid_request",
55 "extract_function: missing required param 'name'",
56 );
57 }
58 };
59
60 let start_line_1based = match req.params.get("start_line").and_then(|v| v.as_u64()) {
61 Some(l) if l >= 1 => l as u32,
62 Some(_) => {
63 return Response::error(
64 &req.id,
65 "invalid_request",
66 "extract_function: 'start_line' must be >= 1 (1-based)",
67 );
68 }
69 None => {
70 return Response::error(
71 &req.id,
72 "invalid_request",
73 "extract_function: missing required param 'start_line'",
74 );
75 }
76 };
77 let start_line = start_line_1based - 1;
78
79 let end_line_1based = match req.params.get("end_line").and_then(|v| v.as_u64()) {
80 Some(l) if l >= 1 => l as u32,
81 Some(_) => {
82 return Response::error(
83 &req.id,
84 "invalid_request",
85 "extract_function: 'end_line' must be >= 1 (1-based)",
86 );
87 }
88 None => {
89 return Response::error(
90 &req.id,
91 "invalid_request",
92 "extract_function: missing required param 'end_line'",
93 );
94 }
95 };
96 let end_line = end_line_1based - 1;
97
98 if start_line >= end_line {
99 return Response::error(
100 &req.id,
101 "invalid_request",
102 format!(
103 "extract_function: start_line ({}) must be less than end_line ({})",
104 start_line, end_line
105 ),
106 );
107 }
108
109 let path = match ctx.validate_path(&req.id, Path::new(file)) {
111 Ok(path) => path,
112 Err(resp) => return resp,
113 };
114 if !path.exists() {
115 return Response::error(
116 &req.id,
117 "file_not_found",
118 format!("extract_function: file not found: {}", file),
119 );
120 }
121
122 let lang = match detect_language(&path) {
124 Some(l) => l,
125 None => {
126 return Response::error(
127 &req.id,
128 "unsupported_language",
129 "extract_function: unsupported file type",
130 );
131 }
132 };
133
134 if !matches!(
135 lang,
136 LangId::TypeScript | LangId::Tsx | LangId::JavaScript | LangId::Python
137 ) {
138 return Response::error(
139 &req.id,
140 "unsupported_language",
141 format!(
142 "extract_function: only TypeScript/JavaScript/Python files are supported, got {:?}",
143 lang
144 ),
145 );
146 }
147
148 let source = match std::fs::read_to_string(&path) {
150 Ok(s) => s,
151 Err(e) => {
152 return Response::error(
153 &req.id,
154 "file_not_found",
155 format!("extract_function: {}: {}", file, e),
156 );
157 }
158 };
159
160 let grammar = grammar_for(lang);
161 let mut parser = Parser::new();
162 if parser.set_language(&grammar).is_err() {
163 return Response::error(
164 &req.id,
165 "parse_error",
166 "extract_function: failed to initialize parser",
167 );
168 }
169 let tree = match parser.parse(source.as_bytes(), None) {
170 Some(t) => t,
171 None => {
172 return Response::error(
173 &req.id,
174 "parse_error",
175 "extract_function: failed to parse file",
176 );
177 }
178 };
179
180 let start_byte = edit::line_col_to_byte(&source, start_line, 0);
182 let end_byte = edit::line_col_to_byte(&source, end_line, 0);
183
184 if start_byte >= source.len() {
185 return Response::error(
186 &req.id,
187 "invalid_request",
188 format!(
189 "extract_function: start_line {} is beyond end of file",
190 start_line
191 ),
192 );
193 }
194
195 let free_vars = detect_free_variables(&source, &tree, start_byte, end_byte, lang);
197
198 if free_vars.has_this_or_self {
200 let keyword = match lang {
201 LangId::Python => "self",
202 _ => "this",
203 };
204 return Response::error(
205 &req.id,
206 "this_reference_in_range",
207 format!(
208 "extract_function: selected range contains '{}' reference. Consider extracting as a method instead, or move the {} usage outside the extracted range.",
209 keyword, keyword
210 ),
211 );
212 }
213
214 let root = tree.root_node();
216 let enclosing_fn = find_enclosing_function_node(&root, start_byte, lang);
217 let enclosing_fn_end_byte = enclosing_fn.map(|n| n.end_byte());
218
219 let return_kind = detect_return_value(
221 &source,
222 &tree,
223 start_byte,
224 end_byte,
225 enclosing_fn_end_byte,
226 lang,
227 );
228
229 let indent_style = detect_indent(&source, lang);
231
232 let base_indent = if let Some(fn_node) = enclosing_fn {
235 let fn_start_line = fn_node.start_position().row;
236 get_line_indent(&source, fn_start_line as usize)
237 } else {
238 String::new()
239 };
240
241 let range_indent = get_line_indent(&source, start_line as usize);
243
244 let body_text = &source[start_byte..end_byte];
246 let body_text = body_text.trim_end_matches('\n');
247
248 let extracted_fn = generate_extracted_function(
250 name,
251 &free_vars.parameters,
252 &return_kind,
253 body_text,
254 &base_indent,
255 lang,
256 indent_style,
257 );
258
259 let call_site = generate_call_site(
260 name,
261 &free_vars.parameters,
262 &return_kind,
263 &range_indent,
264 lang,
265 );
266
267 let insert_pos = if let Some(fn_node) = enclosing_fn {
282 let mut anchor = fn_node;
283 if matches!(lang, LangId::TypeScript | LangId::Tsx | LangId::JavaScript) {
284 if let Some(parent) = fn_node.parent() {
285 if parent.kind() == "export_statement" {
286 anchor = parent;
287 }
288 }
289 }
290 anchor.start_byte()
291 } else {
292 start_byte
293 };
294
295 let new_source = build_new_source(
296 &source,
297 insert_pos,
298 start_byte,
299 end_byte,
300 &extracted_fn,
301 &call_site,
302 );
303
304 let return_type = match &return_kind {
306 ReturnKind::Expression(_) => "expression",
307 ReturnKind::Variable(_) => "variable",
308 ReturnKind::Void => "void",
309 };
310
311 if edit::is_dry_run(&req.params) {
313 let dr = edit::dry_run_diff(&source, &new_source, &path);
314 return Response::success(
315 &req.id,
316 serde_json::json!({
317 "ok": true,
318 "dry_run": true,
319 "diff": dr.diff,
320 "syntax_valid": dr.syntax_valid,
321 "parameters": free_vars.parameters,
322 "return_type": return_type,
323 }),
324 );
325 }
326
327 let backup_id = match edit::auto_backup(
329 ctx,
330 req.session(),
331 &path,
332 &format!("extract_function: {}", name),
333 ) {
334 Ok(id) => id,
335 Err(e) => {
336 return Response::error(&req.id, e.code(), e.to_string());
337 }
338 };
339
340 let mut write_result =
342 match edit::write_format_validate(&path, &new_source, &ctx.config(), &req.params) {
343 Ok(r) => r,
344 Err(e) => {
345 return Response::error(&req.id, e.code(), e.to_string());
346 }
347 };
348
349 if let Ok(final_content) = std::fs::read_to_string(&path) {
350 write_result.lsp_outcome = ctx.lsp_post_write(&path, &final_content, &req.params);
351 }
352
353 let param_count = free_vars.parameters.len();
354 log::debug!(
355 "extract_function: {} from {}:{}-{} ({} params)",
356 name,
357 file,
358 start_line,
359 end_line,
360 param_count
361 );
362
363 let syntax_valid = write_result.syntax_valid.unwrap_or(true);
365
366 let mut result = serde_json::json!({
367 "file": file,
368 "name": name,
369 "parameters": free_vars.parameters,
370 "return_type": return_type,
371 "syntax_valid": syntax_valid,
372 "formatted": write_result.formatted,
373 });
374
375 if let Some(ref reason) = write_result.format_skipped_reason {
376 result["format_skipped_reason"] = serde_json::json!(reason);
377 }
378
379 if write_result.validate_requested {
380 result["validation_errors"] = serde_json::json!(write_result.validation_errors);
381 }
382 if let Some(ref reason) = write_result.validate_skipped_reason {
383 result["validate_skipped_reason"] = serde_json::json!(reason);
384 }
385
386 if let Some(ref id) = backup_id {
387 result["backup_id"] = serde_json::json!(id);
388 }
389
390 write_result.append_lsp_diagnostics_to(&mut result);
391 Response::success(&req.id, result)
392}
393
394fn find_enclosing_function_node<'a>(
396 root: &'a tree_sitter::Node<'a>,
397 byte_pos: usize,
398 lang: LangId,
399) -> Option<tree_sitter::Node<'a>> {
400 let fn_kinds: &[&str] = match lang {
401 LangId::TypeScript | LangId::Tsx | LangId::JavaScript => &[
402 "function_declaration",
403 "method_definition",
404 "arrow_function",
405 "lexical_declaration",
406 ],
407 LangId::Python => &["function_definition"],
408 _ => &[],
409 };
410
411 find_deepest_ancestor(root, byte_pos, fn_kinds)
412}
413
414fn find_deepest_ancestor<'a>(
416 node: &tree_sitter::Node<'a>,
417 byte_pos: usize,
418 kinds: &[&str],
419) -> Option<tree_sitter::Node<'a>> {
420 let mut result: Option<tree_sitter::Node<'a>> = None;
421 if kinds.contains(&node.kind()) && node.start_byte() <= byte_pos && byte_pos < node.end_byte() {
422 result = Some(*node);
423 }
424
425 let child_count = node.child_count();
426 for i in 0..child_count {
427 if let Some(child) = node.child(i as u32) {
428 if child.start_byte() <= byte_pos && byte_pos < child.end_byte() {
429 if let Some(deeper) = find_deepest_ancestor(&child, byte_pos, kinds) {
430 result = Some(deeper);
431 }
432 }
433 }
434 }
435
436 result
437}
438
439fn get_line_indent(source: &str, line: usize) -> String {
441 source
442 .lines()
443 .nth(line)
444 .map(|l| {
445 let trimmed = l.trim_start();
446 l[..l.len() - trimmed.len()].to_string()
447 })
448 .unwrap_or_default()
449}
450
451fn build_new_source(
453 source: &str,
454 insert_pos: usize,
455 range_start: usize,
456 range_end: usize,
457 extracted_fn: &str,
458 call_site: &str,
459) -> String {
460 let mut result = String::with_capacity(source.len() + extracted_fn.len() + 64);
461
462 result.push_str(&source[..insert_pos]);
464
465 result.push_str(extracted_fn);
467 result.push_str("\n\n");
468
469 result.push_str(&source[insert_pos..range_start]);
472
473 result.push_str(call_site);
475 result.push('\n');
476
477 result.push_str(&source[range_end..]);
479
480 result
481}
482
483#[cfg(test)]
488mod tests {
489 use super::*;
490 use crate::protocol::RawRequest;
491
492 fn make_request(id: &str, command: &str, params: serde_json::Value) -> RawRequest {
493 RawRequest {
494 id: id.to_string(),
495 command: command.to_string(),
496 params,
497 lsp_hints: None,
498 session_id: None,
499 }
500 }
501
502 #[test]
505 fn extract_function_missing_file() {
506 let req = make_request("1", "extract_function", serde_json::json!({}));
507 let ctx = crate::context::AppContext::new(
508 Box::new(crate::parser::TreeSitterProvider::new()),
509 crate::config::Config::default(),
510 );
511 let resp = handle_extract_function(&req, &ctx);
512 let json = serde_json::to_value(&resp).unwrap();
513 assert_eq!(json["success"], false);
514 assert_eq!(json["code"], "invalid_request");
515 let msg = json["message"].as_str().unwrap();
516 assert!(
517 msg.contains("file"),
518 "message should mention 'file': {}",
519 msg
520 );
521 }
522
523 #[test]
524 fn extract_function_missing_name() {
525 let req = make_request(
526 "2",
527 "extract_function",
528 serde_json::json!({"file": "/tmp/test.ts"}),
529 );
530 let ctx = crate::context::AppContext::new(
531 Box::new(crate::parser::TreeSitterProvider::new()),
532 crate::config::Config::default(),
533 );
534 let resp = handle_extract_function(&req, &ctx);
535 let json = serde_json::to_value(&resp).unwrap();
536 assert_eq!(json["success"], false);
537 assert_eq!(json["code"], "invalid_request");
538 let msg = json["message"].as_str().unwrap();
539 assert!(
540 msg.contains("name"),
541 "message should mention 'name': {}",
542 msg
543 );
544 }
545
546 #[test]
547 fn extract_function_missing_start_line() {
548 let req = make_request(
549 "3",
550 "extract_function",
551 serde_json::json!({"file": "/tmp/test.ts", "name": "foo"}),
552 );
553 let ctx = crate::context::AppContext::new(
554 Box::new(crate::parser::TreeSitterProvider::new()),
555 crate::config::Config::default(),
556 );
557 let resp = handle_extract_function(&req, &ctx);
558 let json = serde_json::to_value(&resp).unwrap();
559 assert_eq!(json["success"], false);
560 assert_eq!(json["code"], "invalid_request");
561 }
562
563 #[test]
564 fn extract_function_unsupported_language() {
565 let dir = std::env::temp_dir().join("aft_test_extract");
567 std::fs::create_dir_all(&dir).ok();
568 let file = dir.join("test.rs");
569 std::fs::write(&file, "fn main() {}").unwrap();
570
571 let req = make_request(
572 "4",
573 "extract_function",
574 serde_json::json!({
575 "file": file.display().to_string(),
576 "name": "foo",
577 "start_line": 1,
578 "end_line": 2,
579 }),
580 );
581 let ctx = crate::context::AppContext::new(
582 Box::new(crate::parser::TreeSitterProvider::new()),
583 crate::config::Config::default(),
584 );
585 let resp = handle_extract_function(&req, &ctx);
586 let json = serde_json::to_value(&resp).unwrap();
587 assert_eq!(json["success"], false);
588 assert_eq!(json["code"], "unsupported_language");
589
590 std::fs::remove_dir_all(&dir).ok();
591 }
592
593 #[test]
594 fn extract_function_invalid_line_range() {
595 let dir = std::env::temp_dir().join("aft_test_extract_range");
596 std::fs::create_dir_all(&dir).ok();
597 let file = dir.join("test.ts");
598 std::fs::write(&file, "const x = 1;\n").unwrap();
599
600 let req = make_request(
601 "5",
602 "extract_function",
603 serde_json::json!({
604 "file": file.display().to_string(),
605 "name": "foo",
606 "start_line": 6,
607 "end_line": 4,
608 }),
609 );
610 let ctx = crate::context::AppContext::new(
611 Box::new(crate::parser::TreeSitterProvider::new()),
612 crate::config::Config::default(),
613 );
614 let resp = handle_extract_function(&req, &ctx);
615 let json = serde_json::to_value(&resp).unwrap();
616 assert_eq!(json["success"], false);
617 assert_eq!(json["code"], "invalid_request");
618
619 std::fs::remove_dir_all(&dir).ok();
620 }
621
622 #[test]
623 fn extract_function_this_reference_error() {
624 let fixture = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
625 .join("tests/fixtures/extract_function/sample_this.ts");
626
627 let req = make_request(
628 "6",
629 "extract_function",
630 serde_json::json!({
631 "file": fixture.display().to_string(),
632 "name": "extracted",
633 "start_line": 5,
634 "end_line": 8,
635 }),
636 );
637 let ctx = crate::context::AppContext::new(
638 Box::new(crate::parser::TreeSitterProvider::new()),
639 crate::config::Config::default(),
640 );
641 let resp = handle_extract_function(&req, &ctx);
642 let json = serde_json::to_value(&resp).unwrap();
643 assert_eq!(json["success"], false);
644 assert_eq!(json["code"], "this_reference_in_range");
645 }
646
647 #[test]
648 fn extract_function_dry_run_returns_diff() {
649 let fixture = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
650 .join("tests/fixtures/extract_function/sample.ts");
651
652 let req = make_request(
653 "7",
654 "extract_function",
655 serde_json::json!({
656 "file": fixture.display().to_string(),
657 "name": "computeResult",
658 "start_line": 15,
659 "end_line": 17,
660 "dry_run": true,
661 }),
662 );
663 let ctx = crate::context::AppContext::new(
664 Box::new(crate::parser::TreeSitterProvider::new()),
665 crate::config::Config::default(),
666 );
667 let resp = handle_extract_function(&req, &ctx);
668 let json = serde_json::to_value(&resp).unwrap();
669 assert_eq!(json["success"], true);
670 assert_eq!(json["dry_run"], true);
671 assert!(json["diff"].as_str().is_some(), "should have diff");
672 assert!(json["parameters"].is_array(), "should have parameters");
673 }
674}