1use std::str::FromStr;
2
3use uuid::Uuid;
4
5use crate::ForgeError;
6
7#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
9pub enum TenantIsolationMode {
10 #[default]
12 None,
13 Strict,
15 ReadShared,
17}
18
19impl TenantIsolationMode {
20 pub fn as_str(&self) -> &'static str {
22 match self {
23 Self::None => "none",
24 Self::Strict => "strict",
25 Self::ReadShared => "read_shared",
26 }
27 }
28}
29
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub struct ParseTenantIsolationModeError(pub String);
32
33impl std::fmt::Display for ParseTenantIsolationModeError {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 write!(f, "invalid tenant isolation mode: '{}'", self.0)
36 }
37}
38
39impl std::error::Error for ParseTenantIsolationModeError {}
40
41impl FromStr for TenantIsolationMode {
42 type Err = ParseTenantIsolationModeError;
43
44 fn from_str(s: &str) -> Result<Self, Self::Err> {
45 match s {
46 "none" => Ok(Self::None),
47 "strict" => Ok(Self::Strict),
48 "read_shared" => Ok(Self::ReadShared),
49 _ => Err(ParseTenantIsolationModeError(s.to_string())),
50 }
51 }
52}
53
54#[derive(Debug, Clone, Default)]
56pub struct TenantContext {
57 pub tenant_id: Option<Uuid>,
59 pub isolation_mode: TenantIsolationMode,
61}
62
63impl TenantContext {
64 pub fn none() -> Self {
66 Self {
67 tenant_id: None,
68 isolation_mode: TenantIsolationMode::None,
69 }
70 }
71
72 pub fn new(tenant_id: Uuid, isolation_mode: TenantIsolationMode) -> Self {
74 Self {
75 tenant_id: Some(tenant_id),
76 isolation_mode,
77 }
78 }
79
80 pub fn strict(tenant_id: Uuid) -> Self {
82 Self::new(tenant_id, TenantIsolationMode::Strict)
83 }
84
85 pub fn has_tenant(&self) -> bool {
87 self.tenant_id.is_some()
88 }
89
90 pub fn require_tenant(&self) -> crate::Result<Uuid> {
92 self.tenant_id
93 .ok_or_else(|| ForgeError::Unauthorized("Tenant context required".into()))
94 }
95
96 pub fn requires_filtering(&self) -> bool {
98 self.tenant_id.is_some() && self.isolation_mode != TenantIsolationMode::None
99 }
100
101 pub fn sql_filter(&self, column: &str, param_index: u32) -> Option<(String, Uuid)> {
107 if column.is_empty()
109 || !column
110 .bytes()
111 .all(|b| b.is_ascii_alphanumeric() || b == b'_')
112 {
113 return None;
114 }
115 self.tenant_id
116 .map(|id| (format!("\"{}\" = ${}", column, param_index), id))
117 }
118}
119
120pub trait HasTenant {
122 fn tenant(&self) -> &TenantContext;
124
125 fn tenant_id(&self) -> Option<Uuid> {
127 self.tenant().tenant_id
128 }
129
130 fn require_tenant(&self) -> crate::Result<Uuid> {
132 self.tenant().require_tenant()
133 }
134}
135
136#[cfg(test)]
137#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
138mod tests {
139 use super::*;
140
141 #[test]
142 fn test_tenant_context_none() {
143 let ctx = TenantContext::none();
144 assert!(!ctx.has_tenant());
145 assert!(!ctx.requires_filtering());
146 }
147
148 #[test]
149 fn test_tenant_context_strict() {
150 let tenant_id = Uuid::new_v4();
151 let ctx = TenantContext::strict(tenant_id);
152 assert!(ctx.has_tenant());
153 assert!(ctx.requires_filtering());
154 assert_eq!(ctx.tenant_id, Some(tenant_id));
155 }
156
157 #[test]
158 fn test_tenant_sql_filter() {
159 let tenant_id = Uuid::new_v4();
160 let ctx = TenantContext::strict(tenant_id);
161 let filter = ctx.sql_filter("tenant_id", 1);
162 assert!(filter.is_some());
163 let (clause, id) = filter.unwrap();
164 assert_eq!(clause, "\"tenant_id\" = $1");
165 assert_eq!(id, tenant_id);
166 }
167
168 #[test]
169 fn test_require_tenant() {
170 let ctx = TenantContext::none();
171 assert!(ctx.require_tenant().is_err());
172
173 let tenant_id = Uuid::new_v4();
174 let ctx = TenantContext::strict(tenant_id);
175 assert!(ctx.require_tenant().is_ok());
176 }
177}