qcs_api_client_grpc/tonic/
mod.rs

1use std::convert::Infallible;
2
3/// QCS Middleware for [`tonic`] clients.
4///
5use futures_util::pin_mut;
6use http::Request;
7use http_body::Body as HttpBody;
8use http_body_util::BodyExt;
9
10mod channel;
11mod common;
12mod error;
13#[cfg(feature = "grpc-web")]
14mod grpc_web;
15mod refresh;
16mod retry;
17#[cfg(feature = "tracing")]
18mod trace;
19
20pub use channel::*;
21pub use error::*;
22#[cfg(feature = "grpc-web")]
23pub use grpc_web::*;
24pub use refresh::*;
25pub use retry::*;
26use tonic::body::Body;
27#[cfg(feature = "tracing")]
28pub use trace::*;
29/// An error observed while duplicating a request body. This may be returned by any
30/// [`tower::Service`] that duplicates a request body for the purpose of retrying a request.
31#[derive(Debug, thiserror::Error)]
32pub enum RequestBodyDuplicationError {
33    /// The inner service returned an error from the server, or the client cancelled the
34    /// request.
35    #[error(transparent)]
36    Status(#[from] tonic::Status),
37    /// Failed to read the request body for cloning.
38    #[error("failed to read request body for request clone: {0}")]
39    HttpBody(#[from] http::Error),
40}
41
42impl From<RequestBodyDuplicationError> for tonic::Status {
43    fn from(err: RequestBodyDuplicationError) -> tonic::Status {
44        match err {
45            RequestBodyDuplicationError::Status(status) => status,
46            RequestBodyDuplicationError::HttpBody(error) => tonic::Status::cancelled(format!(
47                "failed to read request body for request clone: {error}"
48            )),
49        }
50    }
51}
52
53type RequestBodyDuplicationResult<T> = Result<T, RequestBodyDuplicationError>;
54
55/// This function should only be used with Unary requests; Stream requests are
56/// untested. It eagerly collects all request data into a buffer, consuming the
57/// original stream. Additionally, it assumes that all frames are data frames
58/// (i.e. the stream cannot contain any trailers); if a trailer frame is found,
59/// the cancelled status will be returned.
60async fn build_duplicate_frame_bytes(
61    mut request: Request<tonic::body::Body>,
62) -> RequestBodyDuplicationResult<(tonic::body::Body, tonic::body::Body)> {
63    let mut bytes = Vec::new();
64
65    let body = request.body_mut();
66    pin_mut!(body);
67    while let Some(result) = std::future::poll_fn(|cx| body.as_mut().poll_frame(cx)).await {
68        let frame_bytes = result?.into_data().map_err(|frame| {
69            tonic::Status::cancelled(format!(
70                "cannot duplicate a frame that is not a data frame: {frame:?}"
71            ))
72        })?;
73        bytes.extend(frame_bytes);
74    }
75
76    let bytes = http_body_util::Full::from(bytes)
77        .map_err(|_: Infallible| -> tonic::Status { unreachable!() });
78    Ok((
79        tonic::body::Body::new(bytes.clone()),
80        tonic::body::Body::new(bytes),
81    ))
82}
83
84/// This function should only be used with Unary requests; Stream requests are
85/// untested. See comment on `build_duplicate_frame_bytes`.
86async fn build_duplicate_request(
87    req: Request<Body>,
88) -> RequestBodyDuplicationResult<(Request<Body>, Request<Body>)> {
89    let mut builder_1 = Request::builder()
90        .method(req.method().clone())
91        .uri(req.uri().clone())
92        .version(req.version());
93
94    let mut builder_2 = Request::builder()
95        .method(req.method().clone())
96        .uri(req.uri().clone())
97        .version(req.version());
98
99    for (key, val) in req.headers() {
100        builder_1 = builder_1.header(key, val);
101        builder_2 = builder_2.header(key, val);
102    }
103
104    let (body_1, body_2) = build_duplicate_frame_bytes(req).await?;
105
106    let req_1 = builder_1.body(body_1)?;
107
108    let req_2 = builder_2.body(body_2)?;
109
110    Ok((req_1, req_2))
111}
112
113/// This module manages a gRPC server-client connection over a Unix domain socket. Useful for unit testing
114/// servers or clients within unit tests - supports parallelization within same process and
115/// requires no port management.
116///
117/// Derived largely from <https://stackoverflow.com/a/71808401> and
118/// <https://github.com/hyperium/tonic/tree/master/examples/src/uds>.
119#[cfg(test)]
120pub(crate) mod uds_grpc_stream {
121    use hyper_util::rt::TokioIo;
122    use opentelemetry::trace::FutureExt;
123    use std::convert::Infallible;
124    use tempfile::TempDir;
125    use tokio::net::UnixStream;
126    use tokio_stream::wrappers::UnixListenerStream;
127    use tonic::server::NamedService;
128    use tonic::transport::{Channel, Endpoint, Server};
129
130    /// The can be any valid URL. It is necessary to initialize an [`Endpoint`].
131    #[allow(dead_code)]
132    static FAUX_URL: &str = "http://api.example.rigetti.com";
133
134    fn build_server_stream(path: String) -> Result<UnixListenerStream, Error> {
135        let uds =
136            tokio::net::UnixListener::bind(path.clone()).map_err(|_| Error::BindUnixPath(path))?;
137        Ok(UnixListenerStream::new(uds))
138    }
139
140    async fn build_client_channel(path: String) -> Result<Channel, Error> {
141        let connector = tower::service_fn(move |_: tonic::transport::Uri| {
142            let path = path.clone();
143            async move {
144                let connection = UnixStream::connect(path).await?;
145                Ok::<_, std::io::Error>(TokioIo::new(connection))
146            }
147        });
148        let channel = Endpoint::try_from(FAUX_URL)
149            .map_err(|source| Error::Endpoint {
150                url: FAUX_URL.to_string(),
151                source,
152            })?
153            .connect_with_connector(connector)
154            .await
155            .map_err(|source| Error::Connect { source })?;
156        Ok(channel)
157    }
158
159    /// Errors when connecting a server-client [`Channel`] over a Unix domain socket for testing gRPC
160    /// services.
161    #[derive(thiserror::Error, Debug)]
162    pub enum Error {
163        /// Failed to initialize an [`Endpoint`] for the provided URL.
164        #[error("failed to initialize endpoint for {url}: {source}")]
165        Endpoint {
166            /// The URL that failed to initialize.
167            url: String,
168            /// The source of the error.
169            #[source]
170            source: tonic::transport::Error,
171        },
172        /// Failed to connect to the provided endpoint.
173        #[error("failed to connect to endpoint: {source}")]
174        Connect {
175            /// The source of the error.
176            #[source]
177            source: tonic::transport::Error,
178        },
179        /// Failed to bind the provided path as a Unix domain socket.
180        #[error("failed to bind path as unix listener: {0}")]
181        BindUnixPath(String),
182        /// Failed to initialize a temporary file for the Unix domain socket.
183        #[error("failed in initialize tempfile: {0}")]
184        TempFile(#[from] std::io::Error),
185        /// Failed to convert the tempfile to an [`OsString`].
186        #[error("failed to convert tempfile to OsString")]
187        TempFileOsString,
188        /// Failed to bind to the Unix domain socket.
189        #[error("failed to bind to UDS: {0}")]
190        TonicTransport(#[from] tonic::transport::Error),
191    }
192
193    /// Serve the provided gRPC service over a Unix domain socket for the duration of the
194    /// provided callback.
195    ///
196    /// # Errors
197    ///
198    /// See [`Error`].
199    #[allow(clippy::significant_drop_tightening)]
200    pub async fn serve<S, F, R, B>(service: S, f: F) -> Result<(), Error>
201    where
202        S: tower::Service<
203                http::Request<tonic::body::Body>,
204                Response = http::Response<B>,
205                Error = Infallible,
206            > + NamedService
207            + Clone
208            + Send
209            + Sync
210            + 'static,
211        S::Future: Send + 'static,
212        B: http_body::Body<Data = tonic::codegen::Bytes> + Send + 'static,
213        B::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + Sync,
214        F: FnOnce(Channel) -> R + Send,
215        R: std::future::Future<Output = ()> + Send,
216    {
217        let directory = TempDir::new()?;
218        let file = directory.path().as_os_str();
219        let file = file.to_os_string();
220        let file = file.into_string().map_err(|_| Error::TempFileOsString)?;
221        let file = format!("{file}/test.sock");
222        let stream = build_server_stream(file.clone())?;
223
224        let channel = build_client_channel(file).await?;
225        let serve_future = Server::builder()
226            .add_service(service)
227            .serve_with_incoming(stream);
228
229        tokio::select! {
230           result = serve_future => result.map_err(Error::TonicTransport),
231           () = f(channel).with_current_context() => Ok(()),
232        }
233    }
234}
235
236#[cfg(test)]
237#[cfg(feature = "tracing-opentelemetry")]
238mod otel_tests {
239    use opentelemetry::propagation::TextMapPropagator;
240    use opentelemetry::trace::{TraceContextExt, TraceId};
241    use opentelemetry_http::HeaderExtractor;
242    use opentelemetry_sdk::propagation::TraceContextPropagator;
243    use qcs_api_client_common::configuration::secrets::{SecretAccessToken, SecretRefreshToken};
244    use qcs_api_client_common::configuration::tokens::RefreshToken;
245    use serde::{Deserialize, Serialize};
246    use std::time::{Duration, SystemTime};
247    use tonic::codegen::http::{HeaderMap, HeaderValue};
248    use tonic::server::NamedService;
249    use tonic::Request;
250    use tonic_health::pb::health_check_response::ServingStatus;
251    use tonic_health::pb::health_server::{Health, HealthServer};
252    use tonic_health::{pb::health_client::HealthClient, server::HealthService};
253
254    use crate::tonic::{uds_grpc_stream, wrap_channel_with_tracing};
255    use qcs_api_client_common::configuration::ClientConfiguration;
256    use qcs_api_client_common::configuration::{settings::AuthServer, tokens::OAuthSession};
257
258    static HEALTH_CHECK_PATH: &str = "/grpc.health.v1.Health/Check";
259
260    const FAUX_BASE_URL: &str = "http://api.example.rigetti.com";
261
262    /// Test that when tracing is enabled and no filter is set, any request is properly traced.
263    #[tokio::test]
264    async fn test_tracing_enabled_no_filter() {
265        use qcs_api_client_common::tracing_configuration::TracingConfiguration;
266
267        let tracing_configuration = TracingConfiguration::builder()
268            .set_propagate_otel_context(true)
269            .build();
270        let client_config = ClientConfiguration::builder()
271            .tracing_configuration(Some(tracing_configuration))
272            .oauth_session(Some(OAuthSession::from_refresh_token(
273                RefreshToken::new(SecretRefreshToken::from("refresh_token")),
274                AuthServer::default(),
275                Some(create_jwt()),
276            )))
277            .build()
278            .expect("should be able to build client config");
279        assert_grpc_health_check_traced(client_config).await;
280    }
281
282    /// Test that when tracing is enabled, no filter is set, and OTel context propagation is
283    /// disabled, any request is properly traced without propagation.
284    #[tokio::test]
285    async fn test_tracing_enabled_no_filter_nor_otel_context_propagation() {
286        use qcs_api_client_common::tracing_configuration::TracingConfiguration;
287
288        let tracing_configuration = TracingConfiguration::default();
289        let client_config = ClientConfiguration::builder()
290            .tracing_configuration(Some(tracing_configuration))
291            .oauth_session(Some(OAuthSession::from_refresh_token(
292                RefreshToken::new(SecretRefreshToken::from("refresh_token")),
293                AuthServer::default(),
294                Some(create_jwt()),
295            )))
296            .build()
297            .expect("failed to build client config");
298        assert_grpc_health_check_traced(client_config).await;
299    }
300
301    /// Test that when tracing is enabled and the filter matches the gRPC request, the request is
302    /// properly traced.
303    #[tokio::test]
304    async fn test_tracing_enabled_filter_passed() {
305        use qcs_api_client_common::tracing_configuration::{TracingConfiguration, TracingFilter};
306
307        let tracing_filter = TracingFilter::builder()
308            .parse_strs_and_set_paths(&[HEALTH_CHECK_PATH])
309            .expect("gRPC healthcheck path should be valid filter path")
310            .build();
311
312        let tracing_configuration = TracingConfiguration::builder()
313            .set_filter(Some(tracing_filter))
314            .set_propagate_otel_context(true)
315            .build();
316
317        let client_config = ClientConfiguration::builder()
318            .tracing_configuration(Some(tracing_configuration))
319            .oauth_session(Some(OAuthSession::from_refresh_token(
320                RefreshToken::new(SecretRefreshToken::from("refresh_token")),
321                AuthServer::default(),
322                Some(create_jwt()),
323            )))
324            .build()
325            .expect("failed to build client config");
326        assert_grpc_health_check_traced(client_config).await;
327    }
328
329    /// Checks that the the [`RefreshService`] propagates the trace context via the traceparent metadata header and that the gRPC
330    /// request span is properly created (ie the span duration is reasonable).
331    #[allow(clippy::future_not_send)]
332    async fn assert_grpc_health_check_traced(client_configuration: ClientConfiguration) {
333        use opentelemetry::trace::FutureExt;
334
335        let propagate_otel_context = client_configuration.tracing_configuration().is_some_and(
336            qcs_api_client_common::tracing_configuration::TracingConfiguration::propagate_otel_context,
337        );
338        let spans: Vec<opentelemetry_sdk::trace::SpanData> = tracing_test::start(
339            "test_trace_id_propagation",
340            |trace_id, _span_id| async move {
341                let sleepy_health_service = SleepyHealthService {
342                    sleep_time: Duration::from_millis(50),
343                };
344
345                let interceptor = move |req| {
346                    if propagate_otel_context {
347                        validate_trace_id_propagated(trace_id, req)
348                    } else {
349                        validate_otel_context_not_propagated(req)
350                    }
351                };
352                let health_server =
353                    HealthServer::with_interceptor(sleepy_health_service, interceptor);
354
355                uds_grpc_stream::serve(health_server, |channel| {
356                    async {
357                        let response = HealthClient::new(wrap_channel_with_tracing(
358                            channel,
359                            FAUX_BASE_URL.to_string(),
360                            client_configuration
361                                .tracing_configuration()
362                                .unwrap()
363                                .clone(),
364                        ))
365                        .check(Request::new(tonic_health::pb::HealthCheckRequest {
366                            service: <HealthServer<HealthService> as NamedService>::NAME
367                                .to_string(),
368                        }))
369                        .await
370                        .unwrap();
371                        assert_eq!(response.into_inner().status(), ServingStatus::Serving);
372                    }
373                    .with_current_context()
374                })
375                .await
376                .unwrap();
377            },
378        )
379        .await
380        .unwrap();
381
382        let grpc_span = spans
383            .iter()
384            .find(|span| span.name == *HEALTH_CHECK_PATH)
385            .expect("failed to find gRPC span");
386        let duration = grpc_span
387            .end_time
388            .duration_since(grpc_span.start_time)
389            .expect("span should have valid timestamps");
390        assert!(duration.as_millis() >= 50u128);
391
392        let status_code_attribute =
393            tracing_test::get_span_attribute(grpc_span, "rpc.grpc.status_code")
394                .expect("gRPC span should have status code attribute");
395        assert_eq!(status_code_attribute, (tonic::Code::Ok as u8).to_string());
396    }
397
398    /// Test that when tracing is enabled but the request does not match the configured filter, the
399    /// request is not traced.
400    #[tokio::test]
401    async fn test_tracing_enabled_filter_not_passed() {
402        use qcs_api_client_common::tracing_configuration::{TracingConfiguration, TracingFilter};
403
404        let tracing_filter = TracingFilter::builder()
405            .parse_strs_and_set_paths(&[HEALTH_CHECK_PATH])
406            .expect("healthcheck path should be a valid filter path")
407            .set_is_negated(true)
408            .build();
409
410        let tracing_configuration = TracingConfiguration::builder()
411            .set_filter(Some(tracing_filter))
412            .set_propagate_otel_context(true)
413            .build();
414
415        let client_config = ClientConfiguration::builder()
416            .tracing_configuration(Some(tracing_configuration))
417            .oauth_session(Some(OAuthSession::from_refresh_token(
418                RefreshToken::new(SecretRefreshToken::from("refresh_token")),
419                AuthServer::default(),
420                Some(create_jwt()),
421            )))
422            .build()
423            .expect("should be able to build client config");
424
425        assert_grpc_health_check_not_traced(client_config.clone()).await;
426    }
427
428    /// Check that the traceparent metadata header is not set on the gRPC request and no tracing
429    /// spans are produced for the gRPC request.
430    #[allow(clippy::future_not_send)]
431    async fn assert_grpc_health_check_not_traced(client_configuration: ClientConfiguration) {
432        use opentelemetry::trace::FutureExt;
433
434        let spans: Vec<opentelemetry_sdk::trace::SpanData> =
435            tracing_test::start("test_tracing_disabled", |_trace_id, _span_id| async move {
436                let interceptor = validate_otel_context_not_propagated;
437                let health_server = HealthServer::with_interceptor(
438                    SleepyHealthService {
439                        sleep_time: Duration::from_millis(0),
440                    },
441                    interceptor,
442                );
443
444                uds_grpc_stream::serve(health_server, |channel| {
445                    async {
446                        let response = HealthClient::new(wrap_channel_with_tracing(
447                            channel,
448                            FAUX_BASE_URL.to_string(),
449                            client_configuration
450                                .tracing_configuration()
451                                .unwrap()
452                                .clone(),
453                        ))
454                        .check(Request::new(tonic_health::pb::HealthCheckRequest {
455                            service: <HealthServer<HealthService> as NamedService>::NAME
456                                .to_string(),
457                        }))
458                        .await
459                        .unwrap();
460                        assert_eq!(response.into_inner().status(), ServingStatus::Serving);
461                    }
462                    .with_current_context()
463                })
464                .await
465                .unwrap();
466            })
467            .await
468            .unwrap();
469
470        assert!(spans.iter().all(|span| { span.name != *HEALTH_CHECK_PATH }));
471    }
472
473    const JWT_SECRET: &[u8] = b"top-secret";
474
475    #[derive(Clone, Debug, Serialize, Deserialize)]
476    struct Claims {
477        sub: String,
478        exp: u64,
479    }
480
481    /// Create an HS256 signed JWT token with sub and exp claims. This is good enough to pass the
482    /// [`RefreshService`] token validation.
483    pub(crate) fn create_jwt() -> SecretAccessToken {
484        use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
485        let expiration = SystemTime::now()
486            .checked_add(Duration::from_secs(60))
487            .unwrap()
488            .duration_since(SystemTime::UNIX_EPOCH)
489            .unwrap()
490            .as_secs();
491
492        let claims = Claims {
493            sub: "test-uid".to_string(),
494            exp: expiration,
495        };
496        // The client doesn't check the signature, so for convenience here, we just sign with HS256
497        // instead of RS256.
498        let header = Header::new(Algorithm::HS256);
499        encode(&header, &claims, &EncodingKey::from_secret(JWT_SECRET))
500            .map(SecretAccessToken::from)
501            .unwrap()
502    }
503
504    #[derive(Debug, thiserror::Error)]
505    #[allow(variant_size_differences)]
506    enum ServerAssertionError {
507        #[error("trace id did not propagate to server: {0}")]
508        UnexpectedTraceId(String),
509        #[error("otel context headers unexpectedly sent to server")]
510        UnexpectedOTelContextHeaders,
511    }
512
513    impl From<ServerAssertionError> for tonic::Status {
514        fn from(server_assertion_error: ServerAssertionError) -> Self {
515            Self::invalid_argument(server_assertion_error.to_string())
516        }
517    }
518
519    /// Given an incoming gRPC request, validate that the the specified [`TraceId`] is propagated
520    /// via the `traceparent` metadata header.
521    #[allow(clippy::result_large_err)]
522    fn validate_trace_id_propagated(
523        trace_id: TraceId,
524        req: Request<()>,
525    ) -> Result<Request<()>, tonic::Status> {
526        req.metadata()
527            .get("traceparent")
528            .ok_or_else(|| {
529                ServerAssertionError::UnexpectedTraceId(
530                    "request traceparent metadata not present".to_string(),
531                )
532            })
533            .and_then(|traceparent| {
534                let mut headers = HeaderMap::new();
535                headers.insert(
536                    "traceparent",
537                    HeaderValue::from_str(traceparent.to_str().map_err(|_| {
538                        ServerAssertionError::UnexpectedTraceId(
539                            "failed to deserialize trace parent".to_string(),
540                        )
541                    })?)
542                    .map_err(|_| {
543                        ServerAssertionError::UnexpectedTraceId(
544                            "failed to serialize trace parent as HeaderValue".to_string(),
545                        )
546                    })?,
547                );
548                Ok(headers)
549            })
550            .and_then(|headers| {
551                let extractor = HeaderExtractor(&headers);
552                let propagator = TraceContextPropagator::new();
553                let context = propagator.extract(&extractor);
554                let propagated_trace_id = context.span().span_context().trace_id();
555                if propagated_trace_id == trace_id {
556                    Ok(req)
557                } else {
558                    Err(ServerAssertionError::UnexpectedTraceId(format!(
559                        "expected trace id {trace_id}, got {propagated_trace_id}",
560                    )))
561                }
562            })
563            .map_err(Into::into)
564    }
565
566    /// Simply validate that the `traceparent` and `tracestate` metadata headers are not present
567    /// on the incoming gRPC.
568    #[allow(clippy::result_large_err)]
569    fn validate_otel_context_not_propagated(
570        req: Request<()>,
571    ) -> Result<Request<()>, tonic::Status> {
572        if req.metadata().get("traceparent").is_some() || req.metadata().get("tracestate").is_some()
573        {
574            Err(ServerAssertionError::UnexpectedOTelContextHeaders.into())
575        } else {
576            Ok(req)
577        }
578    }
579
580    /// An implementation of the gRPC [`HealthService`] that sleeps for the configured duration on before returning a response.
581    /// This is useful for making assertions on span durations. It is also necessary in order to
582    /// wrap the [`HealthServer`] in an interceptor, which is not possible with public methods in
583    /// the health service crate.
584    ///
585    /// Derived in part from <https://github.com/hyperium/tonic/blob/master/tonic-health/src/generated/grpc.health.v1.rs/>
586    #[derive(Clone)]
587    struct SleepyHealthService {
588        sleep_time: Duration,
589    }
590
591    #[tonic::async_trait]
592    impl Health for SleepyHealthService {
593        async fn check(
594            &self,
595            _request: Request<tonic_health::pb::HealthCheckRequest>,
596        ) -> Result<tonic::Response<tonic_health::pb::HealthCheckResponse>, tonic::Status> {
597            tokio::time::sleep(self.sleep_time).await;
598            let response = tonic_health::pb::HealthCheckResponse {
599                status: ServingStatus::Serving as i32,
600            };
601            Ok(tonic::Response::new(response))
602        }
603
604        type WatchStream = tokio_stream::wrappers::ReceiverStream<
605            Result<tonic_health::pb::HealthCheckResponse, tonic::Status>,
606        >;
607
608        async fn watch(
609            &self,
610            _request: Request<tonic_health::pb::HealthCheckRequest>,
611        ) -> Result<tonic::Response<Self::WatchStream>, tonic::Status> {
612            let (tx, rx) = tokio::sync::mpsc::channel(1);
613            let response = tonic_health::pb::HealthCheckResponse {
614                status: ServingStatus::Serving as i32,
615            };
616            tx.send(Ok(response)).await.unwrap();
617            Ok(tonic::Response::new(
618                tokio_stream::wrappers::ReceiverStream::new(rx),
619            ))
620        }
621    }
622
623    /// We need a single global ``SpanProcessor`` because these tests have to work using
624    /// ``opentelemetry::global`` and ``tracing_subscriber::set_global_default``. Otherwise,
625    /// we can't make guarantees about where the spans are processed and therefore could
626    /// not make assertions about the traced spans.
627    mod tracing_test {
628        use futures_util::Future;
629        use opentelemetry::global::BoxedTracer;
630        use opentelemetry::trace::{
631            mark_span_as_active, FutureExt, Span, SpanId, TraceId, Tracer, TracerProvider,
632        };
633        use opentelemetry_sdk::error::OTelSdkError;
634        use opentelemetry_sdk::trace::{SpanData, SpanProcessor};
635        use std::collections::HashMap;
636        use std::sync::{Arc, RwLock};
637        use tokio::sync::oneshot;
638
639        /// Start a new test span and run the specified callback with the span as the active span.
640        /// The call back is provided the span and trace ids. At the end of the callback, we wait
641        /// for the span to be processed and then return all of the spans that were processed for
642        /// this particular test.
643        #[allow(clippy::future_not_send)]
644        pub(crate) async fn start<F, R>(name: &'static str, f: F) -> Result<Vec<SpanData>, String>
645        where
646            F: FnOnce(TraceId, SpanId) -> R + Send,
647            R: Future<Output = ()> + Send,
648        {
649            let tracer = CacheProcessor::tracer();
650            let span = tracer.start(name);
651            let span_id = span.span_context().span_id();
652            let trace_id = span.span_context().trace_id();
653            let cache = CACHE
654                .get()
655                .expect("cache should be initialized with cache tracer");
656            let span_rx = cache.subscribe(span_id);
657            async {
658                let _guard = mark_span_as_active(span);
659                f(trace_id, span_id).with_current_context().await;
660            }
661            .await;
662
663            // wait for test span to be processed.
664            span_rx.await.unwrap();
665
666            // remove and return the spans processed for this test.
667            let mut data = cache.data.write().map_err(|e| format!("{e:?}"))?;
668            Ok(data.remove(&trace_id).unwrap_or_default())
669        }
670
671        static CACHE: once_cell::sync::OnceCell<CacheProcessor> = once_cell::sync::OnceCell::new();
672
673        #[derive(Debug, Clone, Default)]
674        struct CacheProcessor {
675            data: Arc<RwLock<HashMap<TraceId, Vec<SpanData>>>>,
676            notifications: Arc<RwLock<HashMap<SpanId, tokio::sync::oneshot::Sender<()>>>>,
677        }
678
679        impl CacheProcessor {
680            /// Initializes the [`CACHE`] and sets the global `opentelemetry::global::tracer_provider` and `tracing_subscriber::global_default`.
681            /// These initializations occur safely behind a `OnceCell` initialization, so they can be used by several tests.
682            fn tracer() -> BoxedTracer {
683                use tracing_subscriber::layer::SubscriberExt;
684
685                CACHE.get_or_init(|| {
686                    let processor = Self::default();
687                    let provider = opentelemetry_sdk::trace::SdkTracerProvider::builder()
688                        .with_span_processor(processor.clone())
689                        .build();
690                    opentelemetry::global::set_tracer_provider(provider.clone());
691                    let tracer = provider.tracer("test_channel");
692                    let telemetry = tracing_opentelemetry::layer().with_tracer(tracer);
693                    let subscriber = tracing_subscriber::Registry::default().with(telemetry);
694                    tracing::subscriber::set_global_default(subscriber)
695                        .expect("tracing subscriber already set");
696                    processor
697                });
698                opentelemetry::global::tracer("test_channel")
699            }
700
701            /// Ensure that a [`Notification`] exists for the provided span id.
702            fn subscribe(&self, span_id: SpanId) -> oneshot::Receiver<()> {
703                let (tx, rx) = oneshot::channel();
704                self.notifications.write().unwrap().insert(span_id, tx);
705                rx
706            }
707        }
708
709        impl SpanProcessor for CacheProcessor {
710            fn on_start(
711                &self,
712                _span: &mut opentelemetry_sdk::trace::Span,
713                _cx: &opentelemetry::Context,
714            ) {
715            }
716
717            fn on_end(&self, span: SpanData) {
718                let trace_id = span.span_context.trace_id();
719                let span_id = span.span_context.span_id();
720                {
721                    let mut data = self
722                        .data
723                        .write()
724                        .expect("failed to write access cache span data");
725                    data.entry(trace_id).or_default().push(span);
726                }
727
728                if let Some(notify) = self
729                    .notifications
730                    .write()
731                    .expect("failed to read access span notifications during span processing")
732                    .remove(&span_id)
733                {
734                    notify.send(()).unwrap();
735                }
736            }
737
738            /// This is a no-op because spans are processed synchronously in `on_end`.
739            fn force_flush(&self) -> Result<(), OTelSdkError> {
740                Ok(())
741            }
742
743            fn shutdown(&self) -> Result<(), OTelSdkError> {
744                Ok(())
745            }
746
747            fn shutdown_with_timeout(
748                &self,
749                _timeout: std::time::Duration,
750            ) -> Result<(), OTelSdkError> {
751                Ok(())
752            }
753        }
754
755        /// Get the Opentelemetry attribute value for the provided key.
756        #[must_use]
757        pub(super) fn get_span_attribute(
758            span_data: &SpanData,
759            key: &'static str,
760        ) -> Option<String> {
761            span_data
762                .attributes
763                .iter()
764                .find(|attr| attr.key.to_string() == *key)
765                .map(|kv| kv.value.to_string())
766        }
767    }
768}