1use crate::store::RelationalStore;
2use contextdb_core::*;
3use contextdb_tx::{TxManager, WriteSetApplicator};
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::{SystemTime, UNIX_EPOCH};
7
8pub struct MemRelationalExecutor<S: WriteSetApplicator> {
9 store: Arc<RelationalStore>,
10 tx_mgr: Arc<TxManager<S>>,
11}
12
13impl<S: WriteSetApplicator> MemRelationalExecutor<S> {
14 pub fn new(store: Arc<RelationalStore>, tx_mgr: Arc<TxManager<S>>) -> Self {
15 Self { store, tx_mgr }
16 }
17
18 fn ensure_table_exists(&self, table: &str) -> Result<()> {
19 if self.store.table_meta.read().contains_key(table) {
20 Ok(())
21 } else {
22 Err(Error::TableNotFound(table.to_string()))
23 }
24 }
25
26 pub fn scan_with_tx(
27 &self,
28 tx: Option<TxId>,
29 table: &str,
30 snapshot: SnapshotId,
31 ) -> Result<Vec<VersionedRow>> {
32 let tables = self.store.tables.read();
33 let rows = tables
34 .get(table)
35 .ok_or_else(|| Error::TableNotFound(table.to_string()))?;
36
37 let mut result: Vec<VersionedRow> = rows
38 .iter()
39 .filter(|r| r.visible_at(snapshot))
40 .cloned()
41 .collect();
42
43 if let Some(tx_id) = tx {
44 let _ = self.tx_mgr.with_write_set(tx_id, |ws| {
45 let deleted_row_ids: std::collections::HashSet<RowId> = ws
46 .relational_deletes
47 .iter()
48 .filter(|(t, _, _)| t == table)
49 .map(|(_, row_id, _)| *row_id)
50 .collect();
51 result.retain(|row| !deleted_row_ids.contains(&row.row_id));
52 for (t, row) in &ws.relational_inserts {
53 if t == table && !deleted_row_ids.contains(&row.row_id) {
54 result.push(row.clone());
55 }
56 }
57 });
58 }
59
60 Ok(result)
61 }
62
63 pub fn scan_filter_with_tx(
64 &self,
65 tx: Option<TxId>,
66 table: &str,
67 snapshot: SnapshotId,
68 predicate: &dyn Fn(&VersionedRow) -> bool,
69 ) -> Result<Vec<VersionedRow>> {
70 let all = self.scan_with_tx(tx, table, snapshot)?;
71 Ok(all.into_iter().filter(|r| predicate(r)).collect())
72 }
73
74 pub fn point_lookup_with_tx(
75 &self,
76 tx: Option<TxId>,
77 table: &str,
78 col: &str,
79 value: &Value,
80 snapshot: SnapshotId,
81 ) -> Result<Option<VersionedRow>> {
82 let all = self.scan_with_tx(tx, table, snapshot)?;
83 Ok(all.into_iter().find(|r| r.values.get(col) == Some(value)))
84 }
85
86 fn validate_state_transition(
87 &self,
88 tx: TxId,
89 table: &str,
90 values: &HashMap<ColName, Value>,
91 snapshot: SnapshotId,
92 ) -> Result<()> {
93 let meta = self.store.table_meta.read();
94 let Some(sm) = meta.get(table).and_then(|m| m.state_machine.as_ref()) else {
95 return Ok(());
96 };
97 let col = &sm.column;
98
99 let new_status = match values.get(col) {
100 Some(Value::Text(s)) => s.as_str(),
101 _ => return Ok(()),
102 };
103
104 let id = match values.get("id") {
105 Some(v @ Value::Uuid(_)) => v.clone(),
106 _ => return Ok(()),
107 };
108
109 if let Some(existing) = self.point_lookup_with_tx(Some(tx), table, "id", &id, snapshot)? {
110 let old_status = existing
111 .values
112 .get(col)
113 .and_then(Value::as_text)
114 .unwrap_or("");
115 if !self
116 .store
117 .validate_state_transition(table, col, old_status, new_status)
118 {
119 return Err(Error::InvalidStateTransition(format!(
120 "{} -> {}",
121 old_status, new_status
122 )));
123 }
124 }
125
126 Ok(())
127 }
128
129 pub fn insert_with_tx(
130 &self,
131 tx: TxId,
132 table: &str,
133 values: HashMap<ColName, Value>,
134 snapshot: SnapshotId,
135 ) -> Result<RowId> {
136 self.ensure_table_exists(table)?;
137 self.validate_state_transition(tx, table, &values, snapshot)?;
138
139 let row_id = self.store.new_row_id();
140 let row = VersionedRow {
141 row_id,
142 values,
143 created_tx: tx,
144 deleted_tx: None,
145 lsn: contextdb_core::Lsn(0),
146 created_at: Some(contextdb_core::Wallclock(
147 SystemTime::now()
148 .duration_since(UNIX_EPOCH)
149 .unwrap_or_default()
150 .as_millis() as u64,
151 )),
152 };
153
154 self.tx_mgr.with_write_set(tx, |ws| {
155 ws.relational_inserts.push((table.to_string(), row));
156 })?;
157
158 Ok(row_id)
159 }
160
161 pub fn upsert_with_tx(
162 &self,
163 tx: TxId,
164 table: &str,
165 conflict_col: &str,
166 values: HashMap<ColName, Value>,
167 snapshot: SnapshotId,
168 ) -> Result<UpsertResult> {
169 self.ensure_table_exists(table)?;
170 if self.store.is_immutable(table) {
171 return Err(Error::ImmutableTable(table.to_string()));
172 }
173
174 self.validate_state_transition(tx, table, &values, snapshot)?;
175
176 let conflict_val = values
177 .get(conflict_col)
178 .ok_or_else(|| Error::Other("conflict column not in values".to_string()))?
179 .clone();
180
181 let existing =
182 self.point_lookup_with_tx(Some(tx), table, conflict_col, &conflict_val, snapshot)?;
183
184 match existing {
185 None => {
186 self.insert_with_tx(tx, table, values, snapshot)?;
187 Ok(UpsertResult::Inserted)
188 }
189 Some(existing_row) => {
190 let changed = values
191 .iter()
192 .any(|(k, v)| existing_row.values.get(k) != Some(v));
193 if !changed {
194 return Ok(UpsertResult::NoOp);
195 }
196
197 self.delete(tx, table, existing_row.row_id)?;
198 self.insert_with_tx(tx, table, values, snapshot)?;
199 Ok(UpsertResult::Updated)
200 }
201 }
202 }
203}
204
205impl<S: WriteSetApplicator> RelationalExecutor for MemRelationalExecutor<S> {
206 fn scan(&self, table: &str, snapshot: SnapshotId) -> Result<Vec<VersionedRow>> {
207 self.scan_with_tx(None, table, snapshot)
208 }
209
210 fn scan_filter(
211 &self,
212 table: &str,
213 snapshot: SnapshotId,
214 predicate: &dyn Fn(&VersionedRow) -> bool,
215 ) -> Result<Vec<VersionedRow>> {
216 self.scan_filter_with_tx(None, table, snapshot, predicate)
217 }
218
219 fn point_lookup(
220 &self,
221 table: &str,
222 col: &str,
223 value: &Value,
224 snapshot: SnapshotId,
225 ) -> Result<Option<VersionedRow>> {
226 self.point_lookup_with_tx(None, table, col, value, snapshot)
227 }
228
229 fn insert(&self, tx: TxId, table: &str, values: HashMap<ColName, Value>) -> Result<RowId> {
230 let snapshot = self.tx_mgr.snapshot();
231 self.insert_with_tx(tx, table, values, snapshot)
232 }
233
234 fn upsert(
235 &self,
236 tx: TxId,
237 table: &str,
238 conflict_col: &str,
239 values: HashMap<ColName, Value>,
240 snapshot: SnapshotId,
241 ) -> Result<UpsertResult> {
242 self.upsert_with_tx(tx, table, conflict_col, values, snapshot)
243 }
244
245 fn delete(&self, tx: TxId, table: &str, row_id: RowId) -> Result<()> {
246 self.ensure_table_exists(table)?;
247 if self.store.is_immutable(table) {
248 return Err(Error::ImmutableTable(table.to_string()));
249 }
250
251 self.tx_mgr.with_write_set(tx, |ws| {
252 ws.relational_deletes.push((table.to_string(), row_id, tx));
253 })?;
254
255 Ok(())
256 }
257}