agent_control_specification_core 0.3.1-beta.0

Stateless Rust core for Agent Control Specification
Documentation
use super::{
    fold_score_verdict, BundledClassifierProvider, ClassifierVerdict, HttpTransport,
    ResolvedClassifierConfig, TransportRequest,
};
use crate::dispatchers::constants::*;
use serde_json::json;
use std::collections::BTreeMap;

const DEFAULT_API_VERSION: &str = "2024-09-01";
const DEFAULT_CATEGORIES: &[&str] = &["Hate", "SelfHarm", "Sexual", "Violence"];

#[derive(Debug, Default, Clone, Copy)]
pub struct AacsProvider;

impl BundledClassifierProvider for AacsProvider {
    fn classify(
        &self,
        cfg: &ResolvedClassifierConfig,
        subject: &str,
        transport: &dyn HttpTransport,
    ) -> Result<ClassifierVerdict, String> {
        let api_key = cfg
            .api_key
            .as_deref()
            .filter(|key| !key.is_empty())
            .ok_or_else(|| "aacs api key is required".to_string())?;
        if cfg.endpoint.is_empty() {
            return Err("aacs endpoint is required".to_string());
        }

        let api_version = cfg
            .provider_config
            .get("api_version")
            .and_then(|value| value.as_str())
            .unwrap_or(DEFAULT_API_VERSION);
        let categories = categories(cfg)?;
        let body = json!({
            "text": subject,
            "categories": categories,
            "outputType": "FourSeverityLevels",
        });
        let mut request = TransportRequest::post(analyze_url(&cfg.endpoint, api_version))
            .header(HEADER_CONTENT_TYPE, CONTENT_TYPE_JSON)
            .header(HEADER_ACCEPT, CONTENT_TYPE_JSON)
            .header("Ocp-Apim-Subscription-Key", api_key)
            .json(body)
            .timeout_ms(cfg.timeout_ms);
        for (name, value) in &cfg.extra_headers {
            request = request.header(name, value);
        }

        let response = transport.send(request)?;
        if !(200..300).contains(&response.status) {
            return Err(format!("aacs HTTP {} {}", response.status, response.body));
        }
        parse_response(cfg, &response.body, &categories)
    }
}

fn analyze_url(endpoint: &str, api_version: &str) -> String {
    let endpoint = endpoint.trim_end_matches('/');
    if endpoint.contains("/contentsafety/") {
        if endpoint.contains("api-version=") {
            endpoint.to_string()
        } else if endpoint.contains('?') {
            format!("{endpoint}&api-version={api_version}")
        } else {
            format!("{endpoint}?api-version={api_version}")
        }
    } else {
        format!("{endpoint}/contentsafety/text:analyze?api-version={api_version}")
    }
}

fn categories(cfg: &ResolvedClassifierConfig) -> Result<Vec<String>, String> {
    if !cfg.category_thresholds.is_empty() {
        return Ok(cfg.category_thresholds.keys().cloned().collect());
    }
    if let Some(values) = cfg
        .provider_config
        .get("categories")
        .and_then(|value| value.as_array())
    {
        let categories = values
            .iter()
            .map(|value| {
                value
                    .as_str()
                    .filter(|text| !text.is_empty())
                    .map(str::to_string)
                    .ok_or_else(|| "aacs provider_config categories must be strings".to_string())
            })
            .collect::<Result<Vec<_>, _>>()?;
        if categories.is_empty() {
            return Err("aacs categories must not be empty".to_string());
        }
        return Ok(categories);
    }
    Ok(DEFAULT_CATEGORIES
        .iter()
        .map(|value| value.to_string())
        .collect())
}

fn parse_response(
    cfg: &ResolvedClassifierConfig,
    body: &str,
    expected_categories: &[String],
) -> Result<ClassifierVerdict, String> {
    let value: serde_json::Value =
        serde_json::from_str(body).map_err(|error| format!("aacs JSON parse failed {error}"))?;
    let analysis = value
        .get("categoriesAnalysis")
        .and_then(|value| value.as_array())
        .ok_or_else(|| "aacs response missing categoriesAnalysis".to_string())?;
    if analysis.is_empty() {
        return Err("aacs response categoriesAnalysis was empty".to_string());
    }

    let mut scores = BTreeMap::new();
    for item in analysis {
        let category = item
            .get("category")
            .and_then(|value| value.as_str())
            .filter(|value| !value.is_empty())
            .ok_or_else(|| "aacs response item missing category".to_string())?;
        let severity = item
            .get("severity")
            .and_then(|value| value.as_u64())
            .ok_or_else(|| "aacs response item missing severity".to_string())?;
        if !matches!(severity, 0 | 2 | 4 | 6) {
            return Err(format!("aacs response severity {severity} is invalid"));
        }
        scores.insert(category.to_string(), severity as f64 / 6.0);
    }
    for category in expected_categories {
        if !scores.contains_key(category) {
            return Err(format!("aacs response missing category {category}"));
        }
    }

    Ok(fold_score_verdict(cfg, &scores))
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::dispatchers::bundled::StubHttpTransport;

    fn cfg() -> ResolvedClassifierConfig {
        let mut thresholds = BTreeMap::new();
        thresholds.insert("Hate".to_string(), 0.5);
        ResolvedClassifierConfig {
            provider: "aacs".to_string(),
            endpoint: "https://example.cognitiveservices.azure.com".to_string(),
            api_key: Some("test-key".to_string()),
            timeout_ms: 1000,
            threshold: 0.5,
            category_thresholds: thresholds,
            extra_headers: BTreeMap::new(),
            provider_config: serde_json::Value::Null,
        }
    }

    #[test]
    fn allow_when_severity_is_zero() {
        let transport = StubHttpTransport::with_response(
            200,
            r#"{"categoriesAnalysis":[{"category":"Hate","severity":0}]}"#,
        );
        let verdict = AacsProvider.classify(&cfg(), "hello", &transport).unwrap();
        assert!(!verdict.is_failure());
    }

    #[test]
    fn block_when_severity_is_high() {
        let transport = StubHttpTransport::with_response(
            200,
            r#"{"categoriesAnalysis":[{"category":"Hate","severity":6}]}"#,
        );
        let verdict = AacsProvider.classify(&cfg(), "hello", &transport).unwrap();
        assert!(verdict.is_failure());
        assert_eq!(verdict.label.as_deref(), Some("Hate"));
    }

    #[test]
    fn http_429_fails_closed() {
        let transport = StubHttpTransport::with_response(429, "rate limited");
        let error = AacsProvider
            .classify(&cfg(), "hello", &transport)
            .unwrap_err();
        assert!(error.contains("HTTP 429"));
    }

    #[test]
    fn malformed_body_fails_closed() {
        let transport = StubHttpTransport::with_response(200, r#"{"unexpected":true}"#);
        let error = AacsProvider
            .classify(&cfg(), "hello", &transport)
            .unwrap_err();
        assert!(error.contains("missing categoriesAnalysis"));
    }

    #[test]
    fn request_uses_aacs_header_and_url() {
        let transport = StubHttpTransport::with_response(
            200,
            r#"{"categoriesAnalysis":[{"category":"Hate","severity":0}]}"#,
        );
        let _ = AacsProvider.classify(&cfg(), "hello", &transport).unwrap();
        let request = transport.last_request().unwrap();
        assert_eq!(
            request
                .headers
                .get("Ocp-Apim-Subscription-Key")
                .map(String::as_str),
            Some("test-key")
        );
        assert_eq!(
            request.url,
            "https://example.cognitiveservices.azure.com/contentsafety/text:analyze?api-version=2024-09-01"
        );
    }

    #[test]
    fn request_preserves_existing_query_parameters() {
        let mut cfg = cfg();
        cfg.endpoint =
            "https://example.cognitiveservices.azure.com/contentsafety/text:analyze?logging=off"
                .to_string();
        let transport = StubHttpTransport::with_response(
            200,
            r#"{"categoriesAnalysis":[{"category":"Hate","severity":0}]}"#,
        );

        let _ = AacsProvider.classify(&cfg, "hello", &transport).unwrap();

        assert_eq!(
            transport.last_request().unwrap().url,
            "https://example.cognitiveservices.azure.com/contentsafety/text:analyze?logging=off&api-version=2024-09-01"
        );
    }

    #[test]
    fn request_does_not_duplicate_existing_api_version() {
        let mut cfg = cfg();
        cfg.endpoint = "https://example.cognitiveservices.azure.com/contentsafety/text:analyze?logging=off&api-version=2023-10-01"
            .to_string();
        let transport = StubHttpTransport::with_response(
            200,
            r#"{"categoriesAnalysis":[{"category":"Hate","severity":0}]}"#,
        );

        let _ = AacsProvider.classify(&cfg, "hello", &transport).unwrap();

        assert_eq!(
            transport.last_request().unwrap().url,
            "https://example.cognitiveservices.azure.com/contentsafety/text:analyze?logging=off&api-version=2023-10-01"
        );
    }
}