1use std::ops::Deref;
2use std::ops::DerefMut;
3use std::sync::atomic::{AtomicI64, Ordering};
4use std::time::Duration;
5
6use prost_types::Struct;
7
8use crate::session::ManagedSession;
9use crate::statement::Statement;
10use crate::transaction::{CallOptions, QueryOptions, Transaction};
11use crate::value::Timestamp;
12use google_cloud_gax::grpc::{Code, Status};
13use google_cloud_gax::retry::{RetrySetting, TryAs};
14use google_cloud_googleapis::spanner::v1::commit_request::Transaction::TransactionId;
15use google_cloud_googleapis::spanner::v1::transaction_options::IsolationLevel;
16use google_cloud_googleapis::spanner::v1::{
17 commit_request, execute_batch_dml_request, result_set_stats, transaction_options, transaction_selector,
18 BeginTransactionRequest, CommitRequest, CommitResponse, ExecuteBatchDmlRequest, ExecuteSqlRequest, Mutation,
19 ResultSetStats, RollbackRequest, TransactionOptions, TransactionSelector,
20};
21
22#[derive(Clone, Default)]
23pub struct CommitOptions {
24 pub return_commit_stats: bool,
25 pub call_options: CallOptions,
26 pub max_commit_delay: Option<Duration>,
27 pub transaction_tag: Option<String>,
29}
30
31#[derive(Clone)]
32pub struct CommitResult {
33 pub timestamp: Option<Timestamp>,
34 pub mutation_count: Option<u64>,
35}
36
37impl From<CommitResponse> for CommitResult {
38 fn from(value: CommitResponse) -> Self {
39 Self {
40 timestamp: value.commit_timestamp.map(|v| v.into()),
41 mutation_count: value.commit_stats.map(|s| s.mutation_count as u64),
42 }
43 }
44}
45
46pub struct ReadWriteTransaction {
99 base_tx: Transaction,
100 tx_id: Vec<u8>,
101 wb: Vec<Mutation>,
102}
103
104impl Deref for ReadWriteTransaction {
105 type Target = Transaction;
106
107 fn deref(&self) -> &Self::Target {
108 &self.base_tx
109 }
110}
111
112impl DerefMut for ReadWriteTransaction {
113 fn deref_mut(&mut self) -> &mut Transaction {
114 &mut self.base_tx
115 }
116}
117
118pub struct BeginError {
119 pub status: Status,
120 pub session: ManagedSession,
121}
122
123impl ReadWriteTransaction {
124 pub async fn begin(
125 session: ManagedSession,
126 options: CallOptions,
127 transaction_tag: Option<String>,
128 ) -> Result<ReadWriteTransaction, BeginError> {
129 ReadWriteTransaction::begin_internal(
130 session,
131 transaction_options::Mode::ReadWrite(transaction_options::ReadWrite::default()),
132 options,
133 transaction_tag,
134 )
135 .await
136 }
137
138 pub async fn begin_partitioned_dml(
139 session: ManagedSession,
140 options: CallOptions,
141 transaction_tag: Option<String>,
142 ) -> Result<ReadWriteTransaction, BeginError> {
143 ReadWriteTransaction::begin_internal(
144 session,
145 transaction_options::Mode::PartitionedDml(transaction_options::PartitionedDml {}),
146 options,
147 transaction_tag,
148 )
149 .await
150 }
151
152 async fn begin_internal(
153 mut session: ManagedSession,
154 mode: transaction_options::Mode,
155 options: CallOptions,
156 transaction_tag: Option<String>,
157 ) -> Result<ReadWriteTransaction, BeginError> {
158 let request = BeginTransactionRequest {
159 session: session.session.name.to_string(),
160 options: Some(TransactionOptions {
161 exclude_txn_from_change_streams: false,
162 mode: Some(mode),
163 isolation_level: IsolationLevel::Unspecified as i32,
164 }),
165 request_options: Transaction::create_request_options(options.priority, transaction_tag.clone()),
166 mutation_key: None,
167 };
168 let result = session.spanner_client.begin_transaction(request, options.retry).await;
169 let response = match session.invalidate_if_needed(result).await {
170 Ok(response) => response,
171 Err(err) => {
172 return Err(BeginError { status: err, session });
173 }
174 };
175 let tx = response.into_inner();
176 Ok(ReadWriteTransaction {
177 base_tx: Transaction {
178 session: Some(session),
179 sequence_number: AtomicI64::new(0),
180 transaction_selector: TransactionSelector {
181 selector: Some(transaction_selector::Selector::Id(tx.id.clone())),
182 },
183 transaction_tag,
184 },
185 tx_id: tx.id,
186 wb: vec![],
187 })
188 }
189
190 pub fn buffer_write(&mut self, ms: Vec<Mutation>) {
191 self.wb.extend_from_slice(&ms)
192 }
193
194 pub async fn update(&mut self, stmt: Statement) -> Result<i64, Status> {
195 self.update_with_option(stmt, QueryOptions::default()).await
196 }
197
198 pub async fn update_with_option(&mut self, stmt: Statement, options: QueryOptions) -> Result<i64, Status> {
199 let request = ExecuteSqlRequest {
200 session: self.get_session_name(),
201 transaction: Some(self.transaction_selector.clone()),
202 sql: stmt.sql.to_string(),
203 data_boost_enabled: false,
204 params: Some(prost_types::Struct { fields: stmt.params }),
205 param_types: stmt.param_types,
206 resume_token: vec![],
207 query_mode: options.mode.into(),
208 partition_token: vec![],
209 seqno: self.sequence_number.fetch_add(1, Ordering::Relaxed),
210 query_options: options.optimizer_options,
211 request_options: Transaction::create_request_options(
212 options.call_options.priority,
213 self.base_tx.transaction_tag.clone(),
214 ),
215 directed_read_options: None,
216 last_statement: false,
217 };
218
219 let session = self.as_mut_session();
220 let result = session
221 .spanner_client
222 .execute_sql(request, options.call_options.retry)
223 .await;
224 let response = session.invalidate_if_needed(result).await?;
225 Ok(extract_row_count(response.into_inner().stats))
226 }
227
228 pub async fn batch_update(&mut self, stmt: Vec<Statement>) -> Result<Vec<i64>, Status> {
229 self.batch_update_with_option(stmt, QueryOptions::default()).await
230 }
231
232 pub async fn batch_update_with_option(
233 &mut self,
234 stmt: Vec<Statement>,
235 options: QueryOptions,
236 ) -> Result<Vec<i64>, Status> {
237 let request = ExecuteBatchDmlRequest {
238 session: self.get_session_name(),
239 transaction: Some(self.transaction_selector.clone()),
240 seqno: self.sequence_number.fetch_add(1, Ordering::Relaxed),
241 request_options: Transaction::create_request_options(
242 options.call_options.priority,
243 self.base_tx.transaction_tag.clone(),
244 ),
245 statements: stmt
246 .into_iter()
247 .map(|x| execute_batch_dml_request::Statement {
248 sql: x.sql,
249 params: Some(Struct { fields: x.params }),
250 param_types: x.param_types,
251 })
252 .collect(),
253 last_statements: false,
254 };
255
256 let session = self.as_mut_session();
257 let result = session
258 .spanner_client
259 .execute_batch_dml(request, options.call_options.retry)
260 .await;
261 let response = session.invalidate_if_needed(result).await?;
262 Ok(response
263 .into_inner()
264 .result_sets
265 .into_iter()
266 .map(|x| extract_row_count(x.stats))
267 .collect())
268 }
269
270 pub async fn end<S, E>(
271 &mut self,
272 result: Result<S, E>,
273 options: Option<CommitOptions>,
274 ) -> Result<(CommitResult, S), E>
275 where
276 E: TryAs<Status> + From<Status>,
277 {
278 let opt = options.unwrap_or_default();
279 match result {
280 Ok(success) => {
281 let cr = self.commit(opt).await?;
282 Ok((cr.into(), success))
283 }
284 Err(err) => {
285 if let Some(status) = err.try_as() {
286 if status.code() == Code::Aborted {
288 return Err(err);
289 }
290 }
291 let _ = self.rollback(opt.call_options.retry).await;
292 Err(err)
293 }
294 }
295 }
296
297 pub(crate) async fn finish<T, E>(
298 &mut self,
299 result: Result<T, E>,
300 options: Option<CommitOptions>,
301 ) -> Result<(CommitResult, T), (E, Option<ManagedSession>)>
302 where
303 E: TryAs<Status> + From<Status>,
304 {
305 let opt = options.unwrap_or_default();
306
307 match result {
308 Ok(s) => match self.commit(opt).await {
309 Ok(c) => Ok((c.into(), s)),
310 Err(e) => Err((E::from(e), self.take_session())),
314 },
315
316 Err(err) => {
323 let status = match err.try_as() {
324 Some(status) => status,
325 None => {
326 let _ = self.rollback(opt.call_options.retry).await;
327 return Err((err, self.take_session()));
328 }
329 };
330 match status.code() {
331 Code::Aborted => Err((err, self.take_session())),
332 _ => {
333 let _ = self.rollback(opt.call_options.retry).await;
334 Err((err, self.take_session()))
335 }
336 }
337 }
338 }
339 }
340
341 pub(crate) async fn commit(&mut self, options: CommitOptions) -> Result<CommitResponse, Status> {
342 let tx_id = self.tx_id.clone();
343 let mutations = self.wb.to_vec();
344 let session = self.as_mut_session();
345 commit(session, mutations, TransactionId(tx_id), options).await
346 }
347
348 pub(crate) async fn rollback(&mut self, retry: Option<RetrySetting>) -> Result<(), Status> {
349 let request = RollbackRequest {
350 transaction_id: self.tx_id.clone(),
351 session: self.get_session_name(),
352 };
353 let session = self.as_mut_session();
354 let result = session.spanner_client.rollback(request, retry).await;
355 session.invalidate_if_needed(result).await?.into_inner();
356 Ok(())
357 }
358}
359
360pub(crate) async fn commit(
361 session: &mut ManagedSession,
362 ms: Vec<Mutation>,
363 tx: commit_request::Transaction,
364 commit_options: CommitOptions,
365) -> Result<CommitResponse, Status> {
366 let request = CommitRequest {
367 session: session.session.name.to_string(),
368 mutations: ms,
369 transaction: Some(tx),
370 request_options: Transaction::create_request_options(
371 commit_options.call_options.priority,
372 commit_options.transaction_tag.clone(),
373 ),
374 return_commit_stats: commit_options.return_commit_stats,
375 max_commit_delay: commit_options.max_commit_delay.map(|d| d.try_into().unwrap()),
376 precommit_token: None,
377 };
378 let result = session
379 .spanner_client
380 .commit(request, commit_options.call_options.retry)
381 .await;
382 let response = session.invalidate_if_needed(result).await;
383 match response {
384 Ok(r) => Ok(r.into_inner()),
385 Err(s) => Err(s),
386 }
387}
388
389fn extract_row_count(rs: Option<ResultSetStats>) -> i64 {
390 match rs {
391 Some(o) => match o.row_count {
392 Some(o) => match o {
393 result_set_stats::RowCount::RowCountExact(v) => v,
394 result_set_stats::RowCount::RowCountLowerBound(v) => v,
395 },
396 None => 0,
397 },
398 None => 0,
399 }
400}