1use tonic::{Response, Status};
2use tracing::info;
3
4use crate::server::ProtocolServer;
5use crate::{
6 CallEdgeRef, ContextDepth, ContextRequest, ContextResponse, SymbolRef, SymbolResult,
7};
8
9pub async fn handle_context(
22 server: &ProtocolServer,
23 req: ContextRequest,
24) -> Result<Response<ContextResponse>, Status> {
25 let session = server.validate_session(&req.session_id)?;
27
28 let sid = req
30 .session_id
31 .parse::<uuid::Uuid>()
32 .map_err(|_| Status::invalid_argument("Invalid session ID"))?;
33 server.session_mgr().touch_session(&sid);
34
35 let depth = req.depth();
39 let include_source = depth == ContextDepth::Full || depth == ContextDepth::CallGraph;
40 let include_call_graph = depth == ContextDepth::CallGraph;
41
42 let max_results = if req.max_tokens > 0 {
43 ((req.max_tokens / 100) as usize).max(10)
44 } else {
45 50
46 };
47
48 let engine = server.engine();
49
50 let (symbol_results, call_edges) = {
51 let (repo_id, git_repo) = engine
52 .get_repo(&session.codebase)
53 .await
54 .map_err(|e| Status::internal(format!("Repo error: {e}")))?;
55
56 let symbols = engine
57 .query_symbols(repo_id, &req.query, max_results)
58 .await
59 .map_err(|e| Status::internal(format!("Query error: {e}")))?;
60
61 let maybe_ws = engine.workspace_manager().get_workspace(&sid);
65
66 let mut results = Vec::with_capacity(symbols.len());
67 let mut edges = Vec::new();
68
69 for sym in &symbols {
70 let mut result = SymbolResult {
71 symbol: Some(symbol_to_ref(sym)),
72 source: None,
73 caller_ids: vec![],
74 callee_ids: vec![],
75 test_symbol_ids: vec![],
76 };
77
78 if include_source {
80 let source_bytes = if let Some(ref ws) = maybe_ws {
81 ws.read_file(
83 &sym.file_path.to_string_lossy(),
84 &git_repo,
85 )
86 .ok()
87 .map(|r| r.content)
88 } else {
89 let file_path = git_repo.path().join(&sym.file_path);
91 std::fs::read(&file_path).ok()
92 };
93
94 if let Some(source) = source_bytes {
95 let start = sym.span.start_byte as usize;
96 let end = sym.span.end_byte as usize;
97 if end <= source.len() {
98 result.source = Some(
99 String::from_utf8_lossy(&source[start..end]).to_string(),
100 );
101 }
102 }
103 }
104
105 if include_call_graph {
107 if let Ok((callers, callees)) = engine.get_call_graph(repo_id, sym.id).await {
109 result.caller_ids = callers.iter().map(|s| s.id.to_string()).collect();
110 result.callee_ids = callees.iter().map(|s| s.id.to_string()).collect();
111
112 for caller in &callers {
113 edges.push(CallEdgeRef {
114 caller_id: caller.id.to_string(),
115 callee_id: sym.id.to_string(),
116 kind: "direct_call".to_string(),
117 });
118 }
119 }
120 }
121
122 results.push(result);
123 }
124
125 (results, edges)
126 };
127
128 let total_chars: usize = symbol_results
130 .iter()
131 .map(|r| {
132 let sym_size = r
133 .symbol
134 .as_ref()
135 .map(|s| s.name.len() + s.signature.len())
136 .unwrap_or(0);
137 let source_size = r.source.as_ref().map(|s| s.len()).unwrap_or(0);
138 sym_size + source_size
139 })
140 .sum();
141 let mut estimated_tokens = (total_chars / 4) as u32;
142
143 let mut symbol_results = symbol_results;
145 if req.max_tokens > 0 && estimated_tokens > req.max_tokens {
146 let mut remaining = req.max_tokens;
147
148 for result in &mut symbol_results {
149 let sym_tokens = result
150 .symbol
151 .as_ref()
152 .map(|s| ((s.name.len() + s.signature.len()) / 4) as u32)
153 .unwrap_or(0);
154
155 if remaining < sym_tokens {
156 result.source = None;
158 continue;
159 }
160 remaining -= sym_tokens;
161
162 if let Some(ref source) = result.source {
163 let source_tokens = (source.len() / 4) as u32;
164 if remaining < source_tokens {
165 let max_chars = (remaining as usize) * 4;
166 result.source = Some(source[..max_chars.min(source.len())].to_string());
167 remaining = 0;
168 } else {
169 remaining -= source_tokens;
170 }
171 }
172 }
173
174 estimated_tokens = req.max_tokens - remaining;
175 }
176
177 info!(
178 session_id = %req.session_id,
179 query = %req.query,
180 results = symbol_results.len(),
181 estimated_tokens,
182 "CONTEXT: query served"
183 );
184
185 Ok(Response::new(ContextResponse {
186 symbols: symbol_results,
187 call_graph: call_edges,
188 dependencies: if req.include_dependencies {
189 let (repo_id, _git_repo) = engine
190 .get_repo(&session.codebase)
191 .await
192 .map_err(|e| Status::internal(format!("Repo error: {e}")))?;
193
194 let deps = engine
195 .dep_store()
196 .find_by_repo(repo_id)
197 .await
198 .unwrap_or_default();
199
200 let mut dep_refs = Vec::with_capacity(deps.len());
201 for dep in &deps {
202 let symbol_ids = engine
203 .dep_store()
204 .find_symbols_for_dep(dep.id)
205 .await
206 .unwrap_or_default();
207
208 dep_refs.push(crate::DependencyRef {
209 package: dep.package.clone(),
210 version_req: dep.version_req.clone(),
211 used_by_symbol_ids: symbol_ids.iter().map(|id| id.to_string()).collect(),
212 });
213 }
214 dep_refs
215 } else {
216 vec![]
217 },
218 estimated_tokens,
219 }))
220}
221
222fn symbol_to_ref(sym: &dk_core::Symbol) -> SymbolRef {
224 SymbolRef {
225 id: sym.id.to_string(),
226 name: sym.name.clone(),
227 qualified_name: sym.qualified_name.clone(),
228 kind: sym.kind.to_string(),
229 visibility: format!("{:?}", sym.visibility),
230 file_path: sym.file_path.to_string_lossy().to_string(),
231 start_byte: sym.span.start_byte,
232 end_byte: sym.span.end_byte,
233 signature: sym.signature.clone().unwrap_or_default(),
234 doc_comment: sym.doc_comment.clone(),
235 parent_id: sym.parent.map(|p| p.to_string()),
236 }
237}