Skip to main content

rig_vectorize/
lib.rs

1//! Cloudflare Vectorize integration for the Rig framework.
2//!
3//! This crate provides a vector store implementation using Cloudflare Vectorize,
4//! a globally distributed vector database built for AI applications.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use rig::client::ProviderClient;
10//! use rig::providers::openai;
11//! use rig_vectorize::VectorizeVectorStore;
12//!
13//! let openai = openai::Client::from_env()?;
14//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_3_SMALL);
15//!
16//! let vector_store = VectorizeVectorStore::new(
17//!     embedding_model,
18//!     "your-account-id",
19//!     "your-index-name",
20//!     std::env::var("CLOUDFLARE_API_TOKEN").unwrap(),
21//! );
22//! ```
23
24mod client;
25
26// Re-export client types
27pub use client::{
28    DeleteByIdsRequest, DeleteResult, ListVectorsResult, QueryRequest, QueryResult, ReturnMetadata,
29    UpsertRequest, UpsertResult, VectorIdEntry, VectorInput, VectorMatch, VectorizeClient,
30    VectorizeError, VectorizeFilter,
31};
32
33use client::{QueryRequest as ApiQueryRequest, VectorInput as ApiVectorInput};
34use rig::embeddings::EmbeddingModel;
35use rig::vector_store::request::VectorSearchRequest;
36use rig::vector_store::{InsertDocuments, VectorStoreError, VectorStoreIndex};
37use rig::{Embed, OneOrMany, embeddings::Embedding};
38use serde::{Deserialize, Serialize};
39use uuid::Uuid;
40
41impl From<VectorizeError> for VectorStoreError {
42    fn from(err: VectorizeError) -> Self {
43        VectorStoreError::DatastoreError(Box::new(err))
44    }
45}
46
47/// A vector store backed by Cloudflare Vectorize.
48///
49/// This struct implements [`VectorStoreIndex`] to provide vector similarity search
50/// using Cloudflare's globally distributed Vectorize service.
51#[derive(Debug, Clone)]
52pub struct VectorizeVectorStore<M> {
53    /// The embedding model used to generate query embeddings.
54    model: M,
55    /// The HTTP client for Vectorize API.
56    client: VectorizeClient,
57}
58
59impl<M> VectorizeVectorStore<M> {
60    /// Creates a new Vectorize vector store.
61    ///
62    /// # Arguments
63    /// * `model` - The embedding model to use for query embedding
64    /// * `account_id` - Cloudflare account ID
65    /// * `index_name` - Name of the Vectorize index
66    /// * `api_token` - Cloudflare API token with Vectorize read permissions
67    pub fn new(
68        model: M,
69        account_id: impl Into<String>,
70        index_name: impl Into<String>,
71        api_token: impl Into<String>,
72    ) -> Self {
73        Self {
74            model,
75            client: VectorizeClient::new(account_id, index_name, api_token),
76        }
77    }
78}
79
80impl<M> VectorStoreIndex for VectorizeVectorStore<M>
81where
82    M: EmbeddingModel + Sync + Send,
83{
84    type Filter = VectorizeFilter;
85
86    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
87        &self,
88        req: VectorSearchRequest<Self::Filter>,
89    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
90        if let Some(filter) = req.filter() {
91            filter.validate()?;
92        }
93
94        let embedding = self.model.embed_text(req.query()).await?;
95
96        let query_request = ApiQueryRequest {
97            vector: embedding.vec,
98            top_k: req.samples(),
99            return_values: Some(false),
100            return_metadata: Some(ReturnMetadata::All),
101            filter: req.filter().as_ref().map(|f| f.clone().into_inner()),
102        };
103
104        let result = self.client.query(query_request).await?;
105
106        // Convert results to the expected format
107        let results = result
108            .matches
109            .into_iter()
110            .filter(|m| req.threshold().is_none_or(|t| m.score >= t))
111            .map(|m| {
112                let metadata = m.metadata.unwrap_or(serde_json::Value::Null);
113                let doc: T = serde_json::from_value(metadata)?;
114                Ok((m.score, m.id, doc))
115            })
116            .collect::<Result<Vec<_>, serde_json::Error>>()?;
117
118        Ok(results)
119    }
120
121    async fn top_n_ids(
122        &self,
123        req: VectorSearchRequest<Self::Filter>,
124    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
125        if let Some(filter) = req.filter() {
126            filter.validate()?;
127        }
128
129        let embedding = self.model.embed_text(req.query()).await?;
130
131        let query_request = ApiQueryRequest {
132            vector: embedding.vec,
133            top_k: req.samples(),
134            return_values: Some(false),
135            return_metadata: Some(ReturnMetadata::None),
136            filter: req.filter().as_ref().map(|f| f.clone().into_inner()),
137        };
138
139        let result = self.client.query(query_request).await?;
140
141        // Convert results to (score, id) tuples
142        let results = result
143            .matches
144            .into_iter()
145            .filter(|m| req.threshold().is_none_or(|t| m.score >= t))
146            .map(|m| (m.score, m.id))
147            .collect();
148
149        Ok(results)
150    }
151}
152
153impl<M> InsertDocuments for VectorizeVectorStore<M>
154where
155    M: EmbeddingModel + Sync + Send,
156{
157    async fn insert_documents<Doc: Serialize + Embed + Send>(
158        &self,
159        documents: Vec<(Doc, OneOrMany<Embedding>)>,
160    ) -> Result<(), VectorStoreError> {
161        let mut vectors: Vec<ApiVectorInput> = Vec::new();
162
163        for (doc, embeddings) in documents {
164            let metadata = serde_json::to_value(&doc)?;
165
166            for embedding in embeddings {
167                vectors.push(ApiVectorInput {
168                    id: Uuid::new_v4().to_string(),
169                    values: embedding.vec,
170                    metadata: Some(metadata.clone()),
171                    namespace: None,
172                });
173            }
174        }
175
176        if vectors.is_empty() {
177            return Ok(());
178        }
179
180        tracing::debug!("Upserting {} vectors to Vectorize", vectors.len());
181
182        const BATCH_SIZE: usize = 1000;
183
184        for batch in vectors.chunks(BATCH_SIZE) {
185            let request = UpsertRequest {
186                vectors: batch.to_vec(),
187            };
188
189            self.client.upsert(request).await?;
190        }
191
192        Ok(())
193    }
194}