1use std::convert::Infallible;
2
3use futures_util::pin_mut;
6use http::Request;
7use http_body::Body as HttpBody;
8use http_body_util::BodyExt;
9
10mod channel;
11mod common;
12mod error;
13#[cfg(feature = "grpc-web")]
14mod grpc_web;
15mod refresh;
16mod retry;
17#[cfg(feature = "tracing")]
18mod trace;
19
20pub use channel::*;
21pub use error::*;
22#[cfg(feature = "grpc-web")]
23pub use grpc_web::*;
24pub use refresh::*;
25pub use retry::*;
26use tonic::body::Body;
27#[cfg(feature = "tracing")]
28pub use trace::*;
29#[derive(Debug, thiserror::Error)]
32pub enum RequestBodyDuplicationError {
33 #[error(transparent)]
36 Status(#[from] tonic::Status),
37 #[error("failed to read request body for request clone: {0}")]
39 HttpBody(#[from] http::Error),
40}
41
42impl From<RequestBodyDuplicationError> for tonic::Status {
43 fn from(err: RequestBodyDuplicationError) -> tonic::Status {
44 match err {
45 RequestBodyDuplicationError::Status(status) => status,
46 RequestBodyDuplicationError::HttpBody(error) => tonic::Status::cancelled(format!(
47 "failed to read request body for request clone: {error}"
48 )),
49 }
50 }
51}
52
53type RequestBodyDuplicationResult<T> = Result<T, RequestBodyDuplicationError>;
54
55async fn build_duplicate_frame_bytes(
61 mut request: Request<tonic::body::Body>,
62) -> RequestBodyDuplicationResult<(tonic::body::Body, tonic::body::Body)> {
63 let mut bytes = Vec::new();
64
65 let body = request.body_mut();
66 pin_mut!(body);
67 while let Some(result) = std::future::poll_fn(|cx| body.as_mut().poll_frame(cx)).await {
68 let frame_bytes = result?.into_data().map_err(|frame| {
69 tonic::Status::cancelled(format!(
70 "cannot duplicate a frame that is not a data frame: {frame:?}"
71 ))
72 })?;
73 bytes.extend(frame_bytes);
74 }
75
76 let bytes = http_body_util::Full::from(bytes)
77 .map_err(|_: Infallible| -> tonic::Status { unreachable!() });
78 Ok((
79 tonic::body::Body::new(bytes.clone()),
80 tonic::body::Body::new(bytes),
81 ))
82}
83
84async fn build_duplicate_request(
87 req: Request<Body>,
88) -> RequestBodyDuplicationResult<(Request<Body>, Request<Body>)> {
89 let mut builder_1 = Request::builder()
90 .method(req.method().clone())
91 .uri(req.uri().clone())
92 .version(req.version());
93
94 let mut builder_2 = Request::builder()
95 .method(req.method().clone())
96 .uri(req.uri().clone())
97 .version(req.version());
98
99 for (key, val) in req.headers() {
100 builder_1 = builder_1.header(key, val);
101 builder_2 = builder_2.header(key, val);
102 }
103
104 let (body_1, body_2) = build_duplicate_frame_bytes(req).await?;
105
106 let req_1 = builder_1.body(body_1)?;
107
108 let req_2 = builder_2.body(body_2)?;
109
110 Ok((req_1, req_2))
111}
112
113#[cfg(test)]
120pub(crate) mod uds_grpc_stream {
121 use hyper_util::rt::TokioIo;
122 use opentelemetry::trace::FutureExt;
123 use std::convert::Infallible;
124 use tempfile::TempDir;
125 use tokio::net::UnixStream;
126 use tokio_stream::wrappers::UnixListenerStream;
127 use tonic::server::NamedService;
128 use tonic::transport::{Channel, Endpoint, Server};
129
130 #[allow(dead_code)]
132 static FAUX_URL: &str = "http://api.example.rigetti.com";
133
134 fn build_server_stream(path: String) -> Result<UnixListenerStream, Error> {
135 let uds =
136 tokio::net::UnixListener::bind(path.clone()).map_err(|_| Error::BindUnixPath(path))?;
137 Ok(UnixListenerStream::new(uds))
138 }
139
140 async fn build_client_channel(path: String) -> Result<Channel, Error> {
141 let connector = tower::service_fn(move |_: tonic::transport::Uri| {
142 let path = path.clone();
143 async move {
144 let connection = UnixStream::connect(path).await?;
145 Ok::<_, std::io::Error>(TokioIo::new(connection))
146 }
147 });
148 let channel = Endpoint::try_from(FAUX_URL)
149 .map_err(|source| Error::Endpoint {
150 url: FAUX_URL.to_string(),
151 source,
152 })?
153 .connect_with_connector(connector)
154 .await
155 .map_err(|source| Error::Connect { source })?;
156 Ok(channel)
157 }
158
159 #[derive(thiserror::Error, Debug)]
162 pub enum Error {
163 #[error("failed to initialize endpoint for {url}: {source}")]
165 Endpoint {
166 url: String,
168 #[source]
170 source: tonic::transport::Error,
171 },
172 #[error("failed to connect to endpoint: {source}")]
174 Connect {
175 #[source]
177 source: tonic::transport::Error,
178 },
179 #[error("failed to bind path as unix listener: {0}")]
181 BindUnixPath(String),
182 #[error("failed in initialize tempfile: {0}")]
184 TempFile(#[from] std::io::Error),
185 #[error("failed to convert tempfile to OsString")]
187 TempFileOsString,
188 #[error("failed to bind to UDS: {0}")]
190 TonicTransport(#[from] tonic::transport::Error),
191 }
192
193 #[allow(clippy::significant_drop_tightening)]
200 pub async fn serve<S, F, R, B>(service: S, f: F) -> Result<(), Error>
201 where
202 S: tower::Service<
203 http::Request<tonic::body::Body>,
204 Response = http::Response<B>,
205 Error = Infallible,
206 > + NamedService
207 + Clone
208 + Send
209 + Sync
210 + 'static,
211 S::Future: Send + 'static,
212 B: http_body::Body<Data = tonic::codegen::Bytes> + Send + 'static,
213 B::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + Sync,
214 F: FnOnce(Channel) -> R + Send,
215 R: std::future::Future<Output = ()> + Send,
216 {
217 let directory = TempDir::new()?;
218 let file = directory.path().as_os_str();
219 let file = file.to_os_string();
220 let file = file.into_string().map_err(|_| Error::TempFileOsString)?;
221 let file = format!("{file}/test.sock");
222 let stream = build_server_stream(file.clone())?;
223
224 let channel = build_client_channel(file).await?;
225 let serve_future = Server::builder()
226 .add_service(service)
227 .serve_with_incoming(stream);
228
229 tokio::select! {
230 result = serve_future => result.map_err(Error::TonicTransport),
231 () = f(channel).with_current_context() => Ok(()),
232 }
233 }
234}
235
236#[cfg(test)]
237#[cfg(feature = "tracing-opentelemetry")]
238mod otel_tests {
239 use opentelemetry::propagation::TextMapPropagator;
240 use opentelemetry::trace::{TraceContextExt, TraceId};
241 use opentelemetry_http::HeaderExtractor;
242 use opentelemetry_sdk::propagation::TraceContextPropagator;
243 use qcs_api_client_common::configuration::secrets::{SecretAccessToken, SecretRefreshToken};
244 use qcs_api_client_common::configuration::tokens::RefreshToken;
245 use serde::{Deserialize, Serialize};
246 use std::time::{Duration, SystemTime};
247 use tonic::codegen::http::{HeaderMap, HeaderValue};
248 use tonic::server::NamedService;
249 use tonic::Request;
250 use tonic_health::pb::health_check_response::ServingStatus;
251 use tonic_health::pb::health_server::{Health, HealthServer};
252 use tonic_health::{pb::health_client::HealthClient, server::HealthService};
253
254 use crate::tonic::{uds_grpc_stream, wrap_channel_with_tracing};
255 use qcs_api_client_common::configuration::ClientConfiguration;
256 use qcs_api_client_common::configuration::{settings::AuthServer, tokens::OAuthSession};
257
258 static HEALTH_CHECK_PATH: &str = "/grpc.health.v1.Health/Check";
259
260 const FAUX_BASE_URL: &str = "http://api.example.rigetti.com";
261
262 #[tokio::test]
264 async fn test_tracing_enabled_no_filter() {
265 use qcs_api_client_common::tracing_configuration::TracingConfiguration;
266
267 let tracing_configuration = TracingConfiguration::builder()
268 .set_propagate_otel_context(true)
269 .build();
270 let client_config = ClientConfiguration::builder()
271 .tracing_configuration(Some(tracing_configuration))
272 .oauth_session(Some(OAuthSession::from_refresh_token(
273 RefreshToken::new(SecretRefreshToken::from("refresh_token")),
274 AuthServer::default(),
275 Some(create_jwt()),
276 )))
277 .build()
278 .expect("should be able to build client config");
279 assert_grpc_health_check_traced(client_config).await;
280 }
281
282 #[tokio::test]
285 async fn test_tracing_enabled_no_filter_nor_otel_context_propagation() {
286 use qcs_api_client_common::tracing_configuration::TracingConfiguration;
287
288 let tracing_configuration = TracingConfiguration::default();
289 let client_config = ClientConfiguration::builder()
290 .tracing_configuration(Some(tracing_configuration))
291 .oauth_session(Some(OAuthSession::from_refresh_token(
292 RefreshToken::new(SecretRefreshToken::from("refresh_token")),
293 AuthServer::default(),
294 Some(create_jwt()),
295 )))
296 .build()
297 .expect("failed to build client config");
298 assert_grpc_health_check_traced(client_config).await;
299 }
300
301 #[tokio::test]
304 async fn test_tracing_enabled_filter_passed() {
305 use qcs_api_client_common::tracing_configuration::{TracingConfiguration, TracingFilter};
306
307 let tracing_filter = TracingFilter::builder()
308 .parse_strs_and_set_paths(&[HEALTH_CHECK_PATH])
309 .expect("gRPC healthcheck path should be valid filter path")
310 .build();
311
312 let tracing_configuration = TracingConfiguration::builder()
313 .set_filter(Some(tracing_filter))
314 .set_propagate_otel_context(true)
315 .build();
316
317 let client_config = ClientConfiguration::builder()
318 .tracing_configuration(Some(tracing_configuration))
319 .oauth_session(Some(OAuthSession::from_refresh_token(
320 RefreshToken::new(SecretRefreshToken::from("refresh_token")),
321 AuthServer::default(),
322 Some(create_jwt()),
323 )))
324 .build()
325 .expect("failed to build client config");
326 assert_grpc_health_check_traced(client_config).await;
327 }
328
329 #[allow(clippy::future_not_send)]
332 async fn assert_grpc_health_check_traced(client_configuration: ClientConfiguration) {
333 use opentelemetry::trace::FutureExt;
334
335 let propagate_otel_context = client_configuration.tracing_configuration().is_some_and(
336 qcs_api_client_common::tracing_configuration::TracingConfiguration::propagate_otel_context,
337 );
338 let spans: Vec<opentelemetry_sdk::trace::SpanData> = tracing_test::start(
339 "test_trace_id_propagation",
340 |trace_id, _span_id| async move {
341 let sleepy_health_service = SleepyHealthService {
342 sleep_time: Duration::from_millis(50),
343 };
344
345 let interceptor = move |req| {
346 if propagate_otel_context {
347 validate_trace_id_propagated(trace_id, req)
348 } else {
349 validate_otel_context_not_propagated(req)
350 }
351 };
352 let health_server =
353 HealthServer::with_interceptor(sleepy_health_service, interceptor);
354
355 uds_grpc_stream::serve(health_server, |channel| {
356 async {
357 let response = HealthClient::new(wrap_channel_with_tracing(
358 channel,
359 FAUX_BASE_URL.to_string(),
360 client_configuration
361 .tracing_configuration()
362 .unwrap()
363 .clone(),
364 ))
365 .check(Request::new(tonic_health::pb::HealthCheckRequest {
366 service: <HealthServer<HealthService> as NamedService>::NAME
367 .to_string(),
368 }))
369 .await
370 .unwrap();
371 assert_eq!(response.into_inner().status(), ServingStatus::Serving);
372 }
373 .with_current_context()
374 })
375 .await
376 .unwrap();
377 },
378 )
379 .await
380 .unwrap();
381
382 let grpc_span = spans
383 .iter()
384 .find(|span| span.name == *HEALTH_CHECK_PATH)
385 .expect("failed to find gRPC span");
386 let duration = grpc_span
387 .end_time
388 .duration_since(grpc_span.start_time)
389 .expect("span should have valid timestamps");
390 assert!(duration.as_millis() >= 50u128);
391
392 let status_code_attribute =
393 tracing_test::get_span_attribute(grpc_span, "rpc.grpc.status_code")
394 .expect("gRPC span should have status code attribute");
395 assert_eq!(status_code_attribute, (tonic::Code::Ok as u8).to_string());
396 }
397
398 #[tokio::test]
401 async fn test_tracing_enabled_filter_not_passed() {
402 use qcs_api_client_common::tracing_configuration::{TracingConfiguration, TracingFilter};
403
404 let tracing_filter = TracingFilter::builder()
405 .parse_strs_and_set_paths(&[HEALTH_CHECK_PATH])
406 .expect("healthcheck path should be a valid filter path")
407 .set_is_negated(true)
408 .build();
409
410 let tracing_configuration = TracingConfiguration::builder()
411 .set_filter(Some(tracing_filter))
412 .set_propagate_otel_context(true)
413 .build();
414
415 let client_config = ClientConfiguration::builder()
416 .tracing_configuration(Some(tracing_configuration))
417 .oauth_session(Some(OAuthSession::from_refresh_token(
418 RefreshToken::new(SecretRefreshToken::from("refresh_token")),
419 AuthServer::default(),
420 Some(create_jwt()),
421 )))
422 .build()
423 .expect("should be able to build client config");
424
425 assert_grpc_health_check_not_traced(client_config.clone()).await;
426 }
427
428 #[allow(clippy::future_not_send)]
431 async fn assert_grpc_health_check_not_traced(client_configuration: ClientConfiguration) {
432 use opentelemetry::trace::FutureExt;
433
434 let spans: Vec<opentelemetry_sdk::trace::SpanData> =
435 tracing_test::start("test_tracing_disabled", |_trace_id, _span_id| async move {
436 let interceptor = validate_otel_context_not_propagated;
437 let health_server = HealthServer::with_interceptor(
438 SleepyHealthService {
439 sleep_time: Duration::from_millis(0),
440 },
441 interceptor,
442 );
443
444 uds_grpc_stream::serve(health_server, |channel| {
445 async {
446 let response = HealthClient::new(wrap_channel_with_tracing(
447 channel,
448 FAUX_BASE_URL.to_string(),
449 client_configuration
450 .tracing_configuration()
451 .unwrap()
452 .clone(),
453 ))
454 .check(Request::new(tonic_health::pb::HealthCheckRequest {
455 service: <HealthServer<HealthService> as NamedService>::NAME
456 .to_string(),
457 }))
458 .await
459 .unwrap();
460 assert_eq!(response.into_inner().status(), ServingStatus::Serving);
461 }
462 .with_current_context()
463 })
464 .await
465 .unwrap();
466 })
467 .await
468 .unwrap();
469
470 assert!(spans.iter().all(|span| { span.name != *HEALTH_CHECK_PATH }));
471 }
472
473 const JWT_SECRET: &[u8] = b"top-secret";
474
475 #[derive(Clone, Debug, Serialize, Deserialize)]
476 struct Claims {
477 sub: String,
478 exp: u64,
479 }
480
481 pub(crate) fn create_jwt() -> SecretAccessToken {
484 use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
485 let expiration = SystemTime::now()
486 .checked_add(Duration::from_secs(60))
487 .unwrap()
488 .duration_since(SystemTime::UNIX_EPOCH)
489 .unwrap()
490 .as_secs();
491
492 let claims = Claims {
493 sub: "test-uid".to_string(),
494 exp: expiration,
495 };
496 let header = Header::new(Algorithm::HS256);
499 encode(&header, &claims, &EncodingKey::from_secret(JWT_SECRET))
500 .map(SecretAccessToken::from)
501 .unwrap()
502 }
503
504 #[derive(Debug, thiserror::Error)]
505 #[allow(variant_size_differences)]
506 enum ServerAssertionError {
507 #[error("trace id did not propagate to server: {0}")]
508 UnexpectedTraceId(String),
509 #[error("otel context headers unexpectedly sent to server")]
510 UnexpectedOTelContextHeaders,
511 }
512
513 impl From<ServerAssertionError> for tonic::Status {
514 fn from(server_assertion_error: ServerAssertionError) -> Self {
515 Self::invalid_argument(server_assertion_error.to_string())
516 }
517 }
518
519 #[allow(clippy::result_large_err)]
522 fn validate_trace_id_propagated(
523 trace_id: TraceId,
524 req: Request<()>,
525 ) -> Result<Request<()>, tonic::Status> {
526 req.metadata()
527 .get("traceparent")
528 .ok_or_else(|| {
529 ServerAssertionError::UnexpectedTraceId(
530 "request traceparent metadata not present".to_string(),
531 )
532 })
533 .and_then(|traceparent| {
534 let mut headers = HeaderMap::new();
535 headers.insert(
536 "traceparent",
537 HeaderValue::from_str(traceparent.to_str().map_err(|_| {
538 ServerAssertionError::UnexpectedTraceId(
539 "failed to deserialize trace parent".to_string(),
540 )
541 })?)
542 .map_err(|_| {
543 ServerAssertionError::UnexpectedTraceId(
544 "failed to serialize trace parent as HeaderValue".to_string(),
545 )
546 })?,
547 );
548 Ok(headers)
549 })
550 .and_then(|headers| {
551 let extractor = HeaderExtractor(&headers);
552 let propagator = TraceContextPropagator::new();
553 let context = propagator.extract(&extractor);
554 let propagated_trace_id = context.span().span_context().trace_id();
555 if propagated_trace_id == trace_id {
556 Ok(req)
557 } else {
558 Err(ServerAssertionError::UnexpectedTraceId(format!(
559 "expected trace id {trace_id}, got {propagated_trace_id}",
560 )))
561 }
562 })
563 .map_err(Into::into)
564 }
565
566 #[allow(clippy::result_large_err)]
569 fn validate_otel_context_not_propagated(
570 req: Request<()>,
571 ) -> Result<Request<()>, tonic::Status> {
572 if req.metadata().get("traceparent").is_some() || req.metadata().get("tracestate").is_some()
573 {
574 Err(ServerAssertionError::UnexpectedOTelContextHeaders.into())
575 } else {
576 Ok(req)
577 }
578 }
579
580 #[derive(Clone)]
587 struct SleepyHealthService {
588 sleep_time: Duration,
589 }
590
591 #[tonic::async_trait]
592 impl Health for SleepyHealthService {
593 async fn check(
594 &self,
595 _request: Request<tonic_health::pb::HealthCheckRequest>,
596 ) -> Result<tonic::Response<tonic_health::pb::HealthCheckResponse>, tonic::Status> {
597 tokio::time::sleep(self.sleep_time).await;
598 let response = tonic_health::pb::HealthCheckResponse {
599 status: ServingStatus::Serving as i32,
600 };
601 Ok(tonic::Response::new(response))
602 }
603
604 type WatchStream = tokio_stream::wrappers::ReceiverStream<
605 Result<tonic_health::pb::HealthCheckResponse, tonic::Status>,
606 >;
607
608 async fn watch(
609 &self,
610 _request: Request<tonic_health::pb::HealthCheckRequest>,
611 ) -> Result<tonic::Response<Self::WatchStream>, tonic::Status> {
612 let (tx, rx) = tokio::sync::mpsc::channel(1);
613 let response = tonic_health::pb::HealthCheckResponse {
614 status: ServingStatus::Serving as i32,
615 };
616 tx.send(Ok(response)).await.unwrap();
617 Ok(tonic::Response::new(
618 tokio_stream::wrappers::ReceiverStream::new(rx),
619 ))
620 }
621 }
622
623 mod tracing_test {
628 use futures_util::Future;
629 use opentelemetry::global::BoxedTracer;
630 use opentelemetry::trace::{
631 mark_span_as_active, FutureExt, Span, SpanId, TraceId, Tracer, TracerProvider,
632 };
633 use opentelemetry_sdk::error::OTelSdkError;
634 use opentelemetry_sdk::trace::{SpanData, SpanProcessor};
635 use std::collections::HashMap;
636 use std::sync::{Arc, RwLock};
637 use tokio::sync::oneshot;
638
639 #[allow(clippy::future_not_send)]
644 pub(crate) async fn start<F, R>(name: &'static str, f: F) -> Result<Vec<SpanData>, String>
645 where
646 F: FnOnce(TraceId, SpanId) -> R + Send,
647 R: Future<Output = ()> + Send,
648 {
649 let tracer = CacheProcessor::tracer();
650 let span = tracer.start(name);
651 let span_id = span.span_context().span_id();
652 let trace_id = span.span_context().trace_id();
653 let cache = CACHE
654 .get()
655 .expect("cache should be initialized with cache tracer");
656 let span_rx = cache.subscribe(span_id);
657 async {
658 let _guard = mark_span_as_active(span);
659 f(trace_id, span_id).with_current_context().await;
660 }
661 .await;
662
663 span_rx.await.unwrap();
665
666 let mut data = cache.data.write().map_err(|e| format!("{e:?}"))?;
668 Ok(data.remove(&trace_id).unwrap_or_default())
669 }
670
671 static CACHE: once_cell::sync::OnceCell<CacheProcessor> = once_cell::sync::OnceCell::new();
672
673 #[derive(Debug, Clone, Default)]
674 struct CacheProcessor {
675 data: Arc<RwLock<HashMap<TraceId, Vec<SpanData>>>>,
676 notifications: Arc<RwLock<HashMap<SpanId, tokio::sync::oneshot::Sender<()>>>>,
677 }
678
679 impl CacheProcessor {
680 fn tracer() -> BoxedTracer {
683 use tracing_subscriber::layer::SubscriberExt;
684
685 CACHE.get_or_init(|| {
686 let processor = Self::default();
687 let provider = opentelemetry_sdk::trace::SdkTracerProvider::builder()
688 .with_span_processor(processor.clone())
689 .build();
690 opentelemetry::global::set_tracer_provider(provider.clone());
691 let tracer = provider.tracer("test_channel");
692 let telemetry = tracing_opentelemetry::layer().with_tracer(tracer);
693 let subscriber = tracing_subscriber::Registry::default().with(telemetry);
694 tracing::subscriber::set_global_default(subscriber)
695 .expect("tracing subscriber already set");
696 processor
697 });
698 opentelemetry::global::tracer("test_channel")
699 }
700
701 fn subscribe(&self, span_id: SpanId) -> oneshot::Receiver<()> {
703 let (tx, rx) = oneshot::channel();
704 self.notifications.write().unwrap().insert(span_id, tx);
705 rx
706 }
707 }
708
709 impl SpanProcessor for CacheProcessor {
710 fn on_start(
711 &self,
712 _span: &mut opentelemetry_sdk::trace::Span,
713 _cx: &opentelemetry::Context,
714 ) {
715 }
716
717 fn on_end(&self, span: SpanData) {
718 let trace_id = span.span_context.trace_id();
719 let span_id = span.span_context.span_id();
720 {
721 let mut data = self
722 .data
723 .write()
724 .expect("failed to write access cache span data");
725 data.entry(trace_id).or_default().push(span);
726 }
727
728 if let Some(notify) = self
729 .notifications
730 .write()
731 .expect("failed to read access span notifications during span processing")
732 .remove(&span_id)
733 {
734 notify.send(()).unwrap();
735 }
736 }
737
738 fn force_flush(&self) -> Result<(), OTelSdkError> {
740 Ok(())
741 }
742
743 fn shutdown(&self) -> Result<(), OTelSdkError> {
744 Ok(())
745 }
746
747 fn shutdown_with_timeout(
748 &self,
749 _timeout: std::time::Duration,
750 ) -> Result<(), OTelSdkError> {
751 Ok(())
752 }
753 }
754
755 #[must_use]
757 pub(super) fn get_span_attribute(
758 span_data: &SpanData,
759 key: &'static str,
760 ) -> Option<String> {
761 span_data
762 .attributes
763 .iter()
764 .find(|attr| attr.key.to_string() == *key)
765 .map(|kv| kv.value.to_string())
766 }
767 }
768}