use bytes::Bytes;
use futures::{SinkExt, StreamExt};
use rstest::rstest;
use tokio::net::TcpListener;
use tokio_util::codec::{Framed, LengthDelimitedCodec};
use tracing_test::traced_test;
use wireframe_testing::{ServerMode, process_frame};
use crate::{
app::Envelope,
client::{ClientError, TracingConfig, WireframeClient},
rewind_stream::RewindStream,
serializer::BincodeSerializer,
};
type TestClient = WireframeClient<BincodeSerializer, RewindStream<tokio::net::TcpStream>>;
async fn spawn_echo_server() -> (std::net::SocketAddr, tokio::task::JoinHandle<()>) {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("bind listener");
let addr = listener.local_addr().expect("listener addr");
let handle = tokio::spawn(async move {
let (stream, _) = listener.accept().await.expect("accept client");
let mut framed = Framed::new(stream, LengthDelimitedCodec::new());
while let Some(Ok(bytes)) = framed.next().await {
let Some(response_bytes) = process_frame(ServerMode::Echo, &bytes) else {
break;
};
if framed.send(Bytes::from(response_bytes)).await.is_err() {
break;
}
}
});
(addr, handle)
}
async fn with_echo_client<F, Fut>(config: TracingConfig, f: F)
where
F: FnOnce(TestClient, std::net::SocketAddr) -> Fut,
Fut: std::future::Future<Output = ()>,
{
let (addr, server) = spawn_echo_server().await;
let client = WireframeClient::builder()
.tracing_config(config)
.connect(addr)
.await
.expect("connect");
f(client, addr).await;
server.abort();
}
pub(super) fn span_assertion(
span_name: &str,
required_fields: &[&str],
) -> impl Fn(&[&str]) -> Result<(), String> + 'static {
let span = span_name.to_owned();
let fields: Vec<String> = required_fields.iter().map(|s| (*s).to_owned()).collect();
move |lines: &[&str]| {
lines
.iter()
.find(|line| line.contains(&span) && fields.iter().all(|f| line.contains(f.as_str())))
.map(|_| ())
.ok_or_else(|| format!("{span} not found in:\n{}", lines.join("\n")))
}
}
macro_rules! test_span_emission {
($config:expr, $span_name:expr, $required_fields:expr, $operation:expr $(,)?) => {
with_echo_client($config, $operation).await;
logs_assert(span_assertion($span_name, $required_fields));
};
}
#[rstest]
#[traced_test]
#[tokio::test]
async fn connect_emits_span_with_peer_address() {
let captured_addr = std::sync::OnceLock::new();
with_echo_client(
TracingConfig::default().with_connect_timing(true),
|_client, addr| {
captured_addr.set(addr.to_string()).expect("set addr");
async {}
},
)
.await;
let addr_str = captured_addr.get().expect("addr captured");
logs_assert(span_assertion("client.connect", &[addr_str]));
}
#[rstest]
#[traced_test]
#[tokio::test]
async fn send_emits_span_with_frame_bytes() {
test_span_emission!(
TracingConfig::default().with_send_timing(true),
"client.send",
&["frame.bytes"],
|mut client, _addr| async move {
let envelope = Envelope::new(1, None, vec![1, 2, 3]);
client.send(&envelope).await.expect("send");
},
);
}
#[rstest]
#[traced_test]
#[tokio::test]
async fn receive_emits_span_with_result() {
test_span_emission!(
TracingConfig::default().with_receive_timing(true),
"client.receive",
&["result=\"ok\""],
|mut client, _addr| async move {
let envelope = Envelope::new(1, None, vec![1, 2, 3]);
client.send(&envelope).await.expect("send");
let _response: Envelope = client.receive().await.expect("receive");
},
);
}
#[rstest]
#[traced_test]
#[tokio::test]
async fn call_emits_wrapping_span() {
test_span_emission!(
TracingConfig::default().with_call_timing(true),
"client.call",
&[],
|mut client, _addr| async move {
let envelope = Envelope::new(1, None, vec![1, 2, 3]);
let _response: Envelope = client.call(&envelope).await.expect("call");
},
);
}
#[rstest]
#[traced_test]
#[tokio::test]
async fn call_correlated_emits_span_with_correlation_id() {
test_span_emission!(
TracingConfig::default().with_call_timing(true),
"client.call_correlated",
&["correlation_id"],
|mut client, _addr| async move {
let request = Envelope::new(1, None, vec![1, 2, 3]);
let _response: Envelope = client
.call_correlated(request)
.await
.expect("call_correlated");
},
);
}
#[rstest]
#[traced_test]
#[tokio::test]
async fn send_envelope_emits_span_with_correlation_id_and_frame_bytes() {
test_span_emission!(
TracingConfig::default().with_send_timing(true),
"client.send_envelope",
&["correlation_id", "frame.bytes"],
|mut client, _addr| async move {
let envelope = Envelope::new(1, None, vec![1, 2, 3]);
let _cid = client.send_envelope(envelope).await.expect("send_envelope");
},
);
}
#[rstest]
#[traced_test]
#[tokio::test]
async fn close_emits_span() {
with_echo_client(
TracingConfig::default().with_close_timing(true),
|client, _addr| async move {
client.close().await;
},
)
.await;
logs_assert(span_assertion("client.close", &[]));
}
#[rstest]
#[traced_test]
#[tokio::test]
async fn call_correlated_error_records_result_err_and_emits_timing() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("bind listener");
let addr = listener.local_addr().expect("listener addr");
let accept = tokio::spawn(async move {
let (stream, _) = listener.accept().await.expect("accept");
let mut framed = Framed::new(stream, LengthDelimitedCodec::new());
let _frame = framed.next().await;
drop(framed);
});
let mut client = WireframeClient::builder()
.tracing_config(TracingConfig::default().with_call_timing(true))
.connect(addr)
.await
.expect("connect");
let request = Envelope::new(1, None, vec![1, 2, 3]);
let result: Result<Envelope, ClientError> = client.call_correlated(request).await;
assert!(
result.is_err(),
"call_correlated should fail after disconnect"
);
accept.await.expect("join accept");
logs_assert(span_assertion(
"client.call_correlated",
&["result=\"err\"", "elapsed_us"],
));
}
#[rstest]
#[traced_test]
#[tokio::test]
async fn receive_error_records_result_err() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("bind listener");
let addr = listener.local_addr().expect("listener addr");
let accept = tokio::spawn(async move {
let (stream, _) = listener.accept().await.expect("accept");
drop(stream);
});
let mut client = WireframeClient::builder()
.tracing_config(TracingConfig::default().with_receive_timing(true))
.connect(addr)
.await
.expect("connect");
accept.await.expect("join accept");
let result: Result<Envelope, ClientError> = client.receive().await;
assert!(result.is_err(), "receive should fail after disconnect");
logs_assert(span_assertion("client.receive", &["result=\"err\""]));
}
#[rstest]
#[traced_test]
#[tokio::test]
async fn timing_disabled_by_default() {
with_echo_client(TracingConfig::default(), |mut client, _addr| async move {
let envelope = Envelope::new(1, None, vec![1, 2, 3]);
let _response: Envelope = client.call(&envelope).await.expect("call");
})
.await;
assert!(
!logs_contain("elapsed_us"),
"elapsed_us should not appear when timing is disabled"
);
}
#[rstest]
#[traced_test]
#[tokio::test]
async fn timing_enabled_emits_elapsed_us_for_send() {
test_span_emission!(
TracingConfig::default().with_send_timing(true),
"elapsed_us",
&[],
|mut client, _addr| async move {
let envelope = Envelope::new(1, None, vec![1, 2, 3]);
client.send(&envelope).await.expect("send");
},
);
}
#[rstest]
#[traced_test]
#[tokio::test]
async fn timing_enabled_for_connect() {
test_span_emission!(
TracingConfig::default().with_connect_timing(true),
"elapsed_us",
&[],
|_client, _addr| async {},
);
}
#[rstest]
#[traced_test]
#[tokio::test]
async fn all_timing_convenience_enables_all_operations() {
with_echo_client(
TracingConfig::default().with_all_timing(true),
|mut client, _addr| async move {
let envelope = Envelope::new(1, None, vec![1, 2, 3]);
let _response: Envelope = client.call(&envelope).await.expect("call");
},
)
.await;
logs_assert(|lines: &[&str]| {
let count = lines.iter().filter(|l| l.contains("elapsed_us")).count();
if count >= 4 {
Ok(())
} else {
Err(format!("expected >=4 elapsed_us events, found {count}"))
}
});
}
#[rstest]
#[traced_test]
#[tokio::test]
async fn default_config_is_backwards_compatible() {
let (addr, server) = spawn_echo_server().await;
let mut client = WireframeClient::builder()
.connect(addr)
.await
.expect("connect");
let envelope = Envelope::new(1, None, vec![1, 2, 3]);
let _response: Envelope = client.call(&envelope).await.expect("call");
server.abort();
}