1use std::sync::{Arc, RwLock};
2
3use arrow::record_batch::RecordBatch;
4use llkv_executor::ExecutorRowBatch;
5use llkv_expr::expr::Expr as LlkvExpr;
6use llkv_result::{Error, Result as LlkvResult};
7use llkv_storage::pager::Pager;
8use llkv_table::{CatalogDdl, ConstraintEnforcementMode, SingleColumnIndexDescriptor, TableId};
9use llkv_transaction::{TransactionContext, TransactionResult, TransactionSnapshot, TxnId};
10use simd_r_drive_entry_handle::EntryHandle;
11
12use crate::{
13 AlterTablePlan, CreateIndexPlan, CreateTablePlan, CreateViewPlan, DeletePlan, DropIndexPlan,
14 DropTablePlan, DropViewPlan, InsertPlan, PlanColumnSpec, RenameTablePlan, RuntimeContext,
15 RuntimeStatementResult, SelectExecution, SelectPlan, UpdatePlan,
16};
17use llkv_plan::TruncatePlan;
18
19pub struct RuntimeTransactionContext<P>
27where
28 P: Pager<Blob = EntryHandle> + Send + Sync,
29{
30 ctx: Arc<RuntimeContext<P>>,
31 snapshot: RwLock<TransactionSnapshot>,
32 constraint_mode: RwLock<ConstraintEnforcementMode>,
33}
34
35impl<P> RuntimeTransactionContext<P>
36where
37 P: Pager<Blob = EntryHandle> + Send + Sync,
38{
39 pub(crate) fn new(ctx: Arc<RuntimeContext<P>>) -> Self {
40 let snapshot = ctx.default_snapshot();
41 Self {
42 ctx,
43 snapshot: RwLock::new(snapshot),
44 constraint_mode: RwLock::new(ConstraintEnforcementMode::Immediate),
45 }
46 }
47
48 pub(crate) fn set_constraint_mode(&self, mode: ConstraintEnforcementMode) {
49 *self
50 .constraint_mode
51 .write()
52 .expect("constraint mode lock poisoned") = mode;
53 }
54
55 pub(crate) fn constraint_mode(&self) -> ConstraintEnforcementMode {
56 *self
57 .constraint_mode
58 .read()
59 .expect("constraint mode lock poisoned")
60 }
61
62 fn update_snapshot(&self, snapshot: TransactionSnapshot) {
63 let mut guard = self.snapshot.write().expect("snapshot lock poisoned");
64 *guard = snapshot;
65 }
66
67 fn current_snapshot(&self) -> TransactionSnapshot {
68 *self.snapshot.read().expect("snapshot lock poisoned")
69 }
70
71 pub(crate) fn context(&self) -> &Arc<RuntimeContext<P>> {
72 &self.ctx
73 }
74
75 pub(crate) fn ctx(&self) -> &RuntimeContext<P> {
76 &self.ctx
77 }
78}
79
80impl<P> CatalogDdl for RuntimeTransactionContext<P>
81where
82 P: Pager<Blob = EntryHandle> + Send + Sync + 'static,
83{
84 type CreateTableOutput = TransactionResult<P>;
85 type DropTableOutput = ();
86 type RenameTableOutput = ();
87 type AlterTableOutput = TransactionResult<P>;
88 type CreateIndexOutput = TransactionResult<P>;
89 type DropIndexOutput = Option<SingleColumnIndexDescriptor>;
90
91 fn create_table(&self, plan: CreateTablePlan) -> LlkvResult<Self::CreateTableOutput> {
92 let ctx = self.context();
93 let result = CatalogDdl::create_table(ctx.as_ref(), plan)?;
94 Ok(convert_statement_result(result))
95 }
96
97 fn drop_table(&self, plan: DropTablePlan) -> LlkvResult<Self::DropTableOutput> {
98 CatalogDdl::drop_table(self.ctx.as_ref(), plan)
99 }
100
101 fn rename_table(&self, plan: RenameTablePlan) -> LlkvResult<Self::RenameTableOutput> {
102 CatalogDdl::rename_table(self.ctx.as_ref(), plan)
103 }
104
105 fn alter_table(&self, plan: AlterTablePlan) -> LlkvResult<Self::AlterTableOutput> {
106 let ctx = self.context();
107 let result = CatalogDdl::alter_table(ctx.as_ref(), plan)?;
108 Ok(convert_statement_result(result))
109 }
110
111 fn create_index(&self, plan: CreateIndexPlan) -> LlkvResult<Self::CreateIndexOutput> {
112 let ctx = self.context();
113 let result = CatalogDdl::create_index(ctx.as_ref(), plan)?;
114 Ok(convert_statement_result(result))
115 }
116
117 fn drop_index(&self, plan: DropIndexPlan) -> LlkvResult<Self::DropIndexOutput> {
118 CatalogDdl::drop_index(self.ctx.as_ref(), plan)
119 }
120
121 fn create_view(&self, plan: CreateViewPlan) -> LlkvResult<()> {
122 CatalogDdl::create_view(self.ctx.as_ref(), plan)
123 }
124
125 fn drop_view(&self, plan: DropViewPlan) -> LlkvResult<()> {
126 CatalogDdl::drop_view(self.ctx.as_ref(), plan)
127 }
128}
129
130impl<P> TransactionContext for RuntimeTransactionContext<P>
132where
133 P: Pager<Blob = EntryHandle> + Send + Sync + 'static,
134{
135 type Pager = P;
136 type Snapshot = llkv_table::catalog::TableCatalogSnapshot;
137
138 fn set_snapshot(&self, snapshot: TransactionSnapshot) {
139 self.update_snapshot(snapshot);
140 }
141
142 fn snapshot(&self) -> TransactionSnapshot {
143 self.current_snapshot()
144 }
145
146 fn table_column_specs(&self, table_name: &str) -> LlkvResult<Vec<PlanColumnSpec>> {
147 let (_, canonical_name) = llkv_table::canonical_table_name(table_name)?;
148 self.context().catalog().table_column_specs(&canonical_name)
149 }
150
151 fn export_table_rows(&self, table_name: &str) -> LlkvResult<ExecutorRowBatch> {
152 RuntimeContext::export_table_rows(self.context(), table_name)
153 }
154
155 fn get_batches_with_row_ids(
156 &self,
157 table_name: &str,
158 filter: Option<LlkvExpr<'static, String>>,
159 ) -> LlkvResult<Vec<RecordBatch>> {
160 self.context()
161 .get_batches_with_row_ids(table_name, filter, self.snapshot())
162 }
163
164 fn execute_select(&self, plan: SelectPlan) -> LlkvResult<SelectExecution<Self::Pager>> {
165 self.context().execute_select(plan, self.snapshot())
166 }
167
168 fn apply_create_table_plan(&self, plan: CreateTablePlan) -> LlkvResult<TransactionResult<P>> {
169 let ctx = self.context();
170 let result = CatalogDdl::create_table(ctx.as_ref(), plan)?;
171 Ok(convert_statement_result(result))
172 }
173
174 fn drop_table(&self, plan: DropTablePlan) -> LlkvResult<()> {
175 CatalogDdl::drop_table(self.ctx.as_ref(), plan)
176 }
177
178 fn insert(&self, plan: InsertPlan) -> LlkvResult<TransactionResult<P>> {
179 tracing::trace!(
180 "[TX_RUNTIME] TransactionContext::insert plan.table='{}', context_pager={:p}",
181 plan.table,
182 &*self.ctx.pager
183 );
184 let snapshot = self.current_snapshot();
185 let mode = self.constraint_mode();
186 let result = self.ctx().insert(plan, snapshot, mode)?;
187 Ok(convert_statement_result(result))
188 }
189
190 fn update(&self, plan: UpdatePlan) -> LlkvResult<TransactionResult<P>> {
191 let snapshot = self.current_snapshot();
192 let result = self.ctx().update(plan, snapshot)?;
193 Ok(convert_statement_result(result))
194 }
195
196 fn delete(&self, plan: DeletePlan) -> LlkvResult<TransactionResult<P>> {
197 let snapshot = self.current_snapshot();
198 let result = self.ctx().delete(plan, snapshot)?;
199 Ok(convert_statement_result(result))
200 }
201
202 fn truncate(&self, plan: TruncatePlan) -> LlkvResult<TransactionResult<P>> {
203 let snapshot = self.current_snapshot();
204 let result = self.ctx().truncate(plan, snapshot)?;
205 Ok(convert_statement_result(result))
206 }
207
208 fn create_index(&self, plan: CreateIndexPlan) -> LlkvResult<TransactionResult<P>> {
209 let ctx = self.context();
210 let result = CatalogDdl::create_index(ctx.as_ref(), plan)?;
211 Ok(convert_statement_result(result))
212 }
213
214 fn append_batches_with_row_ids(
215 &self,
216 table_name: &str,
217 batches: Vec<RecordBatch>,
218 ) -> LlkvResult<usize> {
219 RuntimeContext::append_batches_with_row_ids(self.context(), table_name, batches)
220 }
221
222 fn table_names(&self) -> Vec<String> {
223 RuntimeContext::table_names(self.context())
224 }
225
226 fn table_id(&self, table_name: &str) -> LlkvResult<TableId> {
227 let ctx = self.context();
228 if ctx.is_table_marked_dropped(table_name) {
229 return Err(Error::InvalidArgumentError(format!(
230 "table '{}' has been dropped",
231 table_name
232 )));
233 }
234
235 let table = ctx.lookup_table(table_name)?;
236 Ok(table.table_id())
237 }
238
239 fn catalog_snapshot(&self) -> Self::Snapshot {
240 let ctx = self.context();
241 ctx.catalog.snapshot()
242 }
243
244 fn validate_commit_constraints(&self, txn_id: TxnId) -> LlkvResult<()> {
245 self.ctx.validate_primary_keys_for_commit(txn_id)
246 }
247
248 fn clear_transaction_state(&self, txn_id: TxnId) {
249 self.ctx.clear_transaction_state(txn_id);
250 }
251}
252
253fn convert_statement_result<P>(result: RuntimeStatementResult<P>) -> TransactionResult<P>
254where
255 P: Pager<Blob = EntryHandle> + Send + Sync + 'static,
256{
257 use llkv_transaction::TransactionResult as TxResult;
258 match result {
259 RuntimeStatementResult::CreateTable { table_name } => TxResult::CreateTable { table_name },
260 RuntimeStatementResult::CreateIndex {
261 table_name,
262 index_name,
263 } => TxResult::CreateIndex {
264 table_name,
265 index_name,
266 },
267 RuntimeStatementResult::Insert { rows_inserted, .. } => TxResult::Insert { rows_inserted },
268 RuntimeStatementResult::Update { rows_updated, .. } => TxResult::Update {
269 rows_matched: rows_updated,
270 rows_updated,
271 },
272 RuntimeStatementResult::Delete { rows_deleted, .. } => TxResult::Delete { rows_deleted },
273 RuntimeStatementResult::Transaction { kind } => TxResult::Transaction { kind },
274 RuntimeStatementResult::NoOp => TxResult::NoOp,
275 _ => panic!("unsupported StatementResult conversion"),
276 }
277}