use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use atomr_core::actor::{Actor, ActorRef, Context, Props};
use tokio::sync::oneshot;
use atomr_accel_cuda::error::GpuError;
use crate::embedding_cache::EmbeddingCacheMsg;
use crate::vector_index::VectorIndexMsg;
#[derive(Debug, Clone)]
pub struct RagQuery {
pub text: String,
pub top_k: usize,
}
#[derive(Debug, Clone)]
pub struct RagAnswer {
pub answer: String,
pub sources: Vec<u64>,
pub embedding_was_cached: bool,
}
pub trait EmbeddingFn: Send + Sync + 'static {
fn embed(&self, text: &str) -> Result<Vec<f32>, GpuError>;
}
impl<F> EmbeddingFn for F
where
F: Fn(&str) -> Result<Vec<f32>, GpuError> + Send + Sync + 'static,
{
fn embed(&self, text: &str) -> Result<Vec<f32>, GpuError> {
self(text)
}
}
pub trait LlmFn: Send + Sync + 'static {
fn answer(&self, query: &str, contexts: &[u64]) -> Result<String, GpuError>;
}
impl<F> LlmFn for F
where
F: Fn(&str, &[u64]) -> Result<String, GpuError> + Send + Sync + 'static,
{
fn answer(&self, query: &str, contexts: &[u64]) -> Result<String, GpuError> {
self(query, contexts)
}
}
pub struct RagConfig {
pub embedding: Arc<dyn EmbeddingFn>,
pub embedding_cache: ActorRef<EmbeddingCacheMsg>,
pub vector_index: ActorRef<VectorIndexMsg>,
pub llm: Arc<dyn LlmFn>,
pub timeout: Duration,
}
impl Clone for RagConfig {
fn clone(&self) -> Self {
Self {
embedding: self.embedding.clone(),
embedding_cache: self.embedding_cache.clone(),
vector_index: self.vector_index.clone(),
llm: self.llm.clone(),
timeout: self.timeout,
}
}
}
pub enum RagMsg {
Query {
q: RagQuery,
reply: oneshot::Sender<Result<RagAnswer, GpuError>>,
},
}
pub struct RagPipeline {
cfg: RagConfig,
}
impl RagPipeline {
pub fn props(cfg: RagConfig) -> Props<Self> {
Props::create(move || RagPipeline { cfg: cfg.clone() })
}
}
#[async_trait]
impl Actor for RagPipeline {
type Msg = RagMsg;
async fn handle(&mut self, _ctx: &mut Context<Self>, msg: RagMsg) {
match msg {
RagMsg::Query { q, reply } => {
let cfg = self.cfg.clone();
tokio::spawn(async move {
let result = run_rag(cfg, q).await;
let _ = reply.send(result);
});
}
}
}
}
async fn run_rag(cfg: RagConfig, q: RagQuery) -> Result<RagAnswer, GpuError> {
let key = q.text.as_bytes().to_vec();
let cached: Option<Vec<f32>> = cfg
.embedding_cache
.ask_with(
move |tx| EmbeddingCacheMsg::Get { key, reply: tx },
cfg.timeout,
)
.await
.map_err(|e| GpuError::Unrecoverable(format!("rag: embed cache get: {e}")))?;
let (embedding, was_cached) = match cached {
Some(v) => (v, true),
None => {
let v = cfg.embedding.embed(&q.text)?;
let key = q.text.as_bytes().to_vec();
let v_clone = v.clone();
let _ = cfg
.embedding_cache
.ask_with(
move |tx| EmbeddingCacheMsg::Insert {
key,
value: v_clone,
reply: tx,
},
cfg.timeout,
)
.await;
(v, false)
}
};
let top_k = q.top_k.max(1);
let scored: Vec<(u64, f32)> = cfg
.vector_index
.ask_with(
move |tx| VectorIndexMsg::Search {
query: embedding,
top_k,
reply: tx,
},
cfg.timeout,
)
.await
.map_err(|e| GpuError::Unrecoverable(format!("rag: vec search: {e}")))??;
let sources: Vec<u64> = scored.into_iter().map(|(id, _)| id).collect();
let answer = cfg.llm.answer(&q.text, &sources)?;
Ok(RagAnswer {
answer,
sources,
embedding_was_cached: was_cached,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::embedding_cache::{EmbeddingCache, EmbeddingCacheConfig};
use crate::vector_index::{CpuVectorIndex, VectorEntry, VectorIndexMsg};
use atomr_config::Config;
use atomr_core::actor::ActorSystem;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn end_to_end_rag_query() {
let sys = ActorSystem::create("rag-test", Config::empty())
.await
.unwrap();
let cache = sys
.actor_of(
EmbeddingCache::props(EmbeddingCacheConfig {
capacity_entries: 8,
embedding_dim: 3,
}),
"cache",
)
.unwrap();
let index = sys.actor_of(CpuVectorIndex::props(3), "idx").unwrap();
for (id, e) in [
(1u64, vec![1.0, 0.0, 0.0]),
(2, vec![0.0, 1.0, 0.0]),
(3, vec![0.7, 0.7, 0.0]),
] {
let (tx, rx) = oneshot::channel();
index.tell(VectorIndexMsg::Insert {
entry: VectorEntry { id, embedding: e },
reply: tx,
});
tokio::time::timeout(Duration::from_secs(2), rx)
.await
.unwrap()
.unwrap()
.unwrap();
}
let embed_fn: Arc<dyn EmbeddingFn> = Arc::new(|text: &str| {
let v = match text {
"alpha" => vec![1.0, 0.0, 0.0],
"beta" => vec![0.0, 1.0, 0.0],
_ => vec![0.5, 0.5, 0.5],
};
Ok(v)
});
let llm: Arc<dyn LlmFn> =
Arc::new(|q: &str, ctx: &[u64]| Ok(format!("answered '{q}' from {ctx:?}")));
let rag = sys
.actor_of(
RagPipeline::props(RagConfig {
embedding: embed_fn,
embedding_cache: cache,
vector_index: index,
llm,
timeout: Duration::from_secs(2),
}),
"rag",
)
.unwrap();
let (tx, rx) = oneshot::channel();
rag.tell(RagMsg::Query {
q: RagQuery {
text: "alpha".into(),
top_k: 2,
},
reply: tx,
});
let r = tokio::time::timeout(Duration::from_secs(5), rx)
.await
.unwrap()
.unwrap()
.unwrap();
assert!(r.answer.contains("alpha"));
assert_eq!(r.sources[0], 1);
assert!(!r.embedding_was_cached);
let (tx, rx) = oneshot::channel();
rag.tell(RagMsg::Query {
q: RagQuery {
text: "alpha".into(),
top_k: 2,
},
reply: tx,
});
let r2 = tokio::time::timeout(Duration::from_secs(5), rx)
.await
.unwrap()
.unwrap()
.unwrap();
assert!(r2.embedding_was_cached);
sys.terminate().await;
}
}