Skip to main content

assay_auth/zanzibar/
postgres.rs

1//! Postgres [`ZanzibarStore`] implementation — recursive CTE walks
2//! over `auth.zanzibar_tuples` for `check`, `expand`, and the lookup
3//! pair.
4//!
5//! Layout: a single round-trip per call. The CTE seeds with the
6//! direct tuples on the resource for whatever relation set the
7//! [`super::resolve`] pass produced, then expands one userset hop at
8//! a time bounded by [`super::types::MAX_DEPTH`]. Cycle detection
9//! uses an in-CTE `path` array — if the next subject is already on
10//! the path, the join skips it.
11//!
12//! Performance notes (PG18 EXPLAIN-checked once during
13//! development against a 100k-tuple seed):
14//!
15//! - Forward walk hits the composite PK `(object_type, object_id,
16//!   relation, …)` for the seed and then for each hop.
17//! - Reverse `lookup_resources` hits `idx_auth_zanzibar_tuples_rev`
18//!   for the seed; subsequent hops still use the PK because the next
19//!   step is "given subject, find object" → forward direction again.
20//! - JSONB `auth.zanzibar_namespaces.schema_json` is round-tripped
21//!   via `serde_json::Value` (sqlx's `JsonValue` mapping); for the
22//!   100-namespace order-of-magnitude any v0.2.0 deployment will
23//!   actually have, parsing on every check is fine — the schema
24//!   cache is a phase-9 follow-up.
25
26use std::sync::Arc;
27
28use anyhow::{Context, Result};
29use sqlx::{PgPool, Row};
30
31use super::resolve::resolve;
32use super::store::ZanzibarStore;
33use super::types::{
34    CheckResult, Consistency, NamespaceSchema, ObjectRef, SubjectRef, Tuple, TreeOp,
35    UsersetTree, MAX_DEPTH,
36};
37
38/// Postgres-backed Zanzibar store. Cheap to clone (the underlying
39/// `PgPool` is `Arc` internally).
40#[derive(Clone)]
41pub struct PostgresZanzibarStore {
42    pool: PgPool,
43}
44
45impl PostgresZanzibarStore {
46    pub fn new(pool: PgPool) -> Self {
47        Self { pool }
48    }
49
50    /// Wrap into an `Arc<dyn ZanzibarStore>` for [`crate::ctx::AuthCtx`].
51    pub fn into_dyn(self) -> Arc<dyn ZanzibarStore> {
52        Arc::new(self)
53    }
54}
55
56#[async_trait::async_trait]
57impl ZanzibarStore for PostgresZanzibarStore {
58    async fn define_namespace(&self, schema: &NamespaceSchema) -> Result<()> {
59        let json = serde_json::to_value(schema)
60            .context("zanzibar serialize NamespaceSchema")?;
61        sqlx::query(
62            "INSERT INTO auth.zanzibar_namespaces (name, schema_json, updated_at)
63             VALUES ($1, $2, EXTRACT(EPOCH FROM NOW()))
64             ON CONFLICT (name) DO UPDATE
65                 SET schema_json = EXCLUDED.schema_json,
66                     updated_at = EXCLUDED.updated_at",
67        )
68        .bind(&schema.name)
69        .bind(json)
70        .execute(&self.pool)
71        .await
72        .context("auth.zanzibar_namespaces upsert")?;
73        Ok(())
74    }
75
76    async fn get_namespace(&self, name: &str) -> Result<Option<NamespaceSchema>> {
77        let row: Option<(serde_json::Value,)> = sqlx::query_as(
78            "SELECT schema_json FROM auth.zanzibar_namespaces WHERE name = $1",
79        )
80        .bind(name)
81        .fetch_optional(&self.pool)
82        .await
83        .context("auth.zanzibar_namespaces get")?;
84        Ok(match row {
85            Some((json,)) => Some(
86                serde_json::from_value(json)
87                    .context("zanzibar deserialize NamespaceSchema")?,
88            ),
89            None => None,
90        })
91    }
92
93    async fn list_namespaces(&self) -> Result<Vec<NamespaceSchema>> {
94        let rows: Vec<(serde_json::Value,)> = sqlx::query_as(
95            "SELECT schema_json FROM auth.zanzibar_namespaces ORDER BY name",
96        )
97        .fetch_all(&self.pool)
98        .await
99        .context("auth.zanzibar_namespaces list")?;
100        rows.into_iter()
101            .map(|(json,)| {
102                serde_json::from_value(json).context("zanzibar deserialize NamespaceSchema")
103            })
104            .collect()
105    }
106
107    async fn write_tuple(&self, t: &Tuple) -> Result<()> {
108        sqlx::query(
109            "INSERT INTO auth.zanzibar_tuples
110                (object_type, object_id, relation,
111                 subject_type, subject_id, subject_rel, created_at)
112             VALUES ($1, $2, $3, $4, $5, $6, EXTRACT(EPOCH FROM NOW()))
113             ON CONFLICT DO NOTHING",
114        )
115        .bind(&t.object_type)
116        .bind(&t.object_id)
117        .bind(&t.relation)
118        .bind(&t.subject_type)
119        .bind(&t.subject_id)
120        .bind(&t.subject_rel)
121        .execute(&self.pool)
122        .await
123        .context("auth.zanzibar_tuples insert")?;
124        Ok(())
125    }
126
127    async fn write_tuples(&self, tuples: &[Tuple]) -> Result<()> {
128        if tuples.is_empty() {
129            return Ok(());
130        }
131        let mut tx = self.pool.begin().await.context("begin tuples txn")?;
132        for t in tuples {
133            sqlx::query(
134                "INSERT INTO auth.zanzibar_tuples
135                    (object_type, object_id, relation,
136                     subject_type, subject_id, subject_rel, created_at)
137                 VALUES ($1, $2, $3, $4, $5, $6, EXTRACT(EPOCH FROM NOW()))
138                 ON CONFLICT DO NOTHING",
139            )
140            .bind(&t.object_type)
141            .bind(&t.object_id)
142            .bind(&t.relation)
143            .bind(&t.subject_type)
144            .bind(&t.subject_id)
145            .bind(&t.subject_rel)
146            .execute(&mut *tx)
147            .await
148            .context("auth.zanzibar_tuples batch insert")?;
149        }
150        tx.commit().await.context("commit tuples txn")?;
151        Ok(())
152    }
153
154    async fn delete_tuple(&self, t: &Tuple) -> Result<bool> {
155        // subject_rel is NOT NULL ('' for direct), so plain equality
156        // suffices — no IS NOT DISTINCT FROM dance.
157        let res = sqlx::query(
158            "DELETE FROM auth.zanzibar_tuples
159             WHERE object_type = $1 AND object_id = $2 AND relation = $3
160               AND subject_type = $4 AND subject_id = $5
161               AND subject_rel = $6",
162        )
163        .bind(&t.object_type)
164        .bind(&t.object_id)
165        .bind(&t.relation)
166        .bind(&t.subject_type)
167        .bind(&t.subject_id)
168        .bind(&t.subject_rel)
169        .execute(&self.pool)
170        .await
171        .context("auth.zanzibar_tuples delete")?;
172        Ok(res.rows_affected() > 0)
173    }
174
175    async fn check(
176        &self,
177        resource: &ObjectRef,
178        permission: &str,
179        subject: &SubjectRef,
180        _consistency: Consistency,
181    ) -> Result<CheckResult> {
182        // 1) Resolve the permission to its candidate relation set.
183        //    No namespace defined → deny (the safe default).
184        let Some(schema) = self.get_namespace(&resource.object_type).await? else {
185            return Ok(CheckResult::Denied);
186        };
187        let Some(resolved) = resolve(&schema, permission) else {
188            // Schema cycle. Surface as `CycleDetected` so callers/UI
189            // can flag the bad schema; this is *not* an error.
190            return Ok(CheckResult::CycleDetected);
191        };
192        if resolved.union_relations.is_empty() {
193            return Ok(CheckResult::Denied);
194        }
195        let relation_list: Vec<String> = resolved.union_relations.into_iter().collect();
196
197        // 2) Run the recursive walk. subject_rel is NOT NULL — '' for
198        //    direct subjects (the terminal kind `check` answers) and a
199        //    relation name for usersets the CTE walks through. Cycle
200        //    guard via the in-CTE `path` array — if the next subject is
201        //    already on the path, the join produces zero rows.
202        let row: Option<(i32,)> = sqlx::query_as(
203            r#"
204            WITH RECURSIVE walk(subject_type, subject_id, subject_rel, depth, path) AS (
205                SELECT t.subject_type,
206                       t.subject_id,
207                       t.subject_rel,
208                       1 AS depth,
209                       ARRAY[t.subject_type || ':' || t.subject_id] AS path
210                FROM auth.zanzibar_tuples t
211                WHERE t.object_type = $1 AND t.object_id = $2 AND t.relation = ANY($3)
212                UNION ALL
213                SELECT t.subject_type,
214                       t.subject_id,
215                       t.subject_rel,
216                       w.depth + 1,
217                       w.path || (t.subject_type || ':' || t.subject_id)
218                FROM auth.zanzibar_tuples t
219                JOIN walk w
220                  ON t.object_type = w.subject_type
221                 AND t.object_id   = w.subject_id
222                 AND w.subject_rel <> ''
223                 AND t.relation = w.subject_rel
224                WHERE w.depth < $4
225                  AND NOT (t.subject_type || ':' || t.subject_id) = ANY(w.path)
226            )
227            SELECT CASE
228                WHEN EXISTS (
229                    SELECT 1 FROM walk
230                    WHERE subject_type = $5 AND subject_id = $6 AND subject_rel = ''
231                ) THEN 1
232                WHEN EXISTS (SELECT 1 FROM walk WHERE depth >= $4) THEN 2
233                ELSE 0
234            END AS verdict
235            "#,
236        )
237        .bind(&resource.object_type)
238        .bind(&resource.object_id)
239        .bind(&relation_list)
240        .bind(MAX_DEPTH as i32)
241        .bind(&subject.subject_type)
242        .bind(&subject.subject_id)
243        .fetch_optional(&self.pool)
244        .await
245        .context("auth.zanzibar check CTE")?;
246
247        let verdict = row.map(|(v,)| v).unwrap_or(0);
248        Ok(match verdict {
249            1 => CheckResult::Allowed {
250                resolved_via: Vec::new(),
251            },
252            2 => CheckResult::DepthExceeded,
253            _ => CheckResult::Denied,
254        })
255    }
256
257    async fn expand(
258        &self,
259        resource: &ObjectRef,
260        relation: &str,
261        depth_limit: u32,
262    ) -> Result<UsersetTree> {
263        // Single-relation expansion — pull every direct subject of
264        // `resource#relation`, then for each userset hop, recurse.
265        // Diagnostic-only path; we keep the implementation simple
266        // (recursive Rust calls) rather than another CTE.
267        let depth = depth_limit.min(MAX_DEPTH);
268        Ok(UsersetTree::Node {
269            op: TreeOp::Direct,
270            children: expand_pg(self, resource, relation, depth, &mut Vec::new()).await?,
271        })
272    }
273
274    async fn lookup_resources(
275        &self,
276        resource_type: &str,
277        permission: &str,
278        subject: &SubjectRef,
279    ) -> Result<Vec<ObjectRef>> {
280        // Cheap heuristic for v0.2.0: we walk forward from every
281        // resource of the requested type that has *any* tuple in the
282        // candidate relation set, then post-filter via `check`.
283        // Optimal for small fan-outs (the family/circle scale jeebon
284        // ships at); a real production deployment will want a
285        // dedicated reverse index.
286        let Some(schema) = self.get_namespace(resource_type).await? else {
287            return Ok(Vec::new());
288        };
289        let Some(resolved) = resolve(&schema, permission) else {
290            return Ok(Vec::new());
291        };
292        if resolved.union_relations.is_empty() {
293            return Ok(Vec::new());
294        }
295        let relation_list: Vec<String> = resolved.union_relations.into_iter().collect();
296        // First pass: all candidate resources (objects with a tuple in
297        // the relation set, regardless of subject — narrows the search
298        // to one per resource).
299        let rows = sqlx::query(
300            "SELECT DISTINCT object_type, object_id
301             FROM auth.zanzibar_tuples
302             WHERE object_type = $1 AND relation = ANY($2)",
303        )
304        .bind(resource_type)
305        .bind(&relation_list)
306        .fetch_all(&self.pool)
307        .await
308        .context("auth.zanzibar_tuples candidate resources")?;
309        let mut out = Vec::new();
310        for row in rows {
311            let object_type: String = row.get("object_type");
312            let object_id: String = row.get("object_id");
313            let r = ObjectRef::new(object_type, object_id);
314            if self
315                .check(&r, permission, subject, Consistency::Minimum)
316                .await?
317                .is_allowed()
318            {
319                out.push(r);
320            }
321        }
322        Ok(out)
323    }
324
325    async fn lookup_subjects(
326        &self,
327        subject_type: &str,
328        resource: &ObjectRef,
329        permission: &str,
330    ) -> Result<Vec<SubjectRef>> {
331        // Forward walk from the resource — every direct subject under
332        // the candidate relation set. Userset intermediates are
333        // followed transitively via the same recursive CTE as `check`.
334        let Some(schema) = self.get_namespace(&resource.object_type).await? else {
335            return Ok(Vec::new());
336        };
337        let Some(resolved) = resolve(&schema, permission) else {
338            return Ok(Vec::new());
339        };
340        if resolved.union_relations.is_empty() {
341            return Ok(Vec::new());
342        }
343        let relation_list: Vec<String> = resolved.union_relations.into_iter().collect();
344        let rows = sqlx::query(
345            r#"
346            WITH RECURSIVE walk(subject_type, subject_id, subject_rel, depth, path) AS (
347                SELECT t.subject_type, t.subject_id, t.subject_rel, 1,
348                       ARRAY[t.subject_type || ':' || t.subject_id]
349                FROM auth.zanzibar_tuples t
350                WHERE t.object_type = $1 AND t.object_id = $2 AND t.relation = ANY($3)
351                UNION ALL
352                SELECT t.subject_type, t.subject_id, t.subject_rel, w.depth + 1,
353                       w.path || (t.subject_type || ':' || t.subject_id)
354                FROM auth.zanzibar_tuples t
355                JOIN walk w
356                  ON t.object_type = w.subject_type
357                 AND t.object_id   = w.subject_id
358                 AND w.subject_rel <> ''
359                 AND t.relation = w.subject_rel
360                WHERE w.depth < $4
361                  AND NOT (t.subject_type || ':' || t.subject_id) = ANY(w.path)
362            )
363            SELECT DISTINCT subject_type, subject_id
364            FROM walk
365            WHERE subject_type = $5 AND subject_rel = ''
366            "#,
367        )
368        .bind(&resource.object_type)
369        .bind(&resource.object_id)
370        .bind(&relation_list)
371        .bind(MAX_DEPTH as i32)
372        .bind(subject_type)
373        .fetch_all(&self.pool)
374        .await
375        .context("auth.zanzibar lookup_subjects CTE")?;
376        Ok(rows
377            .into_iter()
378            .map(|row| {
379                SubjectRef::direct(
380                    row.get::<String, _>("subject_type"),
381                    row.get::<String, _>("subject_id"),
382                )
383            })
384            .collect())
385    }
386}
387
388/// Recursive helper for `expand` — fetches direct tuples and recurses
389/// into userset subjects up to `depth`.
390fn expand_pg<'a>(
391    store: &'a PostgresZanzibarStore,
392    resource: &'a ObjectRef,
393    relation: &'a str,
394    depth: u32,
395    seen: &'a mut Vec<String>,
396) -> std::pin::Pin<
397    Box<dyn std::future::Future<Output = Result<Vec<UsersetTree>>> + Send + 'a>,
398> {
399    Box::pin(async move {
400        if depth == 0 {
401            return Ok(Vec::new());
402        }
403        let key = format!("{}:{}#{}", resource.object_type, resource.object_id, relation);
404        if seen.contains(&key) {
405            return Ok(Vec::new());
406        }
407        seen.push(key);
408        let rows = sqlx::query(
409            "SELECT subject_type, subject_id, subject_rel
410             FROM auth.zanzibar_tuples
411             WHERE object_type = $1 AND object_id = $2 AND relation = $3",
412        )
413        .bind(&resource.object_type)
414        .bind(&resource.object_id)
415        .bind(relation)
416        .fetch_all(&store.pool)
417        .await
418        .context("auth.zanzibar_tuples expand fetch")?;
419        let mut children = Vec::new();
420        for row in rows {
421            let st: String = row.get("subject_type");
422            let sid: String = row.get("subject_id");
423            let sr: String = row.get("subject_rel");
424            if sr.is_empty() {
425                children.push(UsersetTree::Leaf {
426                    subject: SubjectRef::direct(st, sid),
427                });
428            } else {
429                let inner_resource = ObjectRef::new(st.clone(), sid.clone());
430                let sub = expand_pg(store, &inner_resource, &sr, depth - 1, seen).await?;
431                children.push(UsersetTree::Node {
432                    op: TreeOp::TuplesetArrow,
433                    children: vec![
434                        UsersetTree::Leaf {
435                            subject: SubjectRef::userset(st, sid, sr.clone()),
436                        },
437                        UsersetTree::Node {
438                            op: TreeOp::Direct,
439                            children: sub,
440                        },
441                    ],
442                });
443            }
444        }
445        Ok(children)
446    })
447}