Skip to main content

qcs_api_client_grpc/tonic/
mod.rs

1//! QCS Middleware for [`tonic`] clients.
2
3use std::convert::Infallible;
4
5use futures_util::pin_mut;
6use http::Request;
7use http_body::{Body as HttpBody, Frame};
8use http_body_util::StreamBody;
9use prost::bytes::Bytes;
10
11mod channel;
12mod common;
13mod error;
14#[cfg(feature = "grpc-web")]
15mod grpc_web;
16mod refresh;
17mod retry;
18#[cfg(feature = "tracing")]
19mod trace;
20
21pub use channel::*;
22pub use error::*;
23#[cfg(feature = "grpc-web")]
24pub use grpc_web::*;
25pub use refresh::*;
26pub use retry::*;
27use tonic::body::Body;
28#[cfg(feature = "tracing")]
29pub use trace::*;
30/// An error observed while duplicating a request body. This may be returned by any
31/// [`tower::Service`] that duplicates a request body for the purpose of retrying a request.
32#[derive(Debug, thiserror::Error)]
33pub enum RequestBodyDuplicationError {
34    /// The inner service returned an error from the server, or the client cancelled the
35    /// request.
36    #[error(transparent)]
37    Status(#[from] tonic::Status),
38    /// Failed to read the request body for cloning.
39    #[error("failed to read request body for request clone: {0}")]
40    HttpBody(#[from] http::Error),
41}
42
43impl From<RequestBodyDuplicationError> for tonic::Status {
44    fn from(err: RequestBodyDuplicationError) -> tonic::Status {
45        match err {
46            RequestBodyDuplicationError::Status(status) => status,
47            RequestBodyDuplicationError::HttpBody(error) => tonic::Status::cancelled(format!(
48                "failed to read request body for request clone: {error}"
49            )),
50        }
51    }
52}
53
54type RequestBodyDuplicationResult<T> = Result<T, RequestBodyDuplicationError>;
55
56/// The AWS ALB can misinterpret requests which contain a `content-length` header,
57/// causing requests which are routed through a gateway to fail with an h2 protocol error.
58/// Using a [`StreamBody`], even for Unary requests, will prevent `tonic` from sending that header.
59fn make_stream_body(bytes: Bytes) -> tonic::body::Body {
60    let body = Frame::data(bytes);
61
62    // Create a stream that yields a single frame
63    let stream = futures_util::stream::once(std::future::ready(Ok::<_, Infallible>(body)));
64
65    let stream_body = StreamBody::new(stream);
66
67    tonic::body::Body::new(stream_body)
68}
69
70/// This function should only be used with Unary requests; Stream requests are
71/// untested. It eagerly collects all request data into a buffer, consuming the
72/// original stream. Additionally, it assumes that all frames are data frames
73/// (i.e. the stream cannot contain any trailers); if a trailer frame is found,
74/// the cancelled status will be returned.
75async fn build_duplicate_frame_bytes(
76    mut request: Request<tonic::body::Body>,
77) -> RequestBodyDuplicationResult<(tonic::body::Body, tonic::body::Body)> {
78    let mut bytes = Vec::new();
79
80    let body = request.body_mut();
81    pin_mut!(body);
82    while let Some(result) = std::future::poll_fn(|cx| body.as_mut().poll_frame(cx)).await {
83        let frame_bytes = result?.into_data().map_err(|frame| {
84            tonic::Status::cancelled(format!(
85                "cannot duplicate a frame that is not a data frame: {frame:?}"
86            ))
87        })?;
88        bytes.extend(frame_bytes);
89    }
90
91    let bytes = Bytes::from(bytes);
92
93    Ok((make_stream_body(bytes.clone()), make_stream_body(bytes)))
94}
95
96/// This function should only be used with Unary requests; Stream requests are
97/// untested. See comment on `build_duplicate_frame_bytes`.
98async fn build_duplicate_request(
99    req: Request<Body>,
100) -> RequestBodyDuplicationResult<(Request<Body>, Request<Body>)> {
101    let mut builder_1 = Request::builder()
102        .method(req.method().clone())
103        .uri(req.uri().clone())
104        .version(req.version());
105
106    let mut builder_2 = Request::builder()
107        .method(req.method().clone())
108        .uri(req.uri().clone())
109        .version(req.version());
110
111    for (key, val) in req.headers() {
112        builder_1 = builder_1.header(key, val);
113        builder_2 = builder_2.header(key, val);
114    }
115
116    let (body_1, body_2) = build_duplicate_frame_bytes(req).await?;
117
118    let req_1 = builder_1.body(body_1)?;
119
120    let req_2 = builder_2.body(body_2)?;
121
122    Ok((req_1, req_2))
123}
124
125/// This module manages a gRPC server-client connection over a Unix domain socket. Useful for unit testing
126/// servers or clients within unit tests - supports parallelization within same process and
127/// requires no port management.
128///
129/// Derived largely from <https://stackoverflow.com/a/71808401> and
130/// <https://github.com/hyperium/tonic/tree/master/examples/src/uds>.
131#[cfg(test)]
132pub(crate) mod uds_grpc_stream {
133    use hyper_util::rt::TokioIo;
134    use opentelemetry::trace::FutureExt;
135    use std::convert::Infallible;
136    use tempfile::TempDir;
137    use tokio::net::UnixStream;
138    use tokio_stream::wrappers::UnixListenerStream;
139    use tonic::server::NamedService;
140    use tonic::transport::{Channel, Endpoint, Server};
141
142    /// The can be any valid URL. It is necessary to initialize an [`Endpoint`].
143    #[allow(dead_code)]
144    static FAUX_URL: &str = "http://api.example.rigetti.com";
145
146    fn build_server_stream(path: String) -> Result<UnixListenerStream, Error> {
147        let uds =
148            tokio::net::UnixListener::bind(path.clone()).map_err(|_| Error::BindUnixPath(path))?;
149        Ok(UnixListenerStream::new(uds))
150    }
151
152    async fn build_client_channel(path: String) -> Result<Channel, Error> {
153        let connector = tower::service_fn(move |_: tonic::transport::Uri| {
154            let path = path.clone();
155            async move {
156                let connection = UnixStream::connect(path).await?;
157                Ok::<_, std::io::Error>(TokioIo::new(connection))
158            }
159        });
160        let channel = Endpoint::try_from(FAUX_URL)
161            .map_err(|source| Error::Endpoint {
162                url: FAUX_URL.to_string(),
163                source,
164            })?
165            .connect_with_connector(connector)
166            .await
167            .map_err(|source| Error::Connect { source })?;
168        Ok(channel)
169    }
170
171    /// Errors when connecting a server-client [`Channel`] over a Unix domain socket for testing gRPC
172    /// services.
173    #[derive(thiserror::Error, Debug)]
174    pub enum Error {
175        /// Failed to initialize an [`Endpoint`] for the provided URL.
176        #[error("failed to initialize endpoint for {url}: {source}")]
177        Endpoint {
178            /// The URL that failed to initialize.
179            url: String,
180            /// The source of the error.
181            #[source]
182            source: tonic::transport::Error,
183        },
184        /// Failed to connect to the provided endpoint.
185        #[error("failed to connect to endpoint: {source}")]
186        Connect {
187            /// The source of the error.
188            #[source]
189            source: tonic::transport::Error,
190        },
191        /// Failed to bind the provided path as a Unix domain socket.
192        #[error("failed to bind path as unix listener: {0}")]
193        BindUnixPath(String),
194        /// Failed to initialize a temporary file for the Unix domain socket.
195        #[error("failed in initialize tempfile: {0}")]
196        TempFile(#[from] std::io::Error),
197        /// Failed to convert the tempfile to an [`OsString`].
198        #[error("failed to convert tempfile to OsString")]
199        TempFileOsString,
200        /// Failed to bind to the Unix domain socket.
201        #[error("failed to bind to UDS: {0}")]
202        TonicTransport(#[from] tonic::transport::Error),
203    }
204
205    /// Serve the provided gRPC service over a Unix domain socket for the duration of the
206    /// provided callback.
207    ///
208    /// # Errors
209    ///
210    /// See [`Error`].
211    #[allow(clippy::significant_drop_tightening)]
212    pub async fn serve<S, F, R, B>(service: S, f: F) -> Result<(), Error>
213    where
214        S: tower::Service<
215                http::Request<tonic::body::Body>,
216                Response = http::Response<B>,
217                Error = Infallible,
218            > + NamedService
219            + Clone
220            + Send
221            + Sync
222            + 'static,
223        S::Future: Send + 'static,
224        B: http_body::Body<Data = tonic::codegen::Bytes> + Send + 'static,
225        B::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + Sync,
226        F: FnOnce(Channel) -> R + Send,
227        R: std::future::Future<Output = ()> + Send,
228    {
229        let directory = TempDir::new()?;
230        let file = directory.path().as_os_str();
231        let file = file.to_os_string();
232        let file = file.into_string().map_err(|_| Error::TempFileOsString)?;
233        let file = format!("{file}/test.sock");
234        let stream = build_server_stream(file.clone())?;
235
236        let channel = build_client_channel(file).await?;
237        let serve_future = Server::builder()
238            .add_service(service)
239            .serve_with_incoming(stream);
240
241        tokio::select! {
242           result = serve_future => result.map_err(Error::TonicTransport),
243           () = f(channel).with_current_context() => Ok(()),
244        }
245    }
246}
247
248#[cfg(test)]
249#[cfg(feature = "tracing-opentelemetry")]
250mod otel_tests {
251    use opentelemetry::propagation::TextMapPropagator;
252    use opentelemetry::trace::{TraceContextExt, TraceId};
253    use opentelemetry_http::HeaderExtractor;
254    use opentelemetry_sdk::propagation::TraceContextPropagator;
255    use qcs_api_client_common::configuration::secrets::{SecretAccessToken, SecretRefreshToken};
256    use qcs_api_client_common::configuration::tokens::RefreshToken;
257    use serde::{Deserialize, Serialize};
258    use std::time::{Duration, SystemTime};
259    use tonic::codegen::http::{HeaderMap, HeaderValue};
260    use tonic::server::NamedService;
261    use tonic::Request;
262    use tonic_health::pb::health_check_response::ServingStatus;
263    use tonic_health::pb::health_server::{Health, HealthServer};
264    use tonic_health::{pb::health_client::HealthClient, server::HealthService};
265
266    use crate::tonic::{uds_grpc_stream, wrap_channel_with_tracing};
267    use qcs_api_client_common::configuration::ClientConfiguration;
268    use qcs_api_client_common::configuration::{settings::AuthServer, tokens::OAuthSession};
269
270    static HEALTH_CHECK_PATH: &str = "/grpc.health.v1.Health/Check";
271
272    const FAUX_BASE_URL: &str = "http://api.example.rigetti.com";
273
274    /// Test that when tracing is enabled and no filter is set, any request is properly traced.
275    #[tokio::test]
276    async fn test_tracing_enabled_no_filter() {
277        use qcs_api_client_common::tracing_configuration::TracingConfiguration;
278
279        let tracing_configuration = TracingConfiguration::builder()
280            .set_propagate_otel_context(true)
281            .build();
282        let client_config = ClientConfiguration::builder()
283            .tracing_configuration(Some(tracing_configuration))
284            .oauth_session(Some(OAuthSession::from_refresh_token(
285                RefreshToken::new(SecretRefreshToken::from("refresh_token")),
286                AuthServer::default(),
287                Some(create_jwt()),
288            )))
289            .build()
290            .expect("should be able to build client config");
291        assert_grpc_health_check_traced(client_config).await;
292    }
293
294    /// Test that when tracing is enabled, no filter is set, and OTel context propagation is
295    /// disabled, any request is properly traced without propagation.
296    #[tokio::test]
297    async fn test_tracing_enabled_no_filter_nor_otel_context_propagation() {
298        use qcs_api_client_common::tracing_configuration::TracingConfiguration;
299
300        let tracing_configuration = TracingConfiguration::default();
301        let client_config = ClientConfiguration::builder()
302            .tracing_configuration(Some(tracing_configuration))
303            .oauth_session(Some(OAuthSession::from_refresh_token(
304                RefreshToken::new(SecretRefreshToken::from("refresh_token")),
305                AuthServer::default(),
306                Some(create_jwt()),
307            )))
308            .build()
309            .expect("failed to build client config");
310        assert_grpc_health_check_traced(client_config).await;
311    }
312
313    /// Test that when tracing is enabled and the filter matches the gRPC request, the request is
314    /// properly traced.
315    #[tokio::test]
316    async fn test_tracing_enabled_filter_passed() {
317        use qcs_api_client_common::tracing_configuration::{TracingConfiguration, TracingFilter};
318
319        let tracing_filter = TracingFilter::builder()
320            .parse_strs_and_set_paths(&[HEALTH_CHECK_PATH])
321            .expect("gRPC healthcheck path should be valid filter path")
322            .build();
323
324        let tracing_configuration = TracingConfiguration::builder()
325            .set_filter(Some(tracing_filter))
326            .set_propagate_otel_context(true)
327            .build();
328
329        let client_config = ClientConfiguration::builder()
330            .tracing_configuration(Some(tracing_configuration))
331            .oauth_session(Some(OAuthSession::from_refresh_token(
332                RefreshToken::new(SecretRefreshToken::from("refresh_token")),
333                AuthServer::default(),
334                Some(create_jwt()),
335            )))
336            .build()
337            .expect("failed to build client config");
338        assert_grpc_health_check_traced(client_config).await;
339    }
340
341    /// Checks that the the [`RefreshService`] propagates the trace context via the traceparent metadata header and that the gRPC
342    /// request span is properly created (ie the span duration is reasonable).
343    #[allow(clippy::future_not_send)]
344    async fn assert_grpc_health_check_traced(client_configuration: ClientConfiguration) {
345        use opentelemetry::trace::FutureExt;
346
347        let propagate_otel_context = client_configuration.tracing_configuration().is_some_and(
348            qcs_api_client_common::tracing_configuration::TracingConfiguration::propagate_otel_context,
349        );
350        let spans: Vec<opentelemetry_sdk::trace::SpanData> = tracing_test::start(
351            "test_trace_id_propagation",
352            |trace_id, _span_id| async move {
353                let sleepy_health_service = SleepyHealthService {
354                    sleep_time: Duration::from_millis(50),
355                };
356
357                let interceptor = move |req| {
358                    if propagate_otel_context {
359                        validate_trace_id_propagated(trace_id, req)
360                    } else {
361                        validate_otel_context_not_propagated(req)
362                    }
363                };
364                let health_server =
365                    HealthServer::with_interceptor(sleepy_health_service, interceptor);
366
367                uds_grpc_stream::serve(health_server, |channel| {
368                    async {
369                        let response = HealthClient::new(wrap_channel_with_tracing(
370                            channel,
371                            FAUX_BASE_URL.to_string(),
372                            client_configuration
373                                .tracing_configuration()
374                                .unwrap()
375                                .clone(),
376                        ))
377                        .check(Request::new(tonic_health::pb::HealthCheckRequest {
378                            service: <HealthServer<HealthService> as NamedService>::NAME
379                                .to_string(),
380                        }))
381                        .await
382                        .unwrap();
383                        assert_eq!(response.into_inner().status(), ServingStatus::Serving);
384                    }
385                    .with_current_context()
386                })
387                .await
388                .unwrap();
389            },
390        )
391        .await
392        .unwrap();
393
394        let grpc_span = spans
395            .iter()
396            .find(|span| span.name == *HEALTH_CHECK_PATH)
397            .expect("failed to find gRPC span");
398        let duration = grpc_span
399            .end_time
400            .duration_since(grpc_span.start_time)
401            .expect("span should have valid timestamps");
402        assert!(duration.as_millis() >= 50u128);
403
404        let status_code_attribute =
405            tracing_test::get_span_attribute(grpc_span, "rpc.grpc.status_code")
406                .expect("gRPC span should have status code attribute");
407        assert_eq!(status_code_attribute, (tonic::Code::Ok as u8).to_string());
408    }
409
410    /// Test that when tracing is enabled but the request does not match the configured filter, the
411    /// request is not traced.
412    #[tokio::test]
413    async fn test_tracing_enabled_filter_not_passed() {
414        use qcs_api_client_common::tracing_configuration::{TracingConfiguration, TracingFilter};
415
416        let tracing_filter = TracingFilter::builder()
417            .parse_strs_and_set_paths(&[HEALTH_CHECK_PATH])
418            .expect("healthcheck path should be a valid filter path")
419            .set_is_negated(true)
420            .build();
421
422        let tracing_configuration = TracingConfiguration::builder()
423            .set_filter(Some(tracing_filter))
424            .set_propagate_otel_context(true)
425            .build();
426
427        let client_config = ClientConfiguration::builder()
428            .tracing_configuration(Some(tracing_configuration))
429            .oauth_session(Some(OAuthSession::from_refresh_token(
430                RefreshToken::new(SecretRefreshToken::from("refresh_token")),
431                AuthServer::default(),
432                Some(create_jwt()),
433            )))
434            .build()
435            .expect("should be able to build client config");
436
437        assert_grpc_health_check_not_traced(client_config.clone()).await;
438    }
439
440    /// Check that the traceparent metadata header is not set on the gRPC request and no tracing
441    /// spans are produced for the gRPC request.
442    #[allow(clippy::future_not_send)]
443    async fn assert_grpc_health_check_not_traced(client_configuration: ClientConfiguration) {
444        use opentelemetry::trace::FutureExt;
445
446        let spans: Vec<opentelemetry_sdk::trace::SpanData> =
447            tracing_test::start("test_tracing_disabled", |_trace_id, _span_id| async move {
448                let interceptor = validate_otel_context_not_propagated;
449                let health_server = HealthServer::with_interceptor(
450                    SleepyHealthService {
451                        sleep_time: Duration::from_millis(0),
452                    },
453                    interceptor,
454                );
455
456                uds_grpc_stream::serve(health_server, |channel| {
457                    async {
458                        let response = HealthClient::new(wrap_channel_with_tracing(
459                            channel,
460                            FAUX_BASE_URL.to_string(),
461                            client_configuration
462                                .tracing_configuration()
463                                .unwrap()
464                                .clone(),
465                        ))
466                        .check(Request::new(tonic_health::pb::HealthCheckRequest {
467                            service: <HealthServer<HealthService> as NamedService>::NAME
468                                .to_string(),
469                        }))
470                        .await
471                        .unwrap();
472                        assert_eq!(response.into_inner().status(), ServingStatus::Serving);
473                    }
474                    .with_current_context()
475                })
476                .await
477                .unwrap();
478            })
479            .await
480            .unwrap();
481
482        assert!(spans.iter().all(|span| { span.name != *HEALTH_CHECK_PATH }));
483    }
484
485    const JWT_SECRET: &[u8] = b"top-secret";
486
487    #[derive(Clone, Debug, Serialize, Deserialize)]
488    struct Claims {
489        sub: String,
490        exp: u64,
491    }
492
493    /// Create an HS256 signed JWT token with sub and exp claims. This is good enough to pass the
494    /// [`RefreshService`] token validation.
495    pub(crate) fn create_jwt() -> SecretAccessToken {
496        use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
497        let expiration = SystemTime::now()
498            .checked_add(Duration::from_secs(60))
499            .unwrap()
500            .duration_since(SystemTime::UNIX_EPOCH)
501            .unwrap()
502            .as_secs();
503
504        let claims = Claims {
505            sub: "test-uid".to_string(),
506            exp: expiration,
507        };
508        // The client doesn't check the signature, so for convenience here, we just sign with HS256
509        // instead of RS256.
510        let header = Header::new(Algorithm::HS256);
511        encode(&header, &claims, &EncodingKey::from_secret(JWT_SECRET))
512            .map(SecretAccessToken::from)
513            .unwrap()
514    }
515
516    #[derive(Debug, thiserror::Error)]
517    #[allow(variant_size_differences)]
518    enum ServerAssertionError {
519        #[error("trace id did not propagate to server: {0}")]
520        UnexpectedTraceId(String),
521        #[error("otel context headers unexpectedly sent to server")]
522        UnexpectedOTelContextHeaders,
523    }
524
525    impl From<ServerAssertionError> for tonic::Status {
526        fn from(server_assertion_error: ServerAssertionError) -> Self {
527            Self::invalid_argument(server_assertion_error.to_string())
528        }
529    }
530
531    /// Given an incoming gRPC request, validate that the the specified [`TraceId`] is propagated
532    /// via the `traceparent` metadata header.
533    #[allow(clippy::result_large_err)]
534    fn validate_trace_id_propagated(
535        trace_id: TraceId,
536        req: Request<()>,
537    ) -> Result<Request<()>, tonic::Status> {
538        req.metadata()
539            .get("traceparent")
540            .ok_or_else(|| {
541                ServerAssertionError::UnexpectedTraceId(
542                    "request traceparent metadata not present".to_string(),
543                )
544            })
545            .and_then(|traceparent| {
546                let mut headers = HeaderMap::new();
547                headers.insert(
548                    "traceparent",
549                    HeaderValue::from_str(traceparent.to_str().map_err(|_| {
550                        ServerAssertionError::UnexpectedTraceId(
551                            "failed to deserialize trace parent".to_string(),
552                        )
553                    })?)
554                    .map_err(|_| {
555                        ServerAssertionError::UnexpectedTraceId(
556                            "failed to serialize trace parent as HeaderValue".to_string(),
557                        )
558                    })?,
559                );
560                Ok(headers)
561            })
562            .and_then(|headers| {
563                let extractor = HeaderExtractor(&headers);
564                let propagator = TraceContextPropagator::new();
565                let context = propagator.extract(&extractor);
566                let propagated_trace_id = context.span().span_context().trace_id();
567                if propagated_trace_id == trace_id {
568                    Ok(req)
569                } else {
570                    Err(ServerAssertionError::UnexpectedTraceId(format!(
571                        "expected trace id {trace_id}, got {propagated_trace_id}",
572                    )))
573                }
574            })
575            .map_err(Into::into)
576    }
577
578    /// Simply validate that the `traceparent` and `tracestate` metadata headers are not present
579    /// on the incoming gRPC.
580    #[allow(clippy::result_large_err)]
581    fn validate_otel_context_not_propagated(
582        req: Request<()>,
583    ) -> Result<Request<()>, tonic::Status> {
584        if req.metadata().get("traceparent").is_some() || req.metadata().get("tracestate").is_some()
585        {
586            Err(ServerAssertionError::UnexpectedOTelContextHeaders.into())
587        } else {
588            Ok(req)
589        }
590    }
591
592    /// An implementation of the gRPC [`HealthService`] that sleeps for the configured duration on before returning a response.
593    /// This is useful for making assertions on span durations. It is also necessary in order to
594    /// wrap the [`HealthServer`] in an interceptor, which is not possible with public methods in
595    /// the health service crate.
596    ///
597    /// Derived in part from <https://github.com/hyperium/tonic/blob/master/tonic-health/src/generated/grpc.health.v1.rs/>
598    #[derive(Clone)]
599    struct SleepyHealthService {
600        sleep_time: Duration,
601    }
602
603    #[tonic::async_trait]
604    impl Health for SleepyHealthService {
605        async fn check(
606            &self,
607            _request: Request<tonic_health::pb::HealthCheckRequest>,
608        ) -> Result<tonic::Response<tonic_health::pb::HealthCheckResponse>, tonic::Status> {
609            tokio::time::sleep(self.sleep_time).await;
610            let response = tonic_health::pb::HealthCheckResponse {
611                status: ServingStatus::Serving as i32,
612            };
613            Ok(tonic::Response::new(response))
614        }
615
616        type WatchStream = tokio_stream::wrappers::ReceiverStream<
617            Result<tonic_health::pb::HealthCheckResponse, tonic::Status>,
618        >;
619
620        async fn watch(
621            &self,
622            _request: Request<tonic_health::pb::HealthCheckRequest>,
623        ) -> Result<tonic::Response<Self::WatchStream>, tonic::Status> {
624            let (tx, rx) = tokio::sync::mpsc::channel(1);
625            let response = tonic_health::pb::HealthCheckResponse {
626                status: ServingStatus::Serving as i32,
627            };
628            tx.send(Ok(response)).await.unwrap();
629            Ok(tonic::Response::new(
630                tokio_stream::wrappers::ReceiverStream::new(rx),
631            ))
632        }
633    }
634
635    /// We need a single global ``SpanProcessor`` because these tests have to work using
636    /// ``opentelemetry::global`` and ``tracing_subscriber::set_global_default``. Otherwise,
637    /// we can't make guarantees about where the spans are processed and therefore could
638    /// not make assertions about the traced spans.
639    mod tracing_test {
640        use futures_util::Future;
641        use opentelemetry::global::BoxedTracer;
642        use opentelemetry::trace::{
643            mark_span_as_active, FutureExt, Span, SpanId, TraceId, Tracer, TracerProvider,
644        };
645        use opentelemetry_sdk::error::OTelSdkError;
646        use opentelemetry_sdk::trace::{SpanData, SpanProcessor};
647        use std::collections::HashMap;
648        use std::sync::{Arc, RwLock};
649        use tokio::sync::oneshot;
650
651        /// Start a new test span and run the specified callback with the span as the active span.
652        /// The call back is provided the span and trace ids. At the end of the callback, we wait
653        /// for the span to be processed and then return all of the spans that were processed for
654        /// this particular test.
655        #[allow(clippy::future_not_send)]
656        pub(crate) async fn start<F, R>(name: &'static str, f: F) -> Result<Vec<SpanData>, String>
657        where
658            F: FnOnce(TraceId, SpanId) -> R + Send,
659            R: Future<Output = ()> + Send,
660        {
661            let tracer = CacheProcessor::tracer();
662            let span = tracer.start(name);
663            let span_id = span.span_context().span_id();
664            let trace_id = span.span_context().trace_id();
665            let cache = CACHE
666                .get()
667                .expect("cache should be initialized with cache tracer");
668            let span_rx = cache.subscribe(span_id);
669            async {
670                let _guard = mark_span_as_active(span);
671                f(trace_id, span_id).with_current_context().await;
672            }
673            .await;
674
675            // wait for test span to be processed.
676            span_rx.await.unwrap();
677
678            // remove and return the spans processed for this test.
679            let mut data = cache.data.write().map_err(|e| format!("{e:?}"))?;
680            Ok(data.remove(&trace_id).unwrap_or_default())
681        }
682
683        static CACHE: once_cell::sync::OnceCell<CacheProcessor> = once_cell::sync::OnceCell::new();
684
685        #[derive(Debug, Clone, Default)]
686        struct CacheProcessor {
687            data: Arc<RwLock<HashMap<TraceId, Vec<SpanData>>>>,
688            notifications: Arc<RwLock<HashMap<SpanId, tokio::sync::oneshot::Sender<()>>>>,
689        }
690
691        impl CacheProcessor {
692            /// Initializes the [`CACHE`] and sets the global `opentelemetry::global::tracer_provider` and `tracing_subscriber::global_default`.
693            /// These initializations occur safely behind a `OnceCell` initialization, so they can be used by several tests.
694            fn tracer() -> BoxedTracer {
695                use tracing_subscriber::layer::SubscriberExt;
696
697                CACHE.get_or_init(|| {
698                    let processor = Self::default();
699                    let provider = opentelemetry_sdk::trace::SdkTracerProvider::builder()
700                        .with_span_processor(processor.clone())
701                        .build();
702                    opentelemetry::global::set_tracer_provider(provider.clone());
703                    let tracer = provider.tracer("test_channel");
704                    let telemetry = tracing_opentelemetry::layer().with_tracer(tracer);
705                    let subscriber = tracing_subscriber::Registry::default().with(telemetry);
706                    tracing::subscriber::set_global_default(subscriber)
707                        .expect("tracing subscriber already set");
708                    processor
709                });
710                opentelemetry::global::tracer("test_channel")
711            }
712
713            /// Ensure that a [`Notification`] exists for the provided span id.
714            fn subscribe(&self, span_id: SpanId) -> oneshot::Receiver<()> {
715                let (tx, rx) = oneshot::channel();
716                self.notifications.write().unwrap().insert(span_id, tx);
717                rx
718            }
719        }
720
721        impl SpanProcessor for CacheProcessor {
722            fn on_start(
723                &self,
724                _span: &mut opentelemetry_sdk::trace::Span,
725                _cx: &opentelemetry::Context,
726            ) {
727            }
728
729            fn on_end(&self, span: SpanData) {
730                let trace_id = span.span_context.trace_id();
731                let span_id = span.span_context.span_id();
732                {
733                    let mut data = self
734                        .data
735                        .write()
736                        .expect("failed to write access cache span data");
737                    data.entry(trace_id).or_default().push(span);
738                }
739
740                if let Some(notify) = self
741                    .notifications
742                    .write()
743                    .expect("failed to read access span notifications during span processing")
744                    .remove(&span_id)
745                {
746                    notify.send(()).unwrap();
747                }
748            }
749
750            /// This is a no-op because spans are processed synchronously in `on_end`.
751            fn force_flush(&self) -> Result<(), OTelSdkError> {
752                Ok(())
753            }
754
755            fn shutdown(&self) -> Result<(), OTelSdkError> {
756                Ok(())
757            }
758
759            fn shutdown_with_timeout(
760                &self,
761                _timeout: std::time::Duration,
762            ) -> Result<(), OTelSdkError> {
763                Ok(())
764            }
765        }
766
767        /// Get the Opentelemetry attribute value for the provided key.
768        #[must_use]
769        pub(super) fn get_span_attribute(
770            span_data: &SpanData,
771            key: &'static str,
772        ) -> Option<String> {
773            span_data
774                .attributes
775                .iter()
776                .find(|attr| attr.key.to_string() == *key)
777                .map(|kv| kv.value.to_string())
778        }
779    }
780}