1use serde::Serialize;
9
10use super::{RivetType, TypeFidelity, TypeMapping};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
14#[serde(rename_all = "snake_case")]
15pub enum PolicyAction {
16 Fail,
18 Warn,
20 Allow,
22}
23
24#[derive(Debug, Clone)]
29pub struct TypePolicy {
30 pub on_lossy_mapping: PolicyAction,
32 pub on_unsupported_type: PolicyAction,
34}
35
36impl Default for TypePolicy {
37 fn default() -> Self {
38 Self::strict()
39 }
40}
41
42impl TypePolicy {
43 pub fn strict() -> Self {
45 Self {
46 on_lossy_mapping: PolicyAction::Fail,
47 on_unsupported_type: PolicyAction::Fail,
48 }
49 }
50
51 pub fn warn_only() -> Self {
54 Self {
55 on_lossy_mapping: PolicyAction::Warn,
56 on_unsupported_type: PolicyAction::Warn,
57 }
58 }
59}
60
61#[derive(Debug, Clone, Serialize)]
63pub struct PolicyViolation {
64 pub column_name: String,
66 pub fidelity: TypeFidelity,
68 pub message: String,
70 pub fatal: bool,
72}
73
74impl TypePolicy {
75 pub fn validate(&self, mappings: &[TypeMapping]) -> Vec<PolicyViolation> {
77 let mut out = Vec::new();
78 for m in mappings {
79 let (action, fidelity) = match m.fidelity {
80 TypeFidelity::Lossy => (self.on_lossy_mapping, TypeFidelity::Lossy),
81 TypeFidelity::Unsupported => (self.on_unsupported_type, TypeFidelity::Unsupported),
82 _ => continue,
83 };
84 if action == PolicyAction::Allow {
85 continue;
86 }
87 let detail = match &m.rivet_type {
88 RivetType::Unsupported { reason, .. } => format!(": {}", reason),
89 _ => String::new(),
90 };
91 out.push(PolicyViolation {
92 column_name: m.column_name.clone(),
93 fidelity,
94 message: format!(
95 "column '{}' (source type '{}'): fidelity={}{}",
96 m.column_name,
97 m.source_native_type,
98 fidelity.label(),
99 detail
100 ),
101 fatal: action == PolicyAction::Fail,
102 });
103 }
104 out
105 }
106
107 #[allow(dead_code)]
109 pub fn check_fail(&self, violations: &[PolicyViolation]) -> crate::error::Result<()> {
110 let fatal: Vec<&str> = violations
111 .iter()
112 .filter(|v| v.fatal)
113 .map(|v| v.message.as_str())
114 .collect();
115 if !fatal.is_empty() {
116 anyhow::bail!(
117 "strict mode: {} unsafe type mapping(s):\n{}",
118 fatal.len(),
119 fatal.join("\n")
120 );
121 }
122 Ok(())
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129 use crate::types::{SourceColumn, TypeMapping};
130
131 fn unsupported_mapping(name: &str, native: &str) -> TypeMapping {
132 let col = SourceColumn::simple(name, native, true);
133 TypeMapping::from_source(
134 &col,
135 RivetType::Unsupported {
136 native_type: native.into(),
137 reason: "test reason".into(),
138 },
139 )
140 }
141
142 fn exact_mapping(name: &str, native: &str) -> TypeMapping {
143 let col = SourceColumn::simple(name, native, true);
144 TypeMapping::from_source(&col, crate::types::RivetType::Int64)
145 }
146
147 #[test]
148 fn strict_policy_fails_on_unsupported() {
149 let policy = TypePolicy::strict();
150 let mappings = vec![
151 exact_mapping("id", "int8"),
152 unsupported_mapping("location", "geometry"),
153 ];
154 let violations = policy.validate(&mappings);
155 assert_eq!(violations.len(), 1);
156 assert!(violations[0].fatal);
157 assert_eq!(violations[0].column_name, "location");
158 assert!(policy.check_fail(&violations).is_err());
159 }
160
161 #[test]
162 fn warn_only_policy_does_not_fail() {
163 let policy = TypePolicy::warn_only();
164 let mappings = vec![unsupported_mapping("dur", "interval")];
165 let violations = policy.validate(&mappings);
166 assert_eq!(violations.len(), 1);
167 assert!(!violations[0].fatal);
168 assert!(policy.check_fail(&violations).is_ok());
169 }
170
171 #[test]
172 fn allow_policy_produces_no_violations() {
173 let policy = TypePolicy {
174 on_lossy_mapping: PolicyAction::Allow,
175 on_unsupported_type: PolicyAction::Allow,
176 };
177 let mappings = vec![unsupported_mapping("x", "hstore")];
178 assert!(policy.validate(&mappings).is_empty());
179 }
180
181 #[test]
182 fn exact_mappings_never_produce_violations() {
183 let policy = TypePolicy::strict();
184 let mappings = vec![
185 exact_mapping("id", "int8"),
186 TypeMapping::from_source(
187 &SourceColumn::simple("name", "text", true),
188 crate::types::RivetType::Text,
189 ),
190 ];
191 assert!(policy.validate(&mappings).is_empty());
192 }
193}