1use std::path::Path;
8
9use tree_sitter::Parser;
10
11use crate::context::AppContext;
12use crate::edit;
13use crate::extract::{
14 detect_free_variables, detect_return_value, generate_call_site, generate_extracted_function,
15 ReturnKind,
16};
17use crate::indent::detect_indent;
18use crate::parser::{detect_language, grammar_for, LangId};
19use crate::protocol::{RawRequest, Response};
20
21pub fn handle_extract_function(req: &RawRequest, ctx: &AppContext) -> Response {
36 let file = match req.params.get("file").and_then(|v| v.as_str()) {
38 Some(f) => f,
39 None => {
40 return Response::error(
41 &req.id,
42 "invalid_request",
43 "extract_function: missing required param 'file'",
44 );
45 }
46 };
47
48 let name = match req.params.get("name").and_then(|v| v.as_str()) {
49 Some(n) => n,
50 None => {
51 return Response::error(
52 &req.id,
53 "invalid_request",
54 "extract_function: missing required param 'name'",
55 );
56 }
57 };
58
59 let start_line_1based = match req.params.get("start_line").and_then(|v| v.as_u64()) {
60 Some(l) if l >= 1 => l as u32,
61 Some(_) => {
62 return Response::error(
63 &req.id,
64 "invalid_request",
65 "extract_function: 'start_line' must be >= 1 (1-based)",
66 );
67 }
68 None => {
69 return Response::error(
70 &req.id,
71 "invalid_request",
72 "extract_function: missing required param 'start_line'",
73 );
74 }
75 };
76 let start_line = start_line_1based - 1;
77
78 let end_line_1based = match req.params.get("end_line").and_then(|v| v.as_u64()) {
79 Some(l) if l >= 1 => l as u32,
80 Some(_) => {
81 return Response::error(
82 &req.id,
83 "invalid_request",
84 "extract_function: 'end_line' must be >= 1 (1-based)",
85 );
86 }
87 None => {
88 return Response::error(
89 &req.id,
90 "invalid_request",
91 "extract_function: missing required param 'end_line'",
92 );
93 }
94 };
95 let end_line = end_line_1based - 1;
96
97 if start_line >= end_line {
98 return Response::error(
99 &req.id,
100 "invalid_request",
101 format!(
102 "extract_function: start_line ({}) must be less than end_line ({})",
103 start_line, end_line
104 ),
105 );
106 }
107
108 let path = match ctx.validate_path(&req.id, Path::new(file)) {
110 Ok(path) => path,
111 Err(resp) => return resp,
112 };
113 if !path.exists() {
114 return Response::error(
115 &req.id,
116 "file_not_found",
117 format!("extract_function: file not found: {}", file),
118 );
119 }
120
121 let lang = match detect_language(&path) {
123 Some(l) => l,
124 None => {
125 return Response::error(
126 &req.id,
127 "unsupported_language",
128 "extract_function: unsupported file type",
129 );
130 }
131 };
132
133 if !matches!(
134 lang,
135 LangId::TypeScript | LangId::Tsx | LangId::JavaScript | LangId::Python
136 ) {
137 return Response::error(
138 &req.id,
139 "unsupported_language",
140 format!(
141 "extract_function: only TypeScript/JavaScript/Python files are supported, got {:?}",
142 lang
143 ),
144 );
145 }
146
147 let source = match std::fs::read_to_string(&path) {
149 Ok(s) => s,
150 Err(e) => {
151 return Response::error(
152 &req.id,
153 "file_not_found",
154 format!("extract_function: {}: {}", file, e),
155 );
156 }
157 };
158
159 let grammar = grammar_for(lang);
160 let mut parser = Parser::new();
161 if parser.set_language(&grammar).is_err() {
162 return Response::error(
163 &req.id,
164 "parse_error",
165 "extract_function: failed to initialize parser",
166 );
167 }
168 let tree = match parser.parse(source.as_bytes(), None) {
169 Some(t) => t,
170 None => {
171 return Response::error(
172 &req.id,
173 "parse_error",
174 "extract_function: failed to parse file",
175 );
176 }
177 };
178
179 let start_byte = edit::line_col_to_byte(&source, start_line, 0);
181 let end_byte = edit::line_col_to_byte(&source, end_line, 0);
182
183 if start_byte >= source.len() {
184 return Response::error(
185 &req.id,
186 "invalid_request",
187 format!(
188 "extract_function: start_line {} is beyond end of file",
189 start_line
190 ),
191 );
192 }
193
194 let free_vars = detect_free_variables(&source, &tree, start_byte, end_byte, lang);
196
197 if free_vars.has_this_or_self {
199 let keyword = match lang {
200 LangId::Python => "self",
201 _ => "this",
202 };
203 return Response::error(
204 &req.id,
205 "this_reference_in_range",
206 format!(
207 "extract_function: selected range contains '{}' reference. Consider extracting as a method instead, or move the {} usage outside the extracted range.",
208 keyword, keyword
209 ),
210 );
211 }
212
213 let root = tree.root_node();
215 let enclosing_fn = find_enclosing_function_node(&root, start_byte, lang);
216 let enclosing_fn_end_byte = enclosing_fn.map(|n| n.end_byte());
217
218 let return_kind = detect_return_value(
220 &source,
221 &tree,
222 start_byte,
223 end_byte,
224 enclosing_fn_end_byte,
225 lang,
226 );
227
228 let indent_style = detect_indent(&source, lang);
230
231 let base_indent = if let Some(fn_node) = enclosing_fn {
234 let fn_start_line = fn_node.start_position().row;
235 get_line_indent(&source, fn_start_line as usize)
236 } else {
237 String::new()
238 };
239
240 let range_indent = get_line_indent(&source, start_line as usize);
242
243 let body_text = &source[start_byte..end_byte];
245 let body_text = body_text.trim_end_matches('\n');
246
247 let extracted_fn = generate_extracted_function(
249 name,
250 &free_vars.parameters,
251 &return_kind,
252 body_text,
253 &base_indent,
254 lang,
255 indent_style,
256 );
257
258 let call_site = generate_call_site(
259 name,
260 &free_vars.parameters,
261 &return_kind,
262 &range_indent,
263 lang,
264 );
265
266 let insert_pos = if let Some(fn_node) = enclosing_fn {
281 let mut anchor = fn_node;
282 if matches!(lang, LangId::TypeScript | LangId::Tsx | LangId::JavaScript) {
283 if let Some(parent) = fn_node.parent() {
284 if parent.kind() == "export_statement" {
285 anchor = parent;
286 }
287 }
288 }
289 anchor.start_byte()
290 } else {
291 start_byte
292 };
293
294 let new_source = build_new_source(
295 &source,
296 insert_pos,
297 start_byte,
298 end_byte,
299 &extracted_fn,
300 &call_site,
301 );
302
303 let return_type = match &return_kind {
305 ReturnKind::Expression(_) => "expression",
306 ReturnKind::Variable(_) => "variable",
307 ReturnKind::Void => "void",
308 };
309
310 let backup_id = match edit::auto_backup(
312 ctx,
313 req.session(),
314 &path,
315 &format!("extract_function: {}", name),
316 ) {
317 Ok(id) => id,
318 Err(e) => {
319 return Response::error(&req.id, e.code(), e.to_string());
320 }
321 };
322
323 let mut write_result =
325 match edit::write_format_validate(&path, &new_source, &ctx.config(), &req.params) {
326 Ok(r) => r,
327 Err(e) => {
328 return Response::error(&req.id, e.code(), e.to_string());
329 }
330 };
331
332 if let Ok(final_content) = std::fs::read_to_string(&path) {
333 write_result.lsp_outcome = ctx.lsp_post_write(&path, &final_content, &req.params);
334 }
335
336 let param_count = free_vars.parameters.len();
337 log::debug!(
338 "extract_function: {} from {}:{}-{} ({} params)",
339 name,
340 file,
341 start_line,
342 end_line,
343 param_count
344 );
345
346 let syntax_valid = write_result.syntax_valid.unwrap_or(true);
348
349 let mut result = serde_json::json!({
350 "file": file,
351 "name": name,
352 "parameters": free_vars.parameters,
353 "return_type": return_type,
354 "syntax_valid": syntax_valid,
355 "formatted": write_result.formatted,
356 });
357
358 if let Some(ref reason) = write_result.format_skipped_reason {
359 result["format_skipped_reason"] = serde_json::json!(reason);
360 }
361
362 if write_result.validate_requested {
363 result["validation_errors"] = serde_json::json!(write_result.validation_errors);
364 }
365 if let Some(ref reason) = write_result.validate_skipped_reason {
366 result["validate_skipped_reason"] = serde_json::json!(reason);
367 }
368
369 if let Some(ref id) = backup_id {
370 result["backup_id"] = serde_json::json!(id);
371 }
372
373 write_result.append_lsp_diagnostics_to(&mut result);
374 Response::success(&req.id, result)
375}
376
377fn find_enclosing_function_node<'a>(
379 root: &'a tree_sitter::Node<'a>,
380 byte_pos: usize,
381 lang: LangId,
382) -> Option<tree_sitter::Node<'a>> {
383 let fn_kinds: &[&str] = match lang {
384 LangId::TypeScript | LangId::Tsx | LangId::JavaScript => &[
385 "function_declaration",
386 "method_definition",
387 "arrow_function",
388 "lexical_declaration",
389 ],
390 LangId::Python => &["function_definition"],
391 _ => &[],
392 };
393
394 find_deepest_ancestor(root, byte_pos, fn_kinds)
395}
396
397fn find_deepest_ancestor<'a>(
399 node: &tree_sitter::Node<'a>,
400 byte_pos: usize,
401 kinds: &[&str],
402) -> Option<tree_sitter::Node<'a>> {
403 let mut result: Option<tree_sitter::Node<'a>> = None;
404 if kinds.contains(&node.kind()) && node.start_byte() <= byte_pos && byte_pos < node.end_byte() {
405 result = Some(*node);
406 }
407
408 let child_count = node.child_count();
409 for i in 0..child_count {
410 if let Some(child) = node.child(i as u32) {
411 if child.start_byte() <= byte_pos && byte_pos < child.end_byte() {
412 if let Some(deeper) = find_deepest_ancestor(&child, byte_pos, kinds) {
413 result = Some(deeper);
414 }
415 }
416 }
417 }
418
419 result
420}
421
422fn get_line_indent(source: &str, line: usize) -> String {
424 source
425 .lines()
426 .nth(line)
427 .map(|l| {
428 let trimmed = l.trim_start();
429 l[..l.len() - trimmed.len()].to_string()
430 })
431 .unwrap_or_default()
432}
433
434fn build_new_source(
436 source: &str,
437 insert_pos: usize,
438 range_start: usize,
439 range_end: usize,
440 extracted_fn: &str,
441 call_site: &str,
442) -> String {
443 let mut result = String::with_capacity(source.len() + extracted_fn.len() + 64);
444
445 result.push_str(&source[..insert_pos]);
447
448 result.push_str(extracted_fn);
450 result.push_str("\n\n");
451
452 result.push_str(&source[insert_pos..range_start]);
455
456 result.push_str(call_site);
458 result.push('\n');
459
460 result.push_str(&source[range_end..]);
462
463 result
464}
465
466#[cfg(test)]
471mod tests {
472 use super::*;
473 use crate::protocol::RawRequest;
474
475 fn make_request(id: &str, command: &str, params: serde_json::Value) -> RawRequest {
476 RawRequest {
477 id: id.to_string(),
478 command: command.to_string(),
479 params,
480 lsp_hints: None,
481 session_id: None,
482 }
483 }
484
485 #[test]
488 fn extract_function_missing_file() {
489 let req = make_request("1", "extract_function", serde_json::json!({}));
490 let ctx = crate::context::AppContext::new(
491 Box::new(crate::parser::TreeSitterProvider::new()),
492 crate::config::Config::default(),
493 );
494 let resp = handle_extract_function(&req, &ctx);
495 let json = serde_json::to_value(&resp).unwrap();
496 assert_eq!(json["success"], false);
497 assert_eq!(json["code"], "invalid_request");
498 let msg = json["message"].as_str().unwrap();
499 assert!(
500 msg.contains("file"),
501 "message should mention 'file': {}",
502 msg
503 );
504 }
505
506 #[test]
507 fn extract_function_missing_name() {
508 let req = make_request(
509 "2",
510 "extract_function",
511 serde_json::json!({"file": "/tmp/test.ts"}),
512 );
513 let ctx = crate::context::AppContext::new(
514 Box::new(crate::parser::TreeSitterProvider::new()),
515 crate::config::Config::default(),
516 );
517 let resp = handle_extract_function(&req, &ctx);
518 let json = serde_json::to_value(&resp).unwrap();
519 assert_eq!(json["success"], false);
520 assert_eq!(json["code"], "invalid_request");
521 let msg = json["message"].as_str().unwrap();
522 assert!(
523 msg.contains("name"),
524 "message should mention 'name': {}",
525 msg
526 );
527 }
528
529 #[test]
530 fn extract_function_missing_start_line() {
531 let req = make_request(
532 "3",
533 "extract_function",
534 serde_json::json!({"file": "/tmp/test.ts", "name": "foo"}),
535 );
536 let ctx = crate::context::AppContext::new(
537 Box::new(crate::parser::TreeSitterProvider::new()),
538 crate::config::Config::default(),
539 );
540 let resp = handle_extract_function(&req, &ctx);
541 let json = serde_json::to_value(&resp).unwrap();
542 assert_eq!(json["success"], false);
543 assert_eq!(json["code"], "invalid_request");
544 }
545
546 #[test]
547 fn extract_function_unsupported_language() {
548 let dir = std::env::temp_dir().join("aft_test_extract");
550 std::fs::create_dir_all(&dir).ok();
551 let file = dir.join("test.rs");
552 std::fs::write(&file, "fn main() {}").unwrap();
553
554 let req = make_request(
555 "4",
556 "extract_function",
557 serde_json::json!({
558 "file": file.display().to_string(),
559 "name": "foo",
560 "start_line": 1,
561 "end_line": 2,
562 }),
563 );
564 let ctx = crate::context::AppContext::new(
565 Box::new(crate::parser::TreeSitterProvider::new()),
566 crate::config::Config::default(),
567 );
568 let resp = handle_extract_function(&req, &ctx);
569 let json = serde_json::to_value(&resp).unwrap();
570 assert_eq!(json["success"], false);
571 assert_eq!(json["code"], "unsupported_language");
572
573 std::fs::remove_dir_all(&dir).ok();
574 }
575
576 #[test]
577 fn extract_function_invalid_line_range() {
578 let dir = std::env::temp_dir().join("aft_test_extract_range");
579 std::fs::create_dir_all(&dir).ok();
580 let file = dir.join("test.ts");
581 std::fs::write(&file, "const x = 1;\n").unwrap();
582
583 let req = make_request(
584 "5",
585 "extract_function",
586 serde_json::json!({
587 "file": file.display().to_string(),
588 "name": "foo",
589 "start_line": 6,
590 "end_line": 4,
591 }),
592 );
593 let ctx = crate::context::AppContext::new(
594 Box::new(crate::parser::TreeSitterProvider::new()),
595 crate::config::Config::default(),
596 );
597 let resp = handle_extract_function(&req, &ctx);
598 let json = serde_json::to_value(&resp).unwrap();
599 assert_eq!(json["success"], false);
600 assert_eq!(json["code"], "invalid_request");
601
602 std::fs::remove_dir_all(&dir).ok();
603 }
604
605 #[test]
606 fn extract_function_this_reference_error() {
607 let fixture = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
608 .join("tests/fixtures/extract_function/sample_this.ts");
609
610 let req = make_request(
611 "6",
612 "extract_function",
613 serde_json::json!({
614 "file": fixture.display().to_string(),
615 "name": "extracted",
616 "start_line": 5,
617 "end_line": 8,
618 }),
619 );
620 let ctx = crate::context::AppContext::new(
621 Box::new(crate::parser::TreeSitterProvider::new()),
622 crate::config::Config::default(),
623 );
624 let resp = handle_extract_function(&req, &ctx);
625 let json = serde_json::to_value(&resp).unwrap();
626 assert_eq!(json["success"], false);
627 assert_eq!(json["code"], "this_reference_in_range");
628 }
629}