use smux_rust::{client, server};
use std::time::Duration;
use tokio::net::{TcpListener, TcpStream};
#[tokio::test]
async fn test_poll_wait() -> Result<(), Box<dyn std::error::Error>> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let server_handle = tokio::spawn(async move {
let (conn, _) = listener.accept().await.unwrap();
let session = server(Box::new(conn), None).await.unwrap();
let stream1 = session.accept_stream().await.unwrap();
let stream2 = session.accept_stream().await.unwrap();
let stream3 = session.accept_stream().await.unwrap();
let streams = vec![&stream1, &stream2, &stream3];
match session.poll_wait(&streams).await {
Ok(idx) => {
println!("流 {} 有数据可读", idx);
let mut buf = [0u8; 1024];
let n = streams[idx].read(&mut buf).await.unwrap();
println!("读取了 {} 字节", n);
}
Err(e) => {
eprintln!("poll_wait 错误: {:?}", e);
}
}
tokio::time::sleep(Duration::from_millis(200)).await;
});
let conn = TcpStream::connect(addr).await?;
let session = client(Box::new(conn), None).await?;
let stream1 = session.open_stream().await?;
let stream2 = session.open_stream().await?;
let stream3 = session.open_stream().await?;
tokio::time::sleep(Duration::from_millis(100)).await;
stream2.write_all(b"Hello from stream 2").await?;
tokio::time::sleep(Duration::from_millis(100)).await;
stream1.close().await?;
stream2.close().await?;
stream3.close().await?;
server_handle.await?;
Ok(())
}
#[tokio::test]
async fn test_stream_addresses() -> Result<(), Box<dyn std::error::Error>> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let server_handle = tokio::spawn(async move {
let (conn, _) = listener.accept().await.unwrap();
let session = server(Box::new(conn), None).await.unwrap();
let stream = session.accept_stream().await.unwrap();
let _local = stream.local_addr();
let _remote = stream.remote_addr();
tokio::time::sleep(Duration::from_millis(100)).await;
let _ = stream.close().await;
});
let conn = TcpStream::connect(addr).await?;
let session = client(Box::new(conn), None).await?;
let stream = session.open_stream().await?;
let _local = stream.local_addr();
let _remote = stream.remote_addr();
stream.close().await?;
server_handle.await?;
Ok(())
}
#[tokio::test]
async fn test_copy_to() -> Result<(), Box<dyn std::error::Error>> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let test_data = b"Hello, this is test data for copy_to!";
let test_data_clone = test_data.to_vec();
let server_handle = tokio::spawn(async move {
let (conn, _) = listener.accept().await.unwrap();
let session = server(Box::new(conn), None).await.unwrap();
let stream = session.accept_stream().await.unwrap();
let mut buffer = Vec::new();
let bytes_copied = stream.copy_to(&mut buffer).await.unwrap();
println!("复制了 {} 字节", bytes_copied);
assert_eq!(bytes_copied, test_data_clone.len() as u64);
assert_eq!(buffer, test_data_clone);
let _ = stream.close().await;
});
let conn = TcpStream::connect(addr).await?;
let session = client(Box::new(conn), None).await?;
let stream = session.open_stream().await?;
stream.write_all(test_data).await?;
stream.close().await?;
server_handle.await?;
Ok(())
}
#[tokio::test]
async fn test_poll_wait_with_multiple_ready_streams() -> Result<(), Box<dyn std::error::Error>> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let server_handle = tokio::spawn(async move {
let (conn, _) = listener.accept().await.unwrap();
let session = server(Box::new(conn), None).await.unwrap();
let stream1 = session.accept_stream().await.unwrap();
let stream2 = session.accept_stream().await.unwrap();
tokio::time::sleep(Duration::from_millis(200)).await;
let streams = vec![&stream1, &stream2];
let idx = session.poll_wait(&streams).await.unwrap();
assert!(idx < streams.len());
let _ = stream1.close().await;
let _ = stream2.close().await;
});
let conn = TcpStream::connect(addr).await?;
let session = client(Box::new(conn), None).await?;
let stream1 = session.open_stream().await?;
let stream2 = session.open_stream().await?;
tokio::time::sleep(Duration::from_millis(100)).await;
stream1.write_all(b"data1").await?;
stream2.write_all(b"data2").await?;
tokio::time::sleep(Duration::from_millis(200)).await;
stream1.close().await?;
stream2.close().await?;
server_handle.await?;
Ok(())
}
#[tokio::test]
async fn test_copy_to_with_large_data() -> Result<(), Box<dyn std::error::Error>> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let test_data: Vec<u8> = (0..100_000).map(|i| (i % 256) as u8).collect();
let test_data_clone = test_data.clone();
let server_handle = tokio::spawn(async move {
let (conn, _) = listener.accept().await.unwrap();
let session = server(Box::new(conn), None).await.unwrap();
let stream = session.accept_stream().await.unwrap();
let mut buffer = Vec::new();
let bytes_copied = stream.copy_to(&mut buffer).await.unwrap();
assert_eq!(bytes_copied, test_data_clone.len() as u64);
assert_eq!(buffer, test_data_clone);
println!("成功复制 {} 字节的大数据", bytes_copied);
});
let conn = TcpStream::connect(addr).await?;
let session = client(Box::new(conn), None).await?;
let stream = session.open_stream().await?;
stream.write_all(&test_data).await?;
stream.close().await?;
server_handle.await?;
Ok(())
}