1use std::cell::RefCell;
11
12use wasm_dbms_api::prelude::{DbmsResult, TransactionId};
13use wasm_dbms_memory::prelude::{
14 AccessControl, AccessControlList, MemoryManager, MemoryProvider, SchemaRegistry,
15 TableRegistryPage,
16};
17
18use crate::transaction::journal::Journal;
19use crate::transaction::session::TransactionSession;
20
21pub struct DbmsContext<M, A = AccessControlList>
39where
40 M: MemoryProvider,
41 A: AccessControl,
42{
43 pub(crate) mm: RefCell<MemoryManager<M>>,
45
46 pub(crate) schema_registry: RefCell<SchemaRegistry>,
48
49 pub(crate) acl: RefCell<A>,
51
52 pub(crate) transaction_session: RefCell<TransactionSession>,
54
55 pub(crate) journal: RefCell<Option<Journal>>,
57}
58
59impl<M> DbmsContext<M>
60where
61 M: MemoryProvider,
62{
63 pub fn new(memory: M) -> Self {
66 let mm = MemoryManager::init(memory);
67 let schema_registry = SchemaRegistry::load(&mm).unwrap_or_default();
68 let acl = AccessControlList::load(&mm).unwrap_or_default();
69
70 Self {
71 mm: RefCell::new(mm),
72 schema_registry: RefCell::new(schema_registry),
73 acl: RefCell::new(acl),
74 transaction_session: RefCell::new(TransactionSession::default()),
75 journal: RefCell::new(None),
76 }
77 }
78}
79
80impl<M, A> DbmsContext<M, A>
81where
82 M: MemoryProvider,
83 A: AccessControl,
84{
85 pub fn with_acl(memory: M) -> Self {
87 let mm = MemoryManager::init(memory);
88 let schema_registry = SchemaRegistry::load(&mm).unwrap_or_default();
89 let acl = A::load(&mm).unwrap_or_default();
90
91 Self {
92 mm: RefCell::new(mm),
93 schema_registry: RefCell::new(schema_registry),
94 acl: RefCell::new(acl),
95 transaction_session: RefCell::new(TransactionSession::default()),
96 journal: RefCell::new(None),
97 }
98 }
99
100 pub fn register_table<T: wasm_dbms_api::prelude::TableSchema>(
102 &self,
103 ) -> DbmsResult<TableRegistryPage> {
104 let mut sr = self.schema_registry.borrow_mut();
105 let mut mm = self.mm.borrow_mut();
106 sr.register_table::<T>(&mut mm).map_err(Into::into)
107 }
108
109 pub fn acl_add(&self, identity: A::Id) -> DbmsResult<()> {
111 let mut acl = self.acl.borrow_mut();
112 let mut mm = self.mm.borrow_mut();
113 acl.add_identity(identity, &mut mm).map_err(Into::into)
114 }
115
116 pub fn acl_remove(&self, identity: &A::Id) -> DbmsResult<()> {
118 let mut acl = self.acl.borrow_mut();
119 let mut mm = self.mm.borrow_mut();
120 acl.remove_identity(identity, &mut mm).map_err(Into::into)
121 }
122
123 pub fn acl_allowed(&self) -> Vec<A::Id> {
125 let acl = self.acl.borrow();
126 acl.allowed_identities()
127 }
128
129 pub fn acl_is_allowed(&self, identity: &A::Id) -> bool {
131 let acl = self.acl.borrow();
132 acl.is_allowed(identity)
133 }
134
135 pub fn begin_transaction(&self, owner: Vec<u8>) -> TransactionId {
137 let mut ts = self.transaction_session.borrow_mut();
138 ts.begin_transaction(owner)
139 }
140
141 pub fn has_transaction(&self, tx_id: &TransactionId, caller: &[u8]) -> bool {
143 let ts = self.transaction_session.borrow();
144 ts.has_transaction(tx_id, caller)
145 }
146}
147
148impl<M, A> std::fmt::Debug for DbmsContext<M, A>
149where
150 M: MemoryProvider,
151 A: AccessControl + std::fmt::Debug,
152{
153 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154 f.debug_struct("DbmsContext")
155 .field("schema_registry", &self.schema_registry)
156 .field("acl", &self.acl)
157 .field("transaction_session", &self.transaction_session)
158 .finish_non_exhaustive()
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use wasm_dbms_memory::prelude::HeapMemoryProvider;
165
166 use super::*;
167
168 #[test]
169 fn test_should_create_context() {
170 let ctx = DbmsContext::new(HeapMemoryProvider::default());
171 assert!(ctx.acl_allowed().is_empty());
172 }
173
174 #[test]
175 fn test_should_add_acl_identity() {
176 let ctx = DbmsContext::new(HeapMemoryProvider::default());
177 ctx.acl_add(vec![1, 2, 3]).unwrap();
178 assert!(ctx.acl_is_allowed(&vec![1, 2, 3]));
179 assert!(!ctx.acl_is_allowed(&vec![4, 5, 6]));
180 }
181
182 #[test]
183 fn test_should_remove_acl_identity() {
184 let ctx = DbmsContext::new(HeapMemoryProvider::default());
185 ctx.acl_add(vec![1, 2, 3]).unwrap();
186 ctx.acl_add(vec![4, 5, 6]).unwrap();
187 ctx.acl_remove(&vec![1, 2, 3]).unwrap();
188 assert!(!ctx.acl_is_allowed(&vec![1, 2, 3]));
189 assert!(ctx.acl_is_allowed(&vec![4, 5, 6]));
190 }
191
192 #[test]
193 fn test_should_begin_transaction() {
194 let ctx = DbmsContext::new(HeapMemoryProvider::default());
195 let owner = vec![1, 2, 3];
196 let tx_id = ctx.begin_transaction(owner.clone());
197 assert!(ctx.has_transaction(&tx_id, &owner));
198 assert!(!ctx.has_transaction(&tx_id, &[4, 5, 6]));
199 }
200
201 #[test]
202 fn test_should_debug_context() {
203 let ctx = DbmsContext::new(HeapMemoryProvider::default());
204 let debug = format!("{ctx:?}");
205 assert!(debug.contains("DbmsContext"));
206 }
207}