use std::fs::OpenOptions;
use std::io::{BufRead, BufReader};
use std::time::Instant;
use claudius::{
push_or_merge_message, Anthropic, ContentBlock, JsonSchema, MessageCreateParams, MessageParam,
MessageRole, Metadata, Model, SystemPrompt, TextBlock, ToolChoice,
};
use policyai::data::{EvaluationReport, Metrics, TestDataPoint};
use policyai::{ApplyError, Field, Manager, Policy, Report, Usage};
pub async fn naive_apply(
client: &Anthropic,
policies: &[Policy],
template: &MessageCreateParams,
text: &str,
usage: Option<&mut Usage>,
) -> Result<serde_json::Value, ApplyError> {
let mut req = template.clone();
req.metadata = Some(Metadata {
user_id: Some("baseline".into()),
});
req.system = Some(SystemPrompt::from_blocks(vec![TextBlock {
text: include_str!("../../prompts/manager_naive.md").to_string(),
cache_control: None,
citations: None,
}]));
let mut properties = serde_json::json! {{}};
if !policies.is_empty() {
push_or_merge_message(
&mut req.messages,
MessageParam::new_with_string(
format!(
"<default_value>{}</default_value>",
serde_json::to_string(&policies[0].r#type.default_value()).unwrap()
),
MessageRole::User,
),
);
}
for policy in policies.iter() {
let content = policy.prompt.clone();
for field in policy.r#type.fields.iter() {
match field {
Field::Bool {
name,
default: _,
on_conflict: _,
} => {
properties[name.clone()] = bool::json_schema();
}
Field::Number {
name,
default: _,
on_conflict: _,
} => {
properties[name.clone()] = f64::json_schema();
}
Field::String {
name,
default: _,
on_conflict: _,
} => {
properties[name.clone()] = String::json_schema();
}
Field::StringEnum {
name,
values,
default: _,
on_conflict: _,
} => {
let mut schema = String::json_schema();
if let serde_json::Value::Object(object) = &mut schema {
object.insert("enum".to_string(), values.clone().into());
}
properties[name.clone()] = schema;
}
Field::StringArray { name } => {
properties[name.clone()] = Vec::<String>::json_schema();
}
}
}
push_or_merge_message(
&mut req.messages,
MessageParam {
role: MessageRole::User,
content: format!("<rule>{content}</rule>").into(),
},
);
}
push_or_merge_message(
&mut req.messages,
MessageParam::new_with_string(format!("<text>{text}</text>"), MessageRole::User),
);
let mut schema = serde_json::json! {{}};
schema["type"] = "object".into();
schema["required"] = serde_json::Value::Array(vec![]);
schema["properties"] = properties;
req.tool_choice = Some(ToolChoice::tool("output_json"));
req.tools = Some(vec![claudius::ToolUnionParam::CustomTool(
claudius::ToolParam {
name: "output_json".to_string(),
description: Some("output JSON according to policy".to_string()),
input_schema: schema,
cache_control: None,
},
)]);
let start_time = Instant::now();
let resp = client.send(req).await?;
if let Some(u) = usage {
*u = Usage::new();
u.add_claudius_usage(resp.usage);
u.increment_iterations();
u.set_wall_clock_time(start_time.elapsed());
}
if resp.content.len() != 1 {
todo!();
}
let ContentBlock::ToolUse(t) = &resp.content[0] else {
todo!();
};
Ok(t.input.clone())
}
fn values_match(expected: &serde_json::Value, actual: &serde_json::Value) -> bool {
if expected == actual {
return true;
}
match (expected, actual) {
(serde_json::Value::Number(n1), serde_json::Value::Number(n2)) => {
let v1 = if let Some(f) = n1.as_f64() {
f
} else if let Some(i) = n1.as_i64() {
i as f64
} else if let Some(u) = n1.as_u64() {
u as f64
} else {
return false;
};
let v2 = if let Some(f) = n2.as_f64() {
f
} else if let Some(i) = n2.as_i64() {
i as f64
} else if let Some(u) = n2.as_u64() {
u as f64
} else {
return false;
};
if v1 == 0.0 && v2 == 0.0 {
true
} else if v1 == 0.0 || v2 == 0.0 {
false
} else {
let relative_diff = ((v1 - v2) / v1).abs();
relative_diff <= 0.00001 }
}
_ => false,
}
}
fn clean_baseline(baseline: &serde_json::Value) -> serde_json::Value {
if let serde_json::Value::Object(mut obj) = baseline.clone() {
obj.remove("__rule_numbers__");
serde_json::Value::Object(obj)
} else {
baseline.clone()
}
}
fn build_expected_with_defaults(
policies: &[Policy],
expected: Option<&serde_json::Value>,
) -> serde_json::Map<String, serde_json::Value> {
let mut result = serde_json::Map::new();
for policy in policies.iter() {
if let Some(defaults) = policy.r#type.default_value().as_object() {
for (k, v) in defaults {
result.entry(k.clone()).or_insert(v.clone());
}
}
}
if let Some(serde_json::Value::Object(expected)) = expected {
for (k, v) in expected {
result.insert(k.clone(), v.clone());
}
}
result
}
fn calculate_field_metrics(
expected: &serde_json::Map<String, serde_json::Value>,
actual: &serde_json::Value,
) -> (usize, usize, usize, usize) {
let mut matched = 0;
let mut wrong_value = 0;
let mut missing = 0;
let mut extra = 0;
let actual_map = actual.as_object();
for (k, expected_val) in expected {
if let Some(actual_obj) = actual_map {
if let Some(actual_val) = actual_obj.get(k) {
if values_match(expected_val, actual_val) {
matched += 1;
} else {
wrong_value += 1;
}
} else {
missing += 1;
}
} else {
missing += 1;
}
}
if let Some(actual_obj) = actual_map {
for k in actual_obj.keys() {
if k != "__rule_numbers__" && !expected.contains_key(k) {
extra += 1;
}
}
}
(matched, wrong_value, missing, extra)
}
#[tokio::main]
async fn main() {
let client = Anthropic::new(None).unwrap();
for file in std::env::args().skip(1) {
let file = OpenOptions::new()
.read(true)
.open(file)
.expect("could not read input");
let file = BufReader::new(file);
for line in file.lines() {
let line = line.expect("could not read data");
let point: TestDataPoint = match serde_json::from_str(&line) {
Ok(point) => point,
Err(err) => {
eprintln!("error parsing policy {line}: {err}");
continue;
}
};
let mut manager = Manager::default();
for policy in point.policies.iter() {
manager.add(policy.clone());
}
let expected = build_expected_with_defaults(&point.policies, point.expected.as_ref());
let mut metrics = Metrics::default();
let mut baseline_usage = Some(Usage::new());
let start = Instant::now();
let baseline = match naive_apply(
&client,
&point.policies,
&MessageCreateParams {
max_tokens: 4096,
model: Model::Custom("claude-sonnet-4-5".to_string()),
..Default::default()
},
&point.text,
baseline_usage.as_mut(),
)
.await
{
Ok(baseline) => Some(baseline),
Err(err) => {
metrics.baseline_error = Some(format!("{err:?}"));
None
}
};
metrics.baseline_apply_duration_ms = start.elapsed().as_millis() as u32;
metrics.baseline_usage = baseline_usage;
if let Some(ref baseline_val) = baseline {
let cleaned_baseline = clean_baseline(baseline_val);
let (matched, wrong, missing, extra) =
calculate_field_metrics(&expected, &cleaned_baseline);
metrics.baseline_fields_matched = matched;
metrics.baseline_fields_with_wrong_value = wrong;
metrics.baseline_fields_missing = missing;
metrics.baseline_extra_fields = extra;
}
let mut policyai_usage = Some(Usage::new());
let start = Instant::now();
let report = match manager
.apply(
&client,
MessageCreateParams {
max_tokens: 4096,
model: Model::Custom("claude-sonnet-4-5".to_string()),
..Default::default()
},
&point.text,
policyai_usage.as_mut(),
)
.await
{
Ok(returned) => returned,
Err(err) => {
metrics.policyai_error = Some(format!("{err:?}"));
metrics.policyai_apply_duration_ms = start.elapsed().as_millis() as u32;
Report::default()
}
};
metrics.policyai_apply_duration_ms = start.elapsed().as_millis() as u32;
metrics.policyai_usage = policyai_usage;
let output = report.value().clone();
let (matched, wrong, missing, extra) = calculate_field_metrics(&expected, &output);
metrics.policyai_fields_matched = matched;
metrics.policyai_fields_with_wrong_value = wrong;
metrics.policyai_fields_missing = missing;
metrics.policyai_extra_fields = extra;
let report = EvaluationReport {
input: point,
metrics,
report,
output,
baseline,
};
println!("{}", serde_json::to_string(&report).unwrap());
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn evaluation_report_minimal() {
let report = EvaluationReport {
input: TestDataPoint {
text: "test".to_string(),
policies: vec![],
expected: None,
conflicts: None,
},
metrics: Metrics::default(),
report: Report::default(),
output: serde_json::Value::Null,
baseline: None,
};
let serialized = serde_json::to_string(&report).unwrap();
assert!(serialized.contains("input"));
assert!(serialized.contains("metrics"));
assert!(serialized.contains("output"));
assert!(serialized.contains("baseline"));
}
#[test]
fn metrics_default_all_zeros() {
let metrics = Metrics::default();
assert_eq!(metrics.policyai_fields_matched, 0);
assert_eq!(metrics.policyai_fields_with_wrong_value, 0);
assert_eq!(metrics.policyai_fields_missing, 0);
assert_eq!(metrics.policyai_extra_fields, 0);
assert_eq!(metrics.baseline_fields_matched, 0);
assert_eq!(metrics.baseline_fields_with_wrong_value, 0);
assert_eq!(metrics.baseline_fields_missing, 0);
assert_eq!(metrics.baseline_extra_fields, 0);
assert!(metrics.policyai_error.is_none());
assert!(metrics.baseline_error.is_none());
assert_eq!(metrics.policyai_apply_duration_ms, 0);
assert_eq!(metrics.baseline_apply_duration_ms, 0);
assert!(metrics.policyai_usage.is_none());
assert!(metrics.baseline_usage.is_none());
}
#[test]
fn metrics_with_values() {
let metrics = Metrics {
policyai_fields_matched: 3,
policyai_fields_with_wrong_value: 1,
policyai_fields_missing: 2,
policyai_extra_fields: 1,
baseline_fields_matched: 2,
baseline_fields_with_wrong_value: 2,
baseline_fields_missing: 3,
baseline_extra_fields: 0,
policyai_error: Some("error1".to_string()),
baseline_error: Some("error2".to_string()),
policyai_apply_duration_ms: 100,
baseline_apply_duration_ms: 200,
policyai_usage: None,
baseline_usage: None,
};
assert_eq!(metrics.policyai_fields_matched, 3);
assert_eq!(metrics.policyai_fields_with_wrong_value, 1);
assert_eq!(metrics.policyai_fields_missing, 2);
assert_eq!(metrics.policyai_extra_fields, 1);
assert_eq!(metrics.baseline_fields_matched, 2);
assert_eq!(metrics.baseline_fields_with_wrong_value, 2);
assert_eq!(metrics.baseline_fields_missing, 3);
assert_eq!(metrics.baseline_extra_fields, 0);
assert_eq!(metrics.policyai_error, Some("error1".to_string()));
assert_eq!(metrics.baseline_error, Some("error2".to_string()));
assert_eq!(metrics.policyai_apply_duration_ms, 100);
assert_eq!(metrics.baseline_apply_duration_ms, 200);
}
#[test]
fn clean_baseline_removes_rule_numbers() {
let baseline = serde_json::json!({
"field1": "value1",
"field2": 42,
"__rule_numbers__": [1, 2, 3]
});
let cleaned = clean_baseline(&baseline);
let cleaned_obj = cleaned.as_object().unwrap();
assert!(cleaned_obj.contains_key("field1"));
assert!(cleaned_obj.contains_key("field2"));
assert!(!cleaned_obj.contains_key("__rule_numbers__"));
assert_eq!(cleaned_obj.len(), 2);
}
#[test]
fn clean_baseline_handles_missing_rule_numbers() {
let baseline = serde_json::json!({
"field1": "value1",
"field2": 42
});
let cleaned = clean_baseline(&baseline);
assert_eq!(cleaned, baseline);
}
#[test]
fn clean_baseline_handles_non_object() {
let baseline = serde_json::json!("not an object");
let cleaned = clean_baseline(&baseline);
assert_eq!(cleaned, baseline);
}
#[test]
fn calculate_field_metrics_ignores_rule_numbers() {
let expected = serde_json::json!({
"field1": "value1",
"field2": 42
});
let expected_map = expected.as_object().unwrap();
let actual = serde_json::json!({
"field1": "value1",
"field2": 42,
"__rule_numbers__": [1, 2]
});
let (matched, wrong, missing, extra) = calculate_field_metrics(expected_map, &actual);
assert_eq!(matched, 2);
assert_eq!(wrong, 0);
assert_eq!(missing, 0);
assert_eq!(extra, 0); }
#[test]
fn values_match_identical_numbers() {
assert!(values_match(&serde_json::json!(42), &serde_json::json!(42)));
assert!(values_match(
&serde_json::json!(2.71),
&serde_json::json!(2.71)
));
assert!(values_match(&serde_json::json!(0), &serde_json::json!(0)));
assert!(values_match(
&serde_json::json!(0.0),
&serde_json::json!(0.0)
));
}
#[test]
fn values_match_zero_equivalence() {
assert!(values_match(&serde_json::json!(0.0), &serde_json::json!(0)));
assert!(values_match(&serde_json::json!(0), &serde_json::json!(0.0)));
let zero_float = serde_json::Number::from_f64(0.0).unwrap();
let zero_int = serde_json::json!(0);
assert!(values_match(
&serde_json::Value::Number(zero_float),
&zero_int
));
}
#[test]
fn values_match_with_tolerance() {
assert!(values_match(
&serde_json::json!(1000.0),
&serde_json::json!(1000.009)
));
assert!(values_match(
&serde_json::json!(1000.0),
&serde_json::json!(999.991)
));
assert!(!values_match(
&serde_json::json!(1000.0),
&serde_json::json!(1000.011)
));
assert!(!values_match(
&serde_json::json!(1000.0),
&serde_json::json!(999.989)
));
}
#[test]
fn values_match_different_types() {
assert!(!values_match(
&serde_json::json!("42"),
&serde_json::json!(42)
));
assert!(!values_match(
&serde_json::json!(true),
&serde_json::json!(1)
));
assert!(!values_match(
&serde_json::json!(null),
&serde_json::json!(0)
));
}
#[test]
fn calculate_field_metrics_all_match() {
let expected = serde_json::json!({
"field1": "value1",
"field2": 42,
"field3": true
});
let expected_map = expected.as_object().unwrap();
let actual = serde_json::json!({
"field1": "value1",
"field2": 42,
"field3": true
});
let (matched, wrong, missing, extra) = calculate_field_metrics(expected_map, &actual);
assert_eq!(matched, 3);
assert_eq!(wrong, 0);
assert_eq!(missing, 0);
assert_eq!(extra, 0);
}
#[test]
fn calculate_field_metrics_numeric_tolerance() {
let expected = serde_json::json!({
"count": 1000.0,
"zero_float": 0.0,
"value": 42
});
let expected_map = expected.as_object().unwrap();
let actual = serde_json::json!({
"count": 1000.009, "zero_float": 0, "value": 42.0 });
let (matched, wrong, missing, extra) = calculate_field_metrics(expected_map, &actual);
assert_eq!(matched, 3);
assert_eq!(wrong, 0);
assert_eq!(missing, 0);
assert_eq!(extra, 0);
}
#[test]
fn calculate_field_metrics_with_wrong_values() {
let expected = serde_json::json!({
"field1": "value1",
"field2": 42,
"field3": true
});
let expected_map = expected.as_object().unwrap();
let actual = serde_json::json!({
"field1": "different",
"field2": 99,
"field3": true
});
let (matched, wrong, missing, extra) = calculate_field_metrics(expected_map, &actual);
assert_eq!(matched, 1); assert_eq!(wrong, 2); assert_eq!(missing, 0);
assert_eq!(extra, 0);
}
#[test]
fn calculate_field_metrics_with_missing_fields() {
let expected = serde_json::json!({
"field1": "value1",
"field2": 42,
"field3": true
});
let expected_map = expected.as_object().unwrap();
let actual = serde_json::json!({
"field1": "value1"
});
let (matched, wrong, missing, extra) = calculate_field_metrics(expected_map, &actual);
assert_eq!(matched, 1); assert_eq!(wrong, 0);
assert_eq!(missing, 2); assert_eq!(extra, 0);
}
#[test]
fn calculate_field_metrics_with_extra_fields() {
let expected = serde_json::json!({
"field1": "value1"
});
let expected_map = expected.as_object().unwrap();
let actual = serde_json::json!({
"field1": "value1",
"field2": 42,
"field3": true
});
let (matched, wrong, missing, extra) = calculate_field_metrics(expected_map, &actual);
assert_eq!(matched, 1); assert_eq!(wrong, 0);
assert_eq!(missing, 0);
assert_eq!(extra, 2); }
#[test]
fn calculate_field_metrics_empty_expected() {
let expected = serde_json::json!({});
let expected_map = expected.as_object().unwrap();
let actual = serde_json::json!({
"field1": "value1"
});
let (matched, wrong, missing, extra) = calculate_field_metrics(expected_map, &actual);
assert_eq!(matched, 0);
assert_eq!(wrong, 0);
assert_eq!(missing, 0);
assert_eq!(extra, 1);
}
#[test]
fn calculate_field_metrics_empty_actual() {
let expected = serde_json::json!({
"field1": "value1"
});
let expected_map = expected.as_object().unwrap();
let actual = serde_json::json!({});
let (matched, wrong, missing, extra) = calculate_field_metrics(expected_map, &actual);
assert_eq!(matched, 0);
assert_eq!(wrong, 0);
assert_eq!(missing, 1);
assert_eq!(extra, 0);
}
#[test]
fn calculate_field_metrics_both_empty() {
let expected = serde_json::json!({});
let expected_map = expected.as_object().unwrap();
let actual = serde_json::json!({});
let (matched, wrong, missing, extra) = calculate_field_metrics(expected_map, &actual);
assert_eq!(matched, 0);
assert_eq!(wrong, 0);
assert_eq!(missing, 0);
assert_eq!(extra, 0);
}
#[test]
fn calculate_field_metrics_actual_not_object() {
let expected = serde_json::json!({
"field1": "value1"
});
let expected_map = expected.as_object().unwrap();
let actual = serde_json::json!("not an object");
let (matched, wrong, missing, extra) = calculate_field_metrics(expected_map, &actual);
assert_eq!(matched, 0);
assert_eq!(wrong, 0);
assert_eq!(missing, 1); assert_eq!(extra, 0);
}
#[test]
fn evaluation_report_serialization() {
use policyai::{Field, Policy, PolicyType};
let policy_type = PolicyType {
name: "TestPolicy".to_string(),
fields: vec![Field::Bool {
name: "enabled".to_string(),
default: Some(false),
on_conflict: policyai::OnConflict::Default,
}],
};
let report = EvaluationReport {
input: TestDataPoint {
text: "test text".to_string(),
policies: vec![Policy {
r#type: policy_type,
prompt: "test".to_string(),
action: serde_json::json!({"enabled": true}),
}],
expected: Some(serde_json::json!({"enabled": true})),
conflicts: None,
},
metrics: Metrics {
policyai_fields_matched: 1,
policyai_fields_with_wrong_value: 0,
policyai_fields_missing: 0,
policyai_extra_fields: 0,
baseline_fields_matched: 1,
baseline_fields_with_wrong_value: 0,
baseline_fields_missing: 0,
baseline_extra_fields: 0,
policyai_error: None,
baseline_error: None,
policyai_apply_duration_ms: 50,
baseline_apply_duration_ms: 100,
policyai_usage: None,
baseline_usage: None,
},
report: Report::default(),
output: serde_json::json!({"enabled": true}),
baseline: Some(serde_json::json!({"enabled": true})),
};
let serialized = serde_json::to_string(&report).unwrap();
let deserialized: EvaluationReport = serde_json::from_str(&serialized).unwrap();
assert_eq!(
report.metrics.policyai_fields_matched,
deserialized.metrics.policyai_fields_matched
);
assert_eq!(
report.metrics.policyai_apply_duration_ms,
deserialized.metrics.policyai_apply_duration_ms
);
assert_eq!(report.output, deserialized.output);
assert_eq!(report.baseline, deserialized.baseline);
}
#[test]
fn metrics_clone() {
let original = Metrics {
policyai_fields_matched: 5,
policyai_fields_with_wrong_value: 2,
policyai_fields_missing: 1,
policyai_extra_fields: 3,
baseline_fields_matched: 4,
baseline_fields_with_wrong_value: 1,
baseline_fields_missing: 2,
baseline_extra_fields: 1,
policyai_error: Some("error".to_string()),
baseline_error: None,
policyai_apply_duration_ms: 150,
baseline_apply_duration_ms: 250,
policyai_usage: None,
baseline_usage: None,
};
let cloned = original.clone();
assert_eq!(
original.policyai_fields_matched,
cloned.policyai_fields_matched
);
assert_eq!(
original.policyai_fields_with_wrong_value,
cloned.policyai_fields_with_wrong_value
);
assert_eq!(
original.policyai_fields_missing,
cloned.policyai_fields_missing
);
assert_eq!(original.policyai_extra_fields, cloned.policyai_extra_fields);
assert_eq!(original.policyai_error, cloned.policyai_error);
assert_eq!(
original.policyai_apply_duration_ms,
cloned.policyai_apply_duration_ms
);
assert!(original.policyai_usage.is_none());
assert!(original.baseline_usage.is_none());
}
#[test]
fn metrics_debug() {
let metrics = Metrics {
policyai_fields_matched: 1,
policyai_fields_with_wrong_value: 2,
policyai_fields_missing: 3,
policyai_extra_fields: 4,
baseline_fields_matched: 5,
baseline_fields_with_wrong_value: 6,
baseline_fields_missing: 7,
baseline_extra_fields: 8,
policyai_error: None,
baseline_error: None,
policyai_apply_duration_ms: 100,
baseline_apply_duration_ms: 200,
policyai_usage: None,
baseline_usage: None,
};
let debug_str = format!("{metrics:?}");
assert!(debug_str.contains("Metrics"));
assert!(debug_str.contains("policyai_fields_matched"));
assert!(debug_str.contains("policyai_apply_duration_ms"));
}
#[test]
fn build_expected_with_defaults_no_expected() {
use policyai::{Field, PolicyType};
let policy_type = PolicyType {
name: "TestPolicy".to_string(),
fields: vec![
Field::Bool {
name: "enabled".to_string(),
default: Some(true),
on_conflict: policyai::OnConflict::Default,
},
Field::String {
name: "message".to_string(),
default: Some("hello".to_string()),
on_conflict: policyai::OnConflict::Agreement,
},
],
};
let policies = vec![Policy {
r#type: policy_type,
prompt: "test".to_string(),
action: serde_json::json!({}),
}];
let result = build_expected_with_defaults(&policies, None);
assert_eq!(result.len(), 2);
assert_eq!(result.get("enabled"), Some(&serde_json::json!(true)));
assert_eq!(result.get("message"), Some(&serde_json::json!("hello")));
}
#[test]
fn build_expected_with_defaults_merges_expected() {
use policyai::{Field, PolicyType};
let policy_type = PolicyType {
name: "TestPolicy".to_string(),
fields: vec![
Field::Bool {
name: "enabled".to_string(),
default: Some(true),
on_conflict: policyai::OnConflict::Default,
},
Field::String {
name: "message".to_string(),
default: Some("hello".to_string()),
on_conflict: policyai::OnConflict::Agreement,
},
Field::Number {
name: "count".to_string(),
default: Some(policyai::t64(0.0)),
on_conflict: policyai::OnConflict::LargestValue,
},
],
};
let policies = vec![Policy {
r#type: policy_type,
prompt: "test".to_string(),
action: serde_json::json!({}),
}];
let expected = serde_json::json!({
"message": "goodbye",
"count": 42
});
let result = build_expected_with_defaults(&policies, Some(&expected));
assert_eq!(result.len(), 3);
assert_eq!(result.get("enabled"), Some(&serde_json::json!(true)));
assert_eq!(result.get("message"), Some(&serde_json::json!("goodbye")));
assert_eq!(result.get("count"), Some(&serde_json::json!(42)));
}
#[test]
fn build_expected_with_defaults_handles_null_defaults() {
use policyai::{Field, PolicyType};
let policy_type = PolicyType {
name: "TestPolicy".to_string(),
fields: vec![
Field::String {
name: "optional".to_string(),
default: None,
on_conflict: policyai::OnConflict::Agreement,
},
Field::Bool {
name: "required".to_string(),
default: Some(false),
on_conflict: policyai::OnConflict::Default,
},
],
};
let policies = vec![Policy {
r#type: policy_type,
prompt: "test".to_string(),
action: serde_json::json!({}),
}];
let result = build_expected_with_defaults(&policies, None);
assert_eq!(result.len(), 1);
assert!(!result.contains_key("optional"));
assert_eq!(result.get("required"), Some(&serde_json::json!(false)));
}
#[test]
fn build_expected_with_defaults_string_array() {
use policyai::{Field, PolicyType};
let policy_type = PolicyType {
name: "TestPolicy".to_string(),
fields: vec![Field::StringArray {
name: "tags".to_string(),
}],
};
let policies = vec![Policy {
r#type: policy_type,
prompt: "test".to_string(),
action: serde_json::json!({}),
}];
let result = build_expected_with_defaults(&policies, None);
assert_eq!(result.len(), 0);
assert_eq!(result.get("tags"), None);
}
#[test]
fn build_expected_with_defaults_multiple_policies() {
use policyai::{Field, PolicyType};
let policy_type1 = PolicyType {
name: "Policy1".to_string(),
fields: vec![Field::Bool {
name: "field1".to_string(),
default: Some(true),
on_conflict: policyai::OnConflict::Default,
}],
};
let policy_type2 = PolicyType {
name: "Policy2".to_string(),
fields: vec![
Field::Bool {
name: "field1".to_string(),
default: Some(false),
on_conflict: policyai::OnConflict::Default,
},
Field::String {
name: "field2".to_string(),
default: Some("test".to_string()),
on_conflict: policyai::OnConflict::Agreement,
},
],
};
let policies = vec![
Policy {
r#type: policy_type1,
prompt: "test1".to_string(),
action: serde_json::json!({}),
},
Policy {
r#type: policy_type2,
prompt: "test2".to_string(),
action: serde_json::json!({}),
},
];
let result = build_expected_with_defaults(&policies, None);
assert_eq!(result.len(), 2);
assert_eq!(result.get("field1"), Some(&serde_json::json!(true)));
assert_eq!(result.get("field2"), Some(&serde_json::json!("test")));
}
}