anchor_chain/vector/
opensearch_indexer.rs1use std::fmt;
6
7use async_trait::async_trait;
8use opensearch::http::request::JsonBody;
9use opensearch::indices::{IndicesCreateParts, IndicesExistsParts};
10use opensearch::{BulkParts, OpenSearch};
11use serde_json::json;
12#[cfg(feature = "tracing")]
13use tracing::instrument;
14
15use crate::error::AnchorChainError;
16use crate::models::embedding_model::EmbeddingModel;
17use crate::node::Node;
18use crate::vector::document::Document;
19
20#[derive(Debug, Clone)]
22pub struct OpenSearchIndexer<M: EmbeddingModel> {
23 client: OpenSearch,
24 embedding_model: M,
25 index: String,
26 vector_field: String,
27}
28
29impl<M: EmbeddingModel + fmt::Debug> OpenSearchIndexer<M> {
30 #[allow(dead_code)]
32 pub fn new(client: OpenSearch, embedding_model: M, index: &str, vector_field: &str) -> Self {
33 Self {
34 client,
35 embedding_model,
36 index: index.to_string(),
37 vector_field: vector_field.to_string(),
38 }
39 }
40
41 #[cfg_attr(feature = "tracing", instrument(skip(self)))]
43 async fn does_index_exist(&self, index: &str) -> Result<bool, AnchorChainError> {
44 let response = self
45 .client
46 .indices()
47 .exists(IndicesExistsParts::Index(&[index]))
48 .send()
49 .await?;
50 Ok(response.status_code().is_success())
51 }
52
53 #[cfg_attr(feature = "tracing", instrument(skip(self)))]
55 pub async fn create_index(
56 &self,
57 index: &str,
58 vector_field_name: &str,
59 ) -> Result<(), AnchorChainError> {
60 let body = json!({
61 "settings": {
62 "index.knn": true
63 },
64 "mappings": {
65 "properties": {
66 vector_field_name: {
67 "type": "knn_vector",
68 "dimension": self.embedding_model.dimensions(),
69 "method": {
70 "name": "hnsw",
71 "space_type": "cosinesimil",
72 "engine": "nmslib",
73 "parameters": {
74 "ef_construction": 128,
75 "m": 16
76 }
77 }
78 }
79 }
80 }
81 });
82
83 let response = self
84 .client
85 .indices()
86 .create(IndicesCreateParts::Index(index))
87 .body(body)
88 .send()
89 .await?;
90
91 if response.status_code().is_success() {
92 Ok(())
93 } else {
94 Err(AnchorChainError::OpenSearchInternalError(
95 response.text().await?,
96 ))
97 }
98 }
99
100 #[cfg_attr(feature = "tracing", instrument(skip(self)))]
103 pub async fn index_documents(
104 &self,
105 mut docs: Vec<Document>,
106 index: &str,
107 ) -> Result<Vec<Document>, AnchorChainError> {
108 let mut operations: Vec<JsonBody<_>> = Vec::with_capacity(docs.len() * 2);
109
110 for doc in &mut docs {
111 if doc.embedding.is_none() {
112 doc.embedding = Some(
113 self.embedding_model
114 .embed(doc.text.clone())
115 .await
116 .map_err(|e| AnchorChainError::ModelError(e.to_string()))?,
117 );
118 }
119
120 operations.push(
121 json!({
122 "index": {
123 "_index": index,
124 "_id": doc.id,
125 }
126 })
127 .into(),
128 );
129
130 let doc_json = serde_json::to_value(&doc)?;
131 operations.push(doc_json.into());
132 }
133
134 let response = self
135 .client
136 .bulk(BulkParts::Index(index))
137 .body(operations)
138 .send()
139 .await
140 .map_err(AnchorChainError::OpenSearchError)?;
141
142 if response.status_code().is_success() {
143 Ok(docs)
144 } else {
145 Err(AnchorChainError::OpenSearchInternalError(
146 response.text().await?,
147 ))
148 }
149 }
150}
151
152#[async_trait]
153impl<M: EmbeddingModel + fmt::Debug + Send + Sync> Node for OpenSearchIndexer<M> {
154 type Input = Vec<Document>;
155 type Output = Vec<Document>;
156
157 #[cfg_attr(feature = "tracing", instrument(skip(self)))]
163 async fn process(&self, input: Self::Input) -> Result<Self::Output, AnchorChainError> {
164 if !self.does_index_exist(&self.index).await? {
165 self.create_index(&self.index, &self.vector_field).await?;
166 }
167
168 self.index_documents(input, &self.index).await
169 }
170}