Skip to main content

cognis_rag/retrievers/
query_translator.rs

1//! Translate the query before retrieval.
2//!
3//! Common shape: take a free-text query, run it through a translator
4//! (rephraser, language-converter, structurer), then retrieve. Already
5//! expressible via `Runnable::pipe`, but a named type aids discovery.
6
7use std::sync::Arc;
8
9use async_trait::async_trait;
10
11use cognis_core::{Result, Runnable, RunnableConfig};
12
13use crate::document::Document;
14
15/// `query → translator → retriever`.
16pub struct QueryTranslatorRetriever {
17    translator: Arc<dyn Runnable<String, String>>,
18    inner: Arc<dyn Runnable<String, Vec<Document>>>,
19}
20
21impl QueryTranslatorRetriever {
22    /// Build with a translator and an inner retriever.
23    pub fn new(
24        translator: Arc<dyn Runnable<String, String>>,
25        inner: Arc<dyn Runnable<String, Vec<Document>>>,
26    ) -> Self {
27        Self { translator, inner }
28    }
29}
30
31#[async_trait]
32impl Runnable<String, Vec<Document>> for QueryTranslatorRetriever {
33    async fn invoke(&self, query: String, config: RunnableConfig) -> Result<Vec<Document>> {
34        let translated = self.translator.invoke(query, config.clone()).await?;
35        self.inner.invoke(translated, config).await
36    }
37    fn name(&self) -> &str {
38        "QueryTranslatorRetriever"
39    }
40}
41
42#[cfg(test)]
43mod tests {
44    use super::*;
45
46    struct UpperCase;
47    #[async_trait]
48    impl Runnable<String, String> for UpperCase {
49        async fn invoke(&self, q: String, _: RunnableConfig) -> Result<String> {
50            Ok(q.to_uppercase())
51        }
52    }
53
54    struct EchoRetriever;
55    #[async_trait]
56    impl Runnable<String, Vec<Document>> for EchoRetriever {
57        async fn invoke(&self, q: String, _: RunnableConfig) -> Result<Vec<Document>> {
58            Ok(vec![Document::new(q)])
59        }
60    }
61
62    #[tokio::test]
63    async fn translates_then_retrieves() {
64        let r = QueryTranslatorRetriever::new(Arc::new(UpperCase), Arc::new(EchoRetriever));
65        let out = r
66            .invoke("hello".into(), RunnableConfig::default())
67            .await
68            .unwrap();
69        assert_eq!(out[0].content, "HELLO");
70    }
71}