Skip to main content

dk_protocol/
context.rs

1use tonic::{Response, Status};
2use tracing::info;
3
4use crate::server::ProtocolServer;
5use crate::{
6    CallEdgeRef, ContextDepth, ContextRequest, ContextResponse, SymbolRef, SymbolResult,
7};
8
9/// Handle a CONTEXT RPC.
10///
11/// 1. Validates the session (and keeps it alive).
12/// 2. Runs a full-text search for the given query.
13/// 3. Depending on depth:
14///    - `SIGNATURES` -- return symbol metadata only.
15///    - `FULL`       -- also include source code.
16///    - `CALL_GRAPH` -- include source code + caller/callee edges.
17/// 4. Estimates token usage and truncates if `max_tokens` is set.
18///
19/// File reads now go through the session workspace overlay, so agents
20/// see their own in-progress modifications reflected in CONTEXT results.
21pub async fn handle_context(
22    server: &ProtocolServer,
23    req: ContextRequest,
24) -> Result<Response<ContextResponse>, Status> {
25    // 1. Validate session
26    let session = server.validate_session(&req.session_id)?;
27
28    // Touch session to keep it alive
29    let sid = req
30        .session_id
31        .parse::<uuid::Uuid>()
32        .map_err(|_| Status::invalid_argument("Invalid session ID"))?;
33    server.session_mgr().touch_session(&sid);
34
35    // 2-3. Query symbols and build results.
36    //      `GitRepository` is !Sync, so all usage must stay within a single
37    //      block that does not hold it across `.await` points.
38    let depth = req.depth();
39    let include_source = depth == ContextDepth::Full || depth == ContextDepth::CallGraph;
40    let include_call_graph = depth == ContextDepth::CallGraph;
41
42    let max_results = if req.max_tokens > 0 {
43        ((req.max_tokens / 100) as usize).max(10)
44    } else {
45        50
46    };
47
48    let engine = server.engine();
49
50    let (symbol_results, call_edges) = {
51        let (repo_id, git_repo) = engine
52            .get_repo(&session.codebase)
53            .await
54            .map_err(|e| Status::internal(format!("Repo error: {e}")))?;
55
56        let symbols = engine
57            .query_symbols(repo_id, &req.query, max_results)
58            .await
59            .map_err(|e| Status::internal(format!("Query error: {e}")))?;
60
61        // Try to get the workspace for session-aware reads.
62        // If no workspace exists (shouldn't happen after CONNECT, but be defensive),
63        // fall through to direct file reads.
64        let maybe_ws = engine.workspace_manager().get_workspace(&sid);
65
66        let mut results = Vec::with_capacity(symbols.len());
67        let mut edges = Vec::new();
68
69        for sym in &symbols {
70            let mut result = SymbolResult {
71                symbol: Some(symbol_to_ref(sym)),
72                source: None,
73                caller_ids: vec![],
74                callee_ids: vec![],
75                test_symbol_ids: vec![],
76            };
77
78            // FULL / CALL_GRAPH depth: include source code (symbol span only)
79            if include_source {
80                let source_bytes = if let Some(ref ws) = maybe_ws {
81                    // Read through workspace overlay (sees session modifications)
82                    ws.read_file(
83                        &sym.file_path.to_string_lossy(),
84                        &git_repo,
85                    )
86                    .ok()
87                    .map(|r| r.content)
88                } else {
89                    // Fallback: read directly from working directory
90                    let file_path = git_repo.path().join(&sym.file_path);
91                    std::fs::read(&file_path).ok()
92                };
93
94                if let Some(source) = source_bytes {
95                    let start = sym.span.start_byte as usize;
96                    let end = sym.span.end_byte as usize;
97                    if end <= source.len() {
98                        result.source = Some(
99                            String::from_utf8_lossy(&source[start..end]).to_string(),
100                        );
101                    }
102                }
103            }
104
105            // CALL_GRAPH depth: include callers/callees
106            if include_call_graph {
107                // Drop git_repo borrow concern: get_call_graph doesn't need it
108                if let Ok((callers, callees)) = engine.get_call_graph(repo_id, sym.id).await {
109                    result.caller_ids = callers.iter().map(|s| s.id.to_string()).collect();
110                    result.callee_ids = callees.iter().map(|s| s.id.to_string()).collect();
111
112                    for caller in &callers {
113                        edges.push(CallEdgeRef {
114                            caller_id: caller.id.to_string(),
115                            callee_id: sym.id.to_string(),
116                            kind: "direct_call".to_string(),
117                        });
118                    }
119                }
120            }
121
122            results.push(result);
123        }
124
125        (results, edges)
126    };
127
128    // 4. Estimate tokens (~4 chars per token, rough heuristic)
129    let total_chars: usize = symbol_results
130        .iter()
131        .map(|r| {
132            let sym_size = r
133                .symbol
134                .as_ref()
135                .map(|s| s.name.len() + s.signature.len())
136                .unwrap_or(0);
137            let source_size = r.source.as_ref().map(|s| s.len()).unwrap_or(0);
138            sym_size + source_size
139        })
140        .sum();
141    let mut estimated_tokens = (total_chars / 4) as u32;
142
143    // 5. Truncate if max_tokens is set
144    let mut symbol_results = symbol_results;
145    if req.max_tokens > 0 && estimated_tokens > req.max_tokens {
146        let mut remaining = req.max_tokens;
147
148        for result in &mut symbol_results {
149            let sym_tokens = result
150                .symbol
151                .as_ref()
152                .map(|s| ((s.name.len() + s.signature.len()) / 4) as u32)
153                .unwrap_or(0);
154
155            if remaining < sym_tokens {
156                // Can't fit even the symbol header -- drop source.
157                result.source = None;
158                continue;
159            }
160            remaining -= sym_tokens;
161
162            if let Some(ref source) = result.source {
163                let source_tokens = (source.len() / 4) as u32;
164                if remaining < source_tokens {
165                    let max_chars = (remaining as usize) * 4;
166                    result.source = Some(source[..max_chars.min(source.len())].to_string());
167                    remaining = 0;
168                } else {
169                    remaining -= source_tokens;
170                }
171            }
172        }
173
174        estimated_tokens = req.max_tokens - remaining;
175    }
176
177    info!(
178        session_id = %req.session_id,
179        query = %req.query,
180        results = symbol_results.len(),
181        estimated_tokens,
182        "CONTEXT: query served"
183    );
184
185    Ok(Response::new(ContextResponse {
186        symbols: symbol_results,
187        call_graph: call_edges,
188        dependencies: if req.include_dependencies {
189            let (repo_id, _git_repo) = engine
190                .get_repo(&session.codebase)
191                .await
192                .map_err(|e| Status::internal(format!("Repo error: {e}")))?;
193
194            let deps = engine
195                .dep_store()
196                .find_by_repo(repo_id)
197                .await
198                .unwrap_or_default();
199
200            let mut dep_refs = Vec::with_capacity(deps.len());
201            for dep in &deps {
202                let symbol_ids = engine
203                    .dep_store()
204                    .find_symbols_for_dep(dep.id)
205                    .await
206                    .unwrap_or_default();
207
208                dep_refs.push(crate::DependencyRef {
209                    package: dep.package.clone(),
210                    version_req: dep.version_req.clone(),
211                    used_by_symbol_ids: symbol_ids.iter().map(|id| id.to_string()).collect(),
212                });
213            }
214            dep_refs
215        } else {
216            vec![]
217        },
218        estimated_tokens,
219    }))
220}
221
222/// Convert a `dk_core::Symbol` into the protobuf `SymbolRef` message.
223fn symbol_to_ref(sym: &dk_core::Symbol) -> SymbolRef {
224    SymbolRef {
225        id: sym.id.to_string(),
226        name: sym.name.clone(),
227        qualified_name: sym.qualified_name.clone(),
228        kind: sym.kind.to_string(),
229        visibility: format!("{:?}", sym.visibility),
230        file_path: sym.file_path.to_string_lossy().to_string(),
231        start_byte: sym.span.start_byte,
232        end_byte: sym.span.end_byte,
233        signature: sym.signature.clone().unwrap_or_default(),
234        doc_comment: sym.doc_comment.clone(),
235        parent_id: sym.parent.map(|p| p.to_string()),
236    }
237}