redis_vl/vectorizers/
bedrock.rs1use async_trait::async_trait;
11use aws_sdk_bedrockruntime::primitives::Blob;
12
13use super::{AsyncVectorizer, Vectorizer};
14use crate::error::{Error, Result};
15
16#[derive(Debug, Clone)]
22pub struct BedrockConfig {
23 pub model: String,
25 pub region: String,
27 pub access_key_id: Option<String>,
29 pub secret_access_key: Option<String>,
31}
32
33impl Default for BedrockConfig {
34 fn default() -> Self {
35 Self {
36 model: "amazon.titan-embed-text-v2:0".into(),
37 region: "us-east-1".into(),
38 access_key_id: None,
39 secret_access_key: None,
40 }
41 }
42}
43
44impl BedrockConfig {
45 pub fn new(model: impl Into<String>) -> Self {
47 Self {
48 model: model.into(),
49 ..Default::default()
50 }
51 }
52
53 #[must_use]
55 pub fn with_region(mut self, region: impl Into<String>) -> Self {
56 self.region = region.into();
57 self
58 }
59
60 #[must_use]
62 pub fn with_credentials(
63 mut self,
64 access_key_id: impl Into<String>,
65 secret_access_key: impl Into<String>,
66 ) -> Self {
67 self.access_key_id = Some(access_key_id.into());
68 self.secret_access_key = Some(secret_access_key.into());
69 self
70 }
71
72 pub fn from_env() -> Result<Self> {
78 let region = std::env::var("AWS_REGION").unwrap_or_else(|_| "us-east-1".into());
79 let model = std::env::var("BEDROCK_MODEL_ID")
80 .unwrap_or_else(|_| "amazon.titan-embed-text-v2:0".into());
81 let access_key_id = std::env::var("AWS_ACCESS_KEY_ID").ok();
84 let secret_access_key = std::env::var("AWS_SECRET_ACCESS_KEY").ok();
85 Ok(Self {
86 model,
87 region,
88 access_key_id,
89 secret_access_key,
90 })
91 }
92}
93
94#[derive(Debug, serde::Serialize)]
96struct TitanEmbedRequest<'a> {
97 #[serde(rename = "inputText")]
99 input_text: &'a str,
100}
101
102#[derive(Debug, serde::Deserialize)]
104struct TitanEmbedResponse {
105 embedding: Vec<f32>,
107}
108
109pub struct BedrockTextVectorizer {
128 config: BedrockConfig,
129 client: aws_sdk_bedrockruntime::Client,
130}
131
132impl std::fmt::Debug for BedrockTextVectorizer {
133 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134 f.debug_struct("BedrockTextVectorizer")
135 .field("config", &self.config)
136 .finish_non_exhaustive()
137 }
138}
139
140impl BedrockTextVectorizer {
141 pub async fn new(config: BedrockConfig) -> Result<Self> {
147 let mut aws_config_loader =
148 aws_config::from_env().region(aws_config::Region::new(config.region.clone()));
149
150 if let (Some(key_id), Some(secret)) = (&config.access_key_id, &config.secret_access_key) {
151 aws_config_loader = aws_config_loader.credentials_provider(
152 aws_sdk_bedrockruntime::config::Credentials::new(
153 key_id.clone(),
154 secret.clone(),
155 None, None, "redis-vl-bedrock",
158 ),
159 );
160 }
161
162 let sdk_config = aws_config_loader.load().await;
163 let client = aws_sdk_bedrockruntime::Client::new(&sdk_config);
164
165 Ok(Self { config, client })
166 }
167
168 async fn invoke_embed(&self, text: &str) -> Result<Vec<f32>> {
170 let body = serde_json::to_vec(&TitanEmbedRequest { input_text: text })?;
171
172 let response = self
173 .client
174 .invoke_model()
175 .model_id(&self.config.model)
176 .content_type("application/json")
177 .accept("application/json")
178 .body(Blob::new(body))
179 .send()
180 .await
181 .map_err(|e| Error::InvalidInput(format!("Bedrock invoke_model failed: {e}")))?;
182
183 let response_bytes = response.body().as_ref();
184 let parsed: TitanEmbedResponse = serde_json::from_slice(response_bytes)?;
185 Ok(parsed.embedding)
186 }
187}
188
189impl Vectorizer for BedrockTextVectorizer {
190 fn embed(&self, text: &str) -> Result<Vec<f32>> {
191 let rt = tokio::runtime::Builder::new_current_thread()
193 .enable_all()
194 .build()
195 .map_err(|e| Error::InvalidInput(format!("failed to build tokio runtime: {e}")))?;
196 rt.block_on(self.invoke_embed(text))
197 }
198
199 fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
200 let rt = tokio::runtime::Builder::new_current_thread()
201 .enable_all()
202 .build()
203 .map_err(|e| Error::InvalidInput(format!("failed to build tokio runtime: {e}")))?;
204 rt.block_on(async {
205 let mut embeddings = Vec::with_capacity(texts.len());
206 for text in texts {
207 embeddings.push(self.invoke_embed(text).await?);
208 }
209 Ok(embeddings)
210 })
211 }
212}
213
214#[async_trait]
215impl AsyncVectorizer for BedrockTextVectorizer {
216 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
217 self.invoke_embed(text).await
218 }
219
220 async fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
221 let mut embeddings = Vec::with_capacity(texts.len());
222 for text in texts {
223 embeddings.push(self.invoke_embed(text).await?);
224 }
225 Ok(embeddings)
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232
233 #[test]
234 fn bedrock_config_defaults() {
235 let cfg = BedrockConfig::default();
236 assert_eq!(cfg.model, "amazon.titan-embed-text-v2:0");
237 assert_eq!(cfg.region, "us-east-1");
238 assert!(cfg.access_key_id.is_none());
239 assert!(cfg.secret_access_key.is_none());
240 }
241
242 #[test]
243 fn bedrock_config_builder() {
244 let cfg = BedrockConfig::new("amazon.titan-embed-text-v1")
245 .with_region("eu-west-1")
246 .with_credentials("AKID", "SECRET");
247 assert_eq!(cfg.model, "amazon.titan-embed-text-v1");
248 assert_eq!(cfg.region, "eu-west-1");
249 assert_eq!(cfg.access_key_id.as_deref(), Some("AKID"));
250 assert_eq!(cfg.secret_access_key.as_deref(), Some("SECRET"));
251 }
252
253 #[test]
254 fn titan_request_serializes_correctly() {
255 let req = TitanEmbedRequest {
256 input_text: "hello world",
257 };
258 let json = serde_json::to_value(&req).unwrap();
259 assert_eq!(json["inputText"], "hello world");
260 assert_eq!(json.as_object().unwrap().len(), 1);
262 }
263
264 #[test]
265 fn titan_response_deserializes_correctly() {
266 let json = r#"{"embedding": [0.1, 0.2, 0.3]}"#;
267 let resp: TitanEmbedResponse = serde_json::from_str(json).unwrap();
268 assert_eq!(resp.embedding, vec![0.1, 0.2, 0.3]);
269 }
270
271 #[test]
272 fn bedrock_vectorizer_is_send_sync() {
273 fn assert_send_sync<T: Send + Sync>() {}
274 assert_send_sync::<BedrockTextVectorizer>();
275 }
276}