use std::{
sync::Arc,
time::{Duration, Instant},
};
use tokio::sync::Mutex;
use tonic::Code;
use crate::rpc::RpcStreamingConfig;
#[derive(Debug, Default)]
struct StreamingState {
sent_messages: u64,
received_messages: u64,
completed: bool,
code: Option<Code>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RpcStreamingSnapshot {
pub sent_messages: u64,
pub received_messages: u64,
pub completed: bool,
pub code: Option<Code>,
pub duration: Duration,
}
#[derive(Debug, Clone)]
pub struct RpcStreamingObserver {
service: String,
method: String,
config: RpcStreamingConfig,
started_at: Instant,
state: Arc<Mutex<StreamingState>>,
}
impl RpcStreamingObserver {
pub fn new(
service: impl Into<String>,
method: impl Into<String>,
config: RpcStreamingConfig,
) -> Self {
Self {
service: service.into(),
method: method.into(),
config,
started_at: Instant::now(),
state: Arc::new(Mutex::new(StreamingState::default())),
}
}
pub async fn record_send(&self) {
if !self.config.observe {
return;
}
self.state.lock().await.sent_messages += 1;
tracing::debug!(
rpc.service = %self.service,
rpc.method = %self.method,
direction = "send",
"rpc stream message"
);
}
pub async fn record_recv(&self) {
if !self.config.observe {
return;
}
self.state.lock().await.received_messages += 1;
tracing::debug!(
rpc.service = %self.service,
rpc.method = %self.method,
direction = "recv",
"rpc stream message"
);
}
pub async fn finish<T>(&self, result: Result<T, tonic::Status>) {
let code = result
.as_ref()
.err()
.map(tonic::Status::code)
.unwrap_or(Code::Ok);
let mut state = self.state.lock().await;
state.completed = true;
state.code = Some(code);
tracing::info!(
rpc.service = %self.service,
rpc.method = %self.method,
code = ?code,
"rpc stream finished"
);
}
pub async fn snapshot(&self) -> RpcStreamingSnapshot {
let state = self.state.lock().await;
RpcStreamingSnapshot {
sent_messages: state.sent_messages,
received_messages: state.received_messages,
completed: state.completed,
code: state.code,
duration: self.started_at.elapsed(),
}
}
}
#[derive(Debug, Clone)]
pub struct ObservedRecvStream<S> {
inner: S,
observer: RpcStreamingObserver,
}
impl<S> ObservedRecvStream<S> {
pub fn new(inner: S, observer: RpcStreamingObserver) -> Self {
Self { inner, observer }
}
pub fn into_inner(self) -> S {
self.inner
}
}
impl<S, T> futures::Stream for ObservedRecvStream<S>
where
S: futures::Stream<Item = Result<T, tonic::Status>> + Unpin,
{
type Item = Result<T, tonic::Status>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
match std::pin::Pin::new(&mut self.inner).poll_next(cx) {
std::task::Poll::Ready(Some(Ok(value))) => {
let observer = self.observer.clone();
tokio::spawn(async move {
observer.record_recv().await;
});
std::task::Poll::Ready(Some(Ok(value)))
}
std::task::Poll::Ready(Some(Err(status))) => {
let observer = self.observer.clone();
let code = status.code();
tokio::spawn(async move {
observer
.finish::<()>(Err(tonic::Status::new(code, "stream receive failed")))
.await;
});
std::task::Poll::Ready(Some(Err(status)))
}
std::task::Poll::Ready(None) => {
let observer = self.observer.clone();
tokio::spawn(async move {
observer.finish::<()>(Ok(())).await;
});
std::task::Poll::Ready(None)
}
std::task::Poll::Pending => std::task::Poll::Pending,
}
}
}
pub async fn record_stream_send<F, T>(
observer: &RpcStreamingObserver,
send: F,
) -> Result<T, tonic::Status>
where
F: std::future::Future<Output = Result<T, tonic::Status>>,
{
let result = send.await;
match &result {
Ok(_) => observer.record_send().await,
Err(status) => {
observer
.finish::<()>(Err(tonic::Status::new(status.code(), "stream send failed")))
.await;
}
}
result
}
pub async fn run_observed_stream<F, T>(
observer: &RpcStreamingObserver,
config: &RpcStreamingConfig,
stream: F,
) -> Result<T, tonic::Status>
where
F: std::future::Future<Output = Result<T, tonic::Status>>,
{
let result = if let Some(timeout) = config.timeout {
match tokio::time::timeout(timeout, stream).await {
Ok(result) => result,
Err(_) => Err(tonic::Status::deadline_exceeded("rpc stream timed out")),
}
} else {
stream.await
};
let final_result = result
.as_ref()
.map(|_| ())
.map_err(|status| tonic::Status::new(status.code(), status.message().to_string()));
observer.finish(final_result).await;
result
}