Skip to main content

chio_external_guards/external/
azure_content_safety.rs

1//! Azure Content Safety `text:analyze` adapter (phase 13.2).
2//!
3//! Wraps the Azure Content Safety `text:analyze` endpoint as an
4//! [`ExternalGuard`]. Each category returned by the API carries a
5//! `severity` value; the guard denies the call when any configured
6//! category exceeds the configured severity threshold.
7//!
8//! See: <https://learn.microsoft.com/azure/ai-services/content-safety/reference/rest-api-reference-text>
9//!
10//! # Fail-closed
11//!
12//! HTTP/transport errors propagate as [`ExternalGuardError`]; the
13//! [`AsyncGuardAdapter`] maps those into [`Verdict::Deny`].
14//!
15//! [`AsyncGuardAdapter`]: super::AsyncGuardAdapter
16
17use std::time::Duration;
18
19use async_trait::async_trait;
20use chio_core_types::GuardEvidence;
21use chio_kernel::Verdict;
22use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE};
23use reqwest::Client;
24use serde::{Deserialize, Serialize};
25use sha2::{Digest, Sha256};
26use zeroize::Zeroizing;
27
28use super::bedrock::{classify_reqwest_error, classify_status_error};
29use super::{ExternalGuard, ExternalGuardError, GuardCallContext};
30
31/// Guard name reported by [`AzureContentSafetyGuard::name`].
32pub const GUARD_NAME: &str = "azure-content-safety";
33
34/// Default request timeout.
35pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
36
37/// Default `api-version` query parameter.
38pub const DEFAULT_API_VERSION: &str = "2023-10-01";
39
40/// Default severity threshold. Azure Content Safety returns values in
41/// `{0, 2, 4, 6}` (or the 8-level scale). `4` corresponds to "Medium",
42/// which is the typical blocking threshold.
43pub const DEFAULT_SEVERITY_THRESHOLD: u32 = 4;
44
45/// A content category as reported by Azure Content Safety.
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum AzureCategory {
48    Hate,
49    SelfHarm,
50    Sexual,
51    Violence,
52}
53
54impl AzureCategory {
55    /// Returns the canonical API name for the category.
56    pub const fn as_str(self) -> &'static str {
57        match self {
58            Self::Hate => "Hate",
59            Self::SelfHarm => "SelfHarm",
60            Self::Sexual => "Sexual",
61            Self::Violence => "Violence",
62        }
63    }
64
65    /// All known categories.
66    pub const fn all() -> [Self; 4] {
67        [Self::Hate, Self::SelfHarm, Self::Sexual, Self::Violence]
68    }
69}
70
71/// Configuration for [`AzureContentSafetyGuard`].
72#[derive(Clone)]
73pub struct AzureContentSafetyConfig {
74    /// Content Safety API key (`Ocp-Apim-Subscription-Key` header).
75    pub api_key: Zeroizing<String>,
76    /// Content Safety endpoint (e.g.
77    /// `https://<region>.api.cognitive.microsoft.com`).
78    pub endpoint: String,
79    /// `api-version` query parameter.
80    pub api_version: String,
81    /// Per-request HTTP timeout.
82    pub timeout: Duration,
83    /// Severity threshold; any category at or above this value triggers a
84    /// [`Verdict::Deny`]. Azure uses severity `0..=7` on the 8-level
85    /// scale (`0, 2, 4, 6` on the 4-level scale).
86    pub severity_threshold: u32,
87    /// Categories to submit. An empty vector means "all known
88    /// categories".
89    pub categories: Vec<AzureCategory>,
90}
91
92impl std::fmt::Debug for AzureContentSafetyConfig {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        f.debug_struct("AzureContentSafetyConfig")
95            .field("api_key", &"***redacted***")
96            .field("endpoint", &self.endpoint)
97            .field("api_version", &self.api_version)
98            .field("timeout", &self.timeout)
99            .field("severity_threshold", &self.severity_threshold)
100            .field("categories", &self.categories)
101            .finish()
102    }
103}
104
105impl AzureContentSafetyConfig {
106    /// Construct a minimal config with defaults.
107    pub fn new(api_key: impl Into<String>, endpoint: impl Into<String>) -> Self {
108        Self {
109            api_key: Zeroizing::new(api_key.into()),
110            endpoint: endpoint.into(),
111            api_version: DEFAULT_API_VERSION.to_string(),
112            timeout: DEFAULT_TIMEOUT,
113            severity_threshold: DEFAULT_SEVERITY_THRESHOLD,
114            categories: AzureCategory::all().to_vec(),
115        }
116    }
117
118    /// Override the severity threshold.
119    pub fn with_severity_threshold(mut self, threshold: u32) -> Self {
120        self.severity_threshold = threshold;
121        self
122    }
123
124    /// Override the category list.
125    pub fn with_categories(mut self, categories: Vec<AzureCategory>) -> Self {
126        self.categories = categories;
127        self
128    }
129
130    fn analyze_url(&self) -> String {
131        let base = self.endpoint.trim_end_matches('/');
132        format!(
133            "{base}/contentsafety/text:analyze?api-version={}",
134            self.api_version
135        )
136    }
137}
138
139#[derive(Debug, Serialize)]
140struct AnalyzeRequest<'a> {
141    text: &'a str,
142    categories: Vec<&'a str>,
143    #[serde(skip_serializing_if = "Option::is_none")]
144    #[serde(rename = "outputType")]
145    output_type: Option<&'a str>,
146}
147
148#[derive(Debug, Clone, Deserialize)]
149struct AnalyzeResponse {
150    #[serde(default, rename = "categoriesAnalysis")]
151    categories_analysis: Vec<CategoryResult>,
152}
153
154#[derive(Debug, Clone, Deserialize, Serialize)]
155struct CategoryResult {
156    #[serde(default)]
157    category: String,
158    #[serde(default)]
159    severity: u32,
160}
161
162/// Guard wrapping Azure Content Safety's `text:analyze` endpoint.
163pub struct AzureContentSafetyGuard {
164    cfg: AzureContentSafetyConfig,
165    http: Client,
166}
167
168impl AzureContentSafetyGuard {
169    /// Build a guard with an internally-owned [`reqwest::Client`].
170    pub fn new(cfg: AzureContentSafetyConfig) -> Result<Self, ExternalGuardError> {
171        let http = Client::builder()
172            .timeout(cfg.timeout)
173            .build()
174            .map_err(|e| ExternalGuardError::Permanent(format!("reqwest build: {e}")))?;
175        Ok(Self { cfg, http })
176    }
177
178    /// Build a guard with a caller-supplied client (for tests).
179    pub fn with_client(cfg: AzureContentSafetyConfig, http: Client) -> Self {
180        Self { cfg, http }
181    }
182
183    /// Build a [`GuardEvidence`] record for a prior decision.
184    pub fn evidence_from_decision(
185        &self,
186        verdict: Verdict,
187        details: Option<&AzureDecisionDetails>,
188    ) -> GuardEvidence {
189        GuardEvidence {
190            guard_name: self.name().to_string(),
191            verdict: matches!(verdict, Verdict::Allow),
192            details: details.and_then(|d| d.as_details_string()),
193        }
194    }
195}
196
197/// Structured details extracted from a Content Safety response.
198#[derive(Debug, Clone, Serialize)]
199pub struct AzureDecisionDetails {
200    /// Maximum severity observed across any category.
201    pub max_severity: u32,
202    /// Threshold used to make the decision.
203    pub severity_threshold: u32,
204    /// Category-level severity breakdown.
205    pub categories: Vec<AzureCategoryBreakdown>,
206}
207
208impl AzureDecisionDetails {
209    fn as_details_string(&self) -> Option<String> {
210        serde_json::to_string(self).ok()
211    }
212}
213
214#[derive(Debug, Clone, Serialize)]
215pub struct AzureCategoryBreakdown {
216    pub category: String,
217    pub severity: u32,
218}
219
220#[async_trait]
221impl ExternalGuard for AzureContentSafetyGuard {
222    fn name(&self) -> &str {
223        GUARD_NAME
224    }
225
226    fn cache_key(&self, ctx: &GuardCallContext) -> Option<String> {
227        let mut hasher = Sha256::new();
228        hasher.update(self.cfg.endpoint.as_bytes());
229        hasher.update(b":");
230        hasher.update(ctx.tool_name.as_bytes());
231        hasher.update(b":");
232        hasher.update(ctx.arguments_json.as_bytes());
233        let digest = hasher.finalize();
234        let mut hex = String::with_capacity(digest.len() * 2);
235        for b in digest {
236            hex.push_str(&format!("{b:02x}"));
237        }
238        Some(format!("azure-cs:{hex}"))
239    }
240
241    async fn eval(&self, ctx: &GuardCallContext) -> Result<Verdict, ExternalGuardError> {
242        super::endpoint_security::validate_external_guard_url(
243            "azure-content-safety endpoint",
244            &self.cfg.endpoint,
245        )?;
246        let url = self.cfg.analyze_url();
247
248        let cats_ref: Vec<&str> = if self.cfg.categories.is_empty() {
249            AzureCategory::all().iter().map(|c| c.as_str()).collect()
250        } else {
251            self.cfg.categories.iter().map(|c| c.as_str()).collect()
252        };
253
254        let body = AnalyzeRequest {
255            text: &ctx.arguments_json,
256            categories: cats_ref,
257            output_type: Some("FourSeverityLevels"),
258        };
259
260        let mut headers = HeaderMap::new();
261        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
262        headers.insert(
263            "Ocp-Apim-Subscription-Key",
264            HeaderValue::from_str(self.cfg.api_key.as_str())
265                .map_err(|e| ExternalGuardError::Permanent(format!("invalid api key: {e}")))?,
266        );
267
268        let resp = self
269            .http
270            .post(&url)
271            .headers(headers)
272            .json(&body)
273            .send()
274            .await
275            .map_err(classify_reqwest_error)?;
276
277        let status = resp.status();
278        let text = resp
279            .text()
280            .await
281            .map_err(|e| ExternalGuardError::Transient(format!("read body: {e}")))?;
282
283        if !status.is_success() {
284            return Err(classify_status_error("azure-content-safety", status, &text));
285        }
286
287        let parsed: AnalyzeResponse = serde_json::from_str(&text).map_err(|e| {
288            ExternalGuardError::Transient(format!("parse azure content safety response: {e}"))
289        })?;
290
291        let mut max_severity = 0_u32;
292        for entry in &parsed.categories_analysis {
293            if entry.severity > max_severity {
294                max_severity = entry.severity;
295            }
296        }
297
298        tracing::info!(
299            guard = GUARD_NAME,
300            max_severity,
301            threshold = self.cfg.severity_threshold,
302            categories = parsed.categories_analysis.len(),
303            "azure content safety response"
304        );
305
306        Ok(if max_severity >= self.cfg.severity_threshold {
307            Verdict::Deny
308        } else {
309            Verdict::Allow
310        })
311    }
312}