Skip to main content

rs_zero/rpc/
streaming.rs

1use std::{
2    sync::Arc,
3    time::{Duration, Instant},
4};
5
6use tokio::sync::Mutex;
7use tonic::Code;
8
9use crate::rpc::RpcStreamingConfig;
10
11#[derive(Debug, Default)]
12struct StreamingState {
13    sent_messages: u64,
14    received_messages: u64,
15    completed: bool,
16    code: Option<Code>,
17}
18
19/// Point-in-time streaming RPC observation snapshot.
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct RpcStreamingSnapshot {
22    /// Number of send events.
23    pub sent_messages: u64,
24    /// Number of receive events.
25    pub received_messages: u64,
26    /// Whether the stream finished.
27    pub completed: bool,
28    /// Final tonic code when known.
29    pub code: Option<Code>,
30    /// Stream duration.
31    pub duration: Duration,
32}
33
34/// Lightweight observer for streaming RPC adapters.
35#[derive(Debug, Clone)]
36pub struct RpcStreamingObserver {
37    service: String,
38    method: String,
39    config: RpcStreamingConfig,
40    started_at: Instant,
41    state: Arc<Mutex<StreamingState>>,
42}
43
44impl RpcStreamingObserver {
45    /// Creates an observer for one stream.
46    pub fn new(
47        service: impl Into<String>,
48        method: impl Into<String>,
49        config: RpcStreamingConfig,
50    ) -> Self {
51        Self {
52            service: service.into(),
53            method: method.into(),
54            config,
55            started_at: Instant::now(),
56            state: Arc::new(Mutex::new(StreamingState::default())),
57        }
58    }
59
60    /// Records a send event.
61    pub async fn record_send(&self) {
62        if !self.config.observe {
63            return;
64        }
65        self.state.lock().await.sent_messages += 1;
66        tracing::debug!(
67            rpc.service = %self.service,
68            rpc.method = %self.method,
69            direction = "send",
70            "rpc stream message"
71        );
72    }
73
74    /// Records a receive event.
75    pub async fn record_recv(&self) {
76        if !self.config.observe {
77            return;
78        }
79        self.state.lock().await.received_messages += 1;
80        tracing::debug!(
81            rpc.service = %self.service,
82            rpc.method = %self.method,
83            direction = "recv",
84            "rpc stream message"
85        );
86    }
87
88    /// Finishes the stream with a final result.
89    pub async fn finish<T>(&self, result: Result<T, tonic::Status>) {
90        let code = result
91            .as_ref()
92            .err()
93            .map(tonic::Status::code)
94            .unwrap_or(Code::Ok);
95        let mut state = self.state.lock().await;
96        state.completed = true;
97        state.code = Some(code);
98        tracing::info!(
99            rpc.service = %self.service,
100            rpc.method = %self.method,
101            code = ?code,
102            "rpc stream finished"
103        );
104    }
105
106    /// Returns a snapshot.
107    pub async fn snapshot(&self) -> RpcStreamingSnapshot {
108        let state = self.state.lock().await;
109        RpcStreamingSnapshot {
110            sent_messages: state.sent_messages,
111            received_messages: state.received_messages,
112            completed: state.completed,
113            code: state.code,
114            duration: self.started_at.elapsed(),
115        }
116    }
117}
118
119/// Wrapper that records receive boundaries for a streaming response.
120#[derive(Debug, Clone)]
121pub struct ObservedRecvStream<S> {
122    inner: S,
123    observer: RpcStreamingObserver,
124}
125
126impl<S> ObservedRecvStream<S> {
127    /// Creates a receive-observed stream wrapper.
128    pub fn new(inner: S, observer: RpcStreamingObserver) -> Self {
129        Self { inner, observer }
130    }
131
132    /// Returns the wrapped stream.
133    pub fn into_inner(self) -> S {
134        self.inner
135    }
136}
137
138impl<S, T> futures::Stream for ObservedRecvStream<S>
139where
140    S: futures::Stream<Item = Result<T, tonic::Status>> + Unpin,
141{
142    type Item = Result<T, tonic::Status>;
143
144    fn poll_next(
145        mut self: std::pin::Pin<&mut Self>,
146        cx: &mut std::task::Context<'_>,
147    ) -> std::task::Poll<Option<Self::Item>> {
148        match std::pin::Pin::new(&mut self.inner).poll_next(cx) {
149            std::task::Poll::Ready(Some(Ok(value))) => {
150                let observer = self.observer.clone();
151                tokio::spawn(async move {
152                    observer.record_recv().await;
153                });
154                std::task::Poll::Ready(Some(Ok(value)))
155            }
156            std::task::Poll::Ready(Some(Err(status))) => {
157                let observer = self.observer.clone();
158                let code = status.code();
159                tokio::spawn(async move {
160                    observer
161                        .finish::<()>(Err(tonic::Status::new(code, "stream receive failed")))
162                        .await;
163                });
164                std::task::Poll::Ready(Some(Err(status)))
165            }
166            std::task::Poll::Ready(None) => {
167                let observer = self.observer.clone();
168                tokio::spawn(async move {
169                    observer.finish::<()>(Ok(())).await;
170                });
171                std::task::Poll::Ready(None)
172            }
173            std::task::Poll::Pending => std::task::Poll::Pending,
174        }
175    }
176}
177
178/// Runs one streaming send operation with send observation.
179pub async fn record_stream_send<F, T>(
180    observer: &RpcStreamingObserver,
181    send: F,
182) -> Result<T, tonic::Status>
183where
184    F: std::future::Future<Output = Result<T, tonic::Status>>,
185{
186    let result = send.await;
187    match &result {
188        Ok(_) => observer.record_send().await,
189        Err(status) => {
190            observer
191                .finish::<()>(Err(tonic::Status::new(status.code(), "stream send failed")))
192                .await;
193        }
194    }
195    result
196}
197
198/// Runs a streaming future with optional timeout and final observation.
199pub async fn run_observed_stream<F, T>(
200    observer: &RpcStreamingObserver,
201    config: &RpcStreamingConfig,
202    stream: F,
203) -> Result<T, tonic::Status>
204where
205    F: std::future::Future<Output = Result<T, tonic::Status>>,
206{
207    let result = if let Some(timeout) = config.timeout {
208        match tokio::time::timeout(timeout, stream).await {
209            Ok(result) => result,
210            Err(_) => Err(tonic::Status::deadline_exceeded("rpc stream timed out")),
211        }
212    } else {
213        stream.await
214    };
215    let final_result = result
216        .as_ref()
217        .map(|_| ())
218        .map_err(|status| tonic::Status::new(status.code(), status.message().to_string()));
219    observer.finish(final_result).await;
220    result
221}