Skip to main content

brainwires_rag/rag/client/
code_analysis.rs

1//! Code-navigation methods for [`RagClient`]: find definition, references, call graph.
2//!
3//! All items in this file are gated on the `code-analysis` feature.
4
5use super::RagClient;
6use crate::code_analysis::{DefinitionResult, ReferenceResult, RelationsProvider};
7use crate::rag::types::*;
8use anyhow::{Context, Result};
9use std::time::Instant;
10
11impl RagClient {
12    /// Find the definition of a symbol at a given file location
13    ///
14    /// This method looks up the symbol at the specified location and returns
15    /// its definition information if found.
16    ///
17    /// # Arguments
18    ///
19    /// * `request` - The find definition request containing file path, line, and column
20    ///
21    /// # Returns
22    ///
23    /// A response containing the definition if found, along with precision info
24    pub async fn find_definition(
25        &self,
26        request: FindDefinitionRequest,
27    ) -> Result<FindDefinitionResponse> {
28        let start = Instant::now();
29
30        // Validate request
31        request.validate().map_err(|e| anyhow::anyhow!(e))?;
32
33        // Create FileInfo for the file
34        let file_info = self.create_file_info(&request.file_path, request.project.clone())?;
35
36        // Get precision level for this language
37        let language = file_info.language.as_deref().unwrap_or("Unknown");
38        let precision = self.relations_provider.precision_level(language);
39
40        // Extract definitions from the file
41        let definitions = self
42            .relations_provider
43            .extract_definitions(&file_info)
44            .context("Failed to extract definitions")?;
45
46        // Find the definition at the requested position
47        let definition = definitions.into_iter().find(|def| {
48            request.line >= def.symbol_id.start_line
49                && request.line <= def.end_line
50                && (request.column == 0 || request.column >= def.symbol_id.start_col)
51        });
52
53        let result = definition.map(|def| DefinitionResult::from(&def));
54
55        Ok(FindDefinitionResponse {
56            definition: result,
57            precision: format!("{:?}", precision).to_lowercase(),
58            duration_ms: start.elapsed().as_millis() as u64,
59        })
60    }
61
62    /// Find all references to a symbol at a given file location
63    ///
64    /// This method finds all locations where the symbol at the given position
65    /// is referenced throughout the indexed codebase.
66    ///
67    /// # Arguments
68    ///
69    /// * `request` - The find references request containing file path, line, column, and limit
70    ///
71    /// # Returns
72    ///
73    /// A response containing the list of references found
74    pub async fn find_references(
75        &self,
76        request: FindReferencesRequest,
77    ) -> Result<FindReferencesResponse> {
78        let start = Instant::now();
79
80        // Validate request
81        request.validate().map_err(|e| anyhow::anyhow!(e))?;
82
83        // Create FileInfo for the file
84        let file_info = self.create_file_info(&request.file_path, request.project.clone())?;
85
86        // Get precision level for this language
87        let language = file_info.language.as_deref().unwrap_or("Unknown");
88        let precision = self.relations_provider.precision_level(language);
89
90        // Extract definitions from the file to find the symbol at the position
91        let definitions = self
92            .relations_provider
93            .extract_definitions(&file_info)
94            .context("Failed to extract definitions")?;
95
96        // Find the symbol at the requested position
97        let target_symbol = definitions.iter().find(|def| {
98            request.line >= def.symbol_id.start_line
99                && request.line <= def.end_line
100                && (request.column == 0 || request.column >= def.symbol_id.start_col)
101        });
102
103        let symbol_name = target_symbol.map(|def| def.symbol_id.name.clone());
104
105        // If no symbol found at position, return empty result
106        if symbol_name.is_none() {
107            return Ok(FindReferencesResponse {
108                symbol_name: None,
109                references: Vec::new(),
110                total_count: 0,
111                precision: format!("{:?}", precision).to_lowercase(),
112                duration_ms: start.elapsed().as_millis() as u64,
113            });
114        }
115
116        let symbol_name_str = symbol_name
117            .as_ref()
118            .expect("checked is_none above and returned early");
119
120        // Build symbol index from definitions
121        let mut symbol_index: std::collections::HashMap<
122            String,
123            Vec<crate::code_analysis::Definition>,
124        > = std::collections::HashMap::new();
125        for def in definitions {
126            symbol_index
127                .entry(def.symbol_id.name.clone())
128                .or_default()
129                .push(def);
130        }
131
132        // Find references in the same file
133        let references = self
134            .relations_provider
135            .extract_references(&file_info, &symbol_index)
136            .context("Failed to extract references")?;
137
138        // Filter to references matching our target symbol
139        let matching_refs: Vec<ReferenceResult> = references
140            .iter()
141            .filter(|r| {
142                // Check if this reference points to our target symbol
143                r.target_symbol_id.contains(symbol_name_str.as_str())
144            })
145            .take(request.limit)
146            .map(ReferenceResult::from)
147            .collect();
148
149        let total_count = matching_refs.len();
150
151        Ok(FindReferencesResponse {
152            symbol_name,
153            references: matching_refs,
154            total_count,
155            precision: format!("{:?}", precision).to_lowercase(),
156            duration_ms: start.elapsed().as_millis() as u64,
157        })
158    }
159
160    /// Get the call graph for a function at a given file location
161    ///
162    /// This method returns the callers (incoming calls) and callees (outgoing calls)
163    /// for the function at the specified location.
164    ///
165    /// # Arguments
166    ///
167    /// * `request` - The call graph request containing file path, line, column, and depth
168    ///
169    /// # Returns
170    ///
171    /// A response containing the root symbol and its call graph
172    pub async fn get_call_graph(
173        &self,
174        request: GetCallGraphRequest,
175    ) -> Result<GetCallGraphResponse> {
176        let start = Instant::now();
177
178        // Validate request
179        request.validate().map_err(|e| anyhow::anyhow!(e))?;
180
181        // Create FileInfo for the file
182        let file_info = self.create_file_info(&request.file_path, request.project.clone())?;
183
184        // Get precision level for this language
185        let language = file_info.language.as_deref().unwrap_or("Unknown");
186        let precision = self.relations_provider.precision_level(language);
187
188        // Extract definitions from the file to find the function at the position
189        let definitions = self
190            .relations_provider
191            .extract_definitions(&file_info)
192            .context("Failed to extract definitions")?;
193
194        // Find the function at the requested position
195        let target_function = definitions.iter().find(|def| {
196            // Only consider functions/methods
197            matches!(
198                def.symbol_id.kind,
199                crate::code_analysis::SymbolKind::Function
200                    | crate::code_analysis::SymbolKind::Method
201            ) && request.line >= def.symbol_id.start_line
202                && request.line <= def.end_line
203                && (request.column == 0 || request.column >= def.symbol_id.start_col)
204        });
205
206        // If no function found at position, return empty result
207        let root_symbol = match target_function {
208            Some(func) => crate::code_analysis::SymbolInfo {
209                name: func.symbol_id.name.clone(),
210                kind: func.symbol_id.kind,
211                file_path: request.file_path.clone(),
212                start_line: func.symbol_id.start_line,
213                end_line: func.end_line,
214                signature: func.signature.clone(),
215            },
216            None => {
217                return Ok(GetCallGraphResponse {
218                    root_symbol: None,
219                    callers: Vec::new(),
220                    callees: Vec::new(),
221                    precision: format!("{:?}", precision).to_lowercase(),
222                    duration_ms: start.elapsed().as_millis() as u64,
223                });
224            }
225        };
226
227        let function_name = root_symbol.name.clone();
228
229        // Build symbol index from definitions
230        let mut symbol_index: std::collections::HashMap<
231            String,
232            Vec<crate::code_analysis::Definition>,
233        > = std::collections::HashMap::new();
234        for def in &definitions {
235            symbol_index
236                .entry(def.symbol_id.name.clone())
237                .or_default()
238                .push(def.clone());
239        }
240
241        // Find references in the same file to identify callers
242        let references = self
243            .relations_provider
244            .extract_references(&file_info, &symbol_index)
245            .context("Failed to extract references")?;
246
247        // Find callers (references with Call kind pointing to our function)
248        let mut seen_callers = std::collections::HashSet::new();
249        let callers: Vec<crate::code_analysis::CallGraphNode> = references
250            .iter()
251            .filter(|r| {
252                r.reference_kind == crate::code_analysis::ReferenceKind::Call
253                    && r.target_symbol_id.contains(&function_name)
254            })
255            .filter_map(|r| {
256                // Try to find which function contains this call
257                definitions.iter().find(|def| {
258                    matches!(
259                        def.symbol_id.kind,
260                        crate::code_analysis::SymbolKind::Function
261                            | crate::code_analysis::SymbolKind::Method
262                    ) && r.start_line >= def.symbol_id.start_line
263                        && r.start_line <= def.end_line
264                })
265            })
266            .filter(|def| seen_callers.insert(def.symbol_id.name.clone()))
267            .map(|def| crate::code_analysis::CallGraphNode {
268                name: def.symbol_id.name.clone(),
269                kind: def.symbol_id.kind,
270                file_path: request.file_path.clone(),
271                line: def.symbol_id.start_line,
272                children: Vec::new(),
273            })
274            .collect();
275
276        // Find callees (calls made from within our function)
277        let target_func = target_function.expect("early return on None above guarantees Some");
278        let mut seen_callees = std::collections::HashSet::new();
279        let callees: Vec<crate::code_analysis::CallGraphNode> = references
280            .iter()
281            .filter(|r| {
282                r.reference_kind == crate::code_analysis::ReferenceKind::Call
283                    && r.start_line >= target_func.symbol_id.start_line
284                    && r.start_line <= target_func.end_line
285            })
286            .filter_map(|r| {
287                // Extract the called function name from target_symbol_id
288                let parts: Vec<&str> = r.target_symbol_id.split(':').collect();
289                if parts.len() >= 2 {
290                    Some(parts[1].to_string())
291                } else {
292                    None
293                }
294            })
295            .filter(|name| seen_callees.insert(name.clone()))
296            .filter_map(|name| {
297                // Find the definition of the called function
298                symbol_index
299                    .get(&name)
300                    .and_then(|defs| defs.first())
301                    .cloned()
302            })
303            .map(|def| crate::code_analysis::CallGraphNode {
304                name: def.symbol_id.name.clone(),
305                kind: def.symbol_id.kind,
306                file_path: request.file_path.clone(),
307                line: def.symbol_id.start_line,
308                children: Vec::new(),
309            })
310            .collect();
311
312        Ok(GetCallGraphResponse {
313            root_symbol: Some(root_symbol),
314            callers,
315            callees,
316            precision: format!("{:?}", precision).to_lowercase(),
317            duration_ms: start.elapsed().as_millis() as u64,
318        })
319    }
320}