use std::io::BufWriter;
use std::os::unix::net::UnixStream;
use std::sync::mpsc;
use std::thread::JoinHandle;
use std::time::Duration;
use crate::protocol;
use super::state::ClientMsg;
pub(crate) const WRITE_TIMEOUT: Duration = Duration::from_millis(50);
pub(crate) const QUEUE_CAP: usize = 64;
pub(crate) const MAX_WOULDBLOCKS: u32 = 3;
pub(crate) enum OutboundMsg {
Frame(Vec<u8>),
Output(Vec<u8>),
Detached,
Exit,
#[allow(dead_code)]
Raw { tag: u8, payload: Vec<u8> },
Shutdown,
}
pub(crate) fn spawn_writer(
socket: UnixStream,
rx: mpsc::Receiver<OutboundMsg>,
wake_on_drop: mpsc::Sender<ClientMsg>,
) -> JoinHandle<()> {
let _ = socket.set_write_timeout(Some(WRITE_TIMEOUT));
std::thread::Builder::new()
.name("ezpn-writer".to_string())
.spawn(move || run(socket, rx, wake_on_drop))
.expect("spawn ezpn-writer thread")
}
fn run(socket: UnixStream, rx: mpsc::Receiver<OutboundMsg>, wake: mpsc::Sender<ClientMsg>) {
let mut bw = BufWriter::with_capacity(64 * 1024, socket);
let mut consecutive_wouldblocks: u32 = 0;
while let Ok(msg) = rx.recv() {
let result = match &msg {
OutboundMsg::Shutdown => return,
OutboundMsg::Frame(b) | OutboundMsg::Output(b) => {
protocol::write_msg(&mut bw, protocol::S_OUTPUT, b)
}
OutboundMsg::Detached => protocol::write_msg(&mut bw, protocol::S_DETACHED, &[]),
OutboundMsg::Exit => protocol::write_msg(&mut bw, protocol::S_EXIT, &[]),
OutboundMsg::Raw { tag, payload } => protocol::write_msg(&mut bw, *tag, payload),
};
match result {
Ok(()) => consecutive_wouldblocks = 0,
Err(e)
if e.kind() == std::io::ErrorKind::WouldBlock
|| e.kind() == std::io::ErrorKind::TimedOut =>
{
consecutive_wouldblocks += 1;
if consecutive_wouldblocks >= MAX_WOULDBLOCKS {
eprintln!(
"ezpn: evicted slow client after {consecutive_wouldblocks} consecutive write timeouts"
);
let _ = wake.send(ClientMsg::Disconnected);
crate::pane::wake_main_loop();
return;
}
}
Err(_) => {
let _ = wake.send(ClientMsg::Disconnected);
crate::pane::wake_main_loop();
return;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Read;
use std::os::unix::net::UnixStream;
use std::time::Duration;
fn pair() -> (UnixStream, UnixStream) {
UnixStream::pair().expect("UnixStream::pair")
}
fn spawn_drainer(mut peer: UnixStream) -> std::thread::JoinHandle<Vec<u8>> {
std::thread::spawn(move || {
let mut buf = Vec::new();
let _ = peer.read_to_end(&mut buf);
buf
})
}
#[test]
fn writer_passes_through_under_normal_load() {
let (writer_sock, peer) = pair();
let drainer = spawn_drainer(peer);
let (tx, rx) = mpsc::sync_channel::<OutboundMsg>(QUEUE_CAP);
let (wake_tx, _wake_rx) = mpsc::channel();
let handle = spawn_writer(writer_sock, rx_to_unbounded(rx), wake_tx);
for i in 0..10u8 {
tx.send(OutboundMsg::Frame(vec![i; 1024]))
.expect("send frame");
}
drop(tx);
let _ = handle.join();
let buf = drainer.join().expect("drainer thread");
assert_eq!(buf.len(), 10 * 1029, "all frames must reach peer in order");
}
#[test]
fn writer_drops_socket_on_shutdown_msg() {
let (writer_sock, peer) = pair();
let drainer = spawn_drainer(peer);
let (tx, rx) = mpsc::sync_channel::<OutboundMsg>(QUEUE_CAP);
let (wake_tx, _wake_rx) = mpsc::channel();
let handle = spawn_writer(writer_sock, rx_to_unbounded(rx), wake_tx);
tx.send(OutboundMsg::Frame(b"hi".to_vec())).unwrap();
tx.send(OutboundMsg::Shutdown).unwrap();
drop(tx);
let _ = handle.join();
let buf = drainer.join().expect("drainer thread");
assert!(!buf.is_empty(), "peer must receive the queued frame");
}
#[test]
fn writer_evicts_after_three_wouldblocks() {
let (writer_sock, peer) = pair();
let _ = writer_sock.set_write_timeout(Some(Duration::from_millis(20)));
let _ = peer.set_nonblocking(true);
let (tx, rx) = mpsc::sync_channel::<OutboundMsg>(QUEUE_CAP);
let (wake_tx, wake_rx) = mpsc::channel();
let handle = spawn_writer(writer_sock, rx_to_unbounded(rx), wake_tx);
for _ in 0..32 {
let _ = tx.try_send(OutboundMsg::Frame(vec![0u8; 256 * 1024]));
}
let evicted = wake_rx
.recv_timeout(Duration::from_secs(2))
.map(|m| matches!(m, ClientMsg::Disconnected))
.unwrap_or(false);
assert!(
evicted,
"writer must signal Disconnected on repeated timeout"
);
drop(tx);
let _ = handle.join();
drop(peer);
}
fn rx_to_unbounded(rx: mpsc::Receiver<OutboundMsg>) -> mpsc::Receiver<OutboundMsg> {
rx
}
}