use std::io;
use std::net::SocketAddr;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use socket2::{Domain, Protocol, Socket, Type};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use crate::{AutosshError, MonitorMode};
#[non_exhaustive]
#[derive(Debug, thiserror::Error)]
pub enum ProbeError {
#[error("probe round-trip timed out")]
Timeout,
#[error("probe io error: {0}")]
Io(#[from] io::Error),
}
#[derive(Debug, Clone, Copy)]
pub struct MonitorPortPair {
pub port_in: u16,
pub port_out: u16,
}
#[derive(Debug)]
pub struct ProbeLoop {
pub listener_in: TcpListener,
pub listener_out: Option<TcpListener>,
pub message_suffix: Option<String>,
pub ports: MonitorPortPair,
}
impl ProbeLoop {
pub fn bind(mode: &MonitorMode, message: Option<&str>) -> Result<Self, AutosshError> {
if let Some(m) = message {
if m.contains('\n') {
return Err(AutosshError::Internal(
"AUTOSSH_MESSAGE contains embedded newline",
));
}
}
match mode {
MonitorMode::None => Err(AutosshError::Internal(
"ProbeLoop::bind called with MonitorMode::None",
)),
MonitorMode::Active { port, echo: None } => {
let listener_in = bind_reuseaddr(*port)?;
let in_port = listener_in
.local_addr()
.map_err(|source| AutosshError::MonitorBindFailed {
port: *port,
source,
})?
.port();
let out_target = if *port == 0 {
0
} else {
port.saturating_add(1)
};
let listener_out = bind_reuseaddr(out_target)?;
let out_port = listener_out
.local_addr()
.map_err(|source| AutosshError::MonitorBindFailed {
port: out_target,
source,
})?
.port();
Ok(Self {
listener_in,
listener_out: Some(listener_out),
message_suffix: message.map(String::from),
ports: MonitorPortPair {
port_in: in_port,
port_out: out_port,
},
})
}
MonitorMode::Active {
port,
echo: Some(echo),
} => {
let listener_in = bind_reuseaddr(*port)?;
let in_port = listener_in
.local_addr()
.map_err(|source| AutosshError::MonitorBindFailed {
port: *port,
source,
})?
.port();
Ok(Self {
listener_in,
listener_out: None,
message_suffix: message.map(String::from),
ports: MonitorPortPair {
port_in: in_port,
port_out: *echo,
},
})
}
}
}
pub async fn probe(&mut self, poll: Duration) -> Result<(), ProbeError> {
let unix_ts = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let payload = probe_payload(unix_ts, self.message_suffix.as_deref());
let port_in = self.ports.port_in;
match self.listener_out.as_mut() {
Some(out) => {
let roundtrip = async {
let write_payload = payload.clone();
let read_len = payload.len();
let writer = async move {
let mut stream = TcpStream::connect(("127.0.0.1", port_in)).await?;
stream.write_all(&write_payload).await?;
stream.flush().await?;
Ok::<(), io::Error>(())
};
let reader = async {
let (mut sock, _) = out.accept().await?;
let mut buf = vec![0u8; read_len];
sock.read_exact(&mut buf).await?;
Ok::<Vec<u8>, io::Error>(buf)
};
let (_, bytes) = tokio::try_join!(writer, reader)?;
Ok::<Vec<u8>, io::Error>(bytes)
};
match tokio::time::timeout(poll, roundtrip).await {
Ok(Ok(_)) => Ok(()),
Ok(Err(e)) => Err(ProbeError::Io(e)),
Err(_) => Err(ProbeError::Timeout),
}
}
None => {
let read_len = payload.len();
let roundtrip = async move {
let mut stream = TcpStream::connect(("127.0.0.1", port_in)).await?;
stream.write_all(&payload).await?;
stream.flush().await?;
let mut buf = vec![0u8; read_len];
stream.read_exact(&mut buf).await?;
Ok::<Vec<u8>, io::Error>(buf)
};
match tokio::time::timeout(poll, roundtrip).await {
Ok(Ok(_)) => Ok(()),
Ok(Err(e)) => Err(ProbeError::Io(e)),
Err(_) => Err(ProbeError::Timeout),
}
}
}
}
}
fn bind_reuseaddr(port: u16) -> Result<TcpListener, AutosshError> {
let addr: SocketAddr = format!("127.0.0.1:{port}")
.parse()
.map_err(|_| AutosshError::Internal("monitor-port socket address parse failed"))?;
let socket = Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP))
.map_err(|source| AutosshError::MonitorBindFailed { port, source })?;
#[cfg(unix)]
{
socket
.set_reuse_address(true)
.map_err(|source| AutosshError::MonitorBindFailed { port, source })?;
}
socket
.set_nonblocking(true)
.map_err(|source| AutosshError::MonitorBindFailed { port, source })?;
socket
.bind(&addr.into())
.map_err(|source| AutosshError::MonitorBindFailed { port, source })?;
socket
.listen(128)
.map_err(|source| AutosshError::MonitorBindFailed { port, source })?;
let std_listener: std::net::TcpListener = socket.into();
TcpListener::from_std(std_listener)
.map_err(|source| AutosshError::MonitorBindFailed { port, source })
}
pub fn probe_payload(unix_ts: u64, message: Option<&str>) -> Vec<u8> {
match message {
None => format!("{unix_ts:016}\n").into_bytes(),
Some(msg) => format!("{unix_ts:016} {msg}\n").into_bytes(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn probe_payload_without_message_is_17_bytes() {
let bytes = probe_payload(1_748_000_000, None);
assert_eq!(bytes.len(), 17);
assert_eq!(&bytes[..16], b"0000001748000000");
assert_eq!(bytes[16], b'\n');
}
#[test]
fn probe_payload_with_message_uses_single_space_separator() {
let bytes = probe_payload(1_748_000_000, Some("hello"));
let expected = b"0000001748000000 hello\n";
assert_eq!(bytes, expected);
}
#[test]
fn probe_payload_timestamp_left_pads_to_16_chars() {
let bytes = probe_payload(42, None);
assert_eq!(&bytes[..16], b"0000000000000042");
assert_eq!(bytes[16], b'\n');
}
#[test]
fn probe_payload_timestamp_at_16_digit_width_does_not_overflow() {
let bytes = probe_payload(9_999_999_999_999_999, None);
assert_eq!(&bytes[..16], b"9999999999999999");
assert_eq!(bytes.len(), 17);
}
#[test]
fn probe_loop_bind_rejects_embedded_newline_in_message() {
let err = ProbeLoop::bind(
&MonitorMode::Active {
port: 0,
echo: None,
},
Some("with\nnewline"),
)
.expect_err("embedded newline must be rejected");
assert!(matches!(err, AutosshError::Internal(_)));
}
#[test]
fn probe_loop_bind_rejects_monitor_mode_none() {
let err = ProbeLoop::bind(&MonitorMode::None, None)
.expect_err("MonitorMode::None must be rejected");
assert!(matches!(err, AutosshError::Internal(_)));
}
}