use super::{ChatMessage, LLMProvider, LLMRequest, LLMResponse, MessageRole, ModelInfo};
use crate::core::http_client::create_http_client;
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[allow(unused_imports)]
use std::collections::HashMap;
#[allow(unused_imports)]
use std::pin::Pin;
#[allow(unused_imports)]
use tracing::{debug, error, info};
#[derive(Debug, Clone)]
pub struct AWSBedrockProvider {
#[allow(dead_code)]
client: reqwest::Client,
#[allow(dead_code)]
aws_access_key_id: String,
#[allow(dead_code)]
aws_secret_access_key: String,
aws_session_token: Option<String>,
region: String,
model: String,
}
impl AWSBedrockProvider {
pub fn new(aws_access_key_id: String, aws_secret_access_key: String, region: String) -> Self {
Self {
client: create_http_client(Some(120)),
aws_access_key_id,
aws_secret_access_key,
aws_session_token: None,
region,
model: "anthropic.claude-3-sonnet-20240229-v1:0".to_string(),
}
}
pub fn with_session_token(mut self, session_token: String) -> Self {
self.aws_session_token = Some(session_token);
self
}
pub fn with_model(mut self, model: String) -> Self {
self.model = model;
self
}
pub fn with_region(mut self, region: String) -> Self {
self.region = region;
self
}
#[allow(dead_code)]
fn to_bedrock(&self, messages: &[ChatMessage]) -> (Option<String>, Vec<BedrockMessage>) {
let mut system_message = None;
let mut bedrock_messages = Vec::new();
for message in messages {
match message.role {
MessageRole::System => {
system_message = Some(message.content.clone());
}
MessageRole::User => {
bedrock_messages.push(BedrockMessage {
role: "user".to_string(),
content: vec![BedrockContent {
content_type: "text".to_string(),
text: Some(message.content.clone()),
}],
});
}
MessageRole::Assistant => {
bedrock_messages.push(BedrockMessage {
role: "assistant".to_string(),
content: vec![BedrockContent {
content_type: "text".to_string(),
text: Some(message.content.clone()),
}],
});
}
MessageRole::Tool => {
bedrock_messages.push(BedrockMessage {
role: "user".to_string(),
content: vec![BedrockContent {
content_type: "text".to_string(),
text: Some(format!("Tool response: {}", message.content)),
}],
});
}
}
}
(system_message, bedrock_messages)
}
#[allow(dead_code)]
fn get_bedrock_endpoint(&self, model: &str) -> String {
format!(
"https://bedrock-runtime.{}.amazonaws.com/model/{}/invoke",
self.region, model
)
}
#[allow(dead_code)]
async fn sign_request(
&self,
_method: &str,
_url: &str,
_payload: &str,
) -> Result<HashMap<String, String>> {
Err(anyhow!(
"AWS request signing temporarily disabled - API compatibility issues"
))
}
}
#[async_trait]
impl LLMProvider for AWSBedrockProvider {
async fn complete(&self, _request: LLMRequest) -> Result<LLMResponse> {
Err(anyhow!(
"AWS Bedrock provider not yet implemented. \
This is a placeholder for future AWS SDK integration. \
For now, use 'ollama' or 'openai' providers instead. \
See: https://github.com/Michael-A-Kuykendall/rustchain for roadmap."
))
}
async fn stream<'a>(
&'a self,
_request: LLMRequest,
) -> Result<Pin<Box<dyn futures::Stream<Item = Result<LLMResponse>> + Send + 'a>>, anyhow::Error>
{
Err(anyhow!(
"AWS Bedrock streaming not yet implemented. \
This requires AWS SDK integration which is planned for a future release."
))
}
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
Ok(vec![
ModelInfo {
id: "anthropic.claude-3-sonnet-20240229-v1:0".to_string(),
name: "Claude 3 Sonnet".to_string(),
provider: "aws-bedrock".to_string(),
context_length: 200_000,
max_output_tokens: 4096,
supports_tools: true,
supports_streaming: false, cost_per_input_token: Some(0.000003), cost_per_output_token: Some(0.000015), },
ModelInfo {
id: "anthropic.claude-3-haiku-20240307-v1:0".to_string(),
name: "Claude 3 Haiku".to_string(),
provider: "aws-bedrock".to_string(),
context_length: 200_000,
max_output_tokens: 4096,
supports_tools: true,
supports_streaming: false,
cost_per_input_token: Some(0.00000025), cost_per_output_token: Some(0.00000125), },
ModelInfo {
id: "amazon.titan-text-express-v1".to_string(),
name: "Titan Text G1 - Express".to_string(),
provider: "aws-bedrock".to_string(),
context_length: 8000,
max_output_tokens: 8000,
supports_tools: false,
supports_streaming: false,
cost_per_input_token: Some(0.0000008), cost_per_output_token: Some(0.0000016), },
ModelInfo {
id: "ai21.j2-ultra-v1".to_string(),
name: "Jurassic-2 Ultra".to_string(),
provider: "aws-bedrock".to_string(),
context_length: 8192,
max_output_tokens: 8192,
supports_tools: false,
supports_streaming: false,
cost_per_input_token: Some(0.0000188), cost_per_output_token: Some(0.0000188), },
ModelInfo {
id: "cohere.command-text-v14".to_string(),
name: "Command Text".to_string(),
provider: "aws-bedrock".to_string(),
context_length: 4096,
max_output_tokens: 4096,
supports_tools: false,
supports_streaming: false,
cost_per_input_token: Some(0.0000015), cost_per_output_token: Some(0.000002), },
ModelInfo {
id: "meta.llama2-70b-chat-v1".to_string(),
name: "Llama 2 70B Chat".to_string(),
provider: "aws-bedrock".to_string(),
context_length: 4096,
max_output_tokens: 2048,
supports_tools: false,
supports_streaming: false,
cost_per_input_token: Some(0.00000195), cost_per_output_token: Some(0.00000256), },
])
}
fn provider_name(&self) -> &str {
"aws-bedrock"
}
fn supports_streaming(&self) -> bool {
false }
fn supports_tools(&self) -> bool {
true
}
}
#[derive(Debug, Serialize)]
#[allow(dead_code)]
struct BedrockRequest {
anthropic_version: String,
max_tokens: u32,
messages: Vec<BedrockMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
top_k: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
stop_sequences: Option<Vec<String>>,
}
#[derive(Debug, Serialize, Deserialize)]
struct BedrockMessage {
role: String,
content: Vec<BedrockContent>,
}
#[derive(Debug, Serialize, Deserialize)]
struct BedrockContent {
#[serde(rename = "type")]
content_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
text: Option<String>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct BedrockResponse {
content: Vec<BedrockContent>,
#[serde(rename = "stop_reason")]
stop_reason: Option<String>,
usage: BedrockUsage,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct BedrockUsage {
input_tokens: u32,
output_tokens: u32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bedrock_provider_creation() {
let provider = AWSBedrockProvider::new(
"test-access-key".to_string(),
"test-secret-key".to_string(),
"us-east-1".to_string(),
);
assert_eq!(provider.provider_name(), "aws-bedrock");
assert_eq!(provider.region, "us-east-1");
assert_eq!(provider.model, "anthropic.claude-3-sonnet-20240229-v1:0");
}
#[test]
fn test_bedrock_provider_with_custom_model() {
let provider = AWSBedrockProvider::new(
"test-access-key".to_string(),
"test-secret-key".to_string(),
"us-west-2".to_string(),
)
.with_model("anthropic.claude-3-haiku-20240307-v1:0".to_string());
assert_eq!(provider.model, "anthropic.claude-3-haiku-20240307-v1:0");
assert_eq!(provider.region, "us-west-2");
}
#[test]
fn test_bedrock_provider_with_session_token() {
let provider = AWSBedrockProvider::new(
"test-access-key".to_string(),
"test-secret-key".to_string(),
"us-east-1".to_string(),
)
.with_session_token("test-session-token".to_string());
assert_eq!(
provider.aws_session_token,
Some("test-session-token".to_string())
);
}
#[test]
fn test_message_conversion() {
let provider = AWSBedrockProvider::new(
"test-access-key".to_string(),
"test-secret-key".to_string(),
"us-east-1".to_string(),
);
let messages = vec![
ChatMessage {
role: MessageRole::System,
content: "You are a helpful assistant".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
},
ChatMessage {
role: MessageRole::User,
content: "Hello!".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
},
];
let (system, bedrock_messages) = provider.to_bedrock(&messages);
assert_eq!(system, Some("You are a helpful assistant".to_string()));
assert_eq!(bedrock_messages.len(), 1);
assert_eq!(bedrock_messages[0].role, "user");
assert_eq!(
bedrock_messages[0].content[0].text,
Some("Hello!".to_string())
);
}
#[test]
fn test_bedrock_endpoint_generation() {
let provider = AWSBedrockProvider::new(
"test-access-key".to_string(),
"test-secret-key".to_string(),
"us-east-1".to_string(),
);
let endpoint = provider.get_bedrock_endpoint("anthropic.claude-3-sonnet-20240229-v1:0");
assert_eq!(
endpoint,
"https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke"
);
}
#[tokio::test]
async fn test_list_models() {
let provider = AWSBedrockProvider::new(
"test-access-key".to_string(),
"test-secret-key".to_string(),
"us-east-1".to_string(),
);
let models = provider.list_models().await.unwrap();
assert_eq!(models.len(), 6);
assert!(models
.iter()
.any(|m| m.id == "anthropic.claude-3-sonnet-20240229-v1:0"));
assert!(models
.iter()
.any(|m| m.id == "anthropic.claude-3-haiku-20240307-v1:0"));
assert!(models
.iter()
.any(|m| m.id == "amazon.titan-text-express-v1"));
assert!(models.iter().any(|m| m.id == "ai21.j2-ultra-v1"));
assert!(models.iter().any(|m| m.id == "cohere.command-text-v14"));
assert!(models.iter().any(|m| m.id == "meta.llama2-70b-chat-v1"));
}
}