use crate::llm::traits::AiProvider;
use crate::llm::types::{ChatCompletionParams, ProviderResponse};
use crate::llm::utils::normalize_model_name;
use anyhow::Result;
use std::env;
#[derive(Debug, Clone)]
pub struct AmazonBedrockProvider;
impl Default for AmazonBedrockProvider {
fn default() -> Self {
Self::new()
}
}
impl AmazonBedrockProvider {
pub fn new() -> Self {
Self
}
fn get_aws_access_key_id(&self) -> Result<String> {
env::var(AWS_ACCESS_KEY_ID_ENV)
.map_err(|_| anyhow::anyhow!("AWS_ACCESS_KEY_ID not found in environment"))
}
fn get_aws_secret_access_key(&self) -> Result<String> {
env::var(AWS_SECRET_ACCESS_KEY_ENV)
.map_err(|_| anyhow::anyhow!("AWS_SECRET_ACCESS_KEY not found in environment"))
}
}
const AWS_ACCESS_KEY_ID_ENV: &str = "AWS_ACCESS_KEY_ID";
const AWS_SECRET_ACCESS_KEY_ENV: &str = "AWS_SECRET_ACCESS_KEY";
#[async_trait::async_trait]
impl AiProvider for AmazonBedrockProvider {
fn name(&self) -> &str {
"amazon"
}
fn supports_model(&self, model: &str) -> bool {
let model_lower = normalize_model_name(model);
model_lower.contains("claude")
|| model_lower.contains("titan")
|| model_lower.contains("llama")
|| model_lower.contains("anthropic.")
|| model_lower.contains("meta.")
|| model_lower.contains("amazon.")
|| model_lower.contains("ai21.")
|| model_lower.contains("cohere.")
|| model_lower.contains("mistral.")
}
fn get_api_key(&self) -> Result<String> {
let access_key_id = self.get_aws_access_key_id()?;
let _secret_access_key = self.get_aws_secret_access_key()?; Ok(access_key_id) }
fn supports_caching(&self, _model: &str) -> bool {
false
}
fn supports_vision(&self, model: &str) -> bool {
let model_lower = normalize_model_name(model);
model_lower.contains("claude-3")
|| model_lower.contains("claude-4")
|| model_lower.contains("anthropic.claude")
}
fn get_max_input_tokens(&self, model: &str) -> usize {
let model_lower = normalize_model_name(model);
if model_lower.contains("claude") || model_lower.contains("anthropic.claude") {
200_000 } else if model_lower.contains("llama3-2-90b") || model_lower.contains("meta.llama3-2-90b")
{
128_000 } else if model_lower.contains("llama") || model_lower.contains("meta.llama") {
32_768 } else if model_lower.contains("titan") || model_lower.contains("amazon.titan") {
32_000 } else {
32_768 }
}
async fn chat_completion(&self, _params: ChatCompletionParams) -> Result<ProviderResponse> {
Err(anyhow::anyhow!(
"Amazon Bedrock provider not fully implemented in octolib"
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_supports_model() {
let provider = AmazonBedrockProvider::new();
assert!(provider.supports_model("anthropic.claude-3-haiku-20240307-v1:0"));
assert!(provider.supports_model("anthropic.claude-3-5-sonnet-20241022-v2:0"));
assert!(provider.supports_model("meta.llama3-2-90b-instruct-v1:0"));
assert!(provider.supports_model("amazon.titan-embed-text-v2:0"));
assert!(!provider.supports_model("gpt-4"));
assert!(!provider.supports_model("deepseek-chat"));
}
#[test]
fn test_supports_model_case_insensitive() {
let provider = AmazonBedrockProvider::new();
assert!(provider.supports_model("ANTHROPIC.CLAUDE-3-HAIKU-20240307-V1:0"));
assert!(provider.supports_model("META.LLAMA3-2-90B-INSTRUCT-V1:0"));
assert!(provider.supports_model("Anthropic.Claude-3-Haiku"));
assert!(provider.supports_model("AMAZON.TITAN-EMBED-TEXT-V2:0"));
}
#[test]
fn test_supports_vision_case_insensitive() {
let provider = AmazonBedrockProvider::new();
assert!(provider.supports_vision("claude-3-haiku"));
assert!(provider.supports_vision("claude-3-sonnet"));
assert!(provider.supports_vision("CLAUDE-3-HAIKU"));
assert!(provider.supports_vision("CLAUDE-3-SONNET"));
assert!(provider.supports_vision("Anthropic.Claude-3-Haiku"));
}
}