Skip to main content

redis_vl/vectorizers/
vertex_ai.rs

1//! Google Vertex AI embedding adapter.
2//!
3//! Enabled by the `vertex-ai` feature flag. Uses the Vertex AI `predict`
4//! REST API endpoint for text embedding models such as
5//! `textembedding-gecko@003`.
6
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9
10use super::{AsyncVectorizer, Vectorizer};
11use crate::error::Result;
12
13/// Configuration for the Vertex AI embedding provider.
14#[derive(Debug, Clone)]
15pub struct VertexAIConfig {
16    /// GCP project ID.
17    pub project_id: String,
18    /// GCP location / region (e.g. `us-central1`).
19    pub location: String,
20    /// Embedding model name (default: `textembedding-gecko@003`).
21    pub model: String,
22    /// API key or OAuth2 access token used for authentication.
23    pub api_key: String,
24}
25
26impl VertexAIConfig {
27    /// Creates a new Vertex AI config.
28    pub fn new(
29        project_id: impl Into<String>,
30        location: impl Into<String>,
31        model: impl Into<String>,
32        api_key: impl Into<String>,
33    ) -> Self {
34        Self {
35            project_id: project_id.into(),
36            location: location.into(),
37            model: model.into(),
38            api_key: api_key.into(),
39        }
40    }
41
42    /// Constructs from environment variables:
43    /// `GCP_PROJECT_ID`, `GCP_LOCATION`, `GCP_API_KEY`.
44    /// Model defaults to `textembedding-gecko@003`.
45    pub fn from_env(model: Option<String>) -> Result<Self> {
46        let project_id = std::env::var("GCP_PROJECT_ID")
47            .map_err(|_| crate::error::Error::InvalidInput("GCP_PROJECT_ID not set".into()))?;
48        let location = std::env::var("GCP_LOCATION")
49            .map_err(|_| crate::error::Error::InvalidInput("GCP_LOCATION not set".into()))?;
50        let api_key = std::env::var("GCP_API_KEY")
51            .map_err(|_| crate::error::Error::InvalidInput("GCP_API_KEY not set".into()))?;
52        Ok(Self::new(
53            project_id,
54            location,
55            model.unwrap_or_else(|| "textembedding-gecko@003".to_string()),
56            api_key,
57        ))
58    }
59
60    fn predict_url(&self) -> String {
61        format!(
62            "https://{location}-aiplatform.googleapis.com/v1/projects/{project}/locations/{location}/publishers/google/models/{model}:predict",
63            location = self.location,
64            project = self.project_id,
65            model = self.model,
66        )
67    }
68}
69
70#[derive(Serialize)]
71struct VertexAIInstance<'a> {
72    content: &'a str,
73}
74
75#[derive(Serialize)]
76struct VertexAIPredictRequest<'a> {
77    instances: Vec<VertexAIInstance<'a>>,
78}
79
80#[derive(Deserialize)]
81struct VertexAIPredictResponse {
82    predictions: Vec<VertexAIPrediction>,
83}
84
85#[derive(Deserialize)]
86struct VertexAIPrediction {
87    embeddings: VertexAIEmbeddings,
88}
89
90#[derive(Deserialize)]
91struct VertexAIEmbeddings {
92    values: Vec<f32>,
93}
94
95/// Vertex AI embedding adapter.
96///
97/// Uses the Vertex AI `predict` REST endpoint which takes `instances`
98/// with `content` fields and returns `predictions` with `embeddings.values`.
99#[derive(Debug, Clone)]
100pub struct VertexAITextVectorizer {
101    config: VertexAIConfig,
102    client: reqwest::Client,
103    blocking_client: reqwest::blocking::Client,
104}
105
106impl VertexAITextVectorizer {
107    /// Creates a new Vertex AI adapter.
108    pub fn new(config: VertexAIConfig) -> Self {
109        Self {
110            config,
111            client: reqwest::Client::new(),
112            blocking_client: reqwest::blocking::Client::new(),
113        }
114    }
115
116    fn build_request<'a>(&self, texts: &[&'a str]) -> VertexAIPredictRequest<'a> {
117        VertexAIPredictRequest {
118            instances: texts
119                .iter()
120                .map(|t| VertexAIInstance { content: t })
121                .collect(),
122        }
123    }
124
125    fn parse_response(resp: VertexAIPredictResponse) -> Result<Vec<Vec<f32>>> {
126        if resp.predictions.is_empty() {
127            return Err(crate::error::Error::InvalidInput(
128                "no predictions in Vertex AI response".into(),
129            ));
130        }
131        Ok(resp
132            .predictions
133            .into_iter()
134            .map(|p| p.embeddings.values)
135            .collect())
136    }
137
138    async fn embed_many_inner(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
139        let resp: VertexAIPredictResponse = self
140            .client
141            .post(self.config.predict_url())
142            .bearer_auth(&self.config.api_key)
143            .json(&self.build_request(texts))
144            .send()
145            .await?
146            .error_for_status()?
147            .json()
148            .await?;
149        Self::parse_response(resp)
150    }
151}
152
153impl Vectorizer for VertexAITextVectorizer {
154    fn embed(&self, text: &str) -> Result<Vec<f32>> {
155        let resp: VertexAIPredictResponse = self
156            .blocking_client
157            .post(self.config.predict_url())
158            .bearer_auth(&self.config.api_key)
159            .json(&self.build_request(&[text]))
160            .send()?
161            .error_for_status()?
162            .json()?;
163        let mut embeddings = Self::parse_response(resp)?;
164        Ok(embeddings.pop().unwrap_or_default())
165    }
166}
167
168#[async_trait]
169impl AsyncVectorizer for VertexAITextVectorizer {
170    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
171        let mut v = self.embed_many_inner(&[text]).await?;
172        Ok(v.pop().unwrap_or_default())
173    }
174
175    async fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
176        self.embed_many_inner(texts).await
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    #[test]
185    fn vertex_ai_config_stores_fields() {
186        let cfg = VertexAIConfig::new(
187            "my-project",
188            "us-central1",
189            "textembedding-gecko@003",
190            "key",
191        );
192        assert_eq!(cfg.project_id, "my-project");
193        assert_eq!(cfg.location, "us-central1");
194        assert_eq!(cfg.model, "textembedding-gecko@003");
195        assert_eq!(cfg.api_key, "key");
196    }
197
198    #[test]
199    fn vertex_ai_config_builds_predict_url() {
200        let cfg = VertexAIConfig::new("proj", "us-central1", "textembedding-gecko@003", "k");
201        let url = cfg.predict_url();
202        assert_eq!(
203            url,
204            "https://us-central1-aiplatform.googleapis.com/v1/projects/proj/locations/us-central1/publishers/google/models/textembedding-gecko@003:predict"
205        );
206    }
207
208    #[test]
209    fn vertex_ai_request_serializes_correctly() {
210        let cfg = VertexAIConfig::new("p", "us-central1", "model", "k");
211        let v = VertexAITextVectorizer::new(cfg);
212        let body = v.build_request(&["hello", "world"]);
213        let json = serde_json::to_value(&body).unwrap();
214        let instances = json["instances"].as_array().unwrap();
215        assert_eq!(instances.len(), 2);
216        assert_eq!(instances[0]["content"], "hello");
217        assert_eq!(instances[1]["content"], "world");
218    }
219
220    #[test]
221    fn vertex_ai_parse_response_extracts_values() {
222        let resp = VertexAIPredictResponse {
223            predictions: vec![
224                VertexAIPrediction {
225                    embeddings: VertexAIEmbeddings {
226                        values: vec![1.0, 2.0, 3.0],
227                    },
228                },
229                VertexAIPrediction {
230                    embeddings: VertexAIEmbeddings {
231                        values: vec![4.0, 5.0, 6.0],
232                    },
233                },
234            ],
235        };
236        let result = VertexAITextVectorizer::parse_response(resp).unwrap();
237        assert_eq!(result.len(), 2);
238        assert_eq!(result[0], vec![1.0, 2.0, 3.0]);
239        assert_eq!(result[1], vec![4.0, 5.0, 6.0]);
240    }
241
242    #[test]
243    fn vertex_ai_parse_response_errors_on_empty() {
244        let resp = VertexAIPredictResponse {
245            predictions: vec![],
246        };
247        assert!(VertexAITextVectorizer::parse_response(resp).is_err());
248    }
249
250    #[test]
251    fn vertex_ai_vectorizer_is_send_sync() {
252        fn assert_send_sync<T: Send + Sync>() {}
253        assert_send_sync::<VertexAITextVectorizer>();
254    }
255}