anchor_chain/vector/
opensearch_retriever.rs

1//! OpenSearchRetriever is a Node that retrieves documents from OpenSearch based on input text.
2//!
3//! The OpenSearchRetriever struct is a Node that retrieves documents from OpenSearch based on
4//! input text. It uses an embedding model to embed the input text into a vector, then queries
5//! OpenSearch using the vector. The top k documents that are most similar to the input text are
6//! returned.
7#![allow(dead_code)]
8
9use std::fmt;
10
11use async_trait::async_trait;
12use aws_config::meta::region::RegionProviderChain;
13use opensearch::http::transport::{SingleNodeConnectionPool, TransportBuilder};
14use opensearch::http::Url;
15use opensearch::{OpenSearch, SearchParts};
16use serde_json::json;
17#[cfg(feature = "tracing")]
18use tracing::instrument;
19
20use crate::error::AnchorChainError;
21use crate::models::embedding_model::EmbeddingModel;
22use crate::node::Node;
23use crate::vector::document::Document;
24
25/// A Node that retrieves documents from OpenSearch based on input text.
26#[derive(Debug)]
27pub struct OpenSearchRetriever<M: EmbeddingModel> {
28    client: OpenSearch,
29    embedding_model: M,
30    indexes: Vec<String>,
31    vector_field: String,
32    top_k: usize,
33}
34
35impl<M: EmbeddingModel + fmt::Debug> OpenSearchRetriever<M> {
36    /// Creates a new OpenSearchRetrieverBuilder using default AWS credentials from the environment.
37    pub async fn new(
38        client: OpenSearch,
39        embedding_model: M,
40        indexes: &[&str],
41        vector_field: &str,
42        top_k: usize,
43    ) -> Self {
44        Self {
45            client,
46            embedding_model,
47            indexes: indexes.iter().map(|s| s.to_string()).collect(),
48            vector_field: vector_field.to_string(),
49            top_k,
50        }
51    }
52
53    /// Queries OpenSearch for the top k documents that are most similar to the input vector.
54    #[cfg_attr(feature = "tracing", instrument(skip(self)))]
55    pub async fn vector_query(
56        &self,
57        indexes: &[String],
58        vector_field: &str,
59        top_k: usize,
60        vector: Vec<f32>,
61    ) -> Result<serde_json::Value, AnchorChainError> {
62        let indexes = indexes.iter().map(|s| s.as_str()).collect::<Vec<&str>>();
63        let response = self
64            .client
65            .search(SearchParts::Index(&indexes))
66            .from(0)
67            .size(10)
68            .body(json!({
69                "query": {
70                    "knn": {
71                        vector_field: {
72                            "vector": vector,
73                            "k": top_k,
74                        }
75                    }
76                }
77            }))
78            .send()
79            .await?;
80
81        Ok(response.json::<serde_json::Value>().await?)
82    }
83
84    /// Retrieves the top k documents from OpenSearch that are most similar to the input text.
85    ///
86    /// Uses the embedding model to embed the input text into a vector, then queries OpenSearch
87    /// using the vector.
88    #[cfg_attr(feature = "tracing", instrument)]
89    pub async fn retrieve(&self, input: &str) -> Result<Vec<Document>, AnchorChainError> {
90        let embedding = self.embedding_model.embed(input.to_string()).await?;
91        let response = self
92            .vector_query(&self.indexes, &self.vector_field, self.top_k, embedding)
93            .await?;
94        let empty_vec = Vec::new();
95        let hits = response["hits"]["hits"].as_array().unwrap_or(&empty_vec);
96        let docs = hits
97            .iter()
98            .filter_map(|doc| serde_json::from_value(doc["_source"].clone()).ok())
99            .collect();
100
101        Ok(docs)
102    }
103}
104
105#[async_trait]
106impl<M: EmbeddingModel + fmt::Debug + Send + Sync> Node for OpenSearchRetriever<M> {
107    type Input = String;
108    type Output = Vec<Document>;
109
110    /// Retrieves the top k documents from OpenSearch that are most similar to the input text.
111    #[cfg_attr(feature = "tracing", instrument(skip(self)))]
112    async fn process(&self, input: Self::Input) -> Result<Self::Output, AnchorChainError> {
113        self.retrieve(&input).await
114    }
115}
116
117/// A builder struct for creating an OpenSearchRetriever.
118struct OpenSearchRetrieverBuilder<M: EmbeddingModel> {
119    client: Option<OpenSearch>,
120    embedding_model: Option<M>,
121    vector_field: Option<String>,
122    indexes: Option<Vec<String>>,
123    top_k: usize,
124}
125
126impl<M: EmbeddingModel> OpenSearchRetrieverBuilder<M> {
127    /// Creates a new OpenSearchRetrieverBuilder.
128    pub fn new() -> Self {
129        Self {
130            client: None,
131            embedding_model: None,
132            vector_field: None,
133            indexes: None,
134            top_k: 5usize,
135        }
136    }
137
138    /// Sets the embedding model for the OpenSearchRetriever.
139    pub fn with_embedding_model(mut self, embedding_model: M) -> Self {
140        self.embedding_model = Some(embedding_model);
141        self
142    }
143
144    /// Sets the vector field for the OpenSearchRetriever.
145    pub fn with_vector_field(mut self, vector_field: String) -> Self {
146        self.vector_field = Some(vector_field);
147        self
148    }
149
150    /// Sets the indexes for the OpenSearchRetriever.
151    pub fn with_indexes(mut self, indexes: Vec<String>) -> Self {
152        self.indexes = Some(indexes);
153        self
154    }
155
156    /// Sets the top k for the OpenSearchRetriever.
157    pub fn with_top_k(mut self, top_k: usize) -> Self {
158        self.top_k = top_k;
159        self
160    }
161
162    /// Sets the OpenSearch client for the OpenSearchRetriever.
163    pub fn with_client(mut self, client: OpenSearch) -> Self {
164        self.client = Some(client);
165        self
166    }
167
168    /// Builds an OpenSearchRetriever from the provided configuration.
169    pub async fn build(self, base_url: &str) -> Result<OpenSearchRetriever<M>, AnchorChainError> {
170        let embedding_model = self
171            .embedding_model
172            .ok_or(AnchorChainError::InvalidInputError(
173                "No embedding model provided".to_string(),
174            ))?;
175        let vector_field = self
176            .vector_field
177            .ok_or(AnchorChainError::InvalidInputError(
178                "No vector field provided".to_string(),
179            ))?;
180        let indexes = self.indexes.ok_or(AnchorChainError::InvalidInputError(
181            "No indexes provided".to_string(),
182        ))?;
183
184        let url = Url::parse(base_url).map_err(|e| AnchorChainError::ParseError(e.to_string()))?;
185        let service_name = "es";
186        let conn_pool = SingleNodeConnectionPool::new(url);
187        let region_provider = RegionProviderChain::default_provider().or_else("us-east-1");
188        let aws_config = aws_config::from_env()
189            .region(region_provider)
190            .load()
191            .await
192            .clone();
193        let transport = TransportBuilder::new(conn_pool)
194            .auth(aws_config.clone().try_into()?)
195            .service_name(service_name)
196            .build()
197            .map_err(|e| AnchorChainError::OpenSearchError(e.into()))?;
198        let client = OpenSearch::new(transport);
199
200        Ok(OpenSearchRetriever {
201            client,
202            embedding_model,
203            vector_field,
204            indexes,
205            top_k: self.top_k,
206        })
207    }
208}