use std::io::{self, BufWriter, Write};
use std::net::{Shutdown, TcpStream, ToSocketAddrs, UdpSocket};
const TCP_BUF_CAPACITY: usize = 1024 * 1024;
pub struct SocketSink {
buf: BufWriter<TcpStream>,
shutdown_handle: TcpStream,
}
impl SocketSink {
pub fn connect<A: ToSocketAddrs>(addr: A, sndbuf_bytes: Option<usize>) -> io::Result<Self> {
let stream = TcpStream::connect(addr)?;
stream.set_nodelay(true)?;
if let Some(n) = sndbuf_bytes {
set_send_buffer(&stream, n)?;
}
let shutdown_handle = stream.try_clone()?;
Ok(Self {
buf: BufWriter::with_capacity(TCP_BUF_CAPACITY, stream),
shutdown_handle,
})
}
}
impl Write for SocketSink {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.buf.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.buf.flush()
}
}
impl SocketSink {
pub fn finish(&mut self) -> io::Result<()> {
self.buf.flush()?;
self.shutdown_handle.shutdown(Shutdown::Write)
}
}
pub struct UdpSocketSink {
socket: UdpSocket,
}
impl UdpSocketSink {
pub fn connect<A: ToSocketAddrs>(peer: A, sndbuf_bytes: Option<usize>) -> io::Result<Self> {
let socket = UdpSocket::bind("0.0.0.0:0")?;
socket.connect(peer)?;
if let Some(n) = sndbuf_bytes {
set_udp_send_buffer(&socket, n)?;
}
Ok(Self { socket })
}
}
impl Write for UdpSocketSink {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.socket.send(buf)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl UdpSocketSink {
pub fn finish(&mut self) -> io::Result<()> {
Ok(())
}
}
#[cfg(any(target_os = "linux", target_os = "macos"))]
fn set_send_buffer(stream: &TcpStream, bytes: usize) -> io::Result<()> {
use std::os::unix::io::AsRawFd;
setsockopt_sndbuf(stream.as_raw_fd(), bytes)
}
#[cfg(any(target_os = "linux", target_os = "macos"))]
fn set_udp_send_buffer(socket: &UdpSocket, bytes: usize) -> io::Result<()> {
use std::os::unix::io::AsRawFd;
setsockopt_sndbuf(socket.as_raw_fd(), bytes)
}
#[cfg(not(any(target_os = "linux", target_os = "macos")))]
fn set_send_buffer(_stream: &TcpStream, _bytes: usize) -> io::Result<()> {
Ok(())
}
#[cfg(not(any(target_os = "linux", target_os = "macos")))]
fn set_udp_send_buffer(_socket: &UdpSocket, _bytes: usize) -> io::Result<()> {
Ok(())
}
#[cfg(any(target_os = "linux", target_os = "macos"))]
fn setsockopt_sndbuf(fd: std::os::unix::io::RawFd, bytes: usize) -> io::Result<()> {
let want: libc::c_int = bytes.try_into().unwrap_or(libc::c_int::MAX);
let ret = unsafe {
libc::setsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_SNDBUF,
&want as *const _ as *const libc::c_void,
std::mem::size_of::<libc::c_int>() as libc::socklen_t,
)
};
if ret != 0 {
return Err(io::Error::last_os_error());
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Read;
use std::net::{TcpListener, UdpSocket};
use std::thread;
#[test]
fn socket_sink_round_trips_bytes() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let accept = thread::spawn(move || {
let (mut sock, _) = listener.accept().unwrap();
let mut buf = Vec::new();
sock.read_to_end(&mut buf).unwrap();
buf
});
let mut sink = SocketSink::connect(addr, Some(256 * 1024)).unwrap();
let big: Vec<u8> = (0..(2 * TCP_BUF_CAPACITY))
.map(|i| (i & 0xff) as u8)
.collect();
sink.write_all(&big).unwrap();
sink.write_all(b"tail\n").unwrap();
sink.finish().unwrap();
drop(sink);
let received = accept.join().unwrap();
assert_eq!(received.len(), big.len() + 5);
assert_eq!(&received[..big.len()], &big[..]);
assert_eq!(&received[big.len()..], b"tail\n");
}
#[test]
fn socket_sink_is_sequential_only() {
fn _assert_seq(_: &mut dyn super::super::SequentialSink) {}
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let _accept = thread::spawn(move || {
let _ = listener.accept();
});
let mut sink = SocketSink::connect(addr, None).unwrap();
_assert_seq(&mut sink);
}
#[test]
fn udp_socket_sink_delivers_datagrams() {
let receiver = UdpSocket::bind("127.0.0.1:0").unwrap();
receiver
.set_read_timeout(Some(std::time::Duration::from_secs(2)))
.unwrap();
let addr = receiver.local_addr().unwrap();
let mut sink = UdpSocketSink::connect(addr, Some(128 * 1024)).unwrap();
sink.write_all(&[1, 2, 3, 4, 5]).unwrap();
sink.write_all(&[9, 9, 9]).unwrap();
sink.finish().unwrap();
let mut buf = [0u8; 64];
let n1 = receiver.recv(&mut buf).unwrap();
assert_eq!(&buf[..n1], &[1, 2, 3, 4, 5]);
let n2 = receiver.recv(&mut buf).unwrap();
assert_eq!(&buf[..n2], &[9, 9, 9]);
}
}