Skip to main content

contextdb_relational/
mem.rs

1use crate::store::RelationalStore;
2use contextdb_core::*;
3use contextdb_tx::{TxManager, WriteSetApplicator};
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::{SystemTime, UNIX_EPOCH};
7
8pub struct MemRelationalExecutor<S: WriteSetApplicator> {
9    store: Arc<RelationalStore>,
10    tx_mgr: Arc<TxManager<S>>,
11}
12
13impl<S: WriteSetApplicator> MemRelationalExecutor<S> {
14    pub fn new(store: Arc<RelationalStore>, tx_mgr: Arc<TxManager<S>>) -> Self {
15        Self { store, tx_mgr }
16    }
17
18    fn ensure_table_exists(&self, table: &str) -> Result<()> {
19        if self.store.table_meta.read().contains_key(table) {
20            Ok(())
21        } else {
22            Err(Error::TableNotFound(table.to_string()))
23        }
24    }
25
26    pub fn scan_with_tx(
27        &self,
28        tx: Option<TxId>,
29        table: &str,
30        snapshot: SnapshotId,
31    ) -> Result<Vec<VersionedRow>> {
32        let tables = self.store.tables.read();
33        let rows = tables
34            .get(table)
35            .ok_or_else(|| Error::TableNotFound(table.to_string()))?;
36
37        let mut result: Vec<VersionedRow> = rows
38            .iter()
39            .filter(|r| r.visible_at(snapshot))
40            .cloned()
41            .collect();
42
43        if let Some(tx_id) = tx {
44            let _ = self.tx_mgr.with_write_set(tx_id, |ws| {
45                let deleted_row_ids: std::collections::HashSet<RowId> = ws
46                    .relational_deletes
47                    .iter()
48                    .filter(|(t, _, _)| t == table)
49                    .map(|(_, row_id, _)| *row_id)
50                    .collect();
51                result.retain(|row| !deleted_row_ids.contains(&row.row_id));
52                for (t, row) in &ws.relational_inserts {
53                    if t == table {
54                        result.push(row.clone());
55                    }
56                }
57            });
58        }
59
60        Ok(result)
61    }
62
63    pub fn scan_filter_with_tx(
64        &self,
65        tx: Option<TxId>,
66        table: &str,
67        snapshot: SnapshotId,
68        predicate: &dyn Fn(&VersionedRow) -> bool,
69    ) -> Result<Vec<VersionedRow>> {
70        let all = self.scan_with_tx(tx, table, snapshot)?;
71        Ok(all.into_iter().filter(|r| predicate(r)).collect())
72    }
73
74    pub fn point_lookup_with_tx(
75        &self,
76        tx: Option<TxId>,
77        table: &str,
78        col: &str,
79        value: &Value,
80        snapshot: SnapshotId,
81    ) -> Result<Option<VersionedRow>> {
82        let all = self.scan_with_tx(tx, table, snapshot)?;
83        Ok(all.into_iter().find(|r| r.values.get(col) == Some(value)))
84    }
85
86    fn validate_state_transition(
87        &self,
88        tx: TxId,
89        table: &str,
90        values: &HashMap<ColName, Value>,
91        snapshot: SnapshotId,
92    ) -> Result<()> {
93        let meta = self.store.table_meta.read();
94        let Some(sm) = meta.get(table).and_then(|m| m.state_machine.as_ref()) else {
95            return Ok(());
96        };
97        let col = &sm.column;
98
99        let new_status = match values.get(col) {
100            Some(Value::Text(s)) => s.as_str(),
101            _ => return Ok(()),
102        };
103
104        let id = match values.get("id") {
105            Some(v @ Value::Uuid(_)) => v.clone(),
106            _ => return Ok(()),
107        };
108
109        if let Some(existing) = self.point_lookup_with_tx(Some(tx), table, "id", &id, snapshot)? {
110            let old_status = existing
111                .values
112                .get(col)
113                .and_then(Value::as_text)
114                .unwrap_or("");
115            if !self
116                .store
117                .validate_state_transition(table, col, old_status, new_status)
118            {
119                return Err(Error::InvalidStateTransition(format!(
120                    "{} -> {}",
121                    old_status, new_status
122                )));
123            }
124        }
125
126        Ok(())
127    }
128
129    pub fn insert_with_tx(
130        &self,
131        tx: TxId,
132        table: &str,
133        values: HashMap<ColName, Value>,
134        snapshot: SnapshotId,
135    ) -> Result<RowId> {
136        self.ensure_table_exists(table)?;
137        self.validate_state_transition(tx, table, &values, snapshot)?;
138
139        let row_id = self.store.new_row_id();
140        let row = VersionedRow {
141            row_id,
142            values,
143            created_tx: tx,
144            deleted_tx: None,
145            lsn: 0,
146            created_at: Some(
147                SystemTime::now()
148                    .duration_since(UNIX_EPOCH)
149                    .unwrap_or_default()
150                    .as_millis() as u64,
151            ),
152        };
153
154        self.tx_mgr.with_write_set(tx, |ws| {
155            ws.relational_inserts.push((table.to_string(), row));
156        })?;
157
158        Ok(row_id)
159    }
160
161    pub fn upsert_with_tx(
162        &self,
163        tx: TxId,
164        table: &str,
165        conflict_col: &str,
166        values: HashMap<ColName, Value>,
167        snapshot: SnapshotId,
168    ) -> Result<UpsertResult> {
169        self.ensure_table_exists(table)?;
170        if self.store.is_immutable(table) {
171            return Err(Error::ImmutableTable(table.to_string()));
172        }
173
174        self.validate_state_transition(tx, table, &values, snapshot)?;
175
176        let conflict_val = values
177            .get(conflict_col)
178            .ok_or_else(|| Error::Other("conflict column not in values".to_string()))?
179            .clone();
180
181        let existing =
182            self.point_lookup_with_tx(Some(tx), table, conflict_col, &conflict_val, snapshot)?;
183
184        match existing {
185            None => {
186                self.insert_with_tx(tx, table, values, snapshot)?;
187                Ok(UpsertResult::Inserted)
188            }
189            Some(existing_row) => {
190                let changed = values
191                    .iter()
192                    .any(|(k, v)| existing_row.values.get(k) != Some(v));
193                if !changed {
194                    return Ok(UpsertResult::NoOp);
195                }
196
197                self.delete(tx, table, existing_row.row_id)?;
198                self.insert_with_tx(tx, table, values, snapshot)?;
199                Ok(UpsertResult::Updated)
200            }
201        }
202    }
203}
204
205impl<S: WriteSetApplicator> RelationalExecutor for MemRelationalExecutor<S> {
206    fn scan(&self, table: &str, snapshot: SnapshotId) -> Result<Vec<VersionedRow>> {
207        self.scan_with_tx(None, table, snapshot)
208    }
209
210    fn scan_filter(
211        &self,
212        table: &str,
213        snapshot: SnapshotId,
214        predicate: &dyn Fn(&VersionedRow) -> bool,
215    ) -> Result<Vec<VersionedRow>> {
216        self.scan_filter_with_tx(None, table, snapshot, predicate)
217    }
218
219    fn point_lookup(
220        &self,
221        table: &str,
222        col: &str,
223        value: &Value,
224        snapshot: SnapshotId,
225    ) -> Result<Option<VersionedRow>> {
226        self.point_lookup_with_tx(None, table, col, value, snapshot)
227    }
228
229    fn insert(&self, tx: TxId, table: &str, values: HashMap<ColName, Value>) -> Result<RowId> {
230        let snapshot = self.tx_mgr.snapshot();
231        self.insert_with_tx(tx, table, values, snapshot)
232    }
233
234    fn upsert(
235        &self,
236        tx: TxId,
237        table: &str,
238        conflict_col: &str,
239        values: HashMap<ColName, Value>,
240        snapshot: SnapshotId,
241    ) -> Result<UpsertResult> {
242        self.upsert_with_tx(tx, table, conflict_col, values, snapshot)
243    }
244
245    fn delete(&self, tx: TxId, table: &str, row_id: RowId) -> Result<()> {
246        self.ensure_table_exists(table)?;
247        if self.store.is_immutable(table) {
248            return Err(Error::ImmutableTable(table.to_string()));
249        }
250
251        self.tx_mgr.with_write_set(tx, |ws| {
252            ws.relational_deletes.push((table.to_string(), row_id, tx));
253        })?;
254
255        Ok(())
256    }
257}