Skip to main content

forge_core/tenant/
mod.rs

1use std::str::FromStr;
2
3use uuid::Uuid;
4
5use crate::ForgeError;
6
7#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
8#[non_exhaustive]
9pub enum TenantIsolationMode {
10    #[default]
11    None,
12    Strict,
13    ReadShared,
14}
15
16impl TenantIsolationMode {
17    pub fn as_str(&self) -> &'static str {
18        match self {
19            Self::None => "none",
20            Self::Strict => "strict",
21            Self::ReadShared => "read_shared",
22        }
23    }
24}
25
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct ParseTenantIsolationModeError(pub String);
28
29impl std::fmt::Display for ParseTenantIsolationModeError {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        write!(f, "invalid tenant isolation mode: '{}'", self.0)
32    }
33}
34
35impl std::error::Error for ParseTenantIsolationModeError {}
36
37impl FromStr for TenantIsolationMode {
38    type Err = ParseTenantIsolationModeError;
39
40    fn from_str(s: &str) -> Result<Self, Self::Err> {
41        match s {
42            "none" => Ok(Self::None),
43            "strict" => Ok(Self::Strict),
44            "read_shared" => Ok(Self::ReadShared),
45            _ => Err(ParseTenantIsolationModeError(s.to_string())),
46        }
47    }
48}
49
50#[derive(Debug, Clone, Default)]
51#[non_exhaustive]
52pub struct TenantContext {
53    pub tenant_id: Option<Uuid>,
54    pub isolation_mode: TenantIsolationMode,
55}
56
57impl TenantContext {
58    pub fn none() -> Self {
59        Self {
60            tenant_id: None,
61            isolation_mode: TenantIsolationMode::None,
62        }
63    }
64
65    pub fn new(tenant_id: Uuid, isolation_mode: TenantIsolationMode) -> Self {
66        Self {
67            tenant_id: Some(tenant_id),
68            isolation_mode,
69        }
70    }
71
72    pub fn strict(tenant_id: Uuid) -> Self {
73        Self::new(tenant_id, TenantIsolationMode::Strict)
74    }
75
76    pub fn has_tenant(&self) -> bool {
77        self.tenant_id.is_some()
78    }
79
80    pub fn require_tenant(&self) -> crate::Result<Uuid> {
81        self.tenant_id
82            .ok_or_else(|| ForgeError::Unauthorized("Tenant context required".into()))
83    }
84
85    pub fn requires_filtering(&self) -> bool {
86        self.tenant_id.is_some() && self.isolation_mode != TenantIsolationMode::None
87    }
88
89    /// Returns (SQL clause, param value) for tenant-scoped WHERE filtering.
90    pub fn sql_filter(&self, column: &str, param_index: u32) -> Option<(String, Uuid)> {
91        // Validate column name to prevent SQL injection via dynamic column names
92        if column.is_empty()
93            || !column
94                .bytes()
95                .all(|b| b.is_ascii_alphanumeric() || b == b'_')
96        {
97            return None;
98        }
99        self.tenant_id
100            .map(|id| (format!("\"{}\" = ${}", column, param_index), id))
101    }
102}
103
104#[cfg(test)]
105#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
106mod tests {
107    use super::*;
108
109    #[test]
110    fn test_tenant_context_none() {
111        let ctx = TenantContext::none();
112        assert!(!ctx.has_tenant());
113        assert!(!ctx.requires_filtering());
114    }
115
116    #[test]
117    fn test_tenant_context_strict() {
118        let tenant_id = Uuid::new_v4();
119        let ctx = TenantContext::strict(tenant_id);
120        assert!(ctx.has_tenant());
121        assert!(ctx.requires_filtering());
122        assert_eq!(ctx.tenant_id, Some(tenant_id));
123    }
124
125    #[test]
126    fn test_tenant_sql_filter() {
127        let tenant_id = Uuid::new_v4();
128        let ctx = TenantContext::strict(tenant_id);
129        let filter = ctx.sql_filter("tenant_id", 1);
130        assert!(filter.is_some());
131        let (clause, id) = filter.unwrap();
132        assert_eq!(clause, "\"tenant_id\" = $1");
133        assert_eq!(id, tenant_id);
134    }
135
136    #[test]
137    fn test_require_tenant() {
138        let ctx = TenantContext::none();
139        assert!(ctx.require_tenant().is_err());
140
141        let tenant_id = Uuid::new_v4();
142        let ctx = TenantContext::strict(tenant_id);
143        assert!(ctx.require_tenant().is_ok());
144    }
145}