google_cloud_spanner/
batch_write_transaction.rs1use crate::client::DatabaseClient;
16use crate::model::BatchWriteRequest;
17use crate::model::BatchWriteResponse;
18use crate::mutation::MutationGroup;
19use crate::server_streaming::stream::BatchWriteStream;
20use gaxi::prost::FromProto;
21
22#[cfg(feature = "unstable-stream")]
23use futures::Stream;
24
25pub struct BatchWriteTransactionBuilder {
27 client: DatabaseClient,
28}
29
30impl BatchWriteTransactionBuilder {
31 pub(crate) fn new(client: DatabaseClient) -> Self {
32 Self { client }
33 }
34
35 pub fn build(self) -> BatchWriteTransaction {
37 let session_name = self.client.session_name();
38 let channel_hint = self.client.spanner.next_channel_hint();
39 BatchWriteTransaction {
40 session_name,
41 client: self.client,
42 channel_hint,
43 }
44 }
45}
46
47pub struct BatchWriteTransaction {
52 session_name: String,
53 client: DatabaseClient,
54 channel_hint: usize,
55}
56
57impl BatchWriteTransaction {
58 pub async fn execute_streaming<I>(self, groups: I) -> crate::Result<BatchWriteResponseStream>
98 where
99 I: IntoIterator<Item = MutationGroup>,
100 {
101 let req = BatchWriteRequest::new()
102 .set_session(self.session_name.clone())
103 .set_mutation_groups(groups.into_iter().map(|g| g.build_proto()));
104
105 let stream = self
106 .client
107 .spanner
108 .batch_write(req, crate::RequestOptions::default(), self.channel_hint)
109 .send()
110 .await?;
111 Ok(BatchWriteResponseStream { inner: stream })
112 }
113}
114
115pub struct BatchWriteResponseStream {
117 pub(crate) inner: BatchWriteStream,
118}
119
120impl BatchWriteResponseStream {
121 pub async fn next(&mut self) -> Option<crate::Result<BatchWriteResponse>> {
126 let proto_opt = self.inner.next_message().await?;
127 match proto_opt {
128 Ok(proto) => match proto.cnv() {
129 Ok(model) => Some(Ok(model)),
130 Err(e) => Some(Err(crate::Error::deser(e))),
131 },
132 Err(e) => Some(Err(e)),
133 }
134 }
135
136 #[cfg(feature = "unstable-stream")]
140 pub fn into_stream(self) -> impl Stream<Item = crate::Result<BatchWriteResponse>> + Unpin {
141 use futures::stream::unfold;
142 Box::pin(unfold(self, |mut stream| async move {
143 stream.next().await.map(|res| (res, stream))
144 }))
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use crate::client::Spanner;
152 use crate::mutation::Mutation;
153 use crate::result_set::tests::adapt;
154 use anyhow::Result;
155 use gaxi::grpc::tonic::Response;
156 use google_cloud_test_macros::tokio_test_no_panics;
157 use spanner_grpc_mock::MockSpanner;
158 use spanner_grpc_mock::google::spanner::v1 as mock_v1;
159
160 pub(crate) async fn setup_db_client(
161 mock: MockSpanner,
162 ) -> (DatabaseClient, tokio::task::JoinHandle<()>) {
163 use google_cloud_auth::credentials::anonymous::Builder as Anonymous;
164 let (address, server) = spanner_grpc_mock::start("0.0.0.0:0", mock)
165 .await
166 .expect("Failed to start mock server");
167 let spanner = Spanner::builder()
168 .with_endpoint(address)
169 .with_credentials(Anonymous::new().build())
170 .build()
171 .await
172 .expect("Failed to build client");
173
174 let db_client = spanner
175 .database_client("projects/p/instances/i/databases/d")
176 .build()
177 .await
178 .expect("Failed to create DatabaseClient");
179
180 (db_client, server)
181 }
182
183 #[tokio_test_no_panics]
184 async fn execute_streaming() -> Result<()> {
185 let mut mock = MockSpanner::new();
186 mock.expect_create_session().returning(|_| {
187 Ok(Response::new(mock_v1::Session {
188 name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
189 ..Default::default()
190 }))
191 });
192
193 mock.expect_batch_write().once().returning(|req| {
194 let req = req.into_inner();
195 assert_eq!(
196 req.session,
197 "projects/p/instances/i/databases/d/sessions/123"
198 );
199 assert_eq!(req.mutation_groups.len(), 1);
200
201 let response = mock_v1::BatchWriteResponse {
202 indexes: vec![0],
203 status: None,
204 commit_timestamp: None,
205 };
206
207 Ok(Response::from(adapt([Ok(response)])))
208 });
209
210 let (db_client, _server) = setup_db_client(mock).await;
211
212 let mutation = Mutation::new_insert_builder("Users")
213 .set("UserId")
214 .to(&1)
215 .build();
216 let group = MutationGroup::new(vec![mutation]);
217
218 let tx = db_client.batch_write_transaction().build();
219 let mut stream = tx.execute_streaming(vec![group]).await?;
220
221 let result = stream
222 .next()
223 .await
224 .expect("stream should have yielded a message")?;
225 assert_eq!(
226 result.indexes,
227 vec![0],
228 "indexes should match the mocked response"
229 );
230
231 Ok(())
232 }
233
234 #[cfg(feature = "unstable-stream")]
235 #[tokio_test_no_panics]
236 async fn execute_streaming_into_stream() -> Result<()> {
237 use futures::StreamExt;
238
239 let mut mock = MockSpanner::new();
240 mock.expect_create_session().returning(|_| {
241 Ok(Response::new(mock_v1::Session {
242 name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
243 ..Default::default()
244 }))
245 });
246
247 mock.expect_batch_write().once().returning(|req| {
248 let req = req.into_inner();
249 assert_eq!(
250 req.session, "projects/p/instances/i/databases/d/sessions/123",
251 "session name should match"
252 );
253 assert_eq!(
254 req.mutation_groups.len(),
255 1,
256 "should contain precisely 1 mutation group"
257 );
258
259 let response = mock_v1::BatchWriteResponse {
260 indexes: vec![0],
261 status: None,
262 commit_timestamp: None,
263 };
264
265 Ok(Response::from(adapt([Ok(response)])))
266 });
267
268 let (db_client, _server) = setup_db_client(mock).await;
269
270 let mutation = Mutation::new_insert_builder("Users")
271 .set("UserId")
272 .to(&1)
273 .build();
274 let group = MutationGroup::new(vec![mutation]);
275
276 let transaction = db_client.batch_write_transaction().build();
277 let stream = transaction.execute_streaming(vec![group]).await?;
278 let mut stream = stream.into_stream();
279
280 let result = stream
281 .next()
282 .await
283 .expect("stream should have yielded a message")?;
284 assert_eq!(
285 result.indexes,
286 vec![0],
287 "indexes should match the mocked response"
288 );
289
290 Ok(())
291 }
292}