chio_external_guards/external/
azure_content_safety.rs1use std::time::Duration;
18
19use async_trait::async_trait;
20use chio_core_types::GuardEvidence;
21use chio_kernel::Verdict;
22use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE};
23use reqwest::Client;
24use serde::{Deserialize, Serialize};
25use sha2::{Digest, Sha256};
26use zeroize::Zeroizing;
27
28use super::bedrock::{classify_reqwest_error, classify_status_error};
29use super::{ExternalGuard, ExternalGuardError, GuardCallContext};
30
31pub const GUARD_NAME: &str = "azure-content-safety";
33
34pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
36
37pub const DEFAULT_API_VERSION: &str = "2023-10-01";
39
40pub const DEFAULT_SEVERITY_THRESHOLD: u32 = 4;
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum AzureCategory {
48 Hate,
49 SelfHarm,
50 Sexual,
51 Violence,
52}
53
54impl AzureCategory {
55 pub const fn as_str(self) -> &'static str {
57 match self {
58 Self::Hate => "Hate",
59 Self::SelfHarm => "SelfHarm",
60 Self::Sexual => "Sexual",
61 Self::Violence => "Violence",
62 }
63 }
64
65 pub const fn all() -> [Self; 4] {
67 [Self::Hate, Self::SelfHarm, Self::Sexual, Self::Violence]
68 }
69}
70
71#[derive(Clone)]
73pub struct AzureContentSafetyConfig {
74 pub api_key: Zeroizing<String>,
76 pub endpoint: String,
79 pub api_version: String,
81 pub timeout: Duration,
83 pub severity_threshold: u32,
87 pub categories: Vec<AzureCategory>,
90}
91
92impl std::fmt::Debug for AzureContentSafetyConfig {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 f.debug_struct("AzureContentSafetyConfig")
95 .field("api_key", &"***redacted***")
96 .field("endpoint", &self.endpoint)
97 .field("api_version", &self.api_version)
98 .field("timeout", &self.timeout)
99 .field("severity_threshold", &self.severity_threshold)
100 .field("categories", &self.categories)
101 .finish()
102 }
103}
104
105impl AzureContentSafetyConfig {
106 pub fn new(api_key: impl Into<String>, endpoint: impl Into<String>) -> Self {
108 Self {
109 api_key: Zeroizing::new(api_key.into()),
110 endpoint: endpoint.into(),
111 api_version: DEFAULT_API_VERSION.to_string(),
112 timeout: DEFAULT_TIMEOUT,
113 severity_threshold: DEFAULT_SEVERITY_THRESHOLD,
114 categories: AzureCategory::all().to_vec(),
115 }
116 }
117
118 pub fn with_severity_threshold(mut self, threshold: u32) -> Self {
120 self.severity_threshold = threshold;
121 self
122 }
123
124 pub fn with_categories(mut self, categories: Vec<AzureCategory>) -> Self {
126 self.categories = categories;
127 self
128 }
129
130 fn analyze_url(&self) -> String {
131 let base = self.endpoint.trim_end_matches('/');
132 format!(
133 "{base}/contentsafety/text:analyze?api-version={}",
134 self.api_version
135 )
136 }
137}
138
139#[derive(Debug, Serialize)]
140struct AnalyzeRequest<'a> {
141 text: &'a str,
142 categories: Vec<&'a str>,
143 #[serde(skip_serializing_if = "Option::is_none")]
144 #[serde(rename = "outputType")]
145 output_type: Option<&'a str>,
146}
147
148#[derive(Debug, Clone, Deserialize)]
149struct AnalyzeResponse {
150 #[serde(default, rename = "categoriesAnalysis")]
151 categories_analysis: Vec<CategoryResult>,
152}
153
154#[derive(Debug, Clone, Deserialize, Serialize)]
155struct CategoryResult {
156 #[serde(default)]
157 category: String,
158 #[serde(default)]
159 severity: u32,
160}
161
162pub struct AzureContentSafetyGuard {
164 cfg: AzureContentSafetyConfig,
165 http: Client,
166}
167
168impl AzureContentSafetyGuard {
169 pub fn new(cfg: AzureContentSafetyConfig) -> Result<Self, ExternalGuardError> {
171 let http = Client::builder()
172 .timeout(cfg.timeout)
173 .build()
174 .map_err(|e| ExternalGuardError::Permanent(format!("reqwest build: {e}")))?;
175 Ok(Self { cfg, http })
176 }
177
178 pub fn with_client(cfg: AzureContentSafetyConfig, http: Client) -> Self {
180 Self { cfg, http }
181 }
182
183 pub fn evidence_from_decision(
185 &self,
186 verdict: Verdict,
187 details: Option<&AzureDecisionDetails>,
188 ) -> GuardEvidence {
189 GuardEvidence {
190 guard_name: self.name().to_string(),
191 verdict: matches!(verdict, Verdict::Allow),
192 details: details.and_then(|d| d.as_details_string()),
193 }
194 }
195}
196
197#[derive(Debug, Clone, Serialize)]
199pub struct AzureDecisionDetails {
200 pub max_severity: u32,
202 pub severity_threshold: u32,
204 pub categories: Vec<AzureCategoryBreakdown>,
206}
207
208impl AzureDecisionDetails {
209 fn as_details_string(&self) -> Option<String> {
210 serde_json::to_string(self).ok()
211 }
212}
213
214#[derive(Debug, Clone, Serialize)]
215pub struct AzureCategoryBreakdown {
216 pub category: String,
217 pub severity: u32,
218}
219
220#[async_trait]
221impl ExternalGuard for AzureContentSafetyGuard {
222 fn name(&self) -> &str {
223 GUARD_NAME
224 }
225
226 fn cache_key(&self, ctx: &GuardCallContext) -> Option<String> {
227 let mut hasher = Sha256::new();
228 hasher.update(self.cfg.endpoint.as_bytes());
229 hasher.update(b":");
230 hasher.update(ctx.tool_name.as_bytes());
231 hasher.update(b":");
232 hasher.update(ctx.arguments_json.as_bytes());
233 let digest = hasher.finalize();
234 let mut hex = String::with_capacity(digest.len() * 2);
235 for b in digest {
236 hex.push_str(&format!("{b:02x}"));
237 }
238 Some(format!("azure-cs:{hex}"))
239 }
240
241 async fn eval(&self, ctx: &GuardCallContext) -> Result<Verdict, ExternalGuardError> {
242 super::endpoint_security::validate_external_guard_url(
243 "azure-content-safety endpoint",
244 &self.cfg.endpoint,
245 )?;
246 let url = self.cfg.analyze_url();
247
248 let cats_ref: Vec<&str> = if self.cfg.categories.is_empty() {
249 AzureCategory::all().iter().map(|c| c.as_str()).collect()
250 } else {
251 self.cfg.categories.iter().map(|c| c.as_str()).collect()
252 };
253
254 let body = AnalyzeRequest {
255 text: &ctx.arguments_json,
256 categories: cats_ref,
257 output_type: Some("FourSeverityLevels"),
258 };
259
260 let mut headers = HeaderMap::new();
261 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
262 headers.insert(
263 "Ocp-Apim-Subscription-Key",
264 HeaderValue::from_str(self.cfg.api_key.as_str())
265 .map_err(|e| ExternalGuardError::Permanent(format!("invalid api key: {e}")))?,
266 );
267
268 let resp = self
269 .http
270 .post(&url)
271 .headers(headers)
272 .json(&body)
273 .send()
274 .await
275 .map_err(classify_reqwest_error)?;
276
277 let status = resp.status();
278 let text = resp
279 .text()
280 .await
281 .map_err(|e| ExternalGuardError::Transient(format!("read body: {e}")))?;
282
283 if !status.is_success() {
284 return Err(classify_status_error("azure-content-safety", status, &text));
285 }
286
287 let parsed: AnalyzeResponse = serde_json::from_str(&text).map_err(|e| {
288 ExternalGuardError::Transient(format!("parse azure content safety response: {e}"))
289 })?;
290
291 let mut max_severity = 0_u32;
292 for entry in &parsed.categories_analysis {
293 if entry.severity > max_severity {
294 max_severity = entry.severity;
295 }
296 }
297
298 tracing::info!(
299 guard = GUARD_NAME,
300 max_severity,
301 threshold = self.cfg.severity_threshold,
302 categories = parsed.categories_analysis.len(),
303 "azure content safety response"
304 );
305
306 Ok(if max_severity >= self.cfg.severity_threshold {
307 Verdict::Deny
308 } else {
309 Verdict::Allow
310 })
311 }
312}