1mod client;
24
25pub use client::{
27 DeleteByIdsRequest, DeleteResult, ListVectorsResult, QueryRequest, QueryResult, ReturnMetadata,
28 UpsertRequest, UpsertResult, VectorIdEntry, VectorInput, VectorMatch, VectorizeClient,
29 VectorizeError, VectorizeFilter,
30};
31
32use client::{QueryRequest as ApiQueryRequest, VectorInput as ApiVectorInput};
33use rig::embeddings::EmbeddingModel;
34use rig::vector_store::request::VectorSearchRequest;
35use rig::vector_store::{InsertDocuments, VectorStoreError, VectorStoreIndex};
36use rig::{Embed, OneOrMany, embeddings::Embedding};
37use serde::{Deserialize, Serialize};
38use uuid::Uuid;
39
40impl From<VectorizeError> for VectorStoreError {
41 fn from(err: VectorizeError) -> Self {
42 VectorStoreError::DatastoreError(Box::new(err))
43 }
44}
45
46#[derive(Debug, Clone)]
51pub struct VectorizeVectorStore<M> {
52 model: M,
54 client: VectorizeClient,
56}
57
58impl<M> VectorizeVectorStore<M> {
59 pub fn new(
67 model: M,
68 account_id: impl Into<String>,
69 index_name: impl Into<String>,
70 api_token: impl Into<String>,
71 ) -> Self {
72 Self {
73 model,
74 client: VectorizeClient::new(account_id, index_name, api_token),
75 }
76 }
77}
78
79impl<M> VectorStoreIndex for VectorizeVectorStore<M>
80where
81 M: EmbeddingModel + Sync + Send,
82{
83 type Filter = VectorizeFilter;
84
85 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
86 &self,
87 req: VectorSearchRequest<Self::Filter>,
88 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
89 if let Some(filter) = req.filter() {
90 filter.validate()?;
91 }
92
93 let embedding = self.model.embed_text(req.query()).await?;
94
95 let query_request = ApiQueryRequest {
96 vector: embedding.vec,
97 top_k: req.samples(),
98 return_values: Some(false),
99 return_metadata: Some(ReturnMetadata::All),
100 filter: req.filter().as_ref().map(|f| f.clone().into_inner()),
101 };
102
103 let result = self.client.query(query_request).await?;
104
105 let results = result
107 .matches
108 .into_iter()
109 .filter(|m| req.threshold().is_none_or(|t| m.score >= t))
110 .map(|m| {
111 let metadata = m.metadata.unwrap_or(serde_json::Value::Null);
112 let doc: T = serde_json::from_value(metadata)?;
113 Ok((m.score, m.id, doc))
114 })
115 .collect::<Result<Vec<_>, serde_json::Error>>()?;
116
117 Ok(results)
118 }
119
120 async fn top_n_ids(
121 &self,
122 req: VectorSearchRequest<Self::Filter>,
123 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
124 if let Some(filter) = req.filter() {
125 filter.validate()?;
126 }
127
128 let embedding = self.model.embed_text(req.query()).await?;
129
130 let query_request = ApiQueryRequest {
131 vector: embedding.vec,
132 top_k: req.samples(),
133 return_values: Some(false),
134 return_metadata: Some(ReturnMetadata::None),
135 filter: req.filter().as_ref().map(|f| f.clone().into_inner()),
136 };
137
138 let result = self.client.query(query_request).await?;
139
140 let results = result
142 .matches
143 .into_iter()
144 .filter(|m| req.threshold().is_none_or(|t| m.score >= t))
145 .map(|m| (m.score, m.id))
146 .collect();
147
148 Ok(results)
149 }
150}
151
152impl<M> InsertDocuments for VectorizeVectorStore<M>
153where
154 M: EmbeddingModel + Sync + Send,
155{
156 async fn insert_documents<Doc: Serialize + Embed + Send>(
157 &self,
158 documents: Vec<(Doc, OneOrMany<Embedding>)>,
159 ) -> Result<(), VectorStoreError> {
160 let mut vectors: Vec<ApiVectorInput> = Vec::new();
161
162 for (doc, embeddings) in documents {
163 let metadata = serde_json::to_value(&doc)?;
164
165 for embedding in embeddings {
166 vectors.push(ApiVectorInput {
167 id: Uuid::new_v4().to_string(),
168 values: embedding.vec,
169 metadata: Some(metadata.clone()),
170 namespace: None,
171 });
172 }
173 }
174
175 if vectors.is_empty() {
176 return Ok(());
177 }
178
179 tracing::debug!("Upserting {} vectors to Vectorize", vectors.len());
180
181 const BATCH_SIZE: usize = 1000;
182
183 for batch in vectors.chunks(BATCH_SIZE) {
184 let request = UpsertRequest {
185 vectors: batch.to_vec(),
186 };
187
188 self.client.upsert(request).await?;
189 }
190
191 Ok(())
192 }
193}