1use std::path::Path;
8
9use tree_sitter::Parser;
10
11use crate::context::AppContext;
12use crate::edit;
13use crate::extract::{
14 detect_free_variables, detect_return_value, generate_call_site, generate_extracted_function,
15 ReturnKind,
16};
17use crate::indent::detect_indent;
18use crate::parser::{detect_language, grammar_for, LangId};
19use crate::protocol::{RawRequest, Response};
20
21pub fn handle_extract_function(req: &RawRequest, ctx: &AppContext) -> Response {
38 let op_id = crate::backup::new_op_id();
39 let file = match req.params.get("file").and_then(|v| v.as_str()) {
41 Some(f) => f,
42 None => {
43 return Response::error(
44 &req.id,
45 "invalid_request",
46 "extract_function: missing required param 'file'",
47 );
48 }
49 };
50
51 let name = match req.params.get("name").and_then(|v| v.as_str()) {
52 Some(n) => n,
53 None => {
54 return Response::error(
55 &req.id,
56 "invalid_request",
57 "extract_function: missing required param 'name'",
58 );
59 }
60 };
61
62 let start_line_1based = match req.params.get("start_line").and_then(|v| v.as_u64()) {
63 Some(l) if l >= 1 => l as u32,
64 Some(_) => {
65 return Response::error(
66 &req.id,
67 "invalid_request",
68 "extract_function: 'start_line' must be >= 1 (1-based)",
69 );
70 }
71 None => {
72 return Response::error(
73 &req.id,
74 "invalid_request",
75 "extract_function: missing required param 'start_line'",
76 );
77 }
78 };
79 let start_line = start_line_1based - 1;
80
81 let end_line_1based = match req.params.get("end_line").and_then(|v| v.as_u64()) {
82 Some(l) if l >= 1 => l as u32,
83 Some(_) => {
84 return Response::error(
85 &req.id,
86 "invalid_request",
87 "extract_function: 'end_line' must be >= 1 (1-based)",
88 );
89 }
90 None => {
91 return Response::error(
92 &req.id,
93 "invalid_request",
94 "extract_function: missing required param 'end_line'",
95 );
96 }
97 };
98 let end_line = end_line_1based - 1;
99
100 if start_line >= end_line {
101 return Response::error(
102 &req.id,
103 "invalid_request",
104 format!(
105 "extract_function: start_line ({}) must be less than end_line ({})",
106 start_line, end_line
107 ),
108 );
109 }
110
111 let path = match ctx.validate_path(&req.id, Path::new(file)) {
113 Ok(path) => path,
114 Err(resp) => return resp,
115 };
116 if !path.exists() {
117 return Response::error(
118 &req.id,
119 "file_not_found",
120 format!("extract_function: file not found: {}", file),
121 );
122 }
123
124 let lang = match detect_language(&path) {
126 Some(l) => l,
127 None => {
128 return Response::error(
129 &req.id,
130 "unsupported_language",
131 "extract_function: unsupported file type",
132 );
133 }
134 };
135
136 if !matches!(
137 lang,
138 LangId::TypeScript | LangId::Tsx | LangId::JavaScript | LangId::Python
139 ) {
140 return Response::error(
141 &req.id,
142 "unsupported_language",
143 format!(
144 "extract_function: only TypeScript/JavaScript/Python files are supported, got {:?}",
145 lang
146 ),
147 );
148 }
149
150 let source = match std::fs::read_to_string(&path) {
152 Ok(s) => s,
153 Err(e) => {
154 return Response::error(
155 &req.id,
156 "file_not_found",
157 format!("extract_function: {}: {}", file, e),
158 );
159 }
160 };
161
162 let grammar = grammar_for(lang);
163 let mut parser = Parser::new();
164 if parser.set_language(&grammar).is_err() {
165 return Response::error(
166 &req.id,
167 "parse_error",
168 "extract_function: failed to initialize parser",
169 );
170 }
171 let tree = match parser.parse(source.as_bytes(), None) {
172 Some(t) => t,
173 None => {
174 return Response::error(
175 &req.id,
176 "parse_error",
177 "extract_function: failed to parse file",
178 );
179 }
180 };
181
182 let start_byte = edit::line_col_to_byte(&source, start_line, 0);
184 let end_byte = edit::line_col_to_byte(&source, end_line, 0);
185
186 if start_byte >= source.len() {
187 return Response::error(
188 &req.id,
189 "invalid_request",
190 format!(
191 "extract_function: start_line {} is beyond end of file",
192 start_line
193 ),
194 );
195 }
196
197 let free_vars = detect_free_variables(&source, &tree, start_byte, end_byte, lang);
199
200 if free_vars.has_this_or_self {
202 let keyword = match lang {
203 LangId::Python => "self",
204 _ => "this",
205 };
206 return Response::error(
207 &req.id,
208 "this_reference_in_range",
209 format!(
210 "extract_function: selected range contains '{}' reference. Consider extracting as a method instead, or move the {} usage outside the extracted range.",
211 keyword, keyword
212 ),
213 );
214 }
215
216 let root = tree.root_node();
218 let enclosing_fn = find_enclosing_function_node(&root, start_byte, lang);
219 let enclosing_fn_end_byte = enclosing_fn.map(|n| n.end_byte());
220
221 let return_kind = detect_return_value(
223 &source,
224 &tree,
225 start_byte,
226 end_byte,
227 enclosing_fn_end_byte,
228 lang,
229 );
230
231 let indent_style = detect_indent(&source, lang);
233
234 let base_indent = if let Some(fn_node) = enclosing_fn {
237 let fn_start_line = fn_node.start_position().row;
238 get_line_indent(&source, fn_start_line as usize)
239 } else {
240 String::new()
241 };
242
243 let range_indent = get_line_indent(&source, start_line as usize);
245
246 let body_text = &source[start_byte..end_byte];
248 let body_text = body_text.trim_end_matches('\n');
249
250 let extracted_fn = generate_extracted_function(
252 name,
253 &free_vars.parameters,
254 &return_kind,
255 body_text,
256 &base_indent,
257 lang,
258 indent_style,
259 );
260
261 let call_site = generate_call_site(
262 name,
263 &free_vars.parameters,
264 &return_kind,
265 &range_indent,
266 lang,
267 );
268
269 let insert_pos = if let Some(fn_node) = enclosing_fn {
284 let mut anchor = fn_node;
285 if matches!(lang, LangId::TypeScript | LangId::Tsx | LangId::JavaScript) {
286 if let Some(parent) = fn_node.parent() {
287 if parent.kind() == "export_statement" {
288 anchor = parent;
289 }
290 }
291 }
292 anchor.start_byte()
293 } else {
294 start_byte
295 };
296
297 let new_source = build_new_source(
298 &source,
299 insert_pos,
300 start_byte,
301 end_byte,
302 &extracted_fn,
303 &call_site,
304 );
305
306 let return_type = match &return_kind {
308 ReturnKind::Expression(_) => "expression",
309 ReturnKind::Variable(_) => "variable",
310 ReturnKind::Void => "void",
311 };
312
313 let backup_id = match edit::auto_backup(
315 ctx,
316 req.session(),
317 &path,
318 &format!("extract_function: {}", name),
319 Some(&op_id),
320 ) {
321 Ok(id) => id,
322 Err(e) => {
323 return Response::error(&req.id, e.code(), e.to_string());
324 }
325 };
326
327 let mut write_result =
329 match edit::write_format_validate(&path, &new_source, &ctx.config(), &req.params) {
330 Ok(r) => r,
331 Err(e) => {
332 return Response::error(&req.id, e.code(), e.to_string());
333 }
334 };
335
336 if write_result.rolled_back {
343 return Response::error(
344 &req.id,
345 "generated_invalid_syntax",
346 format!(
347 "extract_function produced invalid syntax; the file was left unchanged. {}",
348 edit::format_validation_errors(&write_result.validation_errors)
349 ),
350 );
351 }
352
353 if let Ok(final_content) = std::fs::read_to_string(&path) {
354 write_result.lsp_outcome = ctx.lsp_post_write(&path, &final_content, &req.params);
355 }
356
357 let param_count = free_vars.parameters.len();
358 log::debug!(
359 "extract_function: {} from {}:{}-{} ({} params)",
360 name,
361 file,
362 start_line,
363 end_line,
364 param_count
365 );
366
367 let mut result = serde_json::json!({
369 "file": file,
370 "name": name,
371 "parameters": free_vars.parameters,
372 "return_type": return_type,
373 "formatted": write_result.formatted,
374 });
375
376 if let Some(valid) = write_result.syntax_valid {
377 result["syntax_valid"] = serde_json::json!(valid);
378 }
379
380 if let Some(ref reason) = write_result.format_skipped_reason {
381 result["format_skipped_reason"] = serde_json::json!(reason);
382 }
383
384 if write_result.validate_requested {
385 result["validation_errors"] = serde_json::json!(write_result.validation_errors);
386 }
387 if let Some(ref reason) = write_result.validate_skipped_reason {
388 result["validate_skipped_reason"] = serde_json::json!(reason);
389 }
390
391 if let Some(ref id) = backup_id {
392 result["backup_id"] = serde_json::json!(id);
393 }
394
395 write_result.append_lsp_diagnostics_to(&mut result);
396 write_result.append_reformatted_excerpt_to(&mut result);
397 Response::success(&req.id, result)
398}
399
400fn find_enclosing_function_node<'a>(
402 root: &'a tree_sitter::Node<'a>,
403 byte_pos: usize,
404 lang: LangId,
405) -> Option<tree_sitter::Node<'a>> {
406 let fn_kinds: &[&str] = match lang {
407 LangId::TypeScript | LangId::Tsx | LangId::JavaScript => &[
408 "function_declaration",
409 "method_definition",
410 "arrow_function",
411 "lexical_declaration",
412 ],
413 LangId::Python => &["function_definition"],
414 _ => &[],
415 };
416
417 find_deepest_ancestor(root, byte_pos, fn_kinds)
418}
419
420fn find_deepest_ancestor<'a>(
422 node: &tree_sitter::Node<'a>,
423 byte_pos: usize,
424 kinds: &[&str],
425) -> Option<tree_sitter::Node<'a>> {
426 let mut result: Option<tree_sitter::Node<'a>> = None;
427 if kinds.contains(&node.kind()) && node.start_byte() <= byte_pos && byte_pos < node.end_byte() {
428 result = Some(*node);
429 }
430
431 let child_count = node.child_count();
432 for i in 0..child_count {
433 if let Some(child) = node.child(i as u32) {
434 if child.start_byte() <= byte_pos && byte_pos < child.end_byte() {
435 if let Some(deeper) = find_deepest_ancestor(&child, byte_pos, kinds) {
436 result = Some(deeper);
437 }
438 }
439 }
440 }
441
442 result
443}
444
445fn get_line_indent(source: &str, line: usize) -> String {
447 source
448 .lines()
449 .nth(line)
450 .map(|l| {
451 let trimmed = l.trim_start();
452 l[..l.len() - trimmed.len()].to_string()
453 })
454 .unwrap_or_default()
455}
456
457fn build_new_source(
459 source: &str,
460 insert_pos: usize,
461 range_start: usize,
462 range_end: usize,
463 extracted_fn: &str,
464 call_site: &str,
465) -> String {
466 let mut result = String::with_capacity(source.len() + extracted_fn.len() + 64);
467
468 result.push_str(&source[..insert_pos]);
470
471 result.push_str(extracted_fn);
473 result.push_str("\n\n");
474
475 result.push_str(&source[insert_pos..range_start]);
478
479 result.push_str(call_site);
481 result.push('\n');
482
483 result.push_str(&source[range_end..]);
485
486 result
487}
488
489#[cfg(test)]
494mod tests {
495 use super::*;
496 use crate::protocol::RawRequest;
497
498 fn make_request(id: &str, command: &str, params: serde_json::Value) -> RawRequest {
499 RawRequest {
500 id: id.to_string(),
501 command: command.to_string(),
502 params,
503 lsp_hints: None,
504 session_id: None,
505 }
506 }
507
508 #[test]
511 fn extract_function_missing_file() {
512 let req = make_request("1", "extract_function", serde_json::json!({}));
513 let ctx = crate::context::AppContext::new(
514 Box::new(crate::parser::TreeSitterProvider::new()),
515 crate::config::Config::default(),
516 );
517 let resp = handle_extract_function(&req, &ctx);
518 let json = serde_json::to_value(&resp).unwrap();
519 assert_eq!(json["success"], false);
520 assert_eq!(json["code"], "invalid_request");
521 let msg = json["message"].as_str().unwrap();
522 assert!(
523 msg.contains("file"),
524 "message should mention 'file': {}",
525 msg
526 );
527 }
528
529 #[test]
530 fn extract_function_missing_name() {
531 let req = make_request(
532 "2",
533 "extract_function",
534 serde_json::json!({"file": "/tmp/test.ts"}),
535 );
536 let ctx = crate::context::AppContext::new(
537 Box::new(crate::parser::TreeSitterProvider::new()),
538 crate::config::Config::default(),
539 );
540 let resp = handle_extract_function(&req, &ctx);
541 let json = serde_json::to_value(&resp).unwrap();
542 assert_eq!(json["success"], false);
543 assert_eq!(json["code"], "invalid_request");
544 let msg = json["message"].as_str().unwrap();
545 assert!(
546 msg.contains("name"),
547 "message should mention 'name': {}",
548 msg
549 );
550 }
551
552 #[test]
553 fn extract_function_missing_start_line() {
554 let req = make_request(
555 "3",
556 "extract_function",
557 serde_json::json!({"file": "/tmp/test.ts", "name": "foo"}),
558 );
559 let ctx = crate::context::AppContext::new(
560 Box::new(crate::parser::TreeSitterProvider::new()),
561 crate::config::Config::default(),
562 );
563 let resp = handle_extract_function(&req, &ctx);
564 let json = serde_json::to_value(&resp).unwrap();
565 assert_eq!(json["success"], false);
566 assert_eq!(json["code"], "invalid_request");
567 }
568
569 #[test]
570 fn extract_function_unsupported_language() {
571 let dir = std::env::temp_dir().join("aft_test_extract");
573 std::fs::create_dir_all(&dir).ok();
574 let file = dir.join("test.rs");
575 std::fs::write(&file, "fn main() {}").unwrap();
576
577 let req = make_request(
578 "4",
579 "extract_function",
580 serde_json::json!({
581 "file": file.display().to_string(),
582 "name": "foo",
583 "start_line": 1,
584 "end_line": 2,
585 }),
586 );
587 let ctx = crate::context::AppContext::new(
588 Box::new(crate::parser::TreeSitterProvider::new()),
589 crate::config::Config::default(),
590 );
591 let resp = handle_extract_function(&req, &ctx);
592 let json = serde_json::to_value(&resp).unwrap();
593 assert_eq!(json["success"], false);
594 assert_eq!(json["code"], "unsupported_language");
595
596 std::fs::remove_dir_all(&dir).ok();
597 }
598
599 #[test]
600 fn extract_function_invalid_line_range() {
601 let dir = std::env::temp_dir().join("aft_test_extract_range");
602 std::fs::create_dir_all(&dir).ok();
603 let file = dir.join("test.ts");
604 std::fs::write(&file, "const x = 1;\n").unwrap();
605
606 let req = make_request(
607 "5",
608 "extract_function",
609 serde_json::json!({
610 "file": file.display().to_string(),
611 "name": "foo",
612 "start_line": 6,
613 "end_line": 4,
614 }),
615 );
616 let ctx = crate::context::AppContext::new(
617 Box::new(crate::parser::TreeSitterProvider::new()),
618 crate::config::Config::default(),
619 );
620 let resp = handle_extract_function(&req, &ctx);
621 let json = serde_json::to_value(&resp).unwrap();
622 assert_eq!(json["success"], false);
623 assert_eq!(json["code"], "invalid_request");
624
625 std::fs::remove_dir_all(&dir).ok();
626 }
627
628 #[test]
629 fn extract_function_this_reference_error() {
630 let fixture = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
631 .join("tests/fixtures/extract_function/sample_this.ts");
632
633 let req = make_request(
634 "6",
635 "extract_function",
636 serde_json::json!({
637 "file": fixture.display().to_string(),
638 "name": "extracted",
639 "start_line": 5,
640 "end_line": 8,
641 }),
642 );
643 let ctx = crate::context::AppContext::new(
644 Box::new(crate::parser::TreeSitterProvider::new()),
645 crate::config::Config::default(),
646 );
647 let resp = handle_extract_function(&req, &ctx);
648 let json = serde_json::to_value(&resp).unwrap();
649 assert_eq!(json["success"], false);
650 assert_eq!(json["code"], "this_reference_in_range");
651 }
652}