1use std::str::FromStr;
2
3use uuid::Uuid;
4
5use crate::ForgeError;
6
7#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
8#[non_exhaustive]
9pub enum TenantIsolationMode {
10 #[default]
11 None,
12 Strict,
13 ReadShared,
14}
15
16impl TenantIsolationMode {
17 pub fn as_str(&self) -> &'static str {
18 match self {
19 Self::None => "none",
20 Self::Strict => "strict",
21 Self::ReadShared => "read_shared",
22 }
23 }
24}
25
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct ParseTenantIsolationModeError(pub String);
28
29impl std::fmt::Display for ParseTenantIsolationModeError {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 write!(f, "invalid tenant isolation mode: '{}'", self.0)
32 }
33}
34
35impl std::error::Error for ParseTenantIsolationModeError {}
36
37impl FromStr for TenantIsolationMode {
38 type Err = ParseTenantIsolationModeError;
39
40 fn from_str(s: &str) -> Result<Self, Self::Err> {
41 match s {
42 "none" => Ok(Self::None),
43 "strict" => Ok(Self::Strict),
44 "read_shared" => Ok(Self::ReadShared),
45 _ => Err(ParseTenantIsolationModeError(s.to_string())),
46 }
47 }
48}
49
50#[derive(Debug, Clone, Default)]
51#[non_exhaustive]
52pub struct TenantContext {
53 pub tenant_id: Option<Uuid>,
54 pub isolation_mode: TenantIsolationMode,
55}
56
57impl TenantContext {
58 pub fn none() -> Self {
59 Self {
60 tenant_id: None,
61 isolation_mode: TenantIsolationMode::None,
62 }
63 }
64
65 pub fn new(tenant_id: Uuid, isolation_mode: TenantIsolationMode) -> Self {
66 Self {
67 tenant_id: Some(tenant_id),
68 isolation_mode,
69 }
70 }
71
72 pub fn strict(tenant_id: Uuid) -> Self {
73 Self::new(tenant_id, TenantIsolationMode::Strict)
74 }
75
76 pub fn has_tenant(&self) -> bool {
77 self.tenant_id.is_some()
78 }
79
80 pub fn require_tenant(&self) -> crate::Result<Uuid> {
81 self.tenant_id
82 .ok_or_else(|| ForgeError::Unauthorized("Tenant context required".into()))
83 }
84
85 pub fn requires_filtering(&self) -> bool {
86 self.tenant_id.is_some() && self.isolation_mode != TenantIsolationMode::None
87 }
88
89 pub fn sql_filter(&self, column: &str, param_index: u32) -> Option<(String, Uuid)> {
91 if column.is_empty()
93 || !column
94 .bytes()
95 .all(|b| b.is_ascii_alphanumeric() || b == b'_')
96 {
97 return None;
98 }
99 self.tenant_id
100 .map(|id| (format!("\"{}\" = ${}", column, param_index), id))
101 }
102}
103
104#[cfg(test)]
105#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
106mod tests {
107 use super::*;
108
109 #[test]
110 fn test_tenant_context_none() {
111 let ctx = TenantContext::none();
112 assert!(!ctx.has_tenant());
113 assert!(!ctx.requires_filtering());
114 }
115
116 #[test]
117 fn test_tenant_context_strict() {
118 let tenant_id = Uuid::new_v4();
119 let ctx = TenantContext::strict(tenant_id);
120 assert!(ctx.has_tenant());
121 assert!(ctx.requires_filtering());
122 assert_eq!(ctx.tenant_id, Some(tenant_id));
123 }
124
125 #[test]
126 fn test_tenant_sql_filter() {
127 let tenant_id = Uuid::new_v4();
128 let ctx = TenantContext::strict(tenant_id);
129 let filter = ctx.sql_filter("tenant_id", 1);
130 assert!(filter.is_some());
131 let (clause, id) = filter.unwrap();
132 assert_eq!(clause, "\"tenant_id\" = $1");
133 assert_eq!(id, tenant_id);
134 }
135
136 #[test]
137 fn test_require_tenant() {
138 let ctx = TenantContext::none();
139 assert!(ctx.require_tenant().is_err());
140
141 let tenant_id = Uuid::new_v4();
142 let ctx = TenantContext::strict(tenant_id);
143 assert!(ctx.require_tenant().is_ok());
144 }
145}