1use crate::roles::Role;
6use kimberlite_types::{SqlIdentifier, TenantId};
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
11pub struct StreamFilter {
12 pub pattern: String,
19
20 pub allow: bool,
22}
23
24impl StreamFilter {
25 pub fn new(pattern: impl Into<String>, allow: bool) -> Self {
27 Self {
28 pattern: pattern.into(),
29 allow,
30 }
31 }
32
33 pub fn matches(&self, stream_name: &str) -> bool {
35 let pattern = &self.pattern;
37
38 if pattern == "*" {
39 return true;
40 }
41
42 if pattern.ends_with('*') {
43 let prefix = &pattern[..pattern.len() - 1];
44 return stream_name.starts_with(prefix);
45 }
46
47 if let Some(suffix) = pattern.strip_prefix('*') {
48 return stream_name.ends_with(suffix);
49 }
50
51 stream_name == pattern
53 }
54}
55
56#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
58pub struct ColumnFilter {
59 pub pattern: String,
66
67 pub allow: bool,
69}
70
71impl ColumnFilter {
72 pub fn new(pattern: impl Into<String>, allow: bool) -> Self {
74 Self {
75 pattern: pattern.into(),
76 allow,
77 }
78 }
79
80 pub fn matches(&self, column_name: &str) -> bool {
95 if let Ok(id) = SqlIdentifier::try_new(self.pattern.clone()) {
96 return id.matches(column_name);
97 }
98 let lhs = self.pattern.to_ascii_lowercase();
102 let rhs = column_name.to_ascii_lowercase();
103 lhs == rhs
104 }
105}
106
107#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
109pub struct RowFilter {
110 pub column: String,
112
113 pub operator: RowFilterOperator,
115
116 pub value: String,
118}
119
120impl RowFilter {
121 pub fn new(
123 column: impl Into<String>,
124 operator: RowFilterOperator,
125 value: impl Into<String>,
126 ) -> Self {
127 Self {
128 column: column.into(),
129 operator,
130 value: value.into(),
131 }
132 }
133}
134
135#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
137pub enum RowFilterOperator {
138 Eq,
140
141 Ne,
143
144 Lt,
146
147 Le,
149
150 Gt,
152
153 Ge,
155
156 In,
158
159 NotIn,
161}
162
163impl RowFilterOperator {
164 pub fn to_sql(&self) -> &'static str {
166 match self {
167 RowFilterOperator::Eq => "=",
168 RowFilterOperator::Ne => "!=",
169 RowFilterOperator::Lt => "<",
170 RowFilterOperator::Le => "<=",
171 RowFilterOperator::Gt => ">",
172 RowFilterOperator::Ge => ">=",
173 RowFilterOperator::In => "IN",
174 RowFilterOperator::NotIn => "NOT IN",
175 }
176 }
177}
178
179#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
183pub struct AccessPolicy {
184 pub role: Role,
186
187 pub tenant_id: Option<TenantId>,
191
192 pub stream_filters: Vec<StreamFilter>,
200
201 pub column_filters: Vec<ColumnFilter>,
205
206 pub row_filters: Vec<RowFilter>,
210
211 pub masking_policy: Option<crate::masking::MaskingPolicy>,
216}
217
218impl AccessPolicy {
219 pub fn new(role: Role) -> Self {
221 Self {
222 role,
223 tenant_id: None,
224 stream_filters: Vec::new(),
225 column_filters: Vec::new(),
226 row_filters: Vec::new(),
227 masking_policy: None,
228 }
229 }
230
231 pub fn with_tenant(mut self, tenant_id: TenantId) -> Self {
233 self.tenant_id = Some(tenant_id);
234 self
235 }
236
237 pub fn with_masking(mut self, policy: crate::masking::MaskingPolicy) -> Self {
241 self.masking_policy = Some(policy);
242 self
243 }
244
245 pub fn allow_stream(mut self, pattern: impl Into<String>) -> Self {
247 self.stream_filters.push(StreamFilter::new(pattern, true));
248 self
249 }
250
251 pub fn deny_stream(mut self, pattern: impl Into<String>) -> Self {
253 self.stream_filters.push(StreamFilter::new(pattern, false));
254 self
255 }
256
257 pub fn allow_column(mut self, pattern: impl Into<String>) -> Self {
259 self.column_filters.push(ColumnFilter::new(pattern, true));
260 self
261 }
262
263 pub fn deny_column(mut self, pattern: impl Into<String>) -> Self {
265 self.column_filters.push(ColumnFilter::new(pattern, false));
266 self
267 }
268
269 pub fn with_row_filter(mut self, filter: RowFilter) -> Self {
271 self.row_filters.push(filter);
272 self
273 }
274
275 pub fn allows_stream(&self, stream_name: &str) -> bool {
277 for filter in &self.stream_filters {
279 if !filter.allow && filter.matches(stream_name) {
280 return false; }
282 }
283
284 for filter in &self.stream_filters {
286 if filter.allow && filter.matches(stream_name) {
287 return true; }
289 }
290
291 false
293 }
294
295 pub fn allows_column(&self, column_name: &str) -> bool {
297 for filter in &self.column_filters {
299 if !filter.allow && filter.matches(column_name) {
300 return false; }
302 }
303
304 for filter in &self.column_filters {
306 if filter.allow && filter.matches(column_name) {
307 return true; }
309 }
310
311 false
313 }
314
315 pub fn row_filters(&self) -> &[RowFilter] {
317 &self.row_filters
318 }
319}
320
321pub struct StandardPolicies;
323
324impl StandardPolicies {
325 pub fn admin() -> AccessPolicy {
332 AccessPolicy::new(Role::Admin)
333 .allow_stream("*") .allow_column("*") }
336
337 pub fn analyst() -> AccessPolicy {
344 AccessPolicy::new(Role::Analyst)
345 .allow_stream("*")
346 .deny_stream("audit_*") .allow_column("*")
348 }
349
350 pub fn user(tenant_id: TenantId) -> AccessPolicy {
357 AccessPolicy::new(Role::User)
358 .with_tenant(tenant_id)
359 .allow_stream("*")
360 .allow_column("*")
361 .with_row_filter(RowFilter::new(
362 "tenant_id",
363 RowFilterOperator::Eq,
364 u64::from(tenant_id).to_string(),
365 ))
366 }
367
368 pub fn auditor() -> AccessPolicy {
375 AccessPolicy::new(Role::Auditor)
376 .allow_stream("audit_*") .allow_column("*")
378 }
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384
385 #[test]
386 fn test_stream_filter_wildcard() {
387 let filter = StreamFilter::new("patient_*", true);
388
389 assert!(filter.matches("patient_records"));
390 assert!(filter.matches("patient_vitals"));
391 assert!(!filter.matches("audit_log"));
392
393 let all_filter = StreamFilter::new("*", true);
394 assert!(all_filter.matches("any_stream"));
395 }
396
397 #[test]
398 fn test_column_filter_wildcard() {
399 let filter = ColumnFilter::new("pii_*", false); assert!(filter.matches("pii_ssn"));
402 assert!(filter.matches("pii_address"));
403 assert!(!filter.matches("public_name"));
404 }
405
406 #[test]
412 fn test_column_filter_case_insensitive() {
413 let deny_name = ColumnFilter::new("NAME", false);
415 assert!(deny_name.matches("name"));
416 assert!(deny_name.matches("NAME"));
417 assert!(deny_name.matches("Name"));
418 assert!(!deny_name.matches("full_name"));
419
420 let deny_pii = ColumnFilter::new("PII_*", false);
422 assert!(deny_pii.matches("pii_ssn"));
423 assert!(deny_pii.matches("PII_ADDRESS"));
424 assert!(deny_pii.matches("Pii_Date_Of_Birth"));
425
426 let deny_secret = ColumnFilter::new("*_SECRET", false);
428 assert!(deny_secret.matches("internal_secret"));
429 assert!(deny_secret.matches("API_Secret"));
430 }
431
432 #[test]
435 fn test_policy_column_access_case_insensitive() {
436 let policy = AccessPolicy::new(Role::Analyst)
437 .allow_column("*")
438 .deny_column("NAME");
439
440 assert!(!policy.allows_column("name"));
441 assert!(!policy.allows_column("NAME"));
442 assert!(!policy.allows_column("Name"));
443 }
444
445 #[test]
446 fn test_policy_stream_access() {
447 let policy = AccessPolicy::new(Role::User)
448 .allow_stream("patient_*")
449 .deny_stream("patient_confidential");
450
451 assert!(!policy.allows_stream("patient_confidential"));
453
454 assert!(policy.allows_stream("patient_records"));
456
457 assert!(!policy.allows_stream("audit_log"));
459 }
460
461 #[test]
462 fn test_policy_column_access() {
463 let policy = AccessPolicy::new(Role::Analyst)
464 .allow_column("*")
465 .deny_column("ssn");
466
467 assert!(policy.allows_column("name"));
468 assert!(policy.allows_column("email"));
469 assert!(!policy.allows_column("ssn")); }
471
472 #[test]
473 fn test_standard_policies() {
474 let admin = StandardPolicies::admin();
475 assert!(admin.allows_stream("any_stream"));
476 assert!(admin.allows_column("any_column"));
477
478 let analyst = StandardPolicies::analyst();
479 assert!(analyst.allows_stream("patient_records"));
480 assert!(!analyst.allows_stream("audit_system_events"));
481
482 let tenant_id = TenantId::new(42);
483 let user = StandardPolicies::user(tenant_id);
484 assert_eq!(user.tenant_id, Some(tenant_id));
485 assert_eq!(user.row_filters.len(), 1);
486
487 let auditor = StandardPolicies::auditor();
488 assert!(auditor.allows_stream("audit_access_log"));
489 assert!(!auditor.allows_stream("patient_records"));
490 }
491
492 #[test]
493 fn test_row_filter_operator_sql() {
494 assert_eq!(RowFilterOperator::Eq.to_sql(), "=");
495 assert_eq!(RowFilterOperator::Ne.to_sql(), "!=");
496 assert_eq!(RowFilterOperator::Lt.to_sql(), "<");
497 assert_eq!(RowFilterOperator::Le.to_sql(), "<=");
498 assert_eq!(RowFilterOperator::Gt.to_sql(), ">");
499 assert_eq!(RowFilterOperator::Ge.to_sql(), ">=");
500 assert_eq!(RowFilterOperator::In.to_sql(), "IN");
501 assert_eq!(RowFilterOperator::NotIn.to_sql(), "NOT IN");
502 }
503}