use crate::core::Result;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use async_trait::async_trait;
pub type VectorMetadata = Option<HashMap<String, String>>;
pub type VectorBatch = Vec<(String, Vec<f32>, VectorMetadata)>;
pub trait Storage {
type Entity;
type Document;
type Chunk;
type Error: std::error::Error + Send + Sync + 'static;
fn store_entity(&mut self, entity: Self::Entity) -> Result<String>;
fn retrieve_entity(&self, id: &str) -> Result<Option<Self::Entity>>;
fn store_document(&mut self, document: Self::Document) -> Result<String>;
fn retrieve_document(&self, id: &str) -> Result<Option<Self::Document>>;
fn store_chunk(&mut self, chunk: Self::Chunk) -> Result<String>;
fn retrieve_chunk(&self, id: &str) -> Result<Option<Self::Chunk>>;
fn list_entities(&self) -> Result<Vec<String>>;
fn store_entities_batch(&mut self, entities: Vec<Self::Entity>) -> Result<Vec<String>>;
}
#[allow(async_fn_in_trait)]
#[async_trait]
pub trait AsyncStorage: Send + Sync {
type Entity: Send + Sync;
type Document: Send + Sync;
type Chunk: Send + Sync;
type Error: std::error::Error + Send + Sync + 'static;
async fn store_entity(&mut self, entity: Self::Entity) -> Result<String>;
async fn retrieve_entity(&self, id: &str) -> Result<Option<Self::Entity>>;
async fn store_document(&mut self, document: Self::Document) -> Result<String>;
async fn retrieve_document(&self, id: &str) -> Result<Option<Self::Document>>;
async fn store_chunk(&mut self, chunk: Self::Chunk) -> Result<String>;
async fn retrieve_chunk(&self, id: &str) -> Result<Option<Self::Chunk>>;
async fn list_entities(&self) -> Result<Vec<String>>;
async fn store_entities_batch(&mut self, entities: Vec<Self::Entity>) -> Result<Vec<String>>;
async fn health_check(&self) -> Result<bool> {
Ok(true)
}
async fn flush(&mut self) -> Result<()> {
Ok(())
}
}
pub trait Embedder {
type Error: std::error::Error + Send + Sync + 'static;
fn embed(&self, text: &str) -> Result<Vec<f32>>;
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>>;
fn dimension(&self) -> usize;
fn is_ready(&self) -> bool;
}
#[allow(async_fn_in_trait)]
#[async_trait]
pub trait AsyncEmbedder: Send + Sync {
type Error: std::error::Error + Send + Sync + 'static;
async fn embed(&self, text: &str) -> Result<Vec<f32>>;
async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>>;
async fn embed_batch_concurrent(
&self,
texts: &[&str],
max_concurrent: usize,
) -> Result<Vec<Vec<f32>>> {
if max_concurrent <= 1 {
return self.embed_batch(texts).await;
}
let chunks: Vec<_> = texts.chunks(max_concurrent).collect();
let mut results = Vec::with_capacity(texts.len());
for chunk in chunks {
let batch_results = self.embed_batch(chunk).await?;
results.extend(batch_results);
}
Ok(results)
}
fn dimension(&self) -> usize;
async fn is_ready(&self) -> bool;
async fn health_check(&self) -> Result<bool> {
self.is_ready()
.await
.then_some(true)
.ok_or_else(|| crate::core::GraphRAGError::Retrieval {
message: "Embedding service health check failed".to_string(),
})
}
}
pub trait VectorStore {
type Error: std::error::Error + Send + Sync + 'static;
fn add_vector(&mut self, id: String, vector: Vec<f32>, metadata: VectorMetadata) -> Result<()>;
fn add_vectors_batch(&mut self, vectors: VectorBatch) -> Result<()>;
fn search(&self, query_vector: &[f32], k: usize) -> Result<Vec<SearchResult>>;
fn search_with_threshold(
&self,
query_vector: &[f32],
k: usize,
threshold: f32,
) -> Result<Vec<SearchResult>>;
fn remove_vector(&mut self, id: &str) -> Result<bool>;
fn len(&self) -> usize;
fn is_empty(&self) -> bool;
}
#[allow(async_fn_in_trait)]
#[async_trait]
pub trait AsyncVectorStore: Send + Sync {
type Error: std::error::Error + Send + Sync + 'static;
async fn add_vector(
&mut self,
id: String,
vector: Vec<f32>,
metadata: VectorMetadata,
) -> Result<()>;
async fn add_vectors_batch(&mut self, vectors: VectorBatch) -> Result<()>;
async fn add_vectors_batch_concurrent(
&mut self,
vectors: VectorBatch,
max_concurrent: usize,
) -> Result<()> {
if max_concurrent <= 1 {
return self.add_vectors_batch(vectors).await;
}
for chunk in vectors.chunks(max_concurrent) {
self.add_vectors_batch(chunk.to_vec()).await?;
}
Ok(())
}
async fn search(&self, query_vector: &[f32], k: usize) -> Result<Vec<SearchResult>>;
async fn search_with_threshold(
&self,
query_vector: &[f32],
k: usize,
threshold: f32,
) -> Result<Vec<SearchResult>>;
async fn search_batch(
&self,
query_vectors: &[Vec<f32>],
k: usize,
) -> Result<Vec<Vec<SearchResult>>> {
let mut results = Vec::with_capacity(query_vectors.len());
for query in query_vectors {
let search_results = self.search(query, k).await?;
results.push(search_results);
}
Ok(results)
}
async fn remove_vector(&mut self, id: &str) -> Result<bool>;
async fn remove_vectors_batch(&mut self, ids: &[&str]) -> Result<Vec<bool>> {
let mut results = Vec::with_capacity(ids.len());
for id in ids {
let removed = self.remove_vector(id).await?;
results.push(removed);
}
Ok(results)
}
async fn len(&self) -> usize;
async fn is_empty(&self) -> bool {
self.len().await == 0
}
async fn health_check(&self) -> Result<bool> {
Ok(true)
}
async fn build_index(&mut self) -> Result<()> {
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct SearchResult {
pub id: String,
pub distance: f32,
pub metadata: Option<HashMap<String, String>>,
}
pub trait EntityExtractor {
type Entity;
type Error: std::error::Error + Send + Sync + 'static;
fn extract(&self, text: &str) -> Result<Vec<Self::Entity>>;
fn extract_with_confidence(&self, text: &str) -> Result<Vec<(Self::Entity, f32)>>;
fn set_confidence_threshold(&mut self, threshold: f32);
}
#[allow(async_fn_in_trait)]
#[async_trait]
pub trait AsyncEntityExtractor: Send + Sync {
type Entity: Send + Sync;
type Error: std::error::Error + Send + Sync + 'static;
async fn extract(&self, text: &str) -> Result<Vec<Self::Entity>>;
async fn extract_with_confidence(&self, text: &str) -> Result<Vec<(Self::Entity, f32)>>;
async fn extract_batch(&self, texts: &[&str]) -> Result<Vec<Vec<Self::Entity>>> {
let mut results = Vec::with_capacity(texts.len());
for text in texts {
let entities = self.extract(text).await?;
results.push(entities);
}
Ok(results)
}
async fn extract_batch_concurrent(
&self,
texts: &[&str],
max_concurrent: usize,
) -> Result<Vec<Vec<Self::Entity>>> {
if max_concurrent <= 1 {
return self.extract_batch(texts).await;
}
let chunks: Vec<_> = texts.chunks(max_concurrent).collect();
let mut results = Vec::with_capacity(texts.len());
for chunk in chunks {
let batch_results = self.extract_batch(chunk).await?;
results.extend(batch_results);
}
Ok(results)
}
async fn set_confidence_threshold(&mut self, threshold: f32);
async fn get_confidence_threshold(&self) -> f32;
async fn health_check(&self) -> Result<bool> {
Ok(true)
}
}
pub trait Retriever {
type Query;
type Result;
type Error: std::error::Error + Send + Sync + 'static;
fn search(&self, query: Self::Query, k: usize) -> Result<Vec<Self::Result>>;
fn search_with_context(
&self,
query: Self::Query,
context: &str,
k: usize,
) -> Result<Vec<Self::Result>>;
fn update(&mut self, content: Vec<String>) -> Result<()>;
}
#[allow(async_fn_in_trait)]
#[async_trait]
pub trait AsyncRetriever: Send + Sync {
type Query: Send + Sync;
type Result: Send + Sync;
type Error: std::error::Error + Send + Sync + 'static;
async fn search(&self, query: Self::Query, k: usize) -> Result<Vec<Self::Result>>;
async fn search_with_context(
&self,
query: Self::Query,
context: &str,
k: usize,
) -> Result<Vec<Self::Result>>;
async fn search_batch(
&self,
queries: Vec<Self::Query>,
k: usize,
) -> Result<Vec<Vec<Self::Result>>> {
let mut results = Vec::with_capacity(queries.len());
for query in queries {
let search_results = self.search(query, k).await?;
results.push(search_results);
}
Ok(results)
}
async fn update(&mut self, content: Vec<String>) -> Result<()>;
async fn update_batch(&mut self, content_batches: Vec<Vec<String>>) -> Result<()> {
for batch in content_batches {
self.update(batch).await?;
}
Ok(())
}
async fn refresh_index(&mut self) -> Result<()> {
Ok(())
}
async fn health_check(&self) -> Result<bool> {
Ok(true)
}
async fn get_stats(&self) -> Result<RetrievalStats> {
Ok(RetrievalStats::default())
}
}
#[derive(Debug, Clone, Default)]
pub struct RetrievalStats {
pub total_queries: u64,
pub average_response_time_ms: f64,
pub index_size: usize,
pub cache_hit_rate: f64,
}
pub trait LanguageModel {
type Error: std::error::Error + Send + Sync + 'static;
fn complete(&self, prompt: &str) -> Result<String>;
fn complete_with_params(&self, prompt: &str, params: GenerationParams) -> Result<String>;
fn is_available(&self) -> bool;
fn model_info(&self) -> ModelInfo;
}
#[allow(async_fn_in_trait)]
#[async_trait]
pub trait AsyncLanguageModel: Send + Sync {
type Error: std::error::Error + Send + Sync + 'static;
async fn complete(&self, prompt: &str) -> Result<String>;
async fn complete_with_params(&self, prompt: &str, params: GenerationParams) -> Result<String>;
async fn complete_batch(&self, prompts: &[&str]) -> Result<Vec<String>> {
let mut results = Vec::with_capacity(prompts.len());
for prompt in prompts {
let completion = self.complete(prompt).await?;
results.push(completion);
}
Ok(results)
}
async fn complete_batch_concurrent(
&self,
prompts: &[&str],
max_concurrent: usize,
) -> Result<Vec<String>> {
if max_concurrent <= 1 {
return self.complete_batch(prompts).await;
}
let chunks: Vec<_> = prompts.chunks(max_concurrent).collect();
let mut results = Vec::with_capacity(prompts.len());
for chunk in chunks {
let batch_results = self.complete_batch(chunk).await?;
results.extend(batch_results);
}
Ok(results)
}
async fn complete_streaming(
&self,
prompt: &str,
) -> Result<Pin<Box<dyn futures::Stream<Item = Result<String>> + Send>>> {
let result = self.complete(prompt).await?;
let stream = futures::stream::once(async move { Ok(result) });
Ok(Box::pin(stream))
}
async fn is_available(&self) -> bool;
async fn model_info(&self) -> ModelInfo;
async fn health_check(&self) -> Result<bool> {
self.is_available().await.then_some(true).ok_or_else(|| {
crate::core::GraphRAGError::Generation {
message: "Language model health check failed".to_string(),
}
})
}
async fn get_usage_stats(&self) -> Result<ModelUsageStats> {
Ok(ModelUsageStats::default())
}
async fn estimate_tokens(&self, prompt: &str) -> Result<usize> {
Ok(prompt.len() / 4)
}
}
#[derive(Debug, Clone, Default)]
pub struct ModelUsageStats {
pub total_requests: u64,
pub total_tokens_processed: u64,
pub average_response_time_ms: f64,
pub error_rate: f64,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct GenerationParams {
pub max_tokens: Option<usize>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub stop_sequences: Option<Vec<String>>,
}
impl Default for GenerationParams {
fn default() -> Self {
Self {
max_tokens: Some(1000),
temperature: Some(0.7),
top_p: Some(0.9),
stop_sequences: None,
}
}
}
#[derive(Debug, Clone)]
pub struct ModelInfo {
pub name: String,
pub version: Option<String>,
pub max_context_length: Option<usize>,
pub supports_streaming: bool,
}
pub trait GraphStore {
type Node;
type Edge;
type Error: std::error::Error + Send + Sync + 'static;
fn add_node(&mut self, node: Self::Node) -> Result<String>;
fn add_edge(&mut self, from_id: &str, to_id: &str, edge: Self::Edge) -> Result<String>;
fn find_nodes(&self, criteria: &str) -> Result<Vec<Self::Node>>;
fn get_neighbors(&self, node_id: &str) -> Result<Vec<Self::Node>>;
fn traverse(&self, start_id: &str, max_depth: usize) -> Result<Vec<Self::Node>>;
fn stats(&self) -> GraphStats;
}
#[allow(async_fn_in_trait)]
#[async_trait]
pub trait AsyncGraphStore: Send + Sync {
type Node: Send + Sync;
type Edge: Send + Sync;
type Error: std::error::Error + Send + Sync + 'static;
async fn add_node(&mut self, node: Self::Node) -> Result<String>;
async fn add_nodes_batch(&mut self, nodes: Vec<Self::Node>) -> Result<Vec<String>> {
let mut ids = Vec::with_capacity(nodes.len());
for node in nodes {
let id = self.add_node(node).await?;
ids.push(id);
}
Ok(ids)
}
async fn add_edge(&mut self, from_id: &str, to_id: &str, edge: Self::Edge) -> Result<String>;
async fn add_edges_batch(
&mut self,
edges: Vec<(String, String, Self::Edge)>,
) -> Result<Vec<String>> {
let mut ids = Vec::with_capacity(edges.len());
for (from_id, to_id, edge) in edges {
let id = self.add_edge(&from_id, &to_id, edge).await?;
ids.push(id);
}
Ok(ids)
}
async fn find_nodes(&self, criteria: &str) -> Result<Vec<Self::Node>>;
async fn find_nodes_batch(&self, criteria_list: &[&str]) -> Result<Vec<Vec<Self::Node>>> {
let mut results = Vec::with_capacity(criteria_list.len());
for criteria in criteria_list {
let nodes = self.find_nodes(criteria).await?;
results.push(nodes);
}
Ok(results)
}
async fn get_neighbors(&self, node_id: &str) -> Result<Vec<Self::Node>>;
async fn get_neighbors_batch(&self, node_ids: &[&str]) -> Result<Vec<Vec<Self::Node>>> {
let mut results = Vec::with_capacity(node_ids.len());
for node_id in node_ids {
let neighbors = self.get_neighbors(node_id).await?;
results.push(neighbors);
}
Ok(results)
}
async fn traverse(&self, start_id: &str, max_depth: usize) -> Result<Vec<Self::Node>>;
async fn traverse_batch(
&self,
start_ids: &[&str],
max_depth: usize,
) -> Result<Vec<Vec<Self::Node>>> {
let mut results = Vec::with_capacity(start_ids.len());
for start_id in start_ids {
let traversal = self.traverse(start_id, max_depth).await?;
results.push(traversal);
}
Ok(results)
}
async fn stats(&self) -> GraphStats;
async fn health_check(&self) -> Result<bool> {
Ok(true)
}
async fn optimize(&mut self) -> Result<()> {
Ok(())
}
async fn export(&self) -> Result<Vec<u8>> {
Ok(Vec::new())
}
#[allow(clippy::disallowed_names)]
async fn import(&mut self, data: &[u8]) -> Result<()> {
let _ = data; Ok(())
}
}
#[derive(Debug, Clone)]
pub struct GraphStats {
pub node_count: usize,
pub edge_count: usize,
pub average_degree: f32,
pub max_depth: usize,
}
pub trait FunctionRegistry {
type Function;
type CallResult;
type Error: std::error::Error + Send + Sync + 'static;
fn register(&mut self, name: String, function: Self::Function) -> Result<()>;
fn call(&self, name: &str, args: &str) -> Result<Self::CallResult>;
fn list_functions(&self) -> Vec<String>;
fn has_function(&self, name: &str) -> bool;
}
#[allow(async_fn_in_trait)]
#[async_trait]
pub trait AsyncFunctionRegistry: Send + Sync {
type Function: Send + Sync;
type CallResult: Send + Sync;
type Error: std::error::Error + Send + Sync + 'static;
async fn register(&mut self, name: String, function: Self::Function) -> Result<()>;
async fn call(&self, name: &str, args: &str) -> Result<Self::CallResult>;
async fn call_batch(&self, calls: &[(&str, &str)]) -> Result<Vec<Self::CallResult>> {
let mut results = Vec::with_capacity(calls.len());
for (name, args) in calls {
let result = self.call(name, args).await?;
results.push(result);
}
Ok(results)
}
async fn list_functions(&self) -> Vec<String>;
async fn has_function(&self, name: &str) -> bool;
async fn get_function_info(&self, name: &str) -> Result<Option<FunctionInfo>>;
async fn health_check(&self) -> Result<bool> {
Ok(true)
}
async fn validate_args(&self, name: &str, args: &str) -> Result<bool> {
let _ = (name, args); Ok(true)
}
}
#[derive(Debug, Clone)]
pub struct FunctionInfo {
pub name: String,
pub description: Option<String>,
pub parameters: Vec<ParameterInfo>,
pub return_type: Option<String>,
}
#[derive(Debug, Clone)]
pub struct ParameterInfo {
pub name: String,
pub param_type: String,
pub description: Option<String>,
pub required: bool,
}
pub trait ConfigProvider {
type Config;
type Error: std::error::Error + Send + Sync + 'static;
fn load(&self) -> Result<Self::Config>;
fn save(&self, config: &Self::Config) -> Result<()>;
fn validate(&self, config: &Self::Config) -> Result<()>;
fn default_config(&self) -> Self::Config;
}
#[allow(async_fn_in_trait)]
#[async_trait]
pub trait AsyncConfigProvider: Send + Sync {
type Config: Send + Sync;
type Error: std::error::Error + Send + Sync + 'static;
async fn load(&self) -> Result<Self::Config>;
async fn save(&self, config: &Self::Config) -> Result<()>;
async fn validate(&self, config: &Self::Config) -> Result<()>;
async fn default_config(&self) -> Self::Config;
async fn watch_changes(
&self,
) -> Result<Pin<Box<dyn futures::Stream<Item = Result<Self::Config>> + Send + 'static>>>
where
Self::Config: 'static,
{
let stream = futures::stream::empty::<Result<Self::Config>>();
Ok(Box::pin(stream))
}
async fn reload(&self) -> Result<Self::Config> {
self.load().await
}
async fn health_check(&self) -> Result<bool> {
Ok(true)
}
}
pub trait MetricsCollector {
fn counter(&self, name: &str, value: u64, tags: Option<&[(&str, &str)]>);
fn gauge(&self, name: &str, value: f64, tags: Option<&[(&str, &str)]>);
fn histogram(&self, name: &str, value: f64, tags: Option<&[(&str, &str)]>);
fn timer(&self, name: &str) -> Timer;
}
#[allow(async_fn_in_trait)]
#[async_trait]
pub trait AsyncMetricsCollector: Send + Sync {
async fn counter(&self, name: &str, value: u64, tags: Option<&[(&str, &str)]>);
async fn gauge(&self, name: &str, value: f64, tags: Option<&[(&str, &str)]>);
async fn histogram(&self, name: &str, value: f64, tags: Option<&[(&str, &str)]>);
async fn record_batch(&self, metrics: &[MetricRecord]) {
for metric in metrics {
match metric {
MetricRecord::Counter { name, value, tags } => {
let tags_refs: Option<Vec<(&str, &str)>> = tags
.as_ref()
.map(|t| t.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect());
self.counter(name, *value, tags_refs.as_deref()).await;
},
MetricRecord::Gauge { name, value, tags } => {
let tags_refs: Option<Vec<(&str, &str)>> = tags
.as_ref()
.map(|t| t.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect());
self.gauge(name, *value, tags_refs.as_deref()).await;
},
MetricRecord::Histogram { name, value, tags } => {
let tags_refs: Option<Vec<(&str, &str)>> = tags
.as_ref()
.map(|t| t.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect());
self.histogram(name, *value, tags_refs.as_deref()).await;
},
}
}
}
async fn timer(&self, name: &str) -> AsyncTimer;
async fn health_check(&self) -> Result<bool> {
Ok(true)
}
async fn flush(&self) -> Result<()> {
Ok(())
}
}
#[derive(Debug, Clone)]
pub enum MetricRecord {
Counter {
name: String,
value: u64,
tags: Option<Vec<(String, String)>>,
},
Gauge {
name: String,
value: f64,
tags: Option<Vec<(String, String)>>,
},
Histogram {
name: String,
value: f64,
tags: Option<Vec<(String, String)>>,
},
}
pub struct AsyncTimer {
name: String,
start: std::time::Instant,
}
impl AsyncTimer {
pub fn new(name: String) -> Self {
Self {
name,
start: std::time::Instant::now(),
}
}
pub async fn finish(self) -> std::time::Duration {
self.start.elapsed()
}
pub fn name(&self) -> &str {
&self.name
}
}
pub struct Timer {
#[allow(dead_code)]
name: String,
start: std::time::Instant,
}
impl Timer {
pub fn new(name: String) -> Self {
Self {
name,
start: std::time::Instant::now(),
}
}
pub fn finish(self) -> std::time::Duration {
self.start.elapsed()
}
}
pub trait Serializer {
type Error: std::error::Error + Send + Sync + 'static;
fn serialize<T: serde::Serialize>(&self, data: &T) -> Result<String>;
fn deserialize<T: serde::de::DeserializeOwned>(&self, data: &str) -> Result<T>;
fn extension(&self) -> &'static str;
}
#[allow(async_fn_in_trait)]
#[async_trait]
pub trait AsyncSerializer: Send + Sync {
type Error: std::error::Error + Send + Sync + 'static;
async fn serialize<T: serde::Serialize + Send + Sync>(&self, data: &T) -> Result<String>;
async fn deserialize<T: serde::de::DeserializeOwned + Send + Sync>(
&self,
data: &str,
) -> Result<T>;
#[allow(clippy::disallowed_names)]
async fn serialize_bytes<T: serde::Serialize + Send + Sync>(
&self,
data: &T,
) -> Result<Vec<u8>> {
let string = self.serialize(data).await?;
Ok(string.into_bytes())
}
#[allow(clippy::disallowed_names)]
async fn deserialize_bytes<T: serde::de::DeserializeOwned + Send + Sync>(
&self,
data: &[u8],
) -> Result<T> {
let string = String::from_utf8(data.to_vec()).map_err(|e| {
crate::core::GraphRAGError::Serialization {
message: format!("Invalid UTF-8 data: {e}"),
}
})?;
self.deserialize(&string).await
}
#[allow(clippy::disallowed_names)]
async fn serialize_batch<T: serde::Serialize + Send + Sync>(
&self,
data: &[T],
) -> Result<Vec<String>> {
let mut results = Vec::with_capacity(data.len());
for item in data {
let serialized = self.serialize(item).await?;
results.push(serialized);
}
Ok(results)
}
fn extension(&self) -> &'static str;
async fn health_check(&self) -> Result<bool> {
Ok(true)
}
}
pub mod sync_to_async {
use super::*;
use std::sync::Arc;
pub struct StorageAdapter<T>(pub Arc<tokio::sync::Mutex<T>>);
#[async_trait]
impl<T> AsyncStorage for StorageAdapter<T>
where
T: Storage + Send + Sync + 'static,
T::Entity: Send + Sync,
T::Document: Send + Sync,
T::Chunk: Send + Sync,
{
type Entity = T::Entity;
type Document = T::Document;
type Chunk = T::Chunk;
type Error = T::Error;
async fn store_entity(&mut self, entity: Self::Entity) -> Result<String> {
let mut storage = self.0.lock().await;
storage.store_entity(entity)
}
async fn retrieve_entity(&self, id: &str) -> Result<Option<Self::Entity>> {
let storage = self.0.lock().await;
storage.retrieve_entity(id)
}
async fn store_document(&mut self, document: Self::Document) -> Result<String> {
let mut storage = self.0.lock().await;
storage.store_document(document)
}
async fn retrieve_document(&self, id: &str) -> Result<Option<Self::Document>> {
let storage = self.0.lock().await;
storage.retrieve_document(id)
}
async fn store_chunk(&mut self, chunk: Self::Chunk) -> Result<String> {
let mut storage = self.0.lock().await;
storage.store_chunk(chunk)
}
async fn retrieve_chunk(&self, id: &str) -> Result<Option<Self::Chunk>> {
let storage = self.0.lock().await;
storage.retrieve_chunk(id)
}
async fn list_entities(&self) -> Result<Vec<String>> {
let storage = self.0.lock().await;
storage.list_entities()
}
async fn store_entities_batch(
&mut self,
entities: Vec<Self::Entity>,
) -> Result<Vec<String>> {
let mut storage = self.0.lock().await;
storage.store_entities_batch(entities)
}
}
pub struct LanguageModelAdapter<T>(pub Arc<T>);
#[async_trait]
impl<T> AsyncLanguageModel for LanguageModelAdapter<T>
where
T: LanguageModel + Send + Sync + 'static,
{
type Error = T::Error;
async fn complete(&self, prompt: &str) -> Result<String> {
self.0.complete(prompt)
}
async fn complete_with_params(
&self,
prompt: &str,
params: GenerationParams,
) -> Result<String> {
self.0.complete_with_params(prompt, params)
}
async fn is_available(&self) -> bool {
self.0.is_available()
}
async fn model_info(&self) -> ModelInfo {
self.0.model_info()
}
}
}
pub mod async_utils {
use super::*;
use std::time::Duration;
pub async fn with_timeout<F, T>(future: F, timeout: Duration) -> Result<T>
where
F: Future<Output = Result<T>>,
{
match tokio::time::timeout(timeout, future).await {
Ok(result) => result,
Err(_) => Err(crate::core::GraphRAGError::Timeout {
operation: "async operation".to_string(),
duration: timeout,
}),
}
}
pub async fn with_retry<F, T, E>(
mut operation: F,
max_retries: usize,
delay: Duration,
) -> std::result::Result<T, E>
where
F: FnMut() -> Pin<Box<dyn Future<Output = std::result::Result<T, E>> + Send>>,
E: std::fmt::Debug,
{
let mut attempts = 0;
loop {
match operation().await {
Ok(result) => return Ok(result),
Err(err) => {
attempts += 1;
if attempts >= max_retries {
return Err(err);
}
tokio::time::sleep(delay).await;
},
}
}
}
pub async fn process_batch_with_rate_limit<T, F, R>(
items: Vec<T>,
processor: F,
max_concurrent: usize,
rate_limit: Option<Duration>,
) -> Vec<Result<R>>
where
T: Send + 'static,
F: Fn(T) -> Pin<Box<dyn Future<Output = Result<R>> + Send>> + Send + Sync + 'static,
R: Send + 'static,
{
use futures::stream::{FuturesUnordered, StreamExt};
use std::sync::Arc;
let processor = Arc::new(processor);
let mut futures = FuturesUnordered::new();
let mut results = Vec::with_capacity(items.len());
let mut pending = 0;
for item in items {
if pending >= max_concurrent {
if let Some(result) = futures.next().await {
results.push(result);
pending -= 1;
}
}
let processor_clone = Arc::clone(&processor);
futures.push(async move {
if let Some(delay) = rate_limit {
tokio::time::sleep(delay).await;
}
processor_clone(item).await
});
pending += 1;
}
while let Some(result) = futures.next().await {
results.push(result);
}
results
}
}
pub type BoxedAsyncLanguageModel =
Box<dyn AsyncLanguageModel<Error = crate::core::GraphRAGError> + Send + Sync>;
pub type BoxedAsyncEmbedder =
Box<dyn AsyncEmbedder<Error = crate::core::GraphRAGError> + Send + Sync>;
pub type BoxedAsyncVectorStore =
Box<dyn AsyncVectorStore<Error = crate::core::GraphRAGError> + Send + Sync>;
pub type BoxedAsyncRetriever = Box<
dyn AsyncRetriever<
Query = String,
Result = crate::retrieval::SearchResult,
Error = crate::core::GraphRAGError,
> + Send
+ Sync,
>;