use std::io::{Read, Write};
use std::{cell::Cell, io, net, time::Duration};
use ntex::codec::BytesCodec;
use ntex::io::{FilterBuf, FilterLayer, Io};
use ntex::server::test_server;
use ntex::service::fn_service;
use ntex::util::Bytes;
const BURST_SIZE: usize = 16 * 1024;
#[derive(Debug, Default)]
struct BurstWriteFilter {
sent: Cell<bool>,
}
impl FilterLayer for BurstWriteFilter {
fn process_read_buf(&self, buf: &FilterBuf<'_>) -> io::Result<()> {
let got_data = buf.with_read_buffers(|src, dst| {
if let Some(src) = src.take() {
dst.extend_from_slice(&src);
!src.is_empty()
} else {
false
}
});
if got_data && !self.sent.get() {
self.sent.set(true);
buf.with_write_buffers(|_, dst| {
dst.extend_from_slice(&[b'x'; BURST_SIZE]);
});
}
Ok(())
}
fn process_write_buf(&self, buf: &FilterBuf<'_>) -> io::Result<()> {
buf.with_write_buffers(|src, dst| {
if !src.is_empty() {
src.move_to(dst);
}
});
Ok(())
}
}
#[ntex::test]
async fn test_filter_large_write_during_read_processing() {
let srv = test_server(async || {
fn_service(|io: Io| async move {
let io = io.add_filter(BurstWriteFilter::default());
io.send(Bytes::from_static(b"hi"), &BytesCodec)
.await
.unwrap();
while let Ok(Some(msg)) = io.recv(&BytesCodec).await {
io.send(msg, &BytesCodec).await.unwrap();
}
Ok::<_, io::Error>(())
})
});
let mut client = net::TcpStream::connect(srv.addr()).unwrap();
client
.set_read_timeout(Some(Duration::from_secs(5)))
.unwrap();
let mut greeting = [0u8; 2];
client.read_exact(&mut greeting).unwrap();
assert_eq!(&greeting, b"hi");
client.write_all(b"ping").unwrap();
let mut data = vec![0u8; BURST_SIZE + 4];
client
.read_exact(&mut data)
.expect("server did not send data, io driver is broken");
assert!(data[..BURST_SIZE].iter().all(|b| *b == b'x'));
assert_eq!(&data[BURST_SIZE..], b"ping");
}