use super::{build_http_client, RateLimitHeaders};
use crate::config::TlsConfig;
use aws_credential_types::Credentials;
use aws_sigv4::http_request::{sign, SignableBody, SignableRequest, SigningSettings};
use aws_sigv4::sign::v4;
use reqwest::Client;
use tokio::time::sleep;
use zeroize::Zeroizing;
#[derive(Clone)]
pub struct BedrockClient {
client: Client,
region: String,
access_key_id: String,
secret_access_key: Zeroizing<String>,
session_token: Option<Zeroizing<String>>,
big_model: String,
small_model: String,
}
#[derive(Debug)]
pub enum BedrockClientError {
Transport(String),
ApiError { status: u16, body: bytes::Bytes },
Signing(String),
}
impl std::fmt::Display for BedrockClientError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Transport(msg) => write!(f, "Bedrock transport error: {msg}"),
Self::ApiError { status, .. } => write!(f, "Bedrock API error (status {status})"),
Self::Signing(msg) => write!(f, "Bedrock signing error: {msg}"),
}
}
}
impl BedrockClient {
pub fn new(
region: String,
credentials: Credentials,
big_model: String,
small_model: String,
tls: &TlsConfig,
) -> Self {
let client = build_http_client(tls);
let access_key_id = credentials.access_key_id().to_string();
let secret_access_key = Zeroizing::new(credentials.secret_access_key().to_string());
let session_token = credentials
.session_token()
.map(|t| Zeroizing::new(t.to_string()));
Self {
client,
region,
access_key_id,
secret_access_key,
session_token,
big_model,
small_model,
}
}
pub fn big_model(&self) -> &str {
&self.big_model
}
pub fn small_model(&self) -> &str {
&self.small_model
}
pub fn native_endpoint_url(&self, model_id: &str, suffix: &str) -> String {
format!(
"https://bedrock-runtime.{}.amazonaws.com/model/{}/{suffix}",
self.region, model_id
)
}
fn invoke_url(&self, model_id: &str) -> String {
self.native_endpoint_url(model_id, "invoke")
}
fn invoke_stream_url(&self, model_id: &str) -> String {
self.native_endpoint_url(model_id, "invoke-with-response-stream")
}
pub async fn forward_native(
&self,
url: &str,
body: bytes::Bytes,
streaming: bool,
) -> Result<reqwest::Response, BedrockClientError> {
let content_type = "application/json";
let accept = if streaming {
"application/vnd.amazon.eventstream"
} else {
"application/json"
};
let base_headers = [("content-type", content_type), ("accept", accept)];
let signing_headers = self.sign_request("POST", url, &body, &base_headers)?;
let mut builder = self
.client
.post(url)
.header("content-type", content_type)
.header("accept", accept)
.body(body);
for (k, v) in &signing_headers {
builder = builder.header(k.as_str(), v.as_str());
}
let response = builder
.send()
.await
.map_err(|e| BedrockClientError::Transport(e.to_string()))?;
let status = response.status().as_u16();
if !(200..300).contains(&status) {
let resp_body = response
.bytes()
.await
.map_err(|e| BedrockClientError::Transport(e.to_string()))?;
return Err(BedrockClientError::ApiError {
status,
body: resp_body,
});
}
Ok(response)
}
fn sign_request(
&self,
method: &str,
url: &str,
body_bytes: &[u8],
extra_headers: &[(&str, &str)],
) -> Result<Vec<(String, String)>, BedrockClientError> {
let creds = Credentials::new(
self.access_key_id.clone(),
self.secret_access_key.as_str(),
self.session_token.as_deref().map(|s| s.to_string()),
None, "anyllm", );
let identity: aws_smithy_runtime_api::client::identity::Identity = creds.into();
let settings = SigningSettings::default();
let params = v4::SigningParams::builder()
.identity(&identity)
.region(&self.region)
.name("bedrock")
.time(std::time::SystemTime::now())
.settings(settings)
.build()
.map_err(|e| BedrockClientError::Signing(e.to_string()))?;
let signing_params = params.into();
let signable = SignableRequest::new(
method,
url,
extra_headers.iter().copied(),
SignableBody::Bytes(body_bytes),
)
.map_err(|e| BedrockClientError::Signing(e.to_string()))?;
let (instructions, _signature) = sign(signable, &signing_params)
.map_err(|e| BedrockClientError::Signing(e.to_string()))?
.into_parts();
let headers: Vec<(String, String)> = instructions
.headers()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect();
Ok(headers)
}
pub async fn forward(
&self,
body: bytes::Bytes,
model_id: &str,
) -> Result<(bytes::Bytes, RateLimitHeaders), BedrockClientError> {
let response = self.send_with_retry(body, model_id, false).await?;
let rate_limits = RateLimitHeaders::default();
let resp_body = response
.bytes()
.await
.map_err(|e| BedrockClientError::Transport(e.to_string()))?;
Ok((resp_body, rate_limits))
}
pub async fn forward_stream(
&self,
body: bytes::Bytes,
model_id: &str,
) -> Result<(reqwest::Response, RateLimitHeaders), BedrockClientError> {
let response = self.send_with_retry(body, model_id, true).await?;
let rate_limits = RateLimitHeaders::default();
Ok((response, rate_limits))
}
async fn send_with_retry(
&self,
body: bytes::Bytes,
model_id: &str,
stream: bool,
) -> Result<reqwest::Response, BedrockClientError> {
let url = if stream {
self.invoke_stream_url(model_id)
} else {
self.invoke_url(model_id)
};
let content_type = "application/json";
let accept = if stream {
"application/vnd.amazon.eventstream"
} else {
"application/json"
};
for attempt in 0..=super::MAX_RETRIES {
let base_headers = [("content-type", content_type), ("accept", accept)];
let signing_headers = self.sign_request("POST", &url, &body, &base_headers)?;
let mut rb = self
.client
.post(&url)
.header("content-type", content_type)
.header("accept", accept)
.body(body.clone());
for (k, v) in &signing_headers {
rb = rb.header(k.as_str(), v.as_str());
}
let response = rb
.send()
.await
.map_err(|e| BedrockClientError::Transport(e.to_string()))?;
let status = response.status().as_u16();
if (200..300).contains(&status) {
return Ok(response);
}
if attempt < super::MAX_RETRIES && super::is_retryable(status) {
let retry_after = super::parse_retry_after(response.headers());
let delay = super::backoff_delay(attempt, retry_after);
tracing::warn!(
status,
attempt = attempt + 1,
max_retries = super::MAX_RETRIES,
delay_ms = delay.as_millis() as u64,
"retryable error from Bedrock, backing off"
);
drop(response.bytes().await);
sleep(delay).await;
continue;
}
let resp_body = response.bytes().await.unwrap_or_default();
return Err(BedrockClientError::ApiError {
status,
body: resp_body,
});
}
unreachable!("loop runs MAX_RETRIES+1 times and always returns")
}
}
pub mod eventstream {
use bytes::BytesMut;
const MIN_FRAME_SIZE: usize = 16;
pub fn decode_frame(buf: &mut BytesMut) -> Result<Option<Vec<u8>>, String> {
if buf.len() < MIN_FRAME_SIZE {
return Ok(None);
}
let total_len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
if buf.len() < total_len {
return Ok(None); }
let headers_len = u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]]) as usize;
let prelude_crc_stored = u32::from_be_bytes([buf[8], buf[9], buf[10], buf[11]]);
let prelude_crc_computed = crc32fast::hash(&buf[..8]);
if prelude_crc_stored != prelude_crc_computed {
return Err(format!(
"event stream prelude CRC mismatch: stored={prelude_crc_stored:#010x} computed={prelude_crc_computed:#010x}"
));
}
let msg_crc_offset = total_len - 4;
let msg_crc_stored = u32::from_be_bytes([
buf[msg_crc_offset],
buf[msg_crc_offset + 1],
buf[msg_crc_offset + 2],
buf[msg_crc_offset + 3],
]);
let msg_crc_computed = crc32fast::hash(&buf[..msg_crc_offset]);
if msg_crc_stored != msg_crc_computed {
let _ = buf.split_to(total_len);
return Err(format!(
"event stream message CRC mismatch: stored={msg_crc_stored:#010x} computed={msg_crc_computed:#010x}"
));
}
let headers_start = 12; let payload_start = headers_start + headers_len;
let payload_end = total_len.saturating_sub(4);
if payload_start > payload_end || payload_end > buf.len() {
let _ = buf.split_to(total_len);
return Ok(Some(Vec::new()));
}
let payload = buf[payload_start..payload_end].to_vec();
let _ = buf.split_to(total_len);
Ok(Some(payload))
}
pub fn extract_event_from_payload(payload: &[u8]) -> Option<String> {
if payload.is_empty() {
return None;
}
let parsed: serde_json::Value = serde_json::from_slice(payload).ok()?;
let b64 = parsed.get("bytes")?.as_str()?;
use base64::Engine;
let decoded = base64::engine::general_purpose::STANDARD.decode(b64).ok()?;
String::from_utf8(decoded).ok()
}
}
#[cfg(test)]
mod tests {
use super::eventstream;
use bytes::BytesMut;
fn build_frame(headers: &[u8], payload: &[u8]) -> Vec<u8> {
let total_len = (12 + headers.len() + payload.len() + 4) as u32;
let headers_len = headers.len() as u32;
let mut frame: Vec<u8> = Vec::with_capacity(total_len as usize);
frame.extend_from_slice(&total_len.to_be_bytes());
frame.extend_from_slice(&headers_len.to_be_bytes());
let prelude_crc = crc32fast::hash(&frame[..8]);
frame.extend_from_slice(&prelude_crc.to_be_bytes());
frame.extend_from_slice(headers);
frame.extend_from_slice(payload);
let msg_crc = crc32fast::hash(&frame);
frame.extend_from_slice(&msg_crc.to_be_bytes());
frame
}
#[test]
fn decode_frame_empty_payload() {
let frame = build_frame(&[], &[]);
let mut buf = BytesMut::from(frame.as_slice());
let payload = eventstream::decode_frame(&mut buf).unwrap().unwrap();
assert!(payload.is_empty());
assert!(buf.is_empty());
}
#[test]
fn decode_frame_with_payload() {
let payload_data = b"hello world";
let frame = build_frame(&[], payload_data);
let mut buf = BytesMut::from(frame.as_slice());
let payload = eventstream::decode_frame(&mut buf).unwrap().unwrap();
assert_eq!(payload, b"hello world");
assert!(buf.is_empty());
}
#[test]
fn decode_frame_incomplete() {
let frame = build_frame(&[], b"hello");
let mut buf = BytesMut::from(&frame[..frame.len() - 2]); assert!(eventstream::decode_frame(&mut buf).unwrap().is_none());
}
#[test]
fn decode_multiple_frames() {
let frame1 = build_frame(&[], b"first");
let frame2 = build_frame(&[], b"second");
let mut buf = BytesMut::new();
buf.extend_from_slice(&frame1);
buf.extend_from_slice(&frame2);
let p1 = eventstream::decode_frame(&mut buf).unwrap().unwrap();
assert_eq!(p1, b"first");
let p2 = eventstream::decode_frame(&mut buf).unwrap().unwrap();
assert_eq!(p2, b"second");
assert!(buf.is_empty());
}
#[test]
fn decode_frame_with_headers() {
let headers = b"\x00\x04test";
let payload_data = b"data";
let frame = build_frame(headers, payload_data);
let mut buf = BytesMut::from(frame.as_slice());
let payload = eventstream::decode_frame(&mut buf).unwrap().unwrap();
assert_eq!(payload, b"data");
}
#[test]
fn decode_frame_rejects_bad_prelude_crc() {
let payload = b"{}";
let mut frame = build_frame(b"", payload);
frame[8] ^= 0xFF; let mut buf = BytesMut::from(frame.as_slice());
let result = eventstream::decode_frame(&mut buf);
assert!(result.is_err(), "bad prelude CRC must be rejected");
}
#[test]
fn decode_frame_prelude_crc_failure_does_not_advance_buffer() {
let payload = b"{}";
let mut frame = build_frame(b"", payload);
let original_len = frame.len();
frame[8] ^= 0xFF; let mut buf = BytesMut::from(frame.as_slice());
let result = eventstream::decode_frame(&mut buf);
assert!(result.is_err());
assert_eq!(
buf.len(),
original_len,
"buffer must not be consumed when prelude CRC fails (total_len is untrustworthy)"
);
}
#[test]
fn decode_frame_rejects_bad_message_crc() {
let payload = b"{}";
let mut frame = build_frame(b"", payload);
let last = frame.len() - 1;
frame[last] ^= 0xFF; let mut buf = BytesMut::from(frame.as_slice());
let result = eventstream::decode_frame(&mut buf);
assert!(result.is_err(), "bad message CRC must be rejected");
}
#[test]
fn decode_frame_accepts_valid_crc() {
let payload = b"{}";
let frame = build_frame(b"", payload);
let mut buf = BytesMut::from(frame.as_slice());
let result = eventstream::decode_frame(&mut buf);
assert!(result.is_ok(), "valid CRC must be accepted");
assert!(result.unwrap().is_some());
}
#[test]
fn extract_event_from_valid_payload() {
use base64::Engine;
let event_json = r#"{"type":"content_block_delta","index":0}"#;
let b64 = base64::engine::general_purpose::STANDARD.encode(event_json);
let wrapper = format!(r#"{{"bytes":"{b64}"}}"#);
let result = eventstream::extract_event_from_payload(wrapper.as_bytes());
assert_eq!(result.unwrap(), event_json);
}
#[test]
fn extract_event_empty_payload() {
assert!(eventstream::extract_event_from_payload(&[]).is_none());
}
#[test]
fn extract_event_invalid_json() {
assert!(eventstream::extract_event_from_payload(b"not json").is_none());
}
#[test]
fn extract_event_missing_bytes_field() {
let payload = r#"{"other":"field"}"#;
assert!(eventstream::extract_event_from_payload(payload.as_bytes()).is_none());
}
}