Skip to main content

citadel_sql/
schema.rs

1//! Schema manager: in-memory cache of table schemas.
2
3use std::sync::Arc;
4
5use rustc_hash::{FxHashMap, FxHashSet};
6
7use citadel::{Database, SqlCacheHandle};
8use parking_lot::Mutex;
9
10use crate::error::{Result, SqlError};
11use crate::system_tables::{self, VirtualTable};
12use crate::types::{ForeignKeySchemaEntry, TableSchema, ViewDef};
13
14const SCHEMA_TABLE: &[u8] = b"_schema";
15const VIEWS_TABLE: &[u8] = b"_views";
16const TRIGGERS_TABLE: &[u8] = b"_triggers";
17const MATVIEWS_TABLE: &[u8] = b"_matviews";
18
19thread_local! {
20    /// Stack of `(alias → storage_name)` frames pushed by FOR EACH STATEMENT trigger
21    /// firings so `REFERENCING NEW TABLE AS new_t` resolves while the body runs.
22    static TRANSITION_TABLES: std::cell::RefCell<Vec<FxHashMap<String, String>>> =
23        const { std::cell::RefCell::new(Vec::new()) };
24}
25
26fn transition_table_lookup(name_lower: &str) -> Option<String> {
27    TRANSITION_TABLES.with(|cell| {
28        let stack = cell.borrow();
29        for frame in stack.iter().rev() {
30            if let Some(storage) = frame.get(name_lower) {
31                return Some(storage.clone());
32            }
33        }
34        None
35    })
36}
37
38pub(crate) fn push_transition_tables(aliases: FxHashMap<String, String>) -> TransitionGuard {
39    TRANSITION_TABLES.with(|cell| cell.borrow_mut().push(aliases));
40    TransitionGuard
41}
42
43pub(crate) struct TransitionGuard;
44impl Drop for TransitionGuard {
45    fn drop(&mut self) {
46        TRANSITION_TABLES.with(|cell| {
47            cell.borrow_mut().pop();
48        });
49    }
50}
51
52/// Manages table schemas in memory, backed by the `_schema` table.
53pub struct SchemaManager {
54    tables: FxHashMap<String, TableSchema>,
55    views: FxHashMap<String, ViewDef>,
56    virtual_tables: FxHashMap<String, Arc<dyn VirtualTable>>,
57    /// Within a `(target, timing, event)` group, triggers fire in name order.
58    triggers: FxHashMap<String, Vec<crate::types::TriggerDef>>,
59    /// Matview catalog. Backing table shares the matview's name in `tables`; this map
60    /// also gates DML rejection (matviews are read-only outside REFRESH).
61    matviews: FxHashMap<String, crate::types::MatviewDef>,
62    /// Maps user-typed TEMP name to prefixed storage name (`__temp_<conn_id>_<name>`).
63    temp_aliases: FxHashMap<String, String>,
64    /// Each entry is leaked once via `Box::leak` so `get()` can hand out a `&TableSchema`
65    /// from inside `&self` methods. Bounded by `(active triggers × transition aliases)`.
66    transition_schemas: std::cell::RefCell<FxHashMap<String, &'static TableSchema>>,
67    generation: u64,
68    /// Per-Database shared cache (e.g. ANN indexes). Cloned from the Database
69    /// when the Connection opens; all Connections to the same DB share entries.
70    /// Tests created via `empty()` get their own isolated cache.
71    pub sql_caches: SqlCacheHandle,
72    /// Tables modified by DML since the last commit/rollback. Drained on
73    /// commit to invalidate dependent shared caches (e.g. ANN indexes).
74    dml_dirty_tables: std::cell::RefCell<FxHashSet<String>>,
75}
76
77#[derive(Clone)]
78pub struct SchemaSnapshot {
79    tables: FxHashMap<String, TableSchema>,
80    views: FxHashMap<String, ViewDef>,
81    generation: u64,
82}
83
84impl SchemaManager {
85    pub fn empty() -> Self {
86        Self {
87            tables: FxHashMap::default(),
88            views: FxHashMap::default(),
89            virtual_tables: FxHashMap::default(),
90            triggers: FxHashMap::default(),
91            matviews: FxHashMap::default(),
92            temp_aliases: FxHashMap::default(),
93            transition_schemas: std::cell::RefCell::new(FxHashMap::default()),
94            generation: 0,
95            sql_caches: Arc::new(Mutex::new(FxHashMap::default())),
96            dml_dirty_tables: std::cell::RefCell::new(FxHashSet::default()),
97        }
98    }
99
100    /// Mark a table as modified by DML. Caller should normalize the name
101    /// (lowercase) so dedup works correctly.
102    pub fn mark_dml(&self, table_name: &str) {
103        self.dml_dirty_tables
104            .borrow_mut()
105            .insert(table_name.to_ascii_lowercase());
106    }
107
108    /// Take the set of tables modified since the last drain. Returns an empty
109    /// vec if no DML has run since the last commit/rollback.
110    pub fn drain_dml_dirty(&self) -> Vec<String> {
111        self.dml_dirty_tables.borrow_mut().drain().collect()
112    }
113
114    /// Forget pending DML markers without invalidating downstream caches.
115    /// Used on rollback (uncommitted writes leave no caches stale).
116    pub fn clear_dml_dirty(&self) {
117        self.dml_dirty_tables.borrow_mut().clear();
118    }
119
120    pub fn register_temp_alias(&mut self, user_name: &str, prefixed_name: String) {
121        self.temp_aliases
122            .insert(user_name.to_ascii_lowercase(), prefixed_name);
123        self.generation += 1;
124    }
125
126    pub fn unregister_temp_alias(&mut self, user_name: &str) -> Option<String> {
127        let lower = user_name.to_ascii_lowercase();
128        let removed = self.temp_aliases.remove(&lower);
129        if removed.is_some() {
130            self.generation += 1;
131        }
132        removed
133    }
134
135    pub fn temp_alias_iter(&self) -> impl Iterator<Item = (&str, &str)> + '_ {
136        self.temp_aliases
137            .iter()
138            .map(|(k, v)| (k.as_str(), v.as_str()))
139    }
140
141    pub fn resolve_temp(&self, name: &str) -> String {
142        let lower = name.to_ascii_lowercase();
143        if let Some(prefixed) = self.temp_aliases.get(&lower) {
144            return prefixed.clone();
145        }
146        name.to_string()
147    }
148
149    pub fn load(db: &Database) -> Result<Self> {
150        let mut tables = FxHashMap::default();
151
152        let mut rtx = db.begin_read();
153        let mut parse_err: Option<crate::error::SqlError> = None;
154        let scan_result = rtx.table_for_each(SCHEMA_TABLE, |_key, value| {
155            match TableSchema::deserialize(value) {
156                Ok(schema) => {
157                    tables.insert(schema.name.clone(), schema);
158                }
159                Err(e) => {
160                    parse_err = Some(e);
161                }
162            }
163            Ok(())
164        });
165
166        match scan_result {
167            Ok(()) => {}
168            Err(citadel_core::Error::TableNotFound(_)) => {}
169            Err(e) => return Err(e.into()),
170        }
171        if let Some(e) = parse_err {
172            return Err(e);
173        }
174
175        let mut views = FxHashMap::default();
176        let mut rtx2 = db.begin_read();
177        let mut view_err: Option<crate::error::SqlError> = None;
178        let view_scan = rtx2.table_for_each(VIEWS_TABLE, |_key, value| {
179            match ViewDef::deserialize(value) {
180                Ok(vd) => {
181                    views.insert(vd.name.clone(), vd);
182                }
183                Err(e) => {
184                    view_err = Some(e);
185                }
186            }
187            Ok(())
188        });
189
190        match view_scan {
191            Ok(()) => {}
192            Err(citadel_core::Error::TableNotFound(_)) => {}
193            Err(e) => return Err(e.into()),
194        }
195        if let Some(e) = view_err {
196            return Err(e);
197        }
198
199        let mut triggers: FxHashMap<String, Vec<crate::types::TriggerDef>> = FxHashMap::default();
200        let mut rtx3 = db.begin_read();
201        let mut trig_err: Option<crate::error::SqlError> = None;
202        let trig_scan = rtx3.table_for_each(TRIGGERS_TABLE, |_key, value| {
203            match crate::types::TriggerDef::deserialize(value) {
204                Ok(td) => {
205                    triggers
206                        .entry(td.target.to_ascii_lowercase())
207                        .or_default()
208                        .push(td);
209                }
210                Err(e) => {
211                    trig_err = Some(e);
212                }
213            }
214            Ok(())
215        });
216        match trig_scan {
217            Ok(()) => {}
218            Err(citadel_core::Error::TableNotFound(_)) => {}
219            Err(e) => return Err(e.into()),
220        }
221        if let Some(e) = trig_err {
222            return Err(e);
223        }
224        // PG-faithful: triggers fire in name order within a (target, timing, event) group.
225        for v in triggers.values_mut() {
226            v.sort_by(|a, b| a.name.cmp(&b.name));
227        }
228
229        let mut matviews: FxHashMap<String, crate::types::MatviewDef> = FxHashMap::default();
230        let mut rtx4 = db.begin_read();
231        let mut mv_err: Option<crate::error::SqlError> = None;
232        let mv_scan = rtx4.table_for_each(MATVIEWS_TABLE, |_key, value| {
233            match crate::types::MatviewDef::deserialize(value) {
234                Ok(mv) => {
235                    matviews.insert(mv.name.to_ascii_lowercase(), mv);
236                }
237                Err(e) => {
238                    mv_err = Some(e);
239                }
240            }
241            Ok(())
242        });
243        match mv_scan {
244            Ok(()) => {}
245            Err(citadel_core::Error::TableNotFound(_)) => {}
246            Err(e) => return Err(e.into()),
247        }
248        if let Some(e) = mv_err {
249            return Err(e);
250        }
251
252        let mut mgr = Self {
253            tables,
254            views,
255            virtual_tables: FxHashMap::default(),
256            triggers,
257            matviews,
258            temp_aliases: FxHashMap::default(),
259            transition_schemas: std::cell::RefCell::new(FxHashMap::default()),
260            generation: 0,
261            sql_caches: db.sql_cache_handle(),
262            dml_dirty_tables: std::cell::RefCell::new(FxHashSet::default()),
263        };
264        system_tables::register_builtins(&mut mgr);
265        Ok(mgr)
266    }
267
268    pub fn get_virtual(&self, name: &str) -> Option<&Arc<dyn VirtualTable>> {
269        self.virtual_tables.get(name)
270    }
271
272    pub fn register_virtual(&mut self, vt: Arc<dyn VirtualTable>) {
273        let name = vt.name().to_ascii_lowercase();
274        self.virtual_tables.insert(name, vt);
275    }
276
277    pub fn get(&self, name: &str) -> Option<&TableSchema> {
278        let lower = name.to_ascii_lowercase();
279        if let Some(prefixed) = transition_table_lookup(&lower) {
280            if let Some(s) = self.tables.get(&prefixed) {
281                return Some(s);
282            }
283            if let Some(&leaked) = self.transition_schemas.borrow().get(&prefixed) {
284                return Some(leaked);
285            }
286        }
287        if let Some(mv) = self.matviews.get(&lower) {
288            return self.tables.get(&mv.backing_table);
289        }
290        if let Some(prefixed) = self.temp_aliases.get(&lower) {
291            return self.tables.get(prefixed);
292        }
293        if let Some(s) = self.tables.get(name) {
294            return Some(s);
295        }
296        if name.bytes().any(|b| b.is_ascii_uppercase()) {
297            self.tables.get(&lower)
298        } else {
299            None
300        }
301    }
302
303    pub fn register_transition_schema(&self, storage_name: String, schema: TableSchema) {
304        let leaked: &'static TableSchema = Box::leak(Box::new(schema));
305        self.transition_schemas
306            .borrow_mut()
307            .insert(storage_name, leaked);
308    }
309
310    pub fn unregister_transition_schema(&self, storage_name: &str) {
311        self.transition_schemas.borrow_mut().remove(storage_name);
312    }
313
314    pub fn contains(&self, name: &str) -> bool {
315        let lower = name.to_ascii_lowercase();
316        if transition_table_lookup(&lower).is_some() {
317            return true;
318        }
319        if self.matviews.contains_key(&lower) {
320            return true;
321        }
322        if self.temp_aliases.contains_key(&lower) {
323            return true;
324        }
325        if self.tables.contains_key(name) {
326            return true;
327        }
328        if name.bytes().any(|b| b.is_ascii_uppercase()) {
329            self.tables.contains_key(&lower)
330        } else {
331            false
332        }
333    }
334
335    pub fn generation(&self) -> u64 {
336        self.generation
337    }
338
339    pub fn register(&mut self, schema: TableSchema) {
340        let lower = schema.name.to_ascii_lowercase();
341        self.tables.insert(lower, schema);
342        self.generation += 1;
343    }
344
345    pub fn remove(&mut self, name: &str) -> Option<TableSchema> {
346        let lower = name.to_ascii_lowercase();
347        let result = self.tables.remove(&lower);
348        if result.is_some() {
349            self.generation += 1;
350        }
351        result
352    }
353
354    pub fn table_names(&self) -> Vec<&str> {
355        self.tables.keys().map(|s| s.as_str()).collect()
356    }
357
358    pub fn all_schemas(&self) -> impl Iterator<Item = &TableSchema> {
359        self.tables.values()
360    }
361
362    pub fn get_view(&self, name: &str) -> Option<&ViewDef> {
363        if let Some(v) = self.views.get(name) {
364            return Some(v);
365        }
366        if name.bytes().any(|b| b.is_ascii_uppercase()) {
367            self.views.get(&name.to_ascii_lowercase())
368        } else {
369            None
370        }
371    }
372
373    pub fn register_view(&mut self, view: ViewDef) {
374        let lower = view.name.to_ascii_lowercase();
375        self.views.insert(lower, view);
376        self.generation += 1;
377    }
378
379    pub fn remove_view(&mut self, name: &str) -> Option<ViewDef> {
380        let lower = name.to_ascii_lowercase();
381        let result = self.views.remove(&lower);
382        if result.is_some() {
383            self.generation += 1;
384        }
385        result
386    }
387
388    pub fn view_names(&self) -> Vec<&str> {
389        self.views.keys().map(|s| s.as_str()).collect()
390    }
391
392    pub fn triggers_for(&self, target: &str) -> &[crate::types::TriggerDef] {
393        let key = target.to_ascii_lowercase();
394        self.triggers.get(&key).map(|v| v.as_slice()).unwrap_or(&[])
395    }
396
397    pub fn all_triggers(&self) -> impl Iterator<Item = &crate::types::TriggerDef> + '_ {
398        self.triggers.values().flatten()
399    }
400
401    pub fn register_trigger(&mut self, trig: crate::types::TriggerDef) {
402        let target = trig.target.to_ascii_lowercase();
403        let bucket = self.triggers.entry(target).or_default();
404        bucket.push(trig);
405        bucket.sort_by(|a, b| a.name.cmp(&b.name));
406        self.generation += 1;
407    }
408
409    pub fn remove_trigger(&mut self, name: &str) -> Option<crate::types::TriggerDef> {
410        let lower = name.to_ascii_lowercase();
411        let mut result = None;
412        for bucket in self.triggers.values_mut() {
413            if let Some(pos) = bucket
414                .iter()
415                .position(|t| t.name.eq_ignore_ascii_case(&lower))
416            {
417                result = Some(bucket.remove(pos));
418                break;
419            }
420        }
421        if result.is_some() {
422            self.generation += 1;
423        }
424        result
425    }
426
427    /// Caller is responsible for dropping the returned triggers' on-disk catalog rows.
428    pub fn remove_triggers_for(&mut self, target: &str) -> Vec<crate::types::TriggerDef> {
429        let key = target.to_ascii_lowercase();
430        let removed = self.triggers.remove(&key).unwrap_or_default();
431        if !removed.is_empty() {
432            self.generation += 1;
433        }
434        removed
435    }
436
437    pub fn find_trigger(&self, name: &str) -> Option<(&str, &crate::types::TriggerDef)> {
438        let lower = name.to_ascii_lowercase();
439        for (target, bucket) in &self.triggers {
440            if let Some(t) = bucket.iter().find(|t| t.name.eq_ignore_ascii_case(&lower)) {
441                return Some((target.as_str(), t));
442            }
443        }
444        None
445    }
446
447    pub fn set_trigger_enabled(&mut self, name: &str, enabled: bool) -> bool {
448        let lower = name.to_ascii_lowercase();
449        for bucket in self.triggers.values_mut() {
450            if let Some(t) = bucket
451                .iter_mut()
452                .find(|t| t.name.eq_ignore_ascii_case(&lower))
453            {
454                t.enabled = enabled;
455                self.generation += 1;
456                return true;
457            }
458        }
459        false
460    }
461
462    pub fn set_all_triggers_enabled(&mut self, target: &str, enabled: bool) -> usize {
463        let key = target.to_ascii_lowercase();
464        let bucket = match self.triggers.get_mut(&key) {
465            Some(b) => b,
466            None => return 0,
467        };
468        let count = bucket.len();
469        for t in bucket {
470            t.enabled = enabled;
471        }
472        if count > 0 {
473            self.generation += 1;
474        }
475        count
476    }
477
478    pub fn ensure_triggers_table(wtx: &mut citadel_txn::write_txn::WriteTxn<'_>) -> Result<()> {
479        match wtx.create_table(TRIGGERS_TABLE) {
480            Ok(()) => Ok(()),
481            Err(citadel_core::Error::TableAlreadyExists(_)) => Ok(()),
482            Err(e) => Err(e.into()),
483        }
484    }
485
486    pub fn save_trigger(
487        wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
488        trig: &crate::types::TriggerDef,
489    ) -> Result<()> {
490        Self::ensure_triggers_table(wtx)?;
491        let data = trig.serialize();
492        let lower = trig.name.to_ascii_lowercase();
493        wtx.table_insert(TRIGGERS_TABLE, lower.as_bytes(), &data)
494            .map_err(crate::error::SqlError::from)?;
495        Ok(())
496    }
497
498    pub fn delete_trigger(
499        wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
500        name: &str,
501    ) -> Result<()> {
502        Self::ensure_triggers_table(wtx)?;
503        let lower = name.to_ascii_lowercase();
504        wtx.table_delete(TRIGGERS_TABLE, lower.as_bytes())
505            .map_err(crate::error::SqlError::from)?;
506        Ok(())
507    }
508
509    pub fn save_view(wtx: &mut citadel_txn::write_txn::WriteTxn<'_>, view: &ViewDef) -> Result<()> {
510        let lower = view.name.to_ascii_lowercase();
511        let data = view.serialize();
512        wtx.table_insert(VIEWS_TABLE, lower.as_bytes(), &data)?;
513        Ok(())
514    }
515
516    pub fn delete_view(wtx: &mut citadel_txn::write_txn::WriteTxn<'_>, name: &str) -> Result<()> {
517        let lower = name.to_ascii_lowercase();
518        wtx.table_delete(VIEWS_TABLE, lower.as_bytes())
519            .map_err(|e| match e {
520                citadel_core::Error::TableNotFound(_) => SqlError::ViewNotFound(name.into()),
521                other => SqlError::Storage(other),
522            })?;
523        Ok(())
524    }
525
526    pub fn ensure_views_table(wtx: &mut citadel_txn::write_txn::WriteTxn<'_>) -> Result<()> {
527        match wtx.create_table(VIEWS_TABLE) {
528            Ok(()) => Ok(()),
529            Err(citadel_core::Error::TableAlreadyExists(_)) => Ok(()),
530            Err(e) => Err(e.into()),
531        }
532    }
533
534    pub fn get_matview(&self, name: &str) -> Option<&crate::types::MatviewDef> {
535        let lower = name.to_ascii_lowercase();
536        self.matviews.get(&lower)
537    }
538
539    pub fn matview_names(&self) -> Vec<&str> {
540        self.matviews.keys().map(|s| s.as_str()).collect()
541    }
542
543    pub fn all_matviews(&self) -> impl Iterator<Item = &crate::types::MatviewDef> + '_ {
544        self.matviews.values()
545    }
546
547    pub fn register_matview(&mut self, mv: crate::types::MatviewDef) {
548        let lower = mv.name.to_ascii_lowercase();
549        self.matviews.insert(lower, mv);
550        self.generation += 1;
551    }
552
553    pub fn remove_matview(&mut self, name: &str) -> Option<crate::types::MatviewDef> {
554        let lower = name.to_ascii_lowercase();
555        let removed = self.matviews.remove(&lower);
556        if removed.is_some() {
557            self.generation += 1;
558        }
559        removed
560    }
561
562    pub fn ensure_matviews_table(wtx: &mut citadel_txn::write_txn::WriteTxn<'_>) -> Result<()> {
563        match wtx.create_table(MATVIEWS_TABLE) {
564            Ok(()) => Ok(()),
565            Err(citadel_core::Error::TableAlreadyExists(_)) => Ok(()),
566            Err(e) => Err(e.into()),
567        }
568    }
569
570    pub fn save_matview(
571        wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
572        mv: &crate::types::MatviewDef,
573    ) -> Result<()> {
574        Self::ensure_matviews_table(wtx)?;
575        let lower = mv.name.to_ascii_lowercase();
576        let data = mv.serialize();
577        wtx.table_insert(MATVIEWS_TABLE, lower.as_bytes(), &data)?;
578        Ok(())
579    }
580
581    pub fn delete_matview(
582        wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
583        name: &str,
584    ) -> Result<()> {
585        Self::ensure_matviews_table(wtx)?;
586        let lower = name.to_ascii_lowercase();
587        wtx.table_delete(MATVIEWS_TABLE, lower.as_bytes())
588            .map_err(crate::error::SqlError::from)?;
589        Ok(())
590    }
591
592    pub fn child_fks_for(&self, parent: &str) -> Vec<(&str, &ForeignKeySchemaEntry)> {
593        self.tables
594            .iter()
595            .flat_map(|(name, schema)| {
596                schema
597                    .foreign_keys
598                    .iter()
599                    .filter(|fk| fk.foreign_table == parent)
600                    .map(move |fk| (name.as_str(), fk))
601            })
602            .collect()
603    }
604
605    pub fn save_schema(
606        wtx: &mut citadel_txn::write_txn::WriteTxn<'_>,
607        schema: &TableSchema,
608    ) -> Result<()> {
609        let lower = schema.name.to_ascii_lowercase();
610        let data = schema.serialize();
611        wtx.table_insert(SCHEMA_TABLE, lower.as_bytes(), &data)?;
612        Ok(())
613    }
614
615    pub fn delete_schema(wtx: &mut citadel_txn::write_txn::WriteTxn<'_>, name: &str) -> Result<()> {
616        let lower = name.to_ascii_lowercase();
617        wtx.table_delete(SCHEMA_TABLE, lower.as_bytes())
618            .map_err(|e| match e {
619                citadel_core::Error::TableNotFound(_) => SqlError::TableNotFound(name.into()),
620                other => SqlError::Storage(other),
621            })?;
622        Ok(())
623    }
624
625    pub fn ensure_schema_table(wtx: &mut citadel_txn::write_txn::WriteTxn<'_>) -> Result<()> {
626        match wtx.create_table(SCHEMA_TABLE) {
627            Ok(()) => Ok(()),
628            Err(citadel_core::Error::TableAlreadyExists(_)) => Ok(()),
629            Err(e) => Err(e.into()),
630        }
631    }
632
633    pub fn save_snapshot(&self) -> SchemaSnapshot {
634        SchemaSnapshot {
635            tables: self.tables.clone(),
636            views: self.views.clone(),
637            generation: self.generation,
638        }
639    }
640
641    pub fn restore_snapshot(&mut self, snap: SchemaSnapshot) {
642        self.tables = snap.tables;
643        self.views = snap.views;
644        self.generation = snap.generation;
645    }
646}
647
648#[cfg(test)]
649#[path = "schema_tests.rs"]
650mod tests;