use std::collections::HashMap;
use async_trait::async_trait;
use aws_config::SdkConfig;
use aws_sdk_bedrockruntime::error::ProvideErrorMetadata;
use aws_sdk_bedrockruntime::primitives::Blob;
use aws_sdk_bedrockruntime::types::{
ContentBlock, ConversationRole, ConverseOutput, InferenceConfiguration, Message, StopReason,
SystemContentBlock, Tool, ToolChoice as BedrockToolChoice, ToolConfiguration, ToolInputSchema,
ToolSpecification, ToolUseBlock,
};
use aws_sdk_bedrockruntime::Client;
use aws_smithy_types::Document;
use futures::stream::BoxStream;
use tracing::{debug, instrument};
use crate::error::{LlmError, Result};
use crate::traits::{
ChatMessage, ChatRole, CompletionOptions, EmbeddingProvider, LLMProvider, LLMResponse,
ToolCall as EdgequakeToolCall, ToolChoice as EdgequakeToolChoice,
ToolDefinition as EdgequakeToolDefinition,
};
fn format_sdk_error<E: ProvideErrorMetadata + std::fmt::Debug>(
label: &str,
err: &aws_sdk_bedrockruntime::error::SdkError<E>,
) -> LlmError {
let detail = if let Some(se) = err.as_service_error() {
let code = se.meta().code().unwrap_or("Unknown");
let msg = se.meta().message().unwrap_or("No message");
format!("{code}: {msg}")
} else {
format!("{err:?}")
};
LlmError::ProviderError(format!("Bedrock {label} error: {detail}"))
}
const DEFAULT_MODEL: &str = "amazon.nova-lite-v1:0";
const DEFAULT_REGION: &str = "us-east-1";
const DEFAULT_MAX_CONTEXT: usize = 300_000;
const DEFAULT_EMBEDDING_MODEL: &str = "amazon.titan-embed-text-v2:0";
const DEFAULT_EMBEDDING_DIMENSION: usize = 1024;
const DEFAULT_EMBEDDING_MAX_TOKENS: usize = 8192;
#[derive(Debug, Clone)]
pub struct BedrockProvider {
client: Client,
model: String,
region: String,
max_context_length: usize,
embedding_model: String,
embedding_dimension: usize,
}
impl BedrockProvider {
pub fn new(sdk_config: &SdkConfig, model: impl Into<String>) -> Self {
let model = model.into();
let region = sdk_config
.region()
.map(|r| r.to_string())
.unwrap_or_else(|| DEFAULT_REGION.to_string());
let max_context_length = Self::context_length_for_model(&model);
Self {
client: Client::new(sdk_config),
model,
region,
max_context_length,
embedding_model: DEFAULT_EMBEDDING_MODEL.to_string(),
embedding_dimension: DEFAULT_EMBEDDING_DIMENSION,
}
}
pub async fn from_env() -> Result<Self> {
let region = std::env::var("AWS_REGION")
.or_else(|_| std::env::var("AWS_DEFAULT_REGION"))
.unwrap_or_else(|_| DEFAULT_REGION.to_string());
let model =
std::env::var("AWS_BEDROCK_MODEL").unwrap_or_else(|_| DEFAULT_MODEL.to_string());
let embedding_model = std::env::var("AWS_BEDROCK_EMBEDDING_MODEL")
.unwrap_or_else(|_| DEFAULT_EMBEDDING_MODEL.to_string());
let embedding_dimension = Self::dimension_for_model(&embedding_model);
let sdk_config = aws_config::from_env()
.region(aws_config::Region::new(region.clone()))
.load()
.await;
let max_context_length = Self::context_length_for_model(&model);
Ok(Self {
client: Client::new(&sdk_config),
model,
region,
max_context_length,
embedding_model,
embedding_dimension,
})
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
let model = model.into();
self.max_context_length = Self::context_length_for_model(&model);
self.model = model;
self
}
pub fn with_max_context_length(mut self, length: usize) -> Self {
self.max_context_length = length;
self
}
pub fn with_embedding_model(mut self, model: impl Into<String>) -> Self {
let model = model.into();
self.embedding_dimension = Self::dimension_for_model(&model);
self.embedding_model = model;
self
}
pub fn with_embedding_dimension(mut self, dimension: usize) -> Self {
self.embedding_dimension = dimension;
self
}
pub fn region(&self) -> &str {
&self.region
}
fn resolve_model_id(&self) -> String {
Self::resolve_model_id_for_region(&self.model, &self.region)
}
fn resolve_model_id_for_region(model: &str, region: &str) -> String {
if model.starts_with("arn:")
|| model.starts_with("global.")
|| model.starts_with("us.")
|| model.starts_with("eu.")
|| model.starts_with("ap.")
|| model.starts_with("ca.")
|| model.starts_with("sa.")
|| model.starts_with("me.")
|| model.starts_with("af.")
{
return model.to_string();
}
let has_inference_profile = model.starts_with("amazon.nova")
|| model.starts_with("anthropic.claude")
|| model.starts_with("meta.llama")
|| model.starts_with("cohere.embed")
|| model.starts_with("deepseek.")
|| model.starts_with("mistral.pixtral")
|| model.starts_with("writer.")
|| model.starts_with("twelvelabs.");
if has_inference_profile {
let prefix = region.split('-').next().unwrap_or("us");
format!("{prefix}.{model}")
} else {
model.to_string()
}
}
fn context_length_for_model(model: &str) -> usize {
let model_lower = model.to_lowercase();
if model_lower.contains("claude-3") || model_lower.contains("claude-4") {
200_000
} else if model_lower.contains("claude-2") {
100_000
} else if model_lower.contains("nova") {
300_000
} else if model_lower.contains("devstral") {
256_000
} else if model_lower.contains("minimax") {
1_000_000
} else if model_lower.contains("qwen") {
131_072
} else if model_lower.contains("llama")
|| model_lower.contains("cohere")
|| model_lower.contains("deepseek")
|| model_lower.contains("pixtral")
|| model_lower.contains("magistral")
|| model_lower.contains("writer")
|| model_lower.contains("palmyra")
|| model_lower.contains("nemotron")
|| model_lower.contains("gemma")
|| model_lower.contains("glm")
|| model_lower.contains("gpt-oss")
{
128_000
} else if model_lower.contains("mistral") {
32_000
} else {
DEFAULT_MAX_CONTEXT
}
}
pub fn dimension_for_model(model: &str) -> usize {
let m = model.to_lowercase();
if m.contains("titan-embed-text-v2") {
1024
} else if m.contains("titan-embed-text-v1") || m.contains("titan-embed-g1") {
1536
} else if m.contains("titan-embed-image") {
1024
} else if m.contains("embed-v4") {
1536
} else if m.contains("embed-english-v3") || m.contains("embed-multilingual-v3") {
1024
} else if m.contains("nova") && m.contains("embed") {
1024
} else if m.contains("marengo") {
1024
} else {
DEFAULT_EMBEDDING_DIMENSION
}
}
fn embedding_max_tokens_for_model(model: &str) -> usize {
let m = model.to_lowercase();
if m.contains("titan-embed-text-v2") {
8192
} else if m.contains("titan-embed-text-v1") || m.contains("titan-embed-g1") {
512
} else if m.contains("cohere") && m.contains("embed") {
2048
} else {
DEFAULT_EMBEDDING_MAX_TOKENS
}
}
fn build_embedding_request(model: &str, texts: &[String]) -> Result<Vec<u8>> {
let m = model.to_lowercase();
if m.contains("titan") || m.contains("nova") && m.contains("embed") {
if texts.len() != 1 {
return Err(LlmError::InvalidRequest(
"Titan/Nova embedding models require single-text requests; \
batch is handled by the caller"
.to_string(),
));
}
let body = serde_json::json!({
"inputText": texts[0]
});
serde_json::to_vec(&body)
.map_err(|e| LlmError::InvalidRequest(format!("Failed to serialize body: {e}")))
} else if m.contains("cohere") && m.contains("embed") {
let body = serde_json::json!({
"texts": texts,
"input_type": "search_query"
});
serde_json::to_vec(&body)
.map_err(|e| LlmError::InvalidRequest(format!("Failed to serialize body: {e}")))
} else {
if texts.len() != 1 {
return Err(LlmError::InvalidRequest(
"Unknown embedding model; single-text requests only".to_string(),
));
}
let body = serde_json::json!({
"inputText": texts[0]
});
serde_json::to_vec(&body)
.map_err(|e| LlmError::InvalidRequest(format!("Failed to serialize body: {e}")))
}
}
fn parse_embedding_response(model: &str, response_bytes: &[u8]) -> Result<Vec<Vec<f32>>> {
let m = model.to_lowercase();
let json: serde_json::Value = serde_json::from_slice(response_bytes).map_err(|e| {
LlmError::ProviderError(format!("Failed to parse embedding response: {e}"))
})?;
if m.contains("titan") || (m.contains("nova") && m.contains("embed")) {
let embedding = json
.get("embedding")
.and_then(|v| v.as_array())
.ok_or_else(|| {
LlmError::ProviderError(
"Missing 'embedding' array in Titan/Nova response".to_string(),
)
})?;
let vec: Vec<f32> = embedding
.iter()
.map(|v| v.as_f64().unwrap_or(0.0) as f32)
.collect();
Ok(vec![vec])
} else if m.contains("cohere") && m.contains("embed") {
let embeddings_val = json.get("embeddings").ok_or_else(|| {
LlmError::ProviderError("Missing 'embeddings' in Cohere response".to_string())
})?;
let embedding_arrays = if let Some(obj) = embeddings_val.as_object() {
obj.get("float")
.and_then(|v| v.as_array())
.ok_or_else(|| {
LlmError::ProviderError(
"Missing 'float' key in Cohere embeddings dict".to_string(),
)
})?
.clone()
} else if let Some(arr) = embeddings_val.as_array() {
arr.clone()
} else {
return Err(LlmError::ProviderError(
"Unexpected 'embeddings' format in Cohere response".to_string(),
));
};
let result: Vec<Vec<f32>> = embedding_arrays
.iter()
.map(|emb| {
emb.as_array()
.unwrap_or(&Vec::new())
.iter()
.map(|v| v.as_f64().unwrap_or(0.0) as f32)
.collect()
})
.collect();
Ok(result)
} else {
let embedding = json
.get("embedding")
.and_then(|v| v.as_array())
.ok_or_else(|| {
LlmError::ProviderError("Missing 'embedding' array in response".to_string())
})?;
let vec: Vec<f32> = embedding
.iter()
.map(|v| v.as_f64().unwrap_or(0.0) as f32)
.collect();
Ok(vec![vec])
}
}
fn is_cohere_embedding(model: &str) -> bool {
let m = model.to_lowercase();
m.contains("cohere") && m.contains("embed")
}
fn json_to_document(value: &serde_json::Value) -> Document {
match value {
serde_json::Value::Null => Document::Null,
serde_json::Value::Bool(b) => Document::Bool(*b),
serde_json::Value::Number(n) => {
if let Some(u) = n.as_u64() {
Document::Number(aws_smithy_types::Number::PosInt(u))
} else if let Some(i) = n.as_i64() {
Document::Number(aws_smithy_types::Number::NegInt(i))
} else if let Some(f) = n.as_f64() {
Document::Number(aws_smithy_types::Number::Float(f))
} else {
Document::Null
}
}
serde_json::Value::String(s) => Document::String(s.clone()),
serde_json::Value::Array(arr) => {
Document::Array(arr.iter().map(Self::json_to_document).collect())
}
serde_json::Value::Object(obj) => Document::Object(
obj.iter()
.map(|(k, v)| (k.clone(), Self::json_to_document(v)))
.collect(),
),
}
}
fn document_to_json(doc: &Document) -> serde_json::Value {
match doc {
Document::Null => serde_json::Value::Null,
Document::Bool(b) => serde_json::Value::Bool(*b),
Document::Number(n) => match n {
aws_smithy_types::Number::PosInt(u) => serde_json::json!(*u),
aws_smithy_types::Number::NegInt(i) => serde_json::json!(*i),
aws_smithy_types::Number::Float(f) => serde_json::json!(*f),
},
Document::String(s) => serde_json::Value::String(s.clone()),
Document::Array(arr) => {
serde_json::Value::Array(arr.iter().map(Self::document_to_json).collect())
}
Document::Object(obj) => serde_json::Value::Object(
obj.iter()
.map(|(k, v)| (k.clone(), Self::document_to_json(v)))
.collect(),
),
}
}
fn convert_messages(
messages: &[ChatMessage],
system_prompt: Option<&str>,
) -> Result<(Vec<Message>, Vec<SystemContentBlock>)> {
let mut bedrock_messages: Vec<Message> = Vec::new();
let mut system_blocks: Vec<SystemContentBlock> = Vec::new();
if let Some(sys) = system_prompt {
if !sys.is_empty() {
system_blocks.push(SystemContentBlock::Text(sys.to_string()));
}
}
for msg in messages {
match msg.role {
ChatRole::System => {
system_blocks.push(SystemContentBlock::Text(msg.content.clone()));
}
ChatRole::User => {
let content = ContentBlock::Text(msg.content.clone());
let bedrock_msg = Message::builder()
.role(ConversationRole::User)
.content(content)
.build()
.map_err(|e| {
LlmError::ProviderError(format!(
"Failed to build Bedrock user message: {e}"
))
})?;
bedrock_messages.push(bedrock_msg);
}
ChatRole::Assistant => {
let mut content_blocks: Vec<ContentBlock> = Vec::new();
if !msg.content.is_empty() {
content_blocks.push(ContentBlock::Text(msg.content.clone()));
}
if let Some(ref tool_calls) = msg.tool_calls {
for tc in tool_calls {
let input_doc =
serde_json::from_str::<serde_json::Value>(&tc.function.arguments)
.map(|v| Self::json_to_document(&v))
.unwrap_or_else(|_| {
Document::String(tc.function.arguments.clone())
});
let tool_use = ToolUseBlock::builder()
.tool_use_id(&tc.id)
.name(&tc.function.name)
.input(input_doc)
.build()
.map_err(|e| {
LlmError::ProviderError(format!(
"Failed to build tool use block: {e}"
))
})?;
content_blocks.push(ContentBlock::ToolUse(tool_use));
}
}
let mut builder = Message::builder().role(ConversationRole::Assistant);
for block in content_blocks {
builder = builder.content(block);
}
let bedrock_msg = builder.build().map_err(|e| {
LlmError::ProviderError(format!(
"Failed to build Bedrock assistant message: {e}"
))
})?;
bedrock_messages.push(bedrock_msg);
}
ChatRole::Tool | ChatRole::Function => {
let tool_call_id = msg.tool_call_id.as_deref().unwrap_or("unknown").to_string();
let result_content =
aws_sdk_bedrockruntime::types::ToolResultContentBlock::Text(
msg.content.clone(),
);
let tool_result = aws_sdk_bedrockruntime::types::ToolResultBlock::builder()
.tool_use_id(tool_call_id)
.content(result_content)
.build()
.map_err(|e| {
LlmError::ProviderError(format!(
"Failed to build tool result block: {e}"
))
})?;
let content = ContentBlock::ToolResult(tool_result);
let bedrock_msg = Message::builder()
.role(ConversationRole::User)
.content(content)
.build()
.map_err(|e| {
LlmError::ProviderError(format!(
"Failed to build Bedrock tool result message: {e}"
))
})?;
bedrock_messages.push(bedrock_msg);
}
}
}
Ok((bedrock_messages, system_blocks))
}
fn build_inference_config(
options: Option<&CompletionOptions>,
) -> Option<InferenceConfiguration> {
let opts = options?;
let mut builder = InferenceConfiguration::builder();
let mut has_config = false;
if let Some(max_tokens) = opts.max_tokens {
builder = builder.max_tokens(max_tokens as i32);
has_config = true;
}
if let Some(temperature) = opts.temperature {
builder = builder.temperature(temperature);
has_config = true;
}
if let Some(top_p) = opts.top_p {
builder = builder.top_p(top_p);
has_config = true;
}
if let Some(ref stops) = opts.stop {
for s in stops {
builder = builder.stop_sequences(s.clone());
}
has_config = true;
}
if has_config {
Some(builder.build())
} else {
None
}
}
fn build_tool_config(
tools: &[EdgequakeToolDefinition],
tool_choice: Option<&EdgequakeToolChoice>,
) -> Result<Option<ToolConfiguration>> {
if tools.is_empty() {
return Ok(None);
}
let mut bedrock_tools = Vec::new();
for tool in tools {
let schema_doc = Self::json_to_document(&tool.function.parameters);
let spec = ToolSpecification::builder()
.name(&tool.function.name)
.description(&tool.function.description)
.input_schema(ToolInputSchema::Json(schema_doc))
.build()
.map_err(|e| {
LlmError::ProviderError(format!("Failed to build tool specification: {e}"))
})?;
bedrock_tools.push(Tool::ToolSpec(spec));
}
let mut config_builder = ToolConfiguration::builder();
for tool in bedrock_tools {
config_builder = config_builder.tools(tool);
}
if let Some(choice) = tool_choice {
let bedrock_choice = match choice {
EdgequakeToolChoice::Auto(s) if s == "none" => {
return Ok(None);
}
EdgequakeToolChoice::Auto(_) => BedrockToolChoice::Auto(
aws_sdk_bedrockruntime::types::AutoToolChoice::builder().build(),
),
EdgequakeToolChoice::Required(_) => BedrockToolChoice::Any(
aws_sdk_bedrockruntime::types::AnyToolChoice::builder().build(),
),
EdgequakeToolChoice::Function { function, .. } => BedrockToolChoice::Tool(
aws_sdk_bedrockruntime::types::SpecificToolChoice::builder()
.name(&function.name)
.build()
.map_err(|e| {
LlmError::ProviderError(format!(
"Failed to build specific tool choice: {e}"
))
})?,
),
};
config_builder = config_builder.tool_choice(bedrock_choice);
}
let config = config_builder.build().map_err(|e| {
LlmError::ProviderError(format!("Failed to build tool configuration: {e}"))
})?;
Ok(Some(config))
}
fn map_stop_reason(reason: &StopReason) -> String {
match reason {
StopReason::EndTurn => "stop".to_string(),
StopReason::MaxTokens => "length".to_string(),
StopReason::StopSequence => "stop".to_string(),
StopReason::ToolUse => "tool_calls".to_string(),
StopReason::ContentFiltered => "content_filter".to_string(),
StopReason::GuardrailIntervened => "content_filter".to_string(),
_ => "stop".to_string(),
}
}
fn extract_content(output: &ConverseOutput) -> (String, Vec<EdgequakeToolCall>) {
let mut text_parts = Vec::new();
let mut tool_calls = Vec::new();
if let ConverseOutput::Message(msg) = output {
for block in msg.content() {
match block {
ContentBlock::Text(text) => {
text_parts.push(text.clone());
}
ContentBlock::ToolUse(tool_use) => {
let arguments_json = Self::document_to_json(&tool_use.input);
let arguments_str =
serde_json::to_string(&arguments_json).unwrap_or_default();
tool_calls.push(EdgequakeToolCall {
id: tool_use.tool_use_id.clone(),
call_type: "function".to_string(),
function: crate::traits::FunctionCall {
name: tool_use.name.clone(),
arguments: arguments_str,
},
thought_signature: None,
});
}
_ => {}
}
}
}
(text_parts.join(""), tool_calls)
}
}
#[async_trait]
impl LLMProvider for BedrockProvider {
fn name(&self) -> &str {
"bedrock"
}
fn model(&self) -> &str {
&self.model
}
fn max_context_length(&self) -> usize {
self.max_context_length
}
#[instrument(skip(self, prompt), fields(provider = "bedrock", model = %self.model))]
async fn complete(&self, prompt: &str) -> Result<LLMResponse> {
let messages = vec![ChatMessage::user(prompt)];
self.chat(&messages, None).await
}
#[instrument(skip(self, prompt, options), fields(provider = "bedrock", model = %self.model))]
async fn complete_with_options(
&self,
prompt: &str,
options: &CompletionOptions,
) -> Result<LLMResponse> {
let messages = vec![ChatMessage::user(prompt)];
self.chat(&messages, Some(options)).await
}
#[instrument(skip(self, messages, options), fields(provider = "bedrock", model = %self.model))]
async fn chat(
&self,
messages: &[ChatMessage],
options: Option<&CompletionOptions>,
) -> Result<LLMResponse> {
let system_prompt = options.and_then(|o| o.system_prompt.as_deref());
let (bedrock_messages, system_blocks) = Self::convert_messages(messages, system_prompt)?;
let resolved_model = self.resolve_model_id();
let mut request = self.client.converse().model_id(&resolved_model);
for msg in bedrock_messages {
request = request.messages(msg);
}
for block in system_blocks {
request = request.system(block);
}
if let Some(config) = Self::build_inference_config(options) {
request = request.inference_config(config);
}
debug!(
"Sending Bedrock Converse request for model: {} (resolved: {})",
self.model, resolved_model
);
let response = request
.send()
.await
.map_err(|e| format_sdk_error("Converse API", &e))?;
let (content, tool_calls) = response
.output()
.map(Self::extract_content)
.unwrap_or_default();
let (prompt_tokens, completion_tokens, total_tokens) = response
.usage()
.map(|u| {
let input = u.input_tokens() as usize;
let output = u.output_tokens() as usize;
(input, output, input + output)
})
.unwrap_or((0, 0, 0));
let finish_reason = Self::map_stop_reason(&response.stop_reason);
Ok(LLMResponse {
content,
prompt_tokens,
completion_tokens,
total_tokens,
model: resolved_model,
finish_reason: Some(finish_reason),
tool_calls,
metadata: HashMap::new(),
cache_hit_tokens: None,
thinking_tokens: None,
thinking_content: None,
})
}
#[instrument(skip(self, messages, tools, tool_choice, options), fields(provider = "bedrock", model = %self.model))]
async fn chat_with_tools(
&self,
messages: &[ChatMessage],
tools: &[EdgequakeToolDefinition],
tool_choice: Option<EdgequakeToolChoice>,
options: Option<&CompletionOptions>,
) -> Result<LLMResponse> {
let system_prompt = options.and_then(|o| o.system_prompt.as_deref());
let (bedrock_messages, system_blocks) = Self::convert_messages(messages, system_prompt)?;
let resolved_model = self.resolve_model_id();
let mut request = self.client.converse().model_id(&resolved_model);
for msg in bedrock_messages {
request = request.messages(msg);
}
for block in system_blocks {
request = request.system(block);
}
if let Some(config) = Self::build_inference_config(options) {
request = request.inference_config(config);
}
if let Some(tool_config) = Self::build_tool_config(tools, tool_choice.as_ref())? {
request = request.tool_config(tool_config);
}
debug!(
"Sending Bedrock Converse request with {} tools for model: {} (resolved: {})",
tools.len(),
self.model,
resolved_model
);
let response = request
.send()
.await
.map_err(|e| format_sdk_error("Converse API", &e))?;
let (content, tool_calls) = response
.output()
.map(Self::extract_content)
.unwrap_or_default();
let (prompt_tokens, completion_tokens, total_tokens) = response
.usage()
.map(|u| {
let input = u.input_tokens() as usize;
let output = u.output_tokens() as usize;
(input, output, input + output)
})
.unwrap_or((0, 0, 0));
let finish_reason = Self::map_stop_reason(&response.stop_reason);
Ok(LLMResponse {
content,
prompt_tokens,
completion_tokens,
total_tokens,
model: resolved_model,
finish_reason: Some(finish_reason),
tool_calls,
metadata: HashMap::new(),
cache_hit_tokens: None,
thinking_tokens: None,
thinking_content: None,
})
}
#[instrument(skip(self, prompt), fields(provider = "bedrock", model = %self.model))]
async fn stream(&self, prompt: &str) -> Result<BoxStream<'static, Result<String>>> {
let messages = vec![ChatMessage::user(prompt)];
let (bedrock_messages, system_blocks) = Self::convert_messages(&messages, None)?;
let resolved_model = self.resolve_model_id();
let mut request = self.client.converse_stream().model_id(&resolved_model);
for msg in bedrock_messages {
request = request.messages(msg);
}
for block in system_blocks {
request = request.system(block);
}
debug!(
"Sending Bedrock ConverseStream request for model: {} (resolved: {})",
self.model, resolved_model
);
let response = request
.send()
.await
.map_err(|e| format_sdk_error("ConverseStream API", &e))?;
use futures::stream;
let mapped_stream = stream::unfold(response.stream, |mut rx| async move {
loop {
match rx.recv().await {
Ok(Some(event)) => {
use aws_sdk_bedrockruntime::types::ConverseStreamOutput as CSO;
match event {
CSO::ContentBlockDelta(delta_event) => {
if let Some(delta) = delta_event.delta() {
use aws_sdk_bedrockruntime::types::ContentBlockDelta;
if let ContentBlockDelta::Text(text) = delta {
return Some((Ok(text.clone()), rx));
}
}
}
CSO::MessageStop(_) => {
return None; }
_ => {
}
}
}
Ok(None) => {
return None; }
Err(e) => {
return Some((
Err(LlmError::ProviderError(format!(
"Bedrock stream error: {e}"
))),
rx,
));
}
}
}
});
Ok(Box::pin(mapped_stream))
}
fn supports_streaming(&self) -> bool {
true
}
fn supports_tool_streaming(&self) -> bool {
false }
fn supports_json_mode(&self) -> bool {
false
}
fn supports_function_calling(&self) -> bool {
true
}
}
#[async_trait]
impl EmbeddingProvider for BedrockProvider {
fn name(&self) -> &str {
"bedrock"
}
#[allow(clippy::misnamed_getters)]
fn model(&self) -> &str {
&self.embedding_model
}
fn dimension(&self) -> usize {
self.embedding_dimension
}
fn max_tokens(&self) -> usize {
Self::embedding_max_tokens_for_model(&self.embedding_model)
}
#[instrument(skip(self, texts), fields(provider = "bedrock", embedding_model = %self.embedding_model))]
async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
if Self::is_cohere_embedding(&self.embedding_model) {
let chunks: Vec<&[String]> = texts.chunks(96).collect();
let mut all_embeddings = Vec::with_capacity(texts.len());
for chunk in chunks {
let body = Self::build_embedding_request(&self.embedding_model, chunk)?;
let response = self
.client
.invoke_model()
.model_id(&self.embedding_model)
.content_type("application/json")
.accept("application/json")
.body(Blob::new(body))
.send()
.await
.map_err(|e| format_sdk_error("InvokeModel (embedding)", &e))?;
let response_bytes = response.body().as_ref();
let mut embeddings =
Self::parse_embedding_response(&self.embedding_model, response_bytes)?;
all_embeddings.append(&mut embeddings);
}
debug!(
"Generated {} Cohere embeddings ({} dims)",
all_embeddings.len(),
self.embedding_dimension
);
Ok(all_embeddings)
} else {
let mut all_embeddings = Vec::with_capacity(texts.len());
for text in texts {
let body = Self::build_embedding_request(
&self.embedding_model,
std::slice::from_ref(text),
)?;
let response = self
.client
.invoke_model()
.model_id(&self.embedding_model)
.content_type("application/json")
.accept("application/json")
.body(Blob::new(body))
.send()
.await
.map_err(|e| format_sdk_error("InvokeModel (embedding)", &e))?;
let response_bytes = response.body().as_ref();
let mut embeddings =
Self::parse_embedding_response(&self.embedding_model, response_bytes)?;
all_embeddings.append(&mut embeddings);
}
debug!(
"Generated {} Titan/Nova embeddings ({} dims)",
all_embeddings.len(),
self.embedding_dimension
);
Ok(all_embeddings)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_length_claude3() {
assert_eq!(
BedrockProvider::context_length_for_model("anthropic.claude-3-5-sonnet-20241022-v2:0"),
200_000
);
assert_eq!(
BedrockProvider::context_length_for_model("anthropic.claude-4-sonnet-20250514-v1:0"),
200_000
);
}
#[test]
fn test_context_length_claude2() {
assert_eq!(
BedrockProvider::context_length_for_model("anthropic.claude-2"),
100_000
);
}
#[test]
fn test_context_length_nova() {
assert_eq!(
BedrockProvider::context_length_for_model("amazon.nova-pro-v1:0"),
300_000
);
}
#[test]
fn test_context_length_llama() {
assert_eq!(
BedrockProvider::context_length_for_model("meta.llama3-70b-instruct-v1:0"),
128_000
);
}
#[test]
fn test_context_length_mistral() {
assert_eq!(
BedrockProvider::context_length_for_model("mistral.mistral-large-2407-v1:0"),
32_000
);
}
#[test]
fn test_context_length_cohere() {
assert_eq!(
BedrockProvider::context_length_for_model("cohere.command-r-plus-v1:0"),
128_000
);
}
#[test]
fn test_context_length_deepseek() {
assert_eq!(
BedrockProvider::context_length_for_model("deepseek.r1-v1:0"),
128_000
);
}
#[test]
fn test_context_length_qwen() {
assert_eq!(
BedrockProvider::context_length_for_model("qwen.qwen2-5-72b-instruct-v1:0"),
131_072
);
}
#[test]
fn test_context_length_writer() {
assert_eq!(
BedrockProvider::context_length_for_model("writer.palmyra-x-004-v1:0"),
128_000
);
}
#[test]
fn test_context_length_default() {
assert_eq!(
BedrockProvider::context_length_for_model("some-unknown-model"),
DEFAULT_MAX_CONTEXT
);
}
#[test]
fn test_stop_reason_mapping() {
assert_eq!(
BedrockProvider::map_stop_reason(&StopReason::EndTurn),
"stop"
);
assert_eq!(
BedrockProvider::map_stop_reason(&StopReason::MaxTokens),
"length"
);
assert_eq!(
BedrockProvider::map_stop_reason(&StopReason::StopSequence),
"stop"
);
assert_eq!(
BedrockProvider::map_stop_reason(&StopReason::ToolUse),
"tool_calls"
);
assert_eq!(
BedrockProvider::map_stop_reason(&StopReason::ContentFiltered),
"content_filter"
);
assert_eq!(
BedrockProvider::map_stop_reason(&StopReason::GuardrailIntervened),
"content_filter"
);
}
#[test]
fn test_build_inference_config_none() {
assert!(BedrockProvider::build_inference_config(None).is_none());
}
#[test]
fn test_build_inference_config_with_options() {
let opts = CompletionOptions {
max_tokens: Some(1024),
temperature: Some(0.7),
top_p: Some(0.9),
stop: Some(vec!["END".to_string()]),
..Default::default()
};
let config = BedrockProvider::build_inference_config(Some(&opts));
assert!(config.is_some());
}
#[test]
fn test_build_inference_config_empty_options() {
let opts = CompletionOptions::default();
let config = BedrockProvider::build_inference_config(Some(&opts));
assert!(config.is_none());
}
#[test]
fn test_convert_messages_system() {
let messages = vec![
ChatMessage::system("You are helpful"),
ChatMessage::user("Hello"),
];
let (bedrock_msgs, system_blocks) =
BedrockProvider::convert_messages(&messages, None).unwrap();
assert_eq!(system_blocks.len(), 1);
assert_eq!(bedrock_msgs.len(), 1);
}
#[test]
fn test_convert_messages_with_system_prompt_option() {
let messages = vec![ChatMessage::user("Hello")];
let (bedrock_msgs, system_blocks) =
BedrockProvider::convert_messages(&messages, Some("Be concise")).unwrap();
assert_eq!(system_blocks.len(), 1);
assert_eq!(bedrock_msgs.len(), 1);
}
#[test]
fn test_convert_messages_empty_system_prompt_ignored() {
let messages = vec![ChatMessage::user("Hello")];
let (_, system_blocks) = BedrockProvider::convert_messages(&messages, Some("")).unwrap();
assert_eq!(system_blocks.len(), 0);
}
#[test]
fn test_convert_messages_tool_result() {
let messages = vec![ChatMessage::tool_result("call_123", "Result data")];
let (bedrock_msgs, system_blocks) =
BedrockProvider::convert_messages(&messages, None).unwrap();
assert_eq!(system_blocks.len(), 0);
assert_eq!(bedrock_msgs.len(), 1);
}
#[test]
fn test_convert_messages_multiple_system_blocks() {
let messages = vec![
ChatMessage::system("System 1"),
ChatMessage::system("System 2"),
ChatMessage::user("Hello"),
];
let (bedrock_msgs, system_blocks) =
BedrockProvider::convert_messages(&messages, Some("Prefix system")).unwrap();
assert_eq!(system_blocks.len(), 3);
assert_eq!(bedrock_msgs.len(), 1);
}
#[test]
fn test_convert_messages_user_and_assistant() {
let messages = vec![
ChatMessage::user("Hello"),
ChatMessage::assistant("Hi there!"),
ChatMessage::user("How are you?"),
];
let (bedrock_msgs, system_blocks) =
BedrockProvider::convert_messages(&messages, None).unwrap();
assert_eq!(system_blocks.len(), 0);
assert_eq!(bedrock_msgs.len(), 3);
}
#[test]
fn test_json_to_document_null() {
let doc = BedrockProvider::json_to_document(&serde_json::Value::Null);
assert!(matches!(doc, Document::Null));
}
#[test]
fn test_json_to_document_bool() {
let doc = BedrockProvider::json_to_document(&serde_json::json!(true));
assert!(matches!(doc, Document::Bool(true)));
}
#[test]
fn test_json_to_document_string() {
let doc = BedrockProvider::json_to_document(&serde_json::json!("hello"));
assert!(matches!(doc, Document::String(s) if s == "hello"));
}
#[test]
fn test_json_to_document_number() {
let doc = BedrockProvider::json_to_document(&serde_json::json!(42));
assert!(matches!(
doc,
Document::Number(aws_smithy_types::Number::PosInt(42))
));
}
#[test]
fn test_json_to_document_negative_number() {
let doc = BedrockProvider::json_to_document(&serde_json::json!(-5));
assert!(matches!(
doc,
Document::Number(aws_smithy_types::Number::NegInt(-5))
));
}
#[test]
fn test_json_to_document_float() {
let doc = BedrockProvider::json_to_document(&serde_json::json!(1.125));
if let Document::Number(aws_smithy_types::Number::Float(f)) = doc {
assert!((f - 1.125).abs() < f64::EPSILON);
} else {
panic!("Expected float document");
}
}
#[test]
fn test_json_to_document_array() {
let doc = BedrockProvider::json_to_document(&serde_json::json!([1, "two", null]));
if let Document::Array(arr) = doc {
assert_eq!(arr.len(), 3);
} else {
panic!("Expected array document");
}
}
#[test]
fn test_json_to_document_object() {
let doc = BedrockProvider::json_to_document(&serde_json::json!({"key": "value"}));
if let Document::Object(obj) = doc {
assert_eq!(obj.len(), 1);
assert!(obj.contains_key("key"));
} else {
panic!("Expected object document");
}
}
#[test]
fn test_document_to_json_roundtrip() {
let original = serde_json::json!({
"name": "test",
"age": 30,
"active": true,
"tags": ["a", "b"],
"nested": {"x": 1.5}
});
let doc = BedrockProvider::json_to_document(&original);
let recovered = BedrockProvider::document_to_json(&doc);
assert_eq!(original, recovered);
}
#[test]
fn test_extract_content_text_only() {
let msg = Message::builder()
.role(ConversationRole::Assistant)
.content(ContentBlock::Text("Hello world".to_string()))
.build()
.unwrap();
let output = ConverseOutput::Message(msg);
let (text, tool_calls) = BedrockProvider::extract_content(&output);
assert_eq!(text, "Hello world");
assert!(tool_calls.is_empty());
}
#[test]
fn test_extract_content_multiple_text_blocks() {
let msg = Message::builder()
.role(ConversationRole::Assistant)
.content(ContentBlock::Text("Hello ".to_string()))
.content(ContentBlock::Text("world".to_string()))
.build()
.unwrap();
let output = ConverseOutput::Message(msg);
let (text, _) = BedrockProvider::extract_content(&output);
assert_eq!(text, "Hello world");
}
#[test]
fn test_extract_content_with_tool_use() {
let tool_use = ToolUseBlock::builder()
.tool_use_id("call_123")
.name("get_weather")
.input(Document::Object(
vec![("city".to_string(), Document::String("Paris".to_string()))]
.into_iter()
.collect(),
))
.build()
.unwrap();
let msg = Message::builder()
.role(ConversationRole::Assistant)
.content(ContentBlock::Text("Let me check the weather.".to_string()))
.content(ContentBlock::ToolUse(tool_use))
.build()
.unwrap();
let output = ConverseOutput::Message(msg);
let (text, tool_calls) = BedrockProvider::extract_content(&output);
assert_eq!(text, "Let me check the weather.");
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].id, "call_123");
assert_eq!(tool_calls[0].call_type, "function");
assert_eq!(tool_calls[0].function.name, "get_weather");
assert!(tool_calls[0].function.arguments.contains("Paris"));
}
#[test]
fn test_build_tool_config_empty_tools() {
let result = BedrockProvider::build_tool_config(&[], None).unwrap();
assert!(result.is_none());
}
#[test]
fn test_build_tool_config_auto_none_returns_none() {
let tools = vec![EdgequakeToolDefinition::function(
"test_fn",
"A test function",
serde_json::json!({"type": "object", "properties": {}}),
)];
let choice = EdgequakeToolChoice::none();
let result = BedrockProvider::build_tool_config(&tools, Some(&choice)).unwrap();
assert!(
result.is_none(),
"tool_choice='none' should omit tool config"
);
}
#[test]
fn test_build_tool_config_with_tools() {
let tools = vec![EdgequakeToolDefinition::function(
"search",
"Search the web",
serde_json::json!({
"type": "object",
"properties": {
"query": {"type": "string"}
},
"required": ["query"]
}),
)];
let result = BedrockProvider::build_tool_config(&tools, None).unwrap();
assert!(result.is_some());
}
#[test]
fn test_provider_name_and_model() {
assert_eq!(
BedrockProvider::context_length_for_model("anthropic.claude-3-5-haiku-20241022-v1:0"),
200_000
);
}
#[test]
fn test_with_model_updates_context() {
let claude =
BedrockProvider::context_length_for_model("anthropic.claude-3-5-sonnet-20241022-v2:0");
let nova = BedrockProvider::context_length_for_model("amazon.nova-pro-v1:0");
let llama = BedrockProvider::context_length_for_model("meta.llama3-70b-instruct-v1:0");
assert_ne!(claude, nova);
assert_ne!(nova, llama);
}
#[test]
fn test_resolve_model_id_bare_us_region() {
assert_eq!(
BedrockProvider::resolve_model_id_for_region("amazon.nova-lite-v1:0", "us-east-1"),
"us.amazon.nova-lite-v1:0"
);
}
#[test]
fn test_resolve_model_id_bare_eu_region() {
assert_eq!(
BedrockProvider::resolve_model_id_for_region("amazon.nova-lite-v1:0", "eu-west-1"),
"eu.amazon.nova-lite-v1:0"
);
}
#[test]
fn test_resolve_model_id_bare_ap_region() {
assert_eq!(
BedrockProvider::resolve_model_id_for_region(
"anthropic.claude-3-haiku-20240307-v1:0",
"ap-southeast-1"
),
"ap.anthropic.claude-3-haiku-20240307-v1:0"
);
}
#[test]
fn test_resolve_model_id_already_prefixed_us() {
assert_eq!(
BedrockProvider::resolve_model_id_for_region("us.amazon.nova-lite-v1:0", "eu-west-1"),
"us.amazon.nova-lite-v1:0"
);
}
#[test]
fn test_resolve_model_id_already_prefixed_eu() {
assert_eq!(
BedrockProvider::resolve_model_id_for_region("eu.amazon.nova-lite-v1:0", "us-east-1"),
"eu.amazon.nova-lite-v1:0"
);
}
#[test]
fn test_resolve_model_id_already_prefixed_global() {
assert_eq!(
BedrockProvider::resolve_model_id_for_region(
"global.anthropic.claude-sonnet-4-20250514-v1:0",
"us-east-1"
),
"global.anthropic.claude-sonnet-4-20250514-v1:0"
);
}
#[test]
fn test_resolve_model_id_arn_passthrough() {
let arn = "arn:aws:bedrock:us-east-1:123456789:inference-profile/my-profile";
assert_eq!(
BedrockProvider::resolve_model_id_for_region(arn, "eu-west-1"),
arn
);
}
#[test]
fn test_resolve_model_id_other_geographies() {
assert_eq!(
BedrockProvider::resolve_model_id_for_region("amazon.nova-lite-v1:0", "ca-central-1"),
"ca.amazon.nova-lite-v1:0"
);
assert_eq!(
BedrockProvider::resolve_model_id_for_region("amazon.nova-lite-v1:0", "sa-east-1"),
"sa.amazon.nova-lite-v1:0"
);
assert_eq!(
BedrockProvider::resolve_model_id_for_region("amazon.nova-lite-v1:0", "me-south-1"),
"me.amazon.nova-lite-v1:0"
);
assert_eq!(
BedrockProvider::resolve_model_id_for_region("amazon.nova-lite-v1:0", "af-south-1"),
"af.amazon.nova-lite-v1:0"
);
}
#[test]
fn test_resolve_model_id_no_prefix_for_non_profile_models() {
assert_eq!(
BedrockProvider::resolve_model_id_for_region(
"mistral.mistral-large-2402-v1:0",
"eu-west-1"
),
"mistral.mistral-large-2402-v1:0"
);
assert_eq!(
BedrockProvider::resolve_model_id_for_region("google.gemma-3-27b-it", "eu-west-1"),
"google.gemma-3-27b-it"
);
assert_eq!(
BedrockProvider::resolve_model_id_for_region("qwen.qwen3-32b-v1:0", "us-east-1"),
"qwen.qwen3-32b-v1:0"
);
assert_eq!(
BedrockProvider::resolve_model_id_for_region(
"nvidia.nemotron-nano-12b-v2",
"us-east-1"
),
"nvidia.nemotron-nano-12b-v2"
);
assert_eq!(
BedrockProvider::resolve_model_id_for_region("minimax.minimax-m2", "eu-west-1"),
"minimax.minimax-m2"
);
}
#[test]
fn test_dimension_titan_embed_v2() {
assert_eq!(
BedrockProvider::dimension_for_model("amazon.titan-embed-text-v2:0"),
1024
);
}
#[test]
fn test_dimension_titan_embed_v1() {
assert_eq!(
BedrockProvider::dimension_for_model("amazon.titan-embed-text-v1"),
1536
);
}
#[test]
fn test_dimension_titan_embed_g1() {
assert_eq!(
BedrockProvider::dimension_for_model("amazon.titan-embed-g1-text-02"),
1536
);
}
#[test]
fn test_dimension_cohere_embed_v4() {
assert_eq!(
BedrockProvider::dimension_for_model("cohere.embed-v4:0"),
1536
);
}
#[test]
fn test_dimension_cohere_embed_v3() {
assert_eq!(
BedrockProvider::dimension_for_model("cohere.embed-english-v3"),
1024
);
assert_eq!(
BedrockProvider::dimension_for_model("cohere.embed-multilingual-v3"),
1024
);
}
#[test]
fn test_dimension_unknown_defaults() {
assert_eq!(
BedrockProvider::dimension_for_model("some-unknown-embed-model"),
DEFAULT_EMBEDDING_DIMENSION
);
}
#[test]
fn test_build_embedding_request_titan() {
let body = BedrockProvider::build_embedding_request(
"amazon.titan-embed-text-v2:0",
&["Hello world".to_string()],
)
.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["inputText"], "Hello world");
}
#[test]
fn test_build_embedding_request_cohere() {
let body = BedrockProvider::build_embedding_request(
"cohere.embed-english-v3",
&["Hello".to_string(), "World".to_string()],
)
.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["texts"], serde_json::json!(["Hello", "World"]));
assert_eq!(json["input_type"], "search_query");
}
#[test]
fn test_build_embedding_request_titan_rejects_batch() {
let result = BedrockProvider::build_embedding_request(
"amazon.titan-embed-text-v2:0",
&["Hello".to_string(), "World".to_string()],
);
assert!(result.is_err());
}
#[test]
fn test_parse_embedding_response_titan() {
let response = serde_json::json!({
"embedding": [0.1, 0.2, 0.3],
"inputTextTokenCount": 3
});
let bytes = serde_json::to_vec(&response).unwrap();
let result =
BedrockProvider::parse_embedding_response("amazon.titan-embed-text-v2:0", &bytes)
.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].len(), 3);
assert!((result[0][0] - 0.1).abs() < 0.001);
}
#[test]
fn test_parse_embedding_response_cohere_list() {
let response = serde_json::json!({
"embeddings": [[0.1, 0.2], [0.3, 0.4]],
"id": "test",
"response_type": "embeddings_floats"
});
let bytes = serde_json::to_vec(&response).unwrap();
let result =
BedrockProvider::parse_embedding_response("cohere.embed-english-v3", &bytes).unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].len(), 2);
assert_eq!(result[1].len(), 2);
}
#[test]
fn test_parse_embedding_response_cohere_dict() {
let response = serde_json::json!({
"embeddings": {"float": [[0.5, 0.6, 0.7]]},
"id": "test",
"response_type": "embeddings_by_type"
});
let bytes = serde_json::to_vec(&response).unwrap();
let result =
BedrockProvider::parse_embedding_response("cohere.embed-v4:0", &bytes).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].len(), 3);
assert!((result[0][0] - 0.5).abs() < 0.001);
}
#[test]
fn test_is_cohere_embedding() {
assert!(BedrockProvider::is_cohere_embedding(
"cohere.embed-english-v3"
));
assert!(BedrockProvider::is_cohere_embedding("cohere.embed-v4:0"));
assert!(!BedrockProvider::is_cohere_embedding(
"amazon.titan-embed-text-v2:0"
));
assert!(!BedrockProvider::is_cohere_embedding(
"cohere.command-r-plus-v1:0"
));
}
#[test]
fn test_embedding_max_tokens() {
assert_eq!(
BedrockProvider::embedding_max_tokens_for_model("amazon.titan-embed-text-v2:0"),
8192
);
assert_eq!(
BedrockProvider::embedding_max_tokens_for_model("amazon.titan-embed-text-v1"),
512
);
assert_eq!(
BedrockProvider::embedding_max_tokens_for_model("cohere.embed-english-v3"),
2048
);
}
}