1use std::sync::{Arc, RwLock};
2
3use arrow::record_batch::RecordBatch;
4use llkv_executor::ExecutorRowBatch;
5use llkv_expr::expr::Expr as LlkvExpr;
6use llkv_result::{Error, Result as LlkvResult};
7use llkv_storage::pager::Pager;
8use llkv_table::{CatalogDdl, SingleColumnIndexDescriptor, TableId};
9use llkv_transaction::{TransactionContext, TransactionResult, TransactionSnapshot, TxnId};
10use simd_r_drive_entry_handle::EntryHandle;
11
12use crate::{
13 AlterTablePlan, CreateIndexPlan, CreateTablePlan, DeletePlan, DropIndexPlan, DropTablePlan,
14 InsertPlan, PlanColumnSpec, RenameTablePlan, RuntimeContext, RuntimeStatementResult,
15 SelectExecution, SelectPlan, UpdatePlan,
16};
17
18pub struct RuntimeTransactionContext<P>
26where
27 P: Pager<Blob = EntryHandle> + Send + Sync,
28{
29 ctx: Arc<RuntimeContext<P>>,
30 snapshot: RwLock<TransactionSnapshot>,
31}
32
33impl<P> RuntimeTransactionContext<P>
34where
35 P: Pager<Blob = EntryHandle> + Send + Sync,
36{
37 pub(crate) fn new(ctx: Arc<RuntimeContext<P>>) -> Self {
38 let snapshot = ctx.default_snapshot();
39 Self {
40 ctx,
41 snapshot: RwLock::new(snapshot),
42 }
43 }
44
45 fn update_snapshot(&self, snapshot: TransactionSnapshot) {
46 let mut guard = self.snapshot.write().expect("snapshot lock poisoned");
47 *guard = snapshot;
48 }
49
50 fn current_snapshot(&self) -> TransactionSnapshot {
51 *self.snapshot.read().expect("snapshot lock poisoned")
52 }
53
54 pub(crate) fn context(&self) -> &Arc<RuntimeContext<P>> {
55 &self.ctx
56 }
57
58 pub(crate) fn ctx(&self) -> &RuntimeContext<P> {
59 &self.ctx
60 }
61}
62
63impl<P> CatalogDdl for RuntimeTransactionContext<P>
64where
65 P: Pager<Blob = EntryHandle> + Send + Sync + 'static,
66{
67 type CreateTableOutput = TransactionResult<P>;
68 type DropTableOutput = ();
69 type RenameTableOutput = ();
70 type AlterTableOutput = TransactionResult<P>;
71 type CreateIndexOutput = TransactionResult<P>;
72 type DropIndexOutput = Option<SingleColumnIndexDescriptor>;
73
74 fn create_table(&self, plan: CreateTablePlan) -> LlkvResult<Self::CreateTableOutput> {
75 let ctx = self.context();
76 let result = CatalogDdl::create_table(ctx.as_ref(), plan)?;
77 Ok(convert_statement_result(result))
78 }
79
80 fn drop_table(&self, plan: DropTablePlan) -> LlkvResult<Self::DropTableOutput> {
81 CatalogDdl::drop_table(self.ctx.as_ref(), plan)
82 }
83
84 fn rename_table(&self, plan: RenameTablePlan) -> LlkvResult<Self::RenameTableOutput> {
85 CatalogDdl::rename_table(self.ctx.as_ref(), plan)
86 }
87
88 fn alter_table(&self, plan: AlterTablePlan) -> LlkvResult<Self::AlterTableOutput> {
89 let ctx = self.context();
90 let result = CatalogDdl::alter_table(ctx.as_ref(), plan)?;
91 Ok(convert_statement_result(result))
92 }
93
94 fn create_index(&self, plan: CreateIndexPlan) -> LlkvResult<Self::CreateIndexOutput> {
95 let ctx = self.context();
96 let result = CatalogDdl::create_index(ctx.as_ref(), plan)?;
97 Ok(convert_statement_result(result))
98 }
99
100 fn drop_index(&self, plan: DropIndexPlan) -> LlkvResult<Self::DropIndexOutput> {
101 CatalogDdl::drop_index(self.ctx.as_ref(), plan)
102 }
103}
104
105impl<P> TransactionContext for RuntimeTransactionContext<P>
107where
108 P: Pager<Blob = EntryHandle> + Send + Sync + 'static,
109{
110 type Pager = P;
111 type Snapshot = llkv_table::catalog::TableCatalogSnapshot;
112
113 fn set_snapshot(&self, snapshot: TransactionSnapshot) {
114 self.update_snapshot(snapshot);
115 }
116
117 fn snapshot(&self) -> TransactionSnapshot {
118 self.current_snapshot()
119 }
120
121 fn table_column_specs(&self, table_name: &str) -> LlkvResult<Vec<PlanColumnSpec>> {
122 let (_, canonical_name) = llkv_table::canonical_table_name(table_name)?;
123 self.context().catalog().table_column_specs(&canonical_name)
124 }
125
126 fn export_table_rows(&self, table_name: &str) -> LlkvResult<ExecutorRowBatch> {
127 RuntimeContext::export_table_rows(self.context(), table_name)
128 }
129
130 fn get_batches_with_row_ids(
131 &self,
132 table_name: &str,
133 filter: Option<LlkvExpr<'static, String>>,
134 ) -> LlkvResult<Vec<RecordBatch>> {
135 self.context()
136 .get_batches_with_row_ids(table_name, filter, self.snapshot())
137 }
138
139 fn execute_select(&self, plan: SelectPlan) -> LlkvResult<SelectExecution<Self::Pager>> {
140 self.context().execute_select(plan, self.snapshot())
141 }
142
143 fn apply_create_table_plan(&self, plan: CreateTablePlan) -> LlkvResult<TransactionResult<P>> {
144 let ctx = self.context();
145 let result = CatalogDdl::create_table(ctx.as_ref(), plan)?;
146 Ok(convert_statement_result(result))
147 }
148
149 fn drop_table(&self, plan: DropTablePlan) -> LlkvResult<()> {
150 CatalogDdl::drop_table(self.ctx.as_ref(), plan)
151 }
152
153 fn insert(&self, plan: InsertPlan) -> LlkvResult<TransactionResult<P>> {
154 tracing::trace!(
155 "[TX_RUNTIME] TransactionContext::insert plan.table='{}', context_pager={:p}",
156 plan.table,
157 &*self.ctx.pager
158 );
159 let snapshot = self.current_snapshot();
160 let result = self.ctx().insert(plan, snapshot)?;
161 Ok(convert_statement_result(result))
162 }
163
164 fn update(&self, plan: UpdatePlan) -> LlkvResult<TransactionResult<P>> {
165 let snapshot = self.current_snapshot();
166 let result = self.ctx().update(plan, snapshot)?;
167 Ok(convert_statement_result(result))
168 }
169
170 fn delete(&self, plan: DeletePlan) -> LlkvResult<TransactionResult<P>> {
171 let snapshot = self.current_snapshot();
172 let result = self.ctx().delete(plan, snapshot)?;
173 Ok(convert_statement_result(result))
174 }
175
176 fn create_index(&self, plan: CreateIndexPlan) -> LlkvResult<TransactionResult<P>> {
177 let ctx = self.context();
178 let result = CatalogDdl::create_index(ctx.as_ref(), plan)?;
179 Ok(convert_statement_result(result))
180 }
181
182 fn append_batches_with_row_ids(
183 &self,
184 table_name: &str,
185 batches: Vec<RecordBatch>,
186 ) -> LlkvResult<usize> {
187 RuntimeContext::append_batches_with_row_ids(self.context(), table_name, batches)
188 }
189
190 fn table_names(&self) -> Vec<String> {
191 RuntimeContext::table_names(self.context())
192 }
193
194 fn table_id(&self, table_name: &str) -> LlkvResult<TableId> {
195 let ctx = self.context();
196 if ctx.is_table_marked_dropped(table_name) {
197 return Err(Error::InvalidArgumentError(format!(
198 "table '{}' has been dropped",
199 table_name
200 )));
201 }
202
203 let table = ctx.lookup_table(table_name)?;
204 Ok(table.table.table_id())
205 }
206
207 fn catalog_snapshot(&self) -> Self::Snapshot {
208 let ctx = self.context();
209 ctx.catalog.snapshot()
210 }
211
212 fn validate_commit_constraints(&self, txn_id: TxnId) -> LlkvResult<()> {
213 self.ctx.validate_primary_keys_for_commit(txn_id)
214 }
215
216 fn clear_transaction_state(&self, txn_id: TxnId) {
217 self.ctx.clear_transaction_state(txn_id);
218 }
219}
220
221fn convert_statement_result<P>(result: RuntimeStatementResult<P>) -> TransactionResult<P>
222where
223 P: Pager<Blob = EntryHandle> + Send + Sync + 'static,
224{
225 use llkv_transaction::TransactionResult as TxResult;
226 match result {
227 RuntimeStatementResult::CreateTable { table_name } => TxResult::CreateTable { table_name },
228 RuntimeStatementResult::CreateIndex {
229 table_name,
230 index_name,
231 } => TxResult::CreateIndex {
232 table_name,
233 index_name,
234 },
235 RuntimeStatementResult::Insert { rows_inserted, .. } => TxResult::Insert { rows_inserted },
236 RuntimeStatementResult::Update { rows_updated, .. } => TxResult::Update {
237 rows_matched: rows_updated,
238 rows_updated,
239 },
240 RuntimeStatementResult::Delete { rows_deleted, .. } => TxResult::Delete { rows_deleted },
241 RuntimeStatementResult::Transaction { kind } => TxResult::Transaction { kind },
242 RuntimeStatementResult::NoOp => TxResult::NoOp,
243 _ => panic!("unsupported StatementResult conversion"),
244 }
245}