Skip to main content

aft/commands/
zoom.rs

1use std::path::Path;
2
3use serde::Serialize;
4
5use crate::context::AppContext;
6use crate::edit::line_col_to_byte;
7use crate::lsp_hints;
8use crate::parser::{FileParser, LangId};
9use crate::protocol::{RawRequest, Response};
10use crate::symbols::Range;
11
12/// A reference to a called/calling function.
13#[derive(Debug, Clone, Serialize)]
14pub struct CallRef {
15    pub name: String,
16    /// 1-based line number of the call reference.
17    pub line: u32,
18}
19
20/// Annotations describing file-scoped call relationships.
21#[derive(Debug, Clone, Serialize)]
22pub struct Annotations {
23    pub calls_out: Vec<CallRef>,
24    pub called_by: Vec<CallRef>,
25}
26
27/// Response payload for the zoom command.
28#[derive(Debug, Clone, Serialize)]
29pub struct ZoomResponse {
30    pub name: String,
31    pub kind: String,
32    pub range: Range,
33    pub content: String,
34    pub context_before: Vec<String>,
35    pub context_after: Vec<String>,
36    pub annotations: Annotations,
37}
38
39/// Handle a `zoom` request.
40///
41/// Expects `file`, `symbol` in request params, optional `context_lines` (default 3).
42/// Resolves the symbol, extracts body + context, walks AST for call annotations.
43pub fn handle_zoom(req: &RawRequest, ctx: &AppContext) -> Response {
44    let file = match req.params.get("file").and_then(|v| v.as_str()) {
45        Some(f) => f,
46        None => {
47            return Response::error(
48                &req.id,
49                "invalid_request",
50                "zoom: missing required param 'file'",
51            );
52        }
53    };
54
55    let context_lines = req
56        .params
57        .get("context_lines")
58        .and_then(|v| v.as_u64())
59        .unwrap_or(3) as usize;
60
61    let start_line = req
62        .params
63        .get("start_line")
64        .and_then(|v| v.as_u64())
65        .map(|v| v as usize);
66    let end_line = req
67        .params
68        .get("end_line")
69        .and_then(|v| v.as_u64())
70        .map(|v| v as usize);
71
72    let path = match ctx.validate_path(&req.id, Path::new(file)) {
73        Ok(path) => path,
74        Err(resp) => return resp,
75    };
76    if !path.exists() {
77        return Response::error(
78            &req.id,
79            "file_not_found",
80            format!("file not found: {}", file),
81        );
82    }
83
84    // Read source file early because both symbol mode and line-range mode need it.
85    let source = match std::fs::read_to_string(&path) {
86        Ok(s) => s,
87        Err(e) => {
88            return Response::error(&req.id, "file_not_found", format!("{}: {}", file, e));
89        }
90    };
91
92    let lines: Vec<String> = source.lines().map(|l| l.to_string()).collect();
93
94    // Line-range mode: read arbitrary lines without requiring a symbol.
95    match (start_line, end_line) {
96        (Some(start), Some(end)) => {
97            if req.params.get("symbol").is_some() {
98                return Response::error(
99                    &req.id,
100                    "invalid_request",
101                    "zoom: provide either 'symbol' OR ('start_line' and 'end_line'), not both",
102                );
103            }
104            if start == 0 || end == 0 {
105                return Response::error(
106                    &req.id,
107                    "invalid_request",
108                    "zoom: 'start_line' and 'end_line' are 1-based and must be >= 1",
109                );
110            }
111            if end < start {
112                return Response::error(
113                    &req.id,
114                    "invalid_request",
115                    format!("zoom: end_line {} must be >= start_line {}", end, start),
116                );
117            }
118            if lines.is_empty() {
119                return Response::error(
120                    &req.id,
121                    "invalid_request",
122                    format!("zoom: {} is empty", file),
123                );
124            }
125
126            let start_idx = start - 1;
127            // Clamp end_line to file length (same as batch edits)
128            let end_idx = (end - 1).min(lines.len() - 1);
129            if start_idx >= lines.len() {
130                return Response::error(
131                    &req.id,
132                    "invalid_request",
133                    format!(
134                        "zoom: start_line {} is past end of {} ({} lines)",
135                        start,
136                        file,
137                        lines.len()
138                    ),
139                );
140            }
141
142            let content = lines[start_idx..=end_idx].join("\n");
143            let ctx_start = start_idx.saturating_sub(context_lines);
144            let context_before: Vec<String> = if ctx_start < start_idx {
145                lines[ctx_start..start_idx]
146                    .iter()
147                    .map(|l| l.to_string())
148                    .collect()
149            } else {
150                vec![]
151            };
152            let ctx_end = (end_idx + 1 + context_lines).min(lines.len());
153            let context_after: Vec<String> = if end_idx + 1 < lines.len() {
154                lines[(end_idx + 1)..ctx_end]
155                    .iter()
156                    .map(|l| l.to_string())
157                    .collect()
158            } else {
159                vec![]
160            };
161            let end_col = lines[end_idx].chars().count() as u32;
162
163            return Response::success(
164                &req.id,
165                serde_json::json!({
166                    "name": format!("lines {}-{}", start, end),
167                    "kind": "lines",
168                    "range": {
169                        "start_line": start,  // already 1-based from user input
170                        "start_col": 1,
171                        "end_line": end,      // already 1-based from user input
172                        "end_col": end_col + 1,
173                    },
174                    "content": content,
175                    "context_before": context_before,
176                    "context_after": context_after,
177                    "annotations": {
178                        "calls_out": [],
179                        "called_by": [],
180                    },
181                }),
182            );
183        }
184        (Some(_), None) | (None, Some(_)) => {
185            return Response::error(
186                &req.id,
187                "invalid_request",
188                "zoom: provide both 'start_line' and 'end_line' for line-range mode",
189            );
190        }
191        (None, None) => {}
192    }
193
194    let symbol_name = match req.params.get("symbol").and_then(|v| v.as_str()) {
195        Some(s) => s,
196        None => {
197            return Response::error(
198                &req.id,
199                "invalid_request",
200                "zoom: missing required param 'symbol' (or use 'start_line' and 'end_line')",
201            );
202        }
203    };
204
205    // Resolve the target symbol
206    let matches = match ctx.provider().resolve_symbol(&path, symbol_name) {
207        Ok(m) => m,
208        Err(e) => {
209            return Response::error(&req.id, e.code(), e.to_string());
210        }
211    };
212
213    // LSP-enhanced disambiguation (S03)
214    let matches = if let Some(hints) = lsp_hints::parse_lsp_hints(req) {
215        lsp_hints::apply_lsp_disambiguation(matches, &hints)
216    } else {
217        matches
218    };
219
220    if matches.len() > 1 {
221        // Ambiguous — return qualified candidates
222        let candidates: Vec<String> = matches
223            .iter()
224            .map(|m| {
225                let sym = &m.symbol;
226                if sym.scope_chain.is_empty() {
227                    format!("{}:{}", sym.name, sym.range.start_line)
228                } else {
229                    format!(
230                        "{}::{}:{}",
231                        sym.scope_chain.join("::"),
232                        sym.name,
233                        sym.range.start_line
234                    )
235                }
236            })
237            .collect();
238        return Response::error(
239            &req.id,
240            "ambiguous_symbol",
241            format!(
242                "symbol '{}' is ambiguous, candidates: [{}]",
243                symbol_name,
244                candidates.join(", ")
245            ),
246        );
247    }
248
249    let target = &matches[0].symbol;
250    let start = target.range.start_line as usize;
251    let end = target.range.end_line as usize;
252
253    // When re-export following resolved to a different file, re-read that file's lines
254    let resolved_file_path = std::path::Path::new(&matches[0].file);
255    let resolved_lines: Vec<String>;
256    let effective_lines: &[String] = if resolved_file_path != path {
257        resolved_lines = match std::fs::read_to_string(resolved_file_path) {
258            Ok(src) => src.lines().map(|l| l.to_string()).collect(),
259            Err(_) => lines.clone(),
260        };
261        &resolved_lines
262    } else {
263        &lines
264    };
265
266    // Extract symbol body (0-based line indices)
267    let content = if end < effective_lines.len() {
268        effective_lines[start..=end].join("\n")
269    } else {
270        effective_lines[start..].join("\n")
271    };
272
273    // Context before
274    let ctx_start = start.saturating_sub(context_lines);
275    let context_before: Vec<String> = if ctx_start < start {
276        effective_lines[ctx_start..start]
277            .iter()
278            .map(|l| l.to_string())
279            .collect()
280    } else {
281        vec![]
282    };
283
284    // Context after
285    let ctx_end = (end + 1 + context_lines).min(effective_lines.len());
286    let context_after: Vec<String> = if end + 1 < effective_lines.len() {
287        effective_lines[(end + 1)..ctx_end]
288            .iter()
289            .map(|l| l.to_string())
290            .collect()
291    } else {
292        vec![]
293    };
294
295    // Get all symbols in the resolved file for call matching
296    let all_symbols = match ctx.provider().list_symbols(resolved_file_path) {
297        Ok(s) => s,
298        Err(e) => {
299            return Response::error(&req.id, e.code(), e.to_string());
300        }
301    };
302
303    let known_names: Vec<&str> = all_symbols.iter().map(|s| s.name.as_str()).collect();
304
305    // Parse AST for call extraction (use resolved file for cross-file re-exports)
306    let mut parser = FileParser::new();
307    let (tree, lang) = match parser.parse(resolved_file_path) {
308        Ok(r) => r,
309        Err(e) => {
310            return Response::error(&req.id, e.code(), e.to_string());
311        }
312    };
313
314    // calls_out: calls within the target symbol's byte range
315    let resolved_source = if resolved_file_path != path {
316        std::fs::read_to_string(resolved_file_path).unwrap_or_else(|_| source.clone())
317    } else {
318        source.clone()
319    };
320    let target_byte_start = line_col_to_byte(
321        &resolved_source,
322        target.range.start_line,
323        target.range.start_col,
324    );
325    let target_byte_end = line_col_to_byte(
326        &resolved_source,
327        target.range.end_line,
328        target.range.end_col,
329    );
330
331    let raw_calls = extract_calls_in_range(
332        &resolved_source,
333        tree.root_node(),
334        target_byte_start,
335        target_byte_end,
336        lang,
337    );
338    let calls_out: Vec<CallRef> = raw_calls
339        .into_iter()
340        .filter(|(name, _)| known_names.contains(&name.as_str()) && *name != target.name)
341        .map(|(name, line)| CallRef { name, line })
342        .collect();
343
344    // called_by: scan all other symbols for calls to this symbol
345    let mut called_by: Vec<CallRef> = Vec::new();
346    for sym in &all_symbols {
347        if sym.name == target.name && sym.range.start_line == target.range.start_line {
348            continue; // skip self
349        }
350        let sym_byte_start =
351            line_col_to_byte(&resolved_source, sym.range.start_line, sym.range.start_col);
352        let sym_byte_end =
353            line_col_to_byte(&resolved_source, sym.range.end_line, sym.range.end_col);
354        let calls = extract_calls_in_range(
355            &resolved_source,
356            tree.root_node(),
357            sym_byte_start,
358            sym_byte_end,
359            lang,
360        );
361        for (name, line) in calls {
362            if name == target.name {
363                called_by.push(CallRef {
364                    name: sym.name.clone(),
365                    line,
366                });
367            }
368        }
369    }
370
371    // Dedup called_by by (name, line)
372    called_by.sort_by(|a, b| a.name.cmp(&b.name).then(a.line.cmp(&b.line)));
373    called_by.dedup_by(|a, b| a.name == b.name && a.line == b.line);
374
375    let kind_str = serde_json::to_value(&target.kind)
376        .ok()
377        .and_then(|v| v.as_str().map(String::from))
378        .unwrap_or_else(|| format!("{:?}", target.kind).to_lowercase());
379
380    let resp = ZoomResponse {
381        name: target.name.clone(),
382        kind: kind_str,
383        range: target.range.clone(),
384        content,
385        context_before,
386        context_after,
387        annotations: Annotations {
388            calls_out,
389            called_by,
390        },
391    };
392
393    match serde_json::to_value(&resp) {
394        Ok(resp_json) => Response::success(&req.id, resp_json),
395        Err(err) => Response::error(
396            &req.id,
397            "internal_error",
398            format!("zoom: failed to serialize response: {err}"),
399        ),
400    }
401}
402
403/// Extract call expression names within a byte range of the AST.
404///
405/// Delegates to `crate::calls::extract_calls_in_range`.
406fn extract_calls_in_range(
407    source: &str,
408    root: tree_sitter::Node,
409    byte_start: usize,
410    byte_end: usize,
411    lang: LangId,
412) -> Vec<(String, u32)> {
413    crate::calls::extract_calls_in_range(source, root, byte_start, byte_end, lang)
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419    use crate::config::Config;
420    use crate::context::AppContext;
421    use crate::parser::TreeSitterProvider;
422    use std::path::PathBuf;
423
424    fn fixture_path(name: &str) -> PathBuf {
425        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
426            .join("tests")
427            .join("fixtures")
428            .join(name)
429    }
430
431    fn make_ctx() -> AppContext {
432        AppContext::new(Box::new(TreeSitterProvider::new()), Config::default())
433    }
434
435    // --- Call extraction tests ---
436
437    #[test]
438    fn extract_calls_finds_direct_calls() {
439        let source = std::fs::read_to_string(fixture_path("calls.ts")).unwrap();
440        let mut parser = FileParser::new();
441        let path = fixture_path("calls.ts");
442        let (tree, lang) = parser.parse(&path).unwrap();
443
444        // `compute` calls `helper` — find compute's range from symbols
445        let ctx = make_ctx();
446        let symbols = ctx.provider().list_symbols(&path).unwrap();
447        let compute = symbols.iter().find(|s| s.name == "compute").unwrap();
448
449        let byte_start =
450            line_col_to_byte(&source, compute.range.start_line, compute.range.start_col);
451        let byte_end = line_col_to_byte(&source, compute.range.end_line, compute.range.end_col);
452
453        let calls = extract_calls_in_range(&source, tree.root_node(), byte_start, byte_end, lang);
454        let names: Vec<&str> = calls.iter().map(|(n, _)| n.as_str()).collect();
455
456        assert!(
457            names.contains(&"helper"),
458            "compute should call helper, got: {:?}",
459            names
460        );
461    }
462
463    #[test]
464    fn extract_calls_finds_member_calls() {
465        let source = std::fs::read_to_string(fixture_path("calls.ts")).unwrap();
466        let mut parser = FileParser::new();
467        let path = fixture_path("calls.ts");
468        let (tree, lang) = parser.parse(&path).unwrap();
469
470        let ctx = make_ctx();
471        let symbols = ctx.provider().list_symbols(&path).unwrap();
472        let run_all = symbols.iter().find(|s| s.name == "runAll").unwrap();
473
474        let byte_start =
475            line_col_to_byte(&source, run_all.range.start_line, run_all.range.start_col);
476        let byte_end = line_col_to_byte(&source, run_all.range.end_line, run_all.range.end_col);
477
478        let calls = extract_calls_in_range(&source, tree.root_node(), byte_start, byte_end, lang);
479        let names: Vec<&str> = calls.iter().map(|(n, _)| n.as_str()).collect();
480
481        assert!(
482            names.contains(&"add"),
483            "runAll should call this.add, got: {:?}",
484            names
485        );
486        assert!(
487            names.contains(&"helper"),
488            "runAll should call helper, got: {:?}",
489            names
490        );
491    }
492
493    #[test]
494    fn extract_calls_unused_function_has_no_calls() {
495        let source = std::fs::read_to_string(fixture_path("calls.ts")).unwrap();
496        let mut parser = FileParser::new();
497        let path = fixture_path("calls.ts");
498        let (tree, lang) = parser.parse(&path).unwrap();
499
500        let ctx = make_ctx();
501        let symbols = ctx.provider().list_symbols(&path).unwrap();
502        let unused = symbols.iter().find(|s| s.name == "unused").unwrap();
503
504        let byte_start = line_col_to_byte(&source, unused.range.start_line, unused.range.start_col);
505        let byte_end = line_col_to_byte(&source, unused.range.end_line, unused.range.end_col);
506
507        let calls = extract_calls_in_range(&source, tree.root_node(), byte_start, byte_end, lang);
508        // console.log is the only call, but "log" or "console" aren't known symbols
509        let known_names = vec![
510            "helper",
511            "compute",
512            "orchestrate",
513            "unused",
514            "format",
515            "display",
516        ];
517        let filtered: Vec<&str> = calls
518            .iter()
519            .map(|(n, _)| n.as_str())
520            .filter(|n| known_names.contains(n))
521            .collect();
522        assert!(
523            filtered.is_empty(),
524            "unused should not call known symbols, got: {:?}",
525            filtered
526        );
527    }
528
529    // --- Context line tests ---
530
531    #[test]
532    fn context_lines_clamp_at_file_start() {
533        // helper() is at the top of the file (line 2) — context_before should be clamped
534        let ctx = make_ctx();
535        let path = fixture_path("calls.ts");
536        let symbols = ctx.provider().list_symbols(&path).unwrap();
537        let helper = symbols.iter().find(|s| s.name == "helper").unwrap();
538
539        let source = std::fs::read_to_string(&path).unwrap();
540        let lines: Vec<&str> = source.lines().collect();
541        let start = helper.range.start_line as usize;
542
543        // With context_lines=5, ctx_start should clamp to 0
544        let ctx_start = start.saturating_sub(5);
545        let context_before: Vec<&str> = lines[ctx_start..start].to_vec();
546        // Should have at most `start` lines (not panic)
547        assert!(context_before.len() <= start);
548    }
549
550    #[test]
551    fn context_lines_clamp_at_file_end() {
552        let ctx = make_ctx();
553        let path = fixture_path("calls.ts");
554        let symbols = ctx.provider().list_symbols(&path).unwrap();
555        let display = symbols.iter().find(|s| s.name == "display").unwrap();
556
557        let source = std::fs::read_to_string(&path).unwrap();
558        let lines: Vec<&str> = source.lines().collect();
559        let end = display.range.end_line as usize;
560
561        // With context_lines=20, should clamp to file length
562        let ctx_end = (end + 1 + 20).min(lines.len());
563        let context_after: Vec<&str> = if end + 1 < lines.len() {
564            lines[(end + 1)..ctx_end].to_vec()
565        } else {
566            vec![]
567        };
568        // Should not panic regardless of context_lines size
569        assert!(context_after.len() <= 20);
570    }
571
572    // --- Body extraction test ---
573
574    #[test]
575    fn body_extraction_matches_source() {
576        let ctx = make_ctx();
577        let path = fixture_path("calls.ts");
578        let symbols = ctx.provider().list_symbols(&path).unwrap();
579        let compute = symbols.iter().find(|s| s.name == "compute").unwrap();
580
581        let source = std::fs::read_to_string(&path).unwrap();
582        let lines: Vec<&str> = source.lines().collect();
583        let start = compute.range.start_line as usize;
584        let end = compute.range.end_line as usize;
585        let body = lines[start..=end].join("\n");
586
587        assert!(
588            body.contains("function compute"),
589            "body should contain function declaration"
590        );
591        assert!(
592            body.contains("helper(a)"),
593            "body should contain call to helper"
594        );
595        assert!(
596            body.contains("doubled + b"),
597            "body should contain return expression"
598        );
599    }
600
601    // --- Full zoom response tests ---
602
603    #[test]
604    fn zoom_response_has_calls_out_and_called_by() {
605        let ctx = make_ctx();
606        let path = fixture_path("calls.ts");
607
608        let req = make_zoom_request("z-1", path.to_str().unwrap(), "compute", None);
609        let resp = handle_zoom(&req, &ctx);
610
611        let json = serde_json::to_value(&resp).unwrap();
612        assert_eq!(json["success"], true, "zoom should succeed: {:?}", json);
613
614        let calls_out = json["annotations"]["calls_out"]
615            .as_array()
616            .expect("calls_out array");
617        let out_names: Vec<&str> = calls_out
618            .iter()
619            .map(|c| c["name"].as_str().unwrap())
620            .collect();
621        assert!(
622            out_names.contains(&"helper"),
623            "compute calls helper: {:?}",
624            out_names
625        );
626
627        let called_by = json["annotations"]["called_by"]
628            .as_array()
629            .expect("called_by array");
630        let by_names: Vec<&str> = called_by
631            .iter()
632            .map(|c| c["name"].as_str().unwrap())
633            .collect();
634        assert!(
635            by_names.contains(&"orchestrate"),
636            "orchestrate calls compute: {:?}",
637            by_names
638        );
639    }
640
641    #[test]
642    fn zoom_response_empty_annotations_for_unused() {
643        let ctx = make_ctx();
644        let path = fixture_path("calls.ts");
645
646        let req = make_zoom_request("z-2", path.to_str().unwrap(), "unused", None);
647        let resp = handle_zoom(&req, &ctx);
648
649        let json = serde_json::to_value(&resp).unwrap();
650        assert_eq!(json["success"], true);
651
652        let _calls_out = json["annotations"]["calls_out"].as_array().unwrap();
653        let called_by = json["annotations"]["called_by"].as_array().unwrap();
654
655        // calls_out exists (may contain console.log but no known symbols)
656        // called_by should be empty — nobody calls unused
657        assert!(
658            called_by.is_empty(),
659            "unused should not be called by anyone: {:?}",
660            called_by
661        );
662    }
663
664    #[test]
665    fn zoom_symbol_not_found() {
666        let ctx = make_ctx();
667        let path = fixture_path("calls.ts");
668
669        let req = make_zoom_request("z-3", path.to_str().unwrap(), "nonexistent", None);
670        let resp = handle_zoom(&req, &ctx);
671
672        let json = serde_json::to_value(&resp).unwrap();
673        assert_eq!(json["success"], false);
674        assert_eq!(json["code"], "symbol_not_found");
675    }
676
677    #[test]
678    fn zoom_custom_context_lines() {
679        let ctx = make_ctx();
680        let path = fixture_path("calls.ts");
681
682        let req = make_zoom_request("z-4", path.to_str().unwrap(), "compute", Some(1));
683        let resp = handle_zoom(&req, &ctx);
684
685        let json = serde_json::to_value(&resp).unwrap();
686        assert_eq!(json["success"], true);
687
688        let ctx_before = json["context_before"].as_array().unwrap();
689        let ctx_after = json["context_after"].as_array().unwrap();
690        // With context_lines=1, we get at most 1 line before and after
691        assert!(
692            ctx_before.len() <= 1,
693            "context_before should be ≤1: {:?}",
694            ctx_before
695        );
696        assert!(
697            ctx_after.len() <= 1,
698            "context_after should be ≤1: {:?}",
699            ctx_after
700        );
701    }
702
703    #[test]
704    fn zoom_missing_file_param() {
705        let ctx = make_ctx();
706        let req = make_raw_request("z-5", r#"{"id":"z-5","command":"zoom","symbol":"foo"}"#);
707        let resp = handle_zoom(&req, &ctx);
708
709        let json = serde_json::to_value(&resp).unwrap();
710        assert_eq!(json["success"], false);
711        assert_eq!(json["code"], "invalid_request");
712    }
713
714    #[test]
715    fn zoom_missing_symbol_param() {
716        let ctx = make_ctx();
717        let path = fixture_path("calls.ts");
718        let req_str = format!(
719            r#"{{"id":"z-6","command":"zoom","file":"{}"}}"#,
720            path.display()
721        );
722        let req: RawRequest = serde_json::from_str(&req_str).unwrap();
723        let resp = handle_zoom(&req, &ctx);
724
725        let json = serde_json::to_value(&resp).unwrap();
726        assert_eq!(json["success"], false);
727        assert_eq!(json["code"], "invalid_request");
728    }
729
730    // --- Helpers ---
731
732    fn make_zoom_request(
733        id: &str,
734        file: &str,
735        symbol: &str,
736        context_lines: Option<u64>,
737    ) -> RawRequest {
738        let mut json = serde_json::json!({
739            "id": id,
740            "command": "zoom",
741            "file": file,
742            "symbol": symbol,
743        });
744        if let Some(cl) = context_lines {
745            json["context_lines"] = serde_json::json!(cl);
746        }
747        serde_json::from_value(json).unwrap()
748    }
749
750    fn make_raw_request(_id: &str, json_str: &str) -> RawRequest {
751        serde_json::from_str(json_str).unwrap()
752    }
753}