use crate::core::error::{GraphRAGError, Result};
use crate::core::traits::AsyncRetriever;
use crate::retrieval::{RetrievalSystem, SearchResult as RetrievalSearchResult};
use async_trait::async_trait;
#[derive(Debug, Clone)]
pub struct RetrievalResult {
pub id: String,
pub content: String,
pub score: f32,
pub entities: Vec<String>,
}
impl From<RetrievalSearchResult> for RetrievalResult {
fn from(result: RetrievalSearchResult) -> Self {
Self {
id: result.id,
content: result.content,
score: result.score,
entities: result.entities,
}
}
}
pub struct RetrievalSystemAdapter {
system: RetrievalSystem,
}
impl RetrievalSystemAdapter {
pub fn new(system: RetrievalSystem) -> Self {
Self { system }
}
pub fn system(&self) -> &RetrievalSystem {
&self.system
}
pub fn system_mut(&mut self) -> &mut RetrievalSystem {
&mut self.system
}
}
#[async_trait]
impl AsyncRetriever for RetrievalSystemAdapter {
type Query = String;
type Result = RetrievalResult;
type Error = GraphRAGError;
async fn search(&self, _query: Self::Query, _k: usize) -> Result<Vec<Self::Result>> {
Ok(vec![])
}
async fn search_with_context(
&self,
query: Self::Query,
_context: &str,
k: usize,
) -> Result<Vec<Self::Result>> {
self.search(query, k).await
}
async fn search_batch(
&self,
queries: Vec<Self::Query>,
k: usize,
) -> Result<Vec<Vec<Self::Result>>> {
let mut results = Vec::with_capacity(queries.len());
for query in queries {
results.push(self.search(query, k).await?);
}
Ok(results)
}
async fn update(&mut self, _content: Vec<String>) -> Result<()> {
Ok(())
}
async fn health_check(&self) -> Result<bool> {
Ok(true)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::Config;
#[tokio::test]
async fn test_retrieval_adapter_creation() {
let config = Config::default();
let system = RetrievalSystem::new(&config).unwrap();
let adapter = RetrievalSystemAdapter::new(system);
assert!(adapter.health_check().await.unwrap());
}
#[tokio::test]
async fn test_search_batch() {
let config = Config::default();
let system = RetrievalSystem::new(&config).unwrap();
let adapter = RetrievalSystemAdapter::new(system);
let queries = vec![
"What is GraphRAG?".to_string(),
"How does retrieval work?".to_string(),
];
let results = adapter.search_batch(queries, 5).await.unwrap();
assert_eq!(results.len(), 2);
}
}