Skip to main content

rns_cli/
rnsh.rs

1//! Reticulum Remote Shell Utility.
2//!
3//! Protocol and behaviour are based on the first-party Python `rnsh` utility,
4//! which itself credits Aaron Heise's original `rnsh` program.
5
6use std::collections::{HashMap, HashSet};
7use std::ffi::CString;
8use std::fs;
9use std::io::{self, IsTerminal, Read, Write};
10use std::os::fd::RawFd;
11use std::path::{Path, PathBuf};
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::mpsc;
14use std::time::{Duration, Instant};
15
16use rns_core::buffer::StreamDataMessage;
17use rns_core::types::{DestHash, IdentityHash};
18use rns_crypto::identity::Identity;
19use rns_crypto::OsRng;
20use rns_net::compressor::Bzip2Compressor;
21use rns_net::destination::Destination;
22use rns_net::{Callbacks, RnsNode, SendError};
23
24use crate::format::{prettyb256rep, prettyhexrep};
25
26const APP_NAME: &str = "rnsh";
27const DEFAULT_SERVICE_NAME: &str = "default";
28const VERSION: &str = env!("FULL_VERSION");
29
30const MSG_MAGIC: u16 = 0xac;
31const PROTOCOL_VERSION: u64 = 1;
32
33const MSG_NOOP: u16 = (MSG_MAGIC << 8) | 0;
34const MSG_WINDOW_SIZE: u16 = (MSG_MAGIC << 8) | 2;
35const MSG_EXECUTE_COMMAND: u16 = (MSG_MAGIC << 8) | 3;
36const MSG_STREAM_DATA: u16 = (MSG_MAGIC << 8) | 4;
37const MSG_VERSION_INFO: u16 = (MSG_MAGIC << 8) | 5;
38const MSG_ERROR: u16 = (MSG_MAGIC << 8) | 6;
39const MSG_COMMAND_EXITED: u16 = (MSG_MAGIC << 8) | 7;
40
41const STREAM_STDIN: u16 = 0;
42const STREAM_STDOUT: u16 = 1;
43const STREAM_STDERR: u16 = 2;
44
45const CHANNEL_PAYLOAD_MAX: usize =
46    rns_core::constants::LINK_MDU - rns_core::constants::CHANNEL_ENVELOPE_OVERHEAD;
47const STREAM_CHUNK_MAX: usize = CHANNEL_PAYLOAD_MAX - 2;
48const MAX_DECOMPRESSED_STREAM_CHUNK: usize = 64 * 1024;
49
50static SIGWINCH_SEEN: AtomicBool = AtomicBool::new(false);
51
52extern "C" fn sigwinch_handler(_: libc::c_int) {
53    SIGWINCH_SEEN.store(true, Ordering::SeqCst);
54}
55
56#[derive(Debug)]
57enum RnshError {
58    Io(io::Error),
59    Protocol(String),
60    Send,
61}
62
63impl From<io::Error> for RnshError {
64    fn from(value: io::Error) -> Self {
65        RnshError::Io(value)
66    }
67}
68
69impl From<SendError> for RnshError {
70    fn from(_: SendError) -> Self {
71        RnshError::Send
72    }
73}
74
75impl std::fmt::Display for RnshError {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        match self {
78            RnshError::Io(err) => write!(f, "{err}"),
79            RnshError::Protocol(err) => write!(f, "{err}"),
80            RnshError::Send => write!(f, "RNS send failed"),
81        }
82    }
83}
84
85pub fn main() -> i32 {
86    match CliOptions::parse(std::env::args().skip(1).collect()) {
87        Ok(opts) => {
88            if opts.help {
89                print_usage();
90                return 0;
91            }
92            if opts.version {
93                println!("rnsh {} (protocol {})", VERSION, PROTOCOL_VERSION);
94                return 0;
95            }
96            if let Err(err) = init_rnsh_logging(&opts) {
97                eprintln!("{err}");
98                return 1;
99            }
100            if opts.print_identity {
101                return match print_identity(&opts) {
102                    Ok(()) => 0,
103                    Err(err) => {
104                        eprintln!("{err}");
105                        1
106                    }
107                };
108            }
109            let result = if opts.listen {
110                listen(opts).map(|_| 0)
111            } else if opts.destination.is_some() {
112                let mirror = opts.mirror_exit;
113                initiate(opts).map(|code| if mirror { code } else { 0 })
114            } else {
115                print_usage();
116                Ok(1)
117            };
118            match result {
119                Ok(code) => code,
120                Err(err) => {
121                    eprintln!("{err}");
122                    1
123                }
124            }
125        }
126        Err(err) => {
127            eprintln!("{err}");
128            print_usage();
129            1
130        }
131    }
132}
133
134#[derive(Debug, Clone, Default)]
135struct CliOptions {
136    config: Option<String>,
137    identity: Option<String>,
138    verbose: u8,
139    quiet: u8,
140    print_identity: bool,
141    base256: bool,
142    version: bool,
143    help: bool,
144    listen: bool,
145    service: Option<String>,
146    announce_period: Option<u64>,
147    allowed: Vec<String>,
148    no_auth: bool,
149    remote_command_as_args: bool,
150    no_remote_command: bool,
151    no_id: bool,
152    mirror_exit: bool,
153    timeout: Option<f64>,
154    destination: Option<String>,
155    command: Vec<String>,
156}
157
158impl CliOptions {
159    fn parse(argv: Vec<String>) -> Result<Self, String> {
160        let mut opts = CliOptions::default();
161        let (rnsh_argv, command) = match argv.iter().position(|arg| arg == "--") {
162            Some(idx) => (argv[..idx].to_vec(), argv[idx + 1..].to_vec()),
163            None => (argv, Vec::new()),
164        };
165        opts.command = command;
166
167        let mut i = 0;
168        while i < rnsh_argv.len() {
169            let arg = &rnsh_argv[i];
170            if !arg.starts_with('-') || arg == "-" {
171                if opts.destination.is_some() {
172                    return Err(format!("unexpected positional argument: {arg}"));
173                }
174                opts.destination = Some(arg.clone());
175                i += 1;
176                continue;
177            }
178
179            if let Some(name) = arg.strip_prefix("--") {
180                match name {
181                    "config" | "identity" | "service" | "announce" | "allowed" | "timeout" => {
182                        i += 1;
183                        let value = rnsh_argv
184                            .get(i)
185                            .ok_or_else(|| format!("--{name} requires a value"))?
186                            .clone();
187                        match name {
188                            "config" => opts.config = Some(value),
189                            "identity" => opts.identity = Some(value),
190                            "service" => opts.service = Some(value),
191                            "announce" => {
192                                opts.announce_period = Some(value.parse().map_err(|_| {
193                                    "--announce requires an integer period".to_string()
194                                })?)
195                            }
196                            "allowed" => opts.allowed.push(value),
197                            "timeout" => {
198                                opts.timeout = Some(value.parse().map_err(|_| {
199                                    "--timeout requires a numeric value".to_string()
200                                })?)
201                            }
202                            _ => {}
203                        }
204                    }
205                    "verbose" => opts.verbose = opts.verbose.saturating_add(1),
206                    "quiet" => opts.quiet = opts.quiet.saturating_add(1),
207                    "print-identity" => opts.print_identity = true,
208                    "base256" => opts.base256 = true,
209                    "version" => opts.version = true,
210                    "help" => opts.help = true,
211                    "listen" => opts.listen = true,
212                    "no-auth" => opts.no_auth = true,
213                    "remote-command-as-args" => opts.remote_command_as_args = true,
214                    "no-remote-command" => opts.no_remote_command = true,
215                    "no-id" => opts.no_id = true,
216                    "mirror" => opts.mirror_exit = true,
217                    _ => return Err(format!("unknown option --{name}")),
218                }
219                i += 1;
220                continue;
221            }
222
223            let chars: Vec<char> = arg[1..].chars().collect();
224            let mut pos = 0;
225            while pos < chars.len() {
226                match chars[pos] {
227                    'c' | 'i' | 's' | 'b' | 'a' | 'w' => {
228                        let key = chars[pos];
229                        let value = if pos + 1 < chars.len() {
230                            chars[pos + 1..].iter().collect::<String>()
231                        } else {
232                            i += 1;
233                            rnsh_argv
234                                .get(i)
235                                .ok_or_else(|| format!("-{key} requires a value"))?
236                                .clone()
237                        };
238                        match key {
239                            'c' => opts.config = Some(value),
240                            'i' => opts.identity = Some(value),
241                            's' => opts.service = Some(value),
242                            'b' => {
243                                opts.announce_period = Some(
244                                    value
245                                        .parse()
246                                        .map_err(|_| "-b requires an integer".to_string())?,
247                                )
248                            }
249                            'a' => opts.allowed.push(value),
250                            'w' => {
251                                opts.timeout = Some(
252                                    value
253                                        .parse()
254                                        .map_err(|_| "-w requires a number".to_string())?,
255                                )
256                            }
257                            _ => {}
258                        }
259                        break;
260                    }
261                    'v' => opts.verbose = opts.verbose.saturating_add(1),
262                    'q' => opts.quiet = opts.quiet.saturating_add(1),
263                    'p' => opts.print_identity = true,
264                    'Z' => opts.base256 = true,
265                    'l' => opts.listen = true,
266                    'n' => opts.no_auth = true,
267                    'A' => opts.remote_command_as_args = true,
268                    'C' => opts.no_remote_command = true,
269                    'N' => opts.no_id = true,
270                    'm' => opts.mirror_exit = true,
271                    'h' => opts.help = true,
272                    other => return Err(format!("unknown option -{other}")),
273                }
274                pos += 1;
275            }
276            i += 1;
277        }
278
279        if opts.listen && opts.service.is_none() {
280            opts.service = Some(DEFAULT_SERVICE_NAME.to_string());
281        }
282        Ok(opts)
283    }
284}
285
286fn print_usage() {
287    eprintln!(
288        "Usage:\n  rnsh -l [options] [-- command...]\n  rnsh [options] <destination> [-- command...]\n\nOptions:\n  -c, --config PATH        Reticulum config directory\n  -i, --identity PATH      Identity file to use\n  -p, --print-identity     Print identity and destination info\n  -Z, --base256            Also print compact base256 display for hashes\n  -l, --listen             Listen for remote shell links\n  -s, --service NAME       Listener identity service name\n  -b, --announce PERIOD    Announce on startup and every PERIOD seconds (0 = once)\n  -a, --allowed HASH       Allow initiator identity hash (repeatable)\n  -n, --no-auth            Allow any initiator identity\n  -A, --remote-command-as-args\n  -C, --no-remote-command\n  -N, --no-id              Do not identify to the listener\n  -m, --mirror             Return remote command exit code\n  -w, --timeout SECONDS    Path/link/protocol timeout"
289    );
290}
291
292fn init_rnsh_logging(opts: &CliOptions) -> Result<(), RnshError> {
293    let dir = rnsh_config_dir()?;
294    let file = std::fs::OpenOptions::new()
295        .create(true)
296        .append(true)
297        .open(dir.join("logfile"))?;
298    let mut builder = env_logger::Builder::new();
299    builder
300        .filter_level(rnsh_log_level(opts.listen, opts.verbose, opts.quiet))
301        .format_timestamp_secs()
302        .target(env_logger::Target::Pipe(Box::new(file)));
303    builder
304        .try_init()
305        .map_err(|err| RnshError::Protocol(format!("failed to initialize rnsh logging: {err}")))
306}
307
308fn rnsh_config_dir() -> Result<PathBuf, RnshError> {
309    let home = std::env::var("HOME").unwrap_or_else(|_| ".".into());
310    let xdg = PathBuf::from(&home).join(".config").join("rnsh");
311    if xdg.is_dir() {
312        return Ok(xdg);
313    }
314    let legacy = PathBuf::from(home).join(".rnsh");
315    fs::create_dir_all(&legacy)?;
316    Ok(legacy)
317}
318
319fn rnsh_log_level(listen: bool, verbose: u8, quiet: u8) -> log::LevelFilter {
320    let base: i16 = if listen { 3 } else { 1 };
321    match (base + verbose as i16 - quiet as i16).clamp(0, 5) {
322        0 => log::LevelFilter::Off,
323        1 => log::LevelFilter::Error,
324        2 => log::LevelFilter::Warn,
325        3 => log::LevelFilter::Info,
326        4 => log::LevelFilter::Debug,
327        _ => log::LevelFilter::Trace,
328    }
329}
330
331#[derive(Debug, Clone, PartialEq)]
332enum MsgValue {
333    Nil,
334    Bool(bool),
335    Int(i64),
336    String(String),
337    Bytes(Vec<u8>),
338    Array(Vec<MsgValue>),
339    Map(Vec<(MsgValue, MsgValue)>),
340}
341
342fn msgpack_pack(value: &MsgValue, out: &mut Vec<u8>) {
343    match value {
344        MsgValue::Nil => out.push(0xc0),
345        MsgValue::Bool(false) => out.push(0xc2),
346        MsgValue::Bool(true) => out.push(0xc3),
347        MsgValue::Int(v) if *v >= 0 && *v <= 0x7f => out.push(*v as u8),
348        MsgValue::Int(v) if *v >= -32 && *v < 0 => out.push((*v as i8) as u8),
349        MsgValue::Int(v) if *v >= i8::MIN as i64 && *v <= i8::MAX as i64 => {
350            out.extend_from_slice(&[0xd0, *v as i8 as u8]);
351        }
352        MsgValue::Int(v) if *v >= i16::MIN as i64 && *v <= i16::MAX as i64 => {
353            out.push(0xd1);
354            out.extend_from_slice(&(*v as i16).to_be_bytes());
355        }
356        MsgValue::Int(v) if *v >= i32::MIN as i64 && *v <= i32::MAX as i64 => {
357            out.push(0xd2);
358            out.extend_from_slice(&(*v as i32).to_be_bytes());
359        }
360        MsgValue::Int(v) => {
361            out.push(0xd3);
362            out.extend_from_slice(&v.to_be_bytes());
363        }
364        MsgValue::String(s) => pack_msgpack_str(s.as_bytes(), out, true),
365        MsgValue::Bytes(bytes) => pack_msgpack_str(bytes, out, false),
366        MsgValue::Array(items) => {
367            if items.len() < 16 {
368                out.push(0x90 | items.len() as u8);
369            } else if items.len() <= u16::MAX as usize {
370                out.push(0xdc);
371                out.extend_from_slice(&(items.len() as u16).to_be_bytes());
372            } else {
373                out.push(0xdd);
374                out.extend_from_slice(&(items.len() as u32).to_be_bytes());
375            }
376            for item in items {
377                msgpack_pack(item, out);
378            }
379        }
380        MsgValue::Map(items) => {
381            if items.len() < 16 {
382                out.push(0x80 | items.len() as u8);
383            } else {
384                out.push(0xde);
385                out.extend_from_slice(&(items.len() as u16).to_be_bytes());
386            }
387            for (key, value) in items {
388                msgpack_pack(key, out);
389                msgpack_pack(value, out);
390            }
391        }
392    }
393}
394
395fn pack_msgpack_str(bytes: &[u8], out: &mut Vec<u8>, utf8: bool) {
396    if utf8 {
397        if bytes.len() < 32 {
398            out.push(0xa0 | bytes.len() as u8);
399        } else if bytes.len() <= u8::MAX as usize {
400            out.extend_from_slice(&[0xd9, bytes.len() as u8]);
401        } else if bytes.len() <= u16::MAX as usize {
402            out.push(0xda);
403            out.extend_from_slice(&(bytes.len() as u16).to_be_bytes());
404        } else {
405            out.push(0xdb);
406            out.extend_from_slice(&(bytes.len() as u32).to_be_bytes());
407        }
408    } else if bytes.len() <= u8::MAX as usize {
409        out.extend_from_slice(&[0xc4, bytes.len() as u8]);
410    } else if bytes.len() <= u16::MAX as usize {
411        out.push(0xc5);
412        out.extend_from_slice(&(bytes.len() as u16).to_be_bytes());
413    } else {
414        out.push(0xc6);
415        out.extend_from_slice(&(bytes.len() as u32).to_be_bytes());
416    }
417    out.extend_from_slice(bytes);
418}
419
420fn msgpack_unpack(raw: &[u8]) -> Result<MsgValue, RnshError> {
421    let (value, consumed) = unpack_at(raw, 0)?;
422    if consumed != raw.len() {
423        return Err(RnshError::Protocol("trailing msgpack data".into()));
424    }
425    Ok(value)
426}
427
428fn unpack_at(raw: &[u8], mut pos: usize) -> Result<(MsgValue, usize), RnshError> {
429    let tag = *raw
430        .get(pos)
431        .ok_or_else(|| RnshError::Protocol("truncated msgpack".into()))?;
432    pos += 1;
433    match tag {
434        0x00..=0x7f => Ok((MsgValue::Int(tag as i64), pos)),
435        0x80..=0x8f => unpack_map(raw, pos, (tag & 0x0f) as usize),
436        0x90..=0x9f => unpack_array(raw, pos, (tag & 0x0f) as usize),
437        0xa0..=0xbf => unpack_string(raw, pos, (tag & 0x1f) as usize),
438        0xc0 => Ok((MsgValue::Nil, pos)),
439        0xc2 => Ok((MsgValue::Bool(false), pos)),
440        0xc3 => Ok((MsgValue::Bool(true), pos)),
441        0xc4 => {
442            let len = read_u8(raw, &mut pos)? as usize;
443            unpack_bytes(raw, pos, len)
444        }
445        0xc5 => {
446            let len = read_u16(raw, &mut pos)? as usize;
447            unpack_bytes(raw, pos, len)
448        }
449        0xc6 => {
450            let len = read_u32(raw, &mut pos)? as usize;
451            unpack_bytes(raw, pos, len)
452        }
453        0xcc => Ok((MsgValue::Int(read_u8(raw, &mut pos)? as i64), pos)),
454        0xcd => Ok((MsgValue::Int(read_u16(raw, &mut pos)? as i64), pos)),
455        0xce => Ok((MsgValue::Int(read_u32(raw, &mut pos)? as i64), pos)),
456        0xcf => Ok((MsgValue::Int(read_u64(raw, &mut pos)? as i64), pos)),
457        0xd0 => Ok((MsgValue::Int(read_u8(raw, &mut pos)? as i8 as i64), pos)),
458        0xd1 => Ok((MsgValue::Int(read_u16(raw, &mut pos)? as i16 as i64), pos)),
459        0xd2 => Ok((MsgValue::Int(read_u32(raw, &mut pos)? as i32 as i64), pos)),
460        0xd3 => Ok((MsgValue::Int(read_u64(raw, &mut pos)? as i64), pos)),
461        0xd9 => {
462            let len = read_u8(raw, &mut pos)? as usize;
463            unpack_string(raw, pos, len)
464        }
465        0xda => {
466            let len = read_u16(raw, &mut pos)? as usize;
467            unpack_string(raw, pos, len)
468        }
469        0xdb => {
470            let len = read_u32(raw, &mut pos)? as usize;
471            unpack_string(raw, pos, len)
472        }
473        0xdc => {
474            let len = read_u16(raw, &mut pos)? as usize;
475            unpack_array(raw, pos, len)
476        }
477        0xdd => {
478            let len = read_u32(raw, &mut pos)? as usize;
479            unpack_array(raw, pos, len)
480        }
481        0xde => {
482            let len = read_u16(raw, &mut pos)? as usize;
483            unpack_map(raw, pos, len)
484        }
485        0xdf => {
486            let len = read_u32(raw, &mut pos)? as usize;
487            unpack_map(raw, pos, len)
488        }
489        0xe0..=0xff => Ok((MsgValue::Int((tag as i8) as i64), pos)),
490        _ => Err(RnshError::Protocol(format!(
491            "unsupported msgpack tag 0x{tag:02x}"
492        ))),
493    }
494}
495
496fn unpack_array(raw: &[u8], mut pos: usize, len: usize) -> Result<(MsgValue, usize), RnshError> {
497    let mut values = Vec::with_capacity(len);
498    for _ in 0..len {
499        let (value, next) = unpack_at(raw, pos)?;
500        values.push(value);
501        pos = next;
502    }
503    Ok((MsgValue::Array(values), pos))
504}
505
506fn unpack_map(raw: &[u8], mut pos: usize, len: usize) -> Result<(MsgValue, usize), RnshError> {
507    let mut values = Vec::with_capacity(len);
508    for _ in 0..len {
509        let (key, next) = unpack_at(raw, pos)?;
510        let (value, next) = unpack_at(raw, next)?;
511        values.push((key, value));
512        pos = next;
513    }
514    Ok((MsgValue::Map(values), pos))
515}
516
517fn unpack_string(raw: &[u8], pos: usize, len: usize) -> Result<(MsgValue, usize), RnshError> {
518    let bytes = raw
519        .get(pos..pos + len)
520        .ok_or_else(|| RnshError::Protocol("truncated msgpack string".into()))?;
521    let s = std::str::from_utf8(bytes)
522        .map_err(|_| RnshError::Protocol("invalid msgpack utf8".into()))?;
523    Ok((MsgValue::String(s.to_string()), pos + len))
524}
525
526fn unpack_bytes(raw: &[u8], pos: usize, len: usize) -> Result<(MsgValue, usize), RnshError> {
527    let bytes = raw
528        .get(pos..pos + len)
529        .ok_or_else(|| RnshError::Protocol("truncated msgpack bytes".into()))?;
530    Ok((MsgValue::Bytes(bytes.to_vec()), pos + len))
531}
532
533fn read_u8(raw: &[u8], pos: &mut usize) -> Result<u8, RnshError> {
534    let v = *raw
535        .get(*pos)
536        .ok_or_else(|| RnshError::Protocol("truncated msgpack integer".into()))?;
537    *pos += 1;
538    Ok(v)
539}
540
541fn read_u16(raw: &[u8], pos: &mut usize) -> Result<u16, RnshError> {
542    let bytes = raw
543        .get(*pos..*pos + 2)
544        .ok_or_else(|| RnshError::Protocol("truncated msgpack integer".into()))?;
545    *pos += 2;
546    Ok(u16::from_be_bytes([bytes[0], bytes[1]]))
547}
548
549fn read_u32(raw: &[u8], pos: &mut usize) -> Result<u32, RnshError> {
550    let bytes = raw
551        .get(*pos..*pos + 4)
552        .ok_or_else(|| RnshError::Protocol("truncated msgpack integer".into()))?;
553    *pos += 4;
554    Ok(u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
555}
556
557fn read_u64(raw: &[u8], pos: &mut usize) -> Result<u64, RnshError> {
558    let bytes = raw
559        .get(*pos..*pos + 8)
560        .ok_or_else(|| RnshError::Protocol("truncated msgpack integer".into()))?;
561    *pos += 8;
562    Ok(u64::from_be_bytes([
563        bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
564    ]))
565}
566
567#[derive(Debug, Clone, PartialEq)]
568struct WindowSize {
569    rows: Option<u16>,
570    cols: Option<u16>,
571    hpix: Option<u16>,
572    vpix: Option<u16>,
573}
574
575#[derive(Debug, Clone, PartialEq)]
576struct ExecuteCommand {
577    cmdline: Vec<String>,
578    pipe_stdin: bool,
579    pipe_stdout: bool,
580    pipe_stderr: bool,
581    term: Option<String>,
582    rows: Option<u16>,
583    cols: Option<u16>,
584    hpix: Option<u16>,
585    vpix: Option<u16>,
586}
587
588#[derive(Debug, Clone, PartialEq)]
589enum RnshMessage {
590    Noop,
591    WindowSize(WindowSize),
592    ExecuteCommand(ExecuteCommand),
593    StreamData(StreamDataMessage),
594    VersionInfo {
595        sw_version: String,
596        protocol_version: u64,
597    },
598    Error {
599        msg: String,
600        fatal: bool,
601    },
602    CommandExited(i32),
603}
604
605impl RnshMessage {
606    fn msgtype(&self) -> u16 {
607        match self {
608            RnshMessage::Noop => MSG_NOOP,
609            RnshMessage::WindowSize(_) => MSG_WINDOW_SIZE,
610            RnshMessage::ExecuteCommand(_) => MSG_EXECUTE_COMMAND,
611            RnshMessage::StreamData(_) => MSG_STREAM_DATA,
612            RnshMessage::VersionInfo { .. } => MSG_VERSION_INFO,
613            RnshMessage::Error { .. } => MSG_ERROR,
614            RnshMessage::CommandExited(_) => MSG_COMMAND_EXITED,
615        }
616    }
617
618    fn pack(&self) -> Vec<u8> {
619        match self {
620            RnshMessage::Noop => Vec::new(),
621            RnshMessage::StreamData(msg) => msg.pack(),
622            RnshMessage::VersionInfo {
623                sw_version,
624                protocol_version,
625            } => pack_msgpack_array(vec![
626                MsgValue::String(sw_version.clone()),
627                MsgValue::Int(*protocol_version as i64),
628            ]),
629            RnshMessage::WindowSize(size) => pack_msgpack_array(vec![
630                opt_u16(size.rows),
631                opt_u16(size.cols),
632                opt_u16(size.hpix),
633                opt_u16(size.vpix),
634            ]),
635            RnshMessage::ExecuteCommand(cmd) => pack_msgpack_array(vec![
636                MsgValue::Array(
637                    cmd.cmdline
638                        .iter()
639                        .map(|s| MsgValue::String(s.clone()))
640                        .collect(),
641                ),
642                MsgValue::Bool(cmd.pipe_stdin),
643                MsgValue::Bool(cmd.pipe_stdout),
644                MsgValue::Bool(cmd.pipe_stderr),
645                MsgValue::Nil,
646                cmd.term
647                    .as_ref()
648                    .map(|s| MsgValue::String(s.clone()))
649                    .unwrap_or(MsgValue::Nil),
650                opt_u16(cmd.rows),
651                opt_u16(cmd.cols),
652                opt_u16(cmd.hpix),
653                opt_u16(cmd.vpix),
654            ]),
655            RnshMessage::Error { msg, fatal } => pack_msgpack_array(vec![
656                MsgValue::String(msg.clone()),
657                MsgValue::Bool(*fatal),
658                MsgValue::Nil,
659            ]),
660            RnshMessage::CommandExited(code) => {
661                let mut out = Vec::new();
662                msgpack_pack(&MsgValue::Int(*code as i64), &mut out);
663                out
664            }
665        }
666    }
667
668    fn unpack(msgtype: u16, payload: &[u8]) -> Result<Self, RnshError> {
669        match msgtype {
670            MSG_NOOP => Ok(RnshMessage::Noop),
671            MSG_STREAM_DATA => Ok(RnshMessage::StreamData(
672                StreamDataMessage::unpack_bounded(
673                    payload,
674                    &Bzip2Compressor,
675                    MAX_DECOMPRESSED_STREAM_CHUNK,
676                )
677                .map_err(|_| RnshError::Protocol("invalid stream data message".into()))?,
678            )),
679            MSG_VERSION_INFO => {
680                let values = expect_array(msgpack_unpack(payload)?, 2)?;
681                Ok(RnshMessage::VersionInfo {
682                    sw_version: expect_string(&values[0])?,
683                    protocol_version: expect_int(&values[1])? as u64,
684                })
685            }
686            MSG_WINDOW_SIZE => {
687                let values = expect_array(msgpack_unpack(payload)?, 4)?;
688                Ok(RnshMessage::WindowSize(WindowSize {
689                    rows: opt_int_u16(&values[0])?,
690                    cols: opt_int_u16(&values[1])?,
691                    hpix: opt_int_u16(&values[2])?,
692                    vpix: opt_int_u16(&values[3])?,
693                }))
694            }
695            MSG_EXECUTE_COMMAND => {
696                let values = expect_array(msgpack_unpack(payload)?, 10)?;
697                let cmdline = expect_array_value(&values[0])?
698                    .iter()
699                    .map(expect_string)
700                    .collect::<Result<Vec<_>, _>>()?;
701                Ok(RnshMessage::ExecuteCommand(ExecuteCommand {
702                    cmdline,
703                    pipe_stdin: expect_bool(&values[1])?,
704                    pipe_stdout: expect_bool(&values[2])?,
705                    pipe_stderr: expect_bool(&values[3])?,
706                    term: opt_string(&values[5])?,
707                    rows: opt_int_u16(&values[6])?,
708                    cols: opt_int_u16(&values[7])?,
709                    hpix: opt_int_u16(&values[8])?,
710                    vpix: opt_int_u16(&values[9])?,
711                }))
712            }
713            MSG_ERROR => {
714                let values = expect_array(msgpack_unpack(payload)?, 3)?;
715                Ok(RnshMessage::Error {
716                    msg: expect_string(&values[0])?,
717                    fatal: expect_bool(&values[1])?,
718                })
719            }
720            MSG_COMMAND_EXITED => Ok(RnshMessage::CommandExited(expect_int(&msgpack_unpack(
721                payload,
722            )?)? as i32)),
723            _ => Err(RnshError::Protocol(format!(
724                "unknown rnsh message type 0x{msgtype:04x}"
725            ))),
726        }
727    }
728}
729
730fn pack_msgpack_array(values: Vec<MsgValue>) -> Vec<u8> {
731    let mut out = Vec::new();
732    msgpack_pack(&MsgValue::Array(values), &mut out);
733    out
734}
735
736fn opt_u16(value: Option<u16>) -> MsgValue {
737    value
738        .map(|v| MsgValue::Int(v as i64))
739        .unwrap_or(MsgValue::Nil)
740}
741
742fn expect_array(value: MsgValue, len: usize) -> Result<Vec<MsgValue>, RnshError> {
743    match value {
744        MsgValue::Array(values) if values.len() == len => Ok(values),
745        _ => Err(RnshError::Protocol("unexpected msgpack array".into())),
746    }
747}
748
749fn expect_array_value(value: &MsgValue) -> Result<&[MsgValue], RnshError> {
750    match value {
751        MsgValue::Array(values) => Ok(values),
752        _ => Err(RnshError::Protocol("expected msgpack array".into())),
753    }
754}
755
756fn expect_string(value: &MsgValue) -> Result<String, RnshError> {
757    match value {
758        MsgValue::String(s) => Ok(s.clone()),
759        _ => Err(RnshError::Protocol("expected msgpack string".into())),
760    }
761}
762
763fn opt_string(value: &MsgValue) -> Result<Option<String>, RnshError> {
764    match value {
765        MsgValue::Nil => Ok(None),
766        MsgValue::String(s) => Ok(Some(s.clone())),
767        _ => Err(RnshError::Protocol(
768            "expected optional msgpack string".into(),
769        )),
770    }
771}
772
773fn expect_bool(value: &MsgValue) -> Result<bool, RnshError> {
774    match value {
775        MsgValue::Bool(v) => Ok(*v),
776        _ => Err(RnshError::Protocol("expected msgpack bool".into())),
777    }
778}
779
780fn expect_int(value: &MsgValue) -> Result<i64, RnshError> {
781    match value {
782        MsgValue::Int(v) => Ok(*v),
783        _ => Err(RnshError::Protocol("expected msgpack int".into())),
784    }
785}
786
787fn opt_int_u16(value: &MsgValue) -> Result<Option<u16>, RnshError> {
788    match value {
789        MsgValue::Nil => Ok(None),
790        MsgValue::Int(v) if *v >= 0 && *v <= u16::MAX as i64 => Ok(Some(*v as u16)),
791        _ => Err(RnshError::Protocol("expected optional u16".into())),
792    }
793}
794
795#[derive(Debug)]
796enum RnshEvent {
797    Announce(rns_net::AnnouncedIdentity),
798    LinkEstablished {
799        link_id: [u8; 16],
800        is_initiator: bool,
801    },
802    LinkClosed([u8; 16]),
803    RemoteIdentified {
804        link_id: [u8; 16],
805        identity_hash: IdentityHash,
806    },
807    ChannelMessage {
808        link_id: [u8; 16],
809        msgtype: u16,
810        payload: Vec<u8>,
811    },
812    ProcessOutput {
813        link_id: [u8; 16],
814        stream_id: u16,
815        data: Vec<u8>,
816    },
817    ProcessExited {
818        link_id: [u8; 16],
819        code: i32,
820    },
821    LocalStdin(Vec<u8>),
822    LocalStdinEof,
823}
824
825struct RnshCallbacks {
826    tx: mpsc::Sender<RnshEvent>,
827}
828
829impl Callbacks for RnshCallbacks {
830    fn on_announce(&mut self, announced: rns_net::AnnouncedIdentity) {
831        let _ = self.tx.send(RnshEvent::Announce(announced));
832    }
833
834    fn on_path_updated(&mut self, _dest_hash: DestHash, _hops: u8) {}
835
836    fn on_local_delivery(
837        &mut self,
838        _dest_hash: DestHash,
839        _raw: Vec<u8>,
840        _packet_hash: rns_net::PacketHash,
841    ) {
842    }
843
844    fn on_link_established(
845        &mut self,
846        link_id: rns_net::LinkId,
847        _dest_hash: DestHash,
848        _rtt: f64,
849        is_initiator: bool,
850    ) {
851        let _ = self.tx.send(RnshEvent::LinkEstablished {
852            link_id: link_id.0,
853            is_initiator,
854        });
855    }
856
857    fn on_link_closed(
858        &mut self,
859        link_id: rns_net::LinkId,
860        _reason: Option<rns_net::TeardownReason>,
861    ) {
862        let _ = self.tx.send(RnshEvent::LinkClosed(link_id.0));
863    }
864
865    fn on_remote_identified(
866        &mut self,
867        link_id: rns_net::LinkId,
868        identity_hash: IdentityHash,
869        _public_key: [u8; 64],
870    ) {
871        let _ = self.tx.send(RnshEvent::RemoteIdentified {
872            link_id: link_id.0,
873            identity_hash,
874        });
875    }
876
877    fn on_channel_message(&mut self, link_id: rns_net::LinkId, msgtype: u16, payload: Vec<u8>) {
878        let _ = self.tx.send(RnshEvent::ChannelMessage {
879            link_id: link_id.0,
880            msgtype,
881            payload,
882        });
883    }
884}
885
886trait RnshTransport {
887    fn send_rnsh_message(&self, link_id: [u8; 16], message: &RnshMessage) -> Result<(), RnshError>;
888
889    fn teardown_rnsh_link(&self, link_id: [u8; 16]) -> Result<(), RnshError>;
890}
891
892impl RnshTransport for RnsNode {
893    fn send_rnsh_message(&self, link_id: [u8; 16], message: &RnshMessage) -> Result<(), RnshError> {
894        self.send_channel_message(link_id, message.msgtype(), message.pack())?;
895        Ok(())
896    }
897
898    fn teardown_rnsh_link(&self, link_id: [u8; 16]) -> Result<(), RnshError> {
899        self.teardown_link(link_id)?;
900        Ok(())
901    }
902}
903
904struct ChildProcess {
905    pid: libc::pid_t,
906    stdin_fd: Option<RawFd>,
907    stdout_fd: Option<RawFd>,
908    stderr_fd: Option<RawFd>,
909}
910
911impl ChildProcess {
912    fn spawn(
913        link_id: [u8; 16],
914        argv: &[String],
915        env_overrides: &[(&str, String)],
916        flags: &ExecuteCommand,
917        event_tx: mpsc::Sender<RnshEvent>,
918    ) -> io::Result<Self> {
919        if argv.is_empty() {
920            return Err(io::Error::new(io::ErrorKind::InvalidInput, "empty command"));
921        }
922
923        let use_pty = !(flags.pipe_stdin && flags.pipe_stdout && flags.pipe_stderr);
924        let mut pty_master = None;
925        let mut pty_child = None;
926        if use_pty {
927            let mut master: libc::c_int = -1;
928            let mut child: libc::c_int = -1;
929            let rc = unsafe {
930                libc::openpty(
931                    &mut master,
932                    &mut child,
933                    std::ptr::null_mut(),
934                    std::ptr::null(),
935                    std::ptr::null(),
936                )
937            };
938            if rc != 0 {
939                return Err(io::Error::last_os_error());
940            }
941            pty_master = Some(master);
942            pty_child = Some(child);
943        }
944
945        let stdin_pipe = if flags.pipe_stdin {
946            Some(pipe_pair()?)
947        } else {
948            None
949        };
950        let stdout_pipe = if flags.pipe_stdout {
951            Some(pipe_pair()?)
952        } else {
953            None
954        };
955        let stderr_pipe = if flags.pipe_stderr {
956            Some(pipe_pair()?)
957        } else {
958            None
959        };
960
961        let child_stdin = stdin_pipe.map(|p| p.0).or(pty_child).unwrap_or(-1);
962        let parent_stdin = stdin_pipe.map(|p| p.1).or(pty_master).unwrap_or(-1);
963        let parent_stdout = stdout_pipe.map(|p| p.0).or(pty_master).unwrap_or(-1);
964        let child_stdout = stdout_pipe.map(|p| p.1).or(pty_child).unwrap_or(-1);
965        let parent_stderr = stderr_pipe.map(|p| p.0).or(pty_master).unwrap_or(-1);
966        let child_stderr = stderr_pipe.map(|p| p.1).or(pty_child).unwrap_or(-1);
967
968        let pid = unsafe { libc::fork() };
969        if pid < 0 {
970            return Err(io::Error::last_os_error());
971        }
972        if pid == 0 {
973            unsafe {
974                if use_pty {
975                    libc::setsid();
976                }
977                libc::dup2(child_stdin, 0);
978                libc::dup2(child_stdout, 1);
979                libc::dup2(child_stderr, 2);
980                if use_pty {
981                    let tty_fd = if !flags.pipe_stdin {
982                        0
983                    } else if !flags.pipe_stdout {
984                        1
985                    } else {
986                        2
987                    };
988                    libc::ioctl(tty_fd, libc::TIOCSCTTY, 0);
989                }
990                for fd in 3..1024 {
991                    libc::close(fd);
992                }
993                for (key, value) in env_overrides {
994                    if let (Ok(k), Ok(v)) = (CString::new(*key), CString::new(value.as_str())) {
995                        libc::setenv(k.as_ptr(), v.as_ptr(), 1);
996                    }
997                }
998                let c_args = argv
999                    .iter()
1000                    .map(|arg| CString::new(arg.as_str()))
1001                    .collect::<Result<Vec<_>, _>>();
1002                if let Ok(c_args) = c_args {
1003                    let mut ptrs = c_args.iter().map(|s| s.as_ptr()).collect::<Vec<_>>();
1004                    ptrs.push(std::ptr::null());
1005                    libc::execvp(ptrs[0], ptrs.as_ptr());
1006                }
1007                libc::_exit(255);
1008            }
1009        }
1010
1011        close_unique(&[
1012            pty_child,
1013            stdin_pipe.map(|p| p.0),
1014            stdout_pipe.map(|p| p.1),
1015            stderr_pipe.map(|p| p.1),
1016        ]);
1017
1018        let stdout_fd = if parent_stdout >= 0 {
1019            Some(parent_stdout)
1020        } else {
1021            None
1022        };
1023        let stderr_fd = if parent_stderr >= 0 && Some(parent_stderr) != stdout_fd {
1024            Some(parent_stderr)
1025        } else {
1026            None
1027        };
1028
1029        let mut reader_handles = Vec::new();
1030        if let Some(fd) = stdout_fd {
1031            reader_handles.push(spawn_reader(link_id, STREAM_STDOUT, fd, event_tx.clone()));
1032        }
1033        if let Some(fd) = stderr_fd {
1034            reader_handles.push(spawn_reader(link_id, STREAM_STDERR, fd, event_tx.clone()));
1035        }
1036        spawn_waiter(link_id, pid, reader_handles, event_tx);
1037
1038        Ok(ChildProcess {
1039            pid,
1040            stdin_fd: (parent_stdin >= 0).then_some(parent_stdin),
1041            stdout_fd,
1042            stderr_fd,
1043        })
1044    }
1045
1046    fn write_stdin(&self, data: &[u8]) {
1047        if let Some(fd) = self.stdin_fd {
1048            let _ = write_all_fd(fd, data);
1049        }
1050    }
1051
1052    fn close_stdin(&mut self) {
1053        if let Some(fd) = self.stdin_fd.take() {
1054            if Some(fd) == self.stdout_fd || Some(fd) == self.stderr_fd {
1055                let _ = write_all_fd(fd, b"\x04");
1056                self.stdin_fd = Some(fd);
1057            } else {
1058                unsafe {
1059                    libc::close(fd);
1060                }
1061            }
1062        }
1063    }
1064
1065    fn set_winsize(&self, size: &WindowSize) {
1066        let Some(fd) = self.stdout_fd.or(self.stdin_fd) else {
1067            return;
1068        };
1069        let ws = libc::winsize {
1070            ws_row: size.rows.unwrap_or(0),
1071            ws_col: size.cols.unwrap_or(0),
1072            ws_xpixel: size.hpix.unwrap_or(0),
1073            ws_ypixel: size.vpix.unwrap_or(0),
1074        };
1075        unsafe {
1076            libc::ioctl(fd, libc::TIOCSWINSZ, &ws);
1077        }
1078    }
1079
1080    fn terminate(&mut self) {
1081        unsafe {
1082            libc::kill(self.pid, libc::SIGTERM);
1083        }
1084        self.close_stdin();
1085    }
1086}
1087
1088impl Drop for ChildProcess {
1089    fn drop(&mut self) {
1090        close_unique(&[
1091            self.stdin_fd.take(),
1092            self.stdout_fd.take(),
1093            self.stderr_fd.take(),
1094        ]);
1095    }
1096}
1097
1098fn pipe_pair() -> io::Result<(RawFd, RawFd)> {
1099    let mut fds = [-1; 2];
1100    if unsafe { libc::pipe(fds.as_mut_ptr()) } != 0 {
1101        return Err(io::Error::last_os_error());
1102    }
1103    Ok((fds[0], fds[1]))
1104}
1105
1106fn close_unique(fds: &[Option<RawFd>]) {
1107    let mut seen = HashSet::new();
1108    for fd in fds.iter().flatten().copied() {
1109        if fd >= 0 && seen.insert(fd) {
1110            unsafe {
1111                libc::close(fd);
1112            }
1113        }
1114    }
1115}
1116
1117fn spawn_reader(
1118    link_id: [u8; 16],
1119    stream_id: u16,
1120    fd: RawFd,
1121    event_tx: mpsc::Sender<RnshEvent>,
1122) -> std::thread::JoinHandle<()> {
1123    std::thread::spawn(move || {
1124        let mut buf = [0u8; 4096];
1125        loop {
1126            let n = unsafe { libc::read(fd, buf.as_mut_ptr().cast(), buf.len()) };
1127            if n > 0 {
1128                let _ = event_tx.send(RnshEvent::ProcessOutput {
1129                    link_id,
1130                    stream_id,
1131                    data: buf[..n as usize].to_vec(),
1132                });
1133            } else {
1134                break;
1135            }
1136        }
1137    })
1138}
1139
1140fn spawn_waiter(
1141    link_id: [u8; 16],
1142    pid: libc::pid_t,
1143    reader_handles: Vec<std::thread::JoinHandle<()>>,
1144    event_tx: mpsc::Sender<RnshEvent>,
1145) {
1146    std::thread::spawn(move || {
1147        let mut status = 0;
1148        let _ = unsafe { libc::waitpid(pid, &mut status, 0) };
1149        let code = if libc::WIFEXITED(status) {
1150            libc::WEXITSTATUS(status)
1151        } else if libc::WIFSIGNALED(status) {
1152            128 + libc::WTERMSIG(status)
1153        } else {
1154            255
1155        };
1156        for handle in reader_handles {
1157            let _ = handle.join();
1158        }
1159        let _ = event_tx.send(RnshEvent::ProcessExited { link_id, code });
1160    });
1161}
1162
1163fn write_all_fd(fd: RawFd, mut data: &[u8]) -> io::Result<()> {
1164    while !data.is_empty() {
1165        let n = unsafe { libc::write(fd, data.as_ptr().cast(), data.len()) };
1166        if n < 0 {
1167            return Err(io::Error::last_os_error());
1168        }
1169        data = &data[n as usize..];
1170    }
1171    Ok(())
1172}
1173
1174struct TtyRestorer {
1175    fd: RawFd,
1176    original: Option<libc::termios>,
1177}
1178
1179impl TtyRestorer {
1180    fn new(fd: RawFd) -> Self {
1181        let mut original = unsafe { std::mem::zeroed() };
1182        let original = if unsafe { libc::tcgetattr(fd, &mut original) } == 0 {
1183            Some(original)
1184        } else {
1185            None
1186        };
1187        TtyRestorer { fd, original }
1188    }
1189
1190    fn raw(&self) {
1191        let Some(mut raw) = self.original else {
1192            return;
1193        };
1194        unsafe {
1195            libc::cfmakeraw(&mut raw);
1196            libc::tcsetattr(self.fd, libc::TCSANOW, &raw);
1197        }
1198    }
1199}
1200
1201impl Drop for TtyRestorer {
1202    fn drop(&mut self) {
1203        if let Some(original) = self.original {
1204            unsafe {
1205                libc::tcsetattr(self.fd, libc::TCSADRAIN, &original);
1206            }
1207        }
1208    }
1209}
1210
1211fn current_winsize(fd: RawFd) -> WindowSize {
1212    let mut ws = libc::winsize {
1213        ws_row: 0,
1214        ws_col: 0,
1215        ws_xpixel: 0,
1216        ws_ypixel: 0,
1217    };
1218    if unsafe { libc::ioctl(fd, libc::TIOCGWINSZ, &mut ws) } == 0 {
1219        WindowSize {
1220            rows: nonzero_u16(ws.ws_row),
1221            cols: nonzero_u16(ws.ws_col),
1222            hpix: nonzero_u16(ws.ws_xpixel),
1223            vpix: nonzero_u16(ws.ws_ypixel),
1224        }
1225    } else {
1226        WindowSize {
1227            rows: None,
1228            cols: None,
1229            hpix: None,
1230            vpix: None,
1231        }
1232    }
1233}
1234
1235fn nonzero_u16(value: u16) -> Option<u16> {
1236    (value != 0).then_some(value)
1237}
1238
1239#[derive(Clone)]
1240struct ListenerConfig {
1241    default_command: Vec<String>,
1242    allow_all: bool,
1243    allowed: HashSet<[u8; 16]>,
1244    allow_remote_command: bool,
1245    remote_command_as_args: bool,
1246}
1247
1248#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1249enum ListenerState {
1250    WaitIdent,
1251    WaitVersion,
1252    WaitCommand,
1253    Running,
1254    Closed,
1255}
1256
1257struct ListenerSession {
1258    link_id: [u8; 16],
1259    state: ListenerState,
1260    remote_identity: Option<IdentityHash>,
1261    config: ListenerConfig,
1262    process: Option<ChildProcess>,
1263}
1264
1265impl ListenerSession {
1266    fn new(link_id: [u8; 16], config: ListenerConfig) -> Self {
1267        let state = if config.allow_all {
1268            ListenerState::WaitVersion
1269        } else {
1270            ListenerState::WaitIdent
1271        };
1272        ListenerSession {
1273            link_id,
1274            state,
1275            remote_identity: None,
1276            config,
1277            process: None,
1278        }
1279    }
1280
1281    fn remote_identified(&mut self, transport: &dyn RnshTransport, identity_hash: IdentityHash) {
1282        if !self.config.allow_all && !self.config.allowed.contains(&identity_hash.0) {
1283            let _ = send_message(
1284                transport,
1285                self.link_id,
1286                &RnshMessage::Error {
1287                    msg: "Identity is not allowed.".into(),
1288                    fatal: true,
1289                },
1290            );
1291            let _ = transport.teardown_rnsh_link(self.link_id);
1292            self.state = ListenerState::Closed;
1293            return;
1294        }
1295        self.remote_identity = Some(identity_hash);
1296        if self.state == ListenerState::WaitIdent {
1297            self.state = ListenerState::WaitVersion;
1298        }
1299    }
1300
1301    fn handle_message(
1302        &mut self,
1303        transport: &dyn RnshTransport,
1304        event_tx: &mpsc::Sender<RnshEvent>,
1305        msgtype: u16,
1306        payload: Vec<u8>,
1307    ) {
1308        if self.state == ListenerState::WaitIdent {
1309            return;
1310        }
1311        let message = match RnshMessage::unpack(msgtype, &payload) {
1312            Ok(message) => message,
1313            Err(err) => {
1314                self.protocol_error(transport, &err.to_string());
1315                return;
1316            }
1317        };
1318        match self.state {
1319            ListenerState::WaitVersion => match message {
1320                RnshMessage::VersionInfo {
1321                    protocol_version, ..
1322                } if protocol_version == PROTOCOL_VERSION => {
1323                    let _ = send_message(transport, self.link_id, &version_message());
1324                    self.state = ListenerState::WaitCommand;
1325                }
1326                RnshMessage::VersionInfo { .. } => {
1327                    self.protocol_error(transport, "Incompatible protocol");
1328                }
1329                _ => self.protocol_error(transport, "expected version info"),
1330            },
1331            ListenerState::WaitCommand => match message {
1332                RnshMessage::ExecuteCommand(command) => {
1333                    if let Err(err) = self.start_command(transport, event_tx, command) {
1334                        self.protocol_error(transport, &format!("Unable to start process: {err}"));
1335                    } else {
1336                        self.state = ListenerState::Running;
1337                    }
1338                }
1339                _ => self.protocol_error(transport, "expected execute command"),
1340            },
1341            ListenerState::Running => match message {
1342                RnshMessage::WindowSize(size) => {
1343                    if let Some(process) = &self.process {
1344                        process.set_winsize(&size);
1345                    }
1346                }
1347                RnshMessage::StreamData(data) if data.stream_id == STREAM_STDIN => {
1348                    if let Some(process) = &mut self.process {
1349                        if !data.data.is_empty() {
1350                            process.write_stdin(&data.data);
1351                        }
1352                        if data.eof {
1353                            process.close_stdin();
1354                        }
1355                    }
1356                }
1357                RnshMessage::Noop => {
1358                    let _ = send_message(transport, self.link_id, &RnshMessage::Noop);
1359                }
1360                _ => self.protocol_error(transport, "unexpected message while running"),
1361            },
1362            ListenerState::WaitIdent | ListenerState::Closed => {}
1363        }
1364    }
1365
1366    fn start_command(
1367        &mut self,
1368        transport: &dyn RnshTransport,
1369        event_tx: &mpsc::Sender<RnshEvent>,
1370        command: ExecuteCommand,
1371    ) -> Result<(), RnshError> {
1372        if !self.config.allow_remote_command && !command.cmdline.is_empty() {
1373            let _ = send_message(
1374                transport,
1375                self.link_id,
1376                &RnshMessage::Error {
1377                    msg: "Remote command line not allowed by listener".into(),
1378                    fatal: true,
1379                },
1380            );
1381            return Err(RnshError::Protocol(
1382                "remote command line not allowed by listener".into(),
1383            ));
1384        }
1385
1386        let mut argv = self.config.default_command.clone();
1387        if self.config.remote_command_as_args && !command.cmdline.is_empty() {
1388            argv.extend(command.cmdline.clone());
1389        } else if !command.cmdline.is_empty() {
1390            argv = command.cmdline.clone();
1391        }
1392
1393        let remote_identity = self
1394            .remote_identity
1395            .as_ref()
1396            .map(|ih| prettyhexrep(&ih.0))
1397            .unwrap_or_default();
1398        let env = [
1399            (
1400                "TERM",
1401                command
1402                    .term
1403                    .clone()
1404                    .or_else(|| std::env::var("TERM").ok())
1405                    .unwrap_or_else(|| "xterm".into()),
1406            ),
1407            ("RNS_REMOTE_IDENTITY", remote_identity),
1408        ];
1409        let process = ChildProcess::spawn(self.link_id, &argv, &env, &command, event_tx.clone())?;
1410        process.set_winsize(&WindowSize {
1411            rows: command.rows,
1412            cols: command.cols,
1413            hpix: command.hpix,
1414            vpix: command.vpix,
1415        });
1416        self.process = Some(process);
1417        Ok(())
1418    }
1419
1420    fn protocol_error(&mut self, transport: &dyn RnshTransport, message: &str) {
1421        let _ = send_message(
1422            transport,
1423            self.link_id,
1424            &RnshMessage::Error {
1425                msg: message.into(),
1426                fatal: true,
1427            },
1428        );
1429        let _ = transport.teardown_rnsh_link(self.link_id);
1430        if let Some(process) = &mut self.process {
1431            process.terminate();
1432        }
1433        self.state = ListenerState::Closed;
1434    }
1435}
1436
1437fn listen(opts: CliOptions) -> Result<(), RnshError> {
1438    let (event_tx, event_rx) = mpsc::channel();
1439    let node = RnsNode::connect_shared_from_config(
1440        opts.config.as_deref().map(Path::new),
1441        Box::new(RnshCallbacks {
1442            tx: event_tx.clone(),
1443        }),
1444    )?;
1445
1446    let service = opts.service.as_deref().unwrap_or(DEFAULT_SERVICE_NAME);
1447    let identity = prepare_identity(
1448        opts.config.as_deref(),
1449        opts.identity.as_deref(),
1450        Some(service),
1451    )?;
1452    let identity_hash = IdentityHash(*identity.hash());
1453    let dest = Destination::single_in(APP_NAME, &[], identity_hash);
1454    let (sig_prv, sig_pub) = extract_sig_keys(&identity)?;
1455    node.register_destination_with_proof(
1456        &dest,
1457        Some(
1458            identity.get_private_key().ok_or_else(|| {
1459                RnshError::Protocol("listener identity has no private key".into())
1460            })?,
1461        ),
1462    )?;
1463    node.register_link_destination(dest.hash.0, sig_prv, sig_pub, 0)?;
1464
1465    eprintln!("rnsh listening on {}", prettyhexrep(&dest.hash.0));
1466
1467    let allowed = load_allowed_identities(&opts)?;
1468    if allowed.is_empty() && !opts.no_auth {
1469        eprintln!("warning: no allowed identities configured; no initiators will be accepted");
1470    }
1471
1472    let default_command = if opts.command.is_empty() {
1473        vec![std::env::var("SHELL").unwrap_or_else(|_| "/bin/sh".into())]
1474    } else {
1475        opts.command.clone()
1476    };
1477    let config = ListenerConfig {
1478        default_command,
1479        allow_all: opts.no_auth,
1480        allowed,
1481        allow_remote_command: !opts.no_remote_command,
1482        remote_command_as_args: opts.remote_command_as_args,
1483    };
1484
1485    let mut sessions: HashMap<[u8; 16], ListenerSession> = HashMap::new();
1486    let mut last_announce = Instant::now() - Duration::from_secs(24 * 60 * 60);
1487    let mut announced_once = false;
1488
1489    loop {
1490        if let Some(period) = opts.announce_period {
1491            let due = period == 0 && !announced_once
1492                || period > 0 && last_announce.elapsed() >= Duration::from_secs(period);
1493            if due {
1494                node.announce(&dest, &identity, None)?;
1495                last_announce = Instant::now();
1496                announced_once = true;
1497            }
1498        }
1499
1500        match event_rx.recv_timeout(Duration::from_millis(100)) {
1501            Ok(RnshEvent::LinkEstablished {
1502                link_id,
1503                is_initiator: false,
1504                ..
1505            }) => {
1506                sessions
1507                    .entry(link_id)
1508                    .or_insert_with(|| ListenerSession::new(link_id, config.clone()));
1509            }
1510            Ok(RnshEvent::RemoteIdentified {
1511                link_id,
1512                identity_hash,
1513            }) => {
1514                if let Some(session) = sessions.get_mut(&link_id) {
1515                    session.remote_identified(&node, identity_hash);
1516                }
1517            }
1518            Ok(RnshEvent::ChannelMessage {
1519                link_id,
1520                msgtype,
1521                payload,
1522            }) => {
1523                if let Some(session) = sessions.get_mut(&link_id) {
1524                    session.handle_message(&node, &event_tx, msgtype, payload);
1525                }
1526            }
1527            Ok(RnshEvent::ProcessOutput {
1528                link_id,
1529                stream_id,
1530                data,
1531            }) => {
1532                send_stream_chunks(&node, link_id, stream_id, &data, false)?;
1533            }
1534            Ok(RnshEvent::ProcessExited { link_id, code }) => {
1535                send_stream_chunks(&node, link_id, STREAM_STDOUT, &[], true)?;
1536                let _ = send_message(&node, link_id, &RnshMessage::CommandExited(code));
1537                sessions.remove(&link_id);
1538            }
1539            Ok(RnshEvent::LinkClosed(link_id)) => {
1540                if let Some(mut session) = sessions.remove(&link_id) {
1541                    if let Some(process) = &mut session.process {
1542                        process.terminate();
1543                    }
1544                }
1545            }
1546            Ok(_) | Err(mpsc::RecvTimeoutError::Timeout) => {}
1547            Err(mpsc::RecvTimeoutError::Disconnected) => break,
1548        }
1549    }
1550    Ok(())
1551}
1552
1553fn initiate(opts: CliOptions) -> Result<i32, RnshError> {
1554    let dest_hash = parse_hash_16(
1555        opts.destination
1556            .as_deref()
1557            .ok_or_else(|| RnshError::Protocol("missing destination".into()))?,
1558    )
1559    .ok_or_else(|| RnshError::Protocol("destination must be 32 hexadecimal characters".into()))?;
1560    let timeout = Duration::from_secs_f64(opts.timeout.unwrap_or(15.0));
1561    let (event_tx, event_rx) = mpsc::channel();
1562    let node = RnsNode::connect_shared_from_config(
1563        opts.config.as_deref().map(Path::new),
1564        Box::new(RnshCallbacks {
1565            tx: event_tx.clone(),
1566        }),
1567    )?;
1568    let identity = prepare_identity(opts.config.as_deref(), opts.identity.as_deref(), None)?;
1569
1570    wait_for_path(&node, dest_hash, &event_rx, timeout)?;
1571    let recalled = node
1572        .recall_identity(&DestHash(dest_hash))?
1573        .ok_or_else(|| RnshError::Protocol("destination identity was not recalled".into()))?;
1574    let mut sig_pub = [0u8; 32];
1575    sig_pub.copy_from_slice(&recalled.public_key[32..64]);
1576
1577    let link_id = node.create_link(dest_hash, sig_pub)?;
1578    wait_for_link(&event_rx, link_id, timeout)?;
1579    if !opts.no_id {
1580        node.identify_on_link(
1581            link_id,
1582            identity
1583                .get_private_key()
1584                .ok_or_else(|| RnshError::Protocol("identity has no private key".into()))?,
1585        )?;
1586    }
1587
1588    send_message(&node, link_id, &version_message())?;
1589    wait_for_version(&event_rx, timeout)?;
1590
1591    let stdin_is_tty = io::stdin().is_terminal();
1592    let stdout_is_tty = io::stdout().is_terminal();
1593    let stderr_is_tty = io::stderr().is_terminal();
1594    let size = current_winsize(0);
1595    let execute = ExecuteCommand {
1596        cmdline: opts.command.clone(),
1597        pipe_stdin: !stdin_is_tty,
1598        pipe_stdout: !stdout_is_tty,
1599        pipe_stderr: !stderr_is_tty,
1600        term: std::env::var("TERM").ok(),
1601        rows: size.rows,
1602        cols: size.cols,
1603        hpix: size.hpix,
1604        vpix: size.vpix,
1605    };
1606    send_message(&node, link_id, &RnshMessage::ExecuteCommand(execute))?;
1607
1608    let tty = stdin_is_tty.then(|| {
1609        let restorer = TtyRestorer::new(0);
1610        restorer.raw();
1611        unsafe {
1612            libc::signal(libc::SIGWINCH, sigwinch_handler as *const () as usize);
1613        }
1614        restorer
1615    });
1616    let _keep_tty = tty;
1617    spawn_stdin_reader(event_tx);
1618
1619    loop {
1620        if SIGWINCH_SEEN.swap(false, Ordering::SeqCst) {
1621            let _ = send_message(&node, link_id, &RnshMessage::WindowSize(current_winsize(0)));
1622        }
1623        match event_rx.recv_timeout(Duration::from_millis(100)) {
1624            Ok(RnshEvent::ChannelMessage {
1625                msgtype, payload, ..
1626            }) => match RnshMessage::unpack(msgtype, &payload)? {
1627                RnshMessage::StreamData(data) if data.stream_id == STREAM_STDOUT => {
1628                    io::stdout().write_all(&data.data)?;
1629                    io::stdout().flush()?;
1630                }
1631                RnshMessage::StreamData(data) if data.stream_id == STREAM_STDERR => {
1632                    io::stderr().write_all(&data.data)?;
1633                    io::stderr().flush()?;
1634                }
1635                RnshMessage::CommandExited(code) => return Ok(code),
1636                RnshMessage::Error { msg, fatal } => {
1637                    eprintln!("remote error: {msg}");
1638                    if fatal {
1639                        return Ok(200);
1640                    }
1641                }
1642                _ => {}
1643            },
1644            Ok(RnshEvent::LocalStdin(data)) => {
1645                send_stream_chunks(&node, link_id, STREAM_STDIN, &data, false)?;
1646            }
1647            Ok(RnshEvent::LocalStdinEof) => {
1648                send_stream_chunks(&node, link_id, STREAM_STDIN, &[], true)?;
1649            }
1650            Ok(RnshEvent::LinkClosed(_)) => return Ok(0),
1651            Ok(_) | Err(mpsc::RecvTimeoutError::Timeout) => {}
1652            Err(mpsc::RecvTimeoutError::Disconnected) => return Ok(0),
1653        }
1654    }
1655}
1656
1657fn send_message(
1658    transport: &dyn RnshTransport,
1659    link_id: [u8; 16],
1660    message: &RnshMessage,
1661) -> Result<(), RnshError> {
1662    transport.send_rnsh_message(link_id, message)
1663}
1664
1665fn send_stream_chunks(
1666    transport: &dyn RnshTransport,
1667    link_id: [u8; 16],
1668    stream_id: u16,
1669    data: &[u8],
1670    eof: bool,
1671) -> Result<(), RnshError> {
1672    for chunk in data.chunks(STREAM_CHUNK_MAX) {
1673        let msg = RnshMessage::StreamData(StreamDataMessage::new(
1674            stream_id,
1675            chunk.to_vec(),
1676            false,
1677            false,
1678        ));
1679        send_message(transport, link_id, &msg)?;
1680    }
1681    if eof {
1682        let msg =
1683            RnshMessage::StreamData(StreamDataMessage::new(stream_id, Vec::new(), true, false));
1684        send_message(transport, link_id, &msg)?;
1685    }
1686    Ok(())
1687}
1688
1689fn version_message() -> RnshMessage {
1690    RnshMessage::VersionInfo {
1691        sw_version: VERSION.into(),
1692        protocol_version: PROTOCOL_VERSION,
1693    }
1694}
1695
1696fn wait_for_path(
1697    node: &RnsNode,
1698    dest_hash: [u8; 16],
1699    event_rx: &mpsc::Receiver<RnshEvent>,
1700    timeout: Duration,
1701) -> Result<(), RnshError> {
1702    let started = Instant::now();
1703    if !node.has_path(&DestHash(dest_hash))? {
1704        node.request_path(&DestHash(dest_hash))?;
1705    }
1706    while started.elapsed() < timeout {
1707        if node.has_path(&DestHash(dest_hash))? {
1708            return Ok(());
1709        }
1710        if let Ok(RnshEvent::Announce(announced)) =
1711            event_rx.recv_timeout(Duration::from_millis(250))
1712        {
1713            if announced.dest_hash.0 == dest_hash {
1714                return Ok(());
1715            }
1716        }
1717    }
1718    Err(RnshError::Protocol("path not found".into()))
1719}
1720
1721fn wait_for_link(
1722    event_rx: &mpsc::Receiver<RnshEvent>,
1723    expected_link: [u8; 16],
1724    timeout: Duration,
1725) -> Result<(), RnshError> {
1726    let started = Instant::now();
1727    while started.elapsed() < timeout {
1728        match event_rx.recv_timeout(Duration::from_millis(100)) {
1729            Ok(RnshEvent::LinkEstablished {
1730                link_id,
1731                is_initiator: true,
1732                ..
1733            }) if link_id == expected_link => return Ok(()),
1734            Ok(_) | Err(mpsc::RecvTimeoutError::Timeout) => {}
1735            Err(mpsc::RecvTimeoutError::Disconnected) => break,
1736        }
1737    }
1738    Err(RnshError::Protocol("link establishment timed out".into()))
1739}
1740
1741fn wait_for_version(
1742    event_rx: &mpsc::Receiver<RnshEvent>,
1743    timeout: Duration,
1744) -> Result<(), RnshError> {
1745    let started = Instant::now();
1746    while started.elapsed() < timeout {
1747        match event_rx.recv_timeout(Duration::from_millis(100)) {
1748            Ok(RnshEvent::ChannelMessage {
1749                msgtype, payload, ..
1750            }) => match RnshMessage::unpack(msgtype, &payload)? {
1751                RnshMessage::VersionInfo {
1752                    protocol_version, ..
1753                } if protocol_version == PROTOCOL_VERSION => return Ok(()),
1754                RnshMessage::Error { msg, .. } => return Err(RnshError::Protocol(msg)),
1755                _ => {}
1756            },
1757            Ok(_) | Err(mpsc::RecvTimeoutError::Timeout) => {}
1758            Err(mpsc::RecvTimeoutError::Disconnected) => break,
1759        }
1760    }
1761    Err(RnshError::Protocol(
1762        "protocol version exchange timed out".into(),
1763    ))
1764}
1765
1766fn spawn_stdin_reader(event_tx: mpsc::Sender<RnshEvent>) {
1767    std::thread::spawn(move || {
1768        let mut stdin = io::stdin();
1769        let mut buf = [0u8; 4096];
1770        loop {
1771            match stdin.read(&mut buf) {
1772                Ok(0) => {
1773                    let _ = event_tx.send(RnshEvent::LocalStdinEof);
1774                    break;
1775                }
1776                Ok(n) => {
1777                    let _ = event_tx.send(RnshEvent::LocalStdin(buf[..n].to_vec()));
1778                }
1779                Err(_) => {
1780                    let _ = event_tx.send(RnshEvent::LocalStdinEof);
1781                    break;
1782                }
1783            }
1784        }
1785    });
1786}
1787
1788fn prepare_identity(
1789    config: Option<&str>,
1790    explicit_path: Option<&str>,
1791    service: Option<&str>,
1792) -> Result<Identity, RnshError> {
1793    let path = if let Some(path) = explicit_path {
1794        PathBuf::from(path)
1795    } else {
1796        let config_dir = rns_net::storage::resolve_config_dir(config.map(Path::new));
1797        let paths = rns_net::storage::ensure_storage_dirs(&config_dir)?;
1798        let suffix = service.map(sanitize_service_name).unwrap_or_default();
1799        let filename = if suffix.is_empty() {
1800            APP_NAME.to_string()
1801        } else {
1802            format!("{APP_NAME}.{suffix}")
1803        };
1804        paths.identities.join(filename)
1805    };
1806    if let Some(parent) = path.parent() {
1807        fs::create_dir_all(parent)?;
1808    }
1809    if path.exists() {
1810        Ok(rns_net::storage::load_identity(&path)?)
1811    } else {
1812        let identity = Identity::new(&mut OsRng);
1813        rns_net::storage::save_identity(&identity, &path)?;
1814        Ok(identity)
1815    }
1816}
1817
1818fn print_identity(opts: &CliOptions) -> Result<(), RnshError> {
1819    let identity = prepare_identity(
1820        opts.config.as_deref(),
1821        opts.identity.as_deref(),
1822        opts.service.as_deref(),
1823    )?;
1824    println!("Identity     : {}", prettyhexrep(identity.hash()));
1825    if opts.base256 {
1826        println!("Identity b256: {}", prettyb256rep(identity.hash()));
1827    }
1828    if opts.listen {
1829        let dest = Destination::single_in(APP_NAME, &[], IdentityHash(*identity.hash()));
1830        println!("Listening on : {}", prettyhexrep(&dest.hash.0));
1831        if opts.base256 {
1832            println!("Listen b256  : {}", prettyb256rep(&dest.hash.0));
1833        }
1834    }
1835    Ok(())
1836}
1837
1838fn load_allowed_identities(opts: &CliOptions) -> Result<HashSet<[u8; 16]>, RnshError> {
1839    let mut allowed = HashSet::new();
1840    for entry in &opts.allowed {
1841        if let Some(hash) = parse_hash_16(entry) {
1842            allowed.insert(hash);
1843        } else {
1844            return Err(RnshError::Protocol(format!(
1845                "invalid allowed identity hash: {entry}"
1846            )));
1847        }
1848    }
1849    for path in allowed_identity_files() {
1850        if !path.exists() {
1851            continue;
1852        }
1853        let contents = fs::read_to_string(path)?;
1854        for line in contents
1855            .lines()
1856            .map(str::trim)
1857            .filter(|line| !line.is_empty())
1858        {
1859            if let Some(hash) = parse_hash_16(line) {
1860                allowed.insert(hash);
1861            }
1862        }
1863    }
1864    Ok(allowed)
1865}
1866
1867fn allowed_identity_files() -> Vec<PathBuf> {
1868    let home = std::env::var("HOME").unwrap_or_else(|_| ".".into());
1869    vec![
1870        PathBuf::from(&home)
1871            .join(".config")
1872            .join("rnsh")
1873            .join("allowed_identities"),
1874        PathBuf::from(home).join(".rnsh").join("allowed_identities"),
1875    ]
1876}
1877
1878fn sanitize_service_name(value: &str) -> String {
1879    value
1880        .chars()
1881        .filter(|c| c.is_ascii_alphanumeric())
1882        .collect()
1883}
1884
1885fn parse_hash_16(value: &str) -> Option<[u8; 16]> {
1886    let s = value.trim();
1887    if s.len() != 32 {
1888        return None;
1889    }
1890    let mut out = [0u8; 16];
1891    for i in 0..16 {
1892        out[i] = u8::from_str_radix(&s[i * 2..i * 2 + 2], 16).ok()?;
1893    }
1894    Some(out)
1895}
1896
1897fn extract_sig_keys(identity: &Identity) -> Result<([u8; 32], [u8; 32]), RnshError> {
1898    let private = identity
1899        .get_private_key()
1900        .ok_or_else(|| RnshError::Protocol("identity has no private key".into()))?;
1901    let public = identity
1902        .get_public_key()
1903        .ok_or_else(|| RnshError::Protocol("identity has no public key".into()))?;
1904    let mut sig_prv = [0u8; 32];
1905    let mut sig_pub = [0u8; 32];
1906    sig_prv.copy_from_slice(&private[32..64]);
1907    sig_pub.copy_from_slice(&public[32..64]);
1908    Ok((sig_prv, sig_pub))
1909}
1910
1911#[cfg(test)]
1912mod tests {
1913    use super::*;
1914    use std::sync::Mutex;
1915
1916    const TEST_LINK: [u8; 16] = [0x42; 16];
1917
1918    #[derive(Default)]
1919    struct FakeTransport {
1920        sent: Mutex<Vec<([u8; 16], u16, Vec<u8>)>>,
1921        teardowns: Mutex<Vec<[u8; 16]>>,
1922    }
1923
1924    impl FakeTransport {
1925        fn sent_messages(&self) -> Vec<([u8; 16], RnshMessage)> {
1926            self.sent
1927                .lock()
1928                .unwrap()
1929                .iter()
1930                .map(|(link_id, msgtype, payload)| {
1931                    (
1932                        *link_id,
1933                        RnshMessage::unpack(*msgtype, payload)
1934                            .expect("fake transport stored decodable message"),
1935                    )
1936                })
1937                .collect()
1938        }
1939    }
1940
1941    impl RnshTransport for FakeTransport {
1942        fn send_rnsh_message(
1943            &self,
1944            link_id: [u8; 16],
1945            message: &RnshMessage,
1946        ) -> Result<(), RnshError> {
1947            self.sent
1948                .lock()
1949                .unwrap()
1950                .push((link_id, message.msgtype(), message.pack()));
1951            Ok(())
1952        }
1953
1954        fn teardown_rnsh_link(&self, link_id: [u8; 16]) -> Result<(), RnshError> {
1955            self.teardowns.lock().unwrap().push(link_id);
1956            Ok(())
1957        }
1958    }
1959
1960    fn test_config() -> ListenerConfig {
1961        ListenerConfig {
1962            default_command: vec!["/bin/cat".into()],
1963            allow_all: true,
1964            allowed: HashSet::new(),
1965            allow_remote_command: true,
1966            remote_command_as_args: false,
1967        }
1968    }
1969
1970    fn exec_msg(cmdline: Vec<&str>) -> RnshMessage {
1971        RnshMessage::ExecuteCommand(ExecuteCommand {
1972            cmdline: cmdline.into_iter().map(str::to_string).collect(),
1973            pipe_stdin: true,
1974            pipe_stdout: true,
1975            pipe_stderr: true,
1976            term: Some("xterm".into()),
1977            rows: Some(24),
1978            cols: Some(80),
1979            hpix: None,
1980            vpix: None,
1981        })
1982    }
1983
1984    #[test]
1985    fn msgpack_version_matches_upstream_shape() {
1986        let msg = RnshMessage::VersionInfo {
1987            sw_version: "1.2.0".into(),
1988            protocol_version: 1,
1989        };
1990        let packed = msg.pack();
1991        assert_eq!(packed, b"\x92\xa51.2.0\x01");
1992        assert_eq!(RnshMessage::unpack(MSG_VERSION_INFO, &packed).unwrap(), msg);
1993    }
1994
1995    #[test]
1996    fn execute_command_roundtrips() {
1997        let msg = RnshMessage::ExecuteCommand(ExecuteCommand {
1998            cmdline: vec!["/bin/sh".into(), "-lc".into(), "echo hi".into()],
1999            pipe_stdin: true,
2000            pipe_stdout: true,
2001            pipe_stderr: false,
2002            term: Some("xterm-256color".into()),
2003            rows: Some(24),
2004            cols: Some(80),
2005            hpix: None,
2006            vpix: None,
2007        });
2008        let packed = msg.pack();
2009        assert_eq!(
2010            RnshMessage::unpack(MSG_EXECUTE_COMMAND, &packed).unwrap(),
2011            msg
2012        );
2013    }
2014
2015    #[test]
2016    fn stream_data_uses_upstream_header_bits() {
2017        let msg = RnshMessage::StreamData(StreamDataMessage::new(2, b"err".to_vec(), true, false));
2018        let packed = msg.pack();
2019        assert_eq!(&packed[..2], &0x8002u16.to_be_bytes());
2020        assert_eq!(RnshMessage::unpack(MSG_STREAM_DATA, &packed).unwrap(), msg);
2021    }
2022
2023    #[test]
2024    fn cli_splits_command_after_double_dash() {
2025        let args = CliOptions::parse(vec![
2026            "-l".into(),
2027            "-s".into(),
2028            "ops".into(),
2029            "--".into(),
2030            "/bin/sh".into(),
2031            "-l".into(),
2032        ])
2033        .unwrap();
2034        assert!(args.listen);
2035        assert_eq!(args.service.as_deref(), Some("ops"));
2036        assert_eq!(args.command, vec!["/bin/sh", "-l"]);
2037    }
2038
2039    #[test]
2040    fn cli_parses_base256_display_flag() {
2041        let short = CliOptions::parse(vec!["-Zp".into()]).unwrap();
2042        assert!(short.base256);
2043        assert!(short.print_identity);
2044
2045        let long = CliOptions::parse(vec!["--base256".into(), "--print-identity".into()]).unwrap();
2046        assert!(long.base256);
2047        assert!(long.print_identity);
2048    }
2049
2050    #[test]
2051    fn service_name_is_sanitized_like_upstream() {
2052        assert_eq!(sanitize_service_name("dev-shell_1!"), "devshell1");
2053    }
2054
2055    #[test]
2056    fn rnsh_logging_uses_file_oriented_levels() {
2057        assert_eq!(rnsh_log_level(false, 0, 0), log::LevelFilter::Error);
2058        assert_eq!(rnsh_log_level(true, 0, 0), log::LevelFilter::Info);
2059        assert_eq!(rnsh_log_level(true, 2, 0), log::LevelFilter::Trace);
2060        assert_eq!(rnsh_log_level(true, 0, 4), log::LevelFilter::Off);
2061    }
2062
2063    #[test]
2064    fn listener_rejects_unallowed_identity_and_tears_down() {
2065        let mut allowed = HashSet::new();
2066        allowed.insert([0x11; 16]);
2067        let config = ListenerConfig {
2068            allow_all: false,
2069            allowed,
2070            ..test_config()
2071        };
2072        let fake = FakeTransport::default();
2073        let mut session = ListenerSession::new(TEST_LINK, config);
2074
2075        assert_eq!(session.state, ListenerState::WaitIdent);
2076        session.remote_identified(&fake, IdentityHash([0x22; 16]));
2077
2078        assert_eq!(session.state, ListenerState::Closed);
2079        assert_eq!(fake.teardowns.lock().unwrap().as_slice(), &[TEST_LINK]);
2080        let messages = fake.sent_messages();
2081        assert_eq!(messages.len(), 1);
2082        assert!(matches!(
2083            &messages[0].1,
2084            RnshMessage::Error { msg, fatal: true } if msg == "Identity is not allowed."
2085        ));
2086    }
2087
2088    #[test]
2089    fn listener_accepts_allowed_identity_and_completes_version_handshake() {
2090        let mut allowed = HashSet::new();
2091        allowed.insert([0x11; 16]);
2092        let config = ListenerConfig {
2093            allow_all: false,
2094            allowed,
2095            ..test_config()
2096        };
2097        let fake = FakeTransport::default();
2098        let (tx, _rx) = mpsc::channel();
2099        let mut session = ListenerSession::new(TEST_LINK, config);
2100
2101        session.remote_identified(&fake, IdentityHash([0x11; 16]));
2102        assert_eq!(session.state, ListenerState::WaitVersion);
2103        let version = version_message();
2104        session.handle_message(&fake, &tx, version.msgtype(), version.pack());
2105
2106        assert_eq!(session.state, ListenerState::WaitCommand);
2107        let messages = fake.sent_messages();
2108        assert_eq!(messages.len(), 1);
2109        assert!(matches!(
2110            &messages[0].1,
2111            RnshMessage::VersionInfo {
2112                protocol_version: PROTOCOL_VERSION,
2113                ..
2114            }
2115        ));
2116    }
2117
2118    #[test]
2119    fn listener_rejects_incompatible_protocol_version() {
2120        let fake = FakeTransport::default();
2121        let (tx, _rx) = mpsc::channel();
2122        let mut session = ListenerSession::new(TEST_LINK, test_config());
2123        let msg = RnshMessage::VersionInfo {
2124            sw_version: "future".into(),
2125            protocol_version: PROTOCOL_VERSION + 1,
2126        };
2127
2128        session.handle_message(&fake, &tx, msg.msgtype(), msg.pack());
2129
2130        assert_eq!(session.state, ListenerState::Closed);
2131        assert_eq!(fake.teardowns.lock().unwrap().as_slice(), &[TEST_LINK]);
2132        assert!(matches!(
2133            &fake.sent_messages()[0].1,
2134            RnshMessage::Error { msg, fatal: true } if msg == "Incompatible protocol"
2135        ));
2136    }
2137
2138    #[test]
2139    fn listener_rejects_remote_command_when_disabled() {
2140        let fake = FakeTransport::default();
2141        let (tx, _rx) = mpsc::channel();
2142        let config = ListenerConfig {
2143            allow_remote_command: false,
2144            ..test_config()
2145        };
2146        let mut session = ListenerSession::new(TEST_LINK, config);
2147        let version = version_message();
2148        session.handle_message(&fake, &tx, version.msgtype(), version.pack());
2149        let exec = exec_msg(vec!["/bin/echo", "nope"]);
2150
2151        session.handle_message(&fake, &tx, exec.msgtype(), exec.pack());
2152
2153        assert_eq!(session.state, ListenerState::Closed);
2154        assert_eq!(fake.teardowns.lock().unwrap().as_slice(), &[TEST_LINK]);
2155        assert!(fake.sent_messages().iter().any(|(_, msg)| matches!(
2156            msg,
2157            RnshMessage::Error { msg, fatal: true }
2158                if msg.contains("Remote command line not allowed")
2159        )));
2160    }
2161
2162    #[test]
2163    fn listener_executes_default_command_and_forwards_stdin_to_process() {
2164        let fake = FakeTransport::default();
2165        let (tx, rx) = mpsc::channel();
2166        let mut session = ListenerSession::new(TEST_LINK, test_config());
2167        let version = version_message();
2168        session.handle_message(&fake, &tx, version.msgtype(), version.pack());
2169        let exec = exec_msg(Vec::new());
2170        session.handle_message(&fake, &tx, exec.msgtype(), exec.pack());
2171        assert_eq!(session.state, ListenerState::Running);
2172
2173        let stdin = RnshMessage::StreamData(StreamDataMessage::new(
2174            STREAM_STDIN,
2175            b"hello over stdin".to_vec(),
2176            true,
2177            false,
2178        ));
2179        session.handle_message(&fake, &tx, stdin.msgtype(), stdin.pack());
2180
2181        let started = Instant::now();
2182        let mut stdout = Vec::new();
2183        let mut exit = None;
2184        while started.elapsed() < Duration::from_secs(5) && exit.is_none() {
2185            match rx.recv_timeout(Duration::from_millis(100)).unwrap() {
2186                RnshEvent::ProcessOutput {
2187                    stream_id: STREAM_STDOUT,
2188                    data,
2189                    ..
2190                } => stdout.extend(data),
2191                RnshEvent::ProcessExited { code, .. } => exit = Some(code),
2192                _ => {}
2193            }
2194        }
2195
2196        assert_eq!(stdout, b"hello over stdin");
2197        assert_eq!(exit, Some(0));
2198    }
2199
2200    #[test]
2201    fn send_stream_chunks_splits_large_payload_and_appends_eof() {
2202        let fake = FakeTransport::default();
2203        let data = vec![0x55; STREAM_CHUNK_MAX * 2 + 3];
2204
2205        send_stream_chunks(&fake, TEST_LINK, STREAM_STDOUT, &data, true).unwrap();
2206
2207        let messages = fake.sent_messages();
2208        assert_eq!(messages.len(), 4);
2209        let mut payload = Vec::new();
2210        for (_, message) in &messages[..3] {
2211            match message {
2212                RnshMessage::StreamData(stream) => {
2213                    assert_eq!(stream.stream_id, STREAM_STDOUT);
2214                    assert!(!stream.eof);
2215                    payload.extend_from_slice(&stream.data);
2216                }
2217                other => panic!("expected stream data, got {other:?}"),
2218            }
2219        }
2220        assert_eq!(payload, data);
2221        assert!(matches!(
2222            messages.last().unwrap().1,
2223            RnshMessage::StreamData(ref stream)
2224                if stream.stream_id == STREAM_STDOUT && stream.eof && stream.data.is_empty()
2225        ));
2226    }
2227
2228    #[test]
2229    fn process_pipe_mode_reports_stdout_stderr_and_exit() {
2230        let (tx, rx) = mpsc::channel();
2231        let link_id = [7u8; 16];
2232        let command = ExecuteCommand {
2233            cmdline: Vec::new(),
2234            pipe_stdin: true,
2235            pipe_stdout: true,
2236            pipe_stderr: true,
2237            term: None,
2238            rows: None,
2239            cols: None,
2240            hpix: None,
2241            vpix: None,
2242        };
2243        let _process = ChildProcess::spawn(
2244            link_id,
2245            &[
2246                "/bin/sh".into(),
2247                "-c".into(),
2248                "printf out; printf err >&2; exit 13".into(),
2249            ],
2250            &[],
2251            &command,
2252            tx,
2253        )
2254        .unwrap();
2255
2256        let started = Instant::now();
2257        let mut stdout = Vec::new();
2258        let mut stderr = Vec::new();
2259        let mut exit = None;
2260        while started.elapsed() < Duration::from_secs(5) && exit.is_none() {
2261            match rx.recv_timeout(Duration::from_millis(100)).unwrap() {
2262                RnshEvent::ProcessOutput {
2263                    stream_id: STREAM_STDOUT,
2264                    data,
2265                    ..
2266                } => stdout.extend(data),
2267                RnshEvent::ProcessOutput {
2268                    stream_id: STREAM_STDERR,
2269                    data,
2270                    ..
2271                } => stderr.extend(data),
2272                RnshEvent::ProcessExited { code, .. } => exit = Some(code),
2273                _ => {}
2274            }
2275        }
2276        assert_eq!(stdout, b"out");
2277        assert_eq!(stderr, b"err");
2278        assert_eq!(exit, Some(13));
2279    }
2280}