1mod client;
25
26pub use client::{
28 DeleteByIdsRequest, DeleteResult, ListVectorsResult, QueryRequest, QueryResult, ReturnMetadata,
29 UpsertRequest, UpsertResult, VectorIdEntry, VectorInput, VectorMatch, VectorizeClient,
30 VectorizeError, VectorizeFilter,
31};
32
33use client::{QueryRequest as ApiQueryRequest, VectorInput as ApiVectorInput};
34use rig::embeddings::EmbeddingModel;
35use rig::vector_store::request::VectorSearchRequest;
36use rig::vector_store::{InsertDocuments, VectorStoreError, VectorStoreIndex};
37use rig::{Embed, OneOrMany, embeddings::Embedding};
38use serde::{Deserialize, Serialize};
39use uuid::Uuid;
40
41impl From<VectorizeError> for VectorStoreError {
42 fn from(err: VectorizeError) -> Self {
43 VectorStoreError::DatastoreError(Box::new(err))
44 }
45}
46
47#[derive(Debug, Clone)]
52pub struct VectorizeVectorStore<M> {
53 model: M,
55 client: VectorizeClient,
57}
58
59impl<M> VectorizeVectorStore<M> {
60 pub fn new(
68 model: M,
69 account_id: impl Into<String>,
70 index_name: impl Into<String>,
71 api_token: impl Into<String>,
72 ) -> Self {
73 Self {
74 model,
75 client: VectorizeClient::new(account_id, index_name, api_token),
76 }
77 }
78}
79
80impl<M> VectorStoreIndex for VectorizeVectorStore<M>
81where
82 M: EmbeddingModel + Sync + Send,
83{
84 type Filter = VectorizeFilter;
85
86 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
87 &self,
88 req: VectorSearchRequest<Self::Filter>,
89 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
90 if let Some(filter) = req.filter() {
91 filter.validate()?;
92 }
93
94 let embedding = self.model.embed_text(req.query()).await?;
95
96 let query_request = ApiQueryRequest {
97 vector: embedding.vec,
98 top_k: req.samples(),
99 return_values: Some(false),
100 return_metadata: Some(ReturnMetadata::All),
101 filter: req.filter().as_ref().map(|f| f.clone().into_inner()),
102 };
103
104 let result = self.client.query(query_request).await?;
105
106 let results = result
108 .matches
109 .into_iter()
110 .filter(|m| req.threshold().is_none_or(|t| m.score >= t))
111 .map(|m| {
112 let metadata = m.metadata.unwrap_or(serde_json::Value::Null);
113 let doc: T = serde_json::from_value(metadata)?;
114 Ok((m.score, m.id, doc))
115 })
116 .collect::<Result<Vec<_>, serde_json::Error>>()?;
117
118 Ok(results)
119 }
120
121 async fn top_n_ids(
122 &self,
123 req: VectorSearchRequest<Self::Filter>,
124 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
125 if let Some(filter) = req.filter() {
126 filter.validate()?;
127 }
128
129 let embedding = self.model.embed_text(req.query()).await?;
130
131 let query_request = ApiQueryRequest {
132 vector: embedding.vec,
133 top_k: req.samples(),
134 return_values: Some(false),
135 return_metadata: Some(ReturnMetadata::None),
136 filter: req.filter().as_ref().map(|f| f.clone().into_inner()),
137 };
138
139 let result = self.client.query(query_request).await?;
140
141 let results = result
143 .matches
144 .into_iter()
145 .filter(|m| req.threshold().is_none_or(|t| m.score >= t))
146 .map(|m| (m.score, m.id))
147 .collect();
148
149 Ok(results)
150 }
151}
152
153impl<M> InsertDocuments for VectorizeVectorStore<M>
154where
155 M: EmbeddingModel + Sync + Send,
156{
157 async fn insert_documents<Doc: Serialize + Embed + Send>(
158 &self,
159 documents: Vec<(Doc, OneOrMany<Embedding>)>,
160 ) -> Result<(), VectorStoreError> {
161 let mut vectors: Vec<ApiVectorInput> = Vec::new();
162
163 for (doc, embeddings) in documents {
164 let metadata = serde_json::to_value(&doc)?;
165
166 for embedding in embeddings {
167 vectors.push(ApiVectorInput {
168 id: Uuid::new_v4().to_string(),
169 values: embedding.vec,
170 metadata: Some(metadata.clone()),
171 namespace: None,
172 });
173 }
174 }
175
176 if vectors.is_empty() {
177 return Ok(());
178 }
179
180 tracing::debug!("Upserting {} vectors to Vectorize", vectors.len());
181
182 const BATCH_SIZE: usize = 1000;
183
184 for batch in vectors.chunks(BATCH_SIZE) {
185 let request = UpsertRequest {
186 vectors: batch.to_vec(),
187 };
188
189 self.client.upsert(request).await?;
190 }
191
192 Ok(())
193 }
194}