Skip to main content

reddb_server/runtime/
within_clause.rs

1//! `WITHIN TENANT '<id>' [USER '<u>'] [AS ROLE '<r>'] <stmt>` —
2//! a per-statement scope override for tenant + auth identity.
3//!
4//! Designed for SaaS deployments where one process / one connection
5//! pool serves many tenants. The clause carries the scope inline with
6//! the query, so:
7//!   * no thread-local state survives the call
8//!   * connection pools cannot leak tenant context between checkouts
9//!   * async runtimes that move tasks between threads stay correct
10//!     (the scope lives in a stack pushed/popped by the same execute
11//!     call — no `.await` in between)
12//!   * clients can use prepared statements normally
13//!
14//! Values: string literal (`'acme'`) or `NULL` (clears just that field).
15
16use crate::storage::query::lexer::{Lexer, Token};
17
18/// Tri-state for a single overridable field. `Inherit` means the
19/// `WITHIN` clause didn't mention this field and the runtime should
20/// fall through to its prior source (session-level or auth-installed).
21/// `Clear` means the clause explicitly set the field to NULL — this
22/// must hide the inherited value, not fall through. `Set(v)` carries
23/// the literal value.
24#[derive(Debug, Clone, PartialEq, Eq, Default)]
25pub enum FieldOverride {
26    #[default]
27    Inherit,
28    Clear,
29    Set(String),
30}
31
32impl FieldOverride {
33    /// Is this override active — i.e. it should win over any
34    /// inherited value?
35    pub fn is_active(&self) -> bool {
36        !matches!(self, Self::Inherit)
37    }
38
39    /// Resolve the override against an inherited value. `Inherit`
40    /// passes the inherited through; `Clear` returns `None`; `Set`
41    /// returns its literal.
42    pub fn resolve(&self, inherited: Option<String>) -> Option<String> {
43        match self {
44            Self::Inherit => inherited,
45            Self::Clear => None,
46            Self::Set(v) => Some(v.clone()),
47        }
48    }
49}
50
51/// Per-statement scope override extracted from a `WITHIN ...` prefix.
52#[derive(Debug, Clone, Default, PartialEq, Eq)]
53pub struct ScopeOverride {
54    pub tenant: FieldOverride,
55    pub user: FieldOverride,
56    pub role: FieldOverride,
57}
58
59impl ScopeOverride {
60    pub fn is_empty(&self) -> bool {
61        !self.tenant.is_active() && !self.user.is_active() && !self.role.is_active()
62    }
63}
64
65/// Try to recognise a `WITHIN TENANT ... <stmt>` prefix at the start
66/// of `input`. Returns the parsed scope plus the remainder slice
67/// (the inner statement), or `None` when the input doesn't start with
68/// `WITHIN`. A malformed `WITHIN ...` clause returns `Err`.
69pub fn try_strip_within_prefix(input: &str) -> Result<Option<(ScopeOverride, &str)>, String> {
70    let trimmed = input.trim_start();
71    let first_word = trimmed
72        .split(|c: char| c.is_whitespace())
73        .next()
74        .unwrap_or("");
75    if !first_word.eq_ignore_ascii_case("WITHIN") {
76        return Ok(None);
77    }
78
79    let mut lexer = Lexer::new(input);
80    expect_ident(&mut lexer, "WITHIN")?;
81
82    let mut scope = ScopeOverride::default();
83    let mut tenant_seen = false;
84
85    loop {
86        let spanned = lexer.next_token().map_err(|e| e.to_string())?;
87        match spanned.token {
88            Token::Ident(ref name) if name.eq_ignore_ascii_case("TENANT") => {
89                if tenant_seen {
90                    return Err("duplicate TENANT clause in WITHIN prefix".into());
91                }
92                tenant_seen = true;
93                scope.tenant = parse_value(&mut lexer)?;
94            }
95            Token::Ident(ref name) if name.eq_ignore_ascii_case("USER") => {
96                if scope.user.is_active() {
97                    return Err("duplicate USER clause in WITHIN prefix".into());
98                }
99                scope.user = parse_value(&mut lexer)?;
100            }
101            Token::As => {
102                expect_ident(&mut lexer, "ROLE")?;
103                if scope.role.is_active() {
104                    return Err("duplicate AS ROLE clause in WITHIN prefix".into());
105                }
106                scope.role = parse_value(&mut lexer)?;
107            }
108            Token::Ident(ref name) if name.eq_ignore_ascii_case("ROLE") => {
109                if scope.role.is_active() {
110                    return Err("duplicate ROLE clause in WITHIN prefix".into());
111                }
112                scope.role = parse_value(&mut lexer)?;
113            }
114            // Anything else is the start of the inner statement — peel
115            // back to the offset where this token began so the inner
116            // query string slice keeps the leading keyword intact.
117            _ => {
118                if !tenant_seen {
119                    return Err("WITHIN clause requires at least TENANT '<id>' (or NULL)".into());
120                }
121                let offset = spanned.start.offset as usize;
122                if offset > input.len() {
123                    return Err("internal: WITHIN clause offset out of range".into());
124                }
125                let inner = input[offset..].trim_start();
126                if inner.is_empty() {
127                    return Err("WITHIN clause has no inner statement to execute".into());
128                }
129                return Ok(Some((scope, inner)));
130            }
131        }
132    }
133}
134
135fn expect_ident(lexer: &mut Lexer<'_>, expected: &str) -> Result<(), String> {
136    let spanned = lexer.next_token().map_err(|e| e.to_string())?;
137    match spanned.token {
138        Token::Ident(name) if name.eq_ignore_ascii_case(expected) => Ok(()),
139        other => Err(format!(
140            "expected `{expected}` in WITHIN prefix, got {other:?}"
141        )),
142    }
143}
144
145fn parse_value(lexer: &mut Lexer<'_>) -> Result<FieldOverride, String> {
146    let spanned = lexer.next_token().map_err(|e| e.to_string())?;
147    match spanned.token {
148        Token::String(s) => Ok(FieldOverride::Set(s)),
149        Token::Null => Ok(FieldOverride::Clear),
150        other => Err(format!(
151            "WITHIN clause value must be a string literal or NULL, got {other:?}"
152        )),
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159
160    #[test]
161    fn no_within_prefix_returns_none() {
162        assert!(try_strip_within_prefix("SELECT * FROM x")
163            .unwrap()
164            .is_none());
165        assert!(try_strip_within_prefix("  SELECT * FROM x")
166            .unwrap()
167            .is_none());
168    }
169
170    #[test]
171    fn parses_tenant_only() {
172        let (scope, inner) = try_strip_within_prefix("WITHIN TENANT 'acme' SELECT * FROM x")
173            .unwrap()
174            .unwrap();
175        assert_eq!(scope.tenant, FieldOverride::Set("acme".into()));
176        assert_eq!(scope.user, FieldOverride::Inherit);
177        assert_eq!(scope.role, FieldOverride::Inherit);
178        assert_eq!(inner, "SELECT * FROM x");
179    }
180
181    #[test]
182    fn parses_full_clause() {
183        let (scope, inner) = try_strip_within_prefix(
184            "WITHIN TENANT 'acme' USER 'filipe' AS ROLE 'admin' SELECT * FROM x",
185        )
186        .unwrap()
187        .unwrap();
188        assert_eq!(scope.tenant, FieldOverride::Set("acme".into()));
189        assert_eq!(scope.user, FieldOverride::Set("filipe".into()));
190        assert_eq!(scope.role, FieldOverride::Set("admin".into()));
191        assert_eq!(inner, "SELECT * FROM x");
192    }
193
194    #[test]
195    fn null_tenant_clears() {
196        let (scope, _) = try_strip_within_prefix("WITHIN TENANT NULL SELECT 1")
197            .unwrap()
198            .unwrap();
199        assert_eq!(scope.tenant, FieldOverride::Clear);
200    }
201
202    #[test]
203    fn rejects_missing_tenant() {
204        assert!(try_strip_within_prefix("WITHIN USER 'x' SELECT 1").is_err());
205    }
206
207    #[test]
208    fn rejects_duplicate_clause() {
209        assert!(try_strip_within_prefix("WITHIN TENANT 'a' TENANT 'b' SELECT 1").is_err());
210    }
211
212    #[test]
213    fn case_insensitive() {
214        let (scope, inner) = try_strip_within_prefix("within tenant 'acme' select * from x")
215            .unwrap()
216            .unwrap();
217        assert_eq!(scope.tenant, FieldOverride::Set("acme".into()));
218        assert_eq!(inner, "select * from x");
219    }
220}