use super::super::*;
use super::super::load::load_rules;
use super::super::error::LoaderError;
use std::fs;
use tempfile::tempdir;
#[test]
fn test_load_valid_toml() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("rule.toml");
let toml_content = r#"
version = "1"
[[ops]]
name = "add"
arity = 2
inputs = ["i32", "i32"]
output = "i32"
declared_laws = ["commutative"]
reference_impl_id = "cpu_add"
[[witnesses]]
op = "add"
seed = 42
count = 100
distribution = "uniform"
"#;
fs::write(&file_path, toml_content).unwrap();
let registry = load_rules(dir.path()).unwrap();
assert_eq!(registry.ops.len(), 1);
assert_eq!(registry.ops["add"].arity, 2);
assert_eq!(registry.witnesses.len(), 1);
assert_eq!(registry.witnesses[0].seed, 42);
}
#[test]
fn test_unsupported_version() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("rule.toml");
let toml_content = r#"
version = "2"
"#;
fs::write(&file_path, toml_content).unwrap();
let err = load_rules(dir.path()).unwrap_err();
assert!(matches!(err, LoaderError::UnsupportedVersion(v) if v == "2"));
}
#[test]
fn test_malformed_toml() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("rule.toml");
let toml_content = r#"
version = "1"
[[ops] # Missing bracket
name = "add"
"#;
fs::write(&file_path, toml_content).unwrap();
let err = load_rules(dir.path()).unwrap_err();
assert!(matches!(err, LoaderError::Parse(_)));
}
#[test]
fn test_conflicting_override() {
let dir = tempdir().unwrap();
let file_path1 = dir.path().join("rule1.toml");
let file_path2 = dir.path().join("rule2.toml");
let toml_content1 = r#"
version = "1"
[[ops]]
name = "add"
arity = 2
inputs = ["i32", "i32"]
output = "i32"
declared_laws = []
reference_impl_id = "cpu_add_1"
"#;
let toml_content2 = r#"
version = "1"
[[ops]]
name = "add"
arity = 2
inputs = ["i32", "i32"]
output = "i32"
declared_laws = []
reference_impl_id = "cpu_add_2"
"#;
fs::write(&file_path1, toml_content1).unwrap();
fs::write(&file_path2, toml_content2).unwrap();
let registry = load_rules(dir.path()).unwrap();
assert_eq!(registry.ops.len(), 1);
assert!(registry.ops["add"]
.reference_impl_id
.starts_with("cpu_add_"));
}
#[test]
fn test_missing_version_field() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("rule.toml");
let toml_content = r#"
[[ops]]
name = "add"
arity = 2
inputs = ["i32", "i32"]
output = "i32"
declared_laws = []
reference_impl_id = "cpu_add"
"#;
fs::write(&file_path, toml_content).unwrap();
let err = load_rules(dir.path()).unwrap_err();
assert!(matches!(err, LoaderError::Parse(_)));
}
#[test]
fn test_toml_too_large() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("rule.toml");
let toml_content = format!(
r#"
version = "1"
[[ops]]
name = "add"
arity = 2
inputs = ["i32", "i32"]
output = "i32"
declared_laws = []
reference_impl_id = "cpu_add"
{}
"#,
"x".repeat(1_048_600)
);
fs::write(&file_path, toml_content).unwrap();
let err = load_rules(dir.path()).unwrap_err();
assert!(
matches!(err, LoaderError::TomlTooLarge { bytes, .. } if bytes > 1_048_576),
"expected TomlTooLarge, got: {err}"
);
}
#[test]
fn test_unreasonable_witness_count() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("rule.toml");
let toml_content = r#"
version = "1"
[[witnesses]]
op = "add"
seed = 42
count = 1000001
distribution = "uniform"
"#;
fs::write(&file_path, toml_content).unwrap();
let err = load_rules(dir.path()).unwrap_err();
assert!(
matches!(err, LoaderError::UnreasonableCount { count: 1_000_001 }),
"expected UnreasonableCount, got: {err}"
);
}
#[test]
fn test_deny_unknown_fields() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("rule.toml");
let toml_content = r#"
version = "1"
unknown_field = true
"#;
fs::write(&file_path, toml_content).unwrap();
let err = load_rules(dir.path()).unwrap_err();
assert!(
matches!(err, LoaderError::Parse(_)),
"expected Parse error for unknown field, got: {err}"
);
let err_str = err.to_string();
assert!(
err_str.contains("unknown field") && err_str.contains("unknown_field"),
"error should mention unknown field: {err_str}"
);
}