1use std::path::{Path, PathBuf};
2use std::sync::Arc;
3
4use roder_api::code_index::{CodeIndexSearchRequest, CodeIndexStats, CodeIndexStatus};
5use roder_api::context::{
6 ContextBlock, ContextBlockKind, ContextProvider, ContextProviderId, ContextQuery,
7};
8use roder_code_index::sqlite::SqliteCodeIndexStore;
9use serde::Serialize;
10use serde_json::json;
11
12const MAX_RESULTS: usize = 5;
13const MAX_BLOCK_BYTES: usize = 4 * 1024;
14const MAX_SNIPPET_BYTES: usize = 480;
15
16#[derive(Clone)]
17pub struct CodeIndexContextProvider {
18 workspace: PathBuf,
19 store: Arc<SqliteCodeIndexStore>,
20}
21
22impl CodeIndexContextProvider {
23 pub fn new(workspace: impl Into<PathBuf>, store: Arc<SqliteCodeIndexStore>) -> Self {
24 Self {
25 workspace: workspace.into(),
26 store,
27 }
28 }
29}
30
31#[async_trait::async_trait]
32impl ContextProvider for CodeIndexContextProvider {
33 fn id(&self) -> ContextProviderId {
34 "code-index-context-provider".to_string()
35 }
36
37 async fn blocks(&self, query: &ContextQuery) -> anyhow::Result<Vec<ContextBlock>> {
38 if query
39 .workspace
40 .as_deref()
41 .is_some_and(|workspace| Path::new(workspace) != self.workspace)
42 {
43 return Ok(Vec::new());
44 }
45 let status = self.store.status(&self.workspace)?;
46 if status.status != CodeIndexStatus::Ready {
47 return Ok(Vec::new());
48 }
49
50 let response = self.store.search(CodeIndexSearchRequest {
51 query_id: format!("{}:{}", query.thread_id, query.turn_id),
52 query: query.prompt.clone(),
53 workspace_root: self.workspace.clone(),
54 limit: MAX_RESULTS,
55 })?;
56 if response.results.is_empty() {
57 return Ok(Vec::new());
58 }
59
60 let mut rows = Vec::new();
61 for result in response.results {
62 let snippet = bounded_snippet(&self.workspace, &result.chunk).unwrap_or_default();
63 rows.push(RenderedCodeResult {
64 path: result.chunk.path.to_string_lossy().replace('\\', "/"),
65 start_line: result.chunk.line_range.start,
66 end_line: result.chunk.line_range.end,
67 score: result.score,
68 chunk_hash: result.chunk.chunk_hash,
69 proof_verified: result.proof_verified,
70 snippet,
71 });
72 }
73
74 let text = render_block_text(&rows);
75 let text = truncate_block(text);
76 Ok(vec![ContextBlock {
77 id: "code-index-context-provider".to_string(),
78 kind: ContextBlockKind::RetrievedDocument,
79 text,
80 priority: 86,
81 token_estimate: None,
82 metadata: json!({
83 "provider": "code-index-context-provider",
84 "source": "indexed_semantic_code_search",
85 "query": query.prompt,
86 "result_count": rows.len(),
87 "proof_filtered_drop_count": response.dropped_results.len(),
88 "stale_index_fallback": false,
89 "index_status": "ready",
90 "generation_id": response.generation.id,
91 "root_hash": response.generation.root_hash,
92 "stats": stats_metadata(&response.generation.stats),
93 "results": rows,
94 }),
95 }])
96 }
97}
98
99#[derive(Debug, Clone, Serialize, PartialEq)]
100#[serde(rename_all = "camelCase")]
101struct RenderedCodeResult {
102 path: String,
103 start_line: u32,
104 end_line: u32,
105 score: f32,
106 chunk_hash: String,
107 proof_verified: bool,
108 snippet: String,
109}
110
111fn render_block_text(results: &[RenderedCodeResult]) -> String {
112 let mut text = String::from("Indexed semantic code context:");
113 for result in results {
114 text.push_str(&format!(
115 "\n- {}:{}-{} score {:.2} proof {} chunk {}",
116 result.path,
117 result.start_line,
118 result.end_line,
119 result.score,
120 if result.proof_verified {
121 "verified"
122 } else {
123 "unverified"
124 },
125 &result.chunk_hash[..12.min(result.chunk_hash.len())]
126 ));
127 if !result.snippet.is_empty() {
128 text.push_str("\n ```\n");
129 text.push_str(&result.snippet);
130 if !result.snippet.ends_with('\n') {
131 text.push('\n');
132 }
133 text.push_str(" ```");
134 }
135 }
136 text
137}
138
139fn bounded_snippet(
140 workspace: &Path,
141 chunk: &roder_api::code_index::CodeChunk,
142) -> anyhow::Result<String> {
143 let workspace = std::fs::canonicalize(workspace)?;
144 let path = std::fs::canonicalize(workspace.join(&chunk.path))?;
145 if !path.starts_with(&workspace) {
146 return Ok(String::new());
147 }
148 let text = std::fs::read_to_string(path)?;
149 let start = chunk.byte_range.start as usize;
150 let end = chunk
151 .byte_range
152 .end
153 .min(start as u64 + MAX_SNIPPET_BYTES as u64) as usize;
154 if start >= text.len() || end > text.len() || start >= end {
155 return Ok(String::new());
156 }
157 let mut start = start;
158 while start < text.len() && !text.is_char_boundary(start) {
159 start += 1;
160 }
161 let mut end = end;
162 while end > start && !text.is_char_boundary(end) {
163 end -= 1;
164 }
165 Ok(text[start..end].to_string())
166}
167
168fn truncate_block(mut text: String) -> String {
169 if text.len() <= MAX_BLOCK_BYTES {
170 return text;
171 }
172 text.truncate(MAX_BLOCK_BYTES);
173 while !text.is_char_boundary(text.len()) {
174 text.pop();
175 }
176 text.push_str("\n...");
177 text
178}
179
180fn stats_metadata(stats: &CodeIndexStats) -> serde_json::Value {
181 json!({
182 "fileCount": stats.file_count,
183 "chunkCount": stats.chunk_count,
184 "embeddedChunkCount": stats.embedded_chunk_count,
185 "cachedEmbeddingCount": stats.cached_embedding_count,
186 "indexBytes": stats.index_bytes,
187 })
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use roder_api::context::ContextQuery;
194
195 #[tokio::test]
196 async fn code_index_provider_returns_bounded_verified_snippets() {
197 let root = test_workspace("provider-bounded-snippets");
198 write(
199 &root,
200 "src/auth.rs",
201 "pub fn oauth_refresh_token() {\n let token = \"refresh\";\n}\n",
202 );
203 let store = Arc::new(SqliteCodeIndexStore::open(root.with_extension("sqlite3")).unwrap());
204 store.rebuild_workspace(&root).unwrap();
205 let provider = CodeIndexContextProvider::new(root.clone(), store);
206
207 let blocks = provider
208 .blocks(&query("oauth refresh token"))
209 .await
210 .unwrap();
211
212 assert_eq!(blocks.len(), 1);
213 let block = &blocks[0];
214 assert_eq!(block.kind, ContextBlockKind::RetrievedDocument);
215 assert!(block.text.contains("src/auth.rs:"));
216 assert!(block.text.contains("proof verified"));
217 assert!(block.text.len() <= MAX_BLOCK_BYTES + 4);
218 assert_eq!(block.metadata["source"], "indexed_semantic_code_search");
219 assert!(block.metadata["result_count"].as_u64().unwrap() >= 1);
220 assert_eq!(block.metadata["results"][0]["proofVerified"], true);
221
222 let _ = std::fs::remove_dir_all(root);
223 }
224
225 #[tokio::test]
226 async fn code_index_provider_degrades_to_empty_when_index_missing() {
227 let root = test_workspace("provider-missing-index");
228 let store = Arc::new(SqliteCodeIndexStore::open(root.with_extension("sqlite3")).unwrap());
229 let provider = CodeIndexContextProvider::new(root.clone(), store);
230
231 let blocks = provider.blocks(&query("anything")).await.unwrap();
232
233 assert!(blocks.is_empty());
234 let _ = std::fs::remove_dir_all(root);
235 }
236
237 fn query(prompt: &str) -> ContextQuery {
238 ContextQuery {
239 thread_id: "thread-a".to_string(),
240 turn_id: "turn-a".to_string(),
241 prompt: prompt.to_string(),
242 workspace: None,
243 token_budget: None,
244 }
245 }
246
247 fn write(root: &Path, path: &str, text: &str) {
248 let path = root.join(path);
249 std::fs::create_dir_all(path.parent().unwrap()).unwrap();
250 std::fs::write(path, text).unwrap();
251 }
252
253 fn test_workspace(name: &str) -> PathBuf {
254 let stamp = std::time::SystemTime::now()
255 .duration_since(std::time::UNIX_EPOCH)
256 .unwrap()
257 .as_nanos();
258 let path = std::env::temp_dir().join(format!("roder-code-context-{name}-{stamp}"));
259 let _ = std::fs::remove_dir_all(&path);
260 std::fs::create_dir_all(&path).unwrap();
261 path
262 }
263}