chio_external_guards/external/
bedrock.rs1use std::time::Duration;
36
37use async_trait::async_trait;
38use chio_core_types::GuardEvidence;
39use chio_kernel::Verdict;
40use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
41use reqwest::{Client, StatusCode};
42use serde::{Deserialize, Serialize};
43use sha2::{Digest, Sha256};
44use zeroize::Zeroizing;
45
46use super::{ExternalGuard, ExternalGuardError, GuardCallContext};
47
48pub const GUARD_NAME: &str = "bedrock-guardrail";
50
51pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
59pub enum BedrockSource {
60 #[default]
62 Input,
63 Output,
65}
66
67impl BedrockSource {
68 fn as_str(self) -> &'static str {
69 match self {
70 Self::Input => "INPUT",
71 Self::Output => "OUTPUT",
72 }
73 }
74}
75
76#[derive(Clone)]
81pub struct BedrockGuardrailConfig {
82 pub api_key: Zeroizing<String>,
84 pub region: String,
86 pub guardrail_id: String,
88 pub guardrail_version: String,
91 pub endpoint: Option<String>,
94 pub source: BedrockSource,
96 pub timeout: Duration,
98}
99
100impl std::fmt::Debug for BedrockGuardrailConfig {
101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102 f.debug_struct("BedrockGuardrailConfig")
103 .field("api_key", &"***redacted***")
104 .field("region", &self.region)
105 .field("guardrail_id", &self.guardrail_id)
106 .field("guardrail_version", &self.guardrail_version)
107 .field("endpoint", &self.endpoint)
108 .field("source", &self.source)
109 .field("timeout", &self.timeout)
110 .finish()
111 }
112}
113
114impl BedrockGuardrailConfig {
115 pub fn new(
117 api_key: impl Into<String>,
118 region: impl Into<String>,
119 guardrail_id: impl Into<String>,
120 guardrail_version: impl Into<String>,
121 ) -> Self {
122 Self {
123 api_key: Zeroizing::new(api_key.into()),
124 region: region.into(),
125 guardrail_id: guardrail_id.into(),
126 guardrail_version: guardrail_version.into(),
127 endpoint: None,
128 source: BedrockSource::Input,
129 timeout: DEFAULT_TIMEOUT,
130 }
131 }
132
133 pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
135 self.endpoint = Some(endpoint.into());
136 self
137 }
138
139 fn resolved_endpoint(&self) -> String {
140 match self.endpoint.as_deref() {
141 Some(ep) => ep.trim_end_matches('/').to_string(),
142 None => format!("https://bedrock-runtime.{}.amazonaws.com", self.region),
143 }
144 }
145
146 fn apply_url(&self) -> String {
147 format!(
148 "{}/guardrail/{}/version/{}/apply",
149 self.resolved_endpoint(),
150 self.guardrail_id,
151 self.guardrail_version
152 )
153 }
154}
155
156#[derive(Debug, Clone, Deserialize)]
159struct ApplyGuardrailResponse {
160 #[serde(default)]
162 action: String,
163 #[serde(default)]
165 assessments: Vec<serde_json::Value>,
166}
167
168#[derive(Debug, Serialize)]
170struct ApplyGuardrailRequest<'a> {
171 source: &'a str,
172 content: Vec<GuardrailContentBlock<'a>>,
173}
174
175#[derive(Debug, Serialize)]
176struct GuardrailContentBlock<'a> {
177 text: GuardrailText<'a>,
178}
179
180#[derive(Debug, Serialize)]
181struct GuardrailText<'a> {
182 text: &'a str,
183}
184
185pub struct BedrockGuardrailGuard {
187 cfg: BedrockGuardrailConfig,
188 http: Client,
189}
190
191impl BedrockGuardrailGuard {
192 pub fn new(cfg: BedrockGuardrailConfig) -> Result<Self, ExternalGuardError> {
194 let http = Client::builder()
195 .timeout(cfg.timeout)
196 .build()
197 .map_err(|e| ExternalGuardError::Permanent(format!("reqwest build: {e}")))?;
198 Ok(Self { cfg, http })
199 }
200
201 pub fn with_client(cfg: BedrockGuardrailConfig, http: Client) -> Self {
204 Self { cfg, http }
205 }
206
207 pub fn evidence_from_decision(
212 &self,
213 verdict: Verdict,
214 details: Option<&BedrockDecisionDetails>,
215 ) -> GuardEvidence {
216 GuardEvidence {
217 guard_name: self.name().to_string(),
218 verdict: matches!(verdict, Verdict::Allow),
219 details: details.and_then(|d| d.as_details_string()),
220 }
221 }
222}
223
224#[derive(Debug, Clone, Serialize)]
227pub struct BedrockDecisionDetails {
228 pub action: String,
230 pub intervened: bool,
232 pub assessments: Vec<serde_json::Value>,
234}
235
236impl BedrockDecisionDetails {
237 fn as_details_string(&self) -> Option<String> {
238 serde_json::to_string(self).ok()
239 }
240}
241
242#[async_trait]
243impl ExternalGuard for BedrockGuardrailGuard {
244 fn name(&self) -> &str {
245 GUARD_NAME
246 }
247
248 fn cache_key(&self, ctx: &GuardCallContext) -> Option<String> {
249 let mut hasher = Sha256::new();
250 hasher.update(self.cfg.guardrail_id.as_bytes());
251 hasher.update(b":");
252 hasher.update(self.cfg.guardrail_version.as_bytes());
253 hasher.update(b":");
254 hasher.update(ctx.tool_name.as_bytes());
255 hasher.update(b":");
256 hasher.update(ctx.arguments_json.as_bytes());
257 let digest = hasher.finalize();
258 let mut hex = String::with_capacity(digest.len() * 2);
259 for b in digest {
260 hex.push_str(&format!("{b:02x}"));
261 }
262 Some(format!("bedrock:{hex}"))
263 }
264
265 async fn eval(&self, ctx: &GuardCallContext) -> Result<Verdict, ExternalGuardError> {
266 let url = self.cfg.apply_url();
267 super::endpoint_security::validate_external_guard_url("bedrock endpoint", &url)?;
268 let body = ApplyGuardrailRequest {
269 source: self.cfg.source.as_str(),
270 content: vec![GuardrailContentBlock {
271 text: GuardrailText {
272 text: &ctx.arguments_json,
273 },
274 }],
275 };
276
277 let mut headers = HeaderMap::new();
278 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
279 let auth_value = format!("Bearer {}", self.cfg.api_key.as_str());
280 headers.insert(
281 AUTHORIZATION,
282 HeaderValue::from_str(&auth_value)
283 .map_err(|e| ExternalGuardError::Permanent(format!("invalid api key: {e}")))?,
284 );
285
286 let resp = self
287 .http
288 .post(&url)
289 .headers(headers)
290 .json(&body)
291 .send()
292 .await
293 .map_err(classify_reqwest_error)?;
294
295 let status = resp.status();
296 let text = resp
297 .text()
298 .await
299 .map_err(|e| ExternalGuardError::Transient(format!("read body: {e}")))?;
300
301 if !status.is_success() {
302 return Err(classify_status_error("bedrock", status, &text));
303 }
304
305 let parsed: ApplyGuardrailResponse = serde_json::from_str(&text)
306 .map_err(|e| ExternalGuardError::Transient(format!("parse bedrock response: {e}")))?;
307
308 let intervened = parsed.action.eq_ignore_ascii_case("GUARDRAIL_INTERVENED");
309 tracing::info!(
310 guard = GUARD_NAME,
311 action = %parsed.action,
312 intervened,
313 assessments = parsed.assessments.len(),
314 "bedrock ApplyGuardrail response"
315 );
316 Ok(if intervened {
317 Verdict::Deny
318 } else {
319 Verdict::Allow
320 })
321 }
322}
323
324pub(crate) fn classify_status_error(
327 provider: &'static str,
328 status: StatusCode,
329 body: &str,
330) -> ExternalGuardError {
331 let snippet = body.chars().take(256).collect::<String>();
332 if status.is_server_error() || status == StatusCode::TOO_MANY_REQUESTS {
333 ExternalGuardError::Transient(format!("{provider} HTTP {}: {}", status.as_u16(), snippet))
334 } else {
335 ExternalGuardError::Permanent(format!("{provider} HTTP {}: {}", status.as_u16(), snippet))
336 }
337}
338
339pub(crate) fn classify_reqwest_error(err: reqwest::Error) -> ExternalGuardError {
342 if err.is_timeout() {
343 ExternalGuardError::Timeout
344 } else if err.is_connect() || err.is_request() {
345 ExternalGuardError::Transient(err.to_string())
346 } else {
347 ExternalGuardError::Permanent(err.to_string())
348 }
349}