1use std::{
2 sync::Arc,
3 time::{Duration, Instant},
4};
5
6use tokio::sync::Mutex;
7use tonic::Code;
8
9use crate::rpc::RpcStreamingConfig;
10
11#[derive(Debug, Default)]
12struct StreamingState {
13 sent_messages: u64,
14 received_messages: u64,
15 completed: bool,
16 code: Option<Code>,
17}
18
19#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct RpcStreamingSnapshot {
22 pub sent_messages: u64,
24 pub received_messages: u64,
26 pub completed: bool,
28 pub code: Option<Code>,
30 pub duration: Duration,
32}
33
34#[derive(Debug, Clone)]
36pub struct RpcStreamingObserver {
37 service: String,
38 method: String,
39 config: RpcStreamingConfig,
40 started_at: Instant,
41 state: Arc<Mutex<StreamingState>>,
42}
43
44impl RpcStreamingObserver {
45 pub fn new(
47 service: impl Into<String>,
48 method: impl Into<String>,
49 config: RpcStreamingConfig,
50 ) -> Self {
51 Self {
52 service: service.into(),
53 method: method.into(),
54 config,
55 started_at: Instant::now(),
56 state: Arc::new(Mutex::new(StreamingState::default())),
57 }
58 }
59
60 pub async fn record_send(&self) {
62 if !self.config.observe {
63 return;
64 }
65 self.state.lock().await.sent_messages += 1;
66 tracing::debug!(
67 rpc.service = %self.service,
68 rpc.method = %self.method,
69 direction = "send",
70 "rpc stream message"
71 );
72 }
73
74 pub async fn record_recv(&self) {
76 if !self.config.observe {
77 return;
78 }
79 self.state.lock().await.received_messages += 1;
80 tracing::debug!(
81 rpc.service = %self.service,
82 rpc.method = %self.method,
83 direction = "recv",
84 "rpc stream message"
85 );
86 }
87
88 pub async fn finish<T>(&self, result: Result<T, tonic::Status>) {
90 let code = result
91 .as_ref()
92 .err()
93 .map(tonic::Status::code)
94 .unwrap_or(Code::Ok);
95 let mut state = self.state.lock().await;
96 state.completed = true;
97 state.code = Some(code);
98 tracing::info!(
99 rpc.service = %self.service,
100 rpc.method = %self.method,
101 code = ?code,
102 "rpc stream finished"
103 );
104 }
105
106 pub async fn snapshot(&self) -> RpcStreamingSnapshot {
108 let state = self.state.lock().await;
109 RpcStreamingSnapshot {
110 sent_messages: state.sent_messages,
111 received_messages: state.received_messages,
112 completed: state.completed,
113 code: state.code,
114 duration: self.started_at.elapsed(),
115 }
116 }
117}
118
119#[derive(Debug, Clone)]
121pub struct ObservedRecvStream<S> {
122 inner: S,
123 observer: RpcStreamingObserver,
124}
125
126impl<S> ObservedRecvStream<S> {
127 pub fn new(inner: S, observer: RpcStreamingObserver) -> Self {
129 Self { inner, observer }
130 }
131
132 pub fn into_inner(self) -> S {
134 self.inner
135 }
136}
137
138impl<S, T> futures::Stream for ObservedRecvStream<S>
139where
140 S: futures::Stream<Item = Result<T, tonic::Status>> + Unpin,
141{
142 type Item = Result<T, tonic::Status>;
143
144 fn poll_next(
145 mut self: std::pin::Pin<&mut Self>,
146 cx: &mut std::task::Context<'_>,
147 ) -> std::task::Poll<Option<Self::Item>> {
148 match std::pin::Pin::new(&mut self.inner).poll_next(cx) {
149 std::task::Poll::Ready(Some(Ok(value))) => {
150 let observer = self.observer.clone();
151 tokio::spawn(async move {
152 observer.record_recv().await;
153 });
154 std::task::Poll::Ready(Some(Ok(value)))
155 }
156 std::task::Poll::Ready(Some(Err(status))) => {
157 let observer = self.observer.clone();
158 let code = status.code();
159 tokio::spawn(async move {
160 observer
161 .finish::<()>(Err(tonic::Status::new(code, "stream receive failed")))
162 .await;
163 });
164 std::task::Poll::Ready(Some(Err(status)))
165 }
166 std::task::Poll::Ready(None) => {
167 let observer = self.observer.clone();
168 tokio::spawn(async move {
169 observer.finish::<()>(Ok(())).await;
170 });
171 std::task::Poll::Ready(None)
172 }
173 std::task::Poll::Pending => std::task::Poll::Pending,
174 }
175 }
176}
177
178pub async fn record_stream_send<F, T>(
180 observer: &RpcStreamingObserver,
181 send: F,
182) -> Result<T, tonic::Status>
183where
184 F: std::future::Future<Output = Result<T, tonic::Status>>,
185{
186 let result = send.await;
187 match &result {
188 Ok(_) => observer.record_send().await,
189 Err(status) => {
190 observer
191 .finish::<()>(Err(tonic::Status::new(status.code(), "stream send failed")))
192 .await;
193 }
194 }
195 result
196}
197
198pub async fn run_observed_stream<F, T>(
200 observer: &RpcStreamingObserver,
201 config: &RpcStreamingConfig,
202 stream: F,
203) -> Result<T, tonic::Status>
204where
205 F: std::future::Future<Output = Result<T, tonic::Status>>,
206{
207 let result = if let Some(timeout) = config.timeout {
208 match tokio::time::timeout(timeout, stream).await {
209 Ok(result) => result,
210 Err(_) => Err(tonic::Status::deadline_exceeded("rpc stream timed out")),
211 }
212 } else {
213 stream.await
214 };
215 let final_result = result
216 .as_ref()
217 .map(|_| ())
218 .map_err(|status| tonic::Status::new(status.code(), status.message().to_string()));
219 observer.finish(final_result).await;
220 result
221}