use std::net::SocketAddr;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use crate::procserv::telnet::{TelnetEvent, TelnetParser, iac_escape};
#[derive(Debug)]
pub struct IncomingClient {
pub stream: ClientStream,
pub peer: ClientPeer,
pub readonly: bool,
}
#[derive(Debug)]
pub enum ClientStream {
Tcp(TcpStream),
#[cfg(unix)]
Unix(tokio::net::UnixStream),
}
#[derive(Debug, Clone)]
pub enum ClientPeer {
Tcp(SocketAddr),
#[cfg(unix)]
Unix(Option<std::path::PathBuf>),
}
#[derive(Debug, Clone)]
pub enum OutboundFrame {
Bytes(Vec<u8>),
RawIac(Vec<u8>),
Disconnect,
}
#[derive(Debug)]
pub enum InboundEvent {
Data { bytes: Vec<u8> },
TelnetReply { bytes: Vec<u8> },
Disconnected,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ClientId(u64);
impl ClientId {
pub fn new() -> Self {
use std::sync::atomic::{AtomicU64, Ordering};
static NEXT: AtomicU64 = AtomicU64::new(1);
Self(NEXT.fetch_add(1, Ordering::Relaxed))
}
pub fn raw(self) -> u64 {
self.0
}
}
impl Default for ClientId {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ClientMeta {
pub id: ClientId,
pub peer: ClientPeer,
pub readonly: bool,
}
pub fn spawn_client(
incoming: IncomingClient,
inbound_tx: mpsc::Sender<(ClientId, InboundEvent)>,
) -> (ClientMeta, mpsc::Sender<OutboundFrame>) {
let id = ClientId::new();
let meta = ClientMeta {
id,
peer: incoming.peer,
readonly: incoming.readonly,
};
let (out_tx, out_rx) = mpsc::channel::<OutboundFrame>(64);
match incoming.stream {
ClientStream::Tcp(s) => spawn_split(s, id, incoming.readonly, inbound_tx, out_rx),
#[cfg(unix)]
ClientStream::Unix(s) => spawn_split(s, id, incoming.readonly, inbound_tx, out_rx),
}
(meta, out_tx)
}
fn spawn_split<S>(
stream: S,
id: ClientId,
readonly: bool,
inbound_tx: mpsc::Sender<(ClientId, InboundEvent)>,
mut outbound_rx: mpsc::Receiver<OutboundFrame>,
) where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + 'static,
{
let (mut reader, mut writer) = tokio::io::split(stream);
let inbound = inbound_tx.clone();
tokio::spawn(async move {
let mut parser = TelnetParser::new();
let mut buf = vec![0u8; 1024];
loop {
match reader.read(&mut buf).await {
Ok(0) => break, Ok(n) => {
for ev in parser.feed(&buf[..n]) {
match ev {
TelnetEvent::Data(d) => {
if !readonly
&& inbound
.send((id, InboundEvent::Data { bytes: d }))
.await
.is_err()
{
return;
}
}
TelnetEvent::Reply(r) => {
if inbound
.send((id, InboundEvent::TelnetReply { bytes: r }))
.await
.is_err()
{
return;
}
}
}
}
}
Err(e) => {
tracing::debug!(client = id.raw(), error = %e, "procserv-rs: client read error");
break;
}
}
}
let _ = inbound.send((id, InboundEvent::Disconnected)).await;
});
tokio::spawn(async move {
let init = crate::procserv::telnet::initial_negotiation();
if writer.write_all(&init).await.is_err() {
return;
}
if writer.flush().await.is_err() {
return;
}
while let Some(frame) = outbound_rx.recv().await {
match frame {
OutboundFrame::Bytes(b) => {
let escaped = iac_escape(&b);
if writer.write_all(&escaped).await.is_err() {
break;
}
}
OutboundFrame::RawIac(b) => {
if writer.write_all(&b).await.is_err() {
break;
}
}
OutboundFrame::Disconnect => break,
}
if writer.flush().await.is_err() {
break;
}
}
let _ = writer.shutdown().await;
});
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time::timeout;
async fn paired_streams() -> (TcpStream, TcpStream) {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let connect = tokio::spawn(async move { TcpStream::connect(addr).await.unwrap() });
let (server, _) = listener.accept().await.unwrap();
let client = connect.await.unwrap();
(server, client)
}
#[tokio::test]
async fn read_data_propagates_inbound_event() {
let (server, mut client) = paired_streams().await;
let (in_tx, mut in_rx) = mpsc::channel(8);
let (_meta, out_tx) = spawn_client(
IncomingClient {
stream: ClientStream::Tcp(server),
peer: ClientPeer::Tcp("127.0.0.1:1".parse().unwrap()),
readonly: false,
},
in_tx,
);
let mut neg = [0u8; 6];
client.read_exact(&mut neg).await.unwrap();
client.write_all(b"hi\n").await.unwrap();
let event = timeout(Duration::from_secs(1), in_rx.recv())
.await
.unwrap()
.unwrap();
match event {
(_, InboundEvent::Data { bytes }) => assert_eq!(bytes, b"hi\n"),
other => panic!("unexpected event: {other:?}"),
}
drop(out_tx);
}
#[tokio::test]
async fn readonly_drops_input() {
let (server, mut client) = paired_streams().await;
let (in_tx, mut in_rx) = mpsc::channel(8);
let (_meta, _out_tx) = spawn_client(
IncomingClient {
stream: ClientStream::Tcp(server),
peer: ClientPeer::Tcp("127.0.0.1:1".parse().unwrap()),
readonly: true,
},
in_tx,
);
let mut neg = [0u8; 6];
client.read_exact(&mut neg).await.unwrap();
client.write_all(b"ignored\n").await.unwrap();
let res = timeout(Duration::from_millis(200), in_rx.recv()).await;
assert!(res.is_err(), "readonly client must not produce Data events");
}
#[tokio::test]
async fn write_iac_escapes_payload_bytes() {
let (server, mut client) = paired_streams().await;
let (in_tx, _in_rx) = mpsc::channel(8);
let (_meta, out_tx) = spawn_client(
IncomingClient {
stream: ClientStream::Tcp(server),
peer: ClientPeer::Tcp("127.0.0.1:1".parse().unwrap()),
readonly: false,
},
in_tx,
);
let mut neg = [0u8; 6];
client.read_exact(&mut neg).await.unwrap();
out_tx
.send(OutboundFrame::Bytes(vec![0x41, 0xFF, 0x42]))
.await
.unwrap();
let mut got = [0u8; 4];
client.read_exact(&mut got).await.unwrap();
assert_eq!(got, [0x41, 0xFF, 0xFF, 0x42]);
}
}