Skip to main content

rig_vectorize/
lib.rs

1//! Cloudflare Vectorize integration for the Rig framework.
2//!
3//! This crate provides a vector store implementation using Cloudflare Vectorize,
4//! a globally distributed vector database built for AI applications.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use rig::providers::openai;
10//! use rig_vectorize::VectorizeVectorStore;
11//!
12//! let openai = openai::Client::from_env();
13//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_3_SMALL);
14//!
15//! let vector_store = VectorizeVectorStore::new(
16//!     embedding_model,
17//!     "your-account-id",
18//!     "your-index-name",
19//!     std::env::var("CLOUDFLARE_API_TOKEN").unwrap(),
20//! );
21//! ```
22
23mod client;
24
25// Re-export client types
26pub use client::{
27    DeleteByIdsRequest, DeleteResult, ListVectorsResult, QueryRequest, QueryResult, ReturnMetadata,
28    UpsertRequest, UpsertResult, VectorIdEntry, VectorInput, VectorMatch, VectorizeClient,
29    VectorizeError, VectorizeFilter,
30};
31
32use client::{QueryRequest as ApiQueryRequest, VectorInput as ApiVectorInput};
33use rig::embeddings::EmbeddingModel;
34use rig::vector_store::request::VectorSearchRequest;
35use rig::vector_store::{InsertDocuments, VectorStoreError, VectorStoreIndex};
36use rig::{Embed, OneOrMany, embeddings::Embedding};
37use serde::{Deserialize, Serialize};
38use uuid::Uuid;
39
40impl From<VectorizeError> for VectorStoreError {
41    fn from(err: VectorizeError) -> Self {
42        VectorStoreError::DatastoreError(Box::new(err))
43    }
44}
45
46/// A vector store backed by Cloudflare Vectorize.
47///
48/// This struct implements [`VectorStoreIndex`] to provide vector similarity search
49/// using Cloudflare's globally distributed Vectorize service.
50#[derive(Debug, Clone)]
51pub struct VectorizeVectorStore<M> {
52    /// The embedding model used to generate query embeddings.
53    model: M,
54    /// The HTTP client for Vectorize API.
55    client: VectorizeClient,
56}
57
58impl<M> VectorizeVectorStore<M> {
59    /// Creates a new Vectorize vector store.
60    ///
61    /// # Arguments
62    /// * `model` - The embedding model to use for query embedding
63    /// * `account_id` - Cloudflare account ID
64    /// * `index_name` - Name of the Vectorize index
65    /// * `api_token` - Cloudflare API token with Vectorize read permissions
66    pub fn new(
67        model: M,
68        account_id: impl Into<String>,
69        index_name: impl Into<String>,
70        api_token: impl Into<String>,
71    ) -> Self {
72        Self {
73            model,
74            client: VectorizeClient::new(account_id, index_name, api_token),
75        }
76    }
77}
78
79impl<M> VectorStoreIndex for VectorizeVectorStore<M>
80where
81    M: EmbeddingModel + Sync + Send,
82{
83    type Filter = VectorizeFilter;
84
85    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
86        &self,
87        req: VectorSearchRequest<Self::Filter>,
88    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
89        if let Some(filter) = req.filter() {
90            filter.validate()?;
91        }
92
93        let embedding = self.model.embed_text(req.query()).await?;
94
95        let query_request = ApiQueryRequest {
96            vector: embedding.vec,
97            top_k: req.samples(),
98            return_values: Some(false),
99            return_metadata: Some(ReturnMetadata::All),
100            filter: req.filter().as_ref().map(|f| f.clone().into_inner()),
101        };
102
103        let result = self.client.query(query_request).await?;
104
105        // Convert results to the expected format
106        let results = result
107            .matches
108            .into_iter()
109            .filter(|m| req.threshold().is_none_or(|t| m.score >= t))
110            .map(|m| {
111                let metadata = m.metadata.unwrap_or(serde_json::Value::Null);
112                let doc: T = serde_json::from_value(metadata)?;
113                Ok((m.score, m.id, doc))
114            })
115            .collect::<Result<Vec<_>, serde_json::Error>>()?;
116
117        Ok(results)
118    }
119
120    async fn top_n_ids(
121        &self,
122        req: VectorSearchRequest<Self::Filter>,
123    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
124        if let Some(filter) = req.filter() {
125            filter.validate()?;
126        }
127
128        let embedding = self.model.embed_text(req.query()).await?;
129
130        let query_request = ApiQueryRequest {
131            vector: embedding.vec,
132            top_k: req.samples(),
133            return_values: Some(false),
134            return_metadata: Some(ReturnMetadata::None),
135            filter: req.filter().as_ref().map(|f| f.clone().into_inner()),
136        };
137
138        let result = self.client.query(query_request).await?;
139
140        // Convert results to (score, id) tuples
141        let results = result
142            .matches
143            .into_iter()
144            .filter(|m| req.threshold().is_none_or(|t| m.score >= t))
145            .map(|m| (m.score, m.id))
146            .collect();
147
148        Ok(results)
149    }
150}
151
152impl<M> InsertDocuments for VectorizeVectorStore<M>
153where
154    M: EmbeddingModel + Sync + Send,
155{
156    async fn insert_documents<Doc: Serialize + Embed + Send>(
157        &self,
158        documents: Vec<(Doc, OneOrMany<Embedding>)>,
159    ) -> Result<(), VectorStoreError> {
160        let mut vectors: Vec<ApiVectorInput> = Vec::new();
161
162        for (doc, embeddings) in documents {
163            let metadata = serde_json::to_value(&doc)?;
164
165            for embedding in embeddings {
166                vectors.push(ApiVectorInput {
167                    id: Uuid::new_v4().to_string(),
168                    values: embedding.vec,
169                    metadata: Some(metadata.clone()),
170                    namespace: None,
171                });
172            }
173        }
174
175        if vectors.is_empty() {
176            return Ok(());
177        }
178
179        tracing::debug!("Upserting {} vectors to Vectorize", vectors.len());
180
181        const BATCH_SIZE: usize = 1000;
182
183        for batch in vectors.chunks(BATCH_SIZE) {
184            let request = UpsertRequest {
185                vectors: batch.to_vec(),
186            };
187
188            self.client.upsert(request).await?;
189        }
190
191        Ok(())
192    }
193}