anchor_chain/vector/
opensearch_indexer.rs

1//! A node for indexing documents into OpenSearch.
2//!
3//! This node indexes a list of documents into OpenSearch. If the index doesn't exist, it is created
4//! with the default settings. Otherwise, the documents are indexed into the existing index.
5use std::fmt;
6
7use async_trait::async_trait;
8use opensearch::http::request::JsonBody;
9use opensearch::indices::{IndicesCreateParts, IndicesExistsParts};
10use opensearch::{BulkParts, OpenSearch};
11use serde_json::json;
12#[cfg(feature = "tracing")]
13use tracing::instrument;
14
15use crate::error::AnchorChainError;
16use crate::models::embedding_model::EmbeddingModel;
17use crate::node::Node;
18use crate::vector::document::Document;
19
20/// A node for indexing documents into OpenSearch.
21#[derive(Debug, Clone)]
22pub struct OpenSearchIndexer<M: EmbeddingModel> {
23    client: OpenSearch,
24    embedding_model: M,
25    index: String,
26    vector_field: String,
27}
28
29impl<M: EmbeddingModel + fmt::Debug> OpenSearchIndexer<M> {
30    /// Creates a new `OpenSearchIndexer` with the specified OpenSearch client, embedding model,
31    #[allow(dead_code)]
32    pub fn new(client: OpenSearch, embedding_model: M, index: &str, vector_field: &str) -> Self {
33        Self {
34            client,
35            embedding_model,
36            index: index.to_string(),
37            vector_field: vector_field.to_string(),
38        }
39    }
40
41    /// Checks if an index with the specified name exists in OpenSearch.
42    #[cfg_attr(feature = "tracing", instrument(skip(self)))]
43    async fn does_index_exist(&self, index: &str) -> Result<bool, AnchorChainError> {
44        let response = self
45            .client
46            .indices()
47            .exists(IndicesExistsParts::Index(&[index]))
48            .send()
49            .await?;
50        Ok(response.status_code().is_success())
51    }
52
53    /// Creates a vector index in OpenSearch with the specified name using default settings.
54    #[cfg_attr(feature = "tracing", instrument(skip(self)))]
55    pub async fn create_index(
56        &self,
57        index: &str,
58        vector_field_name: &str,
59    ) -> Result<(), AnchorChainError> {
60        let body = json!({
61            "settings": {
62                "index.knn": true
63            },
64            "mappings": {
65                "properties": {
66                    vector_field_name: {
67                        "type": "knn_vector",
68                        "dimension": self.embedding_model.dimensions(),
69                        "method": {
70                            "name": "hnsw",
71                            "space_type": "cosinesimil",
72                            "engine": "nmslib",
73                            "parameters": {
74                                "ef_construction": 128,
75                                "m": 16
76                            }
77                        }
78                    }
79                }
80            }
81        });
82
83        let response = self
84            .client
85            .indices()
86            .create(IndicesCreateParts::Index(index))
87            .body(body)
88            .send()
89            .await?;
90
91        if response.status_code().is_success() {
92            Ok(())
93        } else {
94            Err(AnchorChainError::OpenSearchInternalError(
95                response.text().await?,
96            ))
97        }
98    }
99
100    /// Automatically indexes a list of documents. It embeds the text into a vector if not already done,
101    /// then indexes the entire document into OpenSearch.
102    #[cfg_attr(feature = "tracing", instrument(skip(self)))]
103    pub async fn index_documents(
104        &self,
105        mut docs: Vec<Document>,
106        index: &str,
107    ) -> Result<Vec<Document>, AnchorChainError> {
108        let mut operations: Vec<JsonBody<_>> = Vec::with_capacity(docs.len() * 2);
109
110        for doc in &mut docs {
111            if doc.embedding.is_none() {
112                doc.embedding = Some(
113                    self.embedding_model
114                        .embed(doc.text.clone())
115                        .await
116                        .map_err(|e| AnchorChainError::ModelError(e.to_string()))?,
117                );
118            }
119
120            operations.push(
121                json!({
122                    "index": {
123                        "_index": index,
124                        "_id": doc.id,
125                    }
126                })
127                .into(),
128            );
129
130            let doc_json = serde_json::to_value(&doc)?;
131            operations.push(doc_json.into());
132        }
133
134        let response = self
135            .client
136            .bulk(BulkParts::Index(index))
137            .body(operations)
138            .send()
139            .await
140            .map_err(AnchorChainError::OpenSearchError)?;
141
142        if response.status_code().is_success() {
143            Ok(docs)
144        } else {
145            Err(AnchorChainError::OpenSearchInternalError(
146                response.text().await?,
147            ))
148        }
149    }
150}
151
152#[async_trait]
153impl<M: EmbeddingModel + fmt::Debug + Send + Sync> Node for OpenSearchIndexer<M> {
154    type Input = Vec<Document>;
155    type Output = Vec<Document>;
156
157    /// Indexes a list of documents into OpenSearch.
158    ///
159    /// If the index doesn't exist, it is created with the default settings. Otherwise,
160    /// the documents are indexed into the existing index.
161
162    #[cfg_attr(feature = "tracing", instrument(skip(self)))]
163    async fn process(&self, input: Self::Input) -> Result<Self::Output, AnchorChainError> {
164        if !self.does_index_exist(&self.index).await? {
165            self.create_index(&self.index, &self.vector_field).await?;
166        }
167
168        self.index_documents(input, &self.index).await
169    }
170}