Skip to main content

google_cloud_spanner/
client.rs

1// Copyright 2026 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::generated::gapic_dataplane::client::Spanner as GapicSpanner;
16use crate::model::{
17    BeginTransactionRequest, CommitRequest, CommitResponse, CreateSessionRequest,
18    ExecuteBatchDmlRequest, ExecuteBatchDmlResponse, ExecuteSqlRequest, PartitionQueryRequest,
19    PartitionReadRequest, PartitionResponse, RollbackRequest, Session, Transaction,
20};
21use crate::server_streaming::builder;
22use gaxi::options::{ClientConfig, Credentials};
23use google_cloud_auth::credentials::anonymous;
24use google_cloud_gax::client_builder::ClientBuilder as GaxClientBuilder;
25use google_cloud_gax::client_builder::internal::new_builder;
26use google_cloud_gax::options::{
27    RequestOptions as GaxRequestOptions, internal::RequestOptionsExt as _,
28};
29use google_cloud_spanner_admin_database_v1::builder::database_admin::ClientBuilder as DatabaseAdminBuilder;
30use google_cloud_spanner_admin_instance_v1::builder::instance_admin::ClientBuilder as InstanceAdminBuilder;
31use http::{
32    HeaderMap,
33    header::{HeaderName, HeaderValue},
34};
35use std::sync::{
36    LazyLock,
37    atomic::{AtomicUsize, Ordering},
38};
39
40pub use crate::database_client::DatabaseClient;
41pub use google_cloud_spanner_admin_database_v1::client::DatabaseAdmin;
42pub use google_cloud_spanner_admin_instance_v1::client::InstanceAdmin;
43
44/// A client for the [Spanner] API.
45///
46/// Use this client to interact with the Spanner service.
47///
48/// [Spanner]: https://docs.cloud.google.com/spanner/docs
49#[derive(Clone, Debug)]
50pub struct Spanner {
51    pub(crate) channels: Vec<Channel>,
52    pub(crate) counter: std::sync::Arc<AtomicUsize>,
53    pub(crate) config: ClientConfig,
54    pub(crate) is_emulator: bool,
55}
56
57/// A factory for constructing `Spanner` clients.
58pub struct Factory;
59
60impl google_cloud_gax::client_builder::internal::ClientFactory for Factory {
61    type Client = Spanner;
62    type Credentials = Credentials;
63
64    async fn build(self, mut config: ClientConfig) -> crate::ClientBuilderResult<Self::Client> {
65        let mut is_emulator = false;
66        if let Some(endpoint) = std::env::var("SPANNER_EMULATOR_HOST")
67            .ok()
68            .filter(|s| !s.is_empty())
69        {
70            is_emulator = true;
71            if config.endpoint.is_none() {
72                config.endpoint = Some(parse_emulator_endpoint(&endpoint));
73            }
74            if config.cred.is_none() {
75                config.cred = Some(anonymous::Builder::new().build());
76            }
77        }
78
79        let num_channels = std::env::var("SPANNER_NUM_CHANNELS")
80            .ok()
81            .and_then(|s| s.parse::<usize>().ok())
82            .unwrap_or(4);
83
84        let mut channels = Vec::with_capacity(num_channels);
85        for _ in 0..num_channels {
86            channels.push(Channel::create(&config).await?);
87        }
88
89        Ok(Spanner {
90            channels,
91            counter: std::sync::Arc::new(AtomicUsize::new(0)),
92            config,
93            is_emulator,
94        })
95    }
96}
97
98/// A builder for the Spanner client.
99pub type ClientBuilder = google_cloud_gax::client_builder::ClientBuilder<Factory, Credentials>;
100
101fn parse_emulator_endpoint(endpoint: &str) -> String {
102    match url::Url::parse(endpoint) {
103        Ok(url) if url.has_host() => endpoint.to_string(),
104        _ => format!("http://{}", endpoint),
105    }
106}
107
108macro_rules! define_idempotent_rpc {
109    ($method:ident, $request_type:ty, $response_type:ty) => {
110        pub(crate) async fn $method(
111            &self,
112            request: $request_type,
113            options: crate::RequestOptions,
114            channel_hint: usize,
115        ) -> crate::Result<$response_type> {
116            self.get_channel(channel_hint)
117                .inner
118                .$method()
119                .with_request(request)
120                .with_options(apply_request_defaults(options))
121                .send()
122                .await
123        }
124    };
125}
126
127fn apply_request_defaults(mut options: crate::RequestOptions) -> crate::RequestOptions {
128    if options.idempotent().is_none() {
129        options.set_idempotency(true);
130    }
131    if options.retry_policy().is_none() {
132        options.set_retry_policy(crate::retry_policy::SpannerRetryPolicy::new());
133    }
134    options
135}
136
137pub(crate) static LAR_HEADER_MAP: LazyLock<HeaderMap> = LazyLock::new(|| {
138    let mut map = HeaderMap::new();
139    map.insert(
140        HeaderName::from_static("x-goog-spanner-route-to-leader"),
141        HeaderValue::from_static("true"),
142    );
143    map
144});
145
146pub(crate) fn amend_request_options_for_lar(
147    leader_aware_routing_enabled: bool,
148    mut options: GaxRequestOptions,
149) -> GaxRequestOptions {
150    if leader_aware_routing_enabled {
151        let mut headers = options
152            .get_extension::<HeaderMap>()
153            .cloned()
154            .unwrap_or_default();
155        headers.extend((*LAR_HEADER_MAP).clone());
156        options = options.insert_extension(headers);
157    }
158    options
159}
160
161fn map_emulator_admin_endpoint(endpoint: &str, is_emulator: bool) -> String {
162    let mut ep = endpoint.trim_end_matches('/').to_string();
163    if is_emulator && ep.ends_with(":9010") {
164        ep = ep.replace(":9010", ":9020");
165    }
166    ep
167}
168
169impl Spanner {
170    /// Returns a builder for the `Spanner` client.
171    ///
172    /// # Example
173    /// ```
174    /// # use google_cloud_spanner::client::Spanner;
175    /// # async fn sample() -> anyhow::Result<()> {
176    /// let spanner = Spanner::builder().build().await?;
177    ///
178    /// let db_client = spanner
179    ///     .database_client("projects/my-project/instances/my-instance/databases/my-db")
180    ///     .build()
181    ///     .await?;
182    ///
183    /// let tx = db_client.single_use().build();
184    /// let mut rs = tx.execute_query("SELECT 1").await?;
185    ///
186    /// while let Some(row) = rs.next().await {
187    ///     let row = row?;
188    ///     let val: i64 = row.get(0);
189    ///     assert_eq!(val, 1);
190    /// }
191    /// # Ok(())
192    /// # }
193    /// ```
194    ///
195    /// The returned builder is pre-configured with standard defaults. It automatically
196    /// detects and connects to the Spanner emulator if the `SPANNER_EMULATOR_HOST`
197    /// environment variable is set.
198    pub fn builder() -> ClientBuilder {
199        new_builder(Factory)
200    }
201
202    /// Returns a builder for the [DatabaseAdmin] client.
203    ///
204    /// This builder is automatically pre-configured with the same endpoints, credentials,
205    /// and routing configurations as this `Spanner` instance.
206    /// If configured to use the Emulator (via `SPANNER_EMULATOR_HOST`), it maps the gRPC endpoint port
207    /// (`9010`) to the REST admin port (`9020`).
208    pub fn database_admin_builder(&self) -> DatabaseAdminBuilder {
209        self.configure_admin_builder(DatabaseAdmin::builder())
210    }
211
212    /// Returns a builder for the [InstanceAdmin] client.
213    ///
214    /// This builder is automatically pre-configured with the same endpoints, credentials,
215    /// and routing configurations as this `Spanner` instance.
216    /// If configured to use the Emulator (via `SPANNER_EMULATOR_HOST`), it maps the gRPC endpoint port
217    /// (`9010`) to the REST admin port (`9020`).
218    pub fn instance_admin_builder(&self) -> InstanceAdminBuilder {
219        self.configure_admin_builder(InstanceAdmin::builder())
220    }
221
222    fn configure_admin_builder<F, C>(
223        &self,
224        mut builder: GaxClientBuilder<F, C>,
225    ) -> GaxClientBuilder<F, C>
226    where
227        C: Clone + From<Credentials>,
228    {
229        if let Some(ref endpoint) = self.config.endpoint {
230            let ep = map_emulator_admin_endpoint(endpoint, self.is_emulator);
231            builder = builder.with_endpoint(ep);
232        }
233        if let Some(ref cred) = self.config.cred {
234            builder = builder.with_credentials(cred.clone());
235        }
236        if let Some(ref ud) = self.config.universe_domain {
237            builder = builder.with_universe_domain(ud.clone());
238        }
239        builder
240    }
241
242    /// Returns a new [DatabaseClientBuilder](crate::database_client::DatabaseClientBuilder) for
243    /// interacting with a specific database.
244    ///
245    /// # Example
246    /// ```
247    /// # use google_cloud_spanner::client::Spanner;
248    /// # async fn sample() -> anyhow::Result<()> {
249    ///     let spanner = Spanner::builder().build().await?;
250    ///     let database_client = spanner
251    ///         .database_client("projects/my-project/instances/my-instance/databases/my-db")
252    ///         .build()
253    ///         .await?;
254    ///     # Ok(())
255    /// # }
256    /// ```
257    ///
258    /// The returned `DatabaseClient` is intended to be a long-lived object and should be reused
259    /// for all operations on the database.
260    pub fn database_client(
261        &self,
262        database: impl Into<String>,
263    ) -> crate::builder::DatabaseClientBuilder {
264        crate::builder::DatabaseClientBuilder::new(self.clone(), database.into())
265    }
266
267    /// Creates a new client from the provided stub.
268    ///
269    /// The most common case for calling this function is in tests mocking the
270    /// client's behavior.
271    pub fn from_stub<T>(stub: T) -> Self
272    where
273        T: crate::generated::gapic_dataplane::stub::Spanner + 'static,
274    {
275        // This method is primarily for testing and doesn't fully initialize grpc_client.
276        // For production use, prefer `Spanner::builder().build()`.
277        Self {
278            channels: vec![Channel {
279                inner: GapicSpanner::from_stub(stub),
280                grpc_client: None,
281            }],
282            counter: std::sync::Arc::new(AtomicUsize::new(0)),
283            config: ClientConfig::default(),
284            is_emulator: false,
285        }
286    }
287
288    pub(crate) fn is_emulator(&self) -> bool {
289        self.is_emulator
290    }
291
292    pub(crate) fn get_channel(&self, hint: usize) -> &Channel {
293        let idx = hint % self.channels.len();
294        &self.channels[idx]
295    }
296
297    pub(crate) fn next_channel_hint(&self) -> usize {
298        self.counter.fetch_add(1, Ordering::Relaxed)
299    }
300
301    define_idempotent_rpc!(create_session, CreateSessionRequest, Session);
302    define_idempotent_rpc!(execute_sql, ExecuteSqlRequest, crate::model::ResultSet);
303    define_idempotent_rpc!(
304        execute_batch_dml,
305        ExecuteBatchDmlRequest,
306        ExecuteBatchDmlResponse
307    );
308    define_idempotent_rpc!(begin_transaction, BeginTransactionRequest, Transaction);
309    define_idempotent_rpc!(commit, CommitRequest, CommitResponse);
310    define_idempotent_rpc!(rollback, RollbackRequest, ());
311    define_idempotent_rpc!(partition_query, PartitionQueryRequest, PartitionResponse);
312    define_idempotent_rpc!(partition_read, PartitionReadRequest, PartitionResponse);
313
314    /// Executes an SQL statement, returning a stream of results.
315    ///
316    /// This is a custom streaming implementation over the underlying Spanner gRPC
317    /// transport, since streaming responses are not yet auto-generated here.
318    pub(crate) fn execute_streaming_sql(
319        &self,
320        request: crate::model::ExecuteSqlRequest,
321        options: crate::RequestOptions,
322        channel_hint: usize,
323    ) -> builder::ExecuteStreamingSql {
324        let channel = self.get_channel(channel_hint);
325        let grpc = channel
326            .grpc_client
327            .as_ref()
328            .expect("Streaming RPCs are not supported when using a stub client");
329        builder::ExecuteStreamingSql::new(grpc.clone())
330            .with_request(request)
331            .with_options(options)
332    }
333
334    /// Reads rows from the database, returning a stream of results.
335    ///
336    /// This is a custom streaming implementation over the underlying Spanner gRPC
337    /// transport, since streaming responses are not yet auto-generated here.
338    pub(crate) fn streaming_read(
339        &self,
340        request: crate::model::ReadRequest,
341        options: crate::RequestOptions,
342        channel_hint: usize,
343    ) -> builder::StreamingRead {
344        let channel = self.get_channel(channel_hint);
345        let grpc = channel
346            .grpc_client
347            .as_ref()
348            .expect("Streaming RPCs are not supported when using a stub client");
349        builder::StreamingRead::new(grpc.clone())
350            .with_request(request)
351            .with_options(options)
352    }
353
354    pub(crate) fn batch_write(
355        &self,
356        request: crate::model::BatchWriteRequest,
357        options: crate::RequestOptions,
358        channel_hint: usize,
359    ) -> builder::BatchWrite {
360        let channel = self.get_channel(channel_hint);
361        let grpc = channel
362            .grpc_client
363            .as_ref()
364            .expect("Streaming RPCs are not supported when using a stub client");
365        builder::BatchWrite::new(grpc.clone())
366            .with_request(request)
367            .with_options(options)
368    }
369}
370
371#[derive(Clone, Debug)]
372pub(crate) struct Channel {
373    pub(crate) inner: GapicSpanner,
374    pub(crate) grpc_client: Option<gaxi::grpc::Client>,
375}
376
377impl Channel {
378    pub(crate) async fn create(config: &ClientConfig) -> crate::ClientBuilderResult<Self> {
379        let transport =
380            crate::generated::gapic_dataplane::transport::Spanner::new(config.clone()).await?;
381        let grpc_client = transport.inner.clone();
382
383        let inner = if gaxi::options::tracing_enabled(config) {
384            GapicSpanner::from_stub(crate::generated::gapic_dataplane::tracing::Spanner::new(
385                transport,
386            ))
387        } else {
388            GapicSpanner::from_stub(transport)
389        };
390        Ok(Self {
391            inner,
392            grpc_client: Some(grpc_client),
393        })
394    }
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400    use crate::model::CreateSessionRequest;
401    use crate::read::ReadRequest;
402    use crate::result_set::tests::adapt;
403    use crate::statement::Statement;
404    use gaxi::grpc::tonic::MetadataMap;
405    use gaxi::grpc::tonic::{Code as GrpcCode, Response, Status};
406    use google_cloud_auth::credentials::anonymous::Builder as Anonymous;
407    use google_cloud_gax::backoff_policy::BackoffPolicy;
408    use google_cloud_gax::error::rpc::Code;
409    use google_cloud_gax::retry_state::RetryState;
410    use google_cloud_test_macros::tokio_test_no_panics;
411    use spanner_grpc_mock::google::rpc as mock_rpc;
412    use spanner_grpc_mock::google::spanner::v1 as mock_v1;
413    use spanner_grpc_mock::google::spanner::v1::CommitResponse;
414    use spanner_grpc_mock::google::spanner::v1::ResultSet;
415    use spanner_grpc_mock::google::spanner::v1::ResultSetStats;
416    use spanner_grpc_mock::google::spanner::v1::Session;
417    use spanner_grpc_mock::google::spanner::v1::result_set_stats::RowCount;
418    use spanner_grpc_mock::{MockSpanner, start};
419    use static_assertions::{assert_impl_all, assert_not_impl_any};
420    use std::sync::Arc;
421    use std::sync::atomic::{AtomicU64, Ordering};
422    use std::time::Duration;
423
424    mockall::mock! {
425        #[derive(Debug)]
426        BackoffPolicy {}
427        impl BackoffPolicy for BackoffPolicy {
428            fn on_failure(&self, state: &RetryState) -> Duration;
429        }
430    }
431
432    #[test]
433    fn auto_traits() {
434        assert_impl_all!(Spanner: std::fmt::Debug, Clone, Send, Sync);
435        assert_not_impl_any!(Spanner: std::panic::RefUnwindSafe, std::panic::UnwindSafe);
436    }
437
438    #[tokio_test_no_panics]
439    async fn channel_pool_default_size() {
440        let mock = MockSpanner::new();
441        let (address, _server) = start("0.0.0.0:0", mock)
442            .await
443            .expect("Failed to start mock server");
444
445        let client = Spanner::builder()
446            .with_endpoint(address)
447            .with_credentials(Anonymous::new().build())
448            .build()
449            .await
450            .expect("Failed to build client");
451
452        assert_eq!(client.channels.len(), 4);
453    }
454
455    #[test]
456    fn test_map_emulator_admin_endpoint() {
457        // 1. Test normal endpoint without emulator (should remain unchanged)
458        assert_eq!(
459            map_emulator_admin_endpoint("https://spanner.googleapis.com", false),
460            "https://spanner.googleapis.com"
461        );
462
463        // 2. Test emulator endpoint mapping (9010 -> 9020)
464        assert_eq!(
465            map_emulator_admin_endpoint("http://localhost:9010", true),
466            "http://localhost:9020"
467        );
468
469        // 3. Test emulator endpoint with trailing slash (should be trimmed and mapped)
470        assert_eq!(
471            map_emulator_admin_endpoint("http://127.0.0.1:9010/", true),
472            "http://127.0.0.1:9020"
473        );
474
475        // 4. Test emulator endpoint without is_emulator active (should remain unchanged)
476        assert_eq!(
477            map_emulator_admin_endpoint("http://localhost:9010", false),
478            "http://localhost:9010"
479        );
480    }
481
482    #[tokio_test_no_panics]
483    async fn channel_selection() {
484        let mock = MockSpanner::new();
485        let (address, _server) = start("0.0.0.0:0", mock)
486            .await
487            .expect("Failed to start mock server");
488
489        let client = Spanner::builder()
490            .with_endpoint(address)
491            .with_credentials(Anonymous::new().build())
492            .build()
493            .await
494            .expect("Failed to build client");
495
496        let hint0 = client.next_channel_hint();
497        let hint1 = client.next_channel_hint();
498        let hint2 = client.next_channel_hint();
499        let hint3 = client.next_channel_hint();
500        let hint4 = client.next_channel_hint();
501
502        assert_eq!(hint0 % 4, 0);
503        assert_eq!(hint1 % 4, 1);
504        assert_eq!(hint2 % 4, 2);
505        assert_eq!(hint3 % 4, 3);
506        assert_eq!(hint4 % 4, 0);
507    }
508
509    #[tokio_test_no_panics]
510    async fn test_create_session() {
511        // 1. Setup Mock Server
512        let mut mock = MockSpanner::new();
513        mock.expect_create_session().once().returning(|_| {
514            Ok(gaxi::grpc::tonic::Response::new(mock_v1::Session {
515                name:
516                    "projects/test-project/instances/test-instance/databases/test-db/sessions/123"
517                        .to_string(),
518                ..Default::default()
519            }))
520        });
521
522        // 2. Start mock server
523        let (address, _server) = start("0.0.0.0:0", mock)
524            .await
525            .expect("Failed to start mock server");
526
527        // 3. Configure Client to use mock endpoint
528        let client = Spanner::builder()
529            .with_endpoint(address)
530            .with_credentials(Anonymous::new().build())
531            .build()
532            .await
533            .expect("Failed to build client");
534
535        // 4. Call CreateSession
536        let mut req = CreateSessionRequest::new();
537        req.database =
538            "projects/test-project/instances/test-instance/databases/test-db".to_string();
539
540        let session = client
541            .create_session(
542                req,
543                crate::RequestOptions::default(),
544                client.next_channel_hint(),
545            )
546            .await
547            .expect("Failed to call create_session");
548
549        // 5. Verify Response
550        assert_eq!(
551            session.name,
552            "projects/test-project/instances/test-instance/databases/test-db/sessions/123"
553        );
554    }
555
556    #[tokio_test_no_panics]
557    async fn test_create_session_retry() {
558        use google_cloud_gax::options::RequestOptionsBuilder;
559        use google_cloud_gax::retry_policy::{Aip194Strict, RetryPolicyExt};
560
561        // 1. Setup Mock Server
562        let mut mock = MockSpanner::new();
563        let mut seq = mockall::Sequence::new();
564        mock.expect_create_session()
565            .once()
566            .in_sequence(&mut seq)
567            .returning(|_| {
568                Err(gaxi::grpc::tonic::Status::unavailable(
569                    "server is unavailable",
570                ))
571            });
572        mock.expect_create_session().once().in_sequence(&mut seq).returning(|_| {
573            Ok(gaxi::grpc::tonic::Response::new(mock_v1::Session {
574                name: "projects/test-project/instances/test-instance/databases/test-db/sessions/456".to_string(),
575                ..Default::default()
576            }))
577        });
578
579        // 2. Start mock server
580        let (address, _server) = start("0.0.0.0:0", mock)
581            .await
582            .expect("Failed to start mock server");
583
584        // 3. Configure Client to use mock endpoint
585        // NOTE: Default retry policy is assigned automatically for GAPIC methods.
586        let client = Spanner::builder()
587            .with_endpoint(address)
588            .with_credentials(Anonymous::new().build())
589            .build()
590            .await
591            .expect("Failed to build client");
592
593        // 4. Call CreateSession with intentional retry configurations
594        let mut req = CreateSessionRequest::new();
595        req.database =
596            "projects/test-project/instances/test-instance/databases/test-db".to_string();
597
598        let session = client
599            .get_channel(client.next_channel_hint())
600            .inner
601            .create_session()
602            .with_request(req)
603            .with_idempotency(true)
604            .with_retry_policy(Aip194Strict.with_attempt_limit(3))
605            .send()
606            .await
607            .expect("Failed to call create_session");
608
609        // 5. Verify Response
610        assert_eq!(
611            session.name,
612            "projects/test-project/instances/test-instance/databases/test-db/sessions/456"
613        );
614    }
615
616    #[tokio_test_no_panics]
617    async fn test_create_session_transport_retry() {
618        // 1. Setup Mock Server
619        let mut mock = MockSpanner::new();
620        let mut seq = mockall::Sequence::new();
621        mock.expect_create_session()
622            .once()
623            .in_sequence(&mut seq)
624            .returning(|_| {
625                let mut status = Status::unavailable("connection reset");
626                let mut headers = std::mem::take(status.metadata_mut()).into_headers();
627                headers.insert("content-type", http::HeaderValue::from_static("text/html"));
628                *status.metadata_mut() = MetadataMap::from_headers(headers);
629                Err(status)
630            });
631        mock.expect_create_session()
632            .once()
633            .in_sequence(&mut seq)
634            .returning(|_| {
635                Ok(gaxi::grpc::tonic::Response::new(mock_v1::Session {
636                    name: "projects/test-project/instances/test-instance/databases/test-db/sessions/789".to_string(),
637                    ..Default::default()
638                }))
639            });
640
641        // 2. Start mock server
642        let (address, _server) = start("0.0.0.0:0", mock)
643            .await
644            .expect("Failed to start mock server");
645
646        // 3. Configure Client to use mock endpoint
647        let client = Spanner::builder()
648            .with_endpoint(address)
649            .with_credentials(Anonymous::new().build())
650            .build()
651            .await
652            .expect("Failed to build client");
653
654        // 4. Call CreateSession
655        let mut req = CreateSessionRequest::new();
656        req.database =
657            "projects/test-project/instances/test-instance/databases/test-db".to_string();
658
659        let session = client
660            .create_session(
661                req,
662                crate::RequestOptions::default(),
663                client.next_channel_hint(),
664            )
665            .await
666            .expect("Failed to call create_session after transport error retry");
667
668        // 5. Verify Response
669        assert_eq!(
670            session.name,
671            "projects/test-project/instances/test-instance/databases/test-db/sessions/789",
672            "Expected session name to match the second successful response after transport retry"
673        );
674    }
675
676    #[tokio_test_no_panics]
677    async fn test_execute_sql() {
678        use crate::model::ExecuteSqlRequest;
679
680        let mut mock = MockSpanner::new();
681        mock.expect_execute_sql().once().returning(|_| {
682            Ok(gaxi::grpc::tonic::Response::new(mock_v1::ResultSet {
683                metadata: Some(mock_v1::ResultSetMetadata {
684                    row_type: Some(mock_v1::StructType { fields: vec![] }),
685                    transaction: None,
686                    undeclared_parameters: None,
687                }),
688                rows: vec![],
689                stats: None,
690                precommit_token: None,
691                cache_update: None,
692            }))
693        });
694
695        let (address, _server) = start("0.0.0.0:0", mock)
696            .await
697            .expect("Failed to start mock server");
698        let client = Spanner::builder()
699            .with_endpoint(address)
700            .with_credentials(Anonymous::new().build())
701            .build()
702            .await
703            .expect("Failed to build client");
704
705        let mut req = ExecuteSqlRequest::new();
706        req.sql = "SELECT 1".to_string();
707
708        let result_set = client
709            .execute_sql(
710                req,
711                crate::RequestOptions::default(),
712                client.next_channel_hint(),
713            )
714            .await
715            .expect("Failed to call execute_sql");
716        assert!(result_set.metadata.is_some());
717    }
718
719    #[tokio_test_no_panics]
720    async fn test_execute_batch_dml() {
721        use crate::model::ExecuteBatchDmlRequest;
722
723        let mut mock = MockSpanner::new();
724        mock.expect_execute_batch_dml().once().returning(|_| {
725            Ok(gaxi::grpc::tonic::Response::new(
726                mock_v1::ExecuteBatchDmlResponse {
727                    result_sets: vec![],
728                    status: Some(mock_rpc::Status {
729                        code: 0,
730                        message: "OK".to_string(),
731                        details: vec![],
732                    }),
733                    precommit_token: None,
734                },
735            ))
736        });
737
738        let (address, _server) = start("0.0.0.0:0", mock)
739            .await
740            .expect("Failed to start mock server");
741        let client = Spanner::builder()
742            .with_endpoint(address)
743            .with_credentials(Anonymous::new().build())
744            .build()
745            .await
746            .expect("Failed to build client");
747
748        let mut req = ExecuteBatchDmlRequest::new();
749        req.session = "test_session".to_string();
750
751        let response = client
752            .execute_batch_dml(
753                req,
754                crate::RequestOptions::default(),
755                client.next_channel_hint(),
756            )
757            .await
758            .expect("Failed to call execute_batch_dml");
759        assert!(response.status.is_some());
760    }
761
762    #[tokio_test_no_panics]
763    async fn test_begin_transaction() {
764        use crate::model::BeginTransactionRequest;
765
766        let mut mock = MockSpanner::new();
767        mock.expect_begin_transaction().once().returning(|_| {
768            Ok(gaxi::grpc::tonic::Response::new(mock_v1::Transaction {
769                id: vec![1, 2, 3],
770                read_timestamp: None,
771                precommit_token: None,
772                ..Default::default()
773            }))
774        });
775
776        let (address, _server) = start("0.0.0.0:0", mock)
777            .await
778            .expect("Failed to start mock server");
779        let client = Spanner::builder()
780            .with_endpoint(address)
781            .with_credentials(Anonymous::new().build())
782            .build()
783            .await
784            .expect("Failed to build client");
785
786        let mut req = BeginTransactionRequest::new();
787        req.session = "test_session".to_string();
788
789        let tx = client
790            .begin_transaction(
791                req,
792                crate::RequestOptions::default(),
793                client.next_channel_hint(),
794            )
795            .await
796            .expect("Failed to call begin_transaction");
797        assert_eq!(tx.id, vec![1, 2, 3]);
798    }
799
800    #[tokio_test_no_panics]
801    async fn test_commit() {
802        use crate::model::CommitRequest;
803
804        let mut mock = MockSpanner::new();
805        mock.expect_commit().once().returning(|_| {
806            Ok(gaxi::grpc::tonic::Response::new(mock_v1::CommitResponse {
807                commit_timestamp: Some(prost_types::Timestamp {
808                    seconds: 12345,
809                    nanos: 0,
810                }),
811                commit_stats: None,
812                multiplexed_session_retry: None,
813                snapshot_timestamp: None,
814                ..Default::default()
815            }))
816        });
817
818        let (address, _server) = start("0.0.0.0:0", mock)
819            .await
820            .expect("Failed to start mock server");
821        let client = Spanner::builder()
822            .with_endpoint(address)
823            .with_credentials(Anonymous::new().build())
824            .build()
825            .await
826            .expect("Failed to build client");
827
828        let mut req = CommitRequest::new();
829        req.session = "test_session".to_string();
830
831        let response = client
832            .commit(
833                req,
834                crate::RequestOptions::default(),
835                client.next_channel_hint(),
836            )
837            .await
838            .expect("Failed to call commit");
839        assert!(response.commit_timestamp.is_some());
840    }
841
842    #[tokio_test_no_panics]
843    async fn test_rollback() {
844        use crate::model::RollbackRequest;
845
846        let mut mock = MockSpanner::new();
847        mock.expect_rollback()
848            .once()
849            .returning(|_| Ok(gaxi::grpc::tonic::Response::new(())));
850
851        let (address, _server) = start("0.0.0.0:0", mock)
852            .await
853            .expect("Failed to start mock server");
854        let client = Spanner::builder()
855            .with_endpoint(address)
856            .with_credentials(Anonymous::new().build())
857            .build()
858            .await
859            .expect("Failed to build client");
860
861        let mut req = RollbackRequest::new();
862        req.session = "test_session".to_string();
863
864        client
865            .rollback(
866                req,
867                crate::RequestOptions::default(),
868                client.next_channel_hint(),
869            )
870            .await
871            .expect("Failed to call rollback");
872    }
873
874    #[tokio_test_no_panics]
875    async fn test_execute_streaming_sql() {
876        use crate::model::ExecuteSqlRequest;
877
878        let mut mock = MockSpanner::new();
879        mock.expect_execute_streaming_sql().once().returning(|_| {
880            let result_set = mock_v1::PartialResultSet {
881                metadata: Some(mock_v1::ResultSetMetadata {
882                    row_type: Some(mock_v1::StructType { fields: vec![] }),
883                    transaction: None,
884                    undeclared_parameters: None,
885                }),
886                values: vec![],
887                chunked_value: false,
888                resume_token: vec![],
889                stats: None,
890                precommit_token: None,
891                cache_update: None,
892                last: false,
893            };
894            Ok(gaxi::grpc::tonic::Response::new(adapt([Ok(result_set)])))
895        });
896
897        let (address, _server) = start("0.0.0.0:0", mock)
898            .await
899            .expect("Failed to start mock server");
900        let client = Spanner::builder()
901            .with_endpoint(address)
902            .with_credentials(Anonymous::new().build())
903            .build()
904            .await
905            .expect("Failed to build client");
906
907        let mut req = ExecuteSqlRequest::new();
908        req.sql = "SELECT 1".to_string();
909
910        let mut stream = client
911            .execute_streaming_sql(
912                req,
913                crate::RequestOptions::default(),
914                client.next_channel_hint(),
915            )
916            .send()
917            .await
918            .expect("Failed to call execute_streaming_sql");
919
920        let result = stream.next_message().await;
921        assert!(result.is_some());
922        assert!(result.unwrap().is_ok());
923    }
924
925    #[tokio_test_no_panics]
926    async fn test_streaming_read() {
927        use crate::model::ReadRequest;
928
929        let mut mock = MockSpanner::new();
930        mock.expect_streaming_read().once().returning(|_| {
931            let result_set = mock_v1::PartialResultSet {
932                metadata: Some(mock_v1::ResultSetMetadata {
933                    row_type: Some(mock_v1::StructType { fields: vec![] }),
934                    transaction: None,
935                    undeclared_parameters: None,
936                }),
937                values: vec![],
938                chunked_value: false,
939                resume_token: vec![],
940                stats: None,
941                precommit_token: None,
942                cache_update: None,
943                last: false,
944            };
945            Ok(gaxi::grpc::tonic::Response::from(adapt([Ok(result_set)])))
946        });
947
948        let (address, _server) = start("0.0.0.0:0", mock)
949            .await
950            .expect("Failed to start mock server");
951        let client = Spanner::builder()
952            .with_endpoint(address)
953            .with_credentials(Anonymous::new().build())
954            .build()
955            .await
956            .expect("Failed to build client");
957
958        let mut req = ReadRequest::new();
959        req.table = "test_table".to_string();
960        req.columns = vec!["col1".to_string()];
961
962        let mut stream = client
963            .streaming_read(
964                req,
965                crate::RequestOptions::default(),
966                client.next_channel_hint(),
967            )
968            .send()
969            .await
970            .expect("Failed to call streaming_read");
971
972        let result = stream.next_message().await;
973        assert!(result.is_some());
974        assert!(result.unwrap().is_ok());
975    }
976
977    #[tokio_test_no_panics]
978    async fn test_batch_write() {
979        use crate::model::BatchWriteRequest;
980
981        let mut mock = MockSpanner::new();
982        mock.expect_batch_write().once().returning(|_| {
983            let response = mock_v1::BatchWriteResponse {
984                indexes: vec![],
985                status: None,
986                commit_timestamp: None,
987            };
988            Ok(gaxi::grpc::tonic::Response::from(adapt([Ok(response)])))
989        });
990
991        let (address, _server) = start("0.0.0.0:0", mock)
992            .await
993            .expect("Failed to start mock server");
994        let client = Spanner::builder()
995            .with_endpoint(address)
996            .with_credentials(Anonymous::new().build())
997            .build()
998            .await
999            .expect("Failed to build client");
1000
1001        let mut req = BatchWriteRequest::new();
1002        req.session = "test_session".to_string();
1003
1004        let mut stream = client
1005            .batch_write(
1006                req,
1007                crate::RequestOptions::default(),
1008                client.next_channel_hint(),
1009            )
1010            .send()
1011            .await
1012            .expect("Failed to call batch_write");
1013
1014        let result = stream.next_message().await;
1015        assert!(result.is_some());
1016        assert!(result.unwrap().is_ok());
1017    }
1018
1019    #[tokio_test_no_panics]
1020    async fn test_execute_streaming_sql_error() {
1021        use crate::model::ExecuteSqlRequest;
1022
1023        let mut mock = MockSpanner::new();
1024        mock.expect_execute_streaming_sql().once().returning(|_| {
1025            let stream = adapt([Err(gaxi::grpc::tonic::Status::internal(
1026                "unexpected internal error",
1027            ))]);
1028            Ok(gaxi::grpc::tonic::Response::from(stream))
1029        });
1030
1031        let (address, _server) = start("0.0.0.0:0", mock)
1032            .await
1033            .expect("Failed to start mock server");
1034        let client = Spanner::builder()
1035            .with_endpoint(address)
1036            .with_credentials(Anonymous::new().build())
1037            .build()
1038            .await
1039            .expect("Failed to build client");
1040
1041        let mut req = ExecuteSqlRequest::new();
1042        req.sql = "SELECT 1".to_string();
1043
1044        let mut stream = client
1045            .execute_streaming_sql(
1046                req,
1047                crate::RequestOptions::default(),
1048                client.next_channel_hint(),
1049            )
1050            .send()
1051            .await
1052            .expect("Failed to call execute_streaming_sql");
1053
1054        let result = stream.next_message().await;
1055        assert!(result.is_some());
1056        let err = result.unwrap().expect_err("expected error");
1057        assert_eq!(
1058            err.status().unwrap().code,
1059            google_cloud_gax::error::rpc::Code::Internal
1060        );
1061    }
1062
1063    #[tokio_test_no_panics]
1064    async fn default_retry_respected() -> anyhow::Result<()> {
1065        use crate::model::CreateSessionRequest;
1066
1067        // 1. Setup Mock Server
1068        let mut mock = MockSpanner::new();
1069        let mut seq = mockall::Sequence::new();
1070        mock.expect_create_session()
1071            .once()
1072            .in_sequence(&mut seq)
1073            .returning(|_| Err(Status::unavailable("server is unavailable")));
1074        mock.expect_create_session().once().in_sequence(&mut seq).returning(|_| {
1075            Ok(Response::new(Session {
1076                name: "projects/test-project/instances/test-instance/databases/test-db/sessions/456".to_string(),
1077                ..Default::default()
1078            }))
1079        });
1080
1081        // 2. Start mock server
1082        let (address, _server) = start("0.0.0.0:0", mock).await?;
1083
1084        // 3. Configure Client
1085        let client = Spanner::builder()
1086            .with_endpoint(address)
1087            .with_credentials(Anonymous::new().build())
1088            .build()
1089            .await?;
1090
1091        // 4. Call CreateSession using the hand-written wrapper
1092        let mut req = CreateSessionRequest::new();
1093        req.database =
1094            "projects/test-project/instances/test-instance/databases/test-db".to_string();
1095
1096        let session = client
1097            .create_session(
1098                req,
1099                crate::RequestOptions::default(),
1100                client.next_channel_hint(),
1101            )
1102            .await
1103            .expect("Failed to call create_session");
1104
1105        // 5. Verify Response
1106        assert_eq!(
1107            session.name,
1108            "projects/test-project/instances/test-instance/databases/test-db/sessions/456"
1109        );
1110
1111        Ok(())
1112    }
1113
1114    #[tokio_test_no_panics]
1115    async fn override_idempotency_to_false() -> anyhow::Result<()> {
1116        use crate::model::CreateSessionRequest;
1117
1118        // 1. Setup Mock Server to fail with UNAVAILABLE
1119        let mut mock = MockSpanner::new();
1120        mock.expect_create_session()
1121            .once()
1122            .returning(|_| Err(Status::unavailable("server is unavailable")));
1123
1124        // 2. Start mock server
1125        let (address, _server) = start("0.0.0.0:0", mock).await?;
1126
1127        // 3. Configure Client
1128        let client = Spanner::builder()
1129            .with_endpoint(address)
1130            .with_credentials(Anonymous::new().build())
1131            .build()
1132            .await?;
1133
1134        // 4. Call CreateSession with explicit idempotency = false
1135        let mut req = CreateSessionRequest::new();
1136        req.database =
1137            "projects/test-project/instances/test-instance/databases/test-db".to_string();
1138
1139        let mut options = crate::RequestOptions::default();
1140        options.set_idempotency(false);
1141
1142        let result = client
1143            .create_session(req, options, client.next_channel_hint())
1144            .await;
1145
1146        // 5. Verify that it failed and did not retry
1147        assert!(result.is_err(), "Expected error, got {:?}", result);
1148        let err = result.unwrap_err();
1149        assert_eq!(err.status().map(|s| s.code), Some(Code::Unavailable));
1150
1151        Ok(())
1152    }
1153
1154    #[tokio_test_no_panics]
1155    async fn timeout_respected() -> anyhow::Result<()> {
1156        use crate::batch_dml::BatchDml;
1157        use std::time::Duration;
1158
1159        // 1. Setup Mock Server
1160        let mut mock = MockSpanner::new();
1161
1162        mock.expect_create_session().returning(|_| {
1163            Ok(Response::new(Session {
1164                name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
1165                ..Default::default()
1166            }))
1167        });
1168
1169        mock.expect_begin_transaction().returning(|_| {
1170            Ok(Response::new(mock_v1::Transaction {
1171                id: vec![42],
1172                ..Default::default()
1173            }))
1174        });
1175
1176        mock.expect_execute_streaming_sql().once().returning(|req| {
1177            let metadata = req.metadata();
1178            let timeout = metadata.get("grpc-timeout");
1179            assert!(
1180                timeout.is_some(),
1181                "grpc-timeout header should be present for query"
1182            );
1183
1184            let (tx, rx) = tokio::sync::mpsc::channel(1);
1185            let metadata = mock_v1::ResultSetMetadata {
1186                transaction: Some(mock_v1::Transaction {
1187                    id: vec![42],
1188                    ..Default::default()
1189                }),
1190                ..Default::default()
1191            };
1192            let prs = mock_v1::PartialResultSet {
1193                metadata: Some(metadata),
1194                ..Default::default()
1195            };
1196            tx.try_send(Ok(prs)).unwrap();
1197            Ok(Response::new(rx))
1198        });
1199
1200        mock.expect_streaming_read().once().returning(|req| {
1201            let metadata = req.metadata();
1202            let timeout = metadata.get("grpc-timeout");
1203            assert!(
1204                timeout.is_some(),
1205                "grpc-timeout header should be present for read"
1206            );
1207
1208            let (tx, rx) = tokio::sync::mpsc::channel(1);
1209            let metadata = mock_v1::ResultSetMetadata {
1210                transaction: None,
1211                ..Default::default()
1212            };
1213            let prs = mock_v1::PartialResultSet {
1214                metadata: Some(metadata),
1215                ..Default::default()
1216            };
1217            tx.try_send(Ok(prs)).unwrap();
1218            Ok(Response::new(rx))
1219        });
1220
1221        mock.expect_execute_sql().once().returning(|req| {
1222            let metadata = req.metadata();
1223            let timeout = metadata.get("grpc-timeout");
1224            assert!(
1225                timeout.is_some(),
1226                "grpc-timeout header should be present for single DML"
1227            );
1228
1229            Ok(Response::new(mock_v1::ResultSet {
1230                metadata: Some(mock_v1::ResultSetMetadata {
1231                    transaction: Some(mock_v1::Transaction {
1232                        id: vec![42],
1233                        ..Default::default()
1234                    }),
1235                    ..Default::default()
1236                }),
1237                stats: Some(mock_v1::ResultSetStats {
1238                    row_count: Some(mock_v1::result_set_stats::RowCount::RowCountExact(1)),
1239                    ..Default::default()
1240                }),
1241                ..Default::default()
1242            }))
1243        });
1244
1245        mock.expect_execute_batch_dml().once().returning(|req| {
1246            let metadata = req.metadata();
1247            let timeout = metadata.get("grpc-timeout");
1248            assert!(
1249                timeout.is_some(),
1250                "grpc-timeout header should be present for batch dml"
1251            );
1252
1253            Ok(Response::new(mock_v1::ExecuteBatchDmlResponse {
1254                result_sets: vec![mock_v1::ResultSet {
1255                    stats: Some(mock_v1::ResultSetStats {
1256                        row_count: Some(mock_v1::result_set_stats::RowCount::RowCountExact(1)),
1257                        ..Default::default()
1258                    }),
1259                    ..Default::default()
1260                }],
1261                ..Default::default()
1262            }))
1263        });
1264
1265        mock.expect_commit().returning(|_| {
1266            Ok(Response::new(mock_v1::CommitResponse {
1267                commit_timestamp: Some(prost_types::Timestamp {
1268                    seconds: 1234,
1269                    nanos: 0,
1270                }),
1271                ..Default::default()
1272            }))
1273        });
1274
1275        // 2. Start mock server
1276        let (address, _server) = start("0.0.0.0:0", mock).await?;
1277
1278        // 3. Configure Client
1279        let client = Spanner::builder()
1280            .with_endpoint(address)
1281            .with_credentials(Anonymous::new().build())
1282            .build()
1283            .await?;
1284
1285        let db = client
1286            .database_client("projects/p/instances/i/databases/d")
1287            .build()
1288            .await?;
1289        let runner = db.read_write_transaction().build().await?;
1290
1291        // 4. Run transaction
1292        runner
1293            .run(async |tx| {
1294                // Query
1295                let stmt = Statement::builder("SELECT 1")
1296                    .with_attempt_timeout(Duration::from_secs(10))
1297                    .build();
1298                // TODO(#5673): ensure that transaction ID is processed even if ResultSet is dropped
1299                let _rs = tx.execute_query(stmt).await?;
1300
1301                // Read
1302                let req = ReadRequest::builder("Table", vec!["Col"])
1303                    .with_keys(crate::key::KeySet::all())
1304                    .with_attempt_timeout(Duration::from_secs(5))
1305                    .build();
1306                let _ = tx.execute_read(req).await?;
1307
1308                // Single DML
1309                let dml = Statement::builder("UPDATE t SET c = 1")
1310                    .with_attempt_timeout(Duration::from_secs(7))
1311                    .build();
1312                let _ = tx.execute_update(dml).await?;
1313
1314                // Batch DML
1315                let batch = BatchDml::builder()
1316                    .add_statement("UPDATE t SET c = 2")
1317                    .with_attempt_timeout(Duration::from_secs(8))
1318                    .build();
1319                let _ = tx.execute_batch_update(batch).await?;
1320
1321                Ok(())
1322            })
1323            .await?;
1324
1325        Ok(())
1326    }
1327
1328    #[tokio_test_no_panics]
1329    async fn retry_policy_respected() -> anyhow::Result<()> {
1330        use google_cloud_gax::retry_policy::{Aip194Strict, RetryPolicyExt};
1331
1332        // Extend the default retry policy to also retry on ResourceExhausted.
1333        let retry_policy = Aip194Strict.continue_on_too_many_requests();
1334
1335        // 1. Setup Mock Server
1336        let mut mock = MockSpanner::new();
1337
1338        mock.expect_create_session().returning(|_| {
1339            Ok(Response::new(Session {
1340                name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
1341                ..Default::default()
1342            }))
1343        });
1344
1345        mock.expect_begin_transaction().returning(|_| {
1346            Ok(Response::new(mock_v1::Transaction {
1347                id: vec![42],
1348                ..Default::default()
1349            }))
1350        });
1351
1352        // Mock ExecuteSql to first return RESOURCE_EXHAUSTED and then succeed.
1353        let mut seq = mockall::Sequence::new();
1354
1355        mock.expect_execute_sql()
1356            .once()
1357            .in_sequence(&mut seq)
1358            .returning(|_| Err(Status::new(GrpcCode::ResourceExhausted, "quota exceeded")));
1359
1360        mock.expect_execute_sql()
1361            .once()
1362            .in_sequence(&mut seq)
1363            .returning(|_| {
1364                Ok(Response::new(mock_v1::ResultSet {
1365                    metadata: Some(mock_v1::ResultSetMetadata {
1366                        transaction: Some(mock_v1::Transaction {
1367                            id: vec![42],
1368                            ..Default::default()
1369                        }),
1370                        ..Default::default()
1371                    }),
1372                    stats: Some(mock_v1::ResultSetStats {
1373                        row_count: Some(mock_v1::result_set_stats::RowCount::RowCountExact(1)),
1374                        ..Default::default()
1375                    }),
1376                    ..Default::default()
1377                }))
1378            });
1379
1380        mock.expect_commit().returning(|_| {
1381            Ok(Response::new(mock_v1::CommitResponse {
1382                commit_timestamp: Some(prost_types::Timestamp {
1383                    seconds: 1234,
1384                    nanos: 0,
1385                }),
1386                ..Default::default()
1387            }))
1388        });
1389
1390        // 2. Start mock server
1391        let (address, _server) = start("0.0.0.0:0", mock).await?;
1392
1393        // 3. Configure Client
1394        let client = Spanner::builder()
1395            .with_endpoint(address)
1396            .with_credentials(Anonymous::new().build())
1397            .build()
1398            .await?;
1399
1400        let db = client
1401            .database_client("projects/p/instances/i/databases/d")
1402            .build()
1403            .await?;
1404        let runner = db.read_write_transaction().build().await?;
1405
1406        // 4. Call execute_update with custom retry and backoff
1407        let mut mock_backoff = MockBackoffPolicy::new();
1408        mock_backoff
1409            .expect_on_failure()
1410            .once()
1411            .returning(|_| Duration::from_nanos(1));
1412
1413        let stmt = Statement::builder("UPDATE t SET c = 1")
1414            .with_retry_policy(retry_policy)
1415            .with_backoff_policy(mock_backoff)
1416            .build();
1417
1418        let result = runner
1419            .run(async |tx| {
1420                let count = tx.execute_update(stmt.clone()).await?;
1421                Ok(count)
1422            })
1423            .await?;
1424
1425        // 5. Verify success after retry
1426        assert_eq!(result.result, 1);
1427
1428        Ok(())
1429    }
1430
1431    fn parse_timeout(metadata: &MetadataMap) -> u64 {
1432        let timeout = metadata
1433            .get("grpc-timeout")
1434            .expect("grpc-timeout header should be present");
1435        let timeout_str = timeout
1436            .to_str()
1437            .expect("grpc-timeout should be a valid string");
1438        if timeout_str.ends_with('u') {
1439            timeout_str
1440                .trim_end_matches('u')
1441                .parse()
1442                .expect("valid u64")
1443        } else if timeout_str.ends_with('m') {
1444            timeout_str
1445                .trim_end_matches('m')
1446                .parse::<u64>()
1447                .expect("valid u64")
1448                * 1000
1449        } else if timeout_str.ends_with('n') {
1450            timeout_str
1451                .trim_end_matches('n')
1452                .parse::<u64>()
1453                .expect("valid u64")
1454                / 1000
1455        } else {
1456            panic!("Unknown timeout unit in {}", timeout_str);
1457        }
1458    }
1459
1460    #[tokio_test_no_panics]
1461    async fn transaction_timeout_respected() -> anyhow::Result<()> {
1462        use google_cloud_gax::retry_policy::{Aip194Strict, RetryPolicyExt};
1463        use spanner_grpc_mock::google::spanner::v1::Transaction;
1464
1465        // 1. Setup Mock Server
1466        let mut mock = MockSpanner::new();
1467
1468        mock.expect_create_session().returning(|_| {
1469            Ok(Response::new(Session {
1470                name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
1471                ..Default::default()
1472            }))
1473        });
1474
1475        mock.expect_begin_transaction().returning(|_| {
1476            Ok(Response::new(Transaction {
1477                id: vec![1, 2, 3],
1478                ..Default::default()
1479            }))
1480        });
1481
1482        mock.expect_commit().once().returning(|_| {
1483            Ok(Response::new(CommitResponse {
1484                commit_timestamp: Some(prost_types::Timestamp {
1485                    seconds: 12345,
1486                    nanos: 0,
1487                }),
1488                ..Default::default()
1489            }))
1490        });
1491
1492        // Mock execute_sql to first fail and then succeed, checking timeout header on both
1493        let mut seq = mockall::Sequence::new();
1494
1495        mock.expect_execute_sql()
1496            .once()
1497            .in_sequence(&mut seq)
1498            .returning(|req| {
1499                let timeout_val = parse_timeout(req.metadata());
1500                assert!(
1501                    timeout_val <= 100000,
1502                    "Expected timeout to be <= 100ms, got {}",
1503                    timeout_val
1504                );
1505                Err(Status::new(GrpcCode::ResourceExhausted, "quota exceeded"))
1506            });
1507
1508        mock.expect_execute_sql()
1509            .once()
1510            .in_sequence(&mut seq)
1511            .returning(|req| {
1512                let timeout_val = parse_timeout(req.metadata());
1513                assert!(
1514                    timeout_val <= 100000,
1515                    "Expected timeout to be <= 100ms, got {}",
1516                    timeout_val
1517                );
1518
1519                let res = ResultSet {
1520                    metadata: Some(spanner_grpc_mock::google::spanner::v1::ResultSetMetadata {
1521                        transaction: Some(Transaction {
1522                            id: vec![1, 2, 3],
1523                            ..Default::default()
1524                        }),
1525                        ..Default::default()
1526                    }),
1527                    stats: Some(ResultSetStats {
1528                        row_count: Some(RowCount::RowCountExact(1)),
1529                        ..Default::default()
1530                    }),
1531                    ..Default::default()
1532                };
1533                Ok(Response::new(res))
1534            });
1535
1536        // 2. Initialize Client
1537        let (address, _server) = start("127.0.0.1:0", mock).await?;
1538        let client = Spanner::builder()
1539            .with_endpoint(address)
1540            .with_credentials(Anonymous::new().build())
1541            .build()
1542            .await?;
1543        let db = client
1544            .database_client("projects/p/instances/i/databases/d")
1545            .build()
1546            .await?;
1547
1548        // 3. Setup Transaction Runner with 100ms timeout
1549        let runner = db
1550            .read_write_transaction()
1551            .with_transaction_timeout(Duration::from_millis(100))
1552            .build()
1553            .await?;
1554
1555        // 4. Run transaction and expect success after retry
1556        let result = runner
1557            .run(async |tx| {
1558                let mut mock_backoff = MockBackoffPolicy::new();
1559                mock_backoff
1560                    .expect_on_failure()
1561                    .times(1)
1562                    .returning(|_| Duration::from_nanos(1));
1563
1564                let retry_policy = Aip194Strict.continue_on_too_many_requests();
1565
1566                let stmt = Statement::builder("SELECT 1")
1567                    .with_retry_policy(retry_policy)
1568                    .with_backoff_policy(mock_backoff)
1569                    .build();
1570                tx.execute_update(stmt).await?;
1571                Ok(())
1572            })
1573            .await;
1574
1575        result.expect("Transaction should have succeeded");
1576
1577        Ok(())
1578    }
1579
1580    #[tokio_test_no_panics]
1581    async fn transaction_timeout_ticks_down() -> anyhow::Result<()> {
1582        use spanner_grpc_mock::google::spanner::v1::Transaction;
1583
1584        let mut mock = MockSpanner::new();
1585
1586        mock.expect_create_session().returning(|_| {
1587            Ok(Response::new(Session {
1588                name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
1589                ..Default::default()
1590            }))
1591        });
1592
1593        let mut seq = mockall::Sequence::new();
1594
1595        let previous_timeout = Arc::new(AtomicU64::new(0));
1596        let prev_clone1 = previous_timeout.clone();
1597        mock.expect_execute_sql()
1598            .once()
1599            .in_sequence(&mut seq)
1600            .returning(move |req| {
1601                let timeout_val = parse_timeout(req.metadata());
1602                assert!(
1603                    timeout_val <= 500000,
1604                    "Expected timeout to be <= 500ms, got {}",
1605                    timeout_val
1606                );
1607                prev_clone1.store(timeout_val, Ordering::SeqCst);
1608                Err(Status::new(GrpcCode::Aborted, "Aborted"))
1609            });
1610
1611        // Second attempt: Checks that timeout is <= previous
1612
1613        let prev_clone2 = previous_timeout.clone();
1614        mock.expect_execute_sql()
1615            .once()
1616            .in_sequence(&mut seq)
1617            .returning(move |req| {
1618                let timeout_val = parse_timeout(req.metadata());
1619                let prev = prev_clone2.load(Ordering::SeqCst);
1620                assert!(
1621                    timeout_val <= prev,
1622                    "Timeout should tick down between attempts or be equal, got {} and {}",
1623                    timeout_val,
1624                    prev
1625                );
1626                prev_clone2.store(timeout_val, Ordering::SeqCst); // store for next check
1627
1628                let res = ResultSet {
1629                    metadata: Some(spanner_grpc_mock::google::spanner::v1::ResultSetMetadata {
1630                        transaction: Some(Transaction {
1631                            id: vec![2],
1632                            ..Default::default()
1633                        }),
1634                        ..Default::default()
1635                    }),
1636                    stats: Some(ResultSetStats {
1637                        row_count: Some(RowCount::RowCountExact(1)),
1638                        ..Default::default()
1639                    }),
1640                    ..Default::default()
1641                };
1642                Ok(Response::new(res))
1643            });
1644
1645        let prev_clone3 = previous_timeout.clone();
1646        mock.expect_commit().once().returning(move |req| {
1647            let timeout_val = parse_timeout(req.metadata());
1648            let prev = prev_clone3.load(Ordering::SeqCst);
1649            assert!(
1650                timeout_val < prev,
1651                "Timeout should be smaller for commit, got {} and {}",
1652                timeout_val,
1653                prev
1654            );
1655
1656            Ok(Response::new(CommitResponse {
1657                commit_timestamp: Some(prost_types::Timestamp {
1658                    seconds: 12345,
1659                    nanos: 0,
1660                }),
1661                ..Default::default()
1662            }))
1663        });
1664
1665        let (address, _server) = start("127.0.0.1:0", mock).await?;
1666        let client = Spanner::builder()
1667            .with_endpoint(address)
1668            .with_credentials(Anonymous::new().build())
1669            .build()
1670            .await?;
1671        let db = client
1672            .database_client("projects/p/instances/i/databases/d")
1673            .build()
1674            .await?;
1675
1676        let runner = db
1677            .read_write_transaction()
1678            .with_transaction_timeout(Duration::from_millis(500))
1679            .build()
1680            .await?;
1681
1682        let result = runner
1683            .run(async |tx| {
1684                let stmt = Statement::builder("SELECT 1").build();
1685                tx.execute_update(stmt).await?;
1686                Ok(())
1687            })
1688            .await;
1689
1690        result.expect("Transaction should have succeeded");
1691
1692        Ok(())
1693    }
1694
1695    #[test]
1696    fn test_parse_emulator_endpoint() {
1697        assert_eq!(
1698            super::parse_emulator_endpoint("localhost:9010"),
1699            "http://localhost:9010"
1700        );
1701        assert_eq!(
1702            super::parse_emulator_endpoint("spanner-emulator:9010"),
1703            "http://spanner-emulator:9010"
1704        );
1705        assert_eq!(
1706            super::parse_emulator_endpoint("http://localhost:9010"),
1707            "http://localhost:9010"
1708        );
1709        assert_eq!(
1710            super::parse_emulator_endpoint("https://localhost:9010"),
1711            "https://localhost:9010"
1712        );
1713        assert_eq!(
1714            super::parse_emulator_endpoint("grpc://localhost:9010"),
1715            "grpc://localhost:9010"
1716        );
1717        assert_eq!(
1718            super::parse_emulator_endpoint("http_localhost:9010"),
1719            "http://http_localhost:9010"
1720        );
1721    }
1722}