1use std::str::FromStr;
2
3use uuid::Uuid;
4
5use crate::ForgeError;
6
7#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
9pub enum TenantIsolationMode {
10 #[default]
12 None,
13 Strict,
15 ReadShared,
17}
18
19impl TenantIsolationMode {
20 pub fn as_str(&self) -> &'static str {
22 match self {
23 Self::None => "none",
24 Self::Strict => "strict",
25 Self::ReadShared => "read_shared",
26 }
27 }
28}
29
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub struct ParseTenantIsolationModeError(pub String);
32
33impl std::fmt::Display for ParseTenantIsolationModeError {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 write!(f, "invalid tenant isolation mode: '{}'", self.0)
36 }
37}
38
39impl std::error::Error for ParseTenantIsolationModeError {}
40
41impl FromStr for TenantIsolationMode {
42 type Err = ParseTenantIsolationModeError;
43
44 fn from_str(s: &str) -> Result<Self, Self::Err> {
45 match s {
46 "none" => Ok(Self::None),
47 "strict" => Ok(Self::Strict),
48 "read_shared" => Ok(Self::ReadShared),
49 _ => Err(ParseTenantIsolationModeError(s.to_string())),
50 }
51 }
52}
53
54#[derive(Debug, Clone, Default)]
56pub struct TenantContext {
57 pub tenant_id: Option<Uuid>,
59 pub isolation_mode: TenantIsolationMode,
61}
62
63impl TenantContext {
64 pub fn none() -> Self {
66 Self {
67 tenant_id: None,
68 isolation_mode: TenantIsolationMode::None,
69 }
70 }
71
72 pub fn new(tenant_id: Uuid, isolation_mode: TenantIsolationMode) -> Self {
74 Self {
75 tenant_id: Some(tenant_id),
76 isolation_mode,
77 }
78 }
79
80 pub fn strict(tenant_id: Uuid) -> Self {
82 Self::new(tenant_id, TenantIsolationMode::Strict)
83 }
84
85 pub fn has_tenant(&self) -> bool {
87 self.tenant_id.is_some()
88 }
89
90 pub fn require_tenant(&self) -> crate::Result<Uuid> {
92 self.tenant_id
93 .ok_or_else(|| ForgeError::Unauthorized("Tenant context required".into()))
94 }
95
96 pub fn requires_filtering(&self) -> bool {
98 self.tenant_id.is_some() && self.isolation_mode != TenantIsolationMode::None
99 }
100
101 pub fn sql_filter(&self, column: &str, param_index: u32) -> Option<(String, Uuid)> {
104 self.tenant_id
105 .map(|id| (format!("{} = ${}", column, param_index), id))
106 }
107}
108
109pub trait HasTenant {
111 fn tenant(&self) -> &TenantContext;
113
114 fn tenant_id(&self) -> Option<Uuid> {
116 self.tenant().tenant_id
117 }
118
119 fn require_tenant(&self) -> crate::Result<Uuid> {
121 self.tenant().require_tenant()
122 }
123}
124
125#[cfg(test)]
126#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
127mod tests {
128 use super::*;
129
130 #[test]
131 fn test_tenant_context_none() {
132 let ctx = TenantContext::none();
133 assert!(!ctx.has_tenant());
134 assert!(!ctx.requires_filtering());
135 }
136
137 #[test]
138 fn test_tenant_context_strict() {
139 let tenant_id = Uuid::new_v4();
140 let ctx = TenantContext::strict(tenant_id);
141 assert!(ctx.has_tenant());
142 assert!(ctx.requires_filtering());
143 assert_eq!(ctx.tenant_id, Some(tenant_id));
144 }
145
146 #[test]
147 fn test_tenant_sql_filter() {
148 let tenant_id = Uuid::new_v4();
149 let ctx = TenantContext::strict(tenant_id);
150 let filter = ctx.sql_filter("tenant_id", 1);
151 assert!(filter.is_some());
152 let (clause, id) = filter.unwrap();
153 assert_eq!(clause, "tenant_id = $1");
154 assert_eq!(id, tenant_id);
155 }
156
157 #[test]
158 fn test_require_tenant() {
159 let ctx = TenantContext::none();
160 assert!(ctx.require_tenant().is_err());
161
162 let tenant_id = Uuid::new_v4();
163 let ctx = TenantContext::strict(tenant_id);
164 assert!(ctx.require_tenant().is_ok());
165 }
166}