1use crate::generated::gapic_dataplane::client::Spanner as GapicSpanner;
16use crate::model::{
17 BeginTransactionRequest, CommitRequest, CommitResponse, CreateSessionRequest,
18 ExecuteBatchDmlRequest, ExecuteBatchDmlResponse, ExecuteSqlRequest, PartitionQueryRequest,
19 PartitionReadRequest, PartitionResponse, RollbackRequest, Session, Transaction,
20};
21use crate::server_streaming::builder;
22use gaxi::options::{ClientConfig, Credentials};
23use google_cloud_gax::client_builder::ClientBuilder as GaxClientBuilder;
24use google_cloud_gax::options::{
25 RequestOptions as GaxRequestOptions, internal::RequestOptionsExt as _,
26};
27use google_cloud_spanner_admin_database_v1::builder::database_admin::ClientBuilder as DatabaseAdminBuilder;
28use google_cloud_spanner_admin_instance_v1::builder::instance_admin::ClientBuilder as InstanceAdminBuilder;
29use http::{
30 HeaderMap,
31 header::{HeaderName, HeaderValue},
32};
33use std::sync::{
34 LazyLock,
35 atomic::{AtomicUsize, Ordering},
36};
37
38pub use crate::database_client::DatabaseClient;
39pub use google_cloud_spanner_admin_database_v1::client::DatabaseAdmin;
40pub use google_cloud_spanner_admin_instance_v1::client::InstanceAdmin;
41
42#[derive(Clone, Debug)]
48pub struct Spanner {
49 pub(crate) channels: Vec<Channel>,
50 pub(crate) counter: std::sync::Arc<AtomicUsize>,
51 pub(crate) config: ClientConfig,
52}
53
54pub struct Factory;
56
57impl google_cloud_gax::client_builder::internal::ClientFactory for Factory {
58 type Client = Spanner;
59 type Credentials = Credentials;
60
61 async fn build(self, config: ClientConfig) -> crate::ClientBuilderResult<Self::Client> {
62 let num_channels = std::env::var("SPANNER_NUM_CHANNELS")
63 .ok()
64 .and_then(|s| s.parse::<usize>().ok())
65 .unwrap_or(4);
66
67 let mut channels = Vec::with_capacity(num_channels);
68 for _ in 0..num_channels {
69 channels.push(Channel::create(&config).await?);
70 }
71
72 Ok(Spanner {
73 channels,
74 counter: std::sync::Arc::new(AtomicUsize::new(0)),
75 config,
76 })
77 }
78}
79
80pub type ClientBuilder = google_cloud_gax::client_builder::ClientBuilder<Factory, Credentials>;
82
83fn parse_emulator_endpoint(endpoint: &str) -> String {
84 match url::Url::parse(endpoint) {
85 Ok(url) if url.has_host() => endpoint.to_string(),
86 _ => format!("http://{}", endpoint),
87 }
88}
89
90macro_rules! define_idempotent_rpc {
91 ($method:ident, $request_type:ty, $response_type:ty) => {
92 pub(crate) async fn $method(
93 &self,
94 request: $request_type,
95 options: crate::RequestOptions,
96 channel_hint: usize,
97 ) -> crate::Result<$response_type> {
98 self.get_channel(channel_hint)
99 .inner
100 .$method()
101 .with_request(request)
102 .with_options(with_default_idempotency(options))
103 .send()
104 .await
105 }
106 };
107}
108
109fn with_default_idempotency(mut options: crate::RequestOptions) -> crate::RequestOptions {
110 if options.idempotent().is_none() {
111 options.set_idempotency(true);
112 }
113 options
114}
115
116pub(crate) static LAR_HEADER_MAP: LazyLock<HeaderMap> = LazyLock::new(|| {
117 let mut map = HeaderMap::new();
118 map.insert(
119 HeaderName::from_static("x-goog-spanner-route-to-leader"),
120 HeaderValue::from_static("true"),
121 );
122 map
123});
124
125pub(crate) fn amend_request_options_for_lar(
126 leader_aware_routing_enabled: bool,
127 mut options: GaxRequestOptions,
128) -> GaxRequestOptions {
129 if leader_aware_routing_enabled {
130 let mut headers = options
131 .get_extension::<HeaderMap>()
132 .cloned()
133 .unwrap_or_default();
134 headers.extend((*LAR_HEADER_MAP).clone());
135 options = options.insert_extension(headers);
136 }
137 options
138}
139
140fn map_emulator_admin_endpoint(endpoint: &str, is_emulator: bool) -> String {
141 let mut ep = endpoint.trim_end_matches('/').to_string();
142 if is_emulator && ep.ends_with(":9010") {
143 ep = ep.replace(":9010", ":9020");
144 }
145 ep
146}
147
148impl Spanner {
149 pub fn builder() -> ClientBuilder {
178 let builder = google_cloud_gax::client_builder::internal::new_builder(Factory);
179 let Some(endpoint) = std::env::var("SPANNER_EMULATOR_HOST")
182 .ok()
183 .filter(|s| !s.is_empty())
184 else {
185 return builder;
186 };
187
188 let full_endpoint = parse_emulator_endpoint(&endpoint);
190
191 builder
192 .with_endpoint(full_endpoint)
193 .with_credentials(google_cloud_auth::credentials::anonymous::Builder::new().build())
194 }
195
196 pub fn database_admin_builder(&self) -> DatabaseAdminBuilder {
203 self.configure_admin_builder(DatabaseAdmin::builder())
204 }
205
206 pub fn instance_admin_builder(&self) -> InstanceAdminBuilder {
213 self.configure_admin_builder(InstanceAdmin::builder())
214 }
215
216 fn configure_admin_builder<F, C>(
217 &self,
218 mut builder: GaxClientBuilder<F, C>,
219 ) -> GaxClientBuilder<F, C>
220 where
221 C: Clone + From<Credentials>,
222 {
223 if let Some(ref endpoint) = self.config.endpoint {
224 let is_emulator = std::env::var("SPANNER_EMULATOR_HOST")
225 .ok()
226 .filter(|s| !s.is_empty())
227 .is_some();
228 let ep = map_emulator_admin_endpoint(endpoint, is_emulator);
229 builder = builder.with_endpoint(ep);
230 }
231 if let Some(ref cred) = self.config.cred {
232 builder = builder.with_credentials(cred.clone());
233 }
234 if let Some(ref ud) = self.config.universe_domain {
235 builder = builder.with_universe_domain(ud.clone());
236 }
237 builder
238 }
239
240 pub fn database_client(
259 &self,
260 database: impl Into<String>,
261 ) -> crate::builder::DatabaseClientBuilder {
262 crate::builder::DatabaseClientBuilder::new(self.clone(), database.into())
263 }
264
265 pub fn from_stub<T>(stub: T) -> Self
270 where
271 T: crate::generated::gapic_dataplane::stub::Spanner + 'static,
272 {
273 Self {
276 channels: vec![Channel {
277 inner: GapicSpanner::from_stub(stub),
278 grpc_client: None,
279 }],
280 counter: std::sync::Arc::new(AtomicUsize::new(0)),
281 config: ClientConfig::default(),
282 }
283 }
284
285 pub(crate) fn get_channel(&self, hint: usize) -> &Channel {
286 let idx = hint % self.channels.len();
287 &self.channels[idx]
288 }
289
290 pub(crate) fn next_channel_hint(&self) -> usize {
291 self.counter.fetch_add(1, Ordering::Relaxed)
292 }
293
294 define_idempotent_rpc!(create_session, CreateSessionRequest, Session);
295 define_idempotent_rpc!(execute_sql, ExecuteSqlRequest, crate::model::ResultSet);
296 define_idempotent_rpc!(
297 execute_batch_dml,
298 ExecuteBatchDmlRequest,
299 ExecuteBatchDmlResponse
300 );
301 define_idempotent_rpc!(begin_transaction, BeginTransactionRequest, Transaction);
302 define_idempotent_rpc!(commit, CommitRequest, CommitResponse);
303 define_idempotent_rpc!(rollback, RollbackRequest, ());
304 define_idempotent_rpc!(partition_query, PartitionQueryRequest, PartitionResponse);
305 define_idempotent_rpc!(partition_read, PartitionReadRequest, PartitionResponse);
306
307 pub(crate) fn execute_streaming_sql(
312 &self,
313 request: crate::model::ExecuteSqlRequest,
314 options: crate::RequestOptions,
315 channel_hint: usize,
316 ) -> builder::ExecuteStreamingSql {
317 let channel = self.get_channel(channel_hint);
318 let grpc = channel
319 .grpc_client
320 .as_ref()
321 .expect("Streaming RPCs are not supported when using a stub client");
322 builder::ExecuteStreamingSql::new(grpc.clone())
323 .with_request(request)
324 .with_options(options)
325 }
326
327 pub(crate) fn streaming_read(
332 &self,
333 request: crate::model::ReadRequest,
334 options: crate::RequestOptions,
335 channel_hint: usize,
336 ) -> builder::StreamingRead {
337 let channel = self.get_channel(channel_hint);
338 let grpc = channel
339 .grpc_client
340 .as_ref()
341 .expect("Streaming RPCs are not supported when using a stub client");
342 builder::StreamingRead::new(grpc.clone())
343 .with_request(request)
344 .with_options(options)
345 }
346
347 pub(crate) fn batch_write(
348 &self,
349 request: crate::model::BatchWriteRequest,
350 options: crate::RequestOptions,
351 channel_hint: usize,
352 ) -> builder::BatchWrite {
353 let channel = self.get_channel(channel_hint);
354 let grpc = channel
355 .grpc_client
356 .as_ref()
357 .expect("Streaming RPCs are not supported when using a stub client");
358 builder::BatchWrite::new(grpc.clone())
359 .with_request(request)
360 .with_options(options)
361 }
362}
363
364#[derive(Clone, Debug)]
365pub(crate) struct Channel {
366 pub(crate) inner: GapicSpanner,
367 pub(crate) grpc_client: Option<gaxi::grpc::Client>,
368}
369
370impl Channel {
371 pub(crate) async fn create(config: &ClientConfig) -> crate::ClientBuilderResult<Self> {
372 let transport =
373 crate::generated::gapic_dataplane::transport::Spanner::new(config.clone()).await?;
374 let grpc_client = transport.inner.clone();
375
376 let inner = if gaxi::options::tracing_enabled(config) {
377 GapicSpanner::from_stub(crate::generated::gapic_dataplane::tracing::Spanner::new(
378 transport,
379 ))
380 } else {
381 GapicSpanner::from_stub(transport)
382 };
383 Ok(Self {
384 inner,
385 grpc_client: Some(grpc_client),
386 })
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393 use crate::model::CreateSessionRequest;
394 use crate::read::ReadRequest;
395 use crate::result_set::tests::adapt;
396 use crate::statement::Statement;
397 use gaxi::grpc::tonic::MetadataMap;
398 use gaxi::grpc::tonic::{Code as GrpcCode, Response, Status};
399 use google_cloud_auth::credentials::anonymous::Builder as Anonymous;
400 use google_cloud_gax::backoff_policy::BackoffPolicy;
401 use google_cloud_gax::error::rpc::Code;
402 use google_cloud_gax::retry_state::RetryState;
403 use google_cloud_test_macros::tokio_test_no_panics;
404 use spanner_grpc_mock::google::rpc as mock_rpc;
405 use spanner_grpc_mock::google::spanner::v1 as mock_v1;
406 use spanner_grpc_mock::google::spanner::v1::CommitResponse;
407 use spanner_grpc_mock::google::spanner::v1::ResultSet;
408 use spanner_grpc_mock::google::spanner::v1::ResultSetStats;
409 use spanner_grpc_mock::google::spanner::v1::Session;
410 use spanner_grpc_mock::google::spanner::v1::result_set_stats::RowCount;
411 use spanner_grpc_mock::{MockSpanner, start};
412 use static_assertions::{assert_impl_all, assert_not_impl_any};
413 use std::sync::Arc;
414 use std::sync::atomic::{AtomicU64, Ordering};
415 use std::time::Duration;
416
417 mockall::mock! {
418 #[derive(Debug)]
419 BackoffPolicy {}
420 impl BackoffPolicy for BackoffPolicy {
421 fn on_failure(&self, state: &RetryState) -> Duration;
422 }
423 }
424
425 #[test]
426 fn auto_traits() {
427 assert_impl_all!(Spanner: std::fmt::Debug, Clone, Send, Sync);
428 assert_not_impl_any!(Spanner: std::panic::RefUnwindSafe, std::panic::UnwindSafe);
429 }
430
431 #[tokio_test_no_panics]
432 async fn channel_pool_default_size() {
433 let mock = MockSpanner::new();
434 let (address, _server) = start("0.0.0.0:0", mock)
435 .await
436 .expect("Failed to start mock server");
437
438 let client = Spanner::builder()
439 .with_endpoint(address)
440 .with_credentials(Anonymous::new().build())
441 .build()
442 .await
443 .expect("Failed to build client");
444
445 assert_eq!(client.channels.len(), 4);
446 }
447
448 #[test]
449 fn test_map_emulator_admin_endpoint() {
450 assert_eq!(
452 map_emulator_admin_endpoint("https://spanner.googleapis.com", false),
453 "https://spanner.googleapis.com"
454 );
455
456 assert_eq!(
458 map_emulator_admin_endpoint("http://localhost:9010", true),
459 "http://localhost:9020"
460 );
461
462 assert_eq!(
464 map_emulator_admin_endpoint("http://127.0.0.1:9010/", true),
465 "http://127.0.0.1:9020"
466 );
467
468 assert_eq!(
470 map_emulator_admin_endpoint("http://localhost:9010", false),
471 "http://localhost:9010"
472 );
473 }
474
475 #[tokio_test_no_panics]
476 async fn channel_selection() {
477 let mock = MockSpanner::new();
478 let (address, _server) = start("0.0.0.0:0", mock)
479 .await
480 .expect("Failed to start mock server");
481
482 let client = Spanner::builder()
483 .with_endpoint(address)
484 .with_credentials(Anonymous::new().build())
485 .build()
486 .await
487 .expect("Failed to build client");
488
489 let hint0 = client.next_channel_hint();
490 let hint1 = client.next_channel_hint();
491 let hint2 = client.next_channel_hint();
492 let hint3 = client.next_channel_hint();
493 let hint4 = client.next_channel_hint();
494
495 assert_eq!(hint0 % 4, 0);
496 assert_eq!(hint1 % 4, 1);
497 assert_eq!(hint2 % 4, 2);
498 assert_eq!(hint3 % 4, 3);
499 assert_eq!(hint4 % 4, 0);
500 }
501
502 #[tokio_test_no_panics]
503 async fn test_create_session() {
504 let mut mock = MockSpanner::new();
506 mock.expect_create_session().once().returning(|_| {
507 Ok(gaxi::grpc::tonic::Response::new(mock_v1::Session {
508 name:
509 "projects/test-project/instances/test-instance/databases/test-db/sessions/123"
510 .to_string(),
511 ..Default::default()
512 }))
513 });
514
515 let (address, _server) = start("0.0.0.0:0", mock)
517 .await
518 .expect("Failed to start mock server");
519
520 let client = Spanner::builder()
522 .with_endpoint(address)
523 .with_credentials(Anonymous::new().build())
524 .build()
525 .await
526 .expect("Failed to build client");
527
528 let mut req = CreateSessionRequest::new();
530 req.database =
531 "projects/test-project/instances/test-instance/databases/test-db".to_string();
532
533 let session = client
534 .create_session(
535 req,
536 crate::RequestOptions::default(),
537 client.next_channel_hint(),
538 )
539 .await
540 .expect("Failed to call create_session");
541
542 assert_eq!(
544 session.name,
545 "projects/test-project/instances/test-instance/databases/test-db/sessions/123"
546 );
547 }
548
549 #[tokio_test_no_panics]
550 async fn test_create_session_retry() {
551 use google_cloud_gax::options::RequestOptionsBuilder;
552 use google_cloud_gax::retry_policy::{Aip194Strict, RetryPolicyExt};
553
554 let mut mock = MockSpanner::new();
556 let mut seq = mockall::Sequence::new();
557 mock.expect_create_session()
558 .once()
559 .in_sequence(&mut seq)
560 .returning(|_| {
561 Err(gaxi::grpc::tonic::Status::unavailable(
562 "server is unavailable",
563 ))
564 });
565 mock.expect_create_session().once().in_sequence(&mut seq).returning(|_| {
566 Ok(gaxi::grpc::tonic::Response::new(mock_v1::Session {
567 name: "projects/test-project/instances/test-instance/databases/test-db/sessions/456".to_string(),
568 ..Default::default()
569 }))
570 });
571
572 let (address, _server) = start("0.0.0.0:0", mock)
574 .await
575 .expect("Failed to start mock server");
576
577 let client = Spanner::builder()
580 .with_endpoint(address)
581 .with_credentials(Anonymous::new().build())
582 .build()
583 .await
584 .expect("Failed to build client");
585
586 let mut req = CreateSessionRequest::new();
588 req.database =
589 "projects/test-project/instances/test-instance/databases/test-db".to_string();
590
591 let session = client
592 .get_channel(client.next_channel_hint())
593 .inner
594 .create_session()
595 .with_request(req)
596 .with_idempotency(true)
597 .with_retry_policy(Aip194Strict.with_attempt_limit(3))
598 .send()
599 .await
600 .expect("Failed to call create_session");
601
602 assert_eq!(
604 session.name,
605 "projects/test-project/instances/test-instance/databases/test-db/sessions/456"
606 );
607 }
608
609 #[tokio_test_no_panics]
610 async fn test_execute_sql() {
611 use crate::model::ExecuteSqlRequest;
612
613 let mut mock = MockSpanner::new();
614 mock.expect_execute_sql().once().returning(|_| {
615 Ok(gaxi::grpc::tonic::Response::new(mock_v1::ResultSet {
616 metadata: Some(mock_v1::ResultSetMetadata {
617 row_type: Some(mock_v1::StructType { fields: vec![] }),
618 transaction: None,
619 undeclared_parameters: None,
620 }),
621 rows: vec![],
622 stats: None,
623 precommit_token: None,
624 cache_update: None,
625 }))
626 });
627
628 let (address, _server) = start("0.0.0.0:0", mock)
629 .await
630 .expect("Failed to start mock server");
631 let client = Spanner::builder()
632 .with_endpoint(address)
633 .with_credentials(Anonymous::new().build())
634 .build()
635 .await
636 .expect("Failed to build client");
637
638 let mut req = ExecuteSqlRequest::new();
639 req.sql = "SELECT 1".to_string();
640
641 let result_set = client
642 .execute_sql(
643 req,
644 crate::RequestOptions::default(),
645 client.next_channel_hint(),
646 )
647 .await
648 .expect("Failed to call execute_sql");
649 assert!(result_set.metadata.is_some());
650 }
651
652 #[tokio_test_no_panics]
653 async fn test_execute_batch_dml() {
654 use crate::model::ExecuteBatchDmlRequest;
655
656 let mut mock = MockSpanner::new();
657 mock.expect_execute_batch_dml().once().returning(|_| {
658 Ok(gaxi::grpc::tonic::Response::new(
659 mock_v1::ExecuteBatchDmlResponse {
660 result_sets: vec![],
661 status: Some(mock_rpc::Status {
662 code: 0,
663 message: "OK".to_string(),
664 details: vec![],
665 }),
666 precommit_token: None,
667 },
668 ))
669 });
670
671 let (address, _server) = start("0.0.0.0:0", mock)
672 .await
673 .expect("Failed to start mock server");
674 let client = Spanner::builder()
675 .with_endpoint(address)
676 .with_credentials(Anonymous::new().build())
677 .build()
678 .await
679 .expect("Failed to build client");
680
681 let mut req = ExecuteBatchDmlRequest::new();
682 req.session = "test_session".to_string();
683
684 let response = client
685 .execute_batch_dml(
686 req,
687 crate::RequestOptions::default(),
688 client.next_channel_hint(),
689 )
690 .await
691 .expect("Failed to call execute_batch_dml");
692 assert!(response.status.is_some());
693 }
694
695 #[tokio_test_no_panics]
696 async fn test_begin_transaction() {
697 use crate::model::BeginTransactionRequest;
698
699 let mut mock = MockSpanner::new();
700 mock.expect_begin_transaction().once().returning(|_| {
701 Ok(gaxi::grpc::tonic::Response::new(mock_v1::Transaction {
702 id: vec![1, 2, 3],
703 read_timestamp: None,
704 precommit_token: None,
705 ..Default::default()
706 }))
707 });
708
709 let (address, _server) = start("0.0.0.0:0", mock)
710 .await
711 .expect("Failed to start mock server");
712 let client = Spanner::builder()
713 .with_endpoint(address)
714 .with_credentials(Anonymous::new().build())
715 .build()
716 .await
717 .expect("Failed to build client");
718
719 let mut req = BeginTransactionRequest::new();
720 req.session = "test_session".to_string();
721
722 let tx = client
723 .begin_transaction(
724 req,
725 crate::RequestOptions::default(),
726 client.next_channel_hint(),
727 )
728 .await
729 .expect("Failed to call begin_transaction");
730 assert_eq!(tx.id, vec![1, 2, 3]);
731 }
732
733 #[tokio_test_no_panics]
734 async fn test_commit() {
735 use crate::model::CommitRequest;
736
737 let mut mock = MockSpanner::new();
738 mock.expect_commit().once().returning(|_| {
739 Ok(gaxi::grpc::tonic::Response::new(mock_v1::CommitResponse {
740 commit_timestamp: Some(prost_types::Timestamp {
741 seconds: 12345,
742 nanos: 0,
743 }),
744 commit_stats: None,
745 multiplexed_session_retry: None,
746 snapshot_timestamp: None,
747 ..Default::default()
748 }))
749 });
750
751 let (address, _server) = start("0.0.0.0:0", mock)
752 .await
753 .expect("Failed to start mock server");
754 let client = Spanner::builder()
755 .with_endpoint(address)
756 .with_credentials(Anonymous::new().build())
757 .build()
758 .await
759 .expect("Failed to build client");
760
761 let mut req = CommitRequest::new();
762 req.session = "test_session".to_string();
763
764 let response = client
765 .commit(
766 req,
767 crate::RequestOptions::default(),
768 client.next_channel_hint(),
769 )
770 .await
771 .expect("Failed to call commit");
772 assert!(response.commit_timestamp.is_some());
773 }
774
775 #[tokio_test_no_panics]
776 async fn test_rollback() {
777 use crate::model::RollbackRequest;
778
779 let mut mock = MockSpanner::new();
780 mock.expect_rollback()
781 .once()
782 .returning(|_| Ok(gaxi::grpc::tonic::Response::new(())));
783
784 let (address, _server) = start("0.0.0.0:0", mock)
785 .await
786 .expect("Failed to start mock server");
787 let client = Spanner::builder()
788 .with_endpoint(address)
789 .with_credentials(Anonymous::new().build())
790 .build()
791 .await
792 .expect("Failed to build client");
793
794 let mut req = RollbackRequest::new();
795 req.session = "test_session".to_string();
796
797 client
798 .rollback(
799 req,
800 crate::RequestOptions::default(),
801 client.next_channel_hint(),
802 )
803 .await
804 .expect("Failed to call rollback");
805 }
806
807 #[tokio_test_no_panics]
808 async fn test_execute_streaming_sql() {
809 use crate::model::ExecuteSqlRequest;
810
811 let mut mock = MockSpanner::new();
812 mock.expect_execute_streaming_sql().once().returning(|_| {
813 let result_set = mock_v1::PartialResultSet {
814 metadata: Some(mock_v1::ResultSetMetadata {
815 row_type: Some(mock_v1::StructType { fields: vec![] }),
816 transaction: None,
817 undeclared_parameters: None,
818 }),
819 values: vec![],
820 chunked_value: false,
821 resume_token: vec![],
822 stats: None,
823 precommit_token: None,
824 cache_update: None,
825 last: false,
826 };
827 Ok(gaxi::grpc::tonic::Response::new(adapt([Ok(result_set)])))
828 });
829
830 let (address, _server) = start("0.0.0.0:0", mock)
831 .await
832 .expect("Failed to start mock server");
833 let client = Spanner::builder()
834 .with_endpoint(address)
835 .with_credentials(Anonymous::new().build())
836 .build()
837 .await
838 .expect("Failed to build client");
839
840 let mut req = ExecuteSqlRequest::new();
841 req.sql = "SELECT 1".to_string();
842
843 let mut stream = client
844 .execute_streaming_sql(
845 req,
846 crate::RequestOptions::default(),
847 client.next_channel_hint(),
848 )
849 .send()
850 .await
851 .expect("Failed to call execute_streaming_sql");
852
853 let result = stream.next_message().await;
854 assert!(result.is_some());
855 assert!(result.unwrap().is_ok());
856 }
857
858 #[tokio_test_no_panics]
859 async fn test_streaming_read() {
860 use crate::model::ReadRequest;
861
862 let mut mock = MockSpanner::new();
863 mock.expect_streaming_read().once().returning(|_| {
864 let result_set = mock_v1::PartialResultSet {
865 metadata: Some(mock_v1::ResultSetMetadata {
866 row_type: Some(mock_v1::StructType { fields: vec![] }),
867 transaction: None,
868 undeclared_parameters: None,
869 }),
870 values: vec![],
871 chunked_value: false,
872 resume_token: vec![],
873 stats: None,
874 precommit_token: None,
875 cache_update: None,
876 last: false,
877 };
878 Ok(gaxi::grpc::tonic::Response::from(adapt([Ok(result_set)])))
879 });
880
881 let (address, _server) = start("0.0.0.0:0", mock)
882 .await
883 .expect("Failed to start mock server");
884 let client = Spanner::builder()
885 .with_endpoint(address)
886 .with_credentials(Anonymous::new().build())
887 .build()
888 .await
889 .expect("Failed to build client");
890
891 let mut req = ReadRequest::new();
892 req.table = "test_table".to_string();
893 req.columns = vec!["col1".to_string()];
894
895 let mut stream = client
896 .streaming_read(
897 req,
898 crate::RequestOptions::default(),
899 client.next_channel_hint(),
900 )
901 .send()
902 .await
903 .expect("Failed to call streaming_read");
904
905 let result = stream.next_message().await;
906 assert!(result.is_some());
907 assert!(result.unwrap().is_ok());
908 }
909
910 #[tokio_test_no_panics]
911 async fn test_batch_write() {
912 use crate::model::BatchWriteRequest;
913
914 let mut mock = MockSpanner::new();
915 mock.expect_batch_write().once().returning(|_| {
916 let response = mock_v1::BatchWriteResponse {
917 indexes: vec![],
918 status: None,
919 commit_timestamp: None,
920 };
921 Ok(gaxi::grpc::tonic::Response::from(adapt([Ok(response)])))
922 });
923
924 let (address, _server) = start("0.0.0.0:0", mock)
925 .await
926 .expect("Failed to start mock server");
927 let client = Spanner::builder()
928 .with_endpoint(address)
929 .with_credentials(Anonymous::new().build())
930 .build()
931 .await
932 .expect("Failed to build client");
933
934 let mut req = BatchWriteRequest::new();
935 req.session = "test_session".to_string();
936
937 let mut stream = client
938 .batch_write(
939 req,
940 crate::RequestOptions::default(),
941 client.next_channel_hint(),
942 )
943 .send()
944 .await
945 .expect("Failed to call batch_write");
946
947 let result = stream.next_message().await;
948 assert!(result.is_some());
949 assert!(result.unwrap().is_ok());
950 }
951
952 #[tokio_test_no_panics]
953 async fn test_execute_streaming_sql_error() {
954 use crate::model::ExecuteSqlRequest;
955
956 let mut mock = MockSpanner::new();
957 mock.expect_execute_streaming_sql().once().returning(|_| {
958 let stream = adapt([Err(gaxi::grpc::tonic::Status::internal(
959 "unexpected internal error",
960 ))]);
961 Ok(gaxi::grpc::tonic::Response::from(stream))
962 });
963
964 let (address, _server) = start("0.0.0.0:0", mock)
965 .await
966 .expect("Failed to start mock server");
967 let client = Spanner::builder()
968 .with_endpoint(address)
969 .with_credentials(Anonymous::new().build())
970 .build()
971 .await
972 .expect("Failed to build client");
973
974 let mut req = ExecuteSqlRequest::new();
975 req.sql = "SELECT 1".to_string();
976
977 let mut stream = client
978 .execute_streaming_sql(
979 req,
980 crate::RequestOptions::default(),
981 client.next_channel_hint(),
982 )
983 .send()
984 .await
985 .expect("Failed to call execute_streaming_sql");
986
987 let result = stream.next_message().await;
988 assert!(result.is_some());
989 let err = result.unwrap().expect_err("expected error");
990 assert_eq!(
991 err.status().unwrap().code,
992 google_cloud_gax::error::rpc::Code::Internal
993 );
994 }
995
996 #[tokio_test_no_panics]
997 async fn default_retry_respected() -> anyhow::Result<()> {
998 use crate::model::CreateSessionRequest;
999
1000 let mut mock = MockSpanner::new();
1002 let mut seq = mockall::Sequence::new();
1003 mock.expect_create_session()
1004 .once()
1005 .in_sequence(&mut seq)
1006 .returning(|_| Err(Status::unavailable("server is unavailable")));
1007 mock.expect_create_session().once().in_sequence(&mut seq).returning(|_| {
1008 Ok(Response::new(Session {
1009 name: "projects/test-project/instances/test-instance/databases/test-db/sessions/456".to_string(),
1010 ..Default::default()
1011 }))
1012 });
1013
1014 let (address, _server) = start("0.0.0.0:0", mock).await?;
1016
1017 let client = Spanner::builder()
1019 .with_endpoint(address)
1020 .with_credentials(Anonymous::new().build())
1021 .build()
1022 .await?;
1023
1024 let mut req = CreateSessionRequest::new();
1026 req.database =
1027 "projects/test-project/instances/test-instance/databases/test-db".to_string();
1028
1029 let session = client
1030 .create_session(
1031 req,
1032 crate::RequestOptions::default(),
1033 client.next_channel_hint(),
1034 )
1035 .await
1036 .expect("Failed to call create_session");
1037
1038 assert_eq!(
1040 session.name,
1041 "projects/test-project/instances/test-instance/databases/test-db/sessions/456"
1042 );
1043
1044 Ok(())
1045 }
1046
1047 #[tokio_test_no_panics]
1048 async fn override_idempotency_to_false() -> anyhow::Result<()> {
1049 use crate::model::CreateSessionRequest;
1050
1051 let mut mock = MockSpanner::new();
1053 mock.expect_create_session()
1054 .once()
1055 .returning(|_| Err(Status::unavailable("server is unavailable")));
1056
1057 let (address, _server) = start("0.0.0.0:0", mock).await?;
1059
1060 let client = Spanner::builder()
1062 .with_endpoint(address)
1063 .with_credentials(Anonymous::new().build())
1064 .build()
1065 .await?;
1066
1067 let mut req = CreateSessionRequest::new();
1069 req.database =
1070 "projects/test-project/instances/test-instance/databases/test-db".to_string();
1071
1072 let mut options = crate::RequestOptions::default();
1073 options.set_idempotency(false);
1074
1075 let result = client
1076 .create_session(req, options, client.next_channel_hint())
1077 .await;
1078
1079 assert!(result.is_err(), "Expected error, got {:?}", result);
1081 let err = result.unwrap_err();
1082 assert_eq!(err.status().map(|s| s.code), Some(Code::Unavailable));
1083
1084 Ok(())
1085 }
1086
1087 #[tokio_test_no_panics]
1088 async fn timeout_respected() -> anyhow::Result<()> {
1089 use crate::batch_dml::BatchDml;
1090 use std::time::Duration;
1091
1092 let mut mock = MockSpanner::new();
1094
1095 mock.expect_create_session().returning(|_| {
1096 Ok(Response::new(Session {
1097 name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
1098 ..Default::default()
1099 }))
1100 });
1101
1102 mock.expect_begin_transaction().returning(|_| {
1103 Ok(Response::new(mock_v1::Transaction {
1104 id: vec![42],
1105 ..Default::default()
1106 }))
1107 });
1108
1109 mock.expect_execute_streaming_sql().once().returning(|req| {
1110 let metadata = req.metadata();
1111 let timeout = metadata.get("grpc-timeout");
1112 assert!(
1113 timeout.is_some(),
1114 "grpc-timeout header should be present for query"
1115 );
1116
1117 let (tx, rx) = tokio::sync::mpsc::channel(1);
1118 let metadata = mock_v1::ResultSetMetadata {
1119 transaction: Some(mock_v1::Transaction {
1120 id: vec![42],
1121 ..Default::default()
1122 }),
1123 ..Default::default()
1124 };
1125 let prs = mock_v1::PartialResultSet {
1126 metadata: Some(metadata),
1127 ..Default::default()
1128 };
1129 tx.try_send(Ok(prs)).unwrap();
1130 Ok(Response::new(rx))
1131 });
1132
1133 mock.expect_streaming_read().once().returning(|req| {
1134 let metadata = req.metadata();
1135 let timeout = metadata.get("grpc-timeout");
1136 assert!(
1137 timeout.is_some(),
1138 "grpc-timeout header should be present for read"
1139 );
1140
1141 let (tx, rx) = tokio::sync::mpsc::channel(1);
1142 let metadata = mock_v1::ResultSetMetadata {
1143 transaction: None,
1144 ..Default::default()
1145 };
1146 let prs = mock_v1::PartialResultSet {
1147 metadata: Some(metadata),
1148 ..Default::default()
1149 };
1150 tx.try_send(Ok(prs)).unwrap();
1151 Ok(Response::new(rx))
1152 });
1153
1154 mock.expect_execute_sql().once().returning(|req| {
1155 let metadata = req.metadata();
1156 let timeout = metadata.get("grpc-timeout");
1157 assert!(
1158 timeout.is_some(),
1159 "grpc-timeout header should be present for single DML"
1160 );
1161
1162 Ok(Response::new(mock_v1::ResultSet {
1163 metadata: Some(mock_v1::ResultSetMetadata {
1164 transaction: Some(mock_v1::Transaction {
1165 id: vec![42],
1166 ..Default::default()
1167 }),
1168 ..Default::default()
1169 }),
1170 stats: Some(mock_v1::ResultSetStats {
1171 row_count: Some(mock_v1::result_set_stats::RowCount::RowCountExact(1)),
1172 ..Default::default()
1173 }),
1174 ..Default::default()
1175 }))
1176 });
1177
1178 mock.expect_execute_batch_dml().once().returning(|req| {
1179 let metadata = req.metadata();
1180 let timeout = metadata.get("grpc-timeout");
1181 assert!(
1182 timeout.is_some(),
1183 "grpc-timeout header should be present for batch dml"
1184 );
1185
1186 Ok(Response::new(mock_v1::ExecuteBatchDmlResponse {
1187 result_sets: vec![mock_v1::ResultSet {
1188 stats: Some(mock_v1::ResultSetStats {
1189 row_count: Some(mock_v1::result_set_stats::RowCount::RowCountExact(1)),
1190 ..Default::default()
1191 }),
1192 ..Default::default()
1193 }],
1194 ..Default::default()
1195 }))
1196 });
1197
1198 mock.expect_commit().returning(|_| {
1199 Ok(Response::new(mock_v1::CommitResponse {
1200 commit_timestamp: Some(prost_types::Timestamp {
1201 seconds: 1234,
1202 nanos: 0,
1203 }),
1204 ..Default::default()
1205 }))
1206 });
1207
1208 let (address, _server) = start("0.0.0.0:0", mock).await?;
1210
1211 let client = Spanner::builder()
1213 .with_endpoint(address)
1214 .with_credentials(Anonymous::new().build())
1215 .build()
1216 .await?;
1217
1218 let db = client
1219 .database_client("projects/p/instances/i/databases/d")
1220 .build()
1221 .await?;
1222 let runner = db.read_write_transaction().build().await?;
1223
1224 runner
1226 .run(async |tx| {
1227 let stmt = Statement::builder("SELECT 1")
1229 .with_attempt_timeout(Duration::from_secs(10))
1230 .build();
1231 let _rs = tx.execute_query(stmt).await?;
1233
1234 let req = ReadRequest::builder("Table", vec!["Col"])
1236 .with_keys(crate::key::KeySet::all())
1237 .with_attempt_timeout(Duration::from_secs(5))
1238 .build();
1239 let _ = tx.execute_read(req).await?;
1240
1241 let dml = Statement::builder("UPDATE t SET c = 1")
1243 .with_attempt_timeout(Duration::from_secs(7))
1244 .build();
1245 let _ = tx.execute_update(dml).await?;
1246
1247 let batch = BatchDml::builder()
1249 .add_statement("UPDATE t SET c = 2")
1250 .with_attempt_timeout(Duration::from_secs(8))
1251 .build();
1252 let _ = tx.execute_batch_update(batch).await?;
1253
1254 Ok(())
1255 })
1256 .await?;
1257
1258 Ok(())
1259 }
1260
1261 #[tokio_test_no_panics]
1262 async fn retry_policy_respected() -> anyhow::Result<()> {
1263 use google_cloud_gax::retry_policy::{Aip194Strict, RetryPolicyExt};
1264
1265 let retry_policy = Aip194Strict.continue_on_too_many_requests();
1267
1268 let mut mock = MockSpanner::new();
1270
1271 mock.expect_create_session().returning(|_| {
1272 Ok(Response::new(Session {
1273 name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
1274 ..Default::default()
1275 }))
1276 });
1277
1278 mock.expect_begin_transaction().returning(|_| {
1279 Ok(Response::new(mock_v1::Transaction {
1280 id: vec![42],
1281 ..Default::default()
1282 }))
1283 });
1284
1285 let mut seq = mockall::Sequence::new();
1287
1288 mock.expect_execute_sql()
1289 .once()
1290 .in_sequence(&mut seq)
1291 .returning(|_| Err(Status::new(GrpcCode::ResourceExhausted, "quota exceeded")));
1292
1293 mock.expect_execute_sql()
1294 .once()
1295 .in_sequence(&mut seq)
1296 .returning(|_| {
1297 Ok(Response::new(mock_v1::ResultSet {
1298 metadata: Some(mock_v1::ResultSetMetadata {
1299 transaction: Some(mock_v1::Transaction {
1300 id: vec![42],
1301 ..Default::default()
1302 }),
1303 ..Default::default()
1304 }),
1305 stats: Some(mock_v1::ResultSetStats {
1306 row_count: Some(mock_v1::result_set_stats::RowCount::RowCountExact(1)),
1307 ..Default::default()
1308 }),
1309 ..Default::default()
1310 }))
1311 });
1312
1313 mock.expect_commit().returning(|_| {
1314 Ok(Response::new(mock_v1::CommitResponse {
1315 commit_timestamp: Some(prost_types::Timestamp {
1316 seconds: 1234,
1317 nanos: 0,
1318 }),
1319 ..Default::default()
1320 }))
1321 });
1322
1323 let (address, _server) = start("0.0.0.0:0", mock).await?;
1325
1326 let client = Spanner::builder()
1328 .with_endpoint(address)
1329 .with_credentials(Anonymous::new().build())
1330 .build()
1331 .await?;
1332
1333 let db = client
1334 .database_client("projects/p/instances/i/databases/d")
1335 .build()
1336 .await?;
1337 let runner = db.read_write_transaction().build().await?;
1338
1339 let mut mock_backoff = MockBackoffPolicy::new();
1341 mock_backoff
1342 .expect_on_failure()
1343 .once()
1344 .returning(|_| Duration::from_nanos(1));
1345
1346 let stmt = Statement::builder("UPDATE t SET c = 1")
1347 .with_retry_policy(retry_policy)
1348 .with_backoff_policy(mock_backoff)
1349 .build();
1350
1351 let result = runner
1352 .run(async |tx| {
1353 let count = tx.execute_update(stmt.clone()).await?;
1354 Ok(count)
1355 })
1356 .await?;
1357
1358 assert_eq!(result.result, 1);
1360
1361 Ok(())
1362 }
1363
1364 fn parse_timeout(metadata: &MetadataMap) -> u64 {
1365 let timeout = metadata
1366 .get("grpc-timeout")
1367 .expect("grpc-timeout header should be present");
1368 let timeout_str = timeout
1369 .to_str()
1370 .expect("grpc-timeout should be a valid string");
1371 if timeout_str.ends_with('u') {
1372 timeout_str
1373 .trim_end_matches('u')
1374 .parse()
1375 .expect("valid u64")
1376 } else if timeout_str.ends_with('m') {
1377 timeout_str
1378 .trim_end_matches('m')
1379 .parse::<u64>()
1380 .expect("valid u64")
1381 * 1000
1382 } else if timeout_str.ends_with('n') {
1383 timeout_str
1384 .trim_end_matches('n')
1385 .parse::<u64>()
1386 .expect("valid u64")
1387 / 1000
1388 } else {
1389 panic!("Unknown timeout unit in {}", timeout_str);
1390 }
1391 }
1392
1393 #[tokio_test_no_panics]
1394 async fn transaction_timeout_respected() -> anyhow::Result<()> {
1395 use google_cloud_gax::retry_policy::{Aip194Strict, RetryPolicyExt};
1396 use spanner_grpc_mock::google::spanner::v1::Transaction;
1397
1398 let mut mock = MockSpanner::new();
1400
1401 mock.expect_create_session().returning(|_| {
1402 Ok(Response::new(Session {
1403 name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
1404 ..Default::default()
1405 }))
1406 });
1407
1408 mock.expect_begin_transaction().returning(|_| {
1409 Ok(Response::new(Transaction {
1410 id: vec![1, 2, 3],
1411 ..Default::default()
1412 }))
1413 });
1414
1415 mock.expect_commit().once().returning(|_| {
1416 Ok(Response::new(CommitResponse {
1417 commit_timestamp: Some(prost_types::Timestamp {
1418 seconds: 12345,
1419 nanos: 0,
1420 }),
1421 ..Default::default()
1422 }))
1423 });
1424
1425 let mut seq = mockall::Sequence::new();
1427
1428 mock.expect_execute_sql()
1429 .once()
1430 .in_sequence(&mut seq)
1431 .returning(|req| {
1432 let timeout_val = parse_timeout(req.metadata());
1433 assert!(
1434 timeout_val <= 100000,
1435 "Expected timeout to be <= 100ms, got {}",
1436 timeout_val
1437 );
1438 Err(Status::new(GrpcCode::ResourceExhausted, "quota exceeded"))
1439 });
1440
1441 mock.expect_execute_sql()
1442 .once()
1443 .in_sequence(&mut seq)
1444 .returning(|req| {
1445 let timeout_val = parse_timeout(req.metadata());
1446 assert!(
1447 timeout_val <= 100000,
1448 "Expected timeout to be <= 100ms, got {}",
1449 timeout_val
1450 );
1451
1452 let res = ResultSet {
1453 metadata: Some(spanner_grpc_mock::google::spanner::v1::ResultSetMetadata {
1454 transaction: Some(Transaction {
1455 id: vec![1, 2, 3],
1456 ..Default::default()
1457 }),
1458 ..Default::default()
1459 }),
1460 stats: Some(ResultSetStats {
1461 row_count: Some(RowCount::RowCountExact(1)),
1462 ..Default::default()
1463 }),
1464 ..Default::default()
1465 };
1466 Ok(Response::new(res))
1467 });
1468
1469 let (address, _server) = start("127.0.0.1:0", mock).await?;
1471 let client = Spanner::builder()
1472 .with_endpoint(address)
1473 .with_credentials(Anonymous::new().build())
1474 .build()
1475 .await?;
1476 let db = client
1477 .database_client("projects/p/instances/i/databases/d")
1478 .build()
1479 .await?;
1480
1481 let runner = db
1483 .read_write_transaction()
1484 .with_transaction_timeout(Duration::from_millis(100))
1485 .build()
1486 .await?;
1487
1488 let result = runner
1490 .run(async |tx| {
1491 let mut mock_backoff = MockBackoffPolicy::new();
1492 mock_backoff
1493 .expect_on_failure()
1494 .times(1)
1495 .returning(|_| Duration::from_nanos(1));
1496
1497 let retry_policy = Aip194Strict.continue_on_too_many_requests();
1498
1499 let stmt = Statement::builder("SELECT 1")
1500 .with_retry_policy(retry_policy)
1501 .with_backoff_policy(mock_backoff)
1502 .build();
1503 tx.execute_update(stmt).await?;
1504 Ok(())
1505 })
1506 .await;
1507
1508 result.expect("Transaction should have succeeded");
1509
1510 Ok(())
1511 }
1512
1513 #[tokio_test_no_panics]
1514 async fn transaction_timeout_ticks_down() -> anyhow::Result<()> {
1515 use spanner_grpc_mock::google::spanner::v1::Transaction;
1516
1517 let mut mock = MockSpanner::new();
1518
1519 mock.expect_create_session().returning(|_| {
1520 Ok(Response::new(Session {
1521 name: "projects/p/instances/i/databases/d/sessions/123".to_string(),
1522 ..Default::default()
1523 }))
1524 });
1525
1526 let mut seq = mockall::Sequence::new();
1527
1528 let previous_timeout = Arc::new(AtomicU64::new(0));
1529 let prev_clone1 = previous_timeout.clone();
1530 mock.expect_execute_sql()
1531 .once()
1532 .in_sequence(&mut seq)
1533 .returning(move |req| {
1534 let timeout_val = parse_timeout(req.metadata());
1535 assert!(
1536 timeout_val <= 500000,
1537 "Expected timeout to be <= 500ms, got {}",
1538 timeout_val
1539 );
1540 prev_clone1.store(timeout_val, Ordering::SeqCst);
1541 Err(Status::new(GrpcCode::Aborted, "Aborted"))
1542 });
1543
1544 let prev_clone2 = previous_timeout.clone();
1547 mock.expect_execute_sql()
1548 .once()
1549 .in_sequence(&mut seq)
1550 .returning(move |req| {
1551 let timeout_val = parse_timeout(req.metadata());
1552 let prev = prev_clone2.load(Ordering::SeqCst);
1553 assert!(
1554 timeout_val <= prev,
1555 "Timeout should tick down between attempts or be equal, got {} and {}",
1556 timeout_val,
1557 prev
1558 );
1559 prev_clone2.store(timeout_val, Ordering::SeqCst); let res = ResultSet {
1562 metadata: Some(spanner_grpc_mock::google::spanner::v1::ResultSetMetadata {
1563 transaction: Some(Transaction {
1564 id: vec![2],
1565 ..Default::default()
1566 }),
1567 ..Default::default()
1568 }),
1569 stats: Some(ResultSetStats {
1570 row_count: Some(RowCount::RowCountExact(1)),
1571 ..Default::default()
1572 }),
1573 ..Default::default()
1574 };
1575 Ok(Response::new(res))
1576 });
1577
1578 let prev_clone3 = previous_timeout.clone();
1579 mock.expect_commit().once().returning(move |req| {
1580 let timeout_val = parse_timeout(req.metadata());
1581 let prev = prev_clone3.load(Ordering::SeqCst);
1582 assert!(
1583 timeout_val < prev,
1584 "Timeout should be smaller for commit, got {} and {}",
1585 timeout_val,
1586 prev
1587 );
1588
1589 Ok(Response::new(CommitResponse {
1590 commit_timestamp: Some(prost_types::Timestamp {
1591 seconds: 12345,
1592 nanos: 0,
1593 }),
1594 ..Default::default()
1595 }))
1596 });
1597
1598 let (address, _server) = start("127.0.0.1:0", mock).await?;
1599 let client = Spanner::builder()
1600 .with_endpoint(address)
1601 .with_credentials(Anonymous::new().build())
1602 .build()
1603 .await?;
1604 let db = client
1605 .database_client("projects/p/instances/i/databases/d")
1606 .build()
1607 .await?;
1608
1609 let runner = db
1610 .read_write_transaction()
1611 .with_transaction_timeout(Duration::from_millis(500))
1612 .build()
1613 .await?;
1614
1615 let result = runner
1616 .run(async |tx| {
1617 let stmt = Statement::builder("SELECT 1").build();
1618 tx.execute_update(stmt).await?;
1619 Ok(())
1620 })
1621 .await;
1622
1623 result.expect("Transaction should have succeeded");
1624
1625 Ok(())
1626 }
1627
1628 #[test]
1629 fn test_parse_emulator_endpoint() {
1630 assert_eq!(
1631 super::parse_emulator_endpoint("localhost:9010"),
1632 "http://localhost:9010"
1633 );
1634 assert_eq!(
1635 super::parse_emulator_endpoint("spanner-emulator:9010"),
1636 "http://spanner-emulator:9010"
1637 );
1638 assert_eq!(
1639 super::parse_emulator_endpoint("http://localhost:9010"),
1640 "http://localhost:9010"
1641 );
1642 assert_eq!(
1643 super::parse_emulator_endpoint("https://localhost:9010"),
1644 "https://localhost:9010"
1645 );
1646 assert_eq!(
1647 super::parse_emulator_endpoint("grpc://localhost:9010"),
1648 "grpc://localhost:9010"
1649 );
1650 assert_eq!(
1651 super::parse_emulator_endpoint("http_localhost:9010"),
1652 "http://http_localhost:9010"
1653 );
1654 }
1655}