1use std::convert::Infallible;
4
5use futures_util::pin_mut;
6use http::Request;
7use http_body::{Body as HttpBody, Frame};
8use http_body_util::StreamBody;
9use prost::bytes::Bytes;
10
11mod channel;
12mod common;
13mod error;
14#[cfg(feature = "grpc-web")]
15mod grpc_web;
16mod refresh;
17mod retry;
18#[cfg(feature = "tracing")]
19mod trace;
20
21pub use channel::*;
22pub use error::*;
23#[cfg(feature = "grpc-web")]
24pub use grpc_web::*;
25pub use refresh::*;
26pub use retry::*;
27use tonic::body::Body;
28#[cfg(feature = "tracing")]
29pub use trace::*;
30#[derive(Debug, thiserror::Error)]
33pub enum RequestBodyDuplicationError {
34 #[error(transparent)]
37 Status(#[from] tonic::Status),
38 #[error("failed to read request body for request clone: {0}")]
40 HttpBody(#[from] http::Error),
41}
42
43impl From<RequestBodyDuplicationError> for tonic::Status {
44 fn from(err: RequestBodyDuplicationError) -> tonic::Status {
45 match err {
46 RequestBodyDuplicationError::Status(status) => status,
47 RequestBodyDuplicationError::HttpBody(error) => tonic::Status::cancelled(format!(
48 "failed to read request body for request clone: {error}"
49 )),
50 }
51 }
52}
53
54type RequestBodyDuplicationResult<T> = Result<T, RequestBodyDuplicationError>;
55
56fn make_stream_body(bytes: Bytes) -> tonic::body::Body {
60 let body = Frame::data(bytes);
61
62 let stream = futures_util::stream::once(std::future::ready(Ok::<_, Infallible>(body)));
64
65 let stream_body = StreamBody::new(stream);
66
67 tonic::body::Body::new(stream_body)
68}
69
70async fn build_duplicate_frame_bytes(
76 mut request: Request<tonic::body::Body>,
77) -> RequestBodyDuplicationResult<(tonic::body::Body, tonic::body::Body)> {
78 let mut bytes = Vec::new();
79
80 let body = request.body_mut();
81 pin_mut!(body);
82 while let Some(result) = std::future::poll_fn(|cx| body.as_mut().poll_frame(cx)).await {
83 let frame_bytes = result?.into_data().map_err(|frame| {
84 tonic::Status::cancelled(format!(
85 "cannot duplicate a frame that is not a data frame: {frame:?}"
86 ))
87 })?;
88 bytes.extend(frame_bytes);
89 }
90
91 let bytes = Bytes::from(bytes);
92
93 Ok((make_stream_body(bytes.clone()), make_stream_body(bytes)))
94}
95
96async fn build_duplicate_request(
99 req: Request<Body>,
100) -> RequestBodyDuplicationResult<(Request<Body>, Request<Body>)> {
101 let mut builder_1 = Request::builder()
102 .method(req.method().clone())
103 .uri(req.uri().clone())
104 .version(req.version());
105
106 let mut builder_2 = Request::builder()
107 .method(req.method().clone())
108 .uri(req.uri().clone())
109 .version(req.version());
110
111 for (key, val) in req.headers() {
112 builder_1 = builder_1.header(key, val);
113 builder_2 = builder_2.header(key, val);
114 }
115
116 let (body_1, body_2) = build_duplicate_frame_bytes(req).await?;
117
118 let req_1 = builder_1.body(body_1)?;
119
120 let req_2 = builder_2.body(body_2)?;
121
122 Ok((req_1, req_2))
123}
124
125#[cfg(test)]
132pub(crate) mod uds_grpc_stream {
133 use hyper_util::rt::TokioIo;
134 use opentelemetry::trace::FutureExt;
135 use std::convert::Infallible;
136 use tempfile::TempDir;
137 use tokio::net::UnixStream;
138 use tokio_stream::wrappers::UnixListenerStream;
139 use tonic::server::NamedService;
140 use tonic::transport::{Channel, Endpoint, Server};
141
142 #[allow(dead_code)]
144 static FAUX_URL: &str = "http://api.example.rigetti.com";
145
146 fn build_server_stream(path: String) -> Result<UnixListenerStream, Error> {
147 let uds =
148 tokio::net::UnixListener::bind(path.clone()).map_err(|_| Error::BindUnixPath(path))?;
149 Ok(UnixListenerStream::new(uds))
150 }
151
152 async fn build_client_channel(path: String) -> Result<Channel, Error> {
153 let connector = tower::service_fn(move |_: tonic::transport::Uri| {
154 let path = path.clone();
155 async move {
156 let connection = UnixStream::connect(path).await?;
157 Ok::<_, std::io::Error>(TokioIo::new(connection))
158 }
159 });
160 let channel = Endpoint::try_from(FAUX_URL)
161 .map_err(|source| Error::Endpoint {
162 url: FAUX_URL.to_string(),
163 source,
164 })?
165 .connect_with_connector(connector)
166 .await
167 .map_err(|source| Error::Connect { source })?;
168 Ok(channel)
169 }
170
171 #[derive(thiserror::Error, Debug)]
174 pub enum Error {
175 #[error("failed to initialize endpoint for {url}: {source}")]
177 Endpoint {
178 url: String,
180 #[source]
182 source: tonic::transport::Error,
183 },
184 #[error("failed to connect to endpoint: {source}")]
186 Connect {
187 #[source]
189 source: tonic::transport::Error,
190 },
191 #[error("failed to bind path as unix listener: {0}")]
193 BindUnixPath(String),
194 #[error("failed in initialize tempfile: {0}")]
196 TempFile(#[from] std::io::Error),
197 #[error("failed to convert tempfile to OsString")]
199 TempFileOsString,
200 #[error("failed to bind to UDS: {0}")]
202 TonicTransport(#[from] tonic::transport::Error),
203 }
204
205 #[allow(clippy::significant_drop_tightening)]
212 pub async fn serve<S, F, R, B>(service: S, f: F) -> Result<(), Error>
213 where
214 S: tower::Service<
215 http::Request<tonic::body::Body>,
216 Response = http::Response<B>,
217 Error = Infallible,
218 > + NamedService
219 + Clone
220 + Send
221 + Sync
222 + 'static,
223 S::Future: Send + 'static,
224 B: http_body::Body<Data = tonic::codegen::Bytes> + Send + 'static,
225 B::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + Sync,
226 F: FnOnce(Channel) -> R + Send,
227 R: std::future::Future<Output = ()> + Send,
228 {
229 let directory = TempDir::new()?;
230 let file = directory.path().as_os_str();
231 let file = file.to_os_string();
232 let file = file.into_string().map_err(|_| Error::TempFileOsString)?;
233 let file = format!("{file}/test.sock");
234 let stream = build_server_stream(file.clone())?;
235
236 let channel = build_client_channel(file).await?;
237 let serve_future = Server::builder()
238 .add_service(service)
239 .serve_with_incoming(stream);
240
241 tokio::select! {
242 result = serve_future => result.map_err(Error::TonicTransport),
243 () = f(channel).with_current_context() => Ok(()),
244 }
245 }
246}
247
248#[cfg(test)]
249#[cfg(feature = "tracing-opentelemetry")]
250mod otel_tests {
251 use opentelemetry::propagation::TextMapPropagator;
252 use opentelemetry::trace::{TraceContextExt, TraceId};
253 use opentelemetry_http::HeaderExtractor;
254 use opentelemetry_sdk::propagation::TraceContextPropagator;
255 use qcs_api_client_common::configuration::secrets::{SecretAccessToken, SecretRefreshToken};
256 use qcs_api_client_common::configuration::tokens::RefreshToken;
257 use serde::{Deserialize, Serialize};
258 use std::time::{Duration, SystemTime};
259 use tonic::codegen::http::{HeaderMap, HeaderValue};
260 use tonic::server::NamedService;
261 use tonic::Request;
262 use tonic_health::pb::health_check_response::ServingStatus;
263 use tonic_health::pb::health_server::{Health, HealthServer};
264 use tonic_health::{pb::health_client::HealthClient, server::HealthService};
265
266 use crate::tonic::{uds_grpc_stream, wrap_channel_with_tracing};
267 use qcs_api_client_common::configuration::ClientConfiguration;
268 use qcs_api_client_common::configuration::{settings::AuthServer, tokens::OAuthSession};
269
270 static HEALTH_CHECK_PATH: &str = "/grpc.health.v1.Health/Check";
271
272 const FAUX_BASE_URL: &str = "http://api.example.rigetti.com";
273
274 #[tokio::test]
276 async fn test_tracing_enabled_no_filter() {
277 use qcs_api_client_common::tracing_configuration::TracingConfiguration;
278
279 let tracing_configuration = TracingConfiguration::builder()
280 .set_propagate_otel_context(true)
281 .build();
282 let client_config = ClientConfiguration::builder()
283 .tracing_configuration(Some(tracing_configuration))
284 .oauth_session(Some(OAuthSession::from_refresh_token(
285 RefreshToken::new(SecretRefreshToken::from("refresh_token")),
286 AuthServer::default(),
287 Some(create_jwt()),
288 )))
289 .build()
290 .expect("should be able to build client config");
291 assert_grpc_health_check_traced(client_config).await;
292 }
293
294 #[tokio::test]
297 async fn test_tracing_enabled_no_filter_nor_otel_context_propagation() {
298 use qcs_api_client_common::tracing_configuration::TracingConfiguration;
299
300 let tracing_configuration = TracingConfiguration::default();
301 let client_config = ClientConfiguration::builder()
302 .tracing_configuration(Some(tracing_configuration))
303 .oauth_session(Some(OAuthSession::from_refresh_token(
304 RefreshToken::new(SecretRefreshToken::from("refresh_token")),
305 AuthServer::default(),
306 Some(create_jwt()),
307 )))
308 .build()
309 .expect("failed to build client config");
310 assert_grpc_health_check_traced(client_config).await;
311 }
312
313 #[tokio::test]
316 async fn test_tracing_enabled_filter_passed() {
317 use qcs_api_client_common::tracing_configuration::{TracingConfiguration, TracingFilter};
318
319 let tracing_filter = TracingFilter::builder()
320 .parse_strs_and_set_paths(&[HEALTH_CHECK_PATH])
321 .expect("gRPC healthcheck path should be valid filter path")
322 .build();
323
324 let tracing_configuration = TracingConfiguration::builder()
325 .set_filter(Some(tracing_filter))
326 .set_propagate_otel_context(true)
327 .build();
328
329 let client_config = ClientConfiguration::builder()
330 .tracing_configuration(Some(tracing_configuration))
331 .oauth_session(Some(OAuthSession::from_refresh_token(
332 RefreshToken::new(SecretRefreshToken::from("refresh_token")),
333 AuthServer::default(),
334 Some(create_jwt()),
335 )))
336 .build()
337 .expect("failed to build client config");
338 assert_grpc_health_check_traced(client_config).await;
339 }
340
341 #[allow(clippy::future_not_send)]
344 async fn assert_grpc_health_check_traced(client_configuration: ClientConfiguration) {
345 use opentelemetry::trace::FutureExt;
346
347 let propagate_otel_context = client_configuration.tracing_configuration().is_some_and(
348 qcs_api_client_common::tracing_configuration::TracingConfiguration::propagate_otel_context,
349 );
350 let spans: Vec<opentelemetry_sdk::trace::SpanData> = tracing_test::start(
351 "test_trace_id_propagation",
352 |trace_id, _span_id| async move {
353 let sleepy_health_service = SleepyHealthService {
354 sleep_time: Duration::from_millis(50),
355 };
356
357 let interceptor = move |req| {
358 if propagate_otel_context {
359 validate_trace_id_propagated(trace_id, req)
360 } else {
361 validate_otel_context_not_propagated(req)
362 }
363 };
364 let health_server =
365 HealthServer::with_interceptor(sleepy_health_service, interceptor);
366
367 uds_grpc_stream::serve(health_server, |channel| {
368 async {
369 let response = HealthClient::new(wrap_channel_with_tracing(
370 channel,
371 FAUX_BASE_URL.to_string(),
372 client_configuration
373 .tracing_configuration()
374 .unwrap()
375 .clone(),
376 ))
377 .check(Request::new(tonic_health::pb::HealthCheckRequest {
378 service: <HealthServer<HealthService> as NamedService>::NAME
379 .to_string(),
380 }))
381 .await
382 .unwrap();
383 assert_eq!(response.into_inner().status(), ServingStatus::Serving);
384 }
385 .with_current_context()
386 })
387 .await
388 .unwrap();
389 },
390 )
391 .await
392 .unwrap();
393
394 let grpc_span = spans
395 .iter()
396 .find(|span| span.name == *HEALTH_CHECK_PATH)
397 .expect("failed to find gRPC span");
398 let duration = grpc_span
399 .end_time
400 .duration_since(grpc_span.start_time)
401 .expect("span should have valid timestamps");
402 assert!(duration.as_millis() >= 50u128);
403
404 let status_code_attribute =
405 tracing_test::get_span_attribute(grpc_span, "rpc.grpc.status_code")
406 .expect("gRPC span should have status code attribute");
407 assert_eq!(status_code_attribute, (tonic::Code::Ok as u8).to_string());
408 }
409
410 #[tokio::test]
413 async fn test_tracing_enabled_filter_not_passed() {
414 use qcs_api_client_common::tracing_configuration::{TracingConfiguration, TracingFilter};
415
416 let tracing_filter = TracingFilter::builder()
417 .parse_strs_and_set_paths(&[HEALTH_CHECK_PATH])
418 .expect("healthcheck path should be a valid filter path")
419 .set_is_negated(true)
420 .build();
421
422 let tracing_configuration = TracingConfiguration::builder()
423 .set_filter(Some(tracing_filter))
424 .set_propagate_otel_context(true)
425 .build();
426
427 let client_config = ClientConfiguration::builder()
428 .tracing_configuration(Some(tracing_configuration))
429 .oauth_session(Some(OAuthSession::from_refresh_token(
430 RefreshToken::new(SecretRefreshToken::from("refresh_token")),
431 AuthServer::default(),
432 Some(create_jwt()),
433 )))
434 .build()
435 .expect("should be able to build client config");
436
437 assert_grpc_health_check_not_traced(client_config.clone()).await;
438 }
439
440 #[allow(clippy::future_not_send)]
443 async fn assert_grpc_health_check_not_traced(client_configuration: ClientConfiguration) {
444 use opentelemetry::trace::FutureExt;
445
446 let spans: Vec<opentelemetry_sdk::trace::SpanData> =
447 tracing_test::start("test_tracing_disabled", |_trace_id, _span_id| async move {
448 let interceptor = validate_otel_context_not_propagated;
449 let health_server = HealthServer::with_interceptor(
450 SleepyHealthService {
451 sleep_time: Duration::from_millis(0),
452 },
453 interceptor,
454 );
455
456 uds_grpc_stream::serve(health_server, |channel| {
457 async {
458 let response = HealthClient::new(wrap_channel_with_tracing(
459 channel,
460 FAUX_BASE_URL.to_string(),
461 client_configuration
462 .tracing_configuration()
463 .unwrap()
464 .clone(),
465 ))
466 .check(Request::new(tonic_health::pb::HealthCheckRequest {
467 service: <HealthServer<HealthService> as NamedService>::NAME
468 .to_string(),
469 }))
470 .await
471 .unwrap();
472 assert_eq!(response.into_inner().status(), ServingStatus::Serving);
473 }
474 .with_current_context()
475 })
476 .await
477 .unwrap();
478 })
479 .await
480 .unwrap();
481
482 assert!(spans.iter().all(|span| { span.name != *HEALTH_CHECK_PATH }));
483 }
484
485 const JWT_SECRET: &[u8] = b"top-secret";
486
487 #[derive(Clone, Debug, Serialize, Deserialize)]
488 struct Claims {
489 sub: String,
490 exp: u64,
491 }
492
493 pub(crate) fn create_jwt() -> SecretAccessToken {
496 use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
497 let expiration = SystemTime::now()
498 .checked_add(Duration::from_secs(60))
499 .unwrap()
500 .duration_since(SystemTime::UNIX_EPOCH)
501 .unwrap()
502 .as_secs();
503
504 let claims = Claims {
505 sub: "test-uid".to_string(),
506 exp: expiration,
507 };
508 let header = Header::new(Algorithm::HS256);
511 encode(&header, &claims, &EncodingKey::from_secret(JWT_SECRET))
512 .map(SecretAccessToken::from)
513 .unwrap()
514 }
515
516 #[derive(Debug, thiserror::Error)]
517 #[allow(variant_size_differences)]
518 enum ServerAssertionError {
519 #[error("trace id did not propagate to server: {0}")]
520 UnexpectedTraceId(String),
521 #[error("otel context headers unexpectedly sent to server")]
522 UnexpectedOTelContextHeaders,
523 }
524
525 impl From<ServerAssertionError> for tonic::Status {
526 fn from(server_assertion_error: ServerAssertionError) -> Self {
527 Self::invalid_argument(server_assertion_error.to_string())
528 }
529 }
530
531 #[allow(clippy::result_large_err)]
534 fn validate_trace_id_propagated(
535 trace_id: TraceId,
536 req: Request<()>,
537 ) -> Result<Request<()>, tonic::Status> {
538 req.metadata()
539 .get("traceparent")
540 .ok_or_else(|| {
541 ServerAssertionError::UnexpectedTraceId(
542 "request traceparent metadata not present".to_string(),
543 )
544 })
545 .and_then(|traceparent| {
546 let mut headers = HeaderMap::new();
547 headers.insert(
548 "traceparent",
549 HeaderValue::from_str(traceparent.to_str().map_err(|_| {
550 ServerAssertionError::UnexpectedTraceId(
551 "failed to deserialize trace parent".to_string(),
552 )
553 })?)
554 .map_err(|_| {
555 ServerAssertionError::UnexpectedTraceId(
556 "failed to serialize trace parent as HeaderValue".to_string(),
557 )
558 })?,
559 );
560 Ok(headers)
561 })
562 .and_then(|headers| {
563 let extractor = HeaderExtractor(&headers);
564 let propagator = TraceContextPropagator::new();
565 let context = propagator.extract(&extractor);
566 let propagated_trace_id = context.span().span_context().trace_id();
567 if propagated_trace_id == trace_id {
568 Ok(req)
569 } else {
570 Err(ServerAssertionError::UnexpectedTraceId(format!(
571 "expected trace id {trace_id}, got {propagated_trace_id}",
572 )))
573 }
574 })
575 .map_err(Into::into)
576 }
577
578 #[allow(clippy::result_large_err)]
581 fn validate_otel_context_not_propagated(
582 req: Request<()>,
583 ) -> Result<Request<()>, tonic::Status> {
584 if req.metadata().get("traceparent").is_some() || req.metadata().get("tracestate").is_some()
585 {
586 Err(ServerAssertionError::UnexpectedOTelContextHeaders.into())
587 } else {
588 Ok(req)
589 }
590 }
591
592 #[derive(Clone)]
599 struct SleepyHealthService {
600 sleep_time: Duration,
601 }
602
603 #[tonic::async_trait]
604 impl Health for SleepyHealthService {
605 async fn check(
606 &self,
607 _request: Request<tonic_health::pb::HealthCheckRequest>,
608 ) -> Result<tonic::Response<tonic_health::pb::HealthCheckResponse>, tonic::Status> {
609 tokio::time::sleep(self.sleep_time).await;
610 let response = tonic_health::pb::HealthCheckResponse {
611 status: ServingStatus::Serving as i32,
612 };
613 Ok(tonic::Response::new(response))
614 }
615
616 type WatchStream = tokio_stream::wrappers::ReceiverStream<
617 Result<tonic_health::pb::HealthCheckResponse, tonic::Status>,
618 >;
619
620 async fn watch(
621 &self,
622 _request: Request<tonic_health::pb::HealthCheckRequest>,
623 ) -> Result<tonic::Response<Self::WatchStream>, tonic::Status> {
624 let (tx, rx) = tokio::sync::mpsc::channel(1);
625 let response = tonic_health::pb::HealthCheckResponse {
626 status: ServingStatus::Serving as i32,
627 };
628 tx.send(Ok(response)).await.unwrap();
629 Ok(tonic::Response::new(
630 tokio_stream::wrappers::ReceiverStream::new(rx),
631 ))
632 }
633 }
634
635 mod tracing_test {
640 use futures_util::Future;
641 use opentelemetry::global::BoxedTracer;
642 use opentelemetry::trace::{
643 mark_span_as_active, FutureExt, Span, SpanId, TraceId, Tracer, TracerProvider,
644 };
645 use opentelemetry_sdk::error::OTelSdkError;
646 use opentelemetry_sdk::trace::{SpanData, SpanProcessor};
647 use std::collections::HashMap;
648 use std::sync::{Arc, RwLock};
649 use tokio::sync::oneshot;
650
651 #[allow(clippy::future_not_send)]
656 pub(crate) async fn start<F, R>(name: &'static str, f: F) -> Result<Vec<SpanData>, String>
657 where
658 F: FnOnce(TraceId, SpanId) -> R + Send,
659 R: Future<Output = ()> + Send,
660 {
661 let tracer = CacheProcessor::tracer();
662 let span = tracer.start(name);
663 let span_id = span.span_context().span_id();
664 let trace_id = span.span_context().trace_id();
665 let cache = CACHE
666 .get()
667 .expect("cache should be initialized with cache tracer");
668 let span_rx = cache.subscribe(span_id);
669 async {
670 let _guard = mark_span_as_active(span);
671 f(trace_id, span_id).with_current_context().await;
672 }
673 .await;
674
675 span_rx.await.unwrap();
677
678 let mut data = cache.data.write().map_err(|e| format!("{e:?}"))?;
680 Ok(data.remove(&trace_id).unwrap_or_default())
681 }
682
683 static CACHE: once_cell::sync::OnceCell<CacheProcessor> = once_cell::sync::OnceCell::new();
684
685 #[derive(Debug, Clone, Default)]
686 struct CacheProcessor {
687 data: Arc<RwLock<HashMap<TraceId, Vec<SpanData>>>>,
688 notifications: Arc<RwLock<HashMap<SpanId, tokio::sync::oneshot::Sender<()>>>>,
689 }
690
691 impl CacheProcessor {
692 fn tracer() -> BoxedTracer {
695 use tracing_subscriber::layer::SubscriberExt;
696
697 CACHE.get_or_init(|| {
698 let processor = Self::default();
699 let provider = opentelemetry_sdk::trace::SdkTracerProvider::builder()
700 .with_span_processor(processor.clone())
701 .build();
702 opentelemetry::global::set_tracer_provider(provider.clone());
703 let tracer = provider.tracer("test_channel");
704 let telemetry = tracing_opentelemetry::layer().with_tracer(tracer);
705 let subscriber = tracing_subscriber::Registry::default().with(telemetry);
706 tracing::subscriber::set_global_default(subscriber)
707 .expect("tracing subscriber already set");
708 processor
709 });
710 opentelemetry::global::tracer("test_channel")
711 }
712
713 fn subscribe(&self, span_id: SpanId) -> oneshot::Receiver<()> {
715 let (tx, rx) = oneshot::channel();
716 self.notifications.write().unwrap().insert(span_id, tx);
717 rx
718 }
719 }
720
721 impl SpanProcessor for CacheProcessor {
722 fn on_start(
723 &self,
724 _span: &mut opentelemetry_sdk::trace::Span,
725 _cx: &opentelemetry::Context,
726 ) {
727 }
728
729 fn on_end(&self, span: SpanData) {
730 let trace_id = span.span_context.trace_id();
731 let span_id = span.span_context.span_id();
732 {
733 let mut data = self
734 .data
735 .write()
736 .expect("failed to write access cache span data");
737 data.entry(trace_id).or_default().push(span);
738 }
739
740 if let Some(notify) = self
741 .notifications
742 .write()
743 .expect("failed to read access span notifications during span processing")
744 .remove(&span_id)
745 {
746 notify.send(()).unwrap();
747 }
748 }
749
750 fn force_flush(&self) -> Result<(), OTelSdkError> {
752 Ok(())
753 }
754
755 fn shutdown(&self) -> Result<(), OTelSdkError> {
756 Ok(())
757 }
758
759 fn shutdown_with_timeout(
760 &self,
761 _timeout: std::time::Duration,
762 ) -> Result<(), OTelSdkError> {
763 Ok(())
764 }
765 }
766
767 #[must_use]
769 pub(super) fn get_span_attribute(
770 span_data: &SpanData,
771 key: &'static str,
772 ) -> Option<String> {
773 span_data
774 .attributes
775 .iter()
776 .find(|attr| attr.key.to_string() == *key)
777 .map(|kv| kv.value.to_string())
778 }
779 }
780}