openmemory 0.1.1

OpenMemory - Cognitive memory system for AI applications
Documentation
//! AWS Bedrock embedding provider
//!
//! Uses AWS Bedrock's embedding models. Requires the `aws` feature flag.

use crate::core::config::Config;
use crate::core::error::{Error, Result};
use crate::core::types::{EmbeddingResult, Sector};
use crate::memory::embed::EmbeddingProvider;
#[cfg(feature = "aws")]
use crate::memory::embed::resize_vector;
use async_trait::async_trait;

#[cfg(feature = "aws")]
use aws_config::BehaviorVersion;
#[cfg(feature = "aws")]
use aws_sdk_bedrockruntime::Client as BedrockClient;
#[cfg(feature = "aws")]
use serde::{Deserialize, Serialize};

/// AWS Bedrock embedding provider
pub struct BedrockProvider {
    #[cfg(feature = "aws")]
    client: Option<BedrockClient>,
    #[allow(dead_code)]
    model_id: String,
    dim: usize,
    #[cfg(feature = "aws")]
    target_dim: usize,
}

#[cfg(feature = "aws")]
#[derive(Serialize)]
struct BedrockRequest {
    #[serde(rename = "inputText")]
    input_text: String,
    dimensions: usize,
}

#[cfg(feature = "aws")]
#[derive(Deserialize)]
struct BedrockResponse {
    embedding: Vec<f32>,
}

impl BedrockProvider {
    /// Create a new Bedrock provider
    #[cfg(feature = "aws")]
    pub async fn new_async(config: &Config) -> Self {
        let aws_config = aws_config::defaults(BehaviorVersion::latest())
            .load()
            .await;

        let client = BedrockClient::new(&aws_config);

        // Select dimension based on config (Bedrock supports 256, 512, 1024)
        let target_dim = [256, 512, 1024]
            .iter()
            .find(|&&d| d >= config.vec_dim)
            .copied()
            .unwrap_or(1024);

        Self {
            client: Some(client),
            model_id: "amazon.titan-embed-text-v1".to_string(),
            dim: config.vec_dim,
            target_dim,
        }
    }

    /// Create a new Bedrock provider (sync version, no client initialization)
    pub fn new(config: &Config) -> Self {
        Self {
            #[cfg(feature = "aws")]
            client: None,
            model_id: "amazon.titan-embed-text-v1".to_string(),
            dim: config.vec_dim,
            #[cfg(feature = "aws")]
            target_dim: 1024,
        }
    }
}

#[async_trait]
impl EmbeddingProvider for BedrockProvider {
    async fn embed(&self, text: &str, sector: &Sector) -> Result<EmbeddingResult> {
        #[cfg(feature = "aws")]
        {
            let client = self
                .client
                .as_ref()
                .ok_or_else(|| Error::config("Bedrock client not initialized. Use new_async()."))?;

            let request_body = BedrockRequest {
                input_text: text.to_string(),
                dimensions: self.target_dim,
            };

            let body_bytes =
                serde_json::to_vec(&request_body).map_err(|e| Error::embedding(e.to_string()))?;

            let response = client
                .invoke_model()
                .model_id(&self.model_id)
                .content_type("application/json")
                .accept("*/*")
                .body(aws_sdk_bedrockruntime::primitives::Blob::new(body_bytes))
                .send()
                .await
                .map_err(|e| Error::embedding(format!("Bedrock API error: {}", e)))?;

            let response_bytes = response.body().as_ref();
            let bedrock_response: BedrockResponse = serde_json::from_slice(response_bytes)
                .map_err(|e| Error::embedding(format!("Failed to parse Bedrock response: {}", e)))?;

            let vector = resize_vector(&bedrock_response.embedding, self.dim);

            Ok(EmbeddingResult {
                sector: *sector,
                vector: vector.clone(),
                dim: vector.len(),
            })
        }

        #[cfg(not(feature = "aws"))]
        {
            let _ = (text, sector);
            Err(Error::config(
                "AWS feature not enabled. Rebuild with --features aws",
            ))
        }
    }

    async fn embed_batch(&self, texts: &[(&str, &Sector)]) -> Result<Vec<EmbeddingResult>> {
        // Bedrock doesn't have native batch support, process sequentially
        let mut results = Vec::with_capacity(texts.len());
        for (text, sector) in texts {
            results.push(self.embed(text, sector).await?);
        }
        Ok(results)
    }

    fn dimensions(&self) -> usize {
        self.dim
    }

    fn name(&self) -> &'static str {
        "bedrock"
    }

    fn supports_batch(&self) -> bool {
        false // Sequential processing
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_provider_creation() {
        let config = Config::default();
        let provider = BedrockProvider::new(&config);
        assert_eq!(provider.name(), "bedrock");
        assert!(!provider.supports_batch());
    }
}