use std::sync::OnceLock;
use crate::error::{CognisError, Result};
#[doc(hidden)]
pub mod __regex {
pub use regex::*;
}
pub trait ValidateArgs {
fn validate(&self) -> Result<()> {
Ok(())
}
}
pub fn check_range(field: &str, value: f64, min: Option<f64>, max: Option<f64>) -> Result<()> {
if value.is_nan() {
return Err(CognisError::ToolValidationError(format!(
"field `{field}`: value is NaN"
)));
}
if let Some(m) = min {
if value < m {
return Err(CognisError::ToolValidationError(format!(
"field `{field}`: {value} is less than minimum {m}"
)));
}
}
if let Some(m) = max {
if value > m {
return Err(CognisError::ToolValidationError(format!(
"field `{field}`: {value} is greater than maximum {m}"
)));
}
}
Ok(())
}
pub fn check_length(field: &str, len: usize, min: Option<usize>, max: Option<usize>) -> Result<()> {
if let Some(m) = min {
if len < m {
return Err(CognisError::ToolValidationError(format!(
"field `{field}`: length {len} is less than minimum {m}"
)));
}
}
if let Some(m) = max {
if len > m {
return Err(CognisError::ToolValidationError(format!(
"field `{field}`: length {len} is greater than maximum {m}"
)));
}
}
Ok(())
}
pub fn check_enum<S: AsRef<str>>(field: &str, value: &str, allowed: &[S]) -> Result<()> {
if allowed.iter().any(|a| a.as_ref() == value) {
return Ok(());
}
let list = allowed
.iter()
.map(|a| format!("`{}`", a.as_ref()))
.collect::<Vec<_>>()
.join(", ");
Err(CognisError::ToolValidationError(format!(
"field `{field}`: \"{value}\" must be one of [{list}]"
)))
}
pub fn check_pattern(field: &str, value: &str, re: ®ex::Regex) -> Result<()> {
if re.is_match(value) {
Ok(())
} else {
Err(CognisError::ToolValidationError(format!(
"field `{field}`: \"{value}\" does not match pattern `{pat}`",
pat = re.as_str(),
)))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Format {
Email,
Uri,
Uuid,
DateTime,
Ipv4,
Ipv6,
}
impl Format {
pub fn as_str(&self) -> &'static str {
match self {
Format::Email => "email",
Format::Uri => "uri",
Format::Uuid => "uuid",
Format::DateTime => "date-time",
Format::Ipv4 => "ipv4",
Format::Ipv6 => "ipv6",
}
}
}
fn email_regex() -> &'static regex::Regex {
static R: OnceLock<regex::Regex> = OnceLock::new();
R.get_or_init(|| regex::Regex::new(r"^[^\s@]+@[^\s@]+\.[^\s@]+$").expect("valid email regex"))
}
fn uuid_regex() -> &'static regex::Regex {
static R: OnceLock<regex::Regex> = OnceLock::new();
R.get_or_init(|| {
regex::Regex::new(r"(?i)^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$")
.expect("valid uuid regex")
})
}
pub fn check_format(field: &str, value: &str, format: Format) -> Result<()> {
match format {
Format::Email => {
if email_regex().is_match(value) {
Ok(())
} else {
Err(CognisError::ToolValidationError(format!(
"field `{field}`: \"{value}\" is not a valid email"
)))
}
}
Format::Uuid => {
if uuid_regex().is_match(value) {
Ok(())
} else {
Err(CognisError::ToolValidationError(format!(
"field `{field}`: \"{value}\" is not a valid uuid"
)))
}
}
Format::Ipv4 => value
.parse::<std::net::Ipv4Addr>()
.map(|_| ())
.map_err(|_| {
CognisError::ToolValidationError(format!(
"field `{field}`: \"{value}\" is not a valid ipv4"
))
}),
Format::Ipv6 => value
.parse::<std::net::Ipv6Addr>()
.map(|_| ())
.map_err(|_| {
CognisError::ToolValidationError(format!(
"field `{field}`: \"{value}\" is not a valid ipv6"
))
}),
Format::Uri | Format::DateTime => Ok(()),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_validation_err(r: Result<()>, needle: &str) {
match r {
Err(CognisError::ToolValidationError(msg)) => assert!(
msg.contains(needle),
"expected msg to contain {needle:?}, got {msg:?}"
),
other => panic!("expected ToolValidationError, got {other:?}"),
}
}
#[test]
fn range_passes_within_bounds() {
assert!(check_range("limit", 5.0, Some(1.0), Some(50.0)).is_ok());
}
#[test]
fn range_passes_on_boundaries_inclusive() {
assert!(check_range("x", 1.0, Some(1.0), Some(50.0)).is_ok());
assert!(check_range("x", 50.0, Some(1.0), Some(50.0)).is_ok());
}
#[test]
fn range_rejects_below_min() {
assert_validation_err(check_range("limit", 0.5, Some(1.0), Some(50.0)), "limit");
}
#[test]
fn range_rejects_above_max() {
assert_validation_err(
check_range("limit", 100.0, Some(1.0), Some(50.0)),
"maximum",
);
}
#[test]
fn range_with_only_min_ignores_max() {
assert!(check_range("x", 1e9, Some(0.0), None).is_ok());
}
#[test]
fn range_with_only_max_ignores_min() {
assert!(check_range("x", -1e9, None, Some(10.0)).is_ok());
}
#[test]
fn range_rejects_nan() {
assert_validation_err(check_range("x", f64::NAN, Some(0.0), Some(10.0)), "NaN");
}
#[test]
fn length_passes_within_bounds() {
assert!(check_length("name", 5, Some(1), Some(10)).is_ok());
}
#[test]
fn length_rejects_below_min() {
assert_validation_err(check_length("name", 0, Some(1), Some(10)), "minimum");
}
#[test]
fn length_rejects_above_max() {
assert_validation_err(check_length("name", 11, Some(1), Some(10)), "maximum");
}
#[test]
fn length_boundaries_inclusive() {
assert!(check_length("name", 1, Some(1), Some(10)).is_ok());
assert!(check_length("name", 10, Some(1), Some(10)).is_ok());
}
#[test]
fn enum_accepts_listed_value() {
assert!(check_enum("order", "asc", &["asc", "desc"]).is_ok());
assert!(check_enum("order", "desc", &["asc", "desc"]).is_ok());
}
#[test]
fn enum_rejects_unlisted_value() {
assert_validation_err(check_enum("order", "random", &["asc", "desc"]), "one of");
}
#[test]
fn enum_error_includes_allowed_values() {
let err = check_enum("order", "x", &["asc", "desc"]).unwrap_err();
let msg = err.to_string();
assert!(msg.contains("asc") && msg.contains("desc"), "got {msg}");
}
#[test]
fn enum_accepts_owned_strings() {
let allowed = vec!["asc".to_string(), "desc".to_string()];
assert!(check_enum("order", "asc", &allowed).is_ok());
assert_validation_err(check_enum("order", "nope", &allowed), "nope");
}
#[test]
fn pattern_accepts_match() {
let re = regex::Regex::new("^[a-z]+$").unwrap();
assert!(check_pattern("slug", "hello", &re).is_ok());
}
#[test]
fn pattern_rejects_non_match() {
let re = regex::Regex::new("^[a-z]+$").unwrap();
assert_validation_err(check_pattern("slug", "Hello", &re), "pattern");
}
#[test]
fn pattern_error_includes_pattern_string() {
let re = regex::Regex::new("^[a-z]+$").unwrap();
let err = check_pattern("slug", "Hi", &re).unwrap_err();
assert!(err.to_string().contains("^[a-z]+$"), "got {err}");
}
#[test]
fn format_email_accepts_basic_address() {
assert!(check_format("to", "a@b.c", Format::Email).is_ok());
}
#[test]
fn format_email_rejects_bare_token() {
assert_validation_err(check_format("to", "no-at-sign", Format::Email), "email");
}
#[test]
fn format_uuid_accepts_v4() {
assert!(check_format("id", "550e8400-e29b-41d4-a716-446655440000", Format::Uuid,).is_ok());
}
#[test]
fn format_uuid_rejects_non_uuid() {
assert_validation_err(check_format("id", "not-a-uuid", Format::Uuid), "uuid");
}
#[test]
fn format_ipv4_accepts() {
assert!(check_format("ip", "192.168.1.1", Format::Ipv4).is_ok());
}
#[test]
fn format_ipv4_rejects() {
assert_validation_err(check_format("ip", "300.1.1.1", Format::Ipv4), "ipv4");
}
#[test]
fn format_uri_is_schema_only_passthrough() {
assert!(check_format("link", "anything goes here", Format::Uri).is_ok());
assert!(check_format("t", "whatever", Format::DateTime).is_ok());
}
#[test]
fn format_ipv6_rejects_malformed_with_multiple_double_colons() {
assert_validation_err(check_format("ip", "::1::2", Format::Ipv6), "ipv6");
}
#[test]
fn format_ipv6_accepts_standard_forms() {
assert!(check_format("ip", "::1", Format::Ipv6).is_ok());
assert!(check_format("ip", "2001:db8::1", Format::Ipv6).is_ok());
assert!(check_format(
"ip",
"fe80:0000:0000:0000:0202:b3ff:fe1e:8329",
Format::Ipv6,
)
.is_ok());
}
#[test]
fn validate_args_default_impl_compiles() {
struct X;
impl ValidateArgs for X {}
assert!(X.validate().is_ok());
}
#[test]
fn validate_args_custom_impl_surfaces_error() {
struct Y {
limit: u32,
}
impl ValidateArgs for Y {
fn validate(&self) -> Result<()> {
check_range("limit", self.limit as f64, Some(1.0), Some(50.0))
}
}
assert!(Y { limit: 10 }.validate().is_ok());
assert_validation_err(Y { limit: 100 }.validate(), "maximum");
}
#[test]
fn format_as_str_returns_canonical_names() {
assert_eq!(Format::Email.as_str(), "email");
assert_eq!(Format::DateTime.as_str(), "date-time");
assert_eq!(Format::Uri.as_str(), "uri");
}
#[test]
fn length_with_no_bounds_accepts_anything() {
assert!(check_length("x", 0, None, None).is_ok());
assert!(check_length("x", 1_000_000, None, None).is_ok());
}
#[test]
fn enum_rejects_empty_string_when_not_listed() {
assert_validation_err(check_enum("x", "", &["a", "b"]), "must be one of");
}
#[test]
fn length_uses_unicode_char_count_contract() {
let s = "こんにちは";
assert_eq!(s.chars().count(), 5);
assert!(check_length("greeting", s.chars().count(), Some(1), Some(5)).is_ok());
assert_validation_err(
check_length("greeting", s.chars().count(), Some(1), Some(4)),
"maximum",
);
}
}