anchor_chain/vector/
opensearch_retriever.rs1#![allow(dead_code)]
8
9use std::fmt;
10
11use async_trait::async_trait;
12use aws_config::meta::region::RegionProviderChain;
13use opensearch::http::transport::{SingleNodeConnectionPool, TransportBuilder};
14use opensearch::http::Url;
15use opensearch::{OpenSearch, SearchParts};
16use serde_json::json;
17#[cfg(feature = "tracing")]
18use tracing::instrument;
19
20use crate::error::AnchorChainError;
21use crate::models::embedding_model::EmbeddingModel;
22use crate::node::Node;
23use crate::vector::document::Document;
24
25#[derive(Debug)]
27pub struct OpenSearchRetriever<M: EmbeddingModel> {
28 client: OpenSearch,
29 embedding_model: M,
30 indexes: Vec<String>,
31 vector_field: String,
32 top_k: usize,
33}
34
35impl<M: EmbeddingModel + fmt::Debug> OpenSearchRetriever<M> {
36 pub async fn new(
38 client: OpenSearch,
39 embedding_model: M,
40 indexes: &[&str],
41 vector_field: &str,
42 top_k: usize,
43 ) -> Self {
44 Self {
45 client,
46 embedding_model,
47 indexes: indexes.iter().map(|s| s.to_string()).collect(),
48 vector_field: vector_field.to_string(),
49 top_k,
50 }
51 }
52
53 #[cfg_attr(feature = "tracing", instrument(skip(self)))]
55 pub async fn vector_query(
56 &self,
57 indexes: &[String],
58 vector_field: &str,
59 top_k: usize,
60 vector: Vec<f32>,
61 ) -> Result<serde_json::Value, AnchorChainError> {
62 let indexes = indexes.iter().map(|s| s.as_str()).collect::<Vec<&str>>();
63 let response = self
64 .client
65 .search(SearchParts::Index(&indexes))
66 .from(0)
67 .size(10)
68 .body(json!({
69 "query": {
70 "knn": {
71 vector_field: {
72 "vector": vector,
73 "k": top_k,
74 }
75 }
76 }
77 }))
78 .send()
79 .await?;
80
81 Ok(response.json::<serde_json::Value>().await?)
82 }
83
84 #[cfg_attr(feature = "tracing", instrument)]
89 pub async fn retrieve(&self, input: &str) -> Result<Vec<Document>, AnchorChainError> {
90 let embedding = self.embedding_model.embed(input.to_string()).await?;
91 let response = self
92 .vector_query(&self.indexes, &self.vector_field, self.top_k, embedding)
93 .await?;
94 let empty_vec = Vec::new();
95 let hits = response["hits"]["hits"].as_array().unwrap_or(&empty_vec);
96 let docs = hits
97 .iter()
98 .filter_map(|doc| serde_json::from_value(doc["_source"].clone()).ok())
99 .collect();
100
101 Ok(docs)
102 }
103}
104
105#[async_trait]
106impl<M: EmbeddingModel + fmt::Debug + Send + Sync> Node for OpenSearchRetriever<M> {
107 type Input = String;
108 type Output = Vec<Document>;
109
110 #[cfg_attr(feature = "tracing", instrument(skip(self)))]
112 async fn process(&self, input: Self::Input) -> Result<Self::Output, AnchorChainError> {
113 self.retrieve(&input).await
114 }
115}
116
117struct OpenSearchRetrieverBuilder<M: EmbeddingModel> {
119 client: Option<OpenSearch>,
120 embedding_model: Option<M>,
121 vector_field: Option<String>,
122 indexes: Option<Vec<String>>,
123 top_k: usize,
124}
125
126impl<M: EmbeddingModel> OpenSearchRetrieverBuilder<M> {
127 pub fn new() -> Self {
129 Self {
130 client: None,
131 embedding_model: None,
132 vector_field: None,
133 indexes: None,
134 top_k: 5usize,
135 }
136 }
137
138 pub fn with_embedding_model(mut self, embedding_model: M) -> Self {
140 self.embedding_model = Some(embedding_model);
141 self
142 }
143
144 pub fn with_vector_field(mut self, vector_field: String) -> Self {
146 self.vector_field = Some(vector_field);
147 self
148 }
149
150 pub fn with_indexes(mut self, indexes: Vec<String>) -> Self {
152 self.indexes = Some(indexes);
153 self
154 }
155
156 pub fn with_top_k(mut self, top_k: usize) -> Self {
158 self.top_k = top_k;
159 self
160 }
161
162 pub fn with_client(mut self, client: OpenSearch) -> Self {
164 self.client = Some(client);
165 self
166 }
167
168 pub async fn build(self, base_url: &str) -> Result<OpenSearchRetriever<M>, AnchorChainError> {
170 let embedding_model = self
171 .embedding_model
172 .ok_or(AnchorChainError::InvalidInputError(
173 "No embedding model provided".to_string(),
174 ))?;
175 let vector_field = self
176 .vector_field
177 .ok_or(AnchorChainError::InvalidInputError(
178 "No vector field provided".to_string(),
179 ))?;
180 let indexes = self.indexes.ok_or(AnchorChainError::InvalidInputError(
181 "No indexes provided".to_string(),
182 ))?;
183
184 let url = Url::parse(base_url).map_err(|e| AnchorChainError::ParseError(e.to_string()))?;
185 let service_name = "es";
186 let conn_pool = SingleNodeConnectionPool::new(url);
187 let region_provider = RegionProviderChain::default_provider().or_else("us-east-1");
188 let aws_config = aws_config::from_env()
189 .region(region_provider)
190 .load()
191 .await
192 .clone();
193 let transport = TransportBuilder::new(conn_pool)
194 .auth(aws_config.clone().try_into()?)
195 .service_name(service_name)
196 .build()
197 .map_err(|e| AnchorChainError::OpenSearchError(e.into()))?;
198 let client = OpenSearch::new(transport);
199
200 Ok(OpenSearchRetriever {
201 client,
202 embedding_model,
203 vector_field,
204 indexes,
205 top_k: self.top_k,
206 })
207 }
208}