Skip to main content

redis_vl/vectorizers/
bedrock.rs

1//! AWS Bedrock embedding adapter.
2//!
3//! Enabled by the `bedrock` feature flag. Uses the AWS SDK for Rust to call
4//! Amazon Bedrock Runtime's `InvokeModel` API with SigV4 authentication.
5//! The default model is `amazon.titan-embed-text-v2:0`.
6//!
7//! Bedrock does not support batch embedding — each text is embedded
8//! individually via `invoke_model`.
9
10use async_trait::async_trait;
11use aws_sdk_bedrockruntime::primitives::Blob;
12
13use super::{AsyncVectorizer, Vectorizer};
14use crate::error::{Error, Result};
15
16/// Configuration for the AWS Bedrock embedding provider.
17///
18/// Credentials are resolved through the standard AWS credential chain
19/// (environment variables, shared config/credentials files, IAM roles, etc.)
20/// unless explicit values are provided.
21#[derive(Debug, Clone)]
22pub struct BedrockConfig {
23    /// Bedrock model ID (default: `amazon.titan-embed-text-v2:0`).
24    pub model: String,
25    /// AWS region (default: `us-east-1`).
26    pub region: String,
27    /// Optional explicit AWS access key ID.
28    pub access_key_id: Option<String>,
29    /// Optional explicit AWS secret access key.
30    pub secret_access_key: Option<String>,
31}
32
33impl Default for BedrockConfig {
34    fn default() -> Self {
35        Self {
36            model: "amazon.titan-embed-text-v2:0".into(),
37            region: "us-east-1".into(),
38            access_key_id: None,
39            secret_access_key: None,
40        }
41    }
42}
43
44impl BedrockConfig {
45    /// Creates a new Bedrock config with the given model ID.
46    pub fn new(model: impl Into<String>) -> Self {
47        Self {
48            model: model.into(),
49            ..Default::default()
50        }
51    }
52
53    /// Sets the AWS region.
54    #[must_use]
55    pub fn with_region(mut self, region: impl Into<String>) -> Self {
56        self.region = region.into();
57        self
58    }
59
60    /// Sets explicit AWS credentials.
61    #[must_use]
62    pub fn with_credentials(
63        mut self,
64        access_key_id: impl Into<String>,
65        secret_access_key: impl Into<String>,
66    ) -> Self {
67        self.access_key_id = Some(access_key_id.into());
68        self.secret_access_key = Some(secret_access_key.into());
69        self
70    }
71
72    /// Constructs a config from environment variables.
73    ///
74    /// Reads `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, and
75    /// `AWS_REGION` (defaults to `us-east-1`). The `BEDROCK_MODEL_ID`
76    /// env var overrides the default model if set.
77    pub fn from_env() -> Result<Self> {
78        let region = std::env::var("AWS_REGION").unwrap_or_else(|_| "us-east-1".into());
79        let model = std::env::var("BEDROCK_MODEL_ID")
80            .unwrap_or_else(|_| "amazon.titan-embed-text-v2:0".into());
81        // Credentials are optional here; the AWS SDK will resolve them
82        // through its default chain if not explicitly provided.
83        let access_key_id = std::env::var("AWS_ACCESS_KEY_ID").ok();
84        let secret_access_key = std::env::var("AWS_SECRET_ACCESS_KEY").ok();
85        Ok(Self {
86            model,
87            region,
88            access_key_id,
89            secret_access_key,
90        })
91    }
92}
93
94/// Bedrock Titan embedding request body.
95#[derive(Debug, serde::Serialize)]
96struct TitanEmbedRequest<'a> {
97    /// Text to embed.
98    #[serde(rename = "inputText")]
99    input_text: &'a str,
100}
101
102/// Bedrock Titan embedding response body.
103#[derive(Debug, serde::Deserialize)]
104struct TitanEmbedResponse {
105    /// The embedding vector.
106    embedding: Vec<f32>,
107}
108
109/// AWS Bedrock embedding adapter.
110///
111/// Uses the Bedrock Runtime `InvokeModel` API with SigV4 authentication.
112/// Each text is embedded individually since Bedrock does not support batch
113/// embedding.
114///
115/// # Example
116///
117/// ```no_run
118/// use redis_vl::vectorizers::{BedrockConfig, BedrockTextVectorizer, Vectorizer};
119///
120/// # fn main() -> redis_vl::error::Result<()> {
121/// let config = BedrockConfig::from_env()?;
122/// let rt = tokio::runtime::Runtime::new().unwrap();
123/// let vectorizer = rt.block_on(BedrockTextVectorizer::new(config))?;
124/// # Ok(())
125/// # }
126/// ```
127pub struct BedrockTextVectorizer {
128    config: BedrockConfig,
129    client: aws_sdk_bedrockruntime::Client,
130}
131
132impl std::fmt::Debug for BedrockTextVectorizer {
133    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134        f.debug_struct("BedrockTextVectorizer")
135            .field("config", &self.config)
136            .finish_non_exhaustive()
137    }
138}
139
140impl BedrockTextVectorizer {
141    /// Creates a new Bedrock adapter by building an AWS SDK client from the
142    /// provided configuration.
143    ///
144    /// This is an `async` constructor because the AWS SDK credential
145    /// resolution is asynchronous.
146    pub async fn new(config: BedrockConfig) -> Result<Self> {
147        let mut aws_config_loader =
148            aws_config::from_env().region(aws_config::Region::new(config.region.clone()));
149
150        if let (Some(key_id), Some(secret)) = (&config.access_key_id, &config.secret_access_key) {
151            aws_config_loader = aws_config_loader.credentials_provider(
152                aws_sdk_bedrockruntime::config::Credentials::new(
153                    key_id.clone(),
154                    secret.clone(),
155                    None, // session token
156                    None, // expiry
157                    "redis-vl-bedrock",
158                ),
159            );
160        }
161
162        let sdk_config = aws_config_loader.load().await;
163        let client = aws_sdk_bedrockruntime::Client::new(&sdk_config);
164
165        Ok(Self { config, client })
166    }
167
168    /// Invokes the Bedrock model for a single text and returns the embedding.
169    async fn invoke_embed(&self, text: &str) -> Result<Vec<f32>> {
170        let body = serde_json::to_vec(&TitanEmbedRequest { input_text: text })?;
171
172        let response = self
173            .client
174            .invoke_model()
175            .model_id(&self.config.model)
176            .content_type("application/json")
177            .accept("application/json")
178            .body(Blob::new(body))
179            .send()
180            .await
181            .map_err(|e| Error::InvalidInput(format!("Bedrock invoke_model failed: {e}")))?;
182
183        let response_bytes = response.body().as_ref();
184        let parsed: TitanEmbedResponse = serde_json::from_slice(response_bytes)?;
185        Ok(parsed.embedding)
186    }
187}
188
189impl Vectorizer for BedrockTextVectorizer {
190    fn embed(&self, text: &str) -> Result<Vec<f32>> {
191        // Build a current-thread runtime for the blocking path.
192        let rt = tokio::runtime::Builder::new_current_thread()
193            .enable_all()
194            .build()
195            .map_err(|e| Error::InvalidInput(format!("failed to build tokio runtime: {e}")))?;
196        rt.block_on(self.invoke_embed(text))
197    }
198
199    fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
200        let rt = tokio::runtime::Builder::new_current_thread()
201            .enable_all()
202            .build()
203            .map_err(|e| Error::InvalidInput(format!("failed to build tokio runtime: {e}")))?;
204        rt.block_on(async {
205            let mut embeddings = Vec::with_capacity(texts.len());
206            for text in texts {
207                embeddings.push(self.invoke_embed(text).await?);
208            }
209            Ok(embeddings)
210        })
211    }
212}
213
214#[async_trait]
215impl AsyncVectorizer for BedrockTextVectorizer {
216    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
217        self.invoke_embed(text).await
218    }
219
220    async fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
221        let mut embeddings = Vec::with_capacity(texts.len());
222        for text in texts {
223            embeddings.push(self.invoke_embed(text).await?);
224        }
225        Ok(embeddings)
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    #[test]
234    fn bedrock_config_defaults() {
235        let cfg = BedrockConfig::default();
236        assert_eq!(cfg.model, "amazon.titan-embed-text-v2:0");
237        assert_eq!(cfg.region, "us-east-1");
238        assert!(cfg.access_key_id.is_none());
239        assert!(cfg.secret_access_key.is_none());
240    }
241
242    #[test]
243    fn bedrock_config_builder() {
244        let cfg = BedrockConfig::new("amazon.titan-embed-text-v1")
245            .with_region("eu-west-1")
246            .with_credentials("AKID", "SECRET");
247        assert_eq!(cfg.model, "amazon.titan-embed-text-v1");
248        assert_eq!(cfg.region, "eu-west-1");
249        assert_eq!(cfg.access_key_id.as_deref(), Some("AKID"));
250        assert_eq!(cfg.secret_access_key.as_deref(), Some("SECRET"));
251    }
252
253    #[test]
254    fn titan_request_serializes_correctly() {
255        let req = TitanEmbedRequest {
256            input_text: "hello world",
257        };
258        let json = serde_json::to_value(&req).unwrap();
259        assert_eq!(json["inputText"], "hello world");
260        // Must not contain any other top-level keys
261        assert_eq!(json.as_object().unwrap().len(), 1);
262    }
263
264    #[test]
265    fn titan_response_deserializes_correctly() {
266        let json = r#"{"embedding": [0.1, 0.2, 0.3]}"#;
267        let resp: TitanEmbedResponse = serde_json::from_str(json).unwrap();
268        assert_eq!(resp.embedding, vec![0.1, 0.2, 0.3]);
269    }
270
271    #[test]
272    fn bedrock_vectorizer_is_send_sync() {
273        fn assert_send_sync<T: Send + Sync>() {}
274        assert_send_sync::<BedrockTextVectorizer>();
275    }
276}