agent_control_specification_core 0.3.1-beta.0

Stateless Rust core for Agent Control Specification
Documentation
use crate::dispatchers::{
    bundled::{self, HttpTransport, ResolvedClassifierConfig, UreqHttpTransport},
    constants::*,
    http, resolve,
};
use crate::{AnnotatorDispatcher, AnnotatorInvocation, JsonValue, RuntimeError};
use serde_json::json;

#[derive(Debug, Default, Clone, Copy)]
pub struct ClassifierAnnotator;

impl ClassifierAnnotator {
    pub fn new() -> Self {
        Self
    }

    pub fn dispatch_with_transport(
        &self,
        annotator_name: &str,
        annotator: &AnnotatorInvocation,
        preliminary_policy_input: &JsonValue,
        transport: &dyn HttpTransport,
    ) -> Result<JsonValue, RuntimeError> {
        if annotator.field(ANNOTATOR_TYPE).and_then(JsonValue::as_str) != Some(TYPE_CLASSIFIER) {
            return Err(resolve::failed(
                annotator_name,
                "classifier dispatcher received a non-classifier annotator",
            ));
        }
        let policy_target =
            resolve::policy_target_text(annotator_name, annotator, preliminary_policy_input)?;
        if http::optional_string_field(&annotator.fields, FIELD_PROVIDER).is_none() {
            return Err(resolve::failed(
                annotator_name,
                "custom transport requires a bundled classifier provider",
            ));
        }
        dispatch_bundled_with_transport(annotator_name, annotator, &policy_target, transport)
    }
}

impl AnnotatorDispatcher for ClassifierAnnotator {
    fn dispatch(
        &self,
        annotator_name: &str,
        annotator: &AnnotatorInvocation,
        preliminary_policy_input: &JsonValue,
    ) -> Result<JsonValue, RuntimeError> {
        if annotator.field(ANNOTATOR_TYPE).and_then(JsonValue::as_str) != Some(TYPE_CLASSIFIER) {
            return Err(resolve::failed(
                annotator_name,
                "classifier dispatcher received a non-classifier annotator",
            ));
        }
        let policy_target =
            resolve::policy_target_text(annotator_name, annotator, preliminary_policy_input)?;
        if http::optional_string_field(&annotator.fields, FIELD_PROVIDER).is_some() {
            return dispatch_bundled_with_transport(
                annotator_name,
                annotator,
                &policy_target,
                &UreqHttpTransport,
            );
        }
        dispatch_generic(annotator_name, annotator, policy_target)
    }
}

fn dispatch_generic(
    annotator_name: &str,
    annotator: &AnnotatorInvocation,
    policy_target: String,
) -> Result<JsonValue, RuntimeError> {
    let url = http::required_string_field(annotator_name, &annotator.fields, FIELD_URL)?;
    let input_field = http::optional_string_field(&annotator.fields, FIELD_INPUT_FIELD)
        .unwrap_or(DEFAULT_INPUT_FIELD);
    let api_key = http::env_api_key(annotator_name, &annotator.fields)?;
    let timeout_ms = http::timeout_ms(annotator_name, &annotator.fields)?;
    let response = http::post_json(
        annotator_name,
        url,
        json!({ input_field: policy_target }),
        api_key,
        timeout_ms,
    )?;
    match http::optional_string_field(&annotator.fields, FIELD_RESPONSE_FIELD) {
        Some(field) => response.get(field).cloned().ok_or_else(|| {
            resolve::failed(annotator_name, format!("response missing field '{field}'"))
        }),
        None => Ok(response),
    }
}

fn dispatch_bundled_with_transport(
    annotator_name: &str,
    annotator: &AnnotatorInvocation,
    policy_target: &str,
    transport: &dyn HttpTransport,
) -> Result<JsonValue, RuntimeError> {
    let cfg = ResolvedClassifierConfig::from_fields(&annotator.fields)
        .map_err(|error| resolve::failed(annotator_name, error))?;
    bundled::classify(&cfg, policy_target, transport)
        .map(|verdict| verdict.to_json())
        .map_err(|error| resolve::failed(annotator_name, error))
}

#[cfg(all(test, feature = "aacs"))]
mod tests {
    use super::*;
    use crate::dispatchers::bundled::StubHttpTransport;
    use serde_json::json;
    use std::collections::BTreeMap;

    #[test]
    fn provider_field_routes_to_bundled_classifier() {
        std::env::set_var("ACS_AACS_TEST_KEY", "test-key");
        let annotator = AnnotatorInvocation {
            fields: BTreeMap::from([
                (ANNOTATOR_TYPE.to_string(), json!(TYPE_CLASSIFIER)),
                (FIELD_PROVIDER.to_string(), json!("aacs")),
                (FIELD_ENDPOINT.to_string(), json!("https://example.test")),
                (FIELD_API_KEY_ENV.to_string(), json!("ACS_AACS_TEST_KEY")),
            ]),
        };
        let transport = StubHttpTransport::with_response(
            200,
            r#"{"categoriesAnalysis":[{"category":"Hate","severity":0},{"category":"SelfHarm","severity":0},{"category":"Sexual","severity":0},{"category":"Violence","severity":0}]}"#,
        );

        let output = dispatch_bundled_with_transport("classifier", &annotator, "hello", &transport)
            .expect("provider succeeds");

        assert_eq!(output["verdict"], json!("allow"));
        assert_eq!(output["flagged"], json!(false));
    }
}