Skip to main content

qcs_api_client_grpc/tonic/
mod.rs

1//! QCS Middleware for [`tonic`] clients.
2
3use std::convert::Infallible;
4
5use futures_util::pin_mut;
6use http_body::{Body as HttpBody, Frame};
7use http_body_util::StreamBody;
8use qcs_dependencies_client::http::Request;
9use qcs_dependencies_client::prost::bytes::Bytes;
10
11mod channel;
12mod common;
13mod error;
14#[cfg(feature = "grpc-web")]
15mod grpc_web;
16mod refresh;
17mod retry;
18#[cfg(feature = "tracing")]
19mod trace;
20
21pub use channel::*;
22pub use error::*;
23#[cfg(feature = "grpc-web")]
24pub use grpc_web::*;
25use qcs_dependencies_client::tonic::body::Body;
26pub use refresh::*;
27pub use retry::*;
28#[cfg(feature = "tracing")]
29pub use trace::*;
30/// An error observed while duplicating a request body. This may be returned by any
31/// [`qcs_dependencies_client::tower::Service`] that duplicates a request body for the purpose of retrying a request.
32#[derive(Debug, thiserror::Error)]
33pub enum RequestBodyDuplicationError {
34    /// The inner service returned an error from the server, or the client cancelled the
35    /// request.
36    #[error(transparent)]
37    Status(#[from] qcs_dependencies_client::tonic::Status),
38    /// Failed to read the request body for cloning.
39    #[error("failed to read request body for request clone: {0}")]
40    HttpBody(#[from] qcs_dependencies_client::http::Error),
41}
42
43impl From<RequestBodyDuplicationError> for qcs_dependencies_client::tonic::Status {
44    fn from(err: RequestBodyDuplicationError) -> qcs_dependencies_client::tonic::Status {
45        match err {
46            RequestBodyDuplicationError::Status(status) => status,
47            RequestBodyDuplicationError::HttpBody(error) => {
48                qcs_dependencies_client::tonic::Status::cancelled(format!(
49                    "failed to read request body for request clone: {error}"
50                ))
51            }
52        }
53    }
54}
55
56type RequestBodyDuplicationResult<T> = Result<T, RequestBodyDuplicationError>;
57
58/// The AWS ALB can misinterpret requests which contain a `content-length` header,
59/// causing requests which are routed through a gateway to fail with an h2 protocol error.
60/// Using a [`StreamBody`], even for Unary requests, will prevent `tonic` from sending that header.
61fn make_stream_body(bytes: Bytes) -> qcs_dependencies_client::tonic::body::Body {
62    let body = Frame::data(bytes);
63
64    // Create a stream that yields a single frame
65    let stream = futures_util::stream::once(std::future::ready(Ok::<_, Infallible>(body)));
66
67    let stream_body = StreamBody::new(stream);
68
69    qcs_dependencies_client::tonic::body::Body::new(stream_body)
70}
71
72/// This function should only be used with Unary requests; Stream requests are
73/// untested. It eagerly collects all request data into a buffer, consuming the
74/// original stream. Additionally, it assumes that all frames are data frames
75/// (i.e. the stream cannot contain any trailers); if a trailer frame is found,
76/// the cancelled status will be returned.
77async fn build_duplicate_frame_bytes(
78    mut request: Request<qcs_dependencies_client::tonic::body::Body>,
79) -> RequestBodyDuplicationResult<(
80    qcs_dependencies_client::tonic::body::Body,
81    qcs_dependencies_client::tonic::body::Body,
82)> {
83    let mut bytes = Vec::new();
84
85    let body = request.body_mut();
86    pin_mut!(body);
87    while let Some(result) = std::future::poll_fn(|cx| body.as_mut().poll_frame(cx)).await {
88        let frame_bytes = result?.into_data().map_err(|frame| {
89            qcs_dependencies_client::tonic::Status::cancelled(format!(
90                "cannot duplicate a frame that is not a data frame: {frame:?}"
91            ))
92        })?;
93        bytes.extend(frame_bytes);
94    }
95
96    let bytes = Bytes::from(bytes);
97
98    Ok((make_stream_body(bytes.clone()), make_stream_body(bytes)))
99}
100
101/// This function should only be used with Unary requests; Stream requests are
102/// untested. See comment on `build_duplicate_frame_bytes`.
103async fn build_duplicate_request(
104    req: Request<Body>,
105) -> RequestBodyDuplicationResult<(Request<Body>, Request<Body>)> {
106    let mut builder_1 = Request::builder()
107        .method(req.method().clone())
108        .uri(req.uri().clone())
109        .version(req.version());
110
111    let mut builder_2 = Request::builder()
112        .method(req.method().clone())
113        .uri(req.uri().clone())
114        .version(req.version());
115
116    for (key, val) in req.headers() {
117        builder_1 = builder_1.header(key, val);
118        builder_2 = builder_2.header(key, val);
119    }
120
121    let (body_1, body_2) = build_duplicate_frame_bytes(req).await?;
122
123    let req_1 = builder_1.body(body_1)?;
124
125    let req_2 = builder_2.body(body_2)?;
126
127    Ok((req_1, req_2))
128}
129
130/// This module manages a gRPC server-client connection over a Unix domain socket. Useful for unit testing
131/// servers or clients within unit tests - supports parallelization within same process and
132/// requires no port management.
133///
134/// Derived largely from <https://stackoverflow.com/a/71808401> and
135/// <https://github.com/hyperium/tonic/tree/master/examples/src/uds>.
136#[cfg(test)]
137pub(crate) mod uds_grpc_stream {
138    use hyper_util::rt::TokioIo;
139    use qcs_dependencies_client::opentelemetry::trace::FutureExt;
140    use qcs_dependencies_client::tonic::server::NamedService;
141    use qcs_dependencies_client::tonic::transport::{Channel, Endpoint, Server};
142    use std::convert::Infallible;
143    use tempfile::TempDir;
144    use tokio::net::UnixStream;
145    use tokio_stream::wrappers::UnixListenerStream;
146
147    /// The can be any valid URL. It is necessary to initialize an [`Endpoint`].
148    #[allow(dead_code)]
149    static FAUX_URL: &str = "http://api.example.rigetti.com";
150
151    fn build_server_stream(path: String) -> Result<UnixListenerStream, Error> {
152        let uds =
153            tokio::net::UnixListener::bind(path.clone()).map_err(|_| Error::BindUnixPath(path))?;
154        Ok(UnixListenerStream::new(uds))
155    }
156
157    async fn build_client_channel(path: String) -> Result<Channel, Error> {
158        let connector = qcs_dependencies_client::tower::service_fn(
159            move |_: qcs_dependencies_client::tonic::transport::Uri| {
160                let path = path.clone();
161                async move {
162                    let connection = UnixStream::connect(path).await?;
163                    Ok::<_, std::io::Error>(TokioIo::new(connection))
164                }
165            },
166        );
167        let channel = Endpoint::try_from(FAUX_URL)
168            .map_err(|source| Error::Endpoint {
169                url: FAUX_URL.to_string(),
170                source,
171            })?
172            .connect_with_connector(connector)
173            .await
174            .map_err(|source| Error::Connect { source })?;
175        Ok(channel)
176    }
177
178    /// Errors when connecting a server-client [`Channel`] over a Unix domain socket for testing gRPC
179    /// services.
180    #[derive(thiserror::Error, Debug)]
181    pub enum Error {
182        /// Failed to initialize an [`Endpoint`] for the provided URL.
183        #[error("failed to initialize endpoint for {url}: {source}")]
184        Endpoint {
185            /// The URL that failed to initialize.
186            url: String,
187            /// The source of the error.
188            #[source]
189            source: qcs_dependencies_client::tonic::transport::Error,
190        },
191        /// Failed to connect to the provided endpoint.
192        #[error("failed to connect to endpoint: {source}")]
193        Connect {
194            /// The source of the error.
195            #[source]
196            source: qcs_dependencies_client::tonic::transport::Error,
197        },
198        /// Failed to bind the provided path as a Unix domain socket.
199        #[error("failed to bind path as unix listener: {0}")]
200        BindUnixPath(String),
201        /// Failed to initialize a temporary file for the Unix domain socket.
202        #[error("failed in initialize tempfile: {0}")]
203        TempFile(#[from] std::io::Error),
204        /// Failed to convert the tempfile to an [`OsString`].
205        #[error("failed to convert tempfile to OsString")]
206        TempFileOsString,
207        /// Failed to bind to the Unix domain socket.
208        #[error("failed to bind to UDS: {0}")]
209        TonicTransport(#[from] qcs_dependencies_client::tonic::transport::Error),
210    }
211
212    /// Serve the provided gRPC service over a Unix domain socket for the duration of the
213    /// provided callback.
214    ///
215    /// # Errors
216    ///
217    /// See [`Error`].
218    #[allow(clippy::significant_drop_tightening)]
219    pub async fn serve<S, F, R, B>(service: S, f: F) -> Result<(), Error>
220    where
221        S: qcs_dependencies_client::tower::Service<
222                qcs_dependencies_client::http::Request<qcs_dependencies_client::tonic::body::Body>,
223                Response = qcs_dependencies_client::http::Response<B>,
224                Error = Infallible,
225            > + NamedService
226            + Clone
227            + Send
228            + Sync
229            + 'static,
230        S::Future: Send + 'static,
231        B: http_body::Body<Data = qcs_dependencies_client::tonic::codegen::Bytes> + Send + 'static,
232        B::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + Sync,
233        F: FnOnce(Channel) -> R + Send,
234        R: std::future::Future<Output = ()> + Send,
235    {
236        let directory = TempDir::new()?;
237        let file = directory.path().as_os_str();
238        let file = file.to_os_string();
239        let file = file.into_string().map_err(|_| Error::TempFileOsString)?;
240        let file = format!("{file}/test.sock");
241        let stream = build_server_stream(file.clone())?;
242
243        let channel = build_client_channel(file).await?;
244        let serve_future = Server::builder()
245            .add_service(service)
246            .serve_with_incoming(stream);
247
248        tokio::select! {
249           result = serve_future => result.map_err(Error::TonicTransport),
250           () = f(channel).with_current_context() => Ok(()),
251        }
252    }
253}
254
255#[cfg(test)]
256#[cfg(feature = "tracing-opentelemetry")]
257mod otel_tests {
258    use qcs_api_client_common::configuration::secrets::{SecretAccessToken, SecretRefreshToken};
259    use qcs_api_client_common::configuration::tokens::RefreshToken;
260    use qcs_dependencies_client::opentelemetry::propagation::TextMapPropagator;
261    use qcs_dependencies_client::opentelemetry::trace::{TraceContextExt, TraceId};
262    use qcs_dependencies_client::opentelemetry_http::HeaderExtractor;
263    use qcs_dependencies_client::opentelemetry_sdk::propagation::TraceContextPropagator;
264    use qcs_dependencies_client::tonic::Request;
265    use qcs_dependencies_client::tonic::codegen::http::{HeaderMap, HeaderValue};
266    use qcs_dependencies_client::tonic::server::NamedService;
267    use qcs_dependencies_client::tonic_health::pb::health_check_response::ServingStatus;
268    use qcs_dependencies_client::tonic_health::pb::health_server::{Health, HealthServer};
269    use qcs_dependencies_client::tonic_health::{
270        pb::health_client::HealthClient, server::HealthService,
271    };
272    use serde::{Deserialize, Serialize};
273    use std::time::{Duration, SystemTime};
274
275    use super::{uds_grpc_stream, wrap_channel_with_tracing};
276    use qcs_api_client_common::configuration::ClientConfiguration;
277    use qcs_api_client_common::configuration::{settings::AuthServer, tokens::OAuthSession};
278
279    static HEALTH_CHECK_PATH: &str = "/grpc.health.v1.Health/Check";
280
281    const FAUX_BASE_URL: &str = "http://api.example.rigetti.com";
282
283    /// Test that when tracing is enabled and no filter is set, any request is properly traced.
284    #[tokio::test]
285    async fn test_tracing_enabled_no_filter() {
286        use qcs_api_client_common::tracing_configuration::TracingConfiguration;
287
288        let tracing_configuration = TracingConfiguration::builder()
289            .set_propagate_otel_context(true)
290            .build();
291        let client_config = ClientConfiguration::builder()
292            .tracing_configuration(Some(tracing_configuration))
293            .oauth_session(Some(OAuthSession::from_refresh_token(
294                RefreshToken::new(SecretRefreshToken::from("refresh_token")),
295                AuthServer::default(),
296                Some(create_jwt()),
297            )))
298            .build()
299            .expect("should be able to build client config");
300        assert_grpc_health_check_traced(client_config).await;
301    }
302
303    /// Test that when tracing is enabled, no filter is set, and OTel context propagation is
304    /// disabled, any request is properly traced without propagation.
305    #[tokio::test]
306    async fn test_tracing_enabled_no_filter_nor_otel_context_propagation() {
307        use qcs_api_client_common::tracing_configuration::TracingConfiguration;
308
309        let tracing_configuration = TracingConfiguration::default();
310        let client_config = ClientConfiguration::builder()
311            .tracing_configuration(Some(tracing_configuration))
312            .oauth_session(Some(OAuthSession::from_refresh_token(
313                RefreshToken::new(SecretRefreshToken::from("refresh_token")),
314                AuthServer::default(),
315                Some(create_jwt()),
316            )))
317            .build()
318            .expect("failed to build client config");
319        assert_grpc_health_check_traced(client_config).await;
320    }
321
322    /// Test that when tracing is enabled and the filter matches the gRPC request, the request is
323    /// properly traced.
324    #[tokio::test]
325    async fn test_tracing_enabled_filter_passed() {
326        use qcs_api_client_common::tracing_configuration::{TracingConfiguration, TracingFilter};
327
328        let tracing_filter = TracingFilter::builder()
329            .parse_strs_and_set_paths(&[HEALTH_CHECK_PATH])
330            .expect("gRPC healthcheck path should be valid filter path")
331            .build();
332
333        let tracing_configuration = TracingConfiguration::builder()
334            .set_filter(Some(tracing_filter))
335            .set_propagate_otel_context(true)
336            .build();
337
338        let client_config = ClientConfiguration::builder()
339            .tracing_configuration(Some(tracing_configuration))
340            .oauth_session(Some(OAuthSession::from_refresh_token(
341                RefreshToken::new(SecretRefreshToken::from("refresh_token")),
342                AuthServer::default(),
343                Some(create_jwt()),
344            )))
345            .build()
346            .expect("failed to build client config");
347        assert_grpc_health_check_traced(client_config).await;
348    }
349
350    /// Checks that the the [`RefreshService`] propagates the trace context via the traceparent metadata header and that the gRPC
351    /// request span is properly created (ie the span duration is reasonable).
352    #[allow(clippy::future_not_send)]
353    async fn assert_grpc_health_check_traced(client_configuration: ClientConfiguration) {
354        use qcs_dependencies_client::opentelemetry::trace::FutureExt;
355
356        let propagate_otel_context = client_configuration.tracing_configuration().is_some_and(
357            qcs_api_client_common::tracing_configuration::TracingConfiguration::propagate_otel_context,
358        );
359        let spans: Vec<qcs_dependencies_client::opentelemetry_sdk::trace::SpanData> =
360            tracing_test::start(
361                "test_trace_id_propagation",
362                |trace_id, _span_id| async move {
363                    let sleepy_health_service = SleepyHealthService {
364                        sleep_time: Duration::from_millis(50),
365                    };
366
367                    let interceptor = move |req| {
368                        if propagate_otel_context {
369                            validate_trace_id_propagated(trace_id, req)
370                        } else {
371                            validate_otel_context_not_propagated(req)
372                        }
373                    };
374                    let health_server =
375                        HealthServer::with_interceptor(sleepy_health_service, interceptor);
376
377                    uds_grpc_stream::serve(health_server, |channel| {
378                        async {
379                            let response = HealthClient::new(wrap_channel_with_tracing(
380                                channel,
381                                FAUX_BASE_URL.to_string(),
382                                client_configuration
383                                    .tracing_configuration()
384                                    .unwrap()
385                                    .clone(),
386                            ))
387                            .check(Request::new(
388                                qcs_dependencies_client::tonic_health::pb::HealthCheckRequest {
389                                    service: <HealthServer<HealthService> as NamedService>::NAME
390                                        .to_string(),
391                                },
392                            ))
393                            .await
394                            .unwrap();
395                            assert_eq!(response.into_inner().status(), ServingStatus::Serving);
396                        }
397                        .with_current_context()
398                    })
399                    .await
400                    .unwrap();
401                },
402            )
403            .await
404            .unwrap();
405
406        let grpc_span = spans
407            .iter()
408            .find(|span| span.name == *HEALTH_CHECK_PATH)
409            .expect("failed to find gRPC span");
410        let duration = grpc_span
411            .end_time
412            .duration_since(grpc_span.start_time)
413            .expect("span should have valid timestamps");
414        assert!(duration.as_millis() >= 50u128);
415
416        let status_code_attribute =
417            tracing_test::get_span_attribute(grpc_span, "rpc.grpc.status_code")
418                .expect("gRPC span should have status code attribute");
419        assert_eq!(
420            status_code_attribute,
421            (qcs_dependencies_client::tonic::Code::Ok as u8).to_string()
422        );
423    }
424
425    /// Test that when tracing is enabled but the request does not match the configured filter, the
426    /// request is not traced.
427    #[tokio::test]
428    async fn test_tracing_enabled_filter_not_passed() {
429        use qcs_api_client_common::tracing_configuration::{TracingConfiguration, TracingFilter};
430
431        let tracing_filter = TracingFilter::builder()
432            .parse_strs_and_set_paths(&[HEALTH_CHECK_PATH])
433            .expect("healthcheck path should be a valid filter path")
434            .set_is_negated(true)
435            .build();
436
437        let tracing_configuration = TracingConfiguration::builder()
438            .set_filter(Some(tracing_filter))
439            .set_propagate_otel_context(true)
440            .build();
441
442        let client_config = ClientConfiguration::builder()
443            .tracing_configuration(Some(tracing_configuration))
444            .oauth_session(Some(OAuthSession::from_refresh_token(
445                RefreshToken::new(SecretRefreshToken::from("refresh_token")),
446                AuthServer::default(),
447                Some(create_jwt()),
448            )))
449            .build()
450            .expect("should be able to build client config");
451
452        assert_grpc_health_check_not_traced(client_config.clone()).await;
453    }
454
455    /// Check that the traceparent metadata header is not set on the gRPC request and no tracing
456    /// spans are produced for the gRPC request.
457    #[allow(clippy::future_not_send)]
458    async fn assert_grpc_health_check_not_traced(client_configuration: ClientConfiguration) {
459        use qcs_dependencies_client::opentelemetry::trace::FutureExt;
460
461        let spans: Vec<qcs_dependencies_client::opentelemetry_sdk::trace::SpanData> =
462            tracing_test::start("test_tracing_disabled", |_trace_id, _span_id| async move {
463                let interceptor = validate_otel_context_not_propagated;
464                let health_server = HealthServer::with_interceptor(
465                    SleepyHealthService {
466                        sleep_time: Duration::from_millis(0),
467                    },
468                    interceptor,
469                );
470
471                uds_grpc_stream::serve(health_server, |channel| {
472                    async {
473                        let response = HealthClient::new(wrap_channel_with_tracing(
474                            channel,
475                            FAUX_BASE_URL.to_string(),
476                            client_configuration
477                                .tracing_configuration()
478                                .unwrap()
479                                .clone(),
480                        ))
481                        .check(Request::new(
482                            qcs_dependencies_client::tonic_health::pb::HealthCheckRequest {
483                                service: <HealthServer<HealthService> as NamedService>::NAME
484                                    .to_string(),
485                            },
486                        ))
487                        .await
488                        .unwrap();
489                        assert_eq!(response.into_inner().status(), ServingStatus::Serving);
490                    }
491                    .with_current_context()
492                })
493                .await
494                .unwrap();
495            })
496            .await
497            .unwrap();
498
499        assert!(spans.iter().all(|span| { span.name != *HEALTH_CHECK_PATH }));
500    }
501
502    const JWT_SECRET: &[u8] = b"top-secret";
503
504    #[derive(Clone, Debug, Serialize, Deserialize)]
505    struct Claims {
506        sub: String,
507        exp: u64,
508    }
509
510    /// Create an HS256 signed JWT token with sub and exp claims. This is good enough to pass the
511    /// [`RefreshService`] token validation.
512    pub(crate) fn create_jwt() -> SecretAccessToken {
513        use jsonwebtoken::{Algorithm, EncodingKey, Header, encode};
514        let expiration = SystemTime::now()
515            .checked_add(Duration::from_secs(60))
516            .unwrap()
517            .duration_since(SystemTime::UNIX_EPOCH)
518            .unwrap()
519            .as_secs();
520
521        let claims = Claims {
522            sub: "test-uid".to_string(),
523            exp: expiration,
524        };
525        // The client doesn't check the signature, so for convenience here, we just sign with HS256
526        // instead of RS256.
527        let header = Header::new(Algorithm::HS256);
528        encode(&header, &claims, &EncodingKey::from_secret(JWT_SECRET))
529            .map(SecretAccessToken::from)
530            .unwrap()
531    }
532
533    #[derive(Debug, thiserror::Error)]
534    #[allow(variant_size_differences)]
535    enum ServerAssertionError {
536        #[error("trace id did not propagate to server: {0}")]
537        UnexpectedTraceId(String),
538        #[error("otel context headers unexpectedly sent to server")]
539        UnexpectedOTelContextHeaders,
540    }
541
542    impl From<ServerAssertionError> for qcs_dependencies_client::tonic::Status {
543        fn from(server_assertion_error: ServerAssertionError) -> Self {
544            Self::invalid_argument(server_assertion_error.to_string())
545        }
546    }
547
548    /// Given an incoming gRPC request, validate that the the specified [`TraceId`] is propagated
549    /// via the `traceparent` metadata header.
550    #[allow(clippy::result_large_err)]
551    fn validate_trace_id_propagated(
552        trace_id: TraceId,
553        req: Request<()>,
554    ) -> Result<Request<()>, qcs_dependencies_client::tonic::Status> {
555        req.metadata()
556            .get("traceparent")
557            .ok_or_else(|| {
558                ServerAssertionError::UnexpectedTraceId(
559                    "request traceparent metadata not present".to_string(),
560                )
561            })
562            .and_then(|traceparent| {
563                let mut headers = HeaderMap::new();
564                headers.insert(
565                    "traceparent",
566                    HeaderValue::from_str(traceparent.to_str().map_err(|_| {
567                        ServerAssertionError::UnexpectedTraceId(
568                            "failed to deserialize trace parent".to_string(),
569                        )
570                    })?)
571                    .map_err(|_| {
572                        ServerAssertionError::UnexpectedTraceId(
573                            "failed to serialize trace parent as HeaderValue".to_string(),
574                        )
575                    })?,
576                );
577                Ok(headers)
578            })
579            .and_then(|headers| {
580                let extractor = HeaderExtractor(&headers);
581                let propagator = TraceContextPropagator::new();
582                let context = propagator.extract(&extractor);
583                let propagated_trace_id = context.span().span_context().trace_id();
584                if propagated_trace_id == trace_id {
585                    Ok(req)
586                } else {
587                    Err(ServerAssertionError::UnexpectedTraceId(format!(
588                        "expected trace id {trace_id}, got {propagated_trace_id}",
589                    )))
590                }
591            })
592            .map_err(Into::into)
593    }
594
595    /// Simply validate that the `traceparent` and `tracestate` metadata headers are not present
596    /// on the incoming gRPC.
597    #[allow(clippy::result_large_err)]
598    fn validate_otel_context_not_propagated(
599        req: Request<()>,
600    ) -> Result<Request<()>, qcs_dependencies_client::tonic::Status> {
601        if req.metadata().get("traceparent").is_some() || req.metadata().get("tracestate").is_some()
602        {
603            Err(ServerAssertionError::UnexpectedOTelContextHeaders.into())
604        } else {
605            Ok(req)
606        }
607    }
608
609    /// An implementation of the gRPC [`HealthService`] that sleeps for the configured duration on before returning a response.
610    /// This is useful for making assertions on span durations. It is also necessary in order to
611    /// wrap the [`HealthServer`] in an interceptor, which is not possible with public methods in
612    /// the health service crate.
613    ///
614    /// Derived in part from <https://github.com/hyperium/tonic/blob/master/tonic-health/src/generated/grpc.health.v1.rs/>
615    #[derive(Clone)]
616    struct SleepyHealthService {
617        sleep_time: Duration,
618    }
619
620    #[qcs_dependencies_client::tonic::async_trait]
621    impl Health for SleepyHealthService {
622        async fn check(
623            &self,
624            _request: Request<qcs_dependencies_client::tonic_health::pb::HealthCheckRequest>,
625        ) -> Result<
626            qcs_dependencies_client::tonic::Response<
627                qcs_dependencies_client::tonic_health::pb::HealthCheckResponse,
628            >,
629            qcs_dependencies_client::tonic::Status,
630        > {
631            tokio::time::sleep(self.sleep_time).await;
632            let response = qcs_dependencies_client::tonic_health::pb::HealthCheckResponse {
633                status: ServingStatus::Serving as i32,
634            };
635            Ok(qcs_dependencies_client::tonic::Response::new(response))
636        }
637
638        type WatchStream = tokio_stream::wrappers::ReceiverStream<
639            Result<
640                qcs_dependencies_client::tonic_health::pb::HealthCheckResponse,
641                qcs_dependencies_client::tonic::Status,
642            >,
643        >;
644
645        async fn watch(
646            &self,
647            _request: Request<qcs_dependencies_client::tonic_health::pb::HealthCheckRequest>,
648        ) -> Result<
649            qcs_dependencies_client::tonic::Response<Self::WatchStream>,
650            qcs_dependencies_client::tonic::Status,
651        > {
652            let (tx, rx) = tokio::sync::mpsc::channel(1);
653            let response = qcs_dependencies_client::tonic_health::pb::HealthCheckResponse {
654                status: ServingStatus::Serving as i32,
655            };
656            tx.send(Ok(response)).await.unwrap();
657            Ok(qcs_dependencies_client::tonic::Response::new(
658                tokio_stream::wrappers::ReceiverStream::new(rx),
659            ))
660        }
661    }
662
663    /// We need a single global ``SpanProcessor`` because these tests have to work using
664    /// ``qcs_dependencies_client::opentelemetry::global`` and ``tracing_subscriber::set_global_default``. Otherwise,
665    /// we can't make guarantees about where the spans are processed and therefore could
666    /// not make assertions about the traced spans.
667    mod tracing_test {
668        use futures_util::Future;
669        use qcs_dependencies_client::opentelemetry::global::BoxedTracer;
670        use qcs_dependencies_client::opentelemetry::trace::{
671            FutureExt, Span, SpanId, TraceId, Tracer, TracerProvider, mark_span_as_active,
672        };
673        use qcs_dependencies_client::opentelemetry_sdk::error::OTelSdkError;
674        use qcs_dependencies_client::opentelemetry_sdk::trace::{SpanData, SpanProcessor};
675        use std::collections::HashMap;
676        use std::sync::{Arc, RwLock};
677        use tokio::sync::oneshot;
678
679        /// Start a new test span and run the specified callback with the span as the active span.
680        /// The call back is provided the span and trace ids. At the end of the callback, we wait
681        /// for the span to be processed and then return all of the spans that were processed for
682        /// this particular test.
683        #[allow(clippy::future_not_send)]
684        pub(crate) async fn start<F, R>(name: &'static str, f: F) -> Result<Vec<SpanData>, String>
685        where
686            F: FnOnce(TraceId, SpanId) -> R + Send,
687            R: Future<Output = ()> + Send,
688        {
689            let tracer = CacheProcessor::tracer();
690            let span = tracer.start(name);
691            let span_id = span.span_context().span_id();
692            let trace_id = span.span_context().trace_id();
693            let cache = CACHE
694                .get()
695                .expect("cache should be initialized with cache tracer");
696            let span_rx = cache.subscribe(span_id);
697            async {
698                let _guard = mark_span_as_active(span);
699                f(trace_id, span_id).with_current_context().await;
700            }
701            .await;
702
703            // wait for test span to be processed.
704            span_rx.await.unwrap();
705
706            // remove and return the spans processed for this test.
707            let mut data = cache.data.write().map_err(|e| e.to_string())?;
708            Ok(data.remove(&trace_id).unwrap_or_default())
709        }
710
711        static CACHE: once_cell::sync::OnceCell<CacheProcessor> = once_cell::sync::OnceCell::new();
712
713        #[derive(Debug, Clone, Default)]
714        struct CacheProcessor {
715            data: Arc<RwLock<HashMap<TraceId, Vec<SpanData>>>>,
716            notifications: Arc<RwLock<HashMap<SpanId, tokio::sync::oneshot::Sender<()>>>>,
717        }
718
719        impl CacheProcessor {
720            /// Initializes the [`CACHE`] and sets the global `qcs_dependencies_client::opentelemetry::global::tracer_provider` and `tracing_subscriber::global_default`.
721            /// These initializations occur safely behind a `OnceCell` initialization, so they can be used by several tests.
722            fn tracer() -> BoxedTracer {
723                use tracing_subscriber::layer::SubscriberExt;
724
725                CACHE.get_or_init(|| {
726                    let processor = Self::default();
727                    let provider = qcs_dependencies_client::opentelemetry_sdk::trace::SdkTracerProvider::builder()
728                        .with_span_processor(processor.clone())
729                        .build();
730                    qcs_dependencies_client::opentelemetry::global::set_tracer_provider(
731                        provider.clone(),
732                    );
733                    let tracer = provider.tracer("test_channel");
734                    let telemetry =
735                        qcs_dependencies_client::tracing_opentelemetry::layer().with_tracer(tracer);
736                    let subscriber =
737                        tracing_subscriber::Registry::default()
738                            .with(telemetry);
739                    tracing::subscriber::set_global_default(subscriber)
740                        .expect("tracing subscriber already set");
741                    processor
742                });
743                qcs_dependencies_client::opentelemetry::global::tracer("test_channel")
744            }
745
746            /// Ensure that a [`Notification`] exists for the provided span id.
747            fn subscribe(&self, span_id: SpanId) -> oneshot::Receiver<()> {
748                let (tx, rx) = oneshot::channel();
749                self.notifications.write().unwrap().insert(span_id, tx);
750                rx
751            }
752        }
753
754        impl SpanProcessor for CacheProcessor {
755            fn on_start(
756                &self,
757                _span: &mut qcs_dependencies_client::opentelemetry_sdk::trace::Span,
758                _cx: &qcs_dependencies_client::opentelemetry::Context,
759            ) {
760            }
761
762            fn on_end(&self, span: SpanData) {
763                let trace_id = span.span_context.trace_id();
764                let span_id = span.span_context.span_id();
765                {
766                    let mut data = self
767                        .data
768                        .write()
769                        .expect("failed to write access cache span data");
770                    data.entry(trace_id).or_default().push(span);
771                }
772
773                if let Some(notify) = self
774                    .notifications
775                    .write()
776                    .expect("failed to read access span notifications during span processing")
777                    .remove(&span_id)
778                {
779                    notify.send(()).unwrap();
780                }
781            }
782
783            /// This is a no-op because spans are processed synchronously in `on_end`.
784            fn force_flush(&self) -> Result<(), OTelSdkError> {
785                Ok(())
786            }
787
788            fn shutdown(&self) -> Result<(), OTelSdkError> {
789                Ok(())
790            }
791
792            fn shutdown_with_timeout(
793                &self,
794                _timeout: std::time::Duration,
795            ) -> Result<(), OTelSdkError> {
796                Ok(())
797            }
798        }
799
800        /// Get the Opentelemetry attribute value for the provided key.
801        #[must_use]
802        pub(super) fn get_span_attribute(
803            span_data: &SpanData,
804            key: &'static str,
805        ) -> Option<String> {
806            span_data
807                .attributes
808                .iter()
809                .find(|attr| attr.key.to_string() == *key)
810                .map(|kv| kv.value.to_string())
811        }
812    }
813}