use bssh::pty::PtyMessage;
use smallvec::SmallVec;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::time::{Instant, timeout};
#[allow(dead_code)]
fn generate_random_data(size: usize) -> Vec<u8> {
(0..size).map(|i| (i % 256) as u8).collect()
}
#[tokio::test]
async fn test_high_throughput_message_processing() {
let (tx, mut rx) = mpsc::channel::<PtyMessage>(10000);
let message_count = 10000;
let start_time = Instant::now();
let producer = tokio::spawn(async move {
for i in 0..message_count {
let data = format!("High throughput message {i}");
let msg = PtyMessage::LocalInput(SmallVec::from_slice(data.as_bytes()));
if tx.send(msg).await.is_err() {
break; }
}
});
let consumer = tokio::spawn(async move {
let mut count = 0;
while let Some(_msg) = rx.recv().await {
count += 1;
if count >= message_count {
break;
}
}
count
});
let (_, received_count) = tokio::try_join!(producer, consumer).unwrap();
let elapsed = start_time.elapsed();
let throughput = received_count as f64 / elapsed.as_secs_f64();
println!("Processed {received_count} messages in {elapsed:?} ({throughput:.2} msg/s)");
assert_eq!(received_count, message_count);
assert!(
throughput > 1000.0,
"Should process at least 1000 messages/second"
);
}
#[tokio::test]
async fn test_memory_usage_under_load() {
let iterations = 1000;
let mut memory_samples = Vec::new();
for round in 0..10 {
let start_memory = get_approximate_memory_usage();
let mut messages = Vec::with_capacity(iterations);
for i in 0..iterations {
let data = format!("Memory test message {i} in round {round}");
let msg = PtyMessage::LocalInput(SmallVec::from_slice(data.as_bytes()));
messages.push(msg);
}
let (tx, mut rx) = mpsc::channel::<PtyMessage>(iterations);
for msg in messages {
let _ = tx.send(msg).await;
}
drop(tx);
let mut count = 0;
while rx.recv().await.is_some() {
count += 1;
}
assert_eq!(count, iterations);
let end_memory = get_approximate_memory_usage();
memory_samples.push(end_memory.saturating_sub(start_memory));
tokio::task::yield_now().await;
}
let avg_growth = memory_samples.iter().sum::<usize>() / memory_samples.len();
println!("Average memory growth per round: {avg_growth} bytes");
assert!(
avg_growth < 1024 * 1024,
"Memory growth should be less than 1MB per round"
);
}
fn get_approximate_memory_usage() -> usize {
std::process::id() as usize * 1024
}
#[tokio::test]
async fn test_resource_exhaustion_recovery() {
let (tx, mut rx) = mpsc::channel::<PtyMessage>(2);
let mut successful_sends = 0;
let mut failed_sends = 0;
for i in 0..50 {
let data = format!("Fill buffer {i}");
let msg = PtyMessage::LocalInput(SmallVec::from_slice(data.as_bytes()));
match tx.try_send(msg) {
Ok(_) => successful_sends += 1,
Err(_) => {
failed_sends += 1;
if rx.try_recv().is_ok() {
let retry_data = format!("Retry after drain {i}");
let retry_msg =
PtyMessage::LocalInput(SmallVec::from_slice(retry_data.as_bytes()));
if tx.try_send(retry_msg).is_ok() {
successful_sends += 1;
}
}
}
}
}
println!("Resource exhaustion test: {successful_sends} successful, {failed_sends} failed");
assert!(successful_sends > 0, "Some sends should succeed");
if failed_sends == 0 {
println!("Channel was more efficient than expected - no failures observed");
} else {
assert!(
failed_sends > 0,
"Expected some failures with very small buffer"
);
}
}
#[tokio::test]
async fn test_concurrent_message_producers() {
let (tx, mut rx) = mpsc::channel::<PtyMessage>(1000);
let producers = 20;
let messages_per_producer = 100;
let mut handles = Vec::new();
for producer_id in 0..producers {
let tx_clone = tx.clone();
let handle = tokio::spawn(async move {
for i in 0..messages_per_producer {
let data = format!("Producer {producer_id} message {i}");
let msg = PtyMessage::LocalInput(SmallVec::from_slice(data.as_bytes()));
match timeout(Duration::from_millis(100), tx_clone.send(msg)).await {
Ok(Ok(_)) => {}
Ok(Err(_)) => break, Err(_) => break, }
tokio::time::sleep(Duration::from_millis(1)).await;
}
producer_id
});
handles.push(handle);
}
drop(tx);
let consumer = tokio::spawn(async move {
let mut total_received = 0;
let mut producer_counts = vec![0; producers];
while let Some(msg) = rx.recv().await {
if let PtyMessage::LocalInput(data) = msg {
let content = String::from_utf8_lossy(&data);
if let Some(start) = content.find("Producer ")
&& let Some(end) = content[start + 9..].find(" ")
&& let Ok(producer_id) = content[start + 9..start + 9 + end].parse::<usize>()
&& producer_id < producers
{
producer_counts[producer_id] += 1;
}
total_received += 1;
}
}
(total_received, producer_counts)
});
let mut completed_producers = 0;
for handle in handles {
if handle.await.is_ok() {
completed_producers += 1;
}
}
let (total_received, producer_counts) = consumer.await.unwrap();
println!(
"Concurrent test: {completed_producers} producers completed, {total_received} total messages received"
);
assert!(completed_producers > 0, "Some producers should complete");
assert!(total_received > 0, "Should receive some messages");
let active_producers = producer_counts.iter().filter(|&&count| count > 0).count();
assert!(
active_producers > 1,
"Should receive messages from multiple producers"
);
}
#[tokio::test]
async fn test_long_running_message_stream() {
let (tx, mut rx) = mpsc::channel::<PtyMessage>(1000);
let duration = Duration::from_secs(2); let start_time = Instant::now();
let producer = tokio::spawn(async move {
let mut count = 0;
while start_time.elapsed() < duration {
let data = format!("Long running message {count}");
let msg = PtyMessage::LocalInput(SmallVec::from_slice(data.as_bytes()));
match tx.send(msg).await {
Ok(_) => count += 1,
Err(_) => break, }
tokio::time::sleep(Duration::from_millis(10)).await;
}
count
});
let consumer = tokio::spawn(async move {
let mut received = 0;
let consumer_start = Instant::now();
while consumer_start.elapsed() < duration + Duration::from_millis(500) {
match timeout(Duration::from_millis(100), rx.recv()).await {
Ok(Some(_)) => received += 1,
Ok(None) => break, Err(_) => continue, }
}
received
});
let (sent, received) = tokio::try_join!(producer, consumer).unwrap();
let actual_duration = start_time.elapsed();
println!("Long running stream: {sent} sent, {received} received in {actual_duration:?}");
assert!(sent > 0, "Should send some messages");
assert!(received > 0, "Should receive some messages");
assert!(
actual_duration >= duration,
"Should run for at least the specified duration"
);
let message_loss = if sent > received { sent - received } else { 0 };
assert!(
message_loss < sent / 10,
"Should not lose more than 10% of messages"
);
}
#[tokio::test]
async fn test_massive_message_batches() {
let batch_sizes = vec![1000, 5000, 10000];
for batch_size in batch_sizes {
let start_time = Instant::now();
let mut messages = Vec::with_capacity(batch_size);
for i in 0..batch_size {
let data = format!("Batch message {i} of {batch_size}");
let msg = PtyMessage::LocalInput(SmallVec::from_slice(data.as_bytes()));
messages.push(msg);
}
let creation_time = start_time.elapsed();
let (tx, mut rx) = mpsc::channel::<PtyMessage>(batch_size);
let sender = tokio::spawn(async move {
let send_start = Instant::now();
for (i, msg) in messages.into_iter().enumerate() {
if tx.send(msg).await.is_err() {
return (i, send_start.elapsed());
}
}
(batch_size, send_start.elapsed())
});
let receiver = tokio::spawn(async move {
let recv_start = Instant::now();
let mut count = 0;
while rx.recv().await.is_some() {
count += 1;
if count >= batch_size {
break;
}
}
(count, recv_start.elapsed())
});
let ((sent_count, send_time), (recv_count, recv_time)) =
tokio::try_join!(sender, receiver).unwrap();
let total_time = start_time.elapsed();
println!(
"Batch size {batch_size}: created in {creation_time:?}, sent {sent_count} in {send_time:?}, received {recv_count} in {recv_time:?}, total {total_time:?}"
);
assert_eq!(sent_count, batch_size, "Should send all messages");
assert_eq!(recv_count, batch_size, "Should receive all messages");
assert!(
total_time < Duration::from_secs(10),
"Should complete within 10 seconds"
);
}
}
#[tokio::test]
async fn test_error_propagation_under_stress() {
let (tx, mut rx) = mpsc::channel::<PtyMessage>(100);
let total_messages = 500;
let error_frequency = 10;
let producer = tokio::spawn(async move {
let mut sent_normal = 0;
let mut sent_errors = 0;
for i in 0..total_messages {
let msg = if i % error_frequency == 0 {
sent_errors += 1;
PtyMessage::Error(format!("Error message {}", i / error_frequency))
} else {
sent_normal += 1;
let data = format!("Normal message {i}");
PtyMessage::LocalInput(SmallVec::from_slice(data.as_bytes()))
};
if tx.send(msg).await.is_err() {
break;
}
}
(sent_normal, sent_errors)
});
let consumer = tokio::spawn(async move {
let mut received_normal = 0;
let mut received_errors = 0;
let mut received_other = 0;
while let Some(msg) = rx.recv().await {
match msg {
PtyMessage::LocalInput(_) => received_normal += 1,
PtyMessage::Error(_) => received_errors += 1,
_ => received_other += 1,
}
if received_normal + received_errors + received_other >= total_messages {
break;
}
}
(received_normal, received_errors, received_other)
});
let ((sent_normal, sent_errors), (received_normal, received_errors, received_other)) =
tokio::try_join!(producer, consumer).unwrap();
println!(
"Error propagation test: sent {sent_normal}N/{sent_errors}E, received {received_normal}N/{received_errors}E/{received_other}O"
);
assert_eq!(
sent_normal, received_normal,
"All normal messages should be received"
);
assert_eq!(
sent_errors, received_errors,
"All error messages should be received"
);
assert_eq!(
received_other, 0,
"Should not receive unexpected message types"
);
let expected_errors = total_messages / error_frequency;
assert_eq!(
sent_errors, expected_errors,
"Should send expected number of errors"
);
}
#[tokio::test]
async fn test_channel_backpressure_behavior() {
let (tx, mut rx) = mpsc::channel::<PtyMessage>(5);
let mut send_attempts = 0;
let mut successful_sends = 0;
let mut blocked_sends = 0;
let producer = tokio::spawn(async move {
for i in 0..50 {
send_attempts += 1;
let data = format!("Backpressure test {i}");
let msg = PtyMessage::LocalInput(SmallVec::from_slice(data.as_bytes()));
match timeout(Duration::from_millis(10), tx.send(msg)).await {
Ok(Ok(_)) => successful_sends += 1,
Ok(Err(_)) => break, Err(_) => blocked_sends += 1, }
}
(send_attempts, successful_sends, blocked_sends)
});
let consumer = tokio::spawn(async move {
let mut received = 0;
for _ in 0..30 {
match timeout(Duration::from_millis(100), rx.recv()).await {
Ok(Some(_)) => {
received += 1;
tokio::time::sleep(Duration::from_millis(20)).await;
}
Ok(None) => break, Err(_) => break, }
}
received
});
let ((attempts, successful, blocked), received) = tokio::try_join!(producer, consumer).unwrap();
println!(
"Backpressure test: {attempts} attempts, {successful} successful, {blocked} blocked, {received} received"
);
assert!(attempts > 0, "Should attempt to send messages");
assert!(successful > 0, "Some sends should succeed");
assert!(blocked > 0, "Some sends should be blocked by backpressure");
assert!(received > 0, "Consumer should receive some messages");
assert!(
blocked > successful / 2,
"Backpressure should cause significant blocking"
);
}
#[tokio::test]
async fn test_message_size_stress() {
let message_sizes = vec![1, 100, 1024, 10240, 102400];
for size in message_sizes {
let (tx, mut rx) = mpsc::channel::<PtyMessage>(100);
let message_count = 50;
let start_time = Instant::now();
let producer_size = size;
let producer = tokio::spawn(async move {
for i in 0..message_count {
let data = vec![b'A' + (i % 26) as u8; producer_size];
let msg = PtyMessage::LocalInput(SmallVec::from_slice(&data));
if tx.send(msg).await.is_err() {
break;
}
}
});
let consumer = tokio::spawn(async move {
let mut received = 0;
let mut total_bytes = 0;
while let Some(msg) = rx.recv().await {
if let PtyMessage::LocalInput(data) = msg {
total_bytes += data.len();
received += 1;
if received >= message_count {
break;
}
}
}
(received, total_bytes)
});
tokio::try_join!(producer, consumer).unwrap();
let elapsed = start_time.elapsed();
println!("Message size {size} bytes: {message_count} messages in {elapsed:?}");
assert!(
elapsed < Duration::from_secs(5),
"Should complete within 5 seconds"
);
}
}
#[tokio::test]
async fn test_stress_cleanup_after_panic_simulation() {
for round in 0..5 {
let (tx, mut rx) = mpsc::channel::<PtyMessage>(100);
let task = tokio::spawn(async move {
for i in 0..1000 {
let data = format!("Cleanup test {i} round {round}");
let msg = PtyMessage::LocalInput(SmallVec::from_slice(data.as_bytes()));
if tx.send(msg).await.is_err() {
break;
}
if i == 50 {
return i;
}
tokio::time::sleep(Duration::from_millis(1)).await;
}
1000
});
tokio::time::sleep(Duration::from_millis(100)).await;
task.abort();
let mut received = 0;
while let Ok(Some(_)) = timeout(Duration::from_millis(10), rx.recv()).await {
received += 1;
if received > 100 {
break; }
}
println!(
"Cleanup test round {round}: received {received} messages after task cancellation"
);
assert!(
received <= 100,
"Should not receive excessive messages after cancellation"
);
}
}