1use std::str::FromStr;
2
3use uuid::Uuid;
4
5use crate::ForgeError;
6
7#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
9pub enum TenantIsolationMode {
10 #[default]
12 None,
13 Strict,
15 ReadShared,
17}
18
19impl TenantIsolationMode {
20 pub fn as_str(&self) -> &'static str {
22 match self {
23 Self::None => "none",
24 Self::Strict => "strict",
25 Self::ReadShared => "read_shared",
26 }
27 }
28}
29
30impl FromStr for TenantIsolationMode {
31 type Err = std::convert::Infallible;
32
33 fn from_str(s: &str) -> Result<Self, Self::Err> {
34 Ok(match s {
35 "strict" => Self::Strict,
36 "read_shared" => Self::ReadShared,
37 _ => Self::None,
38 })
39 }
40}
41
42#[derive(Debug, Clone, Default)]
44pub struct TenantContext {
45 pub tenant_id: Option<Uuid>,
47 pub isolation_mode: TenantIsolationMode,
49}
50
51impl TenantContext {
52 pub fn none() -> Self {
54 Self {
55 tenant_id: None,
56 isolation_mode: TenantIsolationMode::None,
57 }
58 }
59
60 pub fn new(tenant_id: Uuid, isolation_mode: TenantIsolationMode) -> Self {
62 Self {
63 tenant_id: Some(tenant_id),
64 isolation_mode,
65 }
66 }
67
68 pub fn strict(tenant_id: Uuid) -> Self {
70 Self::new(tenant_id, TenantIsolationMode::Strict)
71 }
72
73 pub fn has_tenant(&self) -> bool {
75 self.tenant_id.is_some()
76 }
77
78 pub fn require_tenant(&self) -> crate::Result<Uuid> {
80 self.tenant_id
81 .ok_or_else(|| ForgeError::Unauthorized("Tenant context required".into()))
82 }
83
84 pub fn requires_filtering(&self) -> bool {
86 self.tenant_id.is_some() && self.isolation_mode != TenantIsolationMode::None
87 }
88
89 pub fn sql_filter(&self, column: &str) -> Option<String> {
91 self.tenant_id.map(|id| format!("{} = '{}'", column, id))
92 }
93}
94
95pub trait HasTenant {
97 fn tenant(&self) -> &TenantContext;
99
100 fn tenant_id(&self) -> Option<Uuid> {
102 self.tenant().tenant_id
103 }
104
105 fn require_tenant(&self) -> crate::Result<Uuid> {
107 self.tenant().require_tenant()
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114
115 #[test]
116 fn test_tenant_context_none() {
117 let ctx = TenantContext::none();
118 assert!(!ctx.has_tenant());
119 assert!(!ctx.requires_filtering());
120 }
121
122 #[test]
123 fn test_tenant_context_strict() {
124 let tenant_id = Uuid::new_v4();
125 let ctx = TenantContext::strict(tenant_id);
126 assert!(ctx.has_tenant());
127 assert!(ctx.requires_filtering());
128 assert_eq!(ctx.tenant_id, Some(tenant_id));
129 }
130
131 #[test]
132 fn test_tenant_sql_filter() {
133 let tenant_id = Uuid::new_v4();
134 let ctx = TenantContext::strict(tenant_id);
135 let filter = ctx.sql_filter("tenant_id");
136 assert!(filter.is_some());
137 assert!(filter.unwrap().contains(&tenant_id.to_string()));
138 }
139
140 #[test]
141 fn test_require_tenant() {
142 let ctx = TenantContext::none();
143 assert!(ctx.require_tenant().is_err());
144
145 let tenant_id = Uuid::new_v4();
146 let ctx = TenantContext::strict(tenant_id);
147 assert!(ctx.require_tenant().is_ok());
148 }
149}