Skip to main content

ucp_agent/
rag.rs

1//! RAG (Retrieval-Augmented Generation) provider interface.
2
3use crate::error::Result;
4use async_trait::async_trait;
5use serde::{Deserialize, Serialize};
6use std::collections::HashSet;
7use ucm_core::BlockId;
8
9/// Options for RAG search.
10#[derive(Debug, Clone, Default)]
11pub struct RagSearchOptions {
12    /// Maximum results to return.
13    pub limit: usize,
14    /// Minimum similarity threshold (0.0 - 1.0).
15    pub min_similarity: f32,
16    /// Filter by block IDs (search only within these).
17    pub filter_block_ids: Option<HashSet<BlockId>>,
18    /// Filter by semantic roles.
19    pub filter_roles: Option<HashSet<String>>,
20    /// Filter by tags.
21    pub filter_tags: Option<HashSet<String>>,
22    /// Include content in results.
23    pub include_content: bool,
24}
25
26impl RagSearchOptions {
27    pub fn new() -> Self {
28        Self {
29            limit: 10,
30            min_similarity: 0.0,
31            filter_block_ids: None,
32            filter_roles: None,
33            filter_tags: None,
34            include_content: true,
35        }
36    }
37
38    pub fn with_limit(mut self, limit: usize) -> Self {
39        self.limit = limit;
40        self
41    }
42
43    pub fn with_min_similarity(mut self, threshold: f32) -> Self {
44        self.min_similarity = threshold;
45        self
46    }
47
48    pub fn with_roles(mut self, roles: impl IntoIterator<Item = String>) -> Self {
49        self.filter_roles = Some(roles.into_iter().collect());
50        self
51    }
52
53    pub fn with_tags(mut self, tags: impl IntoIterator<Item = String>) -> Self {
54        self.filter_tags = Some(tags.into_iter().collect());
55        self
56    }
57
58    pub fn with_block_ids(mut self, ids: impl IntoIterator<Item = BlockId>) -> Self {
59        self.filter_block_ids = Some(ids.into_iter().collect());
60        self
61    }
62}
63
64/// A single match from semantic search.
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct RagMatch {
67    /// Block ID of the match.
68    pub block_id: BlockId,
69    /// Similarity score (0.0 - 1.0).
70    pub similarity: f32,
71    /// Content preview (if requested).
72    pub content_preview: Option<String>,
73    /// Semantic role (if available).
74    pub semantic_role: Option<String>,
75    /// Highlight spans in content (character ranges).
76    pub highlight_spans: Vec<(usize, usize)>,
77}
78
79/// Results from semantic search.
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct RagSearchResults {
82    /// Matching blocks with similarity scores.
83    pub matches: Vec<RagMatch>,
84    /// Query that was executed.
85    pub query: String,
86    /// Total blocks searched.
87    pub total_searched: usize,
88    /// Search execution time in milliseconds.
89    pub execution_time_ms: u64,
90}
91
92impl RagSearchResults {
93    pub fn empty(query: String) -> Self {
94        Self {
95            matches: Vec::new(),
96            query,
97            total_searched: 0,
98            execution_time_ms: 0,
99        }
100    }
101
102    pub fn block_ids(&self) -> Vec<BlockId> {
103        self.matches.iter().map(|m| m.block_id).collect()
104    }
105}
106
107/// Capabilities of a RAG provider.
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct RagCapabilities {
110    /// Supports semantic search.
111    pub supports_search: bool,
112    /// Supports embedding generation.
113    pub supports_embedding: bool,
114    /// Supports filtering.
115    pub supports_filtering: bool,
116    /// Maximum query length.
117    pub max_query_length: usize,
118    /// Maximum results per query.
119    pub max_results: usize,
120}
121
122impl Default for RagCapabilities {
123    fn default() -> Self {
124        Self {
125            supports_search: true,
126            supports_embedding: false,
127            supports_filtering: true,
128            max_query_length: 1000,
129            max_results: 100,
130        }
131    }
132}
133
134/// Abstract interface for semantic search providers.
135#[async_trait]
136pub trait RagProvider: Send + Sync {
137    /// Search for semantically similar content.
138    async fn search(&self, query: &str, options: RagSearchOptions) -> Result<RagSearchResults>;
139
140    /// Get embeddings for content (optional).
141    async fn embed(&self, content: &str) -> Result<Vec<f32>> {
142        let _ = content;
143        Ok(Vec::new())
144    }
145
146    /// Provider capabilities.
147    fn capabilities(&self) -> RagCapabilities;
148
149    /// Provider name for identification.
150    fn name(&self) -> &str;
151}
152
153/// No-op RAG provider for testing.
154pub struct NullRagProvider;
155
156#[async_trait]
157impl RagProvider for NullRagProvider {
158    async fn search(&self, query: &str, _options: RagSearchOptions) -> Result<RagSearchResults> {
159        Ok(RagSearchResults::empty(query.to_string()))
160    }
161
162    fn capabilities(&self) -> RagCapabilities {
163        RagCapabilities {
164            supports_search: false,
165            supports_embedding: false,
166            supports_filtering: false,
167            max_query_length: 0,
168            max_results: 0,
169        }
170    }
171
172    fn name(&self) -> &str {
173        "null"
174    }
175}
176
177/// In-memory RAG provider for testing with mock data.
178pub struct MockRagProvider {
179    results: Vec<RagMatch>,
180}
181
182impl MockRagProvider {
183    pub fn new() -> Self {
184        Self {
185            results: Vec::new(),
186        }
187    }
188
189    pub fn with_results(mut self, results: Vec<RagMatch>) -> Self {
190        self.results = results;
191        self
192    }
193
194    pub fn add_result(&mut self, block_id: BlockId, similarity: f32, preview: Option<&str>) {
195        self.results.push(RagMatch {
196            block_id,
197            similarity,
198            content_preview: preview.map(String::from),
199            semantic_role: None,
200            highlight_spans: Vec::new(),
201        });
202    }
203}
204
205impl Default for MockRagProvider {
206    fn default() -> Self {
207        Self::new()
208    }
209}
210
211#[async_trait]
212impl RagProvider for MockRagProvider {
213    async fn search(&self, query: &str, options: RagSearchOptions) -> Result<RagSearchResults> {
214        let matches: Vec<_> = self
215            .results
216            .iter()
217            .filter(|m| m.similarity >= options.min_similarity)
218            .take(options.limit)
219            .cloned()
220            .collect();
221
222        Ok(RagSearchResults {
223            matches,
224            query: query.to_string(),
225            total_searched: self.results.len(),
226            execution_time_ms: 1,
227        })
228    }
229
230    fn capabilities(&self) -> RagCapabilities {
231        RagCapabilities::default()
232    }
233
234    fn name(&self) -> &str {
235        "mock"
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242
243    fn block_id(s: &str) -> BlockId {
244        s.parse().unwrap_or_else(|_| {
245            // Create a deterministic ID from the input string for testing
246            let mut bytes = [0u8; 12];
247            let s_bytes = s.as_bytes();
248            for (i, b) in s_bytes.iter().enumerate() {
249                bytes[i % 12] ^= *b;
250            }
251            BlockId::from_bytes(bytes)
252        })
253    }
254
255    #[tokio::test]
256    async fn test_null_provider() {
257        let provider = NullRagProvider;
258        let result = provider
259            .search("test query", RagSearchOptions::new())
260            .await
261            .unwrap();
262
263        assert!(result.matches.is_empty());
264        assert_eq!(result.query, "test query");
265    }
266
267    #[tokio::test]
268    async fn test_mock_provider() {
269        let mut provider = MockRagProvider::new();
270        provider.add_result(block_id("blk_000000000001"), 0.9, Some("test content"));
271        provider.add_result(block_id("blk_000000000002"), 0.8, None);
272
273        let result = provider
274            .search("test", RagSearchOptions::new().with_limit(5))
275            .await
276            .unwrap();
277
278        assert_eq!(result.matches.len(), 2);
279        assert_eq!(result.matches[0].similarity, 0.9);
280    }
281
282    #[tokio::test]
283    async fn test_mock_provider_filtering() {
284        let mut provider = MockRagProvider::new();
285        provider.add_result(block_id("blk_000000000001"), 0.9, None);
286        provider.add_result(block_id("blk_000000000002"), 0.5, None);
287
288        let result = provider
289            .search("test", RagSearchOptions::new().with_min_similarity(0.7))
290            .await
291            .unwrap();
292
293        assert_eq!(result.matches.len(), 1);
294        assert_eq!(result.matches[0].similarity, 0.9);
295    }
296}