tonic_side_effect/
lib.rs

1use hyper::body::{Body, Frame, SizeHint};
2use pin_project_lite::pin_project;
3use std::pin::Pin;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::sync::Arc;
6use std::task::{Context, Poll};
7use tonic::body::Body as TonicBody;
8use tonic::transport::Channel;
9use tower_service::Service;
10
11/// Resettable handle for indicating if a frame has been produced.
12#[derive(Clone, Debug, Default)]
13pub struct FrameSignal(Arc<AtomicBool>);
14
15impl FrameSignal {
16    fn signal(&self) {
17        self.0.store(true, Ordering::Release)
18    }
19
20    pub fn new() -> Self {
21        Self(Arc::new(AtomicBool::new(false)))
22    }
23
24    pub fn is_signalled(&self) -> bool {
25        self.0.load(Ordering::Acquire)
26    }
27
28    pub fn reset(&self) {
29        self.0.store(false, Ordering::Release)
30    }
31}
32
33pin_project! {
34    struct RequestFrameMonitorBody<B> {
35        #[pin]
36        inner: B,
37        frame_signal: FrameSignal,
38    }
39}
40
41impl<B> Body for RequestFrameMonitorBody<B>
42where
43    B: Body,
44{
45    type Data = B::Data;
46    type Error = B::Error;
47
48    fn poll_frame(
49        self: Pin<&mut Self>,
50        cx: &mut Context<'_>,
51    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
52        let this = self.project();
53        match this.inner.poll_frame(cx) {
54            Poll::Ready(Some(res)) => match res {
55                Ok(frame) => {
56                    this.frame_signal.signal();
57                    Poll::Ready(Some(Ok(frame)))
58                }
59                Err(status) => Poll::Ready(Some(Err(status))),
60            },
61            Poll::Ready(None) => Poll::Ready(None),
62            Poll::Pending => Poll::Pending,
63        }
64    }
65
66    fn is_end_stream(&self) -> bool {
67        self.inner.is_end_stream()
68    }
69
70    fn size_hint(&self) -> SizeHint {
71        self.inner.size_hint()
72    }
73}
74
75/// Service for monitoring if an HTTP request frame was ever emitted.
76#[derive(Clone, Debug)]
77pub struct RequestFrameMonitor<S = Channel>
78where
79    S: Clone,
80{
81    /// Wrapped channel to monitor.
82    inner: S,
83
84    /// Signal indicating if request frame has been produced.
85    frame_signal: FrameSignal,
86}
87
88impl<S: Clone> RequestFrameMonitor<S> {
89    pub fn new(inner: S, frame_signal: FrameSignal) -> Self {
90        Self {
91            inner,
92            frame_signal: frame_signal.clone(),
93        }
94    }
95}
96
97impl<S> Service<http::Request<TonicBody>> for RequestFrameMonitor<S>
98where
99    S: Service<http::Request<TonicBody>> + Clone,
100{
101    type Response = S::Response;
102    type Error = S::Error;
103    type Future = S::Future;
104
105    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
106        self.inner.poll_ready(cx)
107    }
108
109    fn call(&mut self, req: http::Request<TonicBody>) -> Self::Future {
110        let (head, body) = req.into_parts();
111        let body = TonicBody::new(RequestFrameMonitorBody {
112            inner: body,
113            frame_signal: self.frame_signal.clone(),
114        });
115        // See <https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services>
116        let clone = self.inner.clone();
117        let mut inner = std::mem::replace(&mut self.inner, clone);
118        inner.call(http::Request::from_parts(head, body))
119    }
120}