brainwires_rag/rag/client/
code_analysis.rs1use super::RagClient;
6use crate::code_analysis::{DefinitionResult, ReferenceResult, RelationsProvider};
7use crate::rag::types::*;
8use anyhow::{Context, Result};
9use std::time::Instant;
10
11impl RagClient {
12 pub async fn find_definition(
25 &self,
26 request: FindDefinitionRequest,
27 ) -> Result<FindDefinitionResponse> {
28 let start = Instant::now();
29
30 request.validate().map_err(|e| anyhow::anyhow!(e))?;
32
33 let file_info = self.create_file_info(&request.file_path, request.project.clone())?;
35
36 let language = file_info.language.as_deref().unwrap_or("Unknown");
38 let precision = self.relations_provider.precision_level(language);
39
40 let definitions = self
42 .relations_provider
43 .extract_definitions(&file_info)
44 .context("Failed to extract definitions")?;
45
46 let definition = definitions.into_iter().find(|def| {
48 request.line >= def.symbol_id.start_line
49 && request.line <= def.end_line
50 && (request.column == 0 || request.column >= def.symbol_id.start_col)
51 });
52
53 let result = definition.map(|def| DefinitionResult::from(&def));
54
55 Ok(FindDefinitionResponse {
56 definition: result,
57 precision: format!("{:?}", precision).to_lowercase(),
58 duration_ms: start.elapsed().as_millis() as u64,
59 })
60 }
61
62 pub async fn find_references(
75 &self,
76 request: FindReferencesRequest,
77 ) -> Result<FindReferencesResponse> {
78 let start = Instant::now();
79
80 request.validate().map_err(|e| anyhow::anyhow!(e))?;
82
83 let file_info = self.create_file_info(&request.file_path, request.project.clone())?;
85
86 let language = file_info.language.as_deref().unwrap_or("Unknown");
88 let precision = self.relations_provider.precision_level(language);
89
90 let definitions = self
92 .relations_provider
93 .extract_definitions(&file_info)
94 .context("Failed to extract definitions")?;
95
96 let target_symbol = definitions.iter().find(|def| {
98 request.line >= def.symbol_id.start_line
99 && request.line <= def.end_line
100 && (request.column == 0 || request.column >= def.symbol_id.start_col)
101 });
102
103 let symbol_name = target_symbol.map(|def| def.symbol_id.name.clone());
104
105 if symbol_name.is_none() {
107 return Ok(FindReferencesResponse {
108 symbol_name: None,
109 references: Vec::new(),
110 total_count: 0,
111 precision: format!("{:?}", precision).to_lowercase(),
112 duration_ms: start.elapsed().as_millis() as u64,
113 });
114 }
115
116 let symbol_name_str = symbol_name
117 .as_ref()
118 .expect("checked is_none above and returned early");
119
120 let mut symbol_index: std::collections::HashMap<
122 String,
123 Vec<crate::code_analysis::Definition>,
124 > = std::collections::HashMap::new();
125 for def in definitions {
126 symbol_index
127 .entry(def.symbol_id.name.clone())
128 .or_default()
129 .push(def);
130 }
131
132 let references = self
134 .relations_provider
135 .extract_references(&file_info, &symbol_index)
136 .context("Failed to extract references")?;
137
138 let matching_refs: Vec<ReferenceResult> = references
140 .iter()
141 .filter(|r| {
142 r.target_symbol_id.contains(symbol_name_str.as_str())
144 })
145 .take(request.limit)
146 .map(ReferenceResult::from)
147 .collect();
148
149 let total_count = matching_refs.len();
150
151 Ok(FindReferencesResponse {
152 symbol_name,
153 references: matching_refs,
154 total_count,
155 precision: format!("{:?}", precision).to_lowercase(),
156 duration_ms: start.elapsed().as_millis() as u64,
157 })
158 }
159
160 pub async fn get_call_graph(
173 &self,
174 request: GetCallGraphRequest,
175 ) -> Result<GetCallGraphResponse> {
176 let start = Instant::now();
177
178 request.validate().map_err(|e| anyhow::anyhow!(e))?;
180
181 let file_info = self.create_file_info(&request.file_path, request.project.clone())?;
183
184 let language = file_info.language.as_deref().unwrap_or("Unknown");
186 let precision = self.relations_provider.precision_level(language);
187
188 let definitions = self
190 .relations_provider
191 .extract_definitions(&file_info)
192 .context("Failed to extract definitions")?;
193
194 let target_function = definitions.iter().find(|def| {
196 matches!(
198 def.symbol_id.kind,
199 crate::code_analysis::SymbolKind::Function
200 | crate::code_analysis::SymbolKind::Method
201 ) && request.line >= def.symbol_id.start_line
202 && request.line <= def.end_line
203 && (request.column == 0 || request.column >= def.symbol_id.start_col)
204 });
205
206 let root_symbol = match target_function {
208 Some(func) => crate::code_analysis::SymbolInfo {
209 name: func.symbol_id.name.clone(),
210 kind: func.symbol_id.kind,
211 file_path: request.file_path.clone(),
212 start_line: func.symbol_id.start_line,
213 end_line: func.end_line,
214 signature: func.signature.clone(),
215 },
216 None => {
217 return Ok(GetCallGraphResponse {
218 root_symbol: None,
219 callers: Vec::new(),
220 callees: Vec::new(),
221 precision: format!("{:?}", precision).to_lowercase(),
222 duration_ms: start.elapsed().as_millis() as u64,
223 });
224 }
225 };
226
227 let function_name = root_symbol.name.clone();
228
229 let mut symbol_index: std::collections::HashMap<
231 String,
232 Vec<crate::code_analysis::Definition>,
233 > = std::collections::HashMap::new();
234 for def in &definitions {
235 symbol_index
236 .entry(def.symbol_id.name.clone())
237 .or_default()
238 .push(def.clone());
239 }
240
241 let references = self
243 .relations_provider
244 .extract_references(&file_info, &symbol_index)
245 .context("Failed to extract references")?;
246
247 let mut seen_callers = std::collections::HashSet::new();
249 let callers: Vec<crate::code_analysis::CallGraphNode> = references
250 .iter()
251 .filter(|r| {
252 r.reference_kind == crate::code_analysis::ReferenceKind::Call
253 && r.target_symbol_id.contains(&function_name)
254 })
255 .filter_map(|r| {
256 definitions.iter().find(|def| {
258 matches!(
259 def.symbol_id.kind,
260 crate::code_analysis::SymbolKind::Function
261 | crate::code_analysis::SymbolKind::Method
262 ) && r.start_line >= def.symbol_id.start_line
263 && r.start_line <= def.end_line
264 })
265 })
266 .filter(|def| seen_callers.insert(def.symbol_id.name.clone()))
267 .map(|def| crate::code_analysis::CallGraphNode {
268 name: def.symbol_id.name.clone(),
269 kind: def.symbol_id.kind,
270 file_path: request.file_path.clone(),
271 line: def.symbol_id.start_line,
272 children: Vec::new(),
273 })
274 .collect();
275
276 let target_func = target_function.expect("early return on None above guarantees Some");
278 let mut seen_callees = std::collections::HashSet::new();
279 let callees: Vec<crate::code_analysis::CallGraphNode> = references
280 .iter()
281 .filter(|r| {
282 r.reference_kind == crate::code_analysis::ReferenceKind::Call
283 && r.start_line >= target_func.symbol_id.start_line
284 && r.start_line <= target_func.end_line
285 })
286 .filter_map(|r| {
287 let parts: Vec<&str> = r.target_symbol_id.split(':').collect();
289 if parts.len() >= 2 {
290 Some(parts[1].to_string())
291 } else {
292 None
293 }
294 })
295 .filter(|name| seen_callees.insert(name.clone()))
296 .filter_map(|name| {
297 symbol_index
299 .get(&name)
300 .and_then(|defs| defs.first())
301 .cloned()
302 })
303 .map(|def| crate::code_analysis::CallGraphNode {
304 name: def.symbol_id.name.clone(),
305 kind: def.symbol_id.kind,
306 file_path: request.file_path.clone(),
307 line: def.symbol_id.start_line,
308 children: Vec::new(),
309 })
310 .collect();
311
312 Ok(GetCallGraphResponse {
313 root_symbol: Some(root_symbol),
314 callers,
315 callees,
316 precision: format!("{:?}", precision).to_lowercase(),
317 duration_ms: start.elapsed().as_millis() as u64,
318 })
319 }
320}