Skip to main content

roder_context/
code_index.rs

1use std::path::{Path, PathBuf};
2use std::sync::Arc;
3
4use roder_api::code_index::{CodeIndexSearchRequest, CodeIndexStats, CodeIndexStatus};
5use roder_api::context::{
6    ContextBlock, ContextBlockKind, ContextProvider, ContextProviderId, ContextQuery,
7};
8use roder_code_index::sqlite::SqliteCodeIndexStore;
9use serde::Serialize;
10use serde_json::json;
11
12const MAX_RESULTS: usize = 5;
13const MAX_BLOCK_BYTES: usize = 4 * 1024;
14const MAX_SNIPPET_BYTES: usize = 480;
15
16#[derive(Clone)]
17pub struct CodeIndexContextProvider {
18    workspace: PathBuf,
19    store: Arc<SqliteCodeIndexStore>,
20}
21
22impl CodeIndexContextProvider {
23    pub fn new(workspace: impl Into<PathBuf>, store: Arc<SqliteCodeIndexStore>) -> Self {
24        Self {
25            workspace: workspace.into(),
26            store,
27        }
28    }
29}
30
31#[async_trait::async_trait]
32impl ContextProvider for CodeIndexContextProvider {
33    fn id(&self) -> ContextProviderId {
34        "code-index-context-provider".to_string()
35    }
36
37    async fn blocks(&self, query: &ContextQuery) -> anyhow::Result<Vec<ContextBlock>> {
38        if query
39            .workspace
40            .as_deref()
41            .is_some_and(|workspace| Path::new(workspace) != self.workspace)
42        {
43            return Ok(Vec::new());
44        }
45        let status = self.store.status(&self.workspace)?;
46        if status.status != CodeIndexStatus::Ready {
47            return Ok(Vec::new());
48        }
49
50        let response = self.store.search(CodeIndexSearchRequest {
51            query_id: format!("{}:{}", query.thread_id, query.turn_id),
52            query: query.prompt.clone(),
53            workspace_root: self.workspace.clone(),
54            limit: MAX_RESULTS,
55        })?;
56        if response.results.is_empty() {
57            return Ok(Vec::new());
58        }
59
60        let mut rows = Vec::new();
61        for result in response.results {
62            let snippet = bounded_snippet(&self.workspace, &result.chunk).unwrap_or_default();
63            rows.push(RenderedCodeResult {
64                path: result.chunk.path.to_string_lossy().replace('\\', "/"),
65                start_line: result.chunk.line_range.start,
66                end_line: result.chunk.line_range.end,
67                score: result.score,
68                chunk_hash: result.chunk.chunk_hash,
69                proof_verified: result.proof_verified,
70                snippet,
71            });
72        }
73
74        let text = render_block_text(&rows);
75        let text = truncate_block(text);
76        Ok(vec![ContextBlock {
77            id: "code-index-context-provider".to_string(),
78            kind: ContextBlockKind::RetrievedDocument,
79            text,
80            priority: 86,
81            token_estimate: None,
82            metadata: json!({
83                "provider": "code-index-context-provider",
84                "source": "indexed_semantic_code_search",
85                "query": query.prompt,
86                "result_count": rows.len(),
87                "proof_filtered_drop_count": response.dropped_results.len(),
88                "stale_index_fallback": false,
89                "index_status": "ready",
90                "generation_id": response.generation.id,
91                "root_hash": response.generation.root_hash,
92                "stats": stats_metadata(&response.generation.stats),
93                "results": rows,
94            }),
95        }])
96    }
97}
98
99#[derive(Debug, Clone, Serialize, PartialEq)]
100#[serde(rename_all = "camelCase")]
101struct RenderedCodeResult {
102    path: String,
103    start_line: u32,
104    end_line: u32,
105    score: f32,
106    chunk_hash: String,
107    proof_verified: bool,
108    snippet: String,
109}
110
111fn render_block_text(results: &[RenderedCodeResult]) -> String {
112    let mut text = String::from("Indexed semantic code context:");
113    for result in results {
114        text.push_str(&format!(
115            "\n- {}:{}-{} score {:.2} proof {} chunk {}",
116            result.path,
117            result.start_line,
118            result.end_line,
119            result.score,
120            if result.proof_verified {
121                "verified"
122            } else {
123                "unverified"
124            },
125            &result.chunk_hash[..12.min(result.chunk_hash.len())]
126        ));
127        if !result.snippet.is_empty() {
128            text.push_str("\n  ```\n");
129            text.push_str(&result.snippet);
130            if !result.snippet.ends_with('\n') {
131                text.push('\n');
132            }
133            text.push_str("  ```");
134        }
135    }
136    text
137}
138
139fn bounded_snippet(
140    workspace: &Path,
141    chunk: &roder_api::code_index::CodeChunk,
142) -> anyhow::Result<String> {
143    let workspace = std::fs::canonicalize(workspace)?;
144    let path = std::fs::canonicalize(workspace.join(&chunk.path))?;
145    if !path.starts_with(&workspace) {
146        return Ok(String::new());
147    }
148    let text = std::fs::read_to_string(path)?;
149    let start = chunk.byte_range.start as usize;
150    let end = chunk
151        .byte_range
152        .end
153        .min(start as u64 + MAX_SNIPPET_BYTES as u64) as usize;
154    if start >= text.len() || end > text.len() || start >= end {
155        return Ok(String::new());
156    }
157    let mut start = start;
158    while start < text.len() && !text.is_char_boundary(start) {
159        start += 1;
160    }
161    let mut end = end;
162    while end > start && !text.is_char_boundary(end) {
163        end -= 1;
164    }
165    Ok(text[start..end].to_string())
166}
167
168fn truncate_block(mut text: String) -> String {
169    if text.len() <= MAX_BLOCK_BYTES {
170        return text;
171    }
172    text.truncate(MAX_BLOCK_BYTES);
173    while !text.is_char_boundary(text.len()) {
174        text.pop();
175    }
176    text.push_str("\n...");
177    text
178}
179
180fn stats_metadata(stats: &CodeIndexStats) -> serde_json::Value {
181    json!({
182        "fileCount": stats.file_count,
183        "chunkCount": stats.chunk_count,
184        "embeddedChunkCount": stats.embedded_chunk_count,
185        "cachedEmbeddingCount": stats.cached_embedding_count,
186        "indexBytes": stats.index_bytes,
187    })
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use roder_api::context::ContextQuery;
194
195    #[tokio::test]
196    async fn code_index_provider_returns_bounded_verified_snippets() {
197        let root = test_workspace("provider-bounded-snippets");
198        write(
199            &root,
200            "src/auth.rs",
201            "pub fn oauth_refresh_token() {\n    let token = \"refresh\";\n}\n",
202        );
203        let store = Arc::new(SqliteCodeIndexStore::open(root.with_extension("sqlite3")).unwrap());
204        store.rebuild_workspace(&root).unwrap();
205        let provider = CodeIndexContextProvider::new(root.clone(), store);
206
207        let blocks = provider
208            .blocks(&query("oauth refresh token"))
209            .await
210            .unwrap();
211
212        assert_eq!(blocks.len(), 1);
213        let block = &blocks[0];
214        assert_eq!(block.kind, ContextBlockKind::RetrievedDocument);
215        assert!(block.text.contains("src/auth.rs:"));
216        assert!(block.text.contains("proof verified"));
217        assert!(block.text.len() <= MAX_BLOCK_BYTES + 4);
218        assert_eq!(block.metadata["source"], "indexed_semantic_code_search");
219        assert!(block.metadata["result_count"].as_u64().unwrap() >= 1);
220        assert_eq!(block.metadata["results"][0]["proofVerified"], true);
221
222        let _ = std::fs::remove_dir_all(root);
223    }
224
225    #[tokio::test]
226    async fn code_index_provider_degrades_to_empty_when_index_missing() {
227        let root = test_workspace("provider-missing-index");
228        let store = Arc::new(SqliteCodeIndexStore::open(root.with_extension("sqlite3")).unwrap());
229        let provider = CodeIndexContextProvider::new(root.clone(), store);
230
231        let blocks = provider.blocks(&query("anything")).await.unwrap();
232
233        assert!(blocks.is_empty());
234        let _ = std::fs::remove_dir_all(root);
235    }
236
237    fn query(prompt: &str) -> ContextQuery {
238        ContextQuery {
239            thread_id: "thread-a".to_string(),
240            turn_id: "turn-a".to_string(),
241            prompt: prompt.to_string(),
242            workspace: None,
243            token_budget: None,
244        }
245    }
246
247    fn write(root: &Path, path: &str, text: &str) {
248        let path = root.join(path);
249        std::fs::create_dir_all(path.parent().unwrap()).unwrap();
250        std::fs::write(path, text).unwrap();
251    }
252
253    fn test_workspace(name: &str) -> PathBuf {
254        let stamp = std::time::SystemTime::now()
255            .duration_since(std::time::UNIX_EPOCH)
256            .unwrap()
257            .as_nanos();
258        let path = std::env::temp_dir().join(format!("roder-code-context-{name}-{stamp}"));
259        let _ = std::fs::remove_dir_all(&path);
260        std::fs::create_dir_all(&path).unwrap();
261        path
262    }
263}