Skip to main content

adk_guardrail/
schema.rs

1use crate::{Guardrail, GuardrailError, GuardrailResult, Severity};
2use adk_core::{Content, Part};
3use async_trait::async_trait;
4use jsonschema::Validator;
5use serde_json::Value;
6
7/// JSON Schema validator guardrail for enforcing output structure
8pub struct SchemaValidator {
9    name: String,
10    validator: Validator,
11    severity: Severity,
12}
13
14impl SchemaValidator {
15    /// Create a new schema validator from a JSON Schema value
16    pub fn new(schema: &Value) -> Result<Self, GuardrailError> {
17        let validator = Validator::new(schema)
18            .map_err(|e| GuardrailError::Schema(format!("Invalid schema: {}", e)))?;
19
20        Ok(Self { name: "schema_validator".to_string(), validator, severity: Severity::High })
21    }
22
23    /// Create with a custom name
24    pub fn with_name(mut self, name: impl Into<String>) -> Self {
25        self.name = name.into();
26        self
27    }
28
29    /// Set severity level
30    pub fn with_severity(mut self, severity: Severity) -> Self {
31        self.severity = severity;
32        self
33    }
34
35    fn extract_json(&self, content: &Content) -> Option<Value> {
36        for part in &content.parts {
37            if let Part::Text { text } = part {
38                // Try to parse as JSON directly
39                if let Ok(json) = serde_json::from_str(text) {
40                    return Some(json);
41                }
42                // Try to extract JSON from markdown code block
43                if let Some(json_str) = Self::extract_json_from_markdown(text) {
44                    if let Ok(json) = serde_json::from_str(&json_str) {
45                        return Some(json);
46                    }
47                }
48            }
49        }
50        None
51    }
52
53    fn extract_json_from_markdown(text: &str) -> Option<String> {
54        // Look for ```json ... ``` blocks
55        let start_markers = ["```json\n", "```json\r\n", "```\n", "```\r\n"];
56        let end_marker = "```";
57
58        for start in start_markers {
59            if let Some(start_idx) = text.find(start) {
60                let content_start = start_idx + start.len();
61                if let Some(end_idx) = text[content_start..].find(end_marker) {
62                    return Some(text[content_start..content_start + end_idx].trim().to_string());
63                }
64            }
65        }
66        None
67    }
68}
69
70#[async_trait]
71impl Guardrail for SchemaValidator {
72    fn name(&self) -> &str {
73        &self.name
74    }
75
76    async fn validate(&self, content: &Content) -> GuardrailResult {
77        let json = match self.extract_json(content) {
78            Some(j) => j,
79            None => {
80                return GuardrailResult::Fail {
81                    reason: "Content does not contain valid JSON".to_string(),
82                    severity: self.severity,
83                };
84            }
85        };
86
87        let result = self.validator.validate(&json);
88        if let Err(error) = result {
89            return GuardrailResult::Fail {
90                reason: format!("Schema validation failed: {}", error),
91                severity: self.severity,
92            };
93        }
94
95        GuardrailResult::Pass
96    }
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102    use serde_json::json;
103
104    fn test_schema() -> Value {
105        json!({
106            "type": "object",
107            "properties": {
108                "name": { "type": "string" },
109                "age": { "type": "integer", "minimum": 0 }
110            },
111            "required": ["name"]
112        })
113    }
114
115    #[tokio::test]
116    async fn test_valid_json() {
117        let validator = SchemaValidator::new(&test_schema()).unwrap();
118        let content = Content::new("model").with_text(r#"{"name": "Alice", "age": 30}"#);
119        let result = validator.validate(&content).await;
120        assert!(result.is_pass());
121    }
122
123    #[tokio::test]
124    async fn test_invalid_json_missing_required() {
125        let validator = SchemaValidator::new(&test_schema()).unwrap();
126        let content = Content::new("model").with_text(r#"{"age": 30}"#);
127        let result = validator.validate(&content).await;
128        assert!(result.is_fail());
129    }
130
131    #[tokio::test]
132    async fn test_invalid_json_wrong_type() {
133        let validator = SchemaValidator::new(&test_schema()).unwrap();
134        let content = Content::new("model").with_text(r#"{"name": "Alice", "age": "thirty"}"#);
135        let result = validator.validate(&content).await;
136        assert!(result.is_fail());
137    }
138
139    #[tokio::test]
140    async fn test_json_in_markdown() {
141        let validator = SchemaValidator::new(&test_schema()).unwrap();
142        let content = Content::new("model")
143            .with_text("Here is the result:\n```json\n{\"name\": \"Bob\"}\n```");
144        let result = validator.validate(&content).await;
145        assert!(result.is_pass());
146    }
147
148    #[tokio::test]
149    async fn test_no_json() {
150        let validator = SchemaValidator::new(&test_schema()).unwrap();
151        let content = Content::new("model").with_text("This is just plain text");
152        let result = validator.validate(&content).await;
153        assert!(result.is_fail());
154    }
155}