1use std::ops::Deref;
2use std::ops::DerefMut;
3use std::sync::atomic::{AtomicI64, Ordering};
4use std::time::Duration;
5
6use prost_types::Struct;
7
8use google_cloud_gax::grpc::{Code, Status};
9use google_cloud_gax::retry::{RetrySetting, TryAs};
10use google_cloud_googleapis::spanner::v1::commit_request::Transaction::TransactionId;
11use google_cloud_googleapis::spanner::v1::{
12 commit_request, execute_batch_dml_request, result_set_stats, transaction_options, transaction_selector,
13 BeginTransactionRequest, CommitRequest, CommitResponse, ExecuteBatchDmlRequest, ExecuteSqlRequest, Mutation,
14 ResultSetStats, RollbackRequest, TransactionOptions, TransactionSelector,
15};
16
17use crate::session::ManagedSession;
18use crate::statement::Statement;
19use crate::transaction::{CallOptions, QueryOptions, Transaction};
20use crate::value::Timestamp;
21
22#[derive(Clone, Default)]
23pub struct CommitOptions {
24 pub return_commit_stats: bool,
25 pub call_options: CallOptions,
26 pub max_commit_delay: Option<Duration>,
27}
28
29pub struct ReadWriteTransaction {
82 base_tx: Transaction,
83 tx_id: Vec<u8>,
84 wb: Vec<Mutation>,
85}
86
87impl Deref for ReadWriteTransaction {
88 type Target = Transaction;
89
90 fn deref(&self) -> &Self::Target {
91 &self.base_tx
92 }
93}
94
95impl DerefMut for ReadWriteTransaction {
96 fn deref_mut(&mut self) -> &mut Transaction {
97 &mut self.base_tx
98 }
99}
100
101pub struct BeginError {
102 pub status: Status,
103 pub session: ManagedSession,
104}
105
106impl ReadWriteTransaction {
107 pub async fn begin(session: ManagedSession, options: CallOptions) -> Result<ReadWriteTransaction, BeginError> {
108 ReadWriteTransaction::begin_internal(
109 session,
110 transaction_options::Mode::ReadWrite(transaction_options::ReadWrite::default()),
111 options,
112 )
113 .await
114 }
115
116 pub async fn begin_partitioned_dml(
117 session: ManagedSession,
118 options: CallOptions,
119 ) -> Result<ReadWriteTransaction, BeginError> {
120 ReadWriteTransaction::begin_internal(
121 session,
122 transaction_options::Mode::PartitionedDml(transaction_options::PartitionedDml {}),
123 options,
124 )
125 .await
126 }
127
128 async fn begin_internal(
129 mut session: ManagedSession,
130 mode: transaction_options::Mode,
131 options: CallOptions,
132 ) -> Result<ReadWriteTransaction, BeginError> {
133 let request = BeginTransactionRequest {
134 session: session.session.name.to_string(),
135 options: Some(TransactionOptions {
136 exclude_txn_from_change_streams: false,
137 mode: Some(mode),
138 }),
139 request_options: Transaction::create_request_options(options.priority),
140 };
141 let result = session.spanner_client.begin_transaction(request, options.retry).await;
142 let response = match session.invalidate_if_needed(result).await {
143 Ok(response) => response,
144 Err(err) => {
145 return Err(BeginError { status: err, session });
146 }
147 };
148 let tx = response.into_inner();
149 Ok(ReadWriteTransaction {
150 base_tx: Transaction {
151 session: Some(session),
152 sequence_number: AtomicI64::new(0),
153 transaction_selector: TransactionSelector {
154 selector: Some(transaction_selector::Selector::Id(tx.id.clone())),
155 },
156 },
157 tx_id: tx.id,
158 wb: vec![],
159 })
160 }
161
162 pub fn buffer_write(&mut self, ms: Vec<Mutation>) {
163 self.wb.extend_from_slice(&ms)
164 }
165
166 pub async fn update(&mut self, stmt: Statement) -> Result<i64, Status> {
167 self.update_with_option(stmt, QueryOptions::default()).await
168 }
169
170 pub async fn update_with_option(&mut self, stmt: Statement, options: QueryOptions) -> Result<i64, Status> {
171 let request = ExecuteSqlRequest {
172 session: self.get_session_name(),
173 transaction: Some(self.transaction_selector.clone()),
174 sql: stmt.sql.to_string(),
175 data_boost_enabled: false,
176 params: Some(prost_types::Struct { fields: stmt.params }),
177 param_types: stmt.param_types,
178 resume_token: vec![],
179 query_mode: options.mode.into(),
180 partition_token: vec![],
181 seqno: self.sequence_number.fetch_add(1, Ordering::Relaxed),
182 query_options: options.optimizer_options,
183 request_options: Transaction::create_request_options(options.call_options.priority),
184 directed_read_options: None,
185 };
186
187 let session = self.as_mut_session();
188 let result = session
189 .spanner_client
190 .execute_sql(request, options.call_options.retry)
191 .await;
192 let response = session.invalidate_if_needed(result).await?;
193 Ok(extract_row_count(response.into_inner().stats))
194 }
195
196 pub async fn batch_update(&mut self, stmt: Vec<Statement>) -> Result<Vec<i64>, Status> {
197 self.batch_update_with_option(stmt, QueryOptions::default()).await
198 }
199
200 pub async fn batch_update_with_option(
201 &mut self,
202 stmt: Vec<Statement>,
203 options: QueryOptions,
204 ) -> Result<Vec<i64>, Status> {
205 let request = ExecuteBatchDmlRequest {
206 session: self.get_session_name(),
207 transaction: Some(self.transaction_selector.clone()),
208 seqno: self.sequence_number.fetch_add(1, Ordering::Relaxed),
209 request_options: Transaction::create_request_options(options.call_options.priority),
210 statements: stmt
211 .into_iter()
212 .map(|x| execute_batch_dml_request::Statement {
213 sql: x.sql,
214 params: Some(Struct { fields: x.params }),
215 param_types: x.param_types,
216 })
217 .collect(),
218 };
219
220 let session = self.as_mut_session();
221 let result = session
222 .spanner_client
223 .execute_batch_dml(request, options.call_options.retry)
224 .await;
225 let response = session.invalidate_if_needed(result).await?;
226 Ok(response
227 .into_inner()
228 .result_sets
229 .into_iter()
230 .map(|x| extract_row_count(x.stats))
231 .collect())
232 }
233
234 pub async fn end<S, E>(
235 &mut self,
236 result: Result<S, E>,
237 options: Option<CommitOptions>,
238 ) -> Result<(Option<Timestamp>, S), E>
239 where
240 E: TryAs<Status> + From<Status>,
241 {
242 let opt = options.unwrap_or_default();
243 match result {
244 Ok(success) => {
245 let cr = self.commit(opt).await?;
246 Ok((cr.commit_timestamp.map(|e| e.into()), success))
247 }
248 Err(err) => {
249 if let Some(status) = err.try_as() {
250 if status.code() == Code::Aborted {
252 return Err(err);
253 }
254 }
255 let _ = self.rollback(opt.call_options.retry).await;
256 Err(err)
257 }
258 }
259 }
260
261 pub(crate) async fn finish<T, E>(
262 &mut self,
263 result: Result<T, E>,
264 options: Option<CommitOptions>,
265 ) -> Result<(Option<Timestamp>, T), (E, Option<ManagedSession>)>
266 where
267 E: TryAs<Status> + From<Status>,
268 {
269 let opt = options.unwrap_or_default();
270
271 match result {
272 Ok(s) => match self.commit(opt).await {
273 Ok(c) => Ok((c.commit_timestamp.map(|ts| ts.into()), s)),
274 Err(e) => Err((E::from(e), self.take_session())),
278 },
279
280 Err(err) => {
287 let status = match err.try_as() {
288 Some(status) => status,
289 None => {
290 let _ = self.rollback(opt.call_options.retry).await;
291 return Err((err, self.take_session()));
292 }
293 };
294 match status.code() {
295 Code::Aborted => Err((err, self.take_session())),
296 _ => {
297 let _ = self.rollback(opt.call_options.retry).await;
298 Err((err, self.take_session()))
299 }
300 }
301 }
302 }
303 }
304
305 pub(crate) async fn commit(&mut self, options: CommitOptions) -> Result<CommitResponse, Status> {
306 let tx_id = self.tx_id.clone();
307 let mutations = self.wb.to_vec();
308 let session = self.as_mut_session();
309 commit(session, mutations, TransactionId(tx_id), options).await
310 }
311
312 pub(crate) async fn rollback(&mut self, retry: Option<RetrySetting>) -> Result<(), Status> {
313 let request = RollbackRequest {
314 transaction_id: self.tx_id.clone(),
315 session: self.get_session_name(),
316 };
317 let session = self.as_mut_session();
318 let result = session.spanner_client.rollback(request, retry).await;
319 session.invalidate_if_needed(result).await?.into_inner();
320 Ok(())
321 }
322}
323
324pub(crate) async fn commit(
325 session: &mut ManagedSession,
326 ms: Vec<Mutation>,
327 tx: commit_request::Transaction,
328 commit_options: CommitOptions,
329) -> Result<CommitResponse, Status> {
330 let request = CommitRequest {
331 session: session.session.name.to_string(),
332 mutations: ms,
333 transaction: Some(tx),
334 request_options: Transaction::create_request_options(commit_options.call_options.priority),
335 return_commit_stats: commit_options.return_commit_stats,
336 max_commit_delay: commit_options.max_commit_delay.map(|d| d.try_into().unwrap()),
337 };
338 let result = session
339 .spanner_client
340 .commit(request, commit_options.call_options.retry)
341 .await;
342 let response = session.invalidate_if_needed(result).await;
343 match response {
344 Ok(r) => Ok(r.into_inner()),
345 Err(s) => Err(s),
346 }
347}
348
349fn extract_row_count(rs: Option<ResultSetStats>) -> i64 {
350 match rs {
351 Some(o) => match o.row_count {
352 Some(o) => match o {
353 result_set_stats::RowCount::RowCountExact(v) => v,
354 result_set_stats::RowCount::RowCountLowerBound(v) => v,
355 },
356 None => 0,
357 },
358 None => 0,
359 }
360}