use std::fmt;
use std::io::{Read, Write};
use std::mem::ManuallyDrop;
#[cfg(unix)]
use std::net::Shutdown;
use std::net::{SocketAddr, TcpStream};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
#[cfg(windows)]
use std::os::windows::io::AsRawSocket;
use nu_protocol::HandlerGuard;
use ureq::Error;
use ureq::unversioned::transport::{
Buffers, ConnectionDetails, Connector, Either, LazyBuffers, NextTimeout, Transport,
};
pub type OnConnect =
Arc<dyn Fn(&TcpStream) -> Option<(HandlerGuard, Arc<AtomicBool>)> + Send + Sync>;
pub struct InterruptibleTcpConnector {
on_connect: Option<OnConnect>,
}
impl InterruptibleTcpConnector {
pub fn new(on_connect: Option<OnConnect>) -> Self {
Self { on_connect }
}
}
impl fmt::Debug for InterruptibleTcpConnector {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("InterruptibleTcpConnector")
.field("on_connect", &self.on_connect.as_ref().map(|_| "..."))
.finish()
}
}
impl<In: Transport> Connector<In> for InterruptibleTcpConnector {
type Out = Either<In, InterruptibleTcpTransport>;
fn connect(
&self,
details: &ConnectionDetails,
chained: Option<In>,
) -> Result<Option<Self::Out>, Error> {
if chained.is_some() {
return Ok(chained.map(Either::A));
}
let stream = try_connect(details)?;
let (guard, closed) = self
.on_connect
.as_ref()
.and_then(|f| f(&stream))
.map(|(g, c)| (Some(g), c))
.unwrap_or_else(|| (None, Arc::new(AtomicBool::new(false))));
let buffers = LazyBuffers::new(
details.config.input_buffer_size(),
details.config.output_buffer_size(),
);
Ok(Some(Either::B(InterruptibleTcpTransport::new(
stream, buffers, guard, closed,
))))
}
}
fn try_connect(details: &ConnectionDetails) -> Result<TcpStream, Error> {
let timeout = details.timeout;
let mut last_error = None;
for addr in details.addrs.iter() {
match try_connect_single(*addr, timeout) {
Ok(stream) => return Ok(stream),
Err(Error::Io(e)) if e.kind() == std::io::ErrorKind::ConnectionRefused => {
last_error = Some(Error::Io(e));
continue;
}
Err(e) => return Err(e),
}
}
Err(last_error.unwrap_or(Error::ConnectionFailed))
}
fn try_connect_single(addr: SocketAddr, timeout: NextTimeout) -> Result<TcpStream, Error> {
let maybe_stream = if let Some(t) = timeout.not_zero() {
TcpStream::connect_timeout(&addr, *t)
} else {
TcpStream::connect(addr)
};
let stream = match maybe_stream {
Ok(s) => s,
Err(e)
if e.kind() == std::io::ErrorKind::TimedOut
|| e.kind() == std::io::ErrorKind::WouldBlock =>
{
return Err(Error::Timeout(timeout.reason));
}
Err(e) => return Err(Error::Io(e)),
};
let _ = stream.set_nodelay(true);
Ok(stream)
}
pub struct InterruptibleTcpTransport {
stream: ManuallyDrop<TcpStream>,
buffers: LazyBuffers,
timeout_write: Option<Duration>,
timeout_read: Option<Duration>,
closed: Arc<AtomicBool>,
_guard: Option<HandlerGuard>,
}
impl InterruptibleTcpTransport {
pub fn new(
stream: TcpStream,
buffers: LazyBuffers,
guard: Option<HandlerGuard>,
closed: Arc<AtomicBool>,
) -> Self {
Self {
stream: ManuallyDrop::new(stream),
buffers,
timeout_read: None,
timeout_write: None,
closed,
_guard: guard,
}
}
}
impl Drop for InterruptibleTcpTransport {
fn drop(&mut self) {
if !self.closed.swap(true, Ordering::AcqRel) {
unsafe { ManuallyDrop::drop(&mut self.stream) };
}
}
}
impl Transport for InterruptibleTcpTransport {
fn buffers(&mut self) -> &mut dyn Buffers {
&mut self.buffers
}
fn transmit_output(&mut self, amount: usize, timeout: NextTimeout) -> Result<(), Error> {
maybe_update_timeout(
timeout,
&mut self.timeout_write,
&self.stream,
TcpStream::set_write_timeout,
)?;
let output = &self.buffers.output()[..amount];
match self.stream.write_all(output) {
Ok(()) => Ok(()),
Err(e)
if e.kind() == std::io::ErrorKind::TimedOut
|| e.kind() == std::io::ErrorKind::WouldBlock =>
{
Err(Error::Timeout(timeout.reason))
}
Err(e) => Err(Error::Io(e)),
}
}
fn await_input(&mut self, timeout: NextTimeout) -> Result<bool, Error> {
maybe_update_timeout(
timeout,
&mut self.timeout_read,
&self.stream,
TcpStream::set_read_timeout,
)?;
let input = self.buffers.input_append_buf();
let amount = match self.stream.read(input) {
Ok(n) => n,
Err(e)
if e.kind() == std::io::ErrorKind::TimedOut
|| e.kind() == std::io::ErrorKind::WouldBlock =>
{
return Err(Error::Timeout(timeout.reason));
}
Err(e) => return Err(Error::Io(e)),
};
self.buffers.input_appended(amount);
Ok(amount > 0)
}
fn is_open(&mut self) -> bool {
self.stream.set_nonblocking(true).ok();
let mut buf = [0u8; 1];
let result = match self.stream.peek(&mut buf) {
Ok(_) => true,
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => true,
Err(_) => false,
};
self.stream.set_nonblocking(false).ok();
result
}
}
impl fmt::Debug for InterruptibleTcpTransport {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("InterruptibleTcpTransport")
.field("peer_addr", &self.stream.peer_addr().ok())
.finish()
}
}
fn maybe_update_timeout(
timeout: NextTimeout,
previous: &mut Option<Duration>,
stream: &TcpStream,
f: impl Fn(&TcpStream, Option<Duration>) -> std::io::Result<()>,
) -> Result<(), Error> {
let maybe_timeout = timeout.not_zero().map(|t| *t);
if maybe_timeout != *previous {
f(stream, maybe_timeout).map_err(Error::Io)?;
*previous = maybe_timeout;
}
Ok(())
}
#[cfg(windows)]
pub(crate) fn register_close_handler(
handlers: &nu_protocol::Handlers,
raw_socket: std::os::windows::io::RawSocket,
closed: &Arc<AtomicBool>,
) -> Option<HandlerGuard> {
let closed = Arc::clone(closed);
let raw = raw_socket as usize;
handlers
.register(Box::new(move |action| {
if matches!(action, nu_protocol::SignalAction::Interrupt)
&& !closed.swap(true, Ordering::AcqRel)
{
unsafe {
windows::Win32::Networking::WinSock::closesocket(
windows::Win32::Networking::WinSock::SOCKET(raw),
);
}
}
}))
.ok()
}
pub fn make_on_connect(handlers: &nu_protocol::Handlers) -> OnConnect {
let handlers = handlers.clone();
Arc::new(move |stream: &TcpStream| {
let closed = Arc::new(AtomicBool::new(false));
#[cfg(unix)]
let guard = {
let clone = stream.try_clone().ok()?;
handlers
.register(Box::new(move |action| {
if matches!(action, nu_protocol::SignalAction::Interrupt) {
let _ = clone.shutdown(Shutdown::Both);
}
}))
.ok()?
};
#[cfg(windows)]
let guard = register_close_handler(&handlers, stream.as_raw_socket(), &closed)?;
Some((guard, closed))
})
}
#[cfg(test)]
mod tests {
use super::*;
use nu_protocol::{Handlers, SignalAction};
use nu_utils::time::Instant;
use std::io::Write;
use std::net::TcpListener;
use std::thread;
use std::time::Duration;
#[test]
fn test_interrupt_unblocks_read() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let server_thread = thread::spawn(move || {
let (mut stream, _) = listener.accept().unwrap();
thread::sleep(Duration::from_secs(10));
let _ = stream.write_all(b"delayed response");
});
let handlers = Handlers::new();
let on_connect = make_on_connect(&handlers);
let stream = TcpStream::connect(addr).unwrap();
let (guard, closed) = on_connect(&stream).unwrap();
let transport = InterruptibleTcpTransport::new(
stream,
LazyBuffers::new(8192, 8192),
Some(guard),
closed,
);
let handlers_clone = handlers.clone();
thread::spawn(move || {
thread::sleep(Duration::from_millis(100));
handlers_clone.run(SignalAction::Interrupt);
});
let start = Instant::now();
let mut buf = [0u8; 1024];
let result = std::io::Read::read(&mut &*transport.stream, &mut buf);
let elapsed = start.elapsed();
assert!(
elapsed < Duration::from_secs(2),
"Read took too long ({elapsed:?}), interrupt may not have worked",
);
match result {
Ok(0) => {}
Err(_) => {}
Ok(n) => panic!("Unexpected data received: {n} bytes"),
}
drop(transport);
drop(server_thread);
}
#[test]
fn test_connector_creates_transport() {
let connector = InterruptibleTcpConnector::new(None);
let debug_str = format!("{connector:?}");
assert!(debug_str.contains("InterruptibleTcpConnector"));
}
}