1use std::convert::Infallible;
2
3use futures_util::pin_mut;
6use http::Request;
7use http_body::Body;
8use http_body_util::BodyExt;
9use tonic::body::BoxBody;
10
11mod channel;
12mod common;
13mod error;
14#[cfg(feature = "grpc-web")]
15mod grpc_web;
16mod refresh;
17mod retry;
18#[cfg(feature = "tracing")]
19mod trace;
20
21pub use channel::*;
22pub use error::*;
23#[cfg(feature = "grpc-web")]
24pub use grpc_web::*;
25pub use refresh::*;
26pub use retry::*;
27#[cfg(feature = "tracing")]
28pub use trace::*;
29
30#[derive(Debug, thiserror::Error)]
33pub enum RequestBodyDuplicationError {
34 #[error(transparent)]
37 Status(#[from] tonic::Status),
38 #[error("failed to read request body for request clone: {0}")]
40 HttpBody(#[from] http::Error),
41}
42
43impl From<RequestBodyDuplicationError> for tonic::Status {
44 fn from(err: RequestBodyDuplicationError) -> tonic::Status {
45 match err {
46 RequestBodyDuplicationError::Status(status) => status,
47 RequestBodyDuplicationError::HttpBody(error) => tonic::Status::cancelled(format!(
48 "failed to read request body for request clone: {error}"
49 )),
50 }
51 }
52}
53
54type RequestBodyDuplicationResult<T> = Result<T, RequestBodyDuplicationError>;
55
56async fn build_duplicate_frame_bytes(
62 mut request: Request<tonic::body::BoxBody>,
63) -> RequestBodyDuplicationResult<(tonic::body::BoxBody, tonic::body::BoxBody)> {
64 let mut bytes = Vec::new();
65
66 let body = request.body_mut();
67 pin_mut!(body);
68 while let Some(result) = std::future::poll_fn(|cx| body.as_mut().poll_frame(cx)).await {
69 let frame_bytes = result?.into_data().map_err(|frame| {
70 tonic::Status::cancelled(format!(
71 "cannot duplicate a frame that is not a data frame: {frame:?}"
72 ))
73 })?;
74 bytes.extend(frame_bytes);
75 }
76
77 let bytes = http_body_util::Full::from(bytes)
78 .map_err(|_: Infallible| -> tonic::Status { unreachable!() });
79 Ok((
80 tonic::body::BoxBody::new(bytes.clone()),
81 tonic::body::BoxBody::new(bytes),
82 ))
83}
84
85async fn build_duplicate_request(
88 req: Request<BoxBody>,
89) -> RequestBodyDuplicationResult<(Request<BoxBody>, Request<BoxBody>)> {
90 let mut builder_1 = Request::builder()
91 .method(req.method().clone())
92 .uri(req.uri().clone())
93 .version(req.version());
94
95 let mut builder_2 = Request::builder()
96 .method(req.method().clone())
97 .uri(req.uri().clone())
98 .version(req.version());
99
100 for (key, val) in req.headers() {
101 builder_1 = builder_1.header(key.clone(), val.clone());
102 builder_2 = builder_2.header(key.clone(), val.clone());
103 }
104
105 let (body_1, body_2) = build_duplicate_frame_bytes(req).await?;
106
107 let req_1 = builder_1.body(body_1)?;
108
109 let req_2 = builder_2.body(body_2)?;
110
111 Ok((req_1, req_2))
112}
113
114#[cfg(test)]
121pub(crate) mod uds_grpc_stream {
122 use hyper_util::rt::TokioIo;
123 use opentelemetry::trace::FutureExt;
124 use std::convert::Infallible;
125 use tempfile::TempDir;
126 use tokio::net::UnixStream;
127 use tokio_stream::wrappers::UnixListenerStream;
128 use tonic::server::NamedService;
129 use tonic::transport::{Channel, Endpoint, Server};
130
131 #[allow(dead_code)]
133 static FAUX_URL: &str = "http://api.example.rigetti.com";
134
135 fn build_server_stream(path: String) -> Result<UnixListenerStream, Error> {
136 let uds =
137 tokio::net::UnixListener::bind(path.clone()).map_err(|_| Error::BindUnixPath(path))?;
138 Ok(UnixListenerStream::new(uds))
139 }
140
141 async fn build_client_channel(path: String) -> Result<Channel, Error> {
142 let connector = tower::service_fn(move |_: tonic::transport::Uri| {
143 let path = path.clone();
144 async move {
145 let connection = UnixStream::connect(path).await?;
146 Ok::<_, std::io::Error>(TokioIo::new(connection))
147 }
148 });
149 let channel = Endpoint::try_from(FAUX_URL)
150 .map_err(|source| Error::Endpoint {
151 url: FAUX_URL.to_string(),
152 source,
153 })?
154 .connect_with_connector(connector)
155 .await
156 .map_err(|source| Error::Connect { source })?;
157 Ok(channel)
158 }
159
160 #[derive(thiserror::Error, Debug)]
163 pub enum Error {
164 #[error("failed to initialize endpoint for {url}: {source}")]
166 Endpoint {
167 url: String,
169 #[source]
171 source: tonic::transport::Error,
172 },
173 #[error("failed to connect to endpoint: {source}")]
175 Connect {
176 #[source]
178 source: tonic::transport::Error,
179 },
180 #[error("failed to bind path as unix listener: {0}")]
182 BindUnixPath(String),
183 #[error("failed in initialize tempfile: {0}")]
185 TempFile(#[from] std::io::Error),
186 #[error("failed to convert tempfile to OsString")]
188 TempFileOsString,
189 #[error("failed to bind to UDS: {0}")]
191 TonicTransport(#[from] tonic::transport::Error),
192 }
193
194 #[allow(clippy::significant_drop_tightening)]
201 pub async fn serve<S, F, R>(service: S, f: F) -> Result<(), Error>
202 where
203 S: tower::Service<
204 http::Request<tonic::body::BoxBody>,
205 Response = http::Response<tonic::body::BoxBody>,
206 Error = Infallible,
207 > + NamedService
208 + Clone
209 + Send
210 + 'static,
211 S::Future: Send,
212 F: FnOnce(Channel) -> R + Send,
213 R: std::future::Future<Output = ()> + Send,
214 {
215 let directory = TempDir::new()?;
216 let file = directory.path().as_os_str();
217 let file = file.to_os_string();
218 let file = file.into_string().map_err(|_| Error::TempFileOsString)?;
219 let file = format!("{file}/test.sock");
220 let stream = build_server_stream(file.clone())?;
221
222 let channel = build_client_channel(file).await?;
223 let serve_future = Server::builder()
224 .add_service(service)
225 .serve_with_incoming(stream);
226
227 tokio::select! {
228 result = serve_future => result.map_err(Error::TonicTransport),
229 () = f(channel).with_current_context() => Ok(()),
230 }
231 }
232}
233
234#[cfg(test)]
235#[cfg(feature = "tracing-opentelemetry")]
236mod otel_tests {
237 use opentelemetry::propagation::TextMapPropagator;
238 use opentelemetry::trace::{TraceContextExt, TraceId};
239 use opentelemetry_http::HeaderExtractor;
240 use opentelemetry_sdk::propagation::TraceContextPropagator;
241 use serde::{Deserialize, Serialize};
242 use std::time::{Duration, SystemTime};
243 use tonic::codegen::http::{HeaderMap, HeaderValue};
244 use tonic::server::NamedService;
245 use tonic::Request;
246 use tonic_health::pb::health_check_response::ServingStatus;
247 use tonic_health::pb::health_server::{Health, HealthServer};
248 use tonic_health::{pb::health_client::HealthClient, server::HealthService};
249
250 use crate::tonic::{uds_grpc_stream, wrap_channel_with_tracing};
251 use qcs_api_client_common::configuration::ClientConfiguration;
252 use qcs_api_client_common::configuration::{AuthServer, OAuthSession, RefreshToken};
253
254 static HEALTH_CHECK_PATH: &str = "/grpc.health.v1.Health/Check";
255
256 const FAUX_BASE_URL: &str = "http://api.example.rigetti.com";
257
258 #[tokio::test]
260 async fn test_tracing_enabled_no_filter() {
261 use qcs_api_client_common::tracing_configuration::TracingConfiguration;
262
263 let tracing_configuration = TracingConfiguration::builder()
264 .set_propagate_otel_context(true)
265 .build();
266 let client_config = ClientConfiguration::builder()
267 .tracing_configuration(Some(tracing_configuration))
268 .oauth_session(Some(OAuthSession::from_refresh_token(
269 RefreshToken::new("refresh_token".to_string()),
270 AuthServer::default(),
271 Some(create_jwt()),
272 )))
273 .build()
274 .expect("should be able to build client config");
275 assert_grpc_health_check_traced(client_config).await;
276 }
277
278 #[tokio::test]
281 async fn test_tracing_enabled_no_filter_nor_otel_context_propagation() {
282 use qcs_api_client_common::tracing_configuration::TracingConfiguration;
283
284 let tracing_configuration = TracingConfiguration::default();
285 let client_config = ClientConfiguration::builder()
286 .tracing_configuration(Some(tracing_configuration))
287 .oauth_session(Some(OAuthSession::from_refresh_token(
288 RefreshToken::new("refresh_token".to_string()),
289 AuthServer::default(),
290 Some(create_jwt()),
291 )))
292 .build()
293 .expect("failed to build client config");
294 assert_grpc_health_check_traced(client_config).await;
295 }
296
297 #[tokio::test]
300 async fn test_tracing_enabled_filter_passed() {
301 use qcs_api_client_common::tracing_configuration::{TracingConfiguration, TracingFilter};
302
303 let tracing_filter = TracingFilter::builder()
304 .parse_strs_and_set_paths(&[HEALTH_CHECK_PATH])
305 .expect("gRPC healthcheck path should be valid filter path")
306 .build();
307
308 let tracing_configuration = TracingConfiguration::builder()
309 .set_filter(Some(tracing_filter))
310 .set_propagate_otel_context(true)
311 .build();
312
313 let client_config = ClientConfiguration::builder()
314 .tracing_configuration(Some(tracing_configuration))
315 .oauth_session(Some(OAuthSession::from_refresh_token(
316 RefreshToken::new("refresh_token".to_string()),
317 AuthServer::default(),
318 Some(create_jwt()),
319 )))
320 .build()
321 .expect("failed to build client config");
322 assert_grpc_health_check_traced(client_config).await;
323 }
324
325 #[allow(clippy::future_not_send)]
328 async fn assert_grpc_health_check_traced(client_configuration: ClientConfiguration) {
329 use opentelemetry::trace::FutureExt;
330
331 let propagate_otel_context = client_configuration.tracing_configuration().is_some_and(
332 qcs_api_client_common::tracing_configuration::TracingConfiguration::propagate_otel_context,
333 );
334 let spans = tracing_test::start(
335 "test_trace_id_propagation",
336 |trace_id, _span_id| async move {
337 let sleepy_health_service = SleepyHealthService {
338 sleep_time: Duration::from_millis(50),
339 };
340
341 let interceptor = move |req| {
342 if propagate_otel_context {
343 validate_trace_id_propagated(trace_id, req)
344 } else {
345 validate_otel_context_not_propagated(req)
346 }
347 };
348 let health_server =
349 HealthServer::with_interceptor(sleepy_health_service, interceptor);
350
351 uds_grpc_stream::serve(health_server, |channel| {
352 async {
353 let response = HealthClient::new(wrap_channel_with_tracing(
354 channel,
355 FAUX_BASE_URL.to_string(),
356 client_configuration
357 .tracing_configuration()
358 .unwrap()
359 .clone(),
360 ))
361 .check(Request::new(tonic_health::pb::HealthCheckRequest {
362 service: <HealthServer<HealthService> as NamedService>::NAME
363 .to_string(),
364 }))
365 .await
366 .unwrap();
367 assert_eq!(response.into_inner().status(), ServingStatus::Serving);
368 }
369 .with_current_context()
370 })
371 .await
372 .unwrap();
373 },
374 )
375 .await
376 .unwrap();
377
378 let grpc_span = spans
379 .iter()
380 .find(|span| span.name == *HEALTH_CHECK_PATH)
381 .expect("failed to find gRPC span");
382 let duration = grpc_span
383 .end_time
384 .duration_since(grpc_span.start_time)
385 .expect("span should have valid timestamps");
386 assert!(duration.as_millis() >= 50u128);
387
388 let status_code_attribute =
389 tracing_test::get_span_attribute(grpc_span, "rpc.grpc.status_code")
390 .expect("gRPC span should have status code attribute");
391 assert_eq!(status_code_attribute, (tonic::Code::Ok as u8).to_string());
392 }
393
394 #[tokio::test]
397 async fn test_tracing_enabled_filter_not_passed() {
398 use qcs_api_client_common::tracing_configuration::{TracingConfiguration, TracingFilter};
399
400 let tracing_filter = TracingFilter::builder()
401 .parse_strs_and_set_paths(&[HEALTH_CHECK_PATH])
402 .expect("healthcheck path should be a valid filter path")
403 .set_is_negated(true)
404 .build();
405
406 let tracing_configuration = TracingConfiguration::builder()
407 .set_filter(Some(tracing_filter))
408 .set_propagate_otel_context(true)
409 .build();
410
411 let client_config = ClientConfiguration::builder()
412 .tracing_configuration(Some(tracing_configuration))
413 .oauth_session(Some(OAuthSession::from_refresh_token(
414 RefreshToken::new("refresh_token".to_string()),
415 AuthServer::default(),
416 Some(create_jwt()),
417 )))
418 .build()
419 .expect("should be able to build client config");
420
421 assert_grpc_health_check_not_traced(client_config.clone()).await;
422 }
423
424 #[allow(clippy::future_not_send)]
427 async fn assert_grpc_health_check_not_traced(client_configuration: ClientConfiguration) {
428 use opentelemetry::trace::FutureExt;
429
430 let spans =
431 tracing_test::start("test_tracing_disabled", |_trace_id, _span_id| async move {
432 let interceptor = validate_otel_context_not_propagated;
433 let health_server = HealthServer::with_interceptor(
434 SleepyHealthService {
435 sleep_time: Duration::from_millis(0),
436 },
437 interceptor,
438 );
439
440 uds_grpc_stream::serve(health_server, |channel| {
441 async {
442 let response = HealthClient::new(wrap_channel_with_tracing(
443 channel,
444 FAUX_BASE_URL.to_string(),
445 client_configuration
446 .tracing_configuration()
447 .unwrap()
448 .clone(),
449 ))
450 .check(Request::new(tonic_health::pb::HealthCheckRequest {
451 service: <HealthServer<HealthService> as NamedService>::NAME
452 .to_string(),
453 }))
454 .await
455 .unwrap();
456 assert_eq!(response.into_inner().status(), ServingStatus::Serving);
457 }
458 .with_current_context()
459 })
460 .await
461 .unwrap();
462 })
463 .await
464 .unwrap();
465
466 assert!(spans.iter().all(|span| { span.name != *HEALTH_CHECK_PATH }));
467 }
468
469 const JWT_SECRET: &[u8] = b"top-secret";
470
471 #[derive(Clone, Debug, Serialize, Deserialize)]
472 struct Claims {
473 sub: String,
474 exp: u64,
475 }
476
477 pub(crate) fn create_jwt() -> String {
480 use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
481 let expiration = SystemTime::now()
482 .checked_add(Duration::from_secs(60))
483 .unwrap()
484 .duration_since(SystemTime::UNIX_EPOCH)
485 .unwrap()
486 .as_secs();
487
488 let claims = Claims {
489 sub: "test-uid".to_string(),
490 exp: expiration,
491 };
492 let header = Header::new(Algorithm::HS256);
495 encode(&header, &claims, &EncodingKey::from_secret(JWT_SECRET)).unwrap()
496 }
497
498 #[derive(Debug, thiserror::Error)]
499 #[allow(variant_size_differences)]
500 enum ServerAssertionError {
501 #[error("trace id did not propagate to server: {0}")]
502 UnexpectedTraceId(String),
503 #[error("otel context headers unexpectedly sent to server")]
504 UnexpectedOTelContextHeaders,
505 }
506
507 impl From<ServerAssertionError> for tonic::Status {
508 fn from(server_assertion_error: ServerAssertionError) -> Self {
509 Self::invalid_argument(server_assertion_error.to_string())
510 }
511 }
512
513 fn validate_trace_id_propagated(
516 trace_id: TraceId,
517 req: Request<()>,
518 ) -> Result<Request<()>, tonic::Status> {
519 req.metadata()
520 .get("traceparent")
521 .ok_or_else(|| {
522 ServerAssertionError::UnexpectedTraceId(
523 "request traceparent metadata not present".to_string(),
524 )
525 })
526 .and_then(|traceparent| {
527 let mut headers = HeaderMap::new();
528 headers.insert(
529 "traceparent",
530 HeaderValue::from_str(traceparent.to_str().map_err(|_| {
531 ServerAssertionError::UnexpectedTraceId(
532 "failed to deserialize trace parent".to_string(),
533 )
534 })?)
535 .map_err(|_| {
536 ServerAssertionError::UnexpectedTraceId(
537 "failed to serialize trace parent as HeaderValue".to_string(),
538 )
539 })?,
540 );
541 Ok(headers)
542 })
543 .and_then(|headers| {
544 let extractor = HeaderExtractor(&headers);
545 let propagator = TraceContextPropagator::new();
546 let context = propagator.extract(&extractor);
547 let propagated_trace_id = context.span().span_context().trace_id();
548 if propagated_trace_id == trace_id {
549 Ok(req)
550 } else {
551 Err(ServerAssertionError::UnexpectedTraceId(format!(
552 "expected trace id {trace_id}, got {propagated_trace_id}",
553 )))
554 }
555 })
556 .map_err(Into::into)
557 }
558
559 fn validate_otel_context_not_propagated(
562 req: Request<()>,
563 ) -> Result<Request<()>, tonic::Status> {
564 if req.metadata().get("traceparent").is_some() || req.metadata().get("tracestate").is_some()
565 {
566 Err(ServerAssertionError::UnexpectedOTelContextHeaders.into())
567 } else {
568 Ok(req)
569 }
570 }
571
572 #[derive(Clone)]
579 struct SleepyHealthService {
580 sleep_time: Duration,
581 }
582
583 #[tonic::async_trait]
584 impl Health for SleepyHealthService {
585 async fn check(
586 &self,
587 _request: Request<tonic_health::pb::HealthCheckRequest>,
588 ) -> Result<tonic::Response<tonic_health::pb::HealthCheckResponse>, tonic::Status> {
589 tokio::time::sleep(self.sleep_time).await;
590 let response = tonic_health::pb::HealthCheckResponse {
591 status: ServingStatus::Serving as i32,
592 };
593 Ok(tonic::Response::new(response))
594 }
595
596 type WatchStream = tokio_stream::wrappers::ReceiverStream<
597 Result<tonic_health::pb::HealthCheckResponse, tonic::Status>,
598 >;
599
600 async fn watch(
601 &self,
602 _request: Request<tonic_health::pb::HealthCheckRequest>,
603 ) -> Result<tonic::Response<Self::WatchStream>, tonic::Status> {
604 let (tx, rx) = tokio::sync::mpsc::channel(1);
605 let response = tonic_health::pb::HealthCheckResponse {
606 status: ServingStatus::Serving as i32,
607 };
608 tx.send(Ok(response)).await.unwrap();
609 Ok(tonic::Response::new(
610 tokio_stream::wrappers::ReceiverStream::new(rx),
611 ))
612 }
613 }
614
615 mod tracing_test {
620 use futures_util::Future;
621 use opentelemetry::global::BoxedTracer;
622 use opentelemetry::trace::{
623 mark_span_as_active, FutureExt, Span, SpanId, TraceId, TraceResult, Tracer,
624 TracerProvider,
625 };
626 use opentelemetry_sdk::export::trace::SpanData;
627 use opentelemetry_sdk::trace::SpanProcessor;
628 use std::collections::HashMap;
629 use std::sync::{Arc, RwLock};
630
631 #[allow(clippy::future_not_send)]
636 pub(crate) async fn start<F, R>(name: &'static str, f: F) -> Result<Vec<SpanData>, String>
637 where
638 F: FnOnce(TraceId, SpanId) -> R + Send,
639 R: Future<Output = ()> + Send,
640 {
641 let tracer = CacheProcessor::tracer();
642 let span = tracer.start(name);
643 let span_id = span.span_context().span_id();
644 let trace_id = span.span_context().trace_id();
645 let cache = CACHE
646 .get()
647 .expect("cache should be initialized with cache tracer");
648 cache.subscribe(span_id)?;
649 async {
650 let _guard = mark_span_as_active(span);
651 f(trace_id, span_id).with_current_context().await;
652 }
653 .await;
654
655 cache.notified(span_id).await?;
657
658 let mut data = cache.data.write().map_err(|e| e.to_string())?;
660 Ok(data.remove(&trace_id).unwrap_or_default())
661 }
662
663 static CACHE: once_cell::sync::OnceCell<CacheProcessor> = once_cell::sync::OnceCell::new();
664
665 #[derive(Debug, Clone, Default)]
666 struct CacheProcessor {
667 data: Arc<RwLock<HashMap<TraceId, Vec<SpanData>>>>,
668 notifications: Arc<RwLock<HashMap<SpanId, Arc<tokio::sync::Notify>>>>,
669 }
670
671 impl CacheProcessor {
672 fn tracer() -> BoxedTracer {
675 use tracing_subscriber::layer::SubscriberExt;
676
677 CACHE.get_or_init(|| {
678 let processor = Self::default();
679 let provider = opentelemetry_sdk::trace::TracerProvider::builder()
680 .with_span_processor(processor.clone())
681 .build();
682 opentelemetry::global::set_tracer_provider(provider.clone());
683 let tracer = provider.tracer("test_channel");
684 let telemetry = tracing_opentelemetry::layer().with_tracer(tracer);
685 let subscriber = tracing_subscriber::Registry::default().with(telemetry);
686 tracing::subscriber::set_global_default(subscriber)
687 .expect("tracing subscriber already set");
688 processor
689 });
690 opentelemetry::global::tracer("test_channel")
691 }
692
693 fn subscribe(&self, span_id: SpanId) -> Result<(), String> {
695 let notify = Arc::new(tokio::sync::Notify::new());
696 self.notifications
697 .write()
698 .map_err(|e| e.to_string())?
699 .insert(span_id, notify);
700 Ok(())
701 }
702
703 async fn notified(&self, span_id: SpanId) -> Result<(), String> {
705 let notify = {
706 let notifications = self.notifications.read().map_err(|e| format!("{e}"))?;
707 notifications
708 .get(&span_id)
709 .ok_or("span notification never subscribed")?
710 .clone()
711 };
712 notify.notified().await;
713 Ok(())
714 }
715 }
716
717 impl SpanProcessor for CacheProcessor {
718 fn on_start(
719 &self,
720 _span: &mut opentelemetry_sdk::trace::Span,
721 _cx: &opentelemetry::Context,
722 ) {
723 }
724
725 fn on_end(&self, span: SpanData) {
726 let trace_id = span.span_context.trace_id();
727 let span_id = span.span_context.span_id();
728 {
729 let mut data = self
730 .data
731 .write()
732 .expect("failed to write access cache span data");
733 if let Some(spans) = data.get_mut(&trace_id) {
734 spans.push(span);
735 } else {
736 data.insert(trace_id, vec![span]);
737 }
738 }
739
740 if let Some(notify) = self
741 .notifications
742 .read()
743 .expect("failed to read access span notifications during span processing")
744 .get(&span_id)
745 {
746 notify.notify_waiters();
747 }
748 }
749
750 fn force_flush(&self) -> TraceResult<()> {
752 Ok(())
753 }
754
755 fn shutdown(&self) -> TraceResult<()> {
756 Ok(())
757 }
758 }
759
760 #[must_use]
762 pub(super) fn get_span_attribute(
763 span_data: &SpanData,
764 key: &'static str,
765 ) -> Option<String> {
766 span_data
767 .attributes
768 .iter()
769 .find(|attr| attr.key.to_string() == *key)
770 .map(|kv| kv.value.to_string())
771 }
772 }
773}