use serde_json::Value;
use std::collections::HashMap;
use crate::state::{State, StateSchema};
use super::error::FunctionalError;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ExpectedType {
Null,
Boolean,
Number,
String,
Array,
Object,
}
impl ExpectedType {
pub fn matches(&self, value: &Value) -> bool {
match self {
ExpectedType::Null => value.is_null(),
ExpectedType::Boolean => value.is_boolean(),
ExpectedType::Number => value.is_number(),
ExpectedType::String => value.is_string(),
ExpectedType::Array => value.is_array(),
ExpectedType::Object => value.is_object(),
}
}
pub fn type_name(&self) -> &'static str {
match self {
ExpectedType::Null => "null",
ExpectedType::Boolean => "boolean",
ExpectedType::Number => "number",
ExpectedType::String => "string",
ExpectedType::Array => "array",
ExpectedType::Object => "object",
}
}
}
impl std::fmt::Display for ExpectedType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.type_name())
}
}
fn value_type_name(value: &Value) -> &'static str {
match value {
Value::Null => "null",
Value::Bool(_) => "boolean",
Value::Number(_) => "number",
Value::String(_) => "string",
Value::Array(_) => "array",
Value::Object(_) => "object",
}
}
#[derive(Clone)]
pub struct StateSchemaValidator {
schema: StateSchema,
type_expectations: HashMap<String, ExpectedType>,
required_fields: Vec<String>,
}
impl std::fmt::Debug for StateSchemaValidator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StateSchemaValidator")
.field("type_expectations", &self.type_expectations)
.field("required_fields", &self.required_fields)
.finish_non_exhaustive()
}
}
impl StateSchemaValidator {
pub fn new(schema: StateSchema) -> Self {
Self { schema, type_expectations: HashMap::new(), required_fields: Vec::new() }
}
pub fn expect_type(mut self, field: &str, expected: ExpectedType) -> Self {
self.type_expectations.insert(field.to_string(), expected);
self
}
pub fn require_field(mut self, field: &str) -> Self {
if !self.required_fields.contains(&field.to_string()) {
self.required_fields.push(field.to_string());
}
self
}
pub fn schema(&self) -> &StateSchema {
&self.schema
}
pub fn validate_state(&self, state: &State) -> Result<(), FunctionalError> {
for field in &self.required_fields {
if !state.contains_key(field) {
return Err(FunctionalError::SchemaValidation {
field: field.clone(),
expected: "present".to_string(),
actual: "missing".to_string(),
});
}
}
for (field, expected_type) in &self.type_expectations {
if let Some(value) = state.get(field) {
if !expected_type.matches(value) {
return Err(FunctionalError::SchemaValidation {
field: field.clone(),
expected: expected_type.type_name().to_string(),
actual: value_type_name(value).to_string(),
});
}
}
}
Ok(())
}
pub fn validate_task_output(&self, output: &State) -> Result<(), FunctionalError> {
for (field, value) in output {
if let Some(expected_type) = self.type_expectations.get(field) {
if !expected_type.matches(value) {
return Err(FunctionalError::SchemaValidation {
field: field.clone(),
expected: expected_type.type_name().to_string(),
actual: value_type_name(value).to_string(),
});
}
}
}
Ok(())
}
pub fn apply_update(&self, state: &mut State, key: &str, value: Value) {
self.schema.apply_update(state, key, value);
}
}
impl Default for StateSchemaValidator {
fn default() -> Self {
Self::new(StateSchema::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_validate_state_passes_with_correct_types() {
let schema = StateSchema::builder().channel("name").counter_channel("count").build();
let validator = StateSchemaValidator::new(schema)
.expect_type("name", ExpectedType::String)
.expect_type("count", ExpectedType::Number)
.require_field("name");
let mut state = State::new();
state.insert("name".to_string(), json!("workflow_1"));
state.insert("count".to_string(), json!(0));
assert!(validator.validate_state(&state).is_ok());
}
#[test]
fn test_validate_state_fails_on_missing_required_field() {
let validator =
StateSchemaValidator::new(StateSchema::default()).require_field("required_field");
let state = State::new();
let err = validator.validate_state(&state).unwrap_err();
match err {
FunctionalError::SchemaValidation { field, expected, actual } => {
assert_eq!(field, "required_field");
assert_eq!(expected, "present");
assert_eq!(actual, "missing");
}
_ => panic!("unexpected error variant"),
}
}
#[test]
fn test_validate_state_fails_on_type_mismatch() {
let validator = StateSchemaValidator::new(StateSchema::default())
.expect_type("count", ExpectedType::Number);
let mut state = State::new();
state.insert("count".to_string(), json!("not_a_number"));
let err = validator.validate_state(&state).unwrap_err();
match err {
FunctionalError::SchemaValidation { field, expected, actual } => {
assert_eq!(field, "count");
assert_eq!(expected, "number");
assert_eq!(actual, "string");
}
_ => panic!("unexpected error variant"),
}
}
#[test]
fn test_validate_task_output_passes_with_correct_types() {
let validator = StateSchemaValidator::new(StateSchema::default())
.expect_type("result", ExpectedType::Object)
.expect_type("score", ExpectedType::Number);
let mut output = State::new();
output.insert("result".to_string(), json!({"key": "value"}));
output.insert("score".to_string(), json!(95));
assert!(validator.validate_task_output(&output).is_ok());
}
#[test]
fn test_validate_task_output_fails_on_type_mismatch() {
let validator = StateSchemaValidator::new(StateSchema::default())
.expect_type("items", ExpectedType::Array);
let mut output = State::new();
output.insert("items".to_string(), json!("not_an_array"));
let err = validator.validate_task_output(&output).unwrap_err();
match err {
FunctionalError::SchemaValidation { field, expected, actual } => {
assert_eq!(field, "items");
assert_eq!(expected, "array");
assert_eq!(actual, "string");
}
_ => panic!("unexpected error variant"),
}
}
#[test]
fn test_validate_state_skips_absent_optional_fields() {
let validator = StateSchemaValidator::new(StateSchema::default())
.expect_type("optional_field", ExpectedType::String);
let state = State::new();
assert!(validator.validate_state(&state).is_ok());
}
#[test]
fn test_validate_task_output_ignores_unknown_fields() {
let validator = StateSchemaValidator::new(StateSchema::default())
.expect_type("known", ExpectedType::Number);
let mut output = State::new();
output.insert("known".to_string(), json!(42));
output.insert("unknown".to_string(), json!("anything"));
assert!(validator.validate_task_output(&output).is_ok());
}
#[test]
fn test_expected_type_matches() {
assert!(ExpectedType::Null.matches(&json!(null)));
assert!(ExpectedType::Boolean.matches(&json!(true)));
assert!(ExpectedType::Number.matches(&json!(42)));
assert!(ExpectedType::Number.matches(&json!(3.14)));
assert!(ExpectedType::String.matches(&json!("hello")));
assert!(ExpectedType::Array.matches(&json!([1, 2, 3])));
assert!(ExpectedType::Object.matches(&json!({"key": "value"})));
assert!(!ExpectedType::Number.matches(&json!("42")));
assert!(!ExpectedType::String.matches(&json!(42)));
assert!(!ExpectedType::Array.matches(&json!({})));
}
}