use std::{fmt, sync::atomic::AtomicU64, time::Instant};
use bytes::Bytes;
use futures::SinkExt;
use tokio::{
io::{AsyncRead, AsyncWrite},
net::TcpStream,
};
use tokio_util::codec::{Framed, LengthDelimitedCodec};
use tracing::Instrument;
use super::{
ClientCodecConfig,
ClientError,
WireframeClientBuilder,
hooks::{ClientConnectionTeardownHandler, ClientErrorHandler, RequestHooks},
tracing_config::TracingConfig,
tracing_helpers::{call_span, close_span, emit_timing_event, send_span},
};
use crate::{
message::{DecodeWith, EncodeWith},
rewind_stream::RewindStream,
serializer::{BincodeSerializer, Serializer},
};
pub trait ClientStream: AsyncRead + AsyncWrite + Unpin {}
impl<T> ClientStream for T where T: AsyncRead + AsyncWrite + Unpin {}
pub struct WireframeClient<S = BincodeSerializer, T = TcpStream, C = ()>
where
T: ClientStream,
{
pub(crate) framed: Framed<T, LengthDelimitedCodec>,
pub(crate) serializer: S,
pub(crate) codec_config: ClientCodecConfig,
pub(crate) connection_state: Option<C>,
pub(crate) on_disconnect: Option<ClientConnectionTeardownHandler<C>>,
pub(crate) on_error: Option<ClientErrorHandler>,
pub(crate) request_hooks: RequestHooks,
pub(crate) tracing_config: TracingConfig,
pub(crate) correlation_counter: AtomicU64,
}
impl<S, T, C> fmt::Debug for WireframeClient<S, T, C>
where
T: ClientStream,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WireframeClient")
.field("codec_config", &self.codec_config)
.finish_non_exhaustive()
}
}
impl WireframeClient<BincodeSerializer, TcpStream, ()> {
#[must_use]
pub fn builder() -> WireframeClientBuilder<BincodeSerializer, (), ()> {
WireframeClientBuilder::new()
}
}
impl<S, T, C> WireframeClient<S, T, C>
where
S: Serializer + Send + Sync,
T: ClientStream,
{
pub async fn send<M: EncodeWith<S>>(&mut self, message: &M) -> Result<(), ClientError> {
let timing_start = self.tracing_config.send_timing.then(Instant::now);
let mut bytes = match self.serializer.serialize(message) {
Ok(bytes) => bytes,
Err(e) => {
let err = ClientError::Serialize(e);
emit_timing_event(timing_start);
self.invoke_error_hook(&err).await;
return Err(err);
}
};
self.invoke_before_send_hooks(&mut bytes);
let span = send_span(&self.tracing_config, bytes.len());
let send_result = async {
let result = self.framed.send(Bytes::from(bytes)).await;
emit_timing_event(timing_start);
result
}
.instrument(span)
.await;
if let Err(e) = send_result {
let err = ClientError::from(e);
self.invoke_error_hook(&err).await;
return Err(err);
}
Ok(())
}
pub async fn receive<M: DecodeWith<S>>(&mut self) -> Result<M, ClientError> {
self.receive_internal().await
}
pub async fn call<Req: EncodeWith<S>, Resp: DecodeWith<S>>(
&mut self,
request: &Req,
) -> Result<Resp, ClientError> {
let span = call_span(&self.tracing_config);
let timing_start = self.tracing_config.call_timing.then(Instant::now);
async {
if let Err(err) = self.send(request).await {
span.record("result", "err");
emit_timing_event(timing_start);
return Err(err);
}
let result = self.receive().await;
if result.is_ok() {
Self::traced_ok(&span, timing_start);
} else {
span.record("result", "err");
emit_timing_event(timing_start);
}
result
}
.instrument(span.clone())
.await
}
#[must_use]
pub const fn codec_config(&self) -> &ClientCodecConfig { &self.codec_config }
#[must_use]
pub fn stream(&self) -> &T { self.framed.get_ref() }
}
impl<S, C> WireframeClient<S, RewindStream<TcpStream>, C>
where
S: Serializer + Send + Sync,
C: Send + 'static,
{
#[must_use]
pub fn tcp_stream(&self) -> &TcpStream { self.framed.get_ref().inner() }
#[must_use]
pub fn rewind_stream(&self) -> &RewindStream<TcpStream> { self.framed.get_ref() }
pub async fn close(mut self) {
let span = close_span(&self.tracing_config);
let timing_start = self.tracing_config.close_timing.then(Instant::now);
async {
let _ = self.framed.close().await;
if let (Some(state), Some(handler)) =
(self.connection_state.take(), &self.on_disconnect)
{
handler(state).await;
}
emit_timing_event(timing_start);
}
.instrument(span)
.await;
}
}