1use std::sync::{Arc, RwLock};
2
3use arrow::record_batch::RecordBatch;
4use llkv_executor::ExecutorRowBatch;
5use llkv_expr::expr::Expr as LlkvExpr;
6use llkv_result::{Error, Result as LlkvResult};
7use llkv_storage::pager::Pager;
8use llkv_table::{CatalogDdl, SingleColumnIndexDescriptor, TableId};
9use llkv_transaction::{TransactionContext, TransactionResult, TransactionSnapshot, TxnId};
10use simd_r_drive_entry_handle::EntryHandle;
11
12use crate::{
13 AlterTablePlan, CreateIndexPlan, CreateTablePlan, CreateViewPlan, DeletePlan, DropIndexPlan,
14 DropTablePlan, DropViewPlan, InsertPlan, PlanColumnSpec, RenameTablePlan, RuntimeContext,
15 RuntimeStatementResult, SelectExecution, SelectPlan, UpdatePlan,
16};
17use llkv_plan::TruncatePlan;
18
19pub struct RuntimeTransactionContext<P>
27where
28 P: Pager<Blob = EntryHandle> + Send + Sync,
29{
30 ctx: Arc<RuntimeContext<P>>,
31 snapshot: RwLock<TransactionSnapshot>,
32}
33
34impl<P> RuntimeTransactionContext<P>
35where
36 P: Pager<Blob = EntryHandle> + Send + Sync,
37{
38 pub(crate) fn new(ctx: Arc<RuntimeContext<P>>) -> Self {
39 let snapshot = ctx.default_snapshot();
40 Self {
41 ctx,
42 snapshot: RwLock::new(snapshot),
43 }
44 }
45
46 fn update_snapshot(&self, snapshot: TransactionSnapshot) {
47 let mut guard = self.snapshot.write().expect("snapshot lock poisoned");
48 *guard = snapshot;
49 }
50
51 fn current_snapshot(&self) -> TransactionSnapshot {
52 *self.snapshot.read().expect("snapshot lock poisoned")
53 }
54
55 pub(crate) fn context(&self) -> &Arc<RuntimeContext<P>> {
56 &self.ctx
57 }
58
59 pub(crate) fn ctx(&self) -> &RuntimeContext<P> {
60 &self.ctx
61 }
62}
63
64impl<P> CatalogDdl for RuntimeTransactionContext<P>
65where
66 P: Pager<Blob = EntryHandle> + Send + Sync + 'static,
67{
68 type CreateTableOutput = TransactionResult<P>;
69 type DropTableOutput = ();
70 type RenameTableOutput = ();
71 type AlterTableOutput = TransactionResult<P>;
72 type CreateIndexOutput = TransactionResult<P>;
73 type DropIndexOutput = Option<SingleColumnIndexDescriptor>;
74
75 fn create_table(&self, plan: CreateTablePlan) -> LlkvResult<Self::CreateTableOutput> {
76 let ctx = self.context();
77 let result = CatalogDdl::create_table(ctx.as_ref(), plan)?;
78 Ok(convert_statement_result(result))
79 }
80
81 fn drop_table(&self, plan: DropTablePlan) -> LlkvResult<Self::DropTableOutput> {
82 CatalogDdl::drop_table(self.ctx.as_ref(), plan)
83 }
84
85 fn rename_table(&self, plan: RenameTablePlan) -> LlkvResult<Self::RenameTableOutput> {
86 CatalogDdl::rename_table(self.ctx.as_ref(), plan)
87 }
88
89 fn alter_table(&self, plan: AlterTablePlan) -> LlkvResult<Self::AlterTableOutput> {
90 let ctx = self.context();
91 let result = CatalogDdl::alter_table(ctx.as_ref(), plan)?;
92 Ok(convert_statement_result(result))
93 }
94
95 fn create_index(&self, plan: CreateIndexPlan) -> LlkvResult<Self::CreateIndexOutput> {
96 let ctx = self.context();
97 let result = CatalogDdl::create_index(ctx.as_ref(), plan)?;
98 Ok(convert_statement_result(result))
99 }
100
101 fn drop_index(&self, plan: DropIndexPlan) -> LlkvResult<Self::DropIndexOutput> {
102 CatalogDdl::drop_index(self.ctx.as_ref(), plan)
103 }
104
105 fn create_view(&self, plan: CreateViewPlan) -> LlkvResult<()> {
106 CatalogDdl::create_view(self.ctx.as_ref(), plan)
107 }
108
109 fn drop_view(&self, plan: DropViewPlan) -> LlkvResult<()> {
110 CatalogDdl::drop_view(self.ctx.as_ref(), plan)
111 }
112}
113
114impl<P> TransactionContext for RuntimeTransactionContext<P>
116where
117 P: Pager<Blob = EntryHandle> + Send + Sync + 'static,
118{
119 type Pager = P;
120 type Snapshot = llkv_table::catalog::TableCatalogSnapshot;
121
122 fn set_snapshot(&self, snapshot: TransactionSnapshot) {
123 self.update_snapshot(snapshot);
124 }
125
126 fn snapshot(&self) -> TransactionSnapshot {
127 self.current_snapshot()
128 }
129
130 fn table_column_specs(&self, table_name: &str) -> LlkvResult<Vec<PlanColumnSpec>> {
131 let (_, canonical_name) = llkv_table::canonical_table_name(table_name)?;
132 self.context().catalog().table_column_specs(&canonical_name)
133 }
134
135 fn export_table_rows(&self, table_name: &str) -> LlkvResult<ExecutorRowBatch> {
136 RuntimeContext::export_table_rows(self.context(), table_name)
137 }
138
139 fn get_batches_with_row_ids(
140 &self,
141 table_name: &str,
142 filter: Option<LlkvExpr<'static, String>>,
143 ) -> LlkvResult<Vec<RecordBatch>> {
144 self.context()
145 .get_batches_with_row_ids(table_name, filter, self.snapshot())
146 }
147
148 fn execute_select(&self, plan: SelectPlan) -> LlkvResult<SelectExecution<Self::Pager>> {
149 self.context().execute_select(plan, self.snapshot())
150 }
151
152 fn apply_create_table_plan(&self, plan: CreateTablePlan) -> LlkvResult<TransactionResult<P>> {
153 let ctx = self.context();
154 let result = CatalogDdl::create_table(ctx.as_ref(), plan)?;
155 Ok(convert_statement_result(result))
156 }
157
158 fn drop_table(&self, plan: DropTablePlan) -> LlkvResult<()> {
159 CatalogDdl::drop_table(self.ctx.as_ref(), plan)
160 }
161
162 fn insert(&self, plan: InsertPlan) -> LlkvResult<TransactionResult<P>> {
163 tracing::trace!(
164 "[TX_RUNTIME] TransactionContext::insert plan.table='{}', context_pager={:p}",
165 plan.table,
166 &*self.ctx.pager
167 );
168 let snapshot = self.current_snapshot();
169 let result = self.ctx().insert(plan, snapshot)?;
170 Ok(convert_statement_result(result))
171 }
172
173 fn update(&self, plan: UpdatePlan) -> LlkvResult<TransactionResult<P>> {
174 let snapshot = self.current_snapshot();
175 let result = self.ctx().update(plan, snapshot)?;
176 Ok(convert_statement_result(result))
177 }
178
179 fn delete(&self, plan: DeletePlan) -> LlkvResult<TransactionResult<P>> {
180 let snapshot = self.current_snapshot();
181 let result = self.ctx().delete(plan, snapshot)?;
182 Ok(convert_statement_result(result))
183 }
184
185 fn truncate(&self, plan: TruncatePlan) -> LlkvResult<TransactionResult<P>> {
186 let snapshot = self.current_snapshot();
187 let result = self.ctx().truncate(plan, snapshot)?;
188 Ok(convert_statement_result(result))
189 }
190
191 fn create_index(&self, plan: CreateIndexPlan) -> LlkvResult<TransactionResult<P>> {
192 let ctx = self.context();
193 let result = CatalogDdl::create_index(ctx.as_ref(), plan)?;
194 Ok(convert_statement_result(result))
195 }
196
197 fn append_batches_with_row_ids(
198 &self,
199 table_name: &str,
200 batches: Vec<RecordBatch>,
201 ) -> LlkvResult<usize> {
202 RuntimeContext::append_batches_with_row_ids(self.context(), table_name, batches)
203 }
204
205 fn table_names(&self) -> Vec<String> {
206 RuntimeContext::table_names(self.context())
207 }
208
209 fn table_id(&self, table_name: &str) -> LlkvResult<TableId> {
210 let ctx = self.context();
211 if ctx.is_table_marked_dropped(table_name) {
212 return Err(Error::InvalidArgumentError(format!(
213 "table '{}' has been dropped",
214 table_name
215 )));
216 }
217
218 let table = ctx.lookup_table(table_name)?;
219 Ok(table.table.table_id())
220 }
221
222 fn catalog_snapshot(&self) -> Self::Snapshot {
223 let ctx = self.context();
224 ctx.catalog.snapshot()
225 }
226
227 fn validate_commit_constraints(&self, txn_id: TxnId) -> LlkvResult<()> {
228 self.ctx.validate_primary_keys_for_commit(txn_id)
229 }
230
231 fn clear_transaction_state(&self, txn_id: TxnId) {
232 self.ctx.clear_transaction_state(txn_id);
233 }
234}
235
236fn convert_statement_result<P>(result: RuntimeStatementResult<P>) -> TransactionResult<P>
237where
238 P: Pager<Blob = EntryHandle> + Send + Sync + 'static,
239{
240 use llkv_transaction::TransactionResult as TxResult;
241 match result {
242 RuntimeStatementResult::CreateTable { table_name } => TxResult::CreateTable { table_name },
243 RuntimeStatementResult::CreateIndex {
244 table_name,
245 index_name,
246 } => TxResult::CreateIndex {
247 table_name,
248 index_name,
249 },
250 RuntimeStatementResult::Insert { rows_inserted, .. } => TxResult::Insert { rows_inserted },
251 RuntimeStatementResult::Update { rows_updated, .. } => TxResult::Update {
252 rows_matched: rows_updated,
253 rows_updated,
254 },
255 RuntimeStatementResult::Delete { rows_deleted, .. } => TxResult::Delete { rows_deleted },
256 RuntimeStatementResult::Transaction { kind } => TxResult::Transaction { kind },
257 RuntimeStatementResult::NoOp => TxResult::NoOp,
258 _ => panic!("unsupported StatementResult conversion"),
259 }
260}