use super::{BaseProvider, ModelPricing, Provider, ProviderError, ProviderType};
use crate::config::ProviderConfig;
use crate::core::models::{RequestContext, openai::*};
use crate::utils::error::Result;
use async_trait::async_trait;
use serde_json::json;
use std::collections::HashMap;
use tracing::{debug, info};
#[derive(Debug, Clone)]
pub struct AWSBedrockProvider {
base: BaseProvider,
region: String,
access_key_id: String,
secret_access_key: String,
session_token: Option<String>,
pricing_cache: HashMap<String, ModelPricing>,
}
impl AWSBedrockProvider {
pub async fn new(config: &ProviderConfig) -> Result<Self> {
let base = BaseProvider::new(config)?;
let region = config
.organization
.as_ref()
.cloned()
.unwrap_or_else(|| "us-east-1".to_string());
let access_key_id = config
.project
.as_ref()
.cloned()
.unwrap_or_else(|| std::env::var("AWS_ACCESS_KEY_ID").unwrap_or_default());
let secret_access_key = std::env::var("AWS_SECRET_ACCESS_KEY").unwrap_or_default();
let session_token = std::env::var("AWS_SESSION_TOKEN").ok();
let base_url = config
.base_url
.clone()
.unwrap_or_else(|| format!("https://bedrock-runtime.{}.amazonaws.com", region));
let provider = Self {
base: BaseProvider { base_url, ..base },
region,
access_key_id,
secret_access_key,
session_token,
pricing_cache: Self::initialize_pricing_cache(),
};
info!(
"AWS Bedrock provider '{}' initialized successfully",
config.name
);
Ok(provider)
}
fn initialize_pricing_cache() -> HashMap<String, ModelPricing> {
let mut cache = HashMap::new();
cache.insert(
"anthropic.claude-3-sonnet-20240229-v1:0".to_string(),
ModelPricing {
model: "anthropic.claude-3-sonnet-20240229-v1:0".to_string(),
input_cost_per_1k: 0.003,
output_cost_per_1k: 0.015,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache.insert(
"anthropic.claude-3-haiku-20240307-v1:0".to_string(),
ModelPricing {
model: "anthropic.claude-3-haiku-20240307-v1:0".to_string(),
input_cost_per_1k: 0.00025,
output_cost_per_1k: 0.00125,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache.insert(
"amazon.titan-text-express-v1".to_string(),
ModelPricing {
model: "amazon.titan-text-express-v1".to_string(),
input_cost_per_1k: 0.0008,
output_cost_per_1k: 0.0016,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache.insert(
"meta.llama2-70b-chat-v1".to_string(),
ModelPricing {
model: "meta.llama2-70b-chat-v1".to_string(),
input_cost_per_1k: 0.00195,
output_cost_per_1k: 0.00256,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache
}
async fn create_aws_signature(
&self,
method: &str,
uri: &str,
query_string: &str,
headers: &HashMap<String, String>,
payload: &str,
timestamp: &str,
) -> Result<String> {
use hmac::{Hmac, Mac};
use sha2::{Digest, Sha256};
type HmacSha256 = Hmac<Sha256>;
let mut canonical_headers = String::new();
let mut signed_headers = Vec::new();
for (key, value) in headers {
canonical_headers.push_str(&format!("{}:{}\n", key.to_lowercase(), value));
signed_headers.push(key.to_lowercase());
}
signed_headers.sort();
let signed_headers_str = signed_headers.join(";");
let payload_hash = format!("{:x}", Sha256::digest(payload.as_bytes()));
let canonical_request = format!(
"{}\n{}\n{}\n{}\n{}\n{}",
method, uri, query_string, canonical_headers, signed_headers_str, payload_hash
);
let date = ×tamp[..8];
let credential_scope = format!("{}/{}/bedrock/aws4_request", date, self.region);
let canonical_request_hash = format!("{:x}", Sha256::digest(canonical_request.as_bytes()));
let string_to_sign = format!(
"AWS4-HMAC-SHA256\n{}\n{}\n{}",
timestamp, credential_scope, canonical_request_hash
);
let k_date =
HmacSha256::new_from_slice(format!("AWS4{}", self.secret_access_key).as_bytes())
.map_err(|e| ProviderError::Authentication(e.to_string()))?;
let k_date = k_date.finalize().into_bytes();
let mut k_region = HmacSha256::new_from_slice(&k_date)
.map_err(|e| ProviderError::Authentication(e.to_string()))?;
k_region.update(self.region.as_bytes());
let k_region = k_region.finalize().into_bytes();
let mut k_service = HmacSha256::new_from_slice(&k_region)
.map_err(|e| ProviderError::Authentication(e.to_string()))?;
k_service.update(b"bedrock");
let k_service = k_service.finalize().into_bytes();
let mut k_signing = HmacSha256::new_from_slice(&k_service)
.map_err(|e| ProviderError::Authentication(e.to_string()))?;
k_signing.update(b"aws4_request");
let k_signing = k_signing.finalize().into_bytes();
let mut signature_mac = HmacSha256::new_from_slice(&k_signing)
.map_err(|e| ProviderError::Authentication(e.to_string()))?;
signature_mac.update(string_to_sign.as_bytes());
let signature = format!("{:x}", signature_mac.finalize().into_bytes());
Ok(signature)
}
async fn create_auth_header(
&self,
method: &str,
uri: &str,
payload: &str,
) -> Result<HashMap<String, String>> {
let timestamp = chrono::Utc::now().format("%Y%m%dT%H%M%SZ").to_string();
let date = ×tamp[..8];
let mut headers = HashMap::new();
headers.insert(
"host".to_string(),
format!("bedrock-runtime.{}.amazonaws.com", self.region),
);
headers.insert("x-amz-date".to_string(), timestamp.clone());
headers.insert("content-type".to_string(), "application/json".to_string());
if let Some(token) = &self.session_token {
headers.insert("x-amz-security-token".to_string(), token.clone());
}
let signature = self
.create_aws_signature(method, uri, "", &headers, payload, ×tamp)
.await?;
let credential = format!("{}/{}/bedrock/aws4_request", date, self.region);
let signed_headers = "content-type;host;x-amz-date";
let authorization = format!(
"AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}",
self.access_key_id, credential, signed_headers, signature
);
headers.insert("authorization".to_string(), authorization);
Ok(headers)
}
fn convert_messages_to_bedrock(
&self,
messages: &[ChatMessage],
model: &str,
) -> serde_json::Value {
if model.starts_with("anthropic.claude") {
let mut system_message = None;
let mut claude_messages = Vec::new();
for message in messages {
match message.role {
MessageRole::System => {
if let Some(MessageContent::Text(text)) = &message.content {
system_message = Some(text.clone());
}
}
MessageRole::User => {
claude_messages.push(json!({
"role": "user",
"content": self.extract_text_content(message.content.as_ref())
}));
}
MessageRole::Assistant => {
claude_messages.push(json!({
"role": "assistant",
"content": self.extract_text_content(message.content.as_ref())
}));
}
_ => {}
}
}
let mut body = json!({
"messages": claude_messages,
"anthropic_version": "bedrock-2023-05-31"
});
if let Some(system) = system_message {
body["system"] = json!(system);
}
body
} else if model.starts_with("meta.llama") {
let prompt = messages
.iter()
.map(|msg| {
let role = match msg.role {
MessageRole::User => "Human",
MessageRole::Assistant => "Assistant",
MessageRole::System => "System",
_ => "Human",
};
format!(
"{}: {}",
role,
self.extract_text_content(msg.content.as_ref())
)
})
.collect::<Vec<String>>()
.join("\n\n");
json!({
"prompt": format!("{}\n\nAssistant:", prompt)
})
} else {
let input_text = messages
.iter()
.map(|msg| self.extract_text_content(msg.content.as_ref()))
.collect::<Vec<String>>()
.join("\n");
json!({
"inputText": input_text
})
}
}
fn extract_text_content(&self, content: Option<&MessageContent>) -> String {
match content {
Some(MessageContent::Text(text)) => text.clone(),
Some(MessageContent::Parts(parts)) => parts
.iter()
.filter_map(|part| match part {
ContentPart::Text { text } => Some(text.clone()),
_ => None,
})
.collect::<Vec<String>>()
.join(" "),
None => String::new(),
}
}
fn convert_bedrock_response_to_openai(
&self,
bedrock_response: serde_json::Value,
model: &str,
) -> Result<ChatCompletionResponse> {
let content = if model.starts_with("anthropic.claude") {
bedrock_response
.get("content")
.and_then(|c| c.as_array())
.and_then(|arr| arr.first())
.and_then(|item| item.get("text"))
.and_then(|text| text.as_str())
.unwrap_or("")
.to_string()
} else if model.starts_with("meta.llama") {
bedrock_response
.get("generation")
.and_then(|g| g.as_str())
.unwrap_or("")
.to_string()
} else {
bedrock_response
.get("results")
.and_then(|r| r.as_array())
.and_then(|arr| arr.first())
.and_then(|item| item.get("outputText"))
.and_then(|text| text.as_str())
.unwrap_or("")
.to_string()
};
let usage = if model.starts_with("anthropic.claude") {
bedrock_response.get("usage").map(|u| Usage {
prompt_tokens: u.get("input_tokens").and_then(|v| v.as_u64()).unwrap_or(0) as u32,
completion_tokens: u.get("output_tokens").and_then(|v| v.as_u64()).unwrap_or(0)
as u32,
total_tokens: 0,
prompt_tokens_details: None,
completion_tokens_details: None,
})
} else {
None
};
let mut usage = usage.unwrap_or_default();
usage.total_tokens = usage.prompt_tokens + usage.completion_tokens;
Ok(ChatCompletionResponse {
id: format!("chatcmpl-bedrock-{}", uuid::Uuid::new_v4()),
object: "chat.completion".to_string(),
created: chrono::Utc::now().timestamp() as u64,
model: model.to_string(),
choices: vec![ChatChoice {
index: 0,
message: ChatMessage {
role: MessageRole::Assistant,
content: Some(MessageContent::Text(content)),
name: None,
function_call: None,
tool_calls: None,
tool_call_id: None,
audio: None,
},
finish_reason: Some("stop".to_string()),
logprobs: None,
}],
usage: Some(usage),
system_fingerprint: None,
})
}
}
#[async_trait]
impl Provider for AWSBedrockProvider {
fn name(&self) -> &str {
&self.base.name
}
fn provider_type(&self) -> ProviderType {
ProviderType::Custom("aws_bedrock".to_string())
}
async fn supports_model(&self, model: &str) -> bool {
self.base.is_model_supported(model)
|| model.starts_with("anthropic.")
|| model.starts_with("amazon.")
|| model.starts_with("meta.")
|| model.starts_with("ai21.")
|| model.starts_with("cohere.")
}
async fn supports_images(&self) -> bool {
true }
async fn supports_embeddings(&self) -> bool {
true }
async fn supports_streaming(&self) -> bool {
true }
async fn list_models(&self) -> Result<Vec<Model>> {
let known_models = vec![
"anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-3-haiku-20240307-v1:0",
"anthropic.claude-v2:1",
"amazon.titan-text-express-v1",
"amazon.titan-text-lite-v1",
"amazon.titan-embed-text-v1",
"meta.llama2-70b-chat-v1",
"meta.llama2-13b-chat-v1",
"ai21.j2-ultra-v1",
"ai21.j2-mid-v1",
"cohere.command-text-v14",
"cohere.command-light-text-v14",
];
let models = known_models
.into_iter()
.map(|model| Model {
id: model.to_string(),
object: "model".to_string(),
created: chrono::Utc::now().timestamp() as u64,
owned_by: "aws".to_string(),
})
.collect();
Ok(models)
}
async fn health_check(&self) -> Result<()> {
debug!("Performing AWS Bedrock health check");
Ok(())
}
async fn chat_completion(
&self,
request: ChatCompletionRequest,
_context: RequestContext,
) -> Result<ChatCompletionResponse> {
debug!("AWS Bedrock chat completion for model: {}", request.model);
let body = self.convert_messages_to_bedrock(&request.messages, &request.model);
let mut final_body = body;
if let Some(max_tokens) = request.max_tokens {
final_body["max_tokens"] = json!(max_tokens);
}
if let Some(temperature) = request.temperature {
final_body["temperature"] = json!(temperature);
}
if let Some(top_p) = request.top_p {
final_body["top_p"] = json!(top_p);
}
let payload = serde_json::to_string(&final_body)
.map_err(|e| ProviderError::Parsing(e.to_string()))?;
let uri = format!("/model/{}/invoke", request.model);
let headers = self.create_auth_header("POST", &uri, &payload).await?;
let url = format!("{}{}", self.base.base_url, uri);
let mut req_builder = self.base.client.post(&url);
for (key, value) in headers {
req_builder = req_builder.header(key, value);
}
let response = req_builder
.body(payload)
.send()
.await
.map_err(|e| ProviderError::Network(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(match status.as_u16() {
401 | 403 => ProviderError::Authentication(error_text),
429 => ProviderError::RateLimit(error_text),
404 => ProviderError::ModelNotFound(error_text),
400 => ProviderError::InvalidRequest(error_text),
_ => ProviderError::Unknown(format!("HTTP {}: {}", status, error_text)),
}
.into());
}
let bedrock_response: serde_json::Value = self.base.parse_json_response(response).await?;
self.convert_bedrock_response_to_openai(bedrock_response, &request.model)
}
async fn completion(
&self,
request: CompletionRequest,
_context: RequestContext,
) -> Result<CompletionResponse> {
debug!("AWS Bedrock completion for model: {}", request.model);
let body = if request.model.starts_with("amazon.titan") {
json!({
"inputText": request.prompt,
"textGenerationConfig": {
"maxTokenCount": request.max_tokens.unwrap_or(512),
"temperature": request.temperature.unwrap_or(0.7),
"topP": request.top_p.unwrap_or(1.0)
}
})
} else {
json!({
"prompt": request.prompt,
"max_tokens": request.max_tokens.unwrap_or(512),
"temperature": request.temperature.unwrap_or(0.7)
})
};
let payload =
serde_json::to_string(&body).map_err(|e| ProviderError::Parsing(e.to_string()))?;
let uri = format!("/model/{}/invoke", request.model);
let headers = self.create_auth_header("POST", &uri, &payload).await?;
let url = format!("{}{}", self.base.base_url, uri);
let mut req_builder = self.base.client.post(&url);
for (key, value) in headers {
req_builder = req_builder.header(key, value);
}
let response = req_builder
.body(payload)
.send()
.await
.map_err(|e| ProviderError::Network(e.to_string()))?;
let bedrock_response: serde_json::Value = self.base.parse_json_response(response).await?;
let text = if request.model.starts_with("amazon.titan") {
bedrock_response
.get("results")
.and_then(|r| r.as_array())
.and_then(|arr| arr.first())
.and_then(|item| item.get("outputText"))
.and_then(|text| text.as_str())
.unwrap_or("")
.to_string()
} else {
bedrock_response
.get("generation")
.and_then(|g| g.as_str())
.unwrap_or("")
.to_string()
};
Ok(CompletionResponse {
id: format!("cmpl-bedrock-{}", uuid::Uuid::new_v4()),
object: "text_completion".to_string(),
created: chrono::Utc::now().timestamp() as u64,
model: request.model,
choices: vec![CompletionChoice {
text,
index: 0,
logprobs: None,
finish_reason: Some("stop".to_string()),
}],
usage: Some(Usage {
prompt_tokens: 0, completion_tokens: 0,
total_tokens: 0,
prompt_tokens_details: None,
completion_tokens_details: None,
}),
})
}
async fn embedding(
&self,
request: EmbeddingRequest,
_context: RequestContext,
) -> Result<EmbeddingResponse> {
debug!("AWS Bedrock embedding for model: {}", request.model);
let body = json!({
"inputText": request.input
});
let payload =
serde_json::to_string(&body).map_err(|e| ProviderError::Parsing(e.to_string()))?;
let uri = format!("/model/{}/invoke", request.model);
let headers = self.create_auth_header("POST", &uri, &payload).await?;
let url = format!("{}{}", self.base.base_url, uri);
let mut req_builder = self.base.client.post(&url);
for (key, value) in headers {
req_builder = req_builder.header(key, value);
}
let response = req_builder
.body(payload)
.send()
.await
.map_err(|e| ProviderError::Network(e.to_string()))?;
let bedrock_response: serde_json::Value = self.base.parse_json_response(response).await?;
let embedding_vec = bedrock_response
.get("embedding")
.and_then(|e| e.as_array())
.unwrap_or(&vec![])
.iter()
.filter_map(|v| v.as_f64())
.collect();
let embeddings = vec![EmbeddingObject {
object: "embedding".to_string(),
embedding: embedding_vec,
index: 0,
}];
Ok(EmbeddingResponse {
object: "list".to_string(),
data: embeddings,
model: request.model,
usage: EmbeddingUsage {
prompt_tokens: 0,
total_tokens: 0,
},
})
}
async fn image_generation(
&self,
_request: ImageGenerationRequest,
_context: RequestContext,
) -> Result<ImageGenerationResponse> {
Err(ProviderError::InvalidRequest(
"Image generation not implemented for Bedrock yet".to_string(),
)
.into())
}
async fn get_model_pricing(&self, model: &str) -> Result<ModelPricing> {
if let Some(pricing) = self.pricing_cache.get(model) {
Ok(pricing.clone())
} else {
Ok(ModelPricing {
model: model.to_string(),
input_cost_per_1k: 0.001,
output_cost_per_1k: 0.003,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
})
}
}
async fn calculate_cost(
&self,
model: &str,
input_tokens: u32,
output_tokens: u32,
) -> Result<f64> {
let pricing = self.get_model_pricing(model).await?;
let input_cost = (input_tokens as f64 / 1000.0) * pricing.input_cost_per_1k;
let output_cost = (output_tokens as f64 / 1000.0) * pricing.output_cost_per_1k;
Ok(input_cost + output_cost)
}
}