use crate::core::config::Config;
use crate::core::error::{Error, Result};
use crate::core::types::{EmbeddingResult, Sector};
use crate::memory::embed::EmbeddingProvider;
#[cfg(feature = "aws")]
use crate::memory::embed::resize_vector;
use async_trait::async_trait;
#[cfg(feature = "aws")]
use aws_config::BehaviorVersion;
#[cfg(feature = "aws")]
use aws_sdk_bedrockruntime::Client as BedrockClient;
#[cfg(feature = "aws")]
use serde::{Deserialize, Serialize};
pub struct BedrockProvider {
#[cfg(feature = "aws")]
client: Option<BedrockClient>,
#[allow(dead_code)]
model_id: String,
dim: usize,
#[cfg(feature = "aws")]
target_dim: usize,
}
#[cfg(feature = "aws")]
#[derive(Serialize)]
struct BedrockRequest {
#[serde(rename = "inputText")]
input_text: String,
dimensions: usize,
}
#[cfg(feature = "aws")]
#[derive(Deserialize)]
struct BedrockResponse {
embedding: Vec<f32>,
}
impl BedrockProvider {
#[cfg(feature = "aws")]
pub async fn new_async(config: &Config) -> Self {
let aws_config = aws_config::defaults(BehaviorVersion::latest())
.load()
.await;
let client = BedrockClient::new(&aws_config);
let target_dim = [256, 512, 1024]
.iter()
.find(|&&d| d >= config.vec_dim)
.copied()
.unwrap_or(1024);
Self {
client: Some(client),
model_id: "amazon.titan-embed-text-v1".to_string(),
dim: config.vec_dim,
target_dim,
}
}
pub fn new(config: &Config) -> Self {
Self {
#[cfg(feature = "aws")]
client: None,
model_id: "amazon.titan-embed-text-v1".to_string(),
dim: config.vec_dim,
#[cfg(feature = "aws")]
target_dim: 1024,
}
}
}
#[async_trait]
impl EmbeddingProvider for BedrockProvider {
async fn embed(&self, text: &str, sector: &Sector) -> Result<EmbeddingResult> {
#[cfg(feature = "aws")]
{
let client = self
.client
.as_ref()
.ok_or_else(|| Error::config("Bedrock client not initialized. Use new_async()."))?;
let request_body = BedrockRequest {
input_text: text.to_string(),
dimensions: self.target_dim,
};
let body_bytes =
serde_json::to_vec(&request_body).map_err(|e| Error::embedding(e.to_string()))?;
let response = client
.invoke_model()
.model_id(&self.model_id)
.content_type("application/json")
.accept("*/*")
.body(aws_sdk_bedrockruntime::primitives::Blob::new(body_bytes))
.send()
.await
.map_err(|e| Error::embedding(format!("Bedrock API error: {}", e)))?;
let response_bytes = response.body().as_ref();
let bedrock_response: BedrockResponse = serde_json::from_slice(response_bytes)
.map_err(|e| Error::embedding(format!("Failed to parse Bedrock response: {}", e)))?;
let vector = resize_vector(&bedrock_response.embedding, self.dim);
Ok(EmbeddingResult {
sector: *sector,
vector: vector.clone(),
dim: vector.len(),
})
}
#[cfg(not(feature = "aws"))]
{
let _ = (text, sector);
Err(Error::config(
"AWS feature not enabled. Rebuild with --features aws",
))
}
}
async fn embed_batch(&self, texts: &[(&str, &Sector)]) -> Result<Vec<EmbeddingResult>> {
let mut results = Vec::with_capacity(texts.len());
for (text, sector) in texts {
results.push(self.embed(text, sector).await?);
}
Ok(results)
}
fn dimensions(&self) -> usize {
self.dim
}
fn name(&self) -> &'static str {
"bedrock"
}
fn supports_batch(&self) -> bool {
false }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_creation() {
let config = Config::default();
let provider = BedrockProvider::new(&config);
assert_eq!(provider.name(), "bedrock");
assert!(!provider.supports_batch());
}
}