use std::env;
use std::fmt::Debug;
use aws_config::BehaviorVersion;
use aws_sdk_bedrockruntime::Client as BedrockClient;
use aws_sdk_bedrockruntime::operation::invoke_model_with_response_stream::InvokeModelWithResponseStreamError;
use aws_sdk_bedrockruntime::primitives::{Blob, event_stream::EventReceiver};
use aws_sdk_bedrockruntime::types::{ResponseStream, error::ResponseStreamError};
use aws_smithy_runtime_api::client::orchestrator::HttpResponse;
use aws_smithy_runtime_api::client::result::SdkError;
use aws_smithy_types::error::metadata::ProvideErrorMetadata;
use aws_smithy_types::event_stream::RawMessage;
use defect_core::error::BoxError;
use defect_core::llm::{
Capabilities, CompletionRequest, FeatureSupport, LlmProvider, ModelCapabilityOverrides,
ModelInfo, ProtocolId, ProviderError, ProviderErrorKind, ProviderInfo, ProviderStream,
RateLimitScope, ThinkingEcho, TimeoutPhase,
};
use futures::FutureExt;
use futures::future::BoxFuture;
use futures::{Stream, stream};
use serde_json::Value;
use sse_stream::Sse;
use tokio_util::sync::CancellationToken;
use tracing::warn;
use std::collections::HashMap;
use crate::protocol::anthropic_messages::{self, ThinkingWireFormat};
use crate::wire::anthropic::components as wire;
const DEFAULT_AWS_REGION: &str = "us-east-1";
const DEFAULT_VENDOR: &str = "bedrock";
const DEFAULT_DISPLAY_NAME: &str = "Amazon Bedrock";
const ANTHROPIC_VERSION: &str = "bedrock-2023-05-31";
const CONTENT_TYPE_JSON: &str = "application/json";
const AWS_REGION_ENV: &str = "AWS_REGION";
const AWS_PROFILE_ENV: &str = "AWS_PROFILE";
const BODY_MODEL_FIELD: &str = "model";
const BODY_STREAM_FIELD: &str = "stream";
const BODY_ANTHROPIC_VERSION_FIELD: &str = "anthropic_version";
const BODY_ANTHROPIC_BETA_FIELD: &str = "anthropic_beta";
const ERR_ACCESS_DENIED: &str = "AccessDeniedException";
const ERR_VALIDATION: &str = "ValidationException";
const ERR_MODEL_NOT_READY: &str = "ModelNotReadyException";
const ERR_SERVICE_UNAVAILABLE: &str = "ServiceUnavailableException";
const ERR_THROTTLING: &str = "ThrottlingException";
const ERR_INTERNAL_SERVER: &str = "InternalServerException";
const ERR_MODEL_STREAM: &str = "ModelStreamErrorException";
const ERR_MODEL_TIMEOUT: &str = "ModelTimeoutException";
const ERR_RESOURCE_NOT_FOUND: &str = "ResourceNotFoundException";
const ERR_SERVICE_QUOTA_EXCEEDED: &str = "ServiceQuotaExceededException";
const ERR_MODEL_ERROR: &str = "ModelErrorException";
#[derive(Debug, Default, Clone)]
pub struct BedrockModel {
pub id: String,
pub context_window: Option<u64>,
pub max_output_tokens: Option<u64>,
pub thinking_format: Option<ThinkingWireFormat>,
}
impl BedrockModel {
#[must_use]
pub fn new(id: impl Into<String>) -> Self {
Self {
id: id.into(),
context_window: None,
max_output_tokens: None,
thinking_format: None,
}
}
}
#[derive(Debug, Default, Clone)]
pub struct BedrockConfig {
pub vendor: Option<String>,
pub display_name: Option<String>,
pub base_url: Option<String>,
pub default_model: Option<String>,
pub models: Vec<BedrockModel>,
pub aws_profile: Option<String>,
pub aws_region: Option<String>,
pub anthropic_beta: Option<Vec<String>>,
}
impl BedrockConfig {
fn resolve_region(&self) -> String {
self.aws_region
.clone()
.or_else(|| env::var(AWS_REGION_ENV).ok())
.unwrap_or_else(|| DEFAULT_AWS_REGION.to_owned())
}
fn resolve_profile(&self) -> Option<String> {
self.aws_profile
.clone()
.or_else(|| env::var(AWS_PROFILE_ENV).ok())
}
}
pub struct BedrockProvider {
client: BedrockClient,
info: ProviderInfo,
capabilities: Capabilities,
models: Vec<ModelInfo>,
anthropic_beta: Vec<String>,
thinking_formats: HashMap<String, ThinkingWireFormat>,
}
impl std::fmt::Debug for BedrockProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BedrockProvider")
.field("info", &self.info)
.field("capabilities", &self.capabilities)
.finish_non_exhaustive()
}
}
impl BedrockProvider {
pub async fn new(config: BedrockConfig) -> Result<Self, ProviderError> {
let region = config.resolve_region();
let vendor = config
.vendor
.clone()
.unwrap_or_else(|| DEFAULT_VENDOR.to_owned());
let display_name = config
.display_name
.clone()
.unwrap_or_else(|| DEFAULT_DISPLAY_NAME.to_owned());
let mut loader =
aws_config::defaults(BehaviorVersion::latest()).region(aws_config::Region::new(region));
if let Some(profile) = config.resolve_profile() {
loader = loader.profile_name(profile);
}
if let Some(endpoint) = config.base_url.clone() {
loader = loader.endpoint_url(endpoint);
}
let sdk_config = loader.load().await;
let client = BedrockClient::new(&sdk_config);
let anthropic_beta = config.anthropic_beta.unwrap_or_default();
let thinking_formats: HashMap<String, ThinkingWireFormat> = config
.models
.iter()
.filter_map(|m| m.thinking_format.map(|f| (m.id.clone(), f)))
.collect();
Ok(Self {
client,
anthropic_beta,
thinking_formats,
info: ProviderInfo {
vendor,
protocol: ProtocolId::AnthropicMessages,
display_name,
},
capabilities: Capabilities {
tool_calls: FeatureSupport::Supported,
parallel_tool_calls: FeatureSupport::Supported,
thinking: FeatureSupport::Supported,
vision: FeatureSupport::Supported,
prompt_cache: FeatureSupport::Supported,
thinking_echo: ThinkingEcho::Required,
},
models: model_infos_from_config(config.models, config.default_model),
})
}
}
fn model_infos_from_config(
models: Vec<BedrockModel>,
default_model: Option<String>,
) -> Vec<ModelInfo> {
let mut models = models;
if let Some(default_model) = default_model
&& !models.iter().any(|m| m.id == default_model)
{
models.insert(0, BedrockModel::new(default_model));
}
models
.into_iter()
.map(|m| ModelInfo {
id: m.id,
display_name: None,
context_window: m.context_window,
max_output_tokens: m.max_output_tokens,
deprecated: false,
capabilities_overrides: ModelCapabilityOverrides::default(),
})
.collect()
}
impl LlmProvider for BedrockProvider {
fn info(&self) -> ProviderInfo {
self.info.clone()
}
fn capabilities(&self) -> Capabilities {
self.capabilities
}
fn list_models(&self) -> BoxFuture<'_, Result<Vec<ModelInfo>, ProviderError>> {
async move { Ok(self.models.clone()) }.boxed()
}
fn model_info(&self, model_id: &str) -> Option<ModelInfo> {
self.models
.iter()
.find(|model| model.id == model_id)
.cloned()
}
fn complete(
&self,
req: CompletionRequest,
cancel: CancellationToken,
) -> BoxFuture<'_, Result<ProviderStream, ProviderError>> {
async move {
let thinking_format = self
.thinking_formats
.get(&req.model)
.copied()
.unwrap_or_default();
let body = anthropic_messages::encode_request(&req, thinking_format);
let payload = serde_json::to_vec(&bedrock_request_body(body, &self.anthropic_beta))
.map_err(|e| {
ProviderError::new(ProviderErrorKind::BadRequest {
hint: Some(e.to_string()),
})
})?;
let resp = tokio::select! {
biased;
_ = cancel.cancelled() => {
return Err(ProviderError::new(ProviderErrorKind::Canceled));
}
r = self
.client
.invoke_model_with_response_stream()
.model_id(req.model.clone())
.content_type(CONTENT_TYPE_JSON)
.accept(CONTENT_TYPE_JSON)
.body(Blob::new(payload))
.send() => r,
};
let output = match resp {
Ok(output) => output,
Err(err) => return Err(map_bedrock_error(err, &req.model)),
};
let events = bedrock_event_stream(output.body, cancel.clone());
let chunks = anthropic_messages::decode_stream_provider_errors(events, cancel);
Ok(Box::pin(chunks) as ProviderStream)
}
.boxed()
}
}
fn bedrock_request_body(body: wire::CreateMessageParams, anthropic_beta: &[String]) -> Value {
let mut value = serde_json::to_value(body).expect("Anthropic wire body should serialize");
if let Some(obj) = value.as_object_mut() {
obj.remove(BODY_MODEL_FIELD);
obj.remove(BODY_STREAM_FIELD);
obj.insert(
BODY_ANTHROPIC_VERSION_FIELD.to_owned(),
Value::String(ANTHROPIC_VERSION.to_owned()),
);
if !anthropic_beta.is_empty() {
obj.insert(
BODY_ANTHROPIC_BETA_FIELD.to_owned(),
Value::Array(
anthropic_beta
.iter()
.map(|flag| Value::String(flag.clone()))
.collect(),
),
);
}
}
value
}
type InvokeModelError = SdkError<InvokeModelWithResponseStreamError, HttpResponse>;
type BedrockStreamError = SdkError<ResponseStreamError, RawMessage>;
#[derive(Debug, thiserror::Error)]
#[error("{message}")]
struct BedrockSdkError {
message: String,
}
fn map_bedrock_error(err: InvokeModelError, model: &str) -> ProviderError {
match err {
SdkError::DispatchFailure(e) => {
ProviderError::new(ProviderErrorKind::Transport(box_debug_error(e)))
}
SdkError::TimeoutError(_) => ProviderError::new(ProviderErrorKind::Timeout {
phase: TimeoutPhase::Total,
}),
SdkError::ConstructionFailure(e) => ProviderError::new(ProviderErrorKind::BadRequest {
hint: Some(format!("{e:?}")),
}),
SdkError::ResponseError(e) => {
ProviderError::new(ProviderErrorKind::Transport(box_debug_error(e)))
}
SdkError::ServiceError(e) => bedrock_service_error(e.err(), Some(model)),
unknown => ProviderError::new(ProviderErrorKind::Other(box_debug_error(unknown))),
}
}
fn box_debug_error(error: impl Debug) -> BoxError {
BoxError::new(BedrockSdkError {
message: format!("{error:?}"),
})
}
fn bedrock_event_stream(
body: EventReceiver<ResponseStream, ResponseStreamError>,
cancel: CancellationToken,
) -> impl Stream<Item = Result<Sse, ProviderError>> + Send {
stream::unfold((body, cancel), |(mut body, cancel)| async move {
loop {
if cancel.is_cancelled() {
return None;
}
let received = tokio::select! {
biased;
_ = cancel.cancelled() => return None,
item = body.recv() => item,
};
let item = match received {
Ok(Some(ResponseStream::Chunk(chunk))) => bedrock_chunk_to_sse(chunk),
Ok(Some(event)) if event.is_unknown() => {
warn!("bedrock returned an unknown response stream event");
continue;
}
Ok(Some(event)) => {
warn!(
?event,
"bedrock returned an unhandled response stream event"
);
continue;
}
Ok(None) => return None,
Err(err) => Err(map_bedrock_stream_error(err)),
};
return Some((item, (body, cancel)));
}
})
}
fn bedrock_chunk_to_sse(
chunk: aws_sdk_bedrockruntime::types::PayloadPart,
) -> Result<Sse, ProviderError> {
let Some(bytes) = chunk.bytes else {
return Err(ProviderError::new(ProviderErrorKind::ProtocolViolation {
hint: "bedrock response chunk did not include bytes".into(),
}));
};
let data = String::from_utf8(bytes.into_inner())
.map_err(|e| ProviderError::new(ProviderErrorKind::Malformed(BoxError::new(e))))?;
Ok(Sse {
event: None,
data: Some(data),
id: None,
retry: None,
})
}
fn map_bedrock_stream_error(err: BedrockStreamError) -> ProviderError {
match err {
SdkError::DispatchFailure(e) => {
ProviderError::new(ProviderErrorKind::Transport(box_debug_error(e)))
}
SdkError::TimeoutError(_) => ProviderError::new(ProviderErrorKind::Timeout {
phase: TimeoutPhase::ReadBody,
}),
SdkError::ConstructionFailure(e) => ProviderError::new(ProviderErrorKind::BadRequest {
hint: Some(format!("{e:?}")),
}),
SdkError::ResponseError(e) => {
ProviderError::new(ProviderErrorKind::Transport(box_debug_error(e)))
}
SdkError::ServiceError(e) => bedrock_service_error(e.err(), None),
unknown => ProviderError::new(ProviderErrorKind::Other(box_debug_error(unknown))),
}
}
fn bedrock_service_error(source: &dyn ProvideErrorMetadata, model: Option<&str>) -> ProviderError {
let hint = source.message().map(str::to_owned);
match source.code() {
Some(ERR_ACCESS_DENIED) => ProviderError::new(ProviderErrorKind::AuthRejected { hint }),
Some(ERR_VALIDATION) => ProviderError::new(ProviderErrorKind::BadRequest { hint }),
Some(ERR_RESOURCE_NOT_FOUND) => ProviderError::new(ProviderErrorKind::ModelNotFound {
model: model.unwrap_or(DEFAULT_VENDOR).to_owned(),
}),
Some(ERR_SERVICE_QUOTA_EXCEEDED) => {
ProviderError::new(ProviderErrorKind::QuotaExceeded { hint })
}
Some(ERR_THROTTLING) => ProviderError::new(ProviderErrorKind::RateLimit {
retry_after: None,
scope: RateLimitScope::Unspecified,
}),
Some(ERR_MODEL_TIMEOUT) => ProviderError::new(ProviderErrorKind::Timeout {
phase: TimeoutPhase::ReadBody,
}),
Some(ERR_MODEL_STREAM) => {
ProviderError::new(ProviderErrorKind::ServerStreamAborted { hint })
}
Some(ERR_MODEL_NOT_READY)
| Some(ERR_SERVICE_UNAVAILABLE)
| Some(ERR_INTERNAL_SERVER)
| Some(ERR_MODEL_ERROR) => {
ProviderError::new(ProviderErrorKind::ServerError { status: None, hint })
}
Some(_) | None => ProviderError::new(ProviderErrorKind::ServerError { status: None, hint }),
}
}
#[cfg(test)]
mod tests;