use std::time::Duration;
use async_trait::async_trait;
use chio_core_types::GuardEvidence;
use chio_kernel::Verdict;
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
use reqwest::{Client, StatusCode};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use zeroize::Zeroizing;
use super::{ExternalGuard, ExternalGuardError, GuardCallContext};
pub const GUARD_NAME: &str = "bedrock-guardrail";
pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum BedrockSource {
#[default]
Input,
Output,
}
impl BedrockSource {
fn as_str(self) -> &'static str {
match self {
Self::Input => "INPUT",
Self::Output => "OUTPUT",
}
}
}
#[derive(Clone)]
pub struct BedrockGuardrailConfig {
pub api_key: Zeroizing<String>,
pub region: String,
pub guardrail_id: String,
pub guardrail_version: String,
pub endpoint: Option<String>,
pub source: BedrockSource,
pub timeout: Duration,
}
impl std::fmt::Debug for BedrockGuardrailConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BedrockGuardrailConfig")
.field("api_key", &"***redacted***")
.field("region", &self.region)
.field("guardrail_id", &self.guardrail_id)
.field("guardrail_version", &self.guardrail_version)
.field("endpoint", &self.endpoint)
.field("source", &self.source)
.field("timeout", &self.timeout)
.finish()
}
}
impl BedrockGuardrailConfig {
pub fn new(
api_key: impl Into<String>,
region: impl Into<String>,
guardrail_id: impl Into<String>,
guardrail_version: impl Into<String>,
) -> Self {
Self {
api_key: Zeroizing::new(api_key.into()),
region: region.into(),
guardrail_id: guardrail_id.into(),
guardrail_version: guardrail_version.into(),
endpoint: None,
source: BedrockSource::Input,
timeout: DEFAULT_TIMEOUT,
}
}
pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
self.endpoint = Some(endpoint.into());
self
}
fn resolved_endpoint(&self) -> String {
match self.endpoint.as_deref() {
Some(ep) => ep.trim_end_matches('/').to_string(),
None => format!("https://bedrock-runtime.{}.amazonaws.com", self.region),
}
}
fn apply_url(&self) -> String {
format!(
"{}/guardrail/{}/version/{}/apply",
self.resolved_endpoint(),
self.guardrail_id,
self.guardrail_version
)
}
}
#[derive(Debug, Clone, Deserialize)]
struct ApplyGuardrailResponse {
#[serde(default)]
action: String,
#[serde(default)]
assessments: Vec<serde_json::Value>,
}
#[derive(Debug, Serialize)]
struct ApplyGuardrailRequest<'a> {
source: &'a str,
content: Vec<GuardrailContentBlock<'a>>,
}
#[derive(Debug, Serialize)]
struct GuardrailContentBlock<'a> {
text: GuardrailText<'a>,
}
#[derive(Debug, Serialize)]
struct GuardrailText<'a> {
text: &'a str,
}
pub struct BedrockGuardrailGuard {
cfg: BedrockGuardrailConfig,
http: Client,
}
impl BedrockGuardrailGuard {
pub fn new(cfg: BedrockGuardrailConfig) -> Result<Self, ExternalGuardError> {
let http = Client::builder()
.timeout(cfg.timeout)
.build()
.map_err(|e| ExternalGuardError::Permanent(format!("reqwest build: {e}")))?;
Ok(Self { cfg, http })
}
pub fn with_client(cfg: BedrockGuardrailConfig, http: Client) -> Self {
Self { cfg, http }
}
pub fn evidence_from_decision(
&self,
verdict: Verdict,
details: Option<&BedrockDecisionDetails>,
) -> GuardEvidence {
GuardEvidence {
guard_name: self.name().to_string(),
verdict: matches!(verdict, Verdict::Allow),
details: details.and_then(|d| d.as_details_string()),
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct BedrockDecisionDetails {
pub action: String,
pub intervened: bool,
pub assessments: Vec<serde_json::Value>,
}
impl BedrockDecisionDetails {
fn as_details_string(&self) -> Option<String> {
serde_json::to_string(self).ok()
}
}
#[async_trait]
impl ExternalGuard for BedrockGuardrailGuard {
fn name(&self) -> &str {
GUARD_NAME
}
fn cache_key(&self, ctx: &GuardCallContext) -> Option<String> {
let mut hasher = Sha256::new();
hasher.update(self.cfg.guardrail_id.as_bytes());
hasher.update(b":");
hasher.update(self.cfg.guardrail_version.as_bytes());
hasher.update(b":");
hasher.update(ctx.tool_name.as_bytes());
hasher.update(b":");
hasher.update(ctx.arguments_json.as_bytes());
let digest = hasher.finalize();
let mut hex = String::with_capacity(digest.len() * 2);
for b in digest {
hex.push_str(&format!("{b:02x}"));
}
Some(format!("bedrock:{hex}"))
}
async fn eval(&self, ctx: &GuardCallContext) -> Result<Verdict, ExternalGuardError> {
let url = self.cfg.apply_url();
super::endpoint_security::validate_external_guard_url("bedrock endpoint", &url)?;
let body = ApplyGuardrailRequest {
source: self.cfg.source.as_str(),
content: vec![GuardrailContentBlock {
text: GuardrailText {
text: &ctx.arguments_json,
},
}],
};
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
let auth_value = format!("Bearer {}", self.cfg.api_key.as_str());
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&auth_value)
.map_err(|e| ExternalGuardError::Permanent(format!("invalid api key: {e}")))?,
);
let resp = self
.http
.post(&url)
.headers(headers)
.json(&body)
.send()
.await
.map_err(classify_reqwest_error)?;
let status = resp.status();
let text = resp
.text()
.await
.map_err(|e| ExternalGuardError::Transient(format!("read body: {e}")))?;
if !status.is_success() {
return Err(classify_status_error("bedrock", status, &text));
}
let parsed: ApplyGuardrailResponse = serde_json::from_str(&text)
.map_err(|e| ExternalGuardError::Transient(format!("parse bedrock response: {e}")))?;
let intervened = parsed.action.eq_ignore_ascii_case("GUARDRAIL_INTERVENED");
tracing::info!(
guard = GUARD_NAME,
action = %parsed.action,
intervened,
assessments = parsed.assessments.len(),
"bedrock ApplyGuardrail response"
);
Ok(if intervened {
Verdict::Deny
} else {
Verdict::Allow
})
}
}
pub(crate) fn classify_status_error(
provider: &'static str,
status: StatusCode,
body: &str,
) -> ExternalGuardError {
let snippet = body.chars().take(256).collect::<String>();
if status.is_server_error() || status == StatusCode::TOO_MANY_REQUESTS {
ExternalGuardError::Transient(format!("{provider} HTTP {}: {}", status.as_u16(), snippet))
} else {
ExternalGuardError::Permanent(format!("{provider} HTTP {}: {}", status.as_u16(), snippet))
}
}
pub(crate) fn classify_reqwest_error(err: reqwest::Error) -> ExternalGuardError {
if err.is_timeout() {
ExternalGuardError::Timeout
} else if err.is_connect() || err.is_request() {
ExternalGuardError::Transient(err.to_string())
} else {
ExternalGuardError::Permanent(err.to_string())
}
}