use crate::error::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use ucm_core::BlockId;
#[derive(Debug, Clone, Default)]
pub struct RagSearchOptions {
pub limit: usize,
pub min_similarity: f32,
pub filter_block_ids: Option<HashSet<BlockId>>,
pub filter_roles: Option<HashSet<String>>,
pub filter_tags: Option<HashSet<String>>,
pub include_content: bool,
}
impl RagSearchOptions {
pub fn new() -> Self {
Self {
limit: 10,
min_similarity: 0.0,
filter_block_ids: None,
filter_roles: None,
filter_tags: None,
include_content: true,
}
}
pub fn with_limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
pub fn with_min_similarity(mut self, threshold: f32) -> Self {
self.min_similarity = threshold;
self
}
pub fn with_roles(mut self, roles: impl IntoIterator<Item = String>) -> Self {
self.filter_roles = Some(roles.into_iter().collect());
self
}
pub fn with_tags(mut self, tags: impl IntoIterator<Item = String>) -> Self {
self.filter_tags = Some(tags.into_iter().collect());
self
}
pub fn with_block_ids(mut self, ids: impl IntoIterator<Item = BlockId>) -> Self {
self.filter_block_ids = Some(ids.into_iter().collect());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RagMatch {
pub block_id: BlockId,
pub similarity: f32,
pub content_preview: Option<String>,
pub semantic_role: Option<String>,
pub highlight_spans: Vec<(usize, usize)>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RagSearchResults {
pub matches: Vec<RagMatch>,
pub query: String,
pub total_searched: usize,
pub execution_time_ms: u64,
}
impl RagSearchResults {
pub fn empty(query: String) -> Self {
Self {
matches: Vec::new(),
query,
total_searched: 0,
execution_time_ms: 0,
}
}
pub fn block_ids(&self) -> Vec<BlockId> {
self.matches.iter().map(|m| m.block_id).collect()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RagCapabilities {
pub supports_search: bool,
pub supports_embedding: bool,
pub supports_filtering: bool,
pub max_query_length: usize,
pub max_results: usize,
}
impl Default for RagCapabilities {
fn default() -> Self {
Self {
supports_search: true,
supports_embedding: false,
supports_filtering: true,
max_query_length: 1000,
max_results: 100,
}
}
}
#[async_trait]
pub trait RagProvider: Send + Sync {
async fn search(&self, query: &str, options: RagSearchOptions) -> Result<RagSearchResults>;
async fn embed(&self, content: &str) -> Result<Vec<f32>> {
let _ = content;
Ok(Vec::new())
}
fn capabilities(&self) -> RagCapabilities;
fn name(&self) -> &str;
}
pub struct NullRagProvider;
#[async_trait]
impl RagProvider for NullRagProvider {
async fn search(&self, query: &str, _options: RagSearchOptions) -> Result<RagSearchResults> {
Ok(RagSearchResults::empty(query.to_string()))
}
fn capabilities(&self) -> RagCapabilities {
RagCapabilities {
supports_search: false,
supports_embedding: false,
supports_filtering: false,
max_query_length: 0,
max_results: 0,
}
}
fn name(&self) -> &str {
"null"
}
}
pub struct MockRagProvider {
results: Vec<RagMatch>,
}
impl MockRagProvider {
pub fn new() -> Self {
Self {
results: Vec::new(),
}
}
pub fn with_results(mut self, results: Vec<RagMatch>) -> Self {
self.results = results;
self
}
pub fn add_result(&mut self, block_id: BlockId, similarity: f32, preview: Option<&str>) {
self.results.push(RagMatch {
block_id,
similarity,
content_preview: preview.map(String::from),
semantic_role: None,
highlight_spans: Vec::new(),
});
}
}
impl Default for MockRagProvider {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl RagProvider for MockRagProvider {
async fn search(&self, query: &str, options: RagSearchOptions) -> Result<RagSearchResults> {
let matches: Vec<_> = self
.results
.iter()
.filter(|m| m.similarity >= options.min_similarity)
.take(options.limit)
.cloned()
.collect();
Ok(RagSearchResults {
matches,
query: query.to_string(),
total_searched: self.results.len(),
execution_time_ms: 1,
})
}
fn capabilities(&self) -> RagCapabilities {
RagCapabilities::default()
}
fn name(&self) -> &str {
"mock"
}
}
#[cfg(test)]
mod tests {
use super::*;
fn block_id(s: &str) -> BlockId {
s.parse().unwrap_or_else(|_| {
let mut bytes = [0u8; 12];
let s_bytes = s.as_bytes();
for (i, b) in s_bytes.iter().enumerate() {
bytes[i % 12] ^= *b;
}
BlockId::from_bytes(bytes)
})
}
#[tokio::test]
async fn test_null_provider() {
let provider = NullRagProvider;
let result = provider
.search("test query", RagSearchOptions::new())
.await
.unwrap();
assert!(result.matches.is_empty());
assert_eq!(result.query, "test query");
}
#[tokio::test]
async fn test_mock_provider() {
let mut provider = MockRagProvider::new();
provider.add_result(block_id("blk_000000000001"), 0.9, Some("test content"));
provider.add_result(block_id("blk_000000000002"), 0.8, None);
let result = provider
.search("test", RagSearchOptions::new().with_limit(5))
.await
.unwrap();
assert_eq!(result.matches.len(), 2);
assert_eq!(result.matches[0].similarity, 0.9);
}
#[tokio::test]
async fn test_mock_provider_filtering() {
let mut provider = MockRagProvider::new();
provider.add_result(block_id("blk_000000000001"), 0.9, None);
provider.add_result(block_id("blk_000000000002"), 0.5, None);
let result = provider
.search("test", RagSearchOptions::new().with_min_similarity(0.7))
.await
.unwrap();
assert_eq!(result.matches.len(), 1);
assert_eq!(result.matches[0].similarity, 0.9);
}
}