Skip to main content

aft/commands/
zoom.rs

1use std::path::Path;
2
3use serde::Serialize;
4
5use crate::context::AppContext;
6use crate::lsp_hints;
7use crate::parser::{FileParser, LangId};
8use crate::protocol::{RawRequest, Response};
9use crate::symbols::Range;
10
11/// A reference to a called/calling function.
12#[derive(Debug, Clone, Serialize)]
13pub struct CallRef {
14    pub name: String,
15    /// 1-based line number of the call reference.
16    pub line: u32,
17}
18
19/// Annotations describing file-scoped call relationships.
20#[derive(Debug, Clone, Serialize)]
21pub struct Annotations {
22    pub calls_out: Vec<CallRef>,
23    pub called_by: Vec<CallRef>,
24}
25
26/// Response payload for the zoom command.
27#[derive(Debug, Clone, Serialize)]
28pub struct ZoomResponse {
29    pub name: String,
30    pub kind: String,
31    pub range: Range,
32    pub content: String,
33    pub context_before: Vec<String>,
34    pub context_after: Vec<String>,
35    pub annotations: Annotations,
36}
37
38/// Handle a `zoom` request.
39///
40/// Expects `file`, `symbol` in request params, optional `context_lines` (default 3).
41/// Resolves the symbol, extracts body + context, walks AST for call annotations.
42pub fn handle_zoom(req: &RawRequest, ctx: &AppContext) -> Response {
43    let file = match req.params.get("file").and_then(|v| v.as_str()) {
44        Some(f) => f,
45        None => {
46            return Response::error(
47                &req.id,
48                "invalid_request",
49                "zoom: missing required param 'file'",
50            );
51        }
52    };
53
54    let context_lines = req
55        .params
56        .get("context_lines")
57        .and_then(|v| v.as_u64())
58        .unwrap_or(3) as usize;
59
60    let start_line = req
61        .params
62        .get("start_line")
63        .and_then(|v| v.as_u64())
64        .map(|v| v as usize);
65    let end_line = req
66        .params
67        .get("end_line")
68        .and_then(|v| v.as_u64())
69        .map(|v| v as usize);
70
71    let path = Path::new(file);
72    if !path.exists() {
73        return Response::error(
74            &req.id,
75            "file_not_found",
76            format!("file not found: {}", file),
77        );
78    }
79
80    // Read source file early because both symbol mode and line-range mode need it.
81    let source = match std::fs::read_to_string(path) {
82        Ok(s) => s,
83        Err(e) => {
84            return Response::error(&req.id, "file_not_found", format!("{}: {}", file, e));
85        }
86    };
87
88    let lines: Vec<String> = source.lines().map(|l| l.to_string()).collect();
89
90    // Line-range mode: read arbitrary lines without requiring a symbol.
91    match (start_line, end_line) {
92        (Some(start), Some(end)) => {
93            if req.params.get("symbol").is_some() {
94                return Response::error(
95                    &req.id,
96                    "invalid_request",
97                    "zoom: provide either 'symbol' OR ('start_line' and 'end_line'), not both",
98                );
99            }
100            if start == 0 || end == 0 {
101                return Response::error(
102                    &req.id,
103                    "invalid_request",
104                    "zoom: 'start_line' and 'end_line' are 1-based and must be >= 1",
105                );
106            }
107            if end < start {
108                return Response::error(
109                    &req.id,
110                    "invalid_request",
111                    format!("zoom: end_line {} must be >= start_line {}", end, start),
112                );
113            }
114            if lines.is_empty() {
115                return Response::error(
116                    &req.id,
117                    "invalid_request",
118                    format!("zoom: {} is empty", file),
119                );
120            }
121
122            let start_idx = start - 1;
123            // Clamp end_line to file length (same as batch edits)
124            let end_idx = (end - 1).min(lines.len() - 1);
125            if start_idx >= lines.len() {
126                return Response::error(
127                    &req.id,
128                    "invalid_request",
129                    format!(
130                        "zoom: start_line {} is past end of {} ({} lines)",
131                        start,
132                        file,
133                        lines.len()
134                    ),
135                );
136            }
137
138            let content = lines[start_idx..=end_idx].join("\n");
139            let ctx_start = start_idx.saturating_sub(context_lines);
140            let context_before: Vec<String> = if ctx_start < start_idx {
141                lines[ctx_start..start_idx]
142                    .iter()
143                    .map(|l| l.to_string())
144                    .collect()
145            } else {
146                vec![]
147            };
148            let ctx_end = (end_idx + 1 + context_lines).min(lines.len());
149            let context_after: Vec<String> = if end_idx + 1 < lines.len() {
150                lines[(end_idx + 1)..ctx_end]
151                    .iter()
152                    .map(|l| l.to_string())
153                    .collect()
154            } else {
155                vec![]
156            };
157            let end_col = lines[end_idx].chars().count() as u32;
158
159            return Response::success(
160                &req.id,
161                serde_json::json!({
162                    "name": format!("lines {}-{}", start, end),
163                    "kind": "lines",
164                    "range": {
165                        "start_line": start,  // already 1-based from user input
166                        "start_col": 1,
167                        "end_line": end,      // already 1-based from user input
168                        "end_col": end_col + 1,
169                    },
170                    "content": content,
171                    "context_before": context_before,
172                    "context_after": context_after,
173                    "annotations": {
174                        "calls_out": [],
175                        "called_by": [],
176                    },
177                }),
178            );
179        }
180        (Some(_), None) | (None, Some(_)) => {
181            return Response::error(
182                &req.id,
183                "invalid_request",
184                "zoom: provide both 'start_line' and 'end_line' for line-range mode",
185            );
186        }
187        (None, None) => {}
188    }
189
190    let symbol_name = match req.params.get("symbol").and_then(|v| v.as_str()) {
191        Some(s) => s,
192        None => {
193            return Response::error(
194                &req.id,
195                "invalid_request",
196                "zoom: missing required param 'symbol' (or use 'start_line' and 'end_line')",
197            );
198        }
199    };
200
201    // Resolve the target symbol
202    let matches = match ctx.provider().resolve_symbol(path, symbol_name) {
203        Ok(m) => m,
204        Err(e) => {
205            return Response::error(&req.id, e.code(), e.to_string());
206        }
207    };
208
209    // LSP-enhanced disambiguation (S03)
210    let matches = if let Some(hints) = lsp_hints::parse_lsp_hints(req) {
211        lsp_hints::apply_lsp_disambiguation(matches, &hints)
212    } else {
213        matches
214    };
215
216    if matches.len() > 1 {
217        // Ambiguous — return qualified candidates
218        let candidates: Vec<String> = matches
219            .iter()
220            .map(|m| {
221                let sym = &m.symbol;
222                if sym.scope_chain.is_empty() {
223                    format!("{}:{}", sym.name, sym.range.start_line)
224                } else {
225                    format!(
226                        "{}::{}:{}",
227                        sym.scope_chain.join("::"),
228                        sym.name,
229                        sym.range.start_line
230                    )
231                }
232            })
233            .collect();
234        return Response::error(
235            &req.id,
236            "ambiguous_symbol",
237            format!(
238                "symbol '{}' is ambiguous, candidates: [{}]",
239                symbol_name,
240                candidates.join(", ")
241            ),
242        );
243    }
244
245    let target = &matches[0].symbol;
246    let start = target.range.start_line as usize;
247    let end = target.range.end_line as usize;
248
249    // When re-export following resolved to a different file, re-read that file's lines
250    let resolved_file_path = std::path::Path::new(&matches[0].file);
251    let resolved_lines: Vec<String>;
252    let effective_lines: &[String] = if resolved_file_path != path {
253        resolved_lines = match std::fs::read_to_string(resolved_file_path) {
254            Ok(src) => src.lines().map(|l| l.to_string()).collect(),
255            Err(_) => lines.clone(),
256        };
257        &resolved_lines
258    } else {
259        &lines
260    };
261
262    // Extract symbol body (0-based line indices)
263    let content = if end < effective_lines.len() {
264        effective_lines[start..=end].join("\n")
265    } else {
266        effective_lines[start..].join("\n")
267    };
268
269    // Context before
270    let ctx_start = start.saturating_sub(context_lines);
271    let context_before: Vec<String> = if ctx_start < start {
272        effective_lines[ctx_start..start]
273            .iter()
274            .map(|l| l.to_string())
275            .collect()
276    } else {
277        vec![]
278    };
279
280    // Context after
281    let ctx_end = (end + 1 + context_lines).min(effective_lines.len());
282    let context_after: Vec<String> = if end + 1 < effective_lines.len() {
283        effective_lines[(end + 1)..ctx_end]
284            .iter()
285            .map(|l| l.to_string())
286            .collect()
287    } else {
288        vec![]
289    };
290
291    // Get all symbols in the resolved file for call matching
292    let all_symbols = match ctx.provider().list_symbols(resolved_file_path) {
293        Ok(s) => s,
294        Err(e) => {
295            return Response::error(&req.id, e.code(), e.to_string());
296        }
297    };
298
299    let known_names: Vec<&str> = all_symbols.iter().map(|s| s.name.as_str()).collect();
300
301    // Parse AST for call extraction (use resolved file for cross-file re-exports)
302    let mut parser = FileParser::new();
303    let (tree, lang) = match parser.parse(resolved_file_path) {
304        Ok(r) => r,
305        Err(e) => {
306            return Response::error(&req.id, e.code(), e.to_string());
307        }
308    };
309
310    // calls_out: calls within the target symbol's byte range
311    let resolved_source = if resolved_file_path != path {
312        std::fs::read_to_string(resolved_file_path).unwrap_or_else(|_| source.clone())
313    } else {
314        source.clone()
315    };
316    let target_byte_start = line_col_to_byte(
317        &resolved_source,
318        target.range.start_line,
319        target.range.start_col,
320    );
321    let target_byte_end = line_col_to_byte(
322        &resolved_source,
323        target.range.end_line,
324        target.range.end_col,
325    );
326
327    let raw_calls = extract_calls_in_range(
328        &resolved_source,
329        tree.root_node(),
330        target_byte_start,
331        target_byte_end,
332        lang,
333    );
334    let calls_out: Vec<CallRef> = raw_calls
335        .into_iter()
336        .filter(|(name, _)| known_names.contains(&name.as_str()) && *name != target.name)
337        .map(|(name, line)| CallRef { name, line })
338        .collect();
339
340    // called_by: scan all other symbols for calls to this symbol
341    let mut called_by: Vec<CallRef> = Vec::new();
342    for sym in &all_symbols {
343        if sym.name == target.name && sym.range.start_line == target.range.start_line {
344            continue; // skip self
345        }
346        let sym_byte_start =
347            line_col_to_byte(&resolved_source, sym.range.start_line, sym.range.start_col);
348        let sym_byte_end =
349            line_col_to_byte(&resolved_source, sym.range.end_line, sym.range.end_col);
350        let calls = extract_calls_in_range(
351            &resolved_source,
352            tree.root_node(),
353            sym_byte_start,
354            sym_byte_end,
355            lang,
356        );
357        for (name, line) in calls {
358            if name == target.name {
359                called_by.push(CallRef {
360                    name: sym.name.clone(),
361                    line,
362                });
363            }
364        }
365    }
366
367    // Dedup called_by by (name, line)
368    called_by.sort_by(|a, b| a.name.cmp(&b.name).then(a.line.cmp(&b.line)));
369    called_by.dedup_by(|a, b| a.name == b.name && a.line == b.line);
370
371    let kind_str = serde_json::to_value(&target.kind)
372        .ok()
373        .and_then(|v| v.as_str().map(String::from))
374        .unwrap_or_else(|| format!("{:?}", target.kind).to_lowercase());
375
376    let resp = ZoomResponse {
377        name: target.name.clone(),
378        kind: kind_str,
379        range: target.range.clone(),
380        content,
381        context_before,
382        context_after,
383        annotations: Annotations {
384            calls_out,
385            called_by,
386        },
387    };
388
389    match serde_json::to_value(&resp) {
390        Ok(resp_json) => Response::success(&req.id, resp_json),
391        Err(err) => Response::error(
392            &req.id,
393            "internal_error",
394            format!("zoom: failed to serialize response: {err}"),
395        ),
396    }
397}
398
399/// Convert a 0-based line + column to a byte offset in the source.
400fn line_col_to_byte(source: &str, line: u32, col: u32) -> usize {
401    let mut byte = 0;
402    for (i, l) in source.lines().enumerate() {
403        if i == line as usize {
404            return byte + (col as usize).min(l.len());
405        }
406        byte += l.len() + 1; // +1 for newline
407    }
408    source.len()
409}
410
411/// Extract call expression names within a byte range of the AST.
412///
413/// Delegates to `crate::calls::extract_calls_in_range`.
414fn extract_calls_in_range(
415    source: &str,
416    root: tree_sitter::Node,
417    byte_start: usize,
418    byte_end: usize,
419    lang: LangId,
420) -> Vec<(String, u32)> {
421    crate::calls::extract_calls_in_range(source, root, byte_start, byte_end, lang)
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427    use crate::config::Config;
428    use crate::context::AppContext;
429    use crate::parser::TreeSitterProvider;
430    use std::path::PathBuf;
431
432    fn fixture_path(name: &str) -> PathBuf {
433        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
434            .join("tests")
435            .join("fixtures")
436            .join(name)
437    }
438
439    fn make_ctx() -> AppContext {
440        AppContext::new(Box::new(TreeSitterProvider::new()), Config::default())
441    }
442
443    // --- Call extraction tests ---
444
445    #[test]
446    fn extract_calls_finds_direct_calls() {
447        let source = std::fs::read_to_string(fixture_path("calls.ts")).unwrap();
448        let mut parser = FileParser::new();
449        let path = fixture_path("calls.ts");
450        let (tree, lang) = parser.parse(&path).unwrap();
451
452        // `compute` calls `helper` — find compute's range from symbols
453        let ctx = make_ctx();
454        let symbols = ctx.provider().list_symbols(&path).unwrap();
455        let compute = symbols.iter().find(|s| s.name == "compute").unwrap();
456
457        let byte_start =
458            line_col_to_byte(&source, compute.range.start_line, compute.range.start_col);
459        let byte_end = line_col_to_byte(&source, compute.range.end_line, compute.range.end_col);
460
461        let calls = extract_calls_in_range(&source, tree.root_node(), byte_start, byte_end, lang);
462        let names: Vec<&str> = calls.iter().map(|(n, _)| n.as_str()).collect();
463
464        assert!(
465            names.contains(&"helper"),
466            "compute should call helper, got: {:?}",
467            names
468        );
469    }
470
471    #[test]
472    fn extract_calls_finds_member_calls() {
473        let source = std::fs::read_to_string(fixture_path("calls.ts")).unwrap();
474        let mut parser = FileParser::new();
475        let path = fixture_path("calls.ts");
476        let (tree, lang) = parser.parse(&path).unwrap();
477
478        let ctx = make_ctx();
479        let symbols = ctx.provider().list_symbols(&path).unwrap();
480        let run_all = symbols.iter().find(|s| s.name == "runAll").unwrap();
481
482        let byte_start =
483            line_col_to_byte(&source, run_all.range.start_line, run_all.range.start_col);
484        let byte_end = line_col_to_byte(&source, run_all.range.end_line, run_all.range.end_col);
485
486        let calls = extract_calls_in_range(&source, tree.root_node(), byte_start, byte_end, lang);
487        let names: Vec<&str> = calls.iter().map(|(n, _)| n.as_str()).collect();
488
489        assert!(
490            names.contains(&"add"),
491            "runAll should call this.add, got: {:?}",
492            names
493        );
494        assert!(
495            names.contains(&"helper"),
496            "runAll should call helper, got: {:?}",
497            names
498        );
499    }
500
501    #[test]
502    fn extract_calls_unused_function_has_no_calls() {
503        let source = std::fs::read_to_string(fixture_path("calls.ts")).unwrap();
504        let mut parser = FileParser::new();
505        let path = fixture_path("calls.ts");
506        let (tree, lang) = parser.parse(&path).unwrap();
507
508        let ctx = make_ctx();
509        let symbols = ctx.provider().list_symbols(&path).unwrap();
510        let unused = symbols.iter().find(|s| s.name == "unused").unwrap();
511
512        let byte_start = line_col_to_byte(&source, unused.range.start_line, unused.range.start_col);
513        let byte_end = line_col_to_byte(&source, unused.range.end_line, unused.range.end_col);
514
515        let calls = extract_calls_in_range(&source, tree.root_node(), byte_start, byte_end, lang);
516        // console.log is the only call, but "log" or "console" aren't known symbols
517        let known_names = vec![
518            "helper",
519            "compute",
520            "orchestrate",
521            "unused",
522            "format",
523            "display",
524        ];
525        let filtered: Vec<&str> = calls
526            .iter()
527            .map(|(n, _)| n.as_str())
528            .filter(|n| known_names.contains(n))
529            .collect();
530        assert!(
531            filtered.is_empty(),
532            "unused should not call known symbols, got: {:?}",
533            filtered
534        );
535    }
536
537    // --- Context line tests ---
538
539    #[test]
540    fn context_lines_clamp_at_file_start() {
541        // helper() is at the top of the file (line 2) — context_before should be clamped
542        let ctx = make_ctx();
543        let path = fixture_path("calls.ts");
544        let symbols = ctx.provider().list_symbols(&path).unwrap();
545        let helper = symbols.iter().find(|s| s.name == "helper").unwrap();
546
547        let source = std::fs::read_to_string(&path).unwrap();
548        let lines: Vec<&str> = source.lines().collect();
549        let start = helper.range.start_line as usize;
550
551        // With context_lines=5, ctx_start should clamp to 0
552        let ctx_start = start.saturating_sub(5);
553        let context_before: Vec<&str> = lines[ctx_start..start].to_vec();
554        // Should have at most `start` lines (not panic)
555        assert!(context_before.len() <= start);
556    }
557
558    #[test]
559    fn context_lines_clamp_at_file_end() {
560        let ctx = make_ctx();
561        let path = fixture_path("calls.ts");
562        let symbols = ctx.provider().list_symbols(&path).unwrap();
563        let display = symbols.iter().find(|s| s.name == "display").unwrap();
564
565        let source = std::fs::read_to_string(&path).unwrap();
566        let lines: Vec<&str> = source.lines().collect();
567        let end = display.range.end_line as usize;
568
569        // With context_lines=20, should clamp to file length
570        let ctx_end = (end + 1 + 20).min(lines.len());
571        let context_after: Vec<&str> = if end + 1 < lines.len() {
572            lines[(end + 1)..ctx_end].to_vec()
573        } else {
574            vec![]
575        };
576        // Should not panic regardless of context_lines size
577        assert!(context_after.len() <= 20);
578    }
579
580    // --- Body extraction test ---
581
582    #[test]
583    fn body_extraction_matches_source() {
584        let ctx = make_ctx();
585        let path = fixture_path("calls.ts");
586        let symbols = ctx.provider().list_symbols(&path).unwrap();
587        let compute = symbols.iter().find(|s| s.name == "compute").unwrap();
588
589        let source = std::fs::read_to_string(&path).unwrap();
590        let lines: Vec<&str> = source.lines().collect();
591        let start = compute.range.start_line as usize;
592        let end = compute.range.end_line as usize;
593        let body = lines[start..=end].join("\n");
594
595        assert!(
596            body.contains("function compute"),
597            "body should contain function declaration"
598        );
599        assert!(
600            body.contains("helper(a)"),
601            "body should contain call to helper"
602        );
603        assert!(
604            body.contains("doubled + b"),
605            "body should contain return expression"
606        );
607    }
608
609    // --- Full zoom response tests ---
610
611    #[test]
612    fn zoom_response_has_calls_out_and_called_by() {
613        let ctx = make_ctx();
614        let path = fixture_path("calls.ts");
615
616        let req = make_zoom_request("z-1", path.to_str().unwrap(), "compute", None);
617        let resp = handle_zoom(&req, &ctx);
618
619        let json = serde_json::to_value(&resp).unwrap();
620        assert_eq!(json["success"], true, "zoom should succeed: {:?}", json);
621
622        let calls_out = json["annotations"]["calls_out"]
623            .as_array()
624            .expect("calls_out array");
625        let out_names: Vec<&str> = calls_out
626            .iter()
627            .map(|c| c["name"].as_str().unwrap())
628            .collect();
629        assert!(
630            out_names.contains(&"helper"),
631            "compute calls helper: {:?}",
632            out_names
633        );
634
635        let called_by = json["annotations"]["called_by"]
636            .as_array()
637            .expect("called_by array");
638        let by_names: Vec<&str> = called_by
639            .iter()
640            .map(|c| c["name"].as_str().unwrap())
641            .collect();
642        assert!(
643            by_names.contains(&"orchestrate"),
644            "orchestrate calls compute: {:?}",
645            by_names
646        );
647    }
648
649    #[test]
650    fn zoom_response_empty_annotations_for_unused() {
651        let ctx = make_ctx();
652        let path = fixture_path("calls.ts");
653
654        let req = make_zoom_request("z-2", path.to_str().unwrap(), "unused", None);
655        let resp = handle_zoom(&req, &ctx);
656
657        let json = serde_json::to_value(&resp).unwrap();
658        assert_eq!(json["success"], true);
659
660        let _calls_out = json["annotations"]["calls_out"].as_array().unwrap();
661        let called_by = json["annotations"]["called_by"].as_array().unwrap();
662
663        // calls_out exists (may contain console.log but no known symbols)
664        // called_by should be empty — nobody calls unused
665        assert!(
666            called_by.is_empty(),
667            "unused should not be called by anyone: {:?}",
668            called_by
669        );
670    }
671
672    #[test]
673    fn zoom_symbol_not_found() {
674        let ctx = make_ctx();
675        let path = fixture_path("calls.ts");
676
677        let req = make_zoom_request("z-3", path.to_str().unwrap(), "nonexistent", None);
678        let resp = handle_zoom(&req, &ctx);
679
680        let json = serde_json::to_value(&resp).unwrap();
681        assert_eq!(json["success"], false);
682        assert_eq!(json["code"], "symbol_not_found");
683    }
684
685    #[test]
686    fn zoom_custom_context_lines() {
687        let ctx = make_ctx();
688        let path = fixture_path("calls.ts");
689
690        let req = make_zoom_request("z-4", path.to_str().unwrap(), "compute", Some(1));
691        let resp = handle_zoom(&req, &ctx);
692
693        let json = serde_json::to_value(&resp).unwrap();
694        assert_eq!(json["success"], true);
695
696        let ctx_before = json["context_before"].as_array().unwrap();
697        let ctx_after = json["context_after"].as_array().unwrap();
698        // With context_lines=1, we get at most 1 line before and after
699        assert!(
700            ctx_before.len() <= 1,
701            "context_before should be ≤1: {:?}",
702            ctx_before
703        );
704        assert!(
705            ctx_after.len() <= 1,
706            "context_after should be ≤1: {:?}",
707            ctx_after
708        );
709    }
710
711    #[test]
712    fn zoom_missing_file_param() {
713        let ctx = make_ctx();
714        let req = make_raw_request("z-5", r#"{"id":"z-5","command":"zoom","symbol":"foo"}"#);
715        let resp = handle_zoom(&req, &ctx);
716
717        let json = serde_json::to_value(&resp).unwrap();
718        assert_eq!(json["success"], false);
719        assert_eq!(json["code"], "invalid_request");
720    }
721
722    #[test]
723    fn zoom_missing_symbol_param() {
724        let ctx = make_ctx();
725        let path = fixture_path("calls.ts");
726        let req_str = format!(
727            r#"{{"id":"z-6","command":"zoom","file":"{}"}}"#,
728            path.display()
729        );
730        let req: RawRequest = serde_json::from_str(&req_str).unwrap();
731        let resp = handle_zoom(&req, &ctx);
732
733        let json = serde_json::to_value(&resp).unwrap();
734        assert_eq!(json["success"], false);
735        assert_eq!(json["code"], "invalid_request");
736    }
737
738    // --- Helpers ---
739
740    fn make_zoom_request(
741        id: &str,
742        file: &str,
743        symbol: &str,
744        context_lines: Option<u64>,
745    ) -> RawRequest {
746        let mut json = serde_json::json!({
747            "id": id,
748            "command": "zoom",
749            "file": file,
750            "symbol": symbol,
751        });
752        if let Some(cl) = context_lines {
753            json["context_lines"] = serde_json::json!(cl);
754        }
755        serde_json::from_value(json).unwrap()
756    }
757
758    fn make_raw_request(_id: &str, json_str: &str) -> RawRequest {
759        serde_json::from_str(json_str).unwrap()
760    }
761}