1use crate::{Guardrail, GuardrailError, GuardrailResult, Severity};
2use adk_core::{Content, Part};
3use async_trait::async_trait;
4use jsonschema::Validator;
5use serde_json::Value;
6
7pub struct SchemaValidator {
9 name: String,
10 validator: Validator,
11 severity: Severity,
12}
13
14impl SchemaValidator {
15 pub fn new(schema: &Value) -> Result<Self, GuardrailError> {
17 let validator = Validator::new(schema)
18 .map_err(|e| GuardrailError::Schema(format!("Invalid schema: {}", e)))?;
19
20 Ok(Self { name: "schema_validator".to_string(), validator, severity: Severity::High })
21 }
22
23 pub fn with_name(mut self, name: impl Into<String>) -> Self {
25 self.name = name.into();
26 self
27 }
28
29 pub fn with_severity(mut self, severity: Severity) -> Self {
31 self.severity = severity;
32 self
33 }
34
35 fn extract_json(&self, content: &Content) -> Option<Value> {
36 for part in &content.parts {
37 if let Part::Text { text } = part {
38 if let Ok(json) = serde_json::from_str(text) {
40 return Some(json);
41 }
42 if let Some(json_str) = Self::extract_json_from_markdown(text) {
44 if let Ok(json) = serde_json::from_str(&json_str) {
45 return Some(json);
46 }
47 }
48 }
49 }
50 None
51 }
52
53 fn extract_json_from_markdown(text: &str) -> Option<String> {
54 let start_markers = ["```json\n", "```json\r\n", "```\n", "```\r\n"];
56 let end_marker = "```";
57
58 for start in start_markers {
59 if let Some(start_idx) = text.find(start) {
60 let content_start = start_idx + start.len();
61 if let Some(end_idx) = text[content_start..].find(end_marker) {
62 return Some(text[content_start..content_start + end_idx].trim().to_string());
63 }
64 }
65 }
66 None
67 }
68}
69
70#[async_trait]
71impl Guardrail for SchemaValidator {
72 fn name(&self) -> &str {
73 &self.name
74 }
75
76 async fn validate(&self, content: &Content) -> GuardrailResult {
77 let json = match self.extract_json(content) {
78 Some(j) => j,
79 None => {
80 return GuardrailResult::Fail {
81 reason: "Content does not contain valid JSON".to_string(),
82 severity: self.severity,
83 };
84 }
85 };
86
87 let result = self.validator.validate(&json);
88 if let Err(error) = result {
89 return GuardrailResult::Fail {
90 reason: format!("Schema validation failed: {}", error),
91 severity: self.severity,
92 };
93 }
94
95 GuardrailResult::Pass
96 }
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102 use serde_json::json;
103
104 fn test_schema() -> Value {
105 json!({
106 "type": "object",
107 "properties": {
108 "name": { "type": "string" },
109 "age": { "type": "integer", "minimum": 0 }
110 },
111 "required": ["name"]
112 })
113 }
114
115 #[tokio::test]
116 async fn test_valid_json() {
117 let validator = SchemaValidator::new(&test_schema()).unwrap();
118 let content = Content::new("model").with_text(r#"{"name": "Alice", "age": 30}"#);
119 let result = validator.validate(&content).await;
120 assert!(result.is_pass());
121 }
122
123 #[tokio::test]
124 async fn test_invalid_json_missing_required() {
125 let validator = SchemaValidator::new(&test_schema()).unwrap();
126 let content = Content::new("model").with_text(r#"{"age": 30}"#);
127 let result = validator.validate(&content).await;
128 assert!(result.is_fail());
129 }
130
131 #[tokio::test]
132 async fn test_invalid_json_wrong_type() {
133 let validator = SchemaValidator::new(&test_schema()).unwrap();
134 let content = Content::new("model").with_text(r#"{"name": "Alice", "age": "thirty"}"#);
135 let result = validator.validate(&content).await;
136 assert!(result.is_fail());
137 }
138
139 #[tokio::test]
140 async fn test_json_in_markdown() {
141 let validator = SchemaValidator::new(&test_schema()).unwrap();
142 let content = Content::new("model")
143 .with_text("Here is the result:\n```json\n{\"name\": \"Bob\"}\n```");
144 let result = validator.validate(&content).await;
145 assert!(result.is_pass());
146 }
147
148 #[tokio::test]
149 async fn test_no_json() {
150 let validator = SchemaValidator::new(&test_schema()).unwrap();
151 let content = Content::new("model").with_text("This is just plain text");
152 let result = validator.validate(&content).await;
153 assert!(result.is_fail());
154 }
155}