use serde::Serialize;
use super::{RivetType, TypeFidelity, TypeMapping};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum PolicyAction {
Fail,
Warn,
Allow,
}
#[derive(Debug, Clone)]
pub struct TypePolicy {
pub on_lossy_mapping: PolicyAction,
pub on_unsupported_type: PolicyAction,
}
impl Default for TypePolicy {
fn default() -> Self {
Self::strict()
}
}
impl TypePolicy {
pub fn strict() -> Self {
Self {
on_lossy_mapping: PolicyAction::Fail,
on_unsupported_type: PolicyAction::Fail,
}
}
pub fn warn_only() -> Self {
Self {
on_lossy_mapping: PolicyAction::Warn,
on_unsupported_type: PolicyAction::Warn,
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct PolicyViolation {
pub column_name: String,
pub fidelity: TypeFidelity,
pub message: String,
pub fatal: bool,
}
impl TypePolicy {
pub fn validate(&self, mappings: &[TypeMapping]) -> Vec<PolicyViolation> {
let mut out = Vec::new();
for m in mappings {
let (action, fidelity) = match m.fidelity {
TypeFidelity::Lossy => (self.on_lossy_mapping, TypeFidelity::Lossy),
TypeFidelity::Unsupported => (self.on_unsupported_type, TypeFidelity::Unsupported),
_ => continue,
};
if action == PolicyAction::Allow {
continue;
}
let detail = match &m.rivet_type {
RivetType::Unsupported { reason, .. } => format!(": {}", reason),
_ => String::new(),
};
out.push(PolicyViolation {
column_name: m.column_name.clone(),
fidelity,
message: format!(
"column '{}' (source type '{}'): fidelity={}{}",
m.column_name,
m.source_native_type,
fidelity.label(),
detail
),
fatal: action == PolicyAction::Fail,
});
}
out
}
#[allow(dead_code)]
pub fn check_fail(&self, violations: &[PolicyViolation]) -> crate::error::Result<()> {
let fatal: Vec<&str> = violations
.iter()
.filter(|v| v.fatal)
.map(|v| v.message.as_str())
.collect();
if !fatal.is_empty() {
anyhow::bail!(
"strict mode: {} unsafe type mapping(s):\n{}",
fatal.len(),
fatal.join("\n")
);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{SourceColumn, TypeMapping};
fn unsupported_mapping(name: &str, native: &str) -> TypeMapping {
let col = SourceColumn::simple(name, native, true);
TypeMapping::from_source(
&col,
RivetType::Unsupported {
native_type: native.into(),
reason: "test reason".into(),
},
)
}
fn exact_mapping(name: &str, native: &str) -> TypeMapping {
let col = SourceColumn::simple(name, native, true);
TypeMapping::from_source(&col, crate::types::RivetType::Int64)
}
#[test]
fn strict_policy_fails_on_unsupported() {
let policy = TypePolicy::strict();
let mappings = vec![
exact_mapping("id", "int8"),
unsupported_mapping("location", "geometry"),
];
let violations = policy.validate(&mappings);
assert_eq!(violations.len(), 1);
assert!(violations[0].fatal);
assert_eq!(violations[0].column_name, "location");
assert!(policy.check_fail(&violations).is_err());
}
#[test]
fn warn_only_policy_does_not_fail() {
let policy = TypePolicy::warn_only();
let mappings = vec![unsupported_mapping("dur", "interval")];
let violations = policy.validate(&mappings);
assert_eq!(violations.len(), 1);
assert!(!violations[0].fatal);
assert!(policy.check_fail(&violations).is_ok());
}
#[test]
fn allow_policy_produces_no_violations() {
let policy = TypePolicy {
on_lossy_mapping: PolicyAction::Allow,
on_unsupported_type: PolicyAction::Allow,
};
let mappings = vec![unsupported_mapping("x", "hstore")];
assert!(policy.validate(&mappings).is_empty());
}
#[test]
fn exact_mappings_never_produce_violations() {
let policy = TypePolicy::strict();
let mappings = vec![
exact_mapping("id", "int8"),
TypeMapping::from_source(
&SourceColumn::simple("name", "text", true),
crate::types::RivetType::Text,
),
];
assert!(policy.validate(&mappings).is_empty());
}
}