qcs_api_client_grpc/tonic/
mod.rs

1use std::convert::Infallible;
2
3/// QCS Middleware for [`tonic`] clients.
4///
5use futures_util::pin_mut;
6use http::Request;
7use http_body::Body;
8use http_body_util::BodyExt;
9use tonic::body::BoxBody;
10
11mod channel;
12mod common;
13mod error;
14#[cfg(feature = "grpc-web")]
15mod grpc_web;
16mod refresh;
17mod retry;
18#[cfg(feature = "tracing")]
19mod trace;
20
21pub use channel::*;
22pub use error::*;
23#[cfg(feature = "grpc-web")]
24pub use grpc_web::*;
25pub use refresh::*;
26pub use retry::*;
27#[cfg(feature = "tracing")]
28pub use trace::*;
29
30/// An error observed while duplicating a request body. This may be returned by any
31/// [`tower::Service`] that duplicates a request body for the purpose of retrying a request.
32#[derive(Debug, thiserror::Error)]
33pub enum RequestBodyDuplicationError {
34    /// The inner service returned an error from the server, or the client cancelled the
35    /// request.
36    #[error(transparent)]
37    Status(#[from] tonic::Status),
38    /// Failed to read the request body for cloning.
39    #[error("failed to read request body for request clone: {0}")]
40    HttpBody(#[from] http::Error),
41}
42
43impl From<RequestBodyDuplicationError> for tonic::Status {
44    fn from(err: RequestBodyDuplicationError) -> tonic::Status {
45        match err {
46            RequestBodyDuplicationError::Status(status) => status,
47            RequestBodyDuplicationError::HttpBody(error) => tonic::Status::cancelled(format!(
48                "failed to read request body for request clone: {error}"
49            )),
50        }
51    }
52}
53
54type RequestBodyDuplicationResult<T> = Result<T, RequestBodyDuplicationError>;
55
56/// This function should only be used with Unary requests; Stream requests are
57/// untested. It eagerly collects all request data into a buffer, consuming the
58/// original stream. Additionally, it assumes that all frames are data frames
59/// (i.e. the stream cannot contain any trailers); if a trailer frame is found,
60/// the cancelled status will be returned.
61async fn build_duplicate_frame_bytes(
62    mut request: Request<tonic::body::BoxBody>,
63) -> RequestBodyDuplicationResult<(tonic::body::BoxBody, tonic::body::BoxBody)> {
64    let mut bytes = Vec::new();
65
66    let body = request.body_mut();
67    pin_mut!(body);
68    while let Some(result) = std::future::poll_fn(|cx| body.as_mut().poll_frame(cx)).await {
69        let frame_bytes = result?.into_data().map_err(|frame| {
70            tonic::Status::cancelled(format!(
71                "cannot duplicate a frame that is not a data frame: {frame:?}"
72            ))
73        })?;
74        bytes.extend(frame_bytes);
75    }
76
77    let bytes = http_body_util::Full::from(bytes)
78        .map_err(|_: Infallible| -> tonic::Status { unreachable!() });
79    Ok((
80        tonic::body::BoxBody::new(bytes.clone()),
81        tonic::body::BoxBody::new(bytes),
82    ))
83}
84
85/// This function should only be used with Unary requests; Stream requests are
86/// untested. See comment on `build_duplicate_frame_bytes`.
87async fn build_duplicate_request(
88    req: Request<BoxBody>,
89) -> RequestBodyDuplicationResult<(Request<BoxBody>, Request<BoxBody>)> {
90    let mut builder_1 = Request::builder()
91        .method(req.method().clone())
92        .uri(req.uri().clone())
93        .version(req.version());
94
95    let mut builder_2 = Request::builder()
96        .method(req.method().clone())
97        .uri(req.uri().clone())
98        .version(req.version());
99
100    for (key, val) in req.headers() {
101        builder_1 = builder_1.header(key.clone(), val.clone());
102        builder_2 = builder_2.header(key.clone(), val.clone());
103    }
104
105    let (body_1, body_2) = build_duplicate_frame_bytes(req).await?;
106
107    let req_1 = builder_1.body(body_1)?;
108
109    let req_2 = builder_2.body(body_2)?;
110
111    Ok((req_1, req_2))
112}
113
114/// This module manages a gRPC server-client connection over a Unix domain socket. Useful for unit testing
115/// servers or clients within unit tests - supports parallelization within same process and
116/// requires no port management.
117///
118/// Derived largely from <https://stackoverflow.com/a/71808401> and
119/// <https://github.com/hyperium/tonic/tree/master/examples/src/uds>.
120#[cfg(test)]
121pub(crate) mod uds_grpc_stream {
122    use hyper_util::rt::TokioIo;
123    use opentelemetry::trace::FutureExt;
124    use std::convert::Infallible;
125    use tempfile::TempDir;
126    use tokio::net::UnixStream;
127    use tokio_stream::wrappers::UnixListenerStream;
128    use tonic::server::NamedService;
129    use tonic::transport::{Channel, Endpoint, Server};
130
131    /// The can be any valid URL. It is necessary to initialize an [`Endpoint`].
132    #[allow(dead_code)]
133    static FAUX_URL: &str = "http://api.example.rigetti.com";
134
135    fn build_server_stream(path: String) -> Result<UnixListenerStream, Error> {
136        let uds =
137            tokio::net::UnixListener::bind(path.clone()).map_err(|_| Error::BindUnixPath(path))?;
138        Ok(UnixListenerStream::new(uds))
139    }
140
141    async fn build_client_channel(path: String) -> Result<Channel, Error> {
142        let connector = tower::service_fn(move |_: tonic::transport::Uri| {
143            let path = path.clone();
144            async move {
145                let connection = UnixStream::connect(path).await?;
146                Ok::<_, std::io::Error>(TokioIo::new(connection))
147            }
148        });
149        let channel = Endpoint::try_from(FAUX_URL)
150            .map_err(|source| Error::Endpoint {
151                url: FAUX_URL.to_string(),
152                source,
153            })?
154            .connect_with_connector(connector)
155            .await
156            .map_err(|source| Error::Connect { source })?;
157        Ok(channel)
158    }
159
160    /// Errors when connecting a server-client [`Channel`] over a Unix domain socket for testing gRPC
161    /// services.
162    #[derive(thiserror::Error, Debug)]
163    pub enum Error {
164        /// Failed to initialize an [`Endpoint`] for the provided URL.
165        #[error("failed to initialize endpoint for {url}: {source}")]
166        Endpoint {
167            /// The URL that failed to initialize.
168            url: String,
169            /// The source of the error.
170            #[source]
171            source: tonic::transport::Error,
172        },
173        /// Failed to connect to the provided endpoint.
174        #[error("failed to connect to endpoint: {source}")]
175        Connect {
176            /// The source of the error.
177            #[source]
178            source: tonic::transport::Error,
179        },
180        /// Failed to bind the provided path as a Unix domain socket.
181        #[error("failed to bind path as unix listener: {0}")]
182        BindUnixPath(String),
183        /// Failed to initialize a temporary file for the Unix domain socket.
184        #[error("failed in initialize tempfile: {0}")]
185        TempFile(#[from] std::io::Error),
186        /// Failed to convert the tempfile to an [`OsString`].
187        #[error("failed to convert tempfile to OsString")]
188        TempFileOsString,
189        /// Failed to bind to the Unix domain socket.
190        #[error("failed to bind to UDS: {0}")]
191        TonicTransport(#[from] tonic::transport::Error),
192    }
193
194    /// Serve the provided gRPC service over a Unix domain socket for the duration of the
195    /// provided callback.
196    ///
197    /// # Errors
198    ///
199    /// See [`Error`].
200    #[allow(clippy::significant_drop_tightening)]
201    pub async fn serve<S, F, R>(service: S, f: F) -> Result<(), Error>
202    where
203        S: tower::Service<
204                http::Request<tonic::body::BoxBody>,
205                Response = http::Response<tonic::body::BoxBody>,
206                Error = Infallible,
207            > + NamedService
208            + Clone
209            + Send
210            + 'static,
211        S::Future: Send,
212        F: FnOnce(Channel) -> R + Send,
213        R: std::future::Future<Output = ()> + Send,
214    {
215        let directory = TempDir::new()?;
216        let file = directory.path().as_os_str();
217        let file = file.to_os_string();
218        let file = file.into_string().map_err(|_| Error::TempFileOsString)?;
219        let file = format!("{file}/test.sock");
220        let stream = build_server_stream(file.clone())?;
221
222        let channel = build_client_channel(file).await?;
223        let serve_future = Server::builder()
224            .add_service(service)
225            .serve_with_incoming(stream);
226
227        tokio::select! {
228           result = serve_future => result.map_err(Error::TonicTransport),
229           () = f(channel).with_current_context() => Ok(()),
230        }
231    }
232}
233
234#[cfg(test)]
235#[cfg(feature = "tracing-opentelemetry")]
236mod otel_tests {
237    use opentelemetry::propagation::TextMapPropagator;
238    use opentelemetry::trace::{TraceContextExt, TraceId};
239    use opentelemetry_http::HeaderExtractor;
240    use opentelemetry_sdk::propagation::TraceContextPropagator;
241    use serde::{Deserialize, Serialize};
242    use std::time::{Duration, SystemTime};
243    use tonic::codegen::http::{HeaderMap, HeaderValue};
244    use tonic::server::NamedService;
245    use tonic::Request;
246    use tonic_health::pb::health_check_response::ServingStatus;
247    use tonic_health::pb::health_server::{Health, HealthServer};
248    use tonic_health::{pb::health_client::HealthClient, server::HealthService};
249
250    use crate::tonic::{uds_grpc_stream, wrap_channel_with_tracing};
251    use qcs_api_client_common::configuration::ClientConfiguration;
252    use qcs_api_client_common::configuration::{AuthServer, OAuthSession, RefreshToken};
253
254    static HEALTH_CHECK_PATH: &str = "/grpc.health.v1.Health/Check";
255
256    const FAUX_BASE_URL: &str = "http://api.example.rigetti.com";
257
258    /// Test that when tracing is enabled and no filter is set, any request is properly traced.
259    #[tokio::test]
260    async fn test_tracing_enabled_no_filter() {
261        use qcs_api_client_common::tracing_configuration::TracingConfiguration;
262
263        let tracing_configuration = TracingConfiguration::builder()
264            .set_propagate_otel_context(true)
265            .build();
266        let client_config = ClientConfiguration::builder()
267            .tracing_configuration(Some(tracing_configuration))
268            .oauth_session(Some(OAuthSession::from_refresh_token(
269                RefreshToken::new("refresh_token".to_string()),
270                AuthServer::default(),
271                Some(create_jwt()),
272            )))
273            .build()
274            .expect("should be able to build client config");
275        assert_grpc_health_check_traced(client_config).await;
276    }
277
278    /// Test that when tracing is enabled, no filter is set, and OTel context propagation is
279    /// disabled, any request is properly traced without propagation.
280    #[tokio::test]
281    async fn test_tracing_enabled_no_filter_nor_otel_context_propagation() {
282        use qcs_api_client_common::tracing_configuration::TracingConfiguration;
283
284        let tracing_configuration = TracingConfiguration::default();
285        let client_config = ClientConfiguration::builder()
286            .tracing_configuration(Some(tracing_configuration))
287            .oauth_session(Some(OAuthSession::from_refresh_token(
288                RefreshToken::new("refresh_token".to_string()),
289                AuthServer::default(),
290                Some(create_jwt()),
291            )))
292            .build()
293            .expect("failed to build client config");
294        assert_grpc_health_check_traced(client_config).await;
295    }
296
297    /// Test that when tracing is enabled and the filter matches the gRPC request, the request is
298    /// properly traced.
299    #[tokio::test]
300    async fn test_tracing_enabled_filter_passed() {
301        use qcs_api_client_common::tracing_configuration::{TracingConfiguration, TracingFilter};
302
303        let tracing_filter = TracingFilter::builder()
304            .parse_strs_and_set_paths(&[HEALTH_CHECK_PATH])
305            .expect("gRPC healthcheck path should be valid filter path")
306            .build();
307
308        let tracing_configuration = TracingConfiguration::builder()
309            .set_filter(Some(tracing_filter))
310            .set_propagate_otel_context(true)
311            .build();
312
313        let client_config = ClientConfiguration::builder()
314            .tracing_configuration(Some(tracing_configuration))
315            .oauth_session(Some(OAuthSession::from_refresh_token(
316                RefreshToken::new("refresh_token".to_string()),
317                AuthServer::default(),
318                Some(create_jwt()),
319            )))
320            .build()
321            .expect("failed to build client config");
322        assert_grpc_health_check_traced(client_config).await;
323    }
324
325    /// Checks that the the [`RefreshService`] propagates the trace context via the traceparent metadata header and that the gRPC
326    /// request span is properly created (ie the span duration is reasonable).
327    #[allow(clippy::future_not_send)]
328    async fn assert_grpc_health_check_traced(client_configuration: ClientConfiguration) {
329        use opentelemetry::trace::FutureExt;
330
331        let propagate_otel_context = client_configuration.tracing_configuration().is_some_and(
332            qcs_api_client_common::tracing_configuration::TracingConfiguration::propagate_otel_context,
333        );
334        let spans = tracing_test::start(
335            "test_trace_id_propagation",
336            |trace_id, _span_id| async move {
337                let sleepy_health_service = SleepyHealthService {
338                    sleep_time: Duration::from_millis(50),
339                };
340
341                let interceptor = move |req| {
342                    if propagate_otel_context {
343                        validate_trace_id_propagated(trace_id, req)
344                    } else {
345                        validate_otel_context_not_propagated(req)
346                    }
347                };
348                let health_server =
349                    HealthServer::with_interceptor(sleepy_health_service, interceptor);
350
351                uds_grpc_stream::serve(health_server, |channel| {
352                    async {
353                        let response = HealthClient::new(wrap_channel_with_tracing(
354                            channel,
355                            FAUX_BASE_URL.to_string(),
356                            client_configuration
357                                .tracing_configuration()
358                                .unwrap()
359                                .clone(),
360                        ))
361                        .check(Request::new(tonic_health::pb::HealthCheckRequest {
362                            service: <HealthServer<HealthService> as NamedService>::NAME
363                                .to_string(),
364                        }))
365                        .await
366                        .unwrap();
367                        assert_eq!(response.into_inner().status(), ServingStatus::Serving);
368                    }
369                    .with_current_context()
370                })
371                .await
372                .unwrap();
373            },
374        )
375        .await
376        .unwrap();
377
378        let grpc_span = spans
379            .iter()
380            .find(|span| span.name == *HEALTH_CHECK_PATH)
381            .expect("failed to find gRPC span");
382        let duration = grpc_span
383            .end_time
384            .duration_since(grpc_span.start_time)
385            .expect("span should have valid timestamps");
386        assert!(duration.as_millis() >= 50u128);
387
388        let status_code_attribute =
389            tracing_test::get_span_attribute(grpc_span, "rpc.grpc.status_code")
390                .expect("gRPC span should have status code attribute");
391        assert_eq!(status_code_attribute, (tonic::Code::Ok as u8).to_string());
392    }
393
394    /// Test that when tracing is enabled but the request does not match the configured filter, the
395    /// request is not traced.
396    #[tokio::test]
397    async fn test_tracing_enabled_filter_not_passed() {
398        use qcs_api_client_common::tracing_configuration::{TracingConfiguration, TracingFilter};
399
400        let tracing_filter = TracingFilter::builder()
401            .parse_strs_and_set_paths(&[HEALTH_CHECK_PATH])
402            .expect("healthcheck path should be a valid filter path")
403            .set_is_negated(true)
404            .build();
405
406        let tracing_configuration = TracingConfiguration::builder()
407            .set_filter(Some(tracing_filter))
408            .set_propagate_otel_context(true)
409            .build();
410
411        let client_config = ClientConfiguration::builder()
412            .tracing_configuration(Some(tracing_configuration))
413            .oauth_session(Some(OAuthSession::from_refresh_token(
414                RefreshToken::new("refresh_token".to_string()),
415                AuthServer::default(),
416                Some(create_jwt()),
417            )))
418            .build()
419            .expect("should be able to build client config");
420
421        assert_grpc_health_check_not_traced(client_config.clone()).await;
422    }
423
424    /// Check that the traceparent metadata header is not set on the gRPC request and no tracing
425    /// spans are produced for the gRPC request.
426    #[allow(clippy::future_not_send)]
427    async fn assert_grpc_health_check_not_traced(client_configuration: ClientConfiguration) {
428        use opentelemetry::trace::FutureExt;
429
430        let spans =
431            tracing_test::start("test_tracing_disabled", |_trace_id, _span_id| async move {
432                let interceptor = validate_otel_context_not_propagated;
433                let health_server = HealthServer::with_interceptor(
434                    SleepyHealthService {
435                        sleep_time: Duration::from_millis(0),
436                    },
437                    interceptor,
438                );
439
440                uds_grpc_stream::serve(health_server, |channel| {
441                    async {
442                        let response = HealthClient::new(wrap_channel_with_tracing(
443                            channel,
444                            FAUX_BASE_URL.to_string(),
445                            client_configuration
446                                .tracing_configuration()
447                                .unwrap()
448                                .clone(),
449                        ))
450                        .check(Request::new(tonic_health::pb::HealthCheckRequest {
451                            service: <HealthServer<HealthService> as NamedService>::NAME
452                                .to_string(),
453                        }))
454                        .await
455                        .unwrap();
456                        assert_eq!(response.into_inner().status(), ServingStatus::Serving);
457                    }
458                    .with_current_context()
459                })
460                .await
461                .unwrap();
462            })
463            .await
464            .unwrap();
465
466        assert!(spans.iter().all(|span| { span.name != *HEALTH_CHECK_PATH }));
467    }
468
469    const JWT_SECRET: &[u8] = b"top-secret";
470
471    #[derive(Clone, Debug, Serialize, Deserialize)]
472    struct Claims {
473        sub: String,
474        exp: u64,
475    }
476
477    /// Create an HS256 signed JWT token with sub and exp claims. This is good enough to pass the
478    /// [`RefreshService`] token validation.
479    pub(crate) fn create_jwt() -> String {
480        use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
481        let expiration = SystemTime::now()
482            .checked_add(Duration::from_secs(60))
483            .unwrap()
484            .duration_since(SystemTime::UNIX_EPOCH)
485            .unwrap()
486            .as_secs();
487
488        let claims = Claims {
489            sub: "test-uid".to_string(),
490            exp: expiration,
491        };
492        // The client doesn't check the signature, so for convenience here, we just sign with HS256
493        // instead of RS256.
494        let header = Header::new(Algorithm::HS256);
495        encode(&header, &claims, &EncodingKey::from_secret(JWT_SECRET)).unwrap()
496    }
497
498    #[derive(Debug, thiserror::Error)]
499    #[allow(variant_size_differences)]
500    enum ServerAssertionError {
501        #[error("trace id did not propagate to server: {0}")]
502        UnexpectedTraceId(String),
503        #[error("otel context headers unexpectedly sent to server")]
504        UnexpectedOTelContextHeaders,
505    }
506
507    impl From<ServerAssertionError> for tonic::Status {
508        fn from(server_assertion_error: ServerAssertionError) -> Self {
509            Self::invalid_argument(server_assertion_error.to_string())
510        }
511    }
512
513    /// Given an incoming gRPC request, validate that the the specified [`TraceId`] is propagated
514    /// via the `traceparent` metadata header.
515    fn validate_trace_id_propagated(
516        trace_id: TraceId,
517        req: Request<()>,
518    ) -> Result<Request<()>, tonic::Status> {
519        req.metadata()
520            .get("traceparent")
521            .ok_or_else(|| {
522                ServerAssertionError::UnexpectedTraceId(
523                    "request traceparent metadata not present".to_string(),
524                )
525            })
526            .and_then(|traceparent| {
527                let mut headers = HeaderMap::new();
528                headers.insert(
529                    "traceparent",
530                    HeaderValue::from_str(traceparent.to_str().map_err(|_| {
531                        ServerAssertionError::UnexpectedTraceId(
532                            "failed to deserialize trace parent".to_string(),
533                        )
534                    })?)
535                    .map_err(|_| {
536                        ServerAssertionError::UnexpectedTraceId(
537                            "failed to serialize trace parent as HeaderValue".to_string(),
538                        )
539                    })?,
540                );
541                Ok(headers)
542            })
543            .and_then(|headers| {
544                let extractor = HeaderExtractor(&headers);
545                let propagator = TraceContextPropagator::new();
546                let context = propagator.extract(&extractor);
547                let propagated_trace_id = context.span().span_context().trace_id();
548                if propagated_trace_id == trace_id {
549                    Ok(req)
550                } else {
551                    Err(ServerAssertionError::UnexpectedTraceId(format!(
552                        "expected trace id {trace_id}, got {propagated_trace_id}",
553                    )))
554                }
555            })
556            .map_err(Into::into)
557    }
558
559    /// Simply validate that the `traceparent` and `tracestate` metadata headers are not present
560    /// on the incoming gRPC.
561    fn validate_otel_context_not_propagated(
562        req: Request<()>,
563    ) -> Result<Request<()>, tonic::Status> {
564        if req.metadata().get("traceparent").is_some() || req.metadata().get("tracestate").is_some()
565        {
566            Err(ServerAssertionError::UnexpectedOTelContextHeaders.into())
567        } else {
568            Ok(req)
569        }
570    }
571
572    /// An implementation of the gRPC [`HealthService`] that sleeps for the configured duration on before returning a response.
573    /// This is useful for making assertions on span durations. It is also necessary in order to
574    /// wrap the [`HealthServer`] in an interceptor, which is not possible with public methods in
575    /// the health service crate.
576    ///
577    /// Derived in part from <https://github.com/hyperium/tonic/blob/master/tonic-health/src/generated/grpc.health.v1.rs/>
578    #[derive(Clone)]
579    struct SleepyHealthService {
580        sleep_time: Duration,
581    }
582
583    #[tonic::async_trait]
584    impl Health for SleepyHealthService {
585        async fn check(
586            &self,
587            _request: Request<tonic_health::pb::HealthCheckRequest>,
588        ) -> Result<tonic::Response<tonic_health::pb::HealthCheckResponse>, tonic::Status> {
589            tokio::time::sleep(self.sleep_time).await;
590            let response = tonic_health::pb::HealthCheckResponse {
591                status: ServingStatus::Serving as i32,
592            };
593            Ok(tonic::Response::new(response))
594        }
595
596        type WatchStream = tokio_stream::wrappers::ReceiverStream<
597            Result<tonic_health::pb::HealthCheckResponse, tonic::Status>,
598        >;
599
600        async fn watch(
601            &self,
602            _request: Request<tonic_health::pb::HealthCheckRequest>,
603        ) -> Result<tonic::Response<Self::WatchStream>, tonic::Status> {
604            let (tx, rx) = tokio::sync::mpsc::channel(1);
605            let response = tonic_health::pb::HealthCheckResponse {
606                status: ServingStatus::Serving as i32,
607            };
608            tx.send(Ok(response)).await.unwrap();
609            Ok(tonic::Response::new(
610                tokio_stream::wrappers::ReceiverStream::new(rx),
611            ))
612        }
613    }
614
615    /// We need a single global ``SpanProcessor`` because these tests have to work using
616    /// ``opentelemetry::global`` and ``tracing_subscriber::set_global_default``. Otherwise,
617    /// we can't make guarantees about where the spans are processed and therefore could
618    /// not make assertions about the traced spans.
619    mod tracing_test {
620        use futures_util::Future;
621        use opentelemetry::global::BoxedTracer;
622        use opentelemetry::trace::{
623            mark_span_as_active, FutureExt, Span, SpanId, TraceId, TraceResult, Tracer,
624            TracerProvider,
625        };
626        use opentelemetry_sdk::export::trace::SpanData;
627        use opentelemetry_sdk::trace::SpanProcessor;
628        use std::collections::HashMap;
629        use std::sync::{Arc, RwLock};
630
631        /// Start a new test span and run the specified callback with the span as the active span.
632        /// The call back is provided the span and trace ids. At the end of the callback, we wait
633        /// for the span to be processed and then return all of the spans that were processed for
634        /// this particular test.
635        #[allow(clippy::future_not_send)]
636        pub(crate) async fn start<F, R>(name: &'static str, f: F) -> Result<Vec<SpanData>, String>
637        where
638            F: FnOnce(TraceId, SpanId) -> R + Send,
639            R: Future<Output = ()> + Send,
640        {
641            let tracer = CacheProcessor::tracer();
642            let span = tracer.start(name);
643            let span_id = span.span_context().span_id();
644            let trace_id = span.span_context().trace_id();
645            let cache = CACHE
646                .get()
647                .expect("cache should be initialized with cache tracer");
648            cache.subscribe(span_id)?;
649            async {
650                let _guard = mark_span_as_active(span);
651                f(trace_id, span_id).with_current_context().await;
652            }
653            .await;
654
655            // wait for test span to be processed.
656            cache.notified(span_id).await?;
657
658            // remove and return the spans processed for this test.
659            let mut data = cache.data.write().map_err(|e| e.to_string())?;
660            Ok(data.remove(&trace_id).unwrap_or_default())
661        }
662
663        static CACHE: once_cell::sync::OnceCell<CacheProcessor> = once_cell::sync::OnceCell::new();
664
665        #[derive(Debug, Clone, Default)]
666        struct CacheProcessor {
667            data: Arc<RwLock<HashMap<TraceId, Vec<SpanData>>>>,
668            notifications: Arc<RwLock<HashMap<SpanId, Arc<tokio::sync::Notify>>>>,
669        }
670
671        impl CacheProcessor {
672            /// Initializes the [`CACHE`] and sets the global `opentelemetry::global::tracer_provider` and `tracing_subscriber::global_default`.
673            /// These initializations occur safely behind a `OnceCell` initialization, so they can be used by several tests.
674            fn tracer() -> BoxedTracer {
675                use tracing_subscriber::layer::SubscriberExt;
676
677                CACHE.get_or_init(|| {
678                    let processor = Self::default();
679                    let provider = opentelemetry_sdk::trace::TracerProvider::builder()
680                        .with_span_processor(processor.clone())
681                        .build();
682                    opentelemetry::global::set_tracer_provider(provider.clone());
683                    let tracer = provider.tracer("test_channel");
684                    let telemetry = tracing_opentelemetry::layer().with_tracer(tracer);
685                    let subscriber = tracing_subscriber::Registry::default().with(telemetry);
686                    tracing::subscriber::set_global_default(subscriber)
687                        .expect("tracing subscriber already set");
688                    processor
689                });
690                opentelemetry::global::tracer("test_channel")
691            }
692
693            /// Ensure that a [`Notification`] exists for the provided span id.
694            fn subscribe(&self, span_id: SpanId) -> Result<(), String> {
695                let notify = Arc::new(tokio::sync::Notify::new());
696                self.notifications
697                    .write()
698                    .map_err(|e| e.to_string())?
699                    .insert(span_id, notify);
700                Ok(())
701            }
702
703            /// Wait for the specified [`SpanId`] to be processed.
704            async fn notified(&self, span_id: SpanId) -> Result<(), String> {
705                let notify = {
706                    let notifications = self.notifications.read().map_err(|e| format!("{e}"))?;
707                    notifications
708                        .get(&span_id)
709                        .ok_or("span notification never subscribed")?
710                        .clone()
711                };
712                notify.notified().await;
713                Ok(())
714            }
715        }
716
717        impl SpanProcessor for CacheProcessor {
718            fn on_start(
719                &self,
720                _span: &mut opentelemetry_sdk::trace::Span,
721                _cx: &opentelemetry::Context,
722            ) {
723            }
724
725            fn on_end(&self, span: SpanData) {
726                let trace_id = span.span_context.trace_id();
727                let span_id = span.span_context.span_id();
728                {
729                    let mut data = self
730                        .data
731                        .write()
732                        .expect("failed to write access cache span data");
733                    if let Some(spans) = data.get_mut(&trace_id) {
734                        spans.push(span);
735                    } else {
736                        data.insert(trace_id, vec![span]);
737                    }
738                }
739
740                if let Some(notify) = self
741                    .notifications
742                    .read()
743                    .expect("failed to read access span notifications during span processing")
744                    .get(&span_id)
745                {
746                    notify.notify_waiters();
747                }
748            }
749
750            /// This is a no-op because spans are processed synchronously in `on_end`.
751            fn force_flush(&self) -> TraceResult<()> {
752                Ok(())
753            }
754
755            fn shutdown(&self) -> TraceResult<()> {
756                Ok(())
757            }
758        }
759
760        /// Get the Opentelemetry attribute value for the provided key.
761        #[must_use]
762        pub(super) fn get_span_attribute(
763            span_data: &SpanData,
764            key: &'static str,
765        ) -> Option<String> {
766            span_data
767                .attributes
768                .iter()
769                .find(|attr| attr.key.to_string() == *key)
770                .map(|kv| kv.value.to_string())
771        }
772    }
773}