use crate::schemas::get_schema_info;
use matchy_data_format::DataValue;
use matchy_format::EntryValidator;
use std::collections::HashMap;
use std::error::Error;
use std::fmt;
use thiserror::Error;
#[derive(Debug, Clone)]
pub struct SchemaValidationError {
pub errors: Vec<ValidationErrorDetail>,
}
impl fmt::Display for SchemaValidationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.errors.len() == 1 {
write!(f, "Schema validation failed: {}", self.errors[0])
} else {
writeln!(
f,
"Schema validation failed with {} errors:",
self.errors.len()
)?;
for (i, err) in self.errors.iter().enumerate() {
writeln!(f, " {}. {}", i + 1, err)?;
}
Ok(())
}
}
}
impl std::error::Error for SchemaValidationError {}
#[derive(Debug, Clone)]
pub struct ValidationErrorDetail {
pub path: String,
pub message: String,
}
impl fmt::Display for ValidationErrorDetail {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.path.is_empty() || self.path == "/" {
write!(f, "{}", self.message)
} else {
write!(f, "{}: {}", self.path, self.message)
}
}
}
#[derive(Debug, Error)]
pub enum SchemaError {
#[error("Unknown database type: '{0}'. Known types with validation: {1}")]
UnknownSchema(String, String),
}
const VALID_THREAT_LEVELS: &[&str] = &["critical", "high", "medium", "low", "unknown"];
const VALID_TLP: &[&str] = &["CLEAR", "GREEN", "AMBER", "AMBER+STRICT", "RED"];
pub struct SchemaValidator {
schema_name: String,
}
impl std::fmt::Debug for SchemaValidator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SchemaValidator")
.field("schema_name", &self.schema_name)
.finish_non_exhaustive()
}
}
impl SchemaValidator {
pub fn new(database_type: &str) -> Result<Self, SchemaError> {
let schema_name = if get_schema_info(database_type).is_some() {
database_type.to_string()
} else if let Some(short_name) =
crate::schemas::detect_schema_from_database_type(database_type)
{
short_name.to_string()
} else {
let available: Vec<_> = crate::schemas::available_schemas().collect();
return Err(SchemaError::UnknownSchema(
database_type.to_string(),
available.join(", "),
));
};
Ok(Self { schema_name })
}
#[must_use]
pub fn schema_name(&self) -> &str {
&self.schema_name
}
#[must_use]
pub fn database_type(&self) -> Option<&'static str> {
get_schema_info(&self.schema_name).map(|info| info.database_type)
}
pub fn validate(&self, data: &HashMap<String, DataValue>) -> Result<(), SchemaValidationError> {
let errors = self.validate_detailed(data);
if errors.is_empty() {
Ok(())
} else {
Err(SchemaValidationError { errors })
}
}
#[must_use]
pub fn validate_detailed(
&self,
data: &HashMap<String, DataValue>,
) -> Vec<ValidationErrorDetail> {
validate_threatdb(data)
}
}
fn validate_threatdb(data: &HashMap<String, DataValue>) -> Vec<ValidationErrorDetail> {
let mut errors = Vec::new();
match data.get("threat_level") {
None => errors.push(ValidationErrorDetail {
path: String::new(),
message: "\"threat_level\" is a required property".to_string(),
}),
Some(DataValue::String(s)) => {
if !VALID_THREAT_LEVELS.contains(&s.as_str()) {
errors.push(ValidationErrorDetail {
path: "/threat_level".to_string(),
message: format!(
"\"{s}\" is not one of [\"critical\", \"high\", \"medium\", \"low\", \"unknown\"]"
),
});
}
}
Some(_) => {
errors.push(ValidationErrorDetail {
path: "/threat_level".to_string(),
message: "expected string type".to_string(),
});
}
}
match data.get("category") {
None => errors.push(ValidationErrorDetail {
path: String::new(),
message: "\"category\" is a required property".to_string(),
}),
Some(DataValue::String(s)) => {
if s.is_empty() {
errors.push(ValidationErrorDetail {
path: "/category".to_string(),
message: "string length 0 is less than minLength 1".to_string(),
});
}
}
Some(_) => {
errors.push(ValidationErrorDetail {
path: "/category".to_string(),
message: "expected string type".to_string(),
});
}
}
match data.get("source") {
None => errors.push(ValidationErrorDetail {
path: String::new(),
message: "\"source\" is a required property".to_string(),
}),
Some(DataValue::String(s)) => {
if s.is_empty() {
errors.push(ValidationErrorDetail {
path: "/source".to_string(),
message: "string length 0 is less than minLength 1".to_string(),
});
}
}
Some(_) => {
errors.push(ValidationErrorDetail {
path: "/source".to_string(),
message: "expected string type".to_string(),
});
}
}
if let Some(v) = data.get("confidence") {
match v {
DataValue::Uint32(n) => {
if *n > 100 {
errors.push(ValidationErrorDetail {
path: "/confidence".to_string(),
message: format!("{n} is greater than the maximum of 100"),
});
}
}
DataValue::Int32(n) => {
if *n < 0 {
errors.push(ValidationErrorDetail {
path: "/confidence".to_string(),
message: format!("{n} is less than the minimum of 0"),
});
} else if *n > 100 {
errors.push(ValidationErrorDetail {
path: "/confidence".to_string(),
message: format!("{n} is greater than the maximum of 100"),
});
}
}
DataValue::Uint64(n) => {
if *n > 100 {
errors.push(ValidationErrorDetail {
path: "/confidence".to_string(),
message: format!("{n} is greater than the maximum of 100"),
});
}
}
DataValue::Uint16(n) => {
if *n > 100 {
errors.push(ValidationErrorDetail {
path: "/confidence".to_string(),
message: format!("{n} is greater than the maximum of 100"),
});
}
}
_ => {
errors.push(ValidationErrorDetail {
path: "/confidence".to_string(),
message: "expected integer type".to_string(),
});
}
}
}
if let Some(v) = data.get("tlp") {
match v {
DataValue::String(s) => {
if !VALID_TLP.contains(&s.as_str()) {
errors.push(ValidationErrorDetail {
path: "/tlp".to_string(),
message: format!(
"\"{s}\" is not one of [\"CLEAR\", \"GREEN\", \"AMBER\", \"AMBER+STRICT\", \"RED\"]"
),
});
}
}
_ => {
errors.push(ValidationErrorDetail {
path: "/tlp".to_string(),
message: "expected string type".to_string(),
});
}
}
}
if let Some(v) = data.get("tags") {
match v {
DataValue::Array(arr) => {
for (i, item) in arr.iter().enumerate() {
if !matches!(item, DataValue::String(_)) {
errors.push(ValidationErrorDetail {
path: format!("/tags/{i}"),
message: "expected string type".to_string(),
});
}
}
}
_ => {
errors.push(ValidationErrorDetail {
path: "/tags".to_string(),
message: "expected array type".to_string(),
});
}
}
}
if let Some(v) = data.get("indicator_type") {
match v {
DataValue::String(s) => {
if s.is_empty() {
errors.push(ValidationErrorDetail {
path: "/indicator_type".to_string(),
message: "string length 0 is less than minLength 1".to_string(),
});
}
}
_ => {
errors.push(ValidationErrorDetail {
path: "/indicator_type".to_string(),
message: "expected string type".to_string(),
});
}
}
}
for field in &["description", "reference"] {
if let Some(v) = data.get(*field) {
if !matches!(v, DataValue::String(_)) {
errors.push(ValidationErrorDetail {
path: format!("/{field}"),
message: "expected string type".to_string(),
});
}
}
}
for field in &["first_seen", "last_seen"] {
if let Some(v) = data.get(*field) {
if !matches!(v, DataValue::String(_) | DataValue::Timestamp(_)) {
errors.push(ValidationErrorDetail {
path: format!("/{field}"),
message: "expected string or timestamp type".to_string(),
});
}
}
}
errors
}
impl EntryValidator for SchemaValidator {
fn validate(
&self,
key: &str,
data: &HashMap<String, DataValue>,
) -> Result<(), Box<dyn Error + Send + Sync>> {
self.validate(data).map_err(|e| {
let error_msg = format!("Entry '{key}': {e}");
Box::new(SchemaValidationError {
errors: vec![ValidationErrorDetail {
path: String::new(),
message: error_msg,
}],
}) as Box<dyn Error + Send + Sync>
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn valid_threatdb_data() -> HashMap<String, DataValue> {
let mut data = HashMap::new();
data.insert(
"threat_level".to_string(),
DataValue::String("high".to_string()),
);
data.insert(
"category".to_string(),
DataValue::String("malware".to_string()),
);
data.insert(
"source".to_string(),
DataValue::String("abuse.ch".to_string()),
);
data
}
#[test]
fn test_validator_creation() {
let validator = SchemaValidator::new("threatdb").expect("should create validator");
assert_eq!(validator.schema_name(), "threatdb");
assert_eq!(validator.database_type(), Some("ThreatDB-v1"));
}
#[test]
fn test_validator_creation_from_canonical_name() {
let validator = SchemaValidator::new("ThreatDB-v1").expect("should create validator");
assert_eq!(validator.schema_name(), "threatdb");
assert_eq!(validator.database_type(), Some("ThreatDB-v1"));
}
#[test]
fn test_unknown_schema() {
let result = SchemaValidator::new("nonexistent");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, SchemaError::UnknownSchema(_, _)));
}
#[test]
fn test_valid_threatdb_record() {
let validator = SchemaValidator::new("threatdb").unwrap();
let data = valid_threatdb_data();
assert!(validator.validate(&data).is_ok());
}
#[test]
fn test_valid_threatdb_with_optional_fields() {
let validator = SchemaValidator::new("threatdb").unwrap();
let mut data = valid_threatdb_data();
data.insert("confidence".to_string(), DataValue::Uint32(85));
data.insert(
"description".to_string(),
DataValue::String("Known malware C2".to_string()),
);
data.insert("tlp".to_string(), DataValue::String("AMBER".to_string()));
assert!(validator.validate(&data).is_ok());
}
#[test]
fn test_missing_required_field() {
let validator = SchemaValidator::new("threatdb").unwrap();
let mut data = HashMap::new();
data.insert(
"threat_level".to_string(),
DataValue::String("high".to_string()),
);
let result = validator.validate(&data);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(!err.errors.is_empty());
let error_text = format!("{err}");
assert!(
error_text.contains("category") || error_text.contains("source"),
"Error should mention missing field: {error_text}"
);
}
#[test]
fn test_invalid_enum_value() {
let validator = SchemaValidator::new("threatdb").unwrap();
let mut data = valid_threatdb_data();
data.insert(
"threat_level".to_string(),
DataValue::String("super-critical".to_string()), );
let result = validator.validate(&data);
assert!(result.is_err());
let err = result.unwrap_err();
let error_text = format!("{err}");
assert!(
error_text.contains("threat_level"),
"Error should mention invalid enum: {error_text}"
);
}
#[test]
fn test_invalid_confidence_range() {
let validator = SchemaValidator::new("threatdb").unwrap();
let mut data = valid_threatdb_data();
data.insert("confidence".to_string(), DataValue::Uint32(150));
let result = validator.validate(&data);
assert!(result.is_err());
let err = result.unwrap_err();
let error_text = format!("{err}");
assert!(
error_text.contains("confidence") || error_text.contains("maximum"),
"Error should mention confidence range: {error_text}"
);
}
#[test]
fn test_invalid_tlp_value() {
let validator = SchemaValidator::new("threatdb").unwrap();
let mut data = valid_threatdb_data();
data.insert(
"tlp".to_string(),
DataValue::String("purple".to_string()), );
let result = validator.validate(&data);
assert!(result.is_err());
}
#[test]
fn test_wrong_type_for_field() {
let validator = SchemaValidator::new("threatdb").unwrap();
let mut data = valid_threatdb_data();
data.insert(
"confidence".to_string(),
DataValue::String("high".to_string()),
);
let result = validator.validate(&data);
assert!(result.is_err());
}
#[test]
fn test_additional_properties_allowed() {
let validator = SchemaValidator::new("threatdb").unwrap();
let mut data = valid_threatdb_data();
data.insert(
"custom_field".to_string(),
DataValue::String("custom value".to_string()),
);
assert!(validator.validate(&data).is_ok());
}
#[test]
fn test_tags_array() {
let validator = SchemaValidator::new("threatdb").unwrap();
let mut data = valid_threatdb_data();
data.insert(
"tags".to_string(),
DataValue::Array(vec![
DataValue::String("emotet".to_string()),
DataValue::String("banking-trojan".to_string()),
]),
);
assert!(validator.validate(&data).is_ok());
}
#[test]
fn test_validate_detailed() {
let validator = SchemaValidator::new("threatdb").unwrap();
let data = HashMap::new();
let errors = validator.validate_detailed(&data);
assert!(!errors.is_empty());
}
#[test]
fn test_error_display() {
let err = SchemaValidationError {
errors: vec![
ValidationErrorDetail {
path: "/threat_level".to_string(),
message: "value must be one of: critical, high, medium, low, unknown"
.to_string(),
},
ValidationErrorDetail {
path: "/confidence".to_string(),
message: "value must be <= 100".to_string(),
},
],
};
let display = format!("{err}");
assert!(display.contains("2 errors"));
assert!(display.contains("threat_level"));
assert!(display.contains("confidence"));
}
#[test]
fn test_timestamp_fields_accept_string_and_timestamp() {
let validator = SchemaValidator::new("threatdb").unwrap();
let mut data_with_string = valid_threatdb_data();
data_with_string.insert(
"first_seen".to_string(),
DataValue::String("2025-10-02T18:44:31Z".to_string()),
);
assert!(validator.validate(&data_with_string).is_ok());
let mut data_with_timestamp = valid_threatdb_data();
data_with_timestamp.insert("first_seen".to_string(), DataValue::Timestamp(1727891071));
assert!(validator.validate(&data_with_timestamp).is_ok());
let mut data_with_wrong_type = valid_threatdb_data();
data_with_wrong_type.insert("first_seen".to_string(), DataValue::Uint32(12345));
assert!(validator.validate(&data_with_wrong_type).is_err());
}
}