use super::core::QuicSession;
use super::core::WebTransportHandler;
use super::core::WebTransportStream;
use anyhow::Result;
use bytes::Bytes;
use std::sync::Arc;
use tracing::info;
#[derive(Clone, Default)]
pub(crate) struct EchoHandler;
#[async_trait::async_trait]
impl WebTransportHandler for EchoHandler {
async fn handle(
&self,
session: Arc<QuicSession>,
stream: &mut WebTransportStream,
) -> Result<()> {
let mut payload = Bytes::new();
while let Some(chunk) = stream.recv_data().await? {
if payload.is_empty() {
payload = chunk;
} else {
let mut buf = Vec::with_capacity(payload.len() + chunk.len());
buf.extend_from_slice(&payload);
buf.extend_from_slice(&chunk);
payload = Bytes::from(buf);
}
}
let message =
String::from_utf8(payload.to_vec()).unwrap_or_else(|_| "<binary>".to_string());
info!(session_id = session.id(), remote = %session.remote_addr(), "收到 WebTransport 消息: {message}");
let response = format!("echo(webtransport): {message}");
stream.send_data(Bytes::from(response)).await?;
stream.finish().await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::server::quic::core::QuicSession;
use std::net::SocketAddr;
#[tokio::test]
async fn test_echo_handler_empty_message() {
let remote: SocketAddr = "127.0.0.1:12345".parse().unwrap();
let session = Arc::new(QuicSession::new(remote));
assert!(session.remote_addr() == remote);
}
#[tokio::test]
async fn test_echo_handler_binary_data() {
let binary_data = vec![0xFF, 0xFE, 0xFD, 0xFC];
let bytes = Bytes::from(binary_data);
assert_eq!(bytes.len(), 4);
assert_eq!(bytes[0], 0xFF);
}
#[tokio::test]
async fn test_echo_handler_multiple_chunks() {
let chunks = vec![
Bytes::from("hello "),
Bytes::from("world"),
Bytes::from("!"),
];
let mut payload = Bytes::new();
for chunk in chunks {
if payload.is_empty() {
payload = chunk;
} else {
let mut buf = Vec::with_capacity(payload.len() + chunk.len());
buf.extend_from_slice(&payload);
buf.extend_from_slice(&chunk);
payload = Bytes::from(buf);
}
}
assert_eq!(payload.len(), 12);
assert_eq!(payload, Bytes::from("hello world!"));
let message =
String::from_utf8(payload.to_vec()).unwrap_or_else(|_| "<binary>".to_string());
assert_eq!(message, "hello world!");
}
#[tokio::test]
async fn test_echo_handler_binary_chunk_aggregation() {
let chunks = vec![Bytes::from(&b"\xFF\xFE"[..]), Bytes::from(&b"\xFD\xFC"[..])];
let mut payload = Bytes::new();
for chunk in chunks {
if payload.is_empty() {
payload = chunk;
} else {
let mut buf = Vec::with_capacity(payload.len() + chunk.len());
buf.extend_from_slice(&payload);
buf.extend_from_slice(&chunk);
payload = Bytes::from(buf);
}
}
assert_eq!(payload.len(), 4);
let message =
String::from_utf8(payload.to_vec()).unwrap_or_else(|_| "<binary>".to_string());
assert_eq!(message, "<binary>");
}
#[tokio::test]
async fn test_echo_handler_session_info() {
let addr1: SocketAddr = "127.0.0.1:9999".parse().unwrap();
let addr2: SocketAddr = "127.0.0.1:8888".parse().unwrap();
let session1 = Arc::new(QuicSession::new(addr1));
let session2 = Arc::new(QuicSession::new(addr2));
assert!(!session1.id().is_empty());
assert!(!session2.id().is_empty());
assert_ne!(session1.id(), session2.id()); assert_eq!(session1.remote_addr(), addr1);
assert_eq!(session2.remote_addr(), addr2);
}
#[tokio::test]
async fn test_echo_handler_aggregates_empty_and_nonempty() {
let chunks = vec![
Bytes::new(), Bytes::from("data"),
];
let mut payload = Bytes::new();
for chunk in chunks {
if payload.is_empty() {
payload = chunk;
} else {
let mut buf = Vec::with_capacity(payload.len() + chunk.len());
buf.extend_from_slice(&payload);
buf.extend_from_slice(&chunk);
payload = Bytes::from(buf);
}
}
assert_eq!(payload, Bytes::from("data"));
}
#[tokio::test]
async fn test_echo_handler_single_chunk() {
let chunks = vec![Bytes::from("single")];
let mut payload = Bytes::new();
for chunk in chunks {
if payload.is_empty() {
payload = chunk;
} else {
let mut buf = Vec::with_capacity(payload.len() + chunk.len());
buf.extend_from_slice(&payload);
buf.extend_from_slice(&chunk);
payload = Bytes::from(buf);
}
}
assert_eq!(payload, Bytes::from("single"));
let message =
String::from_utf8(payload.to_vec()).unwrap_or_else(|_| "<binary>".to_string());
assert_eq!(message, "single");
}
#[test]
fn test_echo_handler_response_format() {
let test_cases = vec![
("hello", "echo(webtransport): hello"),
("", "echo(webtransport): "),
("测试中文", "echo(webtransport): 测试中文"),
];
for (input, expected) in test_cases {
let response = format!("echo(webtransport): {input}");
assert_eq!(response, expected);
}
}
}