Skip to main content

google_cloud_spanner/
batch_write_transaction.rs

1// Copyright 2026 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::client::DatabaseClient;
16use crate::model::BatchWriteRequest;
17use crate::model::BatchWriteResponse;
18use crate::mutation::MutationGroup;
19use crate::server_streaming::stream::BatchWriteStream;
20use gaxi::prost::FromProto;
21
22#[cfg(feature = "unstable-stream")]
23use futures::Stream;
24
25/// A builder for [BatchWriteTransaction].
26pub struct BatchWriteTransactionBuilder {
27    client: DatabaseClient,
28}
29
30impl BatchWriteTransactionBuilder {
31    pub(crate) fn new(client: DatabaseClient) -> Self {
32        Self { client }
33    }
34
35    /// Builds the [BatchWriteTransaction].
36    pub fn build(self) -> BatchWriteTransaction {
37        let session_name = self.client.session_name();
38        let channel_hint = self.client.spanner.next_channel_hint();
39        BatchWriteTransaction {
40            session_name,
41            client: self.client,
42            channel_hint,
43        }
44    }
45}
46
47/// A transaction for executing batch writes.
48///
49/// Batch writes are not guaranteed to be atomic across mutation groups.
50/// All mutations within a group are applied atomically.
51pub struct BatchWriteTransaction {
52    session_name: String,
53    client: DatabaseClient,
54    channel_hint: usize,
55}
56
57impl BatchWriteTransaction {
58    /// Executes the batch write and returns a stream of responses.
59    ///
60    /// # Example
61    /// ```
62    /// # use google_cloud_spanner::mutation::Mutation;
63    /// # use google_cloud_spanner::client::Spanner;
64    /// # use google_cloud_spanner::mutation::MutationGroup;
65    /// # use google_cloud_gax::error::rpc::Code;
66    /// # async fn sample() -> Result<(), Box<dyn std::error::Error>> {
67    /// let client = Spanner::builder().build().await?;
68    /// let db = client.database_client("projects/p/instances/i/databases/d").build().await?;
69    ///
70    /// let mutation = Mutation::new_insert_builder("Users")
71    ///     .set("UserId").to(&1)
72    ///     .build();
73    /// let group = MutationGroup::new(vec![mutation]);
74    ///
75    /// let tx = db.batch_write_transaction().build();
76    /// let mut stream = tx.execute_streaming(vec![group]).await?;
77    ///
78    /// while let Some(response) = stream.next().await {
79    ///     let response = response?;
80    ///     if let Some(status) = response.status.as_ref().filter(|s| s.code != Code::Ok as i32) {
81    ///         eprintln!("Error applying groups {:?}: {}", response.indexes, status.message);
82    ///     } else {
83    ///         println!("Applied groups: {:?}", response.indexes);
84    ///     }
85    /// }
86    /// # Ok(())
87    /// # }
88    /// ```
89    ///
90    /// This method sends the mutation groups to Spanner and returns the responses as a stream.
91    /// Each response includes a status code that indicates whether the mutation groups that
92    /// it references were applied successfully.
93    ///
94    /// The method does not handle any errors, including retryable errors like Aborted.
95    /// The caller is responsible for handling any errors and for retrying the transaction in
96    /// case it is aborted by Spanner.
97    pub async fn execute_streaming<I>(self, groups: I) -> crate::Result<BatchWriteResponseStream>
98    where
99        I: IntoIterator<Item = MutationGroup>,
100    {
101        let req = BatchWriteRequest::new()
102            .set_session(self.session_name.clone())
103            .set_mutation_groups(groups.into_iter().map(|g| g.build_proto()));
104
105        let stream = self
106            .client
107            .spanner
108            .batch_write(req, crate::RequestOptions::default(), self.channel_hint)
109            .send()
110            .await?;
111        Ok(BatchWriteResponseStream { inner: stream })
112    }
113}
114
115/// A stream of [BatchWriteResponse] messages.
116pub struct BatchWriteResponseStream {
117    pub(crate) inner: BatchWriteStream,
118}
119
120impl BatchWriteResponseStream {
121    /// Fetches the next [BatchWriteResponse] from the stream.
122    ///
123    /// Returns `Some(Ok(BatchWriteResponse))` when a message is successfully received,
124    /// `None` when the stream concludes naturally, or `Some(Err(_))` on RPC errors.
125    pub async fn next(&mut self) -> Option<crate::Result<BatchWriteResponse>> {
126        let proto_opt = self.inner.next_message().await?;
127        match proto_opt {
128            Ok(proto) => match proto.cnv() {
129                Ok(model) => Some(Ok(model)),
130                Err(e) => Some(Err(crate::Error::deser(e))),
131            },
132            Err(e) => Some(Err(e)),
133        }
134    }
135
136    /// Converts the [`BatchWriteResponseStream`] into a [`Stream`].
137    ///
138    /// This consumes the [`BatchWriteResponseStream`] and returns a stream of responses.
139    #[cfg(feature = "unstable-stream")]
140    pub fn into_stream(self) -> impl Stream<Item = crate::Result<BatchWriteResponse>> + Unpin {
141        use futures::stream::unfold;
142        Box::pin(unfold(self, |mut stream| async move {
143            stream.next().await.map(|res| (res, stream))
144        }))
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use crate::client::Spanner;
152    use crate::mutation::Mutation;
153    use crate::result_set::tests::adapt;
154    use anyhow::Result;
155    use gaxi::grpc::tonic::Response;
156    use google_cloud_test_macros::tokio_test_no_panics;
157    use spanner_grpc_mock::MockSpanner;
158    use spanner_grpc_mock::google::spanner::v1 as mock_v1;
159
160    pub(crate) async fn setup_db_client(
161        mock: MockSpanner,
162    ) -> (DatabaseClient, tokio::task::JoinHandle<()>) {
163        use google_cloud_auth::credentials::anonymous::Builder as Anonymous;
164        let (address, server) = spanner_grpc_mock::start("0.0.0.0:0", mock)
165            .await
166            .expect("Failed to start mock server");
167        let spanner = Spanner::builder()
168            .with_endpoint(address)
169            .with_credentials(Anonymous::new().build())
170            .build()
171            .await
172            .expect("Failed to build client");
173
174        let db_client = spanner
175            .database_client("projects/p/instances/i/databases/d")
176            .build()
177            .await
178            .expect("Failed to create DatabaseClient");
179
180        (db_client, server)
181    }
182
183    #[tokio_test_no_panics]
184    async fn execute_streaming() -> Result<()> {
185        let mut mock = MockSpanner::new();
186        mock.expect_create_session().returning(|_| {
187            Ok(Response::new(mock_v1::Session {
188                name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
189                ..Default::default()
190            }))
191        });
192
193        mock.expect_batch_write().once().returning(|req| {
194            let req = req.into_inner();
195            assert_eq!(
196                req.session,
197                "projects/p/instances/i/databases/d/sessions/123"
198            );
199            assert_eq!(req.mutation_groups.len(), 1);
200
201            let response = mock_v1::BatchWriteResponse {
202                indexes: vec![0],
203                status: None,
204                commit_timestamp: None,
205            };
206
207            Ok(Response::from(adapt([Ok(response)])))
208        });
209
210        let (db_client, _server) = setup_db_client(mock).await;
211
212        let mutation = Mutation::new_insert_builder("Users")
213            .set("UserId")
214            .to(&1)
215            .build();
216        let group = MutationGroup::new(vec![mutation]);
217
218        let tx = db_client.batch_write_transaction().build();
219        let mut stream = tx.execute_streaming(vec![group]).await?;
220
221        let result = stream
222            .next()
223            .await
224            .expect("stream should have yielded a message")?;
225        assert_eq!(
226            result.indexes,
227            vec![0],
228            "indexes should match the mocked response"
229        );
230
231        Ok(())
232    }
233
234    #[cfg(feature = "unstable-stream")]
235    #[tokio_test_no_panics]
236    async fn execute_streaming_into_stream() -> Result<()> {
237        use futures::StreamExt;
238
239        let mut mock = MockSpanner::new();
240        mock.expect_create_session().returning(|_| {
241            Ok(Response::new(mock_v1::Session {
242                name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
243                ..Default::default()
244            }))
245        });
246
247        mock.expect_batch_write().once().returning(|req| {
248            let req = req.into_inner();
249            assert_eq!(
250                req.session, "projects/p/instances/i/databases/d/sessions/123",
251                "session name should match"
252            );
253            assert_eq!(
254                req.mutation_groups.len(),
255                1,
256                "should contain precisely 1 mutation group"
257            );
258
259            let response = mock_v1::BatchWriteResponse {
260                indexes: vec![0],
261                status: None,
262                commit_timestamp: None,
263            };
264
265            Ok(Response::from(adapt([Ok(response)])))
266        });
267
268        let (db_client, _server) = setup_db_client(mock).await;
269
270        let mutation = Mutation::new_insert_builder("Users")
271            .set("UserId")
272            .to(&1)
273            .build();
274        let group = MutationGroup::new(vec![mutation]);
275
276        let transaction = db_client.batch_write_transaction().build();
277        let stream = transaction.execute_streaming(vec![group]).await?;
278        let mut stream = stream.into_stream();
279
280        let result = stream
281            .next()
282            .await
283            .expect("stream should have yielded a message")?;
284        assert_eq!(
285            result.indexes,
286            vec![0],
287            "indexes should match the mocked response"
288        );
289
290        Ok(())
291    }
292}