redis_vl/vectorizers/
vertex_ai.rs1use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9
10use super::{AsyncVectorizer, Vectorizer};
11use crate::error::Result;
12
13#[derive(Debug, Clone)]
15pub struct VertexAIConfig {
16 pub project_id: String,
18 pub location: String,
20 pub model: String,
22 pub api_key: String,
24}
25
26impl VertexAIConfig {
27 pub fn new(
29 project_id: impl Into<String>,
30 location: impl Into<String>,
31 model: impl Into<String>,
32 api_key: impl Into<String>,
33 ) -> Self {
34 Self {
35 project_id: project_id.into(),
36 location: location.into(),
37 model: model.into(),
38 api_key: api_key.into(),
39 }
40 }
41
42 pub fn from_env(model: Option<String>) -> Result<Self> {
46 let project_id = std::env::var("GCP_PROJECT_ID")
47 .map_err(|_| crate::error::Error::InvalidInput("GCP_PROJECT_ID not set".into()))?;
48 let location = std::env::var("GCP_LOCATION")
49 .map_err(|_| crate::error::Error::InvalidInput("GCP_LOCATION not set".into()))?;
50 let api_key = std::env::var("GCP_API_KEY")
51 .map_err(|_| crate::error::Error::InvalidInput("GCP_API_KEY not set".into()))?;
52 Ok(Self::new(
53 project_id,
54 location,
55 model.unwrap_or_else(|| "textembedding-gecko@003".to_string()),
56 api_key,
57 ))
58 }
59
60 fn predict_url(&self) -> String {
61 format!(
62 "https://{location}-aiplatform.googleapis.com/v1/projects/{project}/locations/{location}/publishers/google/models/{model}:predict",
63 location = self.location,
64 project = self.project_id,
65 model = self.model,
66 )
67 }
68}
69
70#[derive(Serialize)]
71struct VertexAIInstance<'a> {
72 content: &'a str,
73}
74
75#[derive(Serialize)]
76struct VertexAIPredictRequest<'a> {
77 instances: Vec<VertexAIInstance<'a>>,
78}
79
80#[derive(Deserialize)]
81struct VertexAIPredictResponse {
82 predictions: Vec<VertexAIPrediction>,
83}
84
85#[derive(Deserialize)]
86struct VertexAIPrediction {
87 embeddings: VertexAIEmbeddings,
88}
89
90#[derive(Deserialize)]
91struct VertexAIEmbeddings {
92 values: Vec<f32>,
93}
94
95#[derive(Debug, Clone)]
100pub struct VertexAITextVectorizer {
101 config: VertexAIConfig,
102 client: reqwest::Client,
103 blocking_client: reqwest::blocking::Client,
104}
105
106impl VertexAITextVectorizer {
107 pub fn new(config: VertexAIConfig) -> Self {
109 Self {
110 config,
111 client: reqwest::Client::new(),
112 blocking_client: reqwest::blocking::Client::new(),
113 }
114 }
115
116 fn build_request<'a>(&self, texts: &[&'a str]) -> VertexAIPredictRequest<'a> {
117 VertexAIPredictRequest {
118 instances: texts
119 .iter()
120 .map(|t| VertexAIInstance { content: t })
121 .collect(),
122 }
123 }
124
125 fn parse_response(resp: VertexAIPredictResponse) -> Result<Vec<Vec<f32>>> {
126 if resp.predictions.is_empty() {
127 return Err(crate::error::Error::InvalidInput(
128 "no predictions in Vertex AI response".into(),
129 ));
130 }
131 Ok(resp
132 .predictions
133 .into_iter()
134 .map(|p| p.embeddings.values)
135 .collect())
136 }
137
138 async fn embed_many_inner(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
139 let resp: VertexAIPredictResponse = self
140 .client
141 .post(self.config.predict_url())
142 .bearer_auth(&self.config.api_key)
143 .json(&self.build_request(texts))
144 .send()
145 .await?
146 .error_for_status()?
147 .json()
148 .await?;
149 Self::parse_response(resp)
150 }
151}
152
153impl Vectorizer for VertexAITextVectorizer {
154 fn embed(&self, text: &str) -> Result<Vec<f32>> {
155 let resp: VertexAIPredictResponse = self
156 .blocking_client
157 .post(self.config.predict_url())
158 .bearer_auth(&self.config.api_key)
159 .json(&self.build_request(&[text]))
160 .send()?
161 .error_for_status()?
162 .json()?;
163 let mut embeddings = Self::parse_response(resp)?;
164 Ok(embeddings.pop().unwrap_or_default())
165 }
166}
167
168#[async_trait]
169impl AsyncVectorizer for VertexAITextVectorizer {
170 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
171 let mut v = self.embed_many_inner(&[text]).await?;
172 Ok(v.pop().unwrap_or_default())
173 }
174
175 async fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
176 self.embed_many_inner(texts).await
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183
184 #[test]
185 fn vertex_ai_config_stores_fields() {
186 let cfg = VertexAIConfig::new(
187 "my-project",
188 "us-central1",
189 "textembedding-gecko@003",
190 "key",
191 );
192 assert_eq!(cfg.project_id, "my-project");
193 assert_eq!(cfg.location, "us-central1");
194 assert_eq!(cfg.model, "textembedding-gecko@003");
195 assert_eq!(cfg.api_key, "key");
196 }
197
198 #[test]
199 fn vertex_ai_config_builds_predict_url() {
200 let cfg = VertexAIConfig::new("proj", "us-central1", "textembedding-gecko@003", "k");
201 let url = cfg.predict_url();
202 assert_eq!(
203 url,
204 "https://us-central1-aiplatform.googleapis.com/v1/projects/proj/locations/us-central1/publishers/google/models/textembedding-gecko@003:predict"
205 );
206 }
207
208 #[test]
209 fn vertex_ai_request_serializes_correctly() {
210 let cfg = VertexAIConfig::new("p", "us-central1", "model", "k");
211 let v = VertexAITextVectorizer::new(cfg);
212 let body = v.build_request(&["hello", "world"]);
213 let json = serde_json::to_value(&body).unwrap();
214 let instances = json["instances"].as_array().unwrap();
215 assert_eq!(instances.len(), 2);
216 assert_eq!(instances[0]["content"], "hello");
217 assert_eq!(instances[1]["content"], "world");
218 }
219
220 #[test]
221 fn vertex_ai_parse_response_extracts_values() {
222 let resp = VertexAIPredictResponse {
223 predictions: vec![
224 VertexAIPrediction {
225 embeddings: VertexAIEmbeddings {
226 values: vec![1.0, 2.0, 3.0],
227 },
228 },
229 VertexAIPrediction {
230 embeddings: VertexAIEmbeddings {
231 values: vec![4.0, 5.0, 6.0],
232 },
233 },
234 ],
235 };
236 let result = VertexAITextVectorizer::parse_response(resp).unwrap();
237 assert_eq!(result.len(), 2);
238 assert_eq!(result[0], vec![1.0, 2.0, 3.0]);
239 assert_eq!(result[1], vec![4.0, 5.0, 6.0]);
240 }
241
242 #[test]
243 fn vertex_ai_parse_response_errors_on_empty() {
244 let resp = VertexAIPredictResponse {
245 predictions: vec![],
246 };
247 assert!(VertexAITextVectorizer::parse_response(resp).is_err());
248 }
249
250 #[test]
251 fn vertex_ai_vectorizer_is_send_sync() {
252 fn assert_send_sync<T: Send + Sync>() {}
253 assert_send_sync::<VertexAITextVectorizer>();
254 }
255}