1use std::path::Path;
8
9use tree_sitter::Parser;
10
11use crate::context::AppContext;
12use crate::edit;
13use crate::extract::{
14 detect_free_variables, detect_return_value, generate_call_site, generate_extracted_function,
15 ReturnKind,
16};
17use crate::indent::detect_indent;
18use crate::parser::{detect_language, grammar_for, LangId};
19use crate::protocol::{RawRequest, Response};
20
21pub fn handle_extract_function(req: &RawRequest, ctx: &AppContext) -> Response {
37 let file = match req.params.get("file").and_then(|v| v.as_str()) {
39 Some(f) => f,
40 None => {
41 return Response::error(
42 &req.id,
43 "invalid_request",
44 "extract_function: missing required param 'file'",
45 );
46 }
47 };
48
49 let name = match req.params.get("name").and_then(|v| v.as_str()) {
50 Some(n) => n,
51 None => {
52 return Response::error(
53 &req.id,
54 "invalid_request",
55 "extract_function: missing required param 'name'",
56 );
57 }
58 };
59
60 let start_line_1based = match req.params.get("start_line").and_then(|v| v.as_u64()) {
61 Some(l) if l >= 1 => l as u32,
62 Some(_) => {
63 return Response::error(
64 &req.id,
65 "invalid_request",
66 "extract_function: 'start_line' must be >= 1 (1-based)",
67 );
68 }
69 None => {
70 return Response::error(
71 &req.id,
72 "invalid_request",
73 "extract_function: missing required param 'start_line'",
74 );
75 }
76 };
77 let start_line = start_line_1based - 1;
78
79 let end_line_1based = match req.params.get("end_line").and_then(|v| v.as_u64()) {
80 Some(l) if l >= 1 => l as u32,
81 Some(_) => {
82 return Response::error(
83 &req.id,
84 "invalid_request",
85 "extract_function: 'end_line' must be >= 1 (1-based)",
86 );
87 }
88 None => {
89 return Response::error(
90 &req.id,
91 "invalid_request",
92 "extract_function: missing required param 'end_line'",
93 );
94 }
95 };
96 let end_line = end_line_1based - 1;
97
98 if start_line >= end_line {
99 return Response::error(
100 &req.id,
101 "invalid_request",
102 format!(
103 "extract_function: start_line ({}) must be less than end_line ({})",
104 start_line, end_line
105 ),
106 );
107 }
108
109 let path = Path::new(file);
111 if !path.exists() {
112 return Response::error(
113 &req.id,
114 "file_not_found",
115 format!("extract_function: file not found: {}", file),
116 );
117 }
118
119 let lang = match detect_language(path) {
121 Some(l) => l,
122 None => {
123 return Response::error(
124 &req.id,
125 "unsupported_language",
126 "extract_function: unsupported file type",
127 );
128 }
129 };
130
131 if !matches!(
132 lang,
133 LangId::TypeScript | LangId::Tsx | LangId::JavaScript | LangId::Python
134 ) {
135 return Response::error(
136 &req.id,
137 "unsupported_language",
138 format!(
139 "extract_function: only TypeScript/JavaScript/Python files are supported, got {:?}",
140 lang
141 ),
142 );
143 }
144
145 let source = match std::fs::read_to_string(path) {
147 Ok(s) => s,
148 Err(e) => {
149 return Response::error(
150 &req.id,
151 "file_not_found",
152 format!("extract_function: {}: {}", file, e),
153 );
154 }
155 };
156
157 let grammar = grammar_for(lang);
158 let mut parser = Parser::new();
159 if parser.set_language(&grammar).is_err() {
160 return Response::error(
161 &req.id,
162 "parse_error",
163 "extract_function: failed to initialize parser",
164 );
165 }
166 let tree = match parser.parse(source.as_bytes(), None) {
167 Some(t) => t,
168 None => {
169 return Response::error(
170 &req.id,
171 "parse_error",
172 "extract_function: failed to parse file",
173 );
174 }
175 };
176
177 let start_byte = edit::line_col_to_byte(&source, start_line, 0);
179 let end_byte = edit::line_col_to_byte(&source, end_line, 0);
180
181 if start_byte >= source.len() {
182 return Response::error(
183 &req.id,
184 "invalid_request",
185 format!(
186 "extract_function: start_line {} is beyond end of file",
187 start_line
188 ),
189 );
190 }
191
192 let free_vars = detect_free_variables(&source, &tree, start_byte, end_byte, lang);
194
195 if free_vars.has_this_or_self {
197 let keyword = match lang {
198 LangId::Python => "self",
199 _ => "this",
200 };
201 return Response::error(
202 &req.id,
203 "this_reference_in_range",
204 format!(
205 "extract_function: selected range contains '{}' reference. Consider extracting as a method instead, or move the {} usage outside the extracted range.",
206 keyword, keyword
207 ),
208 );
209 }
210
211 let root = tree.root_node();
213 let enclosing_fn = find_enclosing_function_node(&root, start_byte, lang);
214 let enclosing_fn_end_byte = enclosing_fn.map(|n| n.end_byte());
215
216 let return_kind = detect_return_value(
218 &source,
219 &tree,
220 start_byte,
221 end_byte,
222 enclosing_fn_end_byte,
223 lang,
224 );
225
226 let indent_style = detect_indent(&source, lang);
228
229 let base_indent = if let Some(fn_node) = enclosing_fn {
232 let fn_start_line = fn_node.start_position().row;
233 get_line_indent(&source, fn_start_line as usize)
234 } else {
235 String::new()
236 };
237
238 let range_indent = get_line_indent(&source, start_line as usize);
240
241 let body_text = &source[start_byte..end_byte];
243 let body_text = body_text.trim_end_matches('\n');
244
245 let extracted_fn = generate_extracted_function(
247 name,
248 &free_vars.parameters,
249 &return_kind,
250 body_text,
251 &base_indent,
252 lang,
253 indent_style,
254 );
255
256 let call_site = generate_call_site(
257 name,
258 &free_vars.parameters,
259 &return_kind,
260 &range_indent,
261 lang,
262 );
263
264 let insert_pos = if let Some(fn_node) = enclosing_fn {
268 fn_node.start_byte()
269 } else {
270 start_byte
271 };
272
273 let new_source = build_new_source(
274 &source,
275 insert_pos,
276 start_byte,
277 end_byte,
278 &extracted_fn,
279 &call_site,
280 );
281
282 let return_type = match &return_kind {
284 ReturnKind::Expression(_) => "expression",
285 ReturnKind::Variable(_) => "variable",
286 ReturnKind::Void => "void",
287 };
288
289 if edit::is_dry_run(&req.params) {
291 let dr = edit::dry_run_diff(&source, &new_source, path);
292 return Response::success(
293 &req.id,
294 serde_json::json!({
295 "ok": true,
296 "dry_run": true,
297 "diff": dr.diff,
298 "syntax_valid": dr.syntax_valid,
299 "parameters": free_vars.parameters,
300 "return_type": return_type,
301 }),
302 );
303 }
304
305 let backup_id = match edit::auto_backup(ctx, path, &format!("extract_function: {}", name)) {
307 Ok(id) => id,
308 Err(e) => {
309 return Response::error(&req.id, e.code(), e.to_string());
310 }
311 };
312
313 let mut write_result =
315 match edit::write_format_validate(path, &new_source, &ctx.config(), &req.params) {
316 Ok(r) => r,
317 Err(e) => {
318 return Response::error(&req.id, e.code(), e.to_string());
319 }
320 };
321
322 if let Ok(final_content) = std::fs::read_to_string(path) {
323 write_result.lsp_diagnostics = ctx.lsp_post_write(path, &final_content, &req.params);
324 }
325
326 let param_count = free_vars.parameters.len();
327 eprintln!(
328 "[aft] extract_function: {} from {}:{}-{} ({} params)",
329 name, file, start_line, end_line, param_count
330 );
331
332 let syntax_valid = write_result.syntax_valid.unwrap_or(true);
334
335 let mut result = serde_json::json!({
336 "file": file,
337 "name": name,
338 "parameters": free_vars.parameters,
339 "return_type": return_type,
340 "syntax_valid": syntax_valid,
341 "formatted": write_result.formatted,
342 });
343
344 if let Some(ref reason) = write_result.format_skipped_reason {
345 result["format_skipped_reason"] = serde_json::json!(reason);
346 }
347
348 if write_result.validate_requested {
349 result["validation_errors"] = serde_json::json!(write_result.validation_errors);
350 }
351 if let Some(ref reason) = write_result.validate_skipped_reason {
352 result["validate_skipped_reason"] = serde_json::json!(reason);
353 }
354
355 if let Some(ref id) = backup_id {
356 result["backup_id"] = serde_json::json!(id);
357 }
358
359 write_result.append_lsp_diagnostics_to(&mut result);
360 Response::success(&req.id, result)
361}
362
363fn find_enclosing_function_node<'a>(
365 root: &'a tree_sitter::Node<'a>,
366 byte_pos: usize,
367 lang: LangId,
368) -> Option<tree_sitter::Node<'a>> {
369 let fn_kinds: &[&str] = match lang {
370 LangId::TypeScript | LangId::Tsx | LangId::JavaScript => &[
371 "function_declaration",
372 "method_definition",
373 "arrow_function",
374 "lexical_declaration",
375 ],
376 LangId::Python => &["function_definition"],
377 _ => &[],
378 };
379
380 find_deepest_ancestor(root, byte_pos, fn_kinds)
381}
382
383fn find_deepest_ancestor<'a>(
385 node: &tree_sitter::Node<'a>,
386 byte_pos: usize,
387 kinds: &[&str],
388) -> Option<tree_sitter::Node<'a>> {
389 let mut result: Option<tree_sitter::Node<'a>> = None;
390 if kinds.contains(&node.kind()) && node.start_byte() <= byte_pos && byte_pos < node.end_byte() {
391 result = Some(*node);
392 }
393
394 let child_count = node.child_count();
395 for i in 0..child_count {
396 if let Some(child) = node.child(i as u32) {
397 if child.start_byte() <= byte_pos && byte_pos < child.end_byte() {
398 if let Some(deeper) = find_deepest_ancestor(&child, byte_pos, kinds) {
399 result = Some(deeper);
400 }
401 }
402 }
403 }
404
405 result
406}
407
408fn get_line_indent(source: &str, line: usize) -> String {
410 source
411 .lines()
412 .nth(line)
413 .map(|l| {
414 let trimmed = l.trim_start();
415 l[..l.len() - trimmed.len()].to_string()
416 })
417 .unwrap_or_default()
418}
419
420fn build_new_source(
422 source: &str,
423 insert_pos: usize,
424 range_start: usize,
425 range_end: usize,
426 extracted_fn: &str,
427 call_site: &str,
428) -> String {
429 let mut result = String::with_capacity(source.len() + extracted_fn.len() + 64);
430
431 result.push_str(&source[..insert_pos]);
433
434 result.push_str(extracted_fn);
436 result.push_str("\n\n");
437
438 result.push_str(&source[insert_pos..range_start]);
441
442 result.push_str(call_site);
444 result.push('\n');
445
446 result.push_str(&source[range_end..]);
448
449 result
450}
451
452#[cfg(test)]
457mod tests {
458 use super::*;
459 use crate::protocol::RawRequest;
460
461 fn make_request(id: &str, command: &str, params: serde_json::Value) -> RawRequest {
462 RawRequest {
463 id: id.to_string(),
464 command: command.to_string(),
465 params,
466 lsp_hints: None,
467 }
468 }
469
470 #[test]
473 fn extract_function_missing_file() {
474 let req = make_request("1", "extract_function", serde_json::json!({}));
475 let ctx = crate::context::AppContext::new(
476 Box::new(crate::parser::TreeSitterProvider::new()),
477 crate::config::Config::default(),
478 );
479 let resp = handle_extract_function(&req, &ctx);
480 let json = serde_json::to_value(&resp).unwrap();
481 assert_eq!(json["ok"], false);
482 assert_eq!(json["code"], "invalid_request");
483 let msg = json["message"].as_str().unwrap();
484 assert!(
485 msg.contains("file"),
486 "message should mention 'file': {}",
487 msg
488 );
489 }
490
491 #[test]
492 fn extract_function_missing_name() {
493 let req = make_request(
494 "2",
495 "extract_function",
496 serde_json::json!({"file": "/tmp/test.ts"}),
497 );
498 let ctx = crate::context::AppContext::new(
499 Box::new(crate::parser::TreeSitterProvider::new()),
500 crate::config::Config::default(),
501 );
502 let resp = handle_extract_function(&req, &ctx);
503 let json = serde_json::to_value(&resp).unwrap();
504 assert_eq!(json["ok"], false);
505 assert_eq!(json["code"], "invalid_request");
506 let msg = json["message"].as_str().unwrap();
507 assert!(
508 msg.contains("name"),
509 "message should mention 'name': {}",
510 msg
511 );
512 }
513
514 #[test]
515 fn extract_function_missing_start_line() {
516 let req = make_request(
517 "3",
518 "extract_function",
519 serde_json::json!({"file": "/tmp/test.ts", "name": "foo"}),
520 );
521 let ctx = crate::context::AppContext::new(
522 Box::new(crate::parser::TreeSitterProvider::new()),
523 crate::config::Config::default(),
524 );
525 let resp = handle_extract_function(&req, &ctx);
526 let json = serde_json::to_value(&resp).unwrap();
527 assert_eq!(json["ok"], false);
528 assert_eq!(json["code"], "invalid_request");
529 }
530
531 #[test]
532 fn extract_function_unsupported_language() {
533 let dir = std::env::temp_dir().join("aft_test_extract");
535 std::fs::create_dir_all(&dir).ok();
536 let file = dir.join("test.rs");
537 std::fs::write(&file, "fn main() {}").unwrap();
538
539 let req = make_request(
540 "4",
541 "extract_function",
542 serde_json::json!({
543 "file": file.display().to_string(),
544 "name": "foo",
545 "start_line": 1,
546 "end_line": 2,
547 }),
548 );
549 let ctx = crate::context::AppContext::new(
550 Box::new(crate::parser::TreeSitterProvider::new()),
551 crate::config::Config::default(),
552 );
553 let resp = handle_extract_function(&req, &ctx);
554 let json = serde_json::to_value(&resp).unwrap();
555 assert_eq!(json["ok"], false);
556 assert_eq!(json["code"], "unsupported_language");
557
558 std::fs::remove_dir_all(&dir).ok();
559 }
560
561 #[test]
562 fn extract_function_invalid_line_range() {
563 let dir = std::env::temp_dir().join("aft_test_extract_range");
564 std::fs::create_dir_all(&dir).ok();
565 let file = dir.join("test.ts");
566 std::fs::write(&file, "const x = 1;\n").unwrap();
567
568 let req = make_request(
569 "5",
570 "extract_function",
571 serde_json::json!({
572 "file": file.display().to_string(),
573 "name": "foo",
574 "start_line": 6,
575 "end_line": 4,
576 }),
577 );
578 let ctx = crate::context::AppContext::new(
579 Box::new(crate::parser::TreeSitterProvider::new()),
580 crate::config::Config::default(),
581 );
582 let resp = handle_extract_function(&req, &ctx);
583 let json = serde_json::to_value(&resp).unwrap();
584 assert_eq!(json["ok"], false);
585 assert_eq!(json["code"], "invalid_request");
586
587 std::fs::remove_dir_all(&dir).ok();
588 }
589
590 #[test]
591 fn extract_function_this_reference_error() {
592 let fixture = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
593 .join("tests/fixtures/extract_function/sample_this.ts");
594
595 let req = make_request(
596 "6",
597 "extract_function",
598 serde_json::json!({
599 "file": fixture.display().to_string(),
600 "name": "extracted",
601 "start_line": 5,
602 "end_line": 8,
603 }),
604 );
605 let ctx = crate::context::AppContext::new(
606 Box::new(crate::parser::TreeSitterProvider::new()),
607 crate::config::Config::default(),
608 );
609 let resp = handle_extract_function(&req, &ctx);
610 let json = serde_json::to_value(&resp).unwrap();
611 assert_eq!(json["ok"], false);
612 assert_eq!(json["code"], "this_reference_in_range");
613 }
614
615 #[test]
616 fn extract_function_dry_run_returns_diff() {
617 let fixture = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
618 .join("tests/fixtures/extract_function/sample.ts");
619
620 let req = make_request(
621 "7",
622 "extract_function",
623 serde_json::json!({
624 "file": fixture.display().to_string(),
625 "name": "computeResult",
626 "start_line": 15,
627 "end_line": 17,
628 "dry_run": true,
629 }),
630 );
631 let ctx = crate::context::AppContext::new(
632 Box::new(crate::parser::TreeSitterProvider::new()),
633 crate::config::Config::default(),
634 );
635 let resp = handle_extract_function(&req, &ctx);
636 let json = serde_json::to_value(&resp).unwrap();
637 assert_eq!(json["ok"], true);
638 assert_eq!(json["dry_run"], true);
639 assert!(json["diff"].as_str().is_some(), "should have diff");
640 assert!(json["parameters"].is_array(), "should have parameters");
641 }
642}