Skip to main content

atomcode_core/tool/
blast_radius.rs

1use std::collections::HashSet;
2use std::path::Path;
3
4use anyhow::Result;
5use async_trait::async_trait;
6use serde::Deserialize;
7use serde_json::json;
8
9use super::{ApprovalRequirement, Tool, ToolContext, ToolDef, ToolResult};
10
11pub struct BlastRadiusTool;
12
13#[derive(Deserialize)]
14struct BlastRadiusArgs {
15    file: String,
16}
17
18fn shorten_path(path: &Path) -> String {
19    let components: Vec<_> = path.components().collect();
20    if components.len() <= 3 {
21        return path.display().to_string();
22    }
23    let last3: Vec<_> = components[components.len() - 3..]
24        .iter()
25        .map(|c| c.as_os_str())
26        .collect();
27    format!(
28        ".../{}",
29        last3
30            .iter()
31            .map(|s| s.to_string_lossy())
32            .collect::<Vec<_>>()
33            .join("/")
34    )
35}
36
37#[async_trait]
38impl Tool for BlastRadiusTool {
39    fn definition(&self) -> ToolDef {
40        ToolDef {
41            name: "blast_radius",
42            description: "Estimate the blast radius of changing a file. Shows direct dependents \
43                (depth 1), indirect dependents (depth 2-3), and total impacted file count.\n\
44                Use before refactoring to understand the scope of changes.\n\
45                Example: {\"file\": \"src/tool/mod.rs\"}"
46                .to_string(),
47            parameters: json!({
48                "type": "object",
49                "properties": {
50                    "file": { "type": "string", "description": "File path (relative to working dir or absolute)" }
51                },
52                "required": ["file"]
53            }),
54        }
55    }
56
57    fn approval(&self, _args: &str) -> ApprovalRequirement {
58        ApprovalRequirement::AutoApprove
59    }
60
61    fn approval_with_context(&self, args: &str, ctx: &ToolContext) -> ApprovalRequirement {
62        let parsed = match serde_json::from_str::<BlastRadiusArgs>(args) {
63            Ok(parsed) => parsed,
64            Err(_) => return self.approval(args),
65        };
66        let working_dir = match ctx.working_dir.try_read() {
67            Ok(wd) => wd.clone(),
68            Err(_) => return self.approval(args),
69        };
70        match super::approval_for_path(
71            &parsed.file,
72            &working_dir,
73            super::ExternalPathAction::Enumerate,
74        ) {
75            Ok(approval) => approval,
76            Err(_) => self.approval(args),
77        }
78    }
79
80    async fn execute(&self, args: &str, ctx: &ToolContext) -> Result<ToolResult> {
81        let parsed: BlastRadiusArgs = serde_json::from_str(args)?;
82        let wd = ctx.working_dir.read().await.clone();
83        let file_path = match super::inspect_path_access(&parsed.file, &wd) {
84            Ok(access) => access.path,
85            Err(err) => {
86                return Ok(ToolResult {
87                    call_id: String::new(),
88                    output: err.to_string(),
89                    success: false,
90                });
91            }
92        };
93
94        let graph = ctx.graph.read().await;
95
96        if !graph.is_ready() {
97            return Ok(ToolResult {
98                call_id: String::new(),
99                output: "Code graph is not yet indexed. The graph will be available after the \
100                    background indexer completes. Try again shortly."
101                    .to_string(),
102                success: false,
103            });
104        }
105
106        let symbols = match graph.symbols_in_file(&file_path) {
107            Some(ids) => ids.clone(),
108            None => {
109                return Ok(ToolResult {
110                    call_id: String::new(),
111                    output: format!(
112                        "File '{}' not found in code graph. Check the path or wait for indexing.",
113                        parsed.file
114                    ),
115                    success: false,
116                });
117            }
118        };
119
120        // Direct dependents (depth 1): files whose symbols directly call this file's symbols
121        let mut direct = HashSet::new();
122        for &sym_id in &symbols {
123            if let Some(edges) = graph.callers(sym_id) {
124                for edge in edges {
125                    if let Some(node) = graph.node(edge.to) {
126                        if node.file != file_path {
127                            direct.insert(node.file.clone());
128                        }
129                    }
130                }
131            }
132        }
133
134        // Indirect dependents (depth 2-3): use file_dependents with depth 3
135        let all_dependents = graph.file_dependents(&file_path, 3);
136        let mut indirect = HashSet::new();
137        for dep in &all_dependents {
138            if !direct.contains(dep) {
139                indirect.insert(dep.clone());
140            }
141        }
142
143        let total = direct.len() + indirect.len();
144
145        let mut out = format!("Blast radius for {}:\n\n", shorten_path(&file_path));
146
147        out.push_str(&format!("DIRECT DEPENDENTS ({} files):\n", direct.len()));
148        if direct.is_empty() {
149            out.push_str("  (none)\n");
150        } else {
151            let mut sorted: Vec<_> = direct.iter().collect();
152            sorted.sort();
153            for f in sorted {
154                out.push_str(&format!("  {}\n", shorten_path(f)));
155            }
156        }
157
158        out.push_str(&format!(
159            "\nINDIRECT DEPENDENTS ({} files):\n",
160            indirect.len()
161        ));
162        if indirect.is_empty() {
163            out.push_str("  (none)\n");
164        } else {
165            let mut sorted: Vec<_> = indirect.iter().collect();
166            sorted.sort();
167            for f in sorted {
168                out.push_str(&format!("  {}\n", shorten_path(f)));
169            }
170        }
171
172        out.push_str(&format!("\nTOTAL IMPACT: {} files\n", total));
173
174        Ok(ToolResult {
175            call_id: String::new(),
176            output: out,
177            success: true,
178        })
179    }
180}