1use std::ops::Deref;
2use std::ops::DerefMut;
3use std::sync::atomic::{AtomicI64, Ordering};
4use std::time::Duration;
5
6use prost_types::Struct;
7
8use crate::session::ManagedSession;
9use crate::statement::Statement;
10use crate::transaction::{CallOptions, QueryOptions, Transaction};
11use crate::value::Timestamp;
12use google_cloud_gax::grpc::{Code, Status};
13use google_cloud_gax::retry::{RetrySetting, TryAs};
14use google_cloud_googleapis::spanner::v1::commit_request::Transaction::TransactionId;
15use google_cloud_googleapis::spanner::v1::transaction_options::IsolationLevel;
16use google_cloud_googleapis::spanner::v1::{
17 commit_request, execute_batch_dml_request, result_set_stats, transaction_options, transaction_selector,
18 BeginTransactionRequest, CommitRequest, CommitResponse, ExecuteBatchDmlRequest, ExecuteSqlRequest, Mutation,
19 ResultSetStats, RollbackRequest, TransactionOptions, TransactionSelector,
20};
21
22#[derive(Clone, Default)]
23pub struct CommitOptions {
24 pub return_commit_stats: bool,
25 pub call_options: CallOptions,
26 pub max_commit_delay: Option<Duration>,
27 pub transaction_tag: Option<String>,
29}
30
31#[derive(Clone)]
32pub struct CommitResult {
33 pub timestamp: Option<Timestamp>,
34 pub mutation_count: Option<u64>,
35}
36
37impl From<CommitResponse> for CommitResult {
38 fn from(value: CommitResponse) -> Self {
39 Self {
40 timestamp: value.commit_timestamp.map(|v| v.into()),
41 mutation_count: value.commit_stats.map(|s| s.mutation_count as u64),
42 }
43 }
44}
45
46pub struct ReadWriteTransaction {
99 base_tx: Transaction,
100 tx_id: Vec<u8>,
101 wb: Vec<Mutation>,
102}
103
104impl Deref for ReadWriteTransaction {
105 type Target = Transaction;
106
107 fn deref(&self) -> &Self::Target {
108 &self.base_tx
109 }
110}
111
112impl DerefMut for ReadWriteTransaction {
113 fn deref_mut(&mut self) -> &mut Transaction {
114 &mut self.base_tx
115 }
116}
117
118pub struct BeginError {
119 pub status: Status,
120 pub session: ManagedSession,
121}
122
123impl ReadWriteTransaction {
124 pub async fn begin(
125 session: ManagedSession,
126 options: CallOptions,
127 transaction_tag: Option<String>,
128 disable_route_to_leader: bool,
129 ) -> Result<ReadWriteTransaction, BeginError> {
130 ReadWriteTransaction::begin_internal(
131 session,
132 transaction_options::Mode::ReadWrite(transaction_options::ReadWrite::default()),
133 options,
134 transaction_tag,
135 disable_route_to_leader,
136 )
137 .await
138 }
139
140 pub async fn begin_partitioned_dml(
141 session: ManagedSession,
142 options: CallOptions,
143 transaction_tag: Option<String>,
144 disable_route_to_leader: bool,
145 ) -> Result<ReadWriteTransaction, BeginError> {
146 ReadWriteTransaction::begin_internal(
147 session,
148 transaction_options::Mode::PartitionedDml(transaction_options::PartitionedDml {}),
149 options,
150 transaction_tag,
151 disable_route_to_leader,
152 )
153 .await
154 }
155
156 async fn begin_internal(
157 mut session: ManagedSession,
158 mode: transaction_options::Mode,
159 options: CallOptions,
160 transaction_tag: Option<String>,
161 disable_route_to_leader: bool,
162 ) -> Result<ReadWriteTransaction, BeginError> {
163 let request = BeginTransactionRequest {
164 session: session.session.name.to_string(),
165 options: Some(TransactionOptions {
166 exclude_txn_from_change_streams: false,
167 mode: Some(mode),
168 isolation_level: IsolationLevel::Unspecified as i32,
169 }),
170 request_options: Transaction::create_request_options(options.priority, transaction_tag.clone()),
171 mutation_key: None,
172 };
173 let result = session
174 .spanner_client
175 .begin_transaction(request, disable_route_to_leader, options.retry)
176 .await;
177 let response = match session.invalidate_if_needed(result).await {
178 Ok(response) => response,
179 Err(err) => {
180 return Err(BeginError { status: err, session });
181 }
182 };
183 let tx = response.into_inner();
184 Ok(ReadWriteTransaction {
185 base_tx: Transaction {
186 session: Some(session),
187 sequence_number: AtomicI64::new(0),
188 transaction_selector: TransactionSelector {
189 selector: Some(transaction_selector::Selector::Id(tx.id.clone())),
190 },
191 transaction_tag,
192 disable_route_to_leader,
193 },
194 tx_id: tx.id,
195 wb: vec![],
196 })
197 }
198
199 pub fn buffer_write(&mut self, ms: Vec<Mutation>) {
200 self.wb.extend_from_slice(&ms)
201 }
202
203 pub async fn update(&mut self, stmt: Statement) -> Result<i64, Status> {
204 self.update_with_option(stmt, QueryOptions::default()).await
205 }
206
207 pub async fn update_with_option(&mut self, stmt: Statement, options: QueryOptions) -> Result<i64, Status> {
208 let request = ExecuteSqlRequest {
209 session: self.get_session_name(),
210 transaction: Some(self.transaction_selector.clone()),
211 sql: stmt.sql.to_string(),
212 data_boost_enabled: false,
213 params: Some(prost_types::Struct { fields: stmt.params }),
214 param_types: stmt.param_types,
215 resume_token: vec![],
216 query_mode: options.mode.into(),
217 partition_token: vec![],
218 seqno: self.sequence_number.fetch_add(1, Ordering::Relaxed),
219 query_options: options.optimizer_options,
220 request_options: Transaction::create_request_options(
221 options.call_options.priority,
222 self.base_tx.transaction_tag.clone(),
223 ),
224 directed_read_options: None,
225 last_statement: false,
226 };
227 let disable_route_to_leader = self.disable_route_to_leader;
228 let session = self.as_mut_session();
229 let result = session
230 .spanner_client
231 .execute_sql(request, disable_route_to_leader, options.call_options.retry)
232 .await;
233 let response = session.invalidate_if_needed(result).await?;
234 Ok(extract_row_count(response.into_inner().stats))
235 }
236
237 pub async fn batch_update(&mut self, stmt: Vec<Statement>) -> Result<Vec<i64>, Status> {
238 self.batch_update_with_option(stmt, QueryOptions::default()).await
239 }
240
241 pub async fn batch_update_with_option(
242 &mut self,
243 stmt: Vec<Statement>,
244 options: QueryOptions,
245 ) -> Result<Vec<i64>, Status> {
246 let request = ExecuteBatchDmlRequest {
247 session: self.get_session_name(),
248 transaction: Some(self.transaction_selector.clone()),
249 seqno: self.sequence_number.fetch_add(1, Ordering::Relaxed),
250 request_options: Transaction::create_request_options(
251 options.call_options.priority,
252 self.base_tx.transaction_tag.clone(),
253 ),
254 statements: stmt
255 .into_iter()
256 .map(|x| execute_batch_dml_request::Statement {
257 sql: x.sql,
258 params: Some(Struct { fields: x.params }),
259 param_types: x.param_types,
260 })
261 .collect(),
262 last_statements: false,
263 };
264
265 let disable_route_to_leader = self.disable_route_to_leader;
266 let session = self.as_mut_session();
267 let result = session
268 .spanner_client
269 .execute_batch_dml(request, disable_route_to_leader, options.call_options.retry)
270 .await;
271 let response = session.invalidate_if_needed(result).await?;
272 Ok(response
273 .into_inner()
274 .result_sets
275 .into_iter()
276 .map(|x| extract_row_count(x.stats))
277 .collect())
278 }
279
280 pub async fn end<S, E>(
281 &mut self,
282 result: Result<S, E>,
283 options: Option<CommitOptions>,
284 ) -> Result<(CommitResult, S), E>
285 where
286 E: TryAs<Status> + From<Status>,
287 {
288 let opt = options.unwrap_or_default();
289 match result {
290 Ok(success) => {
291 let cr = self.commit(opt).await?;
292 Ok((cr.into(), success))
293 }
294 Err(err) => {
295 if let Some(status) = err.try_as() {
296 if status.code() == Code::Aborted {
298 return Err(err);
299 }
300 }
301 let _ = self.rollback(opt.call_options.retry).await;
302 Err(err)
303 }
304 }
305 }
306
307 pub(crate) async fn finish<T, E>(
308 &mut self,
309 result: Result<T, E>,
310 options: Option<CommitOptions>,
311 ) -> Result<(CommitResult, T), (E, Option<ManagedSession>)>
312 where
313 E: TryAs<Status> + From<Status>,
314 {
315 let opt = options.unwrap_or_default();
316
317 match result {
318 Ok(s) => match self.commit(opt).await {
319 Ok(c) => Ok((c.into(), s)),
320 Err(e) => Err((E::from(e), self.take_session())),
324 },
325
326 Err(err) => {
333 let status = match err.try_as() {
334 Some(status) => status,
335 None => {
336 let _ = self.rollback(opt.call_options.retry).await;
337 return Err((err, self.take_session()));
338 }
339 };
340 match status.code() {
341 Code::Aborted => Err((err, self.take_session())),
342 _ => {
343 let _ = self.rollback(opt.call_options.retry).await;
344 Err((err, self.take_session()))
345 }
346 }
347 }
348 }
349 }
350
351 pub(crate) async fn commit(&mut self, options: CommitOptions) -> Result<CommitResponse, Status> {
352 let tx_id = self.tx_id.clone();
353 let mutations = self.wb.to_vec();
354 let disable_route_to_leader = self.disable_route_to_leader;
355 let session = self.as_mut_session();
356 commit(session, mutations, TransactionId(tx_id), options, disable_route_to_leader).await
357 }
358
359 pub(crate) async fn rollback(&mut self, retry: Option<RetrySetting>) -> Result<(), Status> {
360 let request = RollbackRequest {
361 transaction_id: self.tx_id.clone(),
362 session: self.get_session_name(),
363 };
364 let disable_route_to_leader = self.disable_route_to_leader;
365 let session = self.as_mut_session();
366 let result = session
367 .spanner_client
368 .rollback(request, disable_route_to_leader, retry)
369 .await;
370 session.invalidate_if_needed(result).await?.into_inner();
371 Ok(())
372 }
373}
374
375pub(crate) async fn commit(
376 session: &mut ManagedSession,
377 ms: Vec<Mutation>,
378 tx: commit_request::Transaction,
379 commit_options: CommitOptions,
380 disable_route_to_leader: bool,
381) -> Result<CommitResponse, Status> {
382 let request = CommitRequest {
383 session: session.session.name.to_string(),
384 mutations: ms,
385 transaction: Some(tx),
386 request_options: Transaction::create_request_options(
387 commit_options.call_options.priority,
388 commit_options.transaction_tag.clone(),
389 ),
390 return_commit_stats: commit_options.return_commit_stats,
391 max_commit_delay: commit_options.max_commit_delay.map(|d| d.try_into().unwrap()),
392 precommit_token: None,
393 };
394 let result = session
395 .spanner_client
396 .commit(request, disable_route_to_leader, commit_options.call_options.retry)
397 .await;
398 let response = session.invalidate_if_needed(result).await;
399 match response {
400 Ok(r) => Ok(r.into_inner()),
401 Err(s) => Err(s),
402 }
403}
404
405fn extract_row_count(rs: Option<ResultSetStats>) -> i64 {
406 match rs {
407 Some(o) => match o.row_count {
408 Some(o) => match o {
409 result_set_stats::RowCount::RowCountExact(v) => v,
410 result_set_stats::RowCount::RowCountLowerBound(v) => v,
411 },
412 None => 0,
413 },
414 None => 0,
415 }
416}