Skip to main content

aft/commands/
zoom.rs

1use std::path::Path;
2
3use serde::Serialize;
4
5use crate::context::AppContext;
6use crate::edit::line_col_to_byte;
7use crate::lsp_hints;
8use crate::parser::{FileParser, LangId};
9use crate::protocol::{RawRequest, Response};
10use crate::symbols::Range;
11
12/// A reference to a called/calling function.
13#[derive(Debug, Clone, Serialize)]
14pub struct CallRef {
15    pub name: String,
16    /// 1-based line number of the call reference.
17    pub line: u32,
18}
19
20/// Annotations describing file-scoped call relationships.
21#[derive(Debug, Clone, Serialize)]
22pub struct Annotations {
23    pub calls_out: Vec<CallRef>,
24    pub called_by: Vec<CallRef>,
25}
26
27/// Response payload for the zoom command.
28#[derive(Debug, Clone, Serialize)]
29pub struct ZoomResponse {
30    pub name: String,
31    pub kind: String,
32    pub range: Range,
33    pub content: String,
34    pub context_before: Vec<String>,
35    pub context_after: Vec<String>,
36    pub annotations: Annotations,
37}
38
39/// Handle a `zoom` request.
40///
41/// Expects `file`, `symbol` in request params, optional `context_lines` (default 3).
42/// Resolves the symbol, extracts body + context, walks AST for call annotations.
43pub fn handle_zoom(req: &RawRequest, ctx: &AppContext) -> Response {
44    let file = match req.params.get("file").and_then(|v| v.as_str()) {
45        Some(f) => f,
46        None => {
47            return Response::error(
48                &req.id,
49                "invalid_request",
50                "zoom: missing required param 'file'",
51            );
52        }
53    };
54
55    let context_lines = req
56        .params
57        .get("context_lines")
58        .and_then(|v| v.as_u64())
59        .unwrap_or(3) as usize;
60
61    let start_line = req
62        .params
63        .get("start_line")
64        .and_then(|v| v.as_u64())
65        .map(|v| v as usize);
66    let end_line = req
67        .params
68        .get("end_line")
69        .and_then(|v| v.as_u64())
70        .map(|v| v as usize);
71
72    let path = Path::new(file);
73    if !path.exists() {
74        return Response::error(
75            &req.id,
76            "file_not_found",
77            format!("file not found: {}", file),
78        );
79    }
80
81    // Read source file early because both symbol mode and line-range mode need it.
82    let source = match std::fs::read_to_string(path) {
83        Ok(s) => s,
84        Err(e) => {
85            return Response::error(&req.id, "file_not_found", format!("{}: {}", file, e));
86        }
87    };
88
89    let lines: Vec<String> = source.lines().map(|l| l.to_string()).collect();
90
91    // Line-range mode: read arbitrary lines without requiring a symbol.
92    match (start_line, end_line) {
93        (Some(start), Some(end)) => {
94            if req.params.get("symbol").is_some() {
95                return Response::error(
96                    &req.id,
97                    "invalid_request",
98                    "zoom: provide either 'symbol' OR ('start_line' and 'end_line'), not both",
99                );
100            }
101            if start == 0 || end == 0 {
102                return Response::error(
103                    &req.id,
104                    "invalid_request",
105                    "zoom: 'start_line' and 'end_line' are 1-based and must be >= 1",
106                );
107            }
108            if end < start {
109                return Response::error(
110                    &req.id,
111                    "invalid_request",
112                    format!("zoom: end_line {} must be >= start_line {}", end, start),
113                );
114            }
115            if lines.is_empty() {
116                return Response::error(
117                    &req.id,
118                    "invalid_request",
119                    format!("zoom: {} is empty", file),
120                );
121            }
122
123            let start_idx = start - 1;
124            // Clamp end_line to file length (same as batch edits)
125            let end_idx = (end - 1).min(lines.len() - 1);
126            if start_idx >= lines.len() {
127                return Response::error(
128                    &req.id,
129                    "invalid_request",
130                    format!(
131                        "zoom: start_line {} is past end of {} ({} lines)",
132                        start,
133                        file,
134                        lines.len()
135                    ),
136                );
137            }
138
139            let content = lines[start_idx..=end_idx].join("\n");
140            let ctx_start = start_idx.saturating_sub(context_lines);
141            let context_before: Vec<String> = if ctx_start < start_idx {
142                lines[ctx_start..start_idx]
143                    .iter()
144                    .map(|l| l.to_string())
145                    .collect()
146            } else {
147                vec![]
148            };
149            let ctx_end = (end_idx + 1 + context_lines).min(lines.len());
150            let context_after: Vec<String> = if end_idx + 1 < lines.len() {
151                lines[(end_idx + 1)..ctx_end]
152                    .iter()
153                    .map(|l| l.to_string())
154                    .collect()
155            } else {
156                vec![]
157            };
158            let end_col = lines[end_idx].chars().count() as u32;
159
160            return Response::success(
161                &req.id,
162                serde_json::json!({
163                    "name": format!("lines {}-{}", start, end),
164                    "kind": "lines",
165                    "range": {
166                        "start_line": start,  // already 1-based from user input
167                        "start_col": 1,
168                        "end_line": end,      // already 1-based from user input
169                        "end_col": end_col + 1,
170                    },
171                    "content": content,
172                    "context_before": context_before,
173                    "context_after": context_after,
174                    "annotations": {
175                        "calls_out": [],
176                        "called_by": [],
177                    },
178                }),
179            );
180        }
181        (Some(_), None) | (None, Some(_)) => {
182            return Response::error(
183                &req.id,
184                "invalid_request",
185                "zoom: provide both 'start_line' and 'end_line' for line-range mode",
186            );
187        }
188        (None, None) => {}
189    }
190
191    let symbol_name = match req.params.get("symbol").and_then(|v| v.as_str()) {
192        Some(s) => s,
193        None => {
194            return Response::error(
195                &req.id,
196                "invalid_request",
197                "zoom: missing required param 'symbol' (or use 'start_line' and 'end_line')",
198            );
199        }
200    };
201
202    // Resolve the target symbol
203    let matches = match ctx.provider().resolve_symbol(path, symbol_name) {
204        Ok(m) => m,
205        Err(e) => {
206            return Response::error(&req.id, e.code(), e.to_string());
207        }
208    };
209
210    // LSP-enhanced disambiguation (S03)
211    let matches = if let Some(hints) = lsp_hints::parse_lsp_hints(req) {
212        lsp_hints::apply_lsp_disambiguation(matches, &hints)
213    } else {
214        matches
215    };
216
217    if matches.len() > 1 {
218        // Ambiguous — return qualified candidates
219        let candidates: Vec<String> = matches
220            .iter()
221            .map(|m| {
222                let sym = &m.symbol;
223                if sym.scope_chain.is_empty() {
224                    format!("{}:{}", sym.name, sym.range.start_line)
225                } else {
226                    format!(
227                        "{}::{}:{}",
228                        sym.scope_chain.join("::"),
229                        sym.name,
230                        sym.range.start_line
231                    )
232                }
233            })
234            .collect();
235        return Response::error(
236            &req.id,
237            "ambiguous_symbol",
238            format!(
239                "symbol '{}' is ambiguous, candidates: [{}]",
240                symbol_name,
241                candidates.join(", ")
242            ),
243        );
244    }
245
246    let target = &matches[0].symbol;
247    let start = target.range.start_line as usize;
248    let end = target.range.end_line as usize;
249
250    // When re-export following resolved to a different file, re-read that file's lines
251    let resolved_file_path = std::path::Path::new(&matches[0].file);
252    let resolved_lines: Vec<String>;
253    let effective_lines: &[String] = if resolved_file_path != path {
254        resolved_lines = match std::fs::read_to_string(resolved_file_path) {
255            Ok(src) => src.lines().map(|l| l.to_string()).collect(),
256            Err(_) => lines.clone(),
257        };
258        &resolved_lines
259    } else {
260        &lines
261    };
262
263    // Extract symbol body (0-based line indices)
264    let content = if end < effective_lines.len() {
265        effective_lines[start..=end].join("\n")
266    } else {
267        effective_lines[start..].join("\n")
268    };
269
270    // Context before
271    let ctx_start = start.saturating_sub(context_lines);
272    let context_before: Vec<String> = if ctx_start < start {
273        effective_lines[ctx_start..start]
274            .iter()
275            .map(|l| l.to_string())
276            .collect()
277    } else {
278        vec![]
279    };
280
281    // Context after
282    let ctx_end = (end + 1 + context_lines).min(effective_lines.len());
283    let context_after: Vec<String> = if end + 1 < effective_lines.len() {
284        effective_lines[(end + 1)..ctx_end]
285            .iter()
286            .map(|l| l.to_string())
287            .collect()
288    } else {
289        vec![]
290    };
291
292    // Get all symbols in the resolved file for call matching
293    let all_symbols = match ctx.provider().list_symbols(resolved_file_path) {
294        Ok(s) => s,
295        Err(e) => {
296            return Response::error(&req.id, e.code(), e.to_string());
297        }
298    };
299
300    let known_names: Vec<&str> = all_symbols.iter().map(|s| s.name.as_str()).collect();
301
302    // Parse AST for call extraction (use resolved file for cross-file re-exports)
303    let mut parser = FileParser::new();
304    let (tree, lang) = match parser.parse(resolved_file_path) {
305        Ok(r) => r,
306        Err(e) => {
307            return Response::error(&req.id, e.code(), e.to_string());
308        }
309    };
310
311    // calls_out: calls within the target symbol's byte range
312    let resolved_source = if resolved_file_path != path {
313        std::fs::read_to_string(resolved_file_path).unwrap_or_else(|_| source.clone())
314    } else {
315        source.clone()
316    };
317    let target_byte_start = line_col_to_byte(
318        &resolved_source,
319        target.range.start_line,
320        target.range.start_col,
321    );
322    let target_byte_end = line_col_to_byte(
323        &resolved_source,
324        target.range.end_line,
325        target.range.end_col,
326    );
327
328    let raw_calls = extract_calls_in_range(
329        &resolved_source,
330        tree.root_node(),
331        target_byte_start,
332        target_byte_end,
333        lang,
334    );
335    let calls_out: Vec<CallRef> = raw_calls
336        .into_iter()
337        .filter(|(name, _)| known_names.contains(&name.as_str()) && *name != target.name)
338        .map(|(name, line)| CallRef { name, line })
339        .collect();
340
341    // called_by: scan all other symbols for calls to this symbol
342    let mut called_by: Vec<CallRef> = Vec::new();
343    for sym in &all_symbols {
344        if sym.name == target.name && sym.range.start_line == target.range.start_line {
345            continue; // skip self
346        }
347        let sym_byte_start =
348            line_col_to_byte(&resolved_source, sym.range.start_line, sym.range.start_col);
349        let sym_byte_end =
350            line_col_to_byte(&resolved_source, sym.range.end_line, sym.range.end_col);
351        let calls = extract_calls_in_range(
352            &resolved_source,
353            tree.root_node(),
354            sym_byte_start,
355            sym_byte_end,
356            lang,
357        );
358        for (name, line) in calls {
359            if name == target.name {
360                called_by.push(CallRef {
361                    name: sym.name.clone(),
362                    line,
363                });
364            }
365        }
366    }
367
368    // Dedup called_by by (name, line)
369    called_by.sort_by(|a, b| a.name.cmp(&b.name).then(a.line.cmp(&b.line)));
370    called_by.dedup_by(|a, b| a.name == b.name && a.line == b.line);
371
372    let kind_str = serde_json::to_value(&target.kind)
373        .ok()
374        .and_then(|v| v.as_str().map(String::from))
375        .unwrap_or_else(|| format!("{:?}", target.kind).to_lowercase());
376
377    let resp = ZoomResponse {
378        name: target.name.clone(),
379        kind: kind_str,
380        range: target.range.clone(),
381        content,
382        context_before,
383        context_after,
384        annotations: Annotations {
385            calls_out,
386            called_by,
387        },
388    };
389
390    match serde_json::to_value(&resp) {
391        Ok(resp_json) => Response::success(&req.id, resp_json),
392        Err(err) => Response::error(
393            &req.id,
394            "internal_error",
395            format!("zoom: failed to serialize response: {err}"),
396        ),
397    }
398}
399
400/// Extract call expression names within a byte range of the AST.
401///
402/// Delegates to `crate::calls::extract_calls_in_range`.
403fn extract_calls_in_range(
404    source: &str,
405    root: tree_sitter::Node,
406    byte_start: usize,
407    byte_end: usize,
408    lang: LangId,
409) -> Vec<(String, u32)> {
410    crate::calls::extract_calls_in_range(source, root, byte_start, byte_end, lang)
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416    use crate::config::Config;
417    use crate::context::AppContext;
418    use crate::parser::TreeSitterProvider;
419    use std::path::PathBuf;
420
421    fn fixture_path(name: &str) -> PathBuf {
422        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
423            .join("tests")
424            .join("fixtures")
425            .join(name)
426    }
427
428    fn make_ctx() -> AppContext {
429        AppContext::new(Box::new(TreeSitterProvider::new()), Config::default())
430    }
431
432    // --- Call extraction tests ---
433
434    #[test]
435    fn extract_calls_finds_direct_calls() {
436        let source = std::fs::read_to_string(fixture_path("calls.ts")).unwrap();
437        let mut parser = FileParser::new();
438        let path = fixture_path("calls.ts");
439        let (tree, lang) = parser.parse(&path).unwrap();
440
441        // `compute` calls `helper` — find compute's range from symbols
442        let ctx = make_ctx();
443        let symbols = ctx.provider().list_symbols(&path).unwrap();
444        let compute = symbols.iter().find(|s| s.name == "compute").unwrap();
445
446        let byte_start =
447            line_col_to_byte(&source, compute.range.start_line, compute.range.start_col);
448        let byte_end = line_col_to_byte(&source, compute.range.end_line, compute.range.end_col);
449
450        let calls = extract_calls_in_range(&source, tree.root_node(), byte_start, byte_end, lang);
451        let names: Vec<&str> = calls.iter().map(|(n, _)| n.as_str()).collect();
452
453        assert!(
454            names.contains(&"helper"),
455            "compute should call helper, got: {:?}",
456            names
457        );
458    }
459
460    #[test]
461    fn extract_calls_finds_member_calls() {
462        let source = std::fs::read_to_string(fixture_path("calls.ts")).unwrap();
463        let mut parser = FileParser::new();
464        let path = fixture_path("calls.ts");
465        let (tree, lang) = parser.parse(&path).unwrap();
466
467        let ctx = make_ctx();
468        let symbols = ctx.provider().list_symbols(&path).unwrap();
469        let run_all = symbols.iter().find(|s| s.name == "runAll").unwrap();
470
471        let byte_start =
472            line_col_to_byte(&source, run_all.range.start_line, run_all.range.start_col);
473        let byte_end = line_col_to_byte(&source, run_all.range.end_line, run_all.range.end_col);
474
475        let calls = extract_calls_in_range(&source, tree.root_node(), byte_start, byte_end, lang);
476        let names: Vec<&str> = calls.iter().map(|(n, _)| n.as_str()).collect();
477
478        assert!(
479            names.contains(&"add"),
480            "runAll should call this.add, got: {:?}",
481            names
482        );
483        assert!(
484            names.contains(&"helper"),
485            "runAll should call helper, got: {:?}",
486            names
487        );
488    }
489
490    #[test]
491    fn extract_calls_unused_function_has_no_calls() {
492        let source = std::fs::read_to_string(fixture_path("calls.ts")).unwrap();
493        let mut parser = FileParser::new();
494        let path = fixture_path("calls.ts");
495        let (tree, lang) = parser.parse(&path).unwrap();
496
497        let ctx = make_ctx();
498        let symbols = ctx.provider().list_symbols(&path).unwrap();
499        let unused = symbols.iter().find(|s| s.name == "unused").unwrap();
500
501        let byte_start = line_col_to_byte(&source, unused.range.start_line, unused.range.start_col);
502        let byte_end = line_col_to_byte(&source, unused.range.end_line, unused.range.end_col);
503
504        let calls = extract_calls_in_range(&source, tree.root_node(), byte_start, byte_end, lang);
505        // console.log is the only call, but "log" or "console" aren't known symbols
506        let known_names = vec![
507            "helper",
508            "compute",
509            "orchestrate",
510            "unused",
511            "format",
512            "display",
513        ];
514        let filtered: Vec<&str> = calls
515            .iter()
516            .map(|(n, _)| n.as_str())
517            .filter(|n| known_names.contains(n))
518            .collect();
519        assert!(
520            filtered.is_empty(),
521            "unused should not call known symbols, got: {:?}",
522            filtered
523        );
524    }
525
526    // --- Context line tests ---
527
528    #[test]
529    fn context_lines_clamp_at_file_start() {
530        // helper() is at the top of the file (line 2) — context_before should be clamped
531        let ctx = make_ctx();
532        let path = fixture_path("calls.ts");
533        let symbols = ctx.provider().list_symbols(&path).unwrap();
534        let helper = symbols.iter().find(|s| s.name == "helper").unwrap();
535
536        let source = std::fs::read_to_string(&path).unwrap();
537        let lines: Vec<&str> = source.lines().collect();
538        let start = helper.range.start_line as usize;
539
540        // With context_lines=5, ctx_start should clamp to 0
541        let ctx_start = start.saturating_sub(5);
542        let context_before: Vec<&str> = lines[ctx_start..start].to_vec();
543        // Should have at most `start` lines (not panic)
544        assert!(context_before.len() <= start);
545    }
546
547    #[test]
548    fn context_lines_clamp_at_file_end() {
549        let ctx = make_ctx();
550        let path = fixture_path("calls.ts");
551        let symbols = ctx.provider().list_symbols(&path).unwrap();
552        let display = symbols.iter().find(|s| s.name == "display").unwrap();
553
554        let source = std::fs::read_to_string(&path).unwrap();
555        let lines: Vec<&str> = source.lines().collect();
556        let end = display.range.end_line as usize;
557
558        // With context_lines=20, should clamp to file length
559        let ctx_end = (end + 1 + 20).min(lines.len());
560        let context_after: Vec<&str> = if end + 1 < lines.len() {
561            lines[(end + 1)..ctx_end].to_vec()
562        } else {
563            vec![]
564        };
565        // Should not panic regardless of context_lines size
566        assert!(context_after.len() <= 20);
567    }
568
569    // --- Body extraction test ---
570
571    #[test]
572    fn body_extraction_matches_source() {
573        let ctx = make_ctx();
574        let path = fixture_path("calls.ts");
575        let symbols = ctx.provider().list_symbols(&path).unwrap();
576        let compute = symbols.iter().find(|s| s.name == "compute").unwrap();
577
578        let source = std::fs::read_to_string(&path).unwrap();
579        let lines: Vec<&str> = source.lines().collect();
580        let start = compute.range.start_line as usize;
581        let end = compute.range.end_line as usize;
582        let body = lines[start..=end].join("\n");
583
584        assert!(
585            body.contains("function compute"),
586            "body should contain function declaration"
587        );
588        assert!(
589            body.contains("helper(a)"),
590            "body should contain call to helper"
591        );
592        assert!(
593            body.contains("doubled + b"),
594            "body should contain return expression"
595        );
596    }
597
598    // --- Full zoom response tests ---
599
600    #[test]
601    fn zoom_response_has_calls_out_and_called_by() {
602        let ctx = make_ctx();
603        let path = fixture_path("calls.ts");
604
605        let req = make_zoom_request("z-1", path.to_str().unwrap(), "compute", None);
606        let resp = handle_zoom(&req, &ctx);
607
608        let json = serde_json::to_value(&resp).unwrap();
609        assert_eq!(json["success"], true, "zoom should succeed: {:?}", json);
610
611        let calls_out = json["annotations"]["calls_out"]
612            .as_array()
613            .expect("calls_out array");
614        let out_names: Vec<&str> = calls_out
615            .iter()
616            .map(|c| c["name"].as_str().unwrap())
617            .collect();
618        assert!(
619            out_names.contains(&"helper"),
620            "compute calls helper: {:?}",
621            out_names
622        );
623
624        let called_by = json["annotations"]["called_by"]
625            .as_array()
626            .expect("called_by array");
627        let by_names: Vec<&str> = called_by
628            .iter()
629            .map(|c| c["name"].as_str().unwrap())
630            .collect();
631        assert!(
632            by_names.contains(&"orchestrate"),
633            "orchestrate calls compute: {:?}",
634            by_names
635        );
636    }
637
638    #[test]
639    fn zoom_response_empty_annotations_for_unused() {
640        let ctx = make_ctx();
641        let path = fixture_path("calls.ts");
642
643        let req = make_zoom_request("z-2", path.to_str().unwrap(), "unused", None);
644        let resp = handle_zoom(&req, &ctx);
645
646        let json = serde_json::to_value(&resp).unwrap();
647        assert_eq!(json["success"], true);
648
649        let _calls_out = json["annotations"]["calls_out"].as_array().unwrap();
650        let called_by = json["annotations"]["called_by"].as_array().unwrap();
651
652        // calls_out exists (may contain console.log but no known symbols)
653        // called_by should be empty — nobody calls unused
654        assert!(
655            called_by.is_empty(),
656            "unused should not be called by anyone: {:?}",
657            called_by
658        );
659    }
660
661    #[test]
662    fn zoom_symbol_not_found() {
663        let ctx = make_ctx();
664        let path = fixture_path("calls.ts");
665
666        let req = make_zoom_request("z-3", path.to_str().unwrap(), "nonexistent", None);
667        let resp = handle_zoom(&req, &ctx);
668
669        let json = serde_json::to_value(&resp).unwrap();
670        assert_eq!(json["success"], false);
671        assert_eq!(json["code"], "symbol_not_found");
672    }
673
674    #[test]
675    fn zoom_custom_context_lines() {
676        let ctx = make_ctx();
677        let path = fixture_path("calls.ts");
678
679        let req = make_zoom_request("z-4", path.to_str().unwrap(), "compute", Some(1));
680        let resp = handle_zoom(&req, &ctx);
681
682        let json = serde_json::to_value(&resp).unwrap();
683        assert_eq!(json["success"], true);
684
685        let ctx_before = json["context_before"].as_array().unwrap();
686        let ctx_after = json["context_after"].as_array().unwrap();
687        // With context_lines=1, we get at most 1 line before and after
688        assert!(
689            ctx_before.len() <= 1,
690            "context_before should be ≤1: {:?}",
691            ctx_before
692        );
693        assert!(
694            ctx_after.len() <= 1,
695            "context_after should be ≤1: {:?}",
696            ctx_after
697        );
698    }
699
700    #[test]
701    fn zoom_missing_file_param() {
702        let ctx = make_ctx();
703        let req = make_raw_request("z-5", r#"{"id":"z-5","command":"zoom","symbol":"foo"}"#);
704        let resp = handle_zoom(&req, &ctx);
705
706        let json = serde_json::to_value(&resp).unwrap();
707        assert_eq!(json["success"], false);
708        assert_eq!(json["code"], "invalid_request");
709    }
710
711    #[test]
712    fn zoom_missing_symbol_param() {
713        let ctx = make_ctx();
714        let path = fixture_path("calls.ts");
715        let req_str = format!(
716            r#"{{"id":"z-6","command":"zoom","file":"{}"}}"#,
717            path.display()
718        );
719        let req: RawRequest = serde_json::from_str(&req_str).unwrap();
720        let resp = handle_zoom(&req, &ctx);
721
722        let json = serde_json::to_value(&resp).unwrap();
723        assert_eq!(json["success"], false);
724        assert_eq!(json["code"], "invalid_request");
725    }
726
727    // --- Helpers ---
728
729    fn make_zoom_request(
730        id: &str,
731        file: &str,
732        symbol: &str,
733        context_lines: Option<u64>,
734    ) -> RawRequest {
735        let mut json = serde_json::json!({
736            "id": id,
737            "command": "zoom",
738            "file": file,
739            "symbol": symbol,
740        });
741        if let Some(cl) = context_lines {
742            json["context_lines"] = serde_json::json!(cl);
743        }
744        serde_json::from_value(json).unwrap()
745    }
746
747    fn make_raw_request(_id: &str, json_str: &str) -> RawRequest {
748        serde_json::from_str(json_str).unwrap()
749    }
750}