1use std::collections::{HashMap, HashSet};
7use std::ffi::CString;
8use std::fs;
9use std::io::{self, IsTerminal, Read, Write};
10use std::os::fd::RawFd;
11use std::path::{Path, PathBuf};
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::mpsc;
14use std::time::{Duration, Instant};
15
16use rns_core::buffer::StreamDataMessage;
17use rns_core::types::{DestHash, IdentityHash};
18use rns_crypto::identity::Identity;
19use rns_crypto::OsRng;
20use rns_net::compressor::Bzip2Compressor;
21use rns_net::destination::Destination;
22use rns_net::{Callbacks, RnsNode, SendError};
23
24use crate::format::{prettyb256rep, prettyhexrep};
25
26const APP_NAME: &str = "rnsh";
27const DEFAULT_SERVICE_NAME: &str = "default";
28const VERSION: &str = env!("FULL_VERSION");
29
30const MSG_MAGIC: u16 = 0xac;
31const PROTOCOL_VERSION: u64 = 1;
32
33const MSG_NOOP: u16 = (MSG_MAGIC << 8) | 0;
34const MSG_WINDOW_SIZE: u16 = (MSG_MAGIC << 8) | 2;
35const MSG_EXECUTE_COMMAND: u16 = (MSG_MAGIC << 8) | 3;
36const MSG_STREAM_DATA: u16 = (MSG_MAGIC << 8) | 4;
37const MSG_VERSION_INFO: u16 = (MSG_MAGIC << 8) | 5;
38const MSG_ERROR: u16 = (MSG_MAGIC << 8) | 6;
39const MSG_COMMAND_EXITED: u16 = (MSG_MAGIC << 8) | 7;
40
41const STREAM_STDIN: u16 = 0;
42const STREAM_STDOUT: u16 = 1;
43const STREAM_STDERR: u16 = 2;
44
45const CHANNEL_PAYLOAD_MAX: usize =
46 rns_core::constants::LINK_MDU - rns_core::constants::CHANNEL_ENVELOPE_OVERHEAD;
47const STREAM_CHUNK_MAX: usize = CHANNEL_PAYLOAD_MAX - 2;
48const MAX_DECOMPRESSED_STREAM_CHUNK: usize = 64 * 1024;
49
50static SIGWINCH_SEEN: AtomicBool = AtomicBool::new(false);
51
52extern "C" fn sigwinch_handler(_: libc::c_int) {
53 SIGWINCH_SEEN.store(true, Ordering::SeqCst);
54}
55
56#[derive(Debug)]
57enum RnshError {
58 Io(io::Error),
59 Protocol(String),
60 Send,
61}
62
63impl From<io::Error> for RnshError {
64 fn from(value: io::Error) -> Self {
65 RnshError::Io(value)
66 }
67}
68
69impl From<SendError> for RnshError {
70 fn from(_: SendError) -> Self {
71 RnshError::Send
72 }
73}
74
75impl std::fmt::Display for RnshError {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 match self {
78 RnshError::Io(err) => write!(f, "{err}"),
79 RnshError::Protocol(err) => write!(f, "{err}"),
80 RnshError::Send => write!(f, "RNS send failed"),
81 }
82 }
83}
84
85pub fn main() -> i32 {
86 match CliOptions::parse(std::env::args().skip(1).collect()) {
87 Ok(opts) => {
88 if opts.help {
89 print_usage();
90 return 0;
91 }
92 if opts.version {
93 println!("rnsh {} (protocol {})", VERSION, PROTOCOL_VERSION);
94 return 0;
95 }
96 if let Err(err) = init_rnsh_logging(&opts) {
97 eprintln!("{err}");
98 return 1;
99 }
100 if opts.print_identity {
101 return match print_identity(&opts) {
102 Ok(()) => 0,
103 Err(err) => {
104 eprintln!("{err}");
105 1
106 }
107 };
108 }
109 let result = if opts.listen {
110 listen(opts).map(|_| 0)
111 } else if opts.destination.is_some() {
112 let mirror = opts.mirror_exit;
113 initiate(opts).map(|code| if mirror { code } else { 0 })
114 } else {
115 print_usage();
116 Ok(1)
117 };
118 match result {
119 Ok(code) => code,
120 Err(err) => {
121 eprintln!("{err}");
122 1
123 }
124 }
125 }
126 Err(err) => {
127 eprintln!("{err}");
128 print_usage();
129 1
130 }
131 }
132}
133
134#[derive(Debug, Clone, Default)]
135struct CliOptions {
136 config: Option<String>,
137 identity: Option<String>,
138 verbose: u8,
139 quiet: u8,
140 print_identity: bool,
141 base256: bool,
142 version: bool,
143 help: bool,
144 listen: bool,
145 service: Option<String>,
146 announce_period: Option<u64>,
147 allowed: Vec<String>,
148 no_auth: bool,
149 remote_command_as_args: bool,
150 no_remote_command: bool,
151 no_id: bool,
152 mirror_exit: bool,
153 timeout: Option<f64>,
154 destination: Option<String>,
155 command: Vec<String>,
156}
157
158impl CliOptions {
159 fn parse(argv: Vec<String>) -> Result<Self, String> {
160 let mut opts = CliOptions::default();
161 let (rnsh_argv, command) = match argv.iter().position(|arg| arg == "--") {
162 Some(idx) => (argv[..idx].to_vec(), argv[idx + 1..].to_vec()),
163 None => (argv, Vec::new()),
164 };
165 opts.command = command;
166
167 let mut i = 0;
168 while i < rnsh_argv.len() {
169 let arg = &rnsh_argv[i];
170 if !arg.starts_with('-') || arg == "-" {
171 if opts.destination.is_some() {
172 return Err(format!("unexpected positional argument: {arg}"));
173 }
174 opts.destination = Some(arg.clone());
175 i += 1;
176 continue;
177 }
178
179 if let Some(name) = arg.strip_prefix("--") {
180 match name {
181 "config" | "identity" | "service" | "announce" | "allowed" | "timeout" => {
182 i += 1;
183 let value = rnsh_argv
184 .get(i)
185 .ok_or_else(|| format!("--{name} requires a value"))?
186 .clone();
187 match name {
188 "config" => opts.config = Some(value),
189 "identity" => opts.identity = Some(value),
190 "service" => opts.service = Some(value),
191 "announce" => {
192 opts.announce_period = Some(value.parse().map_err(|_| {
193 "--announce requires an integer period".to_string()
194 })?)
195 }
196 "allowed" => opts.allowed.push(value),
197 "timeout" => {
198 opts.timeout = Some(value.parse().map_err(|_| {
199 "--timeout requires a numeric value".to_string()
200 })?)
201 }
202 _ => {}
203 }
204 }
205 "verbose" => opts.verbose = opts.verbose.saturating_add(1),
206 "quiet" => opts.quiet = opts.quiet.saturating_add(1),
207 "print-identity" => opts.print_identity = true,
208 "base256" => opts.base256 = true,
209 "version" => opts.version = true,
210 "help" => opts.help = true,
211 "listen" => opts.listen = true,
212 "no-auth" => opts.no_auth = true,
213 "remote-command-as-args" => opts.remote_command_as_args = true,
214 "no-remote-command" => opts.no_remote_command = true,
215 "no-id" => opts.no_id = true,
216 "mirror" => opts.mirror_exit = true,
217 _ => return Err(format!("unknown option --{name}")),
218 }
219 i += 1;
220 continue;
221 }
222
223 let chars: Vec<char> = arg[1..].chars().collect();
224 let mut pos = 0;
225 while pos < chars.len() {
226 match chars[pos] {
227 'c' | 'i' | 's' | 'b' | 'a' | 'w' => {
228 let key = chars[pos];
229 let value = if pos + 1 < chars.len() {
230 chars[pos + 1..].iter().collect::<String>()
231 } else {
232 i += 1;
233 rnsh_argv
234 .get(i)
235 .ok_or_else(|| format!("-{key} requires a value"))?
236 .clone()
237 };
238 match key {
239 'c' => opts.config = Some(value),
240 'i' => opts.identity = Some(value),
241 's' => opts.service = Some(value),
242 'b' => {
243 opts.announce_period = Some(
244 value
245 .parse()
246 .map_err(|_| "-b requires an integer".to_string())?,
247 )
248 }
249 'a' => opts.allowed.push(value),
250 'w' => {
251 opts.timeout = Some(
252 value
253 .parse()
254 .map_err(|_| "-w requires a number".to_string())?,
255 )
256 }
257 _ => {}
258 }
259 break;
260 }
261 'v' => opts.verbose = opts.verbose.saturating_add(1),
262 'q' => opts.quiet = opts.quiet.saturating_add(1),
263 'p' => opts.print_identity = true,
264 'Z' => opts.base256 = true,
265 'l' => opts.listen = true,
266 'n' => opts.no_auth = true,
267 'A' => opts.remote_command_as_args = true,
268 'C' => opts.no_remote_command = true,
269 'N' => opts.no_id = true,
270 'm' => opts.mirror_exit = true,
271 'h' => opts.help = true,
272 other => return Err(format!("unknown option -{other}")),
273 }
274 pos += 1;
275 }
276 i += 1;
277 }
278
279 if opts.listen && opts.service.is_none() {
280 opts.service = Some(DEFAULT_SERVICE_NAME.to_string());
281 }
282 Ok(opts)
283 }
284}
285
286fn print_usage() {
287 eprintln!(
288 "Usage:\n rnsh -l [options] [-- command...]\n rnsh [options] <destination> [-- command...]\n\nOptions:\n -c, --config PATH Reticulum config directory\n -i, --identity PATH Identity file to use\n -p, --print-identity Print identity and destination info\n -Z, --base256 Also print compact base256 display for hashes\n -l, --listen Listen for remote shell links\n -s, --service NAME Listener identity service name\n -b, --announce PERIOD Announce on startup and every PERIOD seconds (0 = once)\n -a, --allowed HASH Allow initiator identity hash (repeatable)\n -n, --no-auth Allow any initiator identity\n -A, --remote-command-as-args\n -C, --no-remote-command\n -N, --no-id Do not identify to the listener\n -m, --mirror Return remote command exit code\n -w, --timeout SECONDS Path/link/protocol timeout"
289 );
290}
291
292fn init_rnsh_logging(opts: &CliOptions) -> Result<(), RnshError> {
293 let dir = rnsh_config_dir()?;
294 let file = std::fs::OpenOptions::new()
295 .create(true)
296 .append(true)
297 .open(dir.join("logfile"))?;
298 let mut builder = env_logger::Builder::new();
299 builder
300 .filter_level(rnsh_log_level(opts.listen, opts.verbose, opts.quiet))
301 .format_timestamp_secs()
302 .target(env_logger::Target::Pipe(Box::new(file)));
303 builder
304 .try_init()
305 .map_err(|err| RnshError::Protocol(format!("failed to initialize rnsh logging: {err}")))
306}
307
308fn rnsh_config_dir() -> Result<PathBuf, RnshError> {
309 let home = std::env::var("HOME").unwrap_or_else(|_| ".".into());
310 let xdg = PathBuf::from(&home).join(".config").join("rnsh");
311 if xdg.is_dir() {
312 return Ok(xdg);
313 }
314 let legacy = PathBuf::from(home).join(".rnsh");
315 fs::create_dir_all(&legacy)?;
316 Ok(legacy)
317}
318
319fn rnsh_log_level(listen: bool, verbose: u8, quiet: u8) -> log::LevelFilter {
320 let base: i16 = if listen { 3 } else { 1 };
321 match (base + verbose as i16 - quiet as i16).clamp(0, 5) {
322 0 => log::LevelFilter::Off,
323 1 => log::LevelFilter::Error,
324 2 => log::LevelFilter::Warn,
325 3 => log::LevelFilter::Info,
326 4 => log::LevelFilter::Debug,
327 _ => log::LevelFilter::Trace,
328 }
329}
330
331#[derive(Debug, Clone, PartialEq)]
332enum MsgValue {
333 Nil,
334 Bool(bool),
335 Int(i64),
336 String(String),
337 Bytes(Vec<u8>),
338 Array(Vec<MsgValue>),
339 Map(Vec<(MsgValue, MsgValue)>),
340}
341
342fn msgpack_pack(value: &MsgValue, out: &mut Vec<u8>) {
343 match value {
344 MsgValue::Nil => out.push(0xc0),
345 MsgValue::Bool(false) => out.push(0xc2),
346 MsgValue::Bool(true) => out.push(0xc3),
347 MsgValue::Int(v) if *v >= 0 && *v <= 0x7f => out.push(*v as u8),
348 MsgValue::Int(v) if *v >= -32 && *v < 0 => out.push((*v as i8) as u8),
349 MsgValue::Int(v) if *v >= i8::MIN as i64 && *v <= i8::MAX as i64 => {
350 out.extend_from_slice(&[0xd0, *v as i8 as u8]);
351 }
352 MsgValue::Int(v) if *v >= i16::MIN as i64 && *v <= i16::MAX as i64 => {
353 out.push(0xd1);
354 out.extend_from_slice(&(*v as i16).to_be_bytes());
355 }
356 MsgValue::Int(v) if *v >= i32::MIN as i64 && *v <= i32::MAX as i64 => {
357 out.push(0xd2);
358 out.extend_from_slice(&(*v as i32).to_be_bytes());
359 }
360 MsgValue::Int(v) => {
361 out.push(0xd3);
362 out.extend_from_slice(&v.to_be_bytes());
363 }
364 MsgValue::String(s) => pack_msgpack_str(s.as_bytes(), out, true),
365 MsgValue::Bytes(bytes) => pack_msgpack_str(bytes, out, false),
366 MsgValue::Array(items) => {
367 if items.len() < 16 {
368 out.push(0x90 | items.len() as u8);
369 } else if items.len() <= u16::MAX as usize {
370 out.push(0xdc);
371 out.extend_from_slice(&(items.len() as u16).to_be_bytes());
372 } else {
373 out.push(0xdd);
374 out.extend_from_slice(&(items.len() as u32).to_be_bytes());
375 }
376 for item in items {
377 msgpack_pack(item, out);
378 }
379 }
380 MsgValue::Map(items) => {
381 if items.len() < 16 {
382 out.push(0x80 | items.len() as u8);
383 } else {
384 out.push(0xde);
385 out.extend_from_slice(&(items.len() as u16).to_be_bytes());
386 }
387 for (key, value) in items {
388 msgpack_pack(key, out);
389 msgpack_pack(value, out);
390 }
391 }
392 }
393}
394
395fn pack_msgpack_str(bytes: &[u8], out: &mut Vec<u8>, utf8: bool) {
396 if utf8 {
397 if bytes.len() < 32 {
398 out.push(0xa0 | bytes.len() as u8);
399 } else if bytes.len() <= u8::MAX as usize {
400 out.extend_from_slice(&[0xd9, bytes.len() as u8]);
401 } else if bytes.len() <= u16::MAX as usize {
402 out.push(0xda);
403 out.extend_from_slice(&(bytes.len() as u16).to_be_bytes());
404 } else {
405 out.push(0xdb);
406 out.extend_from_slice(&(bytes.len() as u32).to_be_bytes());
407 }
408 } else if bytes.len() <= u8::MAX as usize {
409 out.extend_from_slice(&[0xc4, bytes.len() as u8]);
410 } else if bytes.len() <= u16::MAX as usize {
411 out.push(0xc5);
412 out.extend_from_slice(&(bytes.len() as u16).to_be_bytes());
413 } else {
414 out.push(0xc6);
415 out.extend_from_slice(&(bytes.len() as u32).to_be_bytes());
416 }
417 out.extend_from_slice(bytes);
418}
419
420fn msgpack_unpack(raw: &[u8]) -> Result<MsgValue, RnshError> {
421 let (value, consumed) = unpack_at(raw, 0)?;
422 if consumed != raw.len() {
423 return Err(RnshError::Protocol("trailing msgpack data".into()));
424 }
425 Ok(value)
426}
427
428fn unpack_at(raw: &[u8], mut pos: usize) -> Result<(MsgValue, usize), RnshError> {
429 let tag = *raw
430 .get(pos)
431 .ok_or_else(|| RnshError::Protocol("truncated msgpack".into()))?;
432 pos += 1;
433 match tag {
434 0x00..=0x7f => Ok((MsgValue::Int(tag as i64), pos)),
435 0x80..=0x8f => unpack_map(raw, pos, (tag & 0x0f) as usize),
436 0x90..=0x9f => unpack_array(raw, pos, (tag & 0x0f) as usize),
437 0xa0..=0xbf => unpack_string(raw, pos, (tag & 0x1f) as usize),
438 0xc0 => Ok((MsgValue::Nil, pos)),
439 0xc2 => Ok((MsgValue::Bool(false), pos)),
440 0xc3 => Ok((MsgValue::Bool(true), pos)),
441 0xc4 => {
442 let len = read_u8(raw, &mut pos)? as usize;
443 unpack_bytes(raw, pos, len)
444 }
445 0xc5 => {
446 let len = read_u16(raw, &mut pos)? as usize;
447 unpack_bytes(raw, pos, len)
448 }
449 0xc6 => {
450 let len = read_u32(raw, &mut pos)? as usize;
451 unpack_bytes(raw, pos, len)
452 }
453 0xcc => Ok((MsgValue::Int(read_u8(raw, &mut pos)? as i64), pos)),
454 0xcd => Ok((MsgValue::Int(read_u16(raw, &mut pos)? as i64), pos)),
455 0xce => Ok((MsgValue::Int(read_u32(raw, &mut pos)? as i64), pos)),
456 0xcf => Ok((MsgValue::Int(read_u64(raw, &mut pos)? as i64), pos)),
457 0xd0 => Ok((MsgValue::Int(read_u8(raw, &mut pos)? as i8 as i64), pos)),
458 0xd1 => Ok((MsgValue::Int(read_u16(raw, &mut pos)? as i16 as i64), pos)),
459 0xd2 => Ok((MsgValue::Int(read_u32(raw, &mut pos)? as i32 as i64), pos)),
460 0xd3 => Ok((MsgValue::Int(read_u64(raw, &mut pos)? as i64), pos)),
461 0xd9 => {
462 let len = read_u8(raw, &mut pos)? as usize;
463 unpack_string(raw, pos, len)
464 }
465 0xda => {
466 let len = read_u16(raw, &mut pos)? as usize;
467 unpack_string(raw, pos, len)
468 }
469 0xdb => {
470 let len = read_u32(raw, &mut pos)? as usize;
471 unpack_string(raw, pos, len)
472 }
473 0xdc => {
474 let len = read_u16(raw, &mut pos)? as usize;
475 unpack_array(raw, pos, len)
476 }
477 0xdd => {
478 let len = read_u32(raw, &mut pos)? as usize;
479 unpack_array(raw, pos, len)
480 }
481 0xde => {
482 let len = read_u16(raw, &mut pos)? as usize;
483 unpack_map(raw, pos, len)
484 }
485 0xdf => {
486 let len = read_u32(raw, &mut pos)? as usize;
487 unpack_map(raw, pos, len)
488 }
489 0xe0..=0xff => Ok((MsgValue::Int((tag as i8) as i64), pos)),
490 _ => Err(RnshError::Protocol(format!(
491 "unsupported msgpack tag 0x{tag:02x}"
492 ))),
493 }
494}
495
496fn unpack_array(raw: &[u8], mut pos: usize, len: usize) -> Result<(MsgValue, usize), RnshError> {
497 let mut values = Vec::with_capacity(len);
498 for _ in 0..len {
499 let (value, next) = unpack_at(raw, pos)?;
500 values.push(value);
501 pos = next;
502 }
503 Ok((MsgValue::Array(values), pos))
504}
505
506fn unpack_map(raw: &[u8], mut pos: usize, len: usize) -> Result<(MsgValue, usize), RnshError> {
507 let mut values = Vec::with_capacity(len);
508 for _ in 0..len {
509 let (key, next) = unpack_at(raw, pos)?;
510 let (value, next) = unpack_at(raw, next)?;
511 values.push((key, value));
512 pos = next;
513 }
514 Ok((MsgValue::Map(values), pos))
515}
516
517fn unpack_string(raw: &[u8], pos: usize, len: usize) -> Result<(MsgValue, usize), RnshError> {
518 let bytes = raw
519 .get(pos..pos + len)
520 .ok_or_else(|| RnshError::Protocol("truncated msgpack string".into()))?;
521 let s = std::str::from_utf8(bytes)
522 .map_err(|_| RnshError::Protocol("invalid msgpack utf8".into()))?;
523 Ok((MsgValue::String(s.to_string()), pos + len))
524}
525
526fn unpack_bytes(raw: &[u8], pos: usize, len: usize) -> Result<(MsgValue, usize), RnshError> {
527 let bytes = raw
528 .get(pos..pos + len)
529 .ok_or_else(|| RnshError::Protocol("truncated msgpack bytes".into()))?;
530 Ok((MsgValue::Bytes(bytes.to_vec()), pos + len))
531}
532
533fn read_u8(raw: &[u8], pos: &mut usize) -> Result<u8, RnshError> {
534 let v = *raw
535 .get(*pos)
536 .ok_or_else(|| RnshError::Protocol("truncated msgpack integer".into()))?;
537 *pos += 1;
538 Ok(v)
539}
540
541fn read_u16(raw: &[u8], pos: &mut usize) -> Result<u16, RnshError> {
542 let bytes = raw
543 .get(*pos..*pos + 2)
544 .ok_or_else(|| RnshError::Protocol("truncated msgpack integer".into()))?;
545 *pos += 2;
546 Ok(u16::from_be_bytes([bytes[0], bytes[1]]))
547}
548
549fn read_u32(raw: &[u8], pos: &mut usize) -> Result<u32, RnshError> {
550 let bytes = raw
551 .get(*pos..*pos + 4)
552 .ok_or_else(|| RnshError::Protocol("truncated msgpack integer".into()))?;
553 *pos += 4;
554 Ok(u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
555}
556
557fn read_u64(raw: &[u8], pos: &mut usize) -> Result<u64, RnshError> {
558 let bytes = raw
559 .get(*pos..*pos + 8)
560 .ok_or_else(|| RnshError::Protocol("truncated msgpack integer".into()))?;
561 *pos += 8;
562 Ok(u64::from_be_bytes([
563 bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
564 ]))
565}
566
567#[derive(Debug, Clone, PartialEq)]
568struct WindowSize {
569 rows: Option<u16>,
570 cols: Option<u16>,
571 hpix: Option<u16>,
572 vpix: Option<u16>,
573}
574
575#[derive(Debug, Clone, PartialEq)]
576struct ExecuteCommand {
577 cmdline: Vec<String>,
578 pipe_stdin: bool,
579 pipe_stdout: bool,
580 pipe_stderr: bool,
581 term: Option<String>,
582 rows: Option<u16>,
583 cols: Option<u16>,
584 hpix: Option<u16>,
585 vpix: Option<u16>,
586}
587
588#[derive(Debug, Clone, PartialEq)]
589enum RnshMessage {
590 Noop,
591 WindowSize(WindowSize),
592 ExecuteCommand(ExecuteCommand),
593 StreamData(StreamDataMessage),
594 VersionInfo {
595 sw_version: String,
596 protocol_version: u64,
597 },
598 Error {
599 msg: String,
600 fatal: bool,
601 },
602 CommandExited(i32),
603}
604
605impl RnshMessage {
606 fn msgtype(&self) -> u16 {
607 match self {
608 RnshMessage::Noop => MSG_NOOP,
609 RnshMessage::WindowSize(_) => MSG_WINDOW_SIZE,
610 RnshMessage::ExecuteCommand(_) => MSG_EXECUTE_COMMAND,
611 RnshMessage::StreamData(_) => MSG_STREAM_DATA,
612 RnshMessage::VersionInfo { .. } => MSG_VERSION_INFO,
613 RnshMessage::Error { .. } => MSG_ERROR,
614 RnshMessage::CommandExited(_) => MSG_COMMAND_EXITED,
615 }
616 }
617
618 fn pack(&self) -> Vec<u8> {
619 match self {
620 RnshMessage::Noop => Vec::new(),
621 RnshMessage::StreamData(msg) => msg.pack(),
622 RnshMessage::VersionInfo {
623 sw_version,
624 protocol_version,
625 } => pack_msgpack_array(vec![
626 MsgValue::String(sw_version.clone()),
627 MsgValue::Int(*protocol_version as i64),
628 ]),
629 RnshMessage::WindowSize(size) => pack_msgpack_array(vec![
630 opt_u16(size.rows),
631 opt_u16(size.cols),
632 opt_u16(size.hpix),
633 opt_u16(size.vpix),
634 ]),
635 RnshMessage::ExecuteCommand(cmd) => pack_msgpack_array(vec![
636 MsgValue::Array(
637 cmd.cmdline
638 .iter()
639 .map(|s| MsgValue::String(s.clone()))
640 .collect(),
641 ),
642 MsgValue::Bool(cmd.pipe_stdin),
643 MsgValue::Bool(cmd.pipe_stdout),
644 MsgValue::Bool(cmd.pipe_stderr),
645 MsgValue::Nil,
646 cmd.term
647 .as_ref()
648 .map(|s| MsgValue::String(s.clone()))
649 .unwrap_or(MsgValue::Nil),
650 opt_u16(cmd.rows),
651 opt_u16(cmd.cols),
652 opt_u16(cmd.hpix),
653 opt_u16(cmd.vpix),
654 ]),
655 RnshMessage::Error { msg, fatal } => pack_msgpack_array(vec![
656 MsgValue::String(msg.clone()),
657 MsgValue::Bool(*fatal),
658 MsgValue::Nil,
659 ]),
660 RnshMessage::CommandExited(code) => {
661 let mut out = Vec::new();
662 msgpack_pack(&MsgValue::Int(*code as i64), &mut out);
663 out
664 }
665 }
666 }
667
668 fn unpack(msgtype: u16, payload: &[u8]) -> Result<Self, RnshError> {
669 match msgtype {
670 MSG_NOOP => Ok(RnshMessage::Noop),
671 MSG_STREAM_DATA => Ok(RnshMessage::StreamData(
672 StreamDataMessage::unpack_bounded(
673 payload,
674 &Bzip2Compressor,
675 MAX_DECOMPRESSED_STREAM_CHUNK,
676 )
677 .map_err(|_| RnshError::Protocol("invalid stream data message".into()))?,
678 )),
679 MSG_VERSION_INFO => {
680 let values = expect_array(msgpack_unpack(payload)?, 2)?;
681 Ok(RnshMessage::VersionInfo {
682 sw_version: expect_string(&values[0])?,
683 protocol_version: expect_int(&values[1])? as u64,
684 })
685 }
686 MSG_WINDOW_SIZE => {
687 let values = expect_array(msgpack_unpack(payload)?, 4)?;
688 Ok(RnshMessage::WindowSize(WindowSize {
689 rows: opt_int_u16(&values[0])?,
690 cols: opt_int_u16(&values[1])?,
691 hpix: opt_int_u16(&values[2])?,
692 vpix: opt_int_u16(&values[3])?,
693 }))
694 }
695 MSG_EXECUTE_COMMAND => {
696 let values = expect_array(msgpack_unpack(payload)?, 10)?;
697 let cmdline = expect_array_value(&values[0])?
698 .iter()
699 .map(expect_string)
700 .collect::<Result<Vec<_>, _>>()?;
701 Ok(RnshMessage::ExecuteCommand(ExecuteCommand {
702 cmdline,
703 pipe_stdin: expect_bool(&values[1])?,
704 pipe_stdout: expect_bool(&values[2])?,
705 pipe_stderr: expect_bool(&values[3])?,
706 term: opt_string(&values[5])?,
707 rows: opt_int_u16(&values[6])?,
708 cols: opt_int_u16(&values[7])?,
709 hpix: opt_int_u16(&values[8])?,
710 vpix: opt_int_u16(&values[9])?,
711 }))
712 }
713 MSG_ERROR => {
714 let values = expect_array(msgpack_unpack(payload)?, 3)?;
715 Ok(RnshMessage::Error {
716 msg: expect_string(&values[0])?,
717 fatal: expect_bool(&values[1])?,
718 })
719 }
720 MSG_COMMAND_EXITED => Ok(RnshMessage::CommandExited(expect_int(&msgpack_unpack(
721 payload,
722 )?)? as i32)),
723 _ => Err(RnshError::Protocol(format!(
724 "unknown rnsh message type 0x{msgtype:04x}"
725 ))),
726 }
727 }
728}
729
730fn pack_msgpack_array(values: Vec<MsgValue>) -> Vec<u8> {
731 let mut out = Vec::new();
732 msgpack_pack(&MsgValue::Array(values), &mut out);
733 out
734}
735
736fn opt_u16(value: Option<u16>) -> MsgValue {
737 value
738 .map(|v| MsgValue::Int(v as i64))
739 .unwrap_or(MsgValue::Nil)
740}
741
742fn expect_array(value: MsgValue, len: usize) -> Result<Vec<MsgValue>, RnshError> {
743 match value {
744 MsgValue::Array(values) if values.len() == len => Ok(values),
745 _ => Err(RnshError::Protocol("unexpected msgpack array".into())),
746 }
747}
748
749fn expect_array_value(value: &MsgValue) -> Result<&[MsgValue], RnshError> {
750 match value {
751 MsgValue::Array(values) => Ok(values),
752 _ => Err(RnshError::Protocol("expected msgpack array".into())),
753 }
754}
755
756fn expect_string(value: &MsgValue) -> Result<String, RnshError> {
757 match value {
758 MsgValue::String(s) => Ok(s.clone()),
759 _ => Err(RnshError::Protocol("expected msgpack string".into())),
760 }
761}
762
763fn opt_string(value: &MsgValue) -> Result<Option<String>, RnshError> {
764 match value {
765 MsgValue::Nil => Ok(None),
766 MsgValue::String(s) => Ok(Some(s.clone())),
767 _ => Err(RnshError::Protocol(
768 "expected optional msgpack string".into(),
769 )),
770 }
771}
772
773fn expect_bool(value: &MsgValue) -> Result<bool, RnshError> {
774 match value {
775 MsgValue::Bool(v) => Ok(*v),
776 _ => Err(RnshError::Protocol("expected msgpack bool".into())),
777 }
778}
779
780fn expect_int(value: &MsgValue) -> Result<i64, RnshError> {
781 match value {
782 MsgValue::Int(v) => Ok(*v),
783 _ => Err(RnshError::Protocol("expected msgpack int".into())),
784 }
785}
786
787fn opt_int_u16(value: &MsgValue) -> Result<Option<u16>, RnshError> {
788 match value {
789 MsgValue::Nil => Ok(None),
790 MsgValue::Int(v) if *v >= 0 && *v <= u16::MAX as i64 => Ok(Some(*v as u16)),
791 _ => Err(RnshError::Protocol("expected optional u16".into())),
792 }
793}
794
795#[derive(Debug)]
796enum RnshEvent {
797 Announce(rns_net::AnnouncedIdentity),
798 LinkEstablished {
799 link_id: [u8; 16],
800 is_initiator: bool,
801 },
802 LinkClosed([u8; 16]),
803 RemoteIdentified {
804 link_id: [u8; 16],
805 identity_hash: IdentityHash,
806 },
807 ChannelMessage {
808 link_id: [u8; 16],
809 msgtype: u16,
810 payload: Vec<u8>,
811 },
812 ProcessOutput {
813 link_id: [u8; 16],
814 stream_id: u16,
815 data: Vec<u8>,
816 },
817 ProcessExited {
818 link_id: [u8; 16],
819 code: i32,
820 },
821 LocalStdin(Vec<u8>),
822 LocalStdinEof,
823}
824
825struct RnshCallbacks {
826 tx: mpsc::Sender<RnshEvent>,
827}
828
829impl Callbacks for RnshCallbacks {
830 fn on_announce(&mut self, announced: rns_net::AnnouncedIdentity) {
831 let _ = self.tx.send(RnshEvent::Announce(announced));
832 }
833
834 fn on_path_updated(&mut self, _dest_hash: DestHash, _hops: u8) {}
835
836 fn on_local_delivery(
837 &mut self,
838 _dest_hash: DestHash,
839 _raw: Vec<u8>,
840 _packet_hash: rns_net::PacketHash,
841 ) {
842 }
843
844 fn on_link_established(
845 &mut self,
846 link_id: rns_net::LinkId,
847 _dest_hash: DestHash,
848 _rtt: f64,
849 is_initiator: bool,
850 ) {
851 let _ = self.tx.send(RnshEvent::LinkEstablished {
852 link_id: link_id.0,
853 is_initiator,
854 });
855 }
856
857 fn on_link_closed(
858 &mut self,
859 link_id: rns_net::LinkId,
860 _reason: Option<rns_net::TeardownReason>,
861 ) {
862 let _ = self.tx.send(RnshEvent::LinkClosed(link_id.0));
863 }
864
865 fn on_remote_identified(
866 &mut self,
867 link_id: rns_net::LinkId,
868 identity_hash: IdentityHash,
869 _public_key: [u8; 64],
870 ) {
871 let _ = self.tx.send(RnshEvent::RemoteIdentified {
872 link_id: link_id.0,
873 identity_hash,
874 });
875 }
876
877 fn on_channel_message(&mut self, link_id: rns_net::LinkId, msgtype: u16, payload: Vec<u8>) {
878 let _ = self.tx.send(RnshEvent::ChannelMessage {
879 link_id: link_id.0,
880 msgtype,
881 payload,
882 });
883 }
884}
885
886trait RnshTransport {
887 fn send_rnsh_message(&self, link_id: [u8; 16], message: &RnshMessage) -> Result<(), RnshError>;
888
889 fn teardown_rnsh_link(&self, link_id: [u8; 16]) -> Result<(), RnshError>;
890}
891
892impl RnshTransport for RnsNode {
893 fn send_rnsh_message(&self, link_id: [u8; 16], message: &RnshMessage) -> Result<(), RnshError> {
894 self.send_channel_message(link_id, message.msgtype(), message.pack())?;
895 Ok(())
896 }
897
898 fn teardown_rnsh_link(&self, link_id: [u8; 16]) -> Result<(), RnshError> {
899 self.teardown_link(link_id)?;
900 Ok(())
901 }
902}
903
904struct ChildProcess {
905 pid: libc::pid_t,
906 stdin_fd: Option<RawFd>,
907 stdout_fd: Option<RawFd>,
908 stderr_fd: Option<RawFd>,
909}
910
911impl ChildProcess {
912 fn spawn(
913 link_id: [u8; 16],
914 argv: &[String],
915 env_overrides: &[(&str, String)],
916 flags: &ExecuteCommand,
917 event_tx: mpsc::Sender<RnshEvent>,
918 ) -> io::Result<Self> {
919 if argv.is_empty() {
920 return Err(io::Error::new(io::ErrorKind::InvalidInput, "empty command"));
921 }
922
923 let use_pty = !(flags.pipe_stdin && flags.pipe_stdout && flags.pipe_stderr);
924 let mut pty_master = None;
925 let mut pty_child = None;
926 if use_pty {
927 let mut master: libc::c_int = -1;
928 let mut child: libc::c_int = -1;
929 let rc = unsafe {
930 libc::openpty(
931 &mut master,
932 &mut child,
933 std::ptr::null_mut(),
934 std::ptr::null(),
935 std::ptr::null(),
936 )
937 };
938 if rc != 0 {
939 return Err(io::Error::last_os_error());
940 }
941 pty_master = Some(master);
942 pty_child = Some(child);
943 }
944
945 let stdin_pipe = if flags.pipe_stdin {
946 Some(pipe_pair()?)
947 } else {
948 None
949 };
950 let stdout_pipe = if flags.pipe_stdout {
951 Some(pipe_pair()?)
952 } else {
953 None
954 };
955 let stderr_pipe = if flags.pipe_stderr {
956 Some(pipe_pair()?)
957 } else {
958 None
959 };
960
961 let child_stdin = stdin_pipe.map(|p| p.0).or(pty_child).unwrap_or(-1);
962 let parent_stdin = stdin_pipe.map(|p| p.1).or(pty_master).unwrap_or(-1);
963 let parent_stdout = stdout_pipe.map(|p| p.0).or(pty_master).unwrap_or(-1);
964 let child_stdout = stdout_pipe.map(|p| p.1).or(pty_child).unwrap_or(-1);
965 let parent_stderr = stderr_pipe.map(|p| p.0).or(pty_master).unwrap_or(-1);
966 let child_stderr = stderr_pipe.map(|p| p.1).or(pty_child).unwrap_or(-1);
967
968 let pid = unsafe { libc::fork() };
969 if pid < 0 {
970 return Err(io::Error::last_os_error());
971 }
972 if pid == 0 {
973 unsafe {
974 if use_pty {
975 libc::setsid();
976 }
977 libc::dup2(child_stdin, 0);
978 libc::dup2(child_stdout, 1);
979 libc::dup2(child_stderr, 2);
980 if use_pty {
981 let tty_fd = if !flags.pipe_stdin {
982 0
983 } else if !flags.pipe_stdout {
984 1
985 } else {
986 2
987 };
988 libc::ioctl(tty_fd, libc::TIOCSCTTY, 0);
989 }
990 for fd in 3..1024 {
991 libc::close(fd);
992 }
993 for (key, value) in env_overrides {
994 if let (Ok(k), Ok(v)) = (CString::new(*key), CString::new(value.as_str())) {
995 libc::setenv(k.as_ptr(), v.as_ptr(), 1);
996 }
997 }
998 let c_args = argv
999 .iter()
1000 .map(|arg| CString::new(arg.as_str()))
1001 .collect::<Result<Vec<_>, _>>();
1002 if let Ok(c_args) = c_args {
1003 let mut ptrs = c_args.iter().map(|s| s.as_ptr()).collect::<Vec<_>>();
1004 ptrs.push(std::ptr::null());
1005 libc::execvp(ptrs[0], ptrs.as_ptr());
1006 }
1007 libc::_exit(255);
1008 }
1009 }
1010
1011 close_unique(&[
1012 pty_child,
1013 stdin_pipe.map(|p| p.0),
1014 stdout_pipe.map(|p| p.1),
1015 stderr_pipe.map(|p| p.1),
1016 ]);
1017
1018 let stdout_fd = if parent_stdout >= 0 {
1019 Some(parent_stdout)
1020 } else {
1021 None
1022 };
1023 let stderr_fd = if parent_stderr >= 0 && Some(parent_stderr) != stdout_fd {
1024 Some(parent_stderr)
1025 } else {
1026 None
1027 };
1028
1029 let mut reader_handles = Vec::new();
1030 if let Some(fd) = stdout_fd {
1031 reader_handles.push(spawn_reader(link_id, STREAM_STDOUT, fd, event_tx.clone()));
1032 }
1033 if let Some(fd) = stderr_fd {
1034 reader_handles.push(spawn_reader(link_id, STREAM_STDERR, fd, event_tx.clone()));
1035 }
1036 spawn_waiter(link_id, pid, reader_handles, event_tx);
1037
1038 Ok(ChildProcess {
1039 pid,
1040 stdin_fd: (parent_stdin >= 0).then_some(parent_stdin),
1041 stdout_fd,
1042 stderr_fd,
1043 })
1044 }
1045
1046 fn write_stdin(&self, data: &[u8]) {
1047 if let Some(fd) = self.stdin_fd {
1048 let _ = write_all_fd(fd, data);
1049 }
1050 }
1051
1052 fn close_stdin(&mut self) {
1053 if let Some(fd) = self.stdin_fd.take() {
1054 if Some(fd) == self.stdout_fd || Some(fd) == self.stderr_fd {
1055 let _ = write_all_fd(fd, b"\x04");
1056 self.stdin_fd = Some(fd);
1057 } else {
1058 unsafe {
1059 libc::close(fd);
1060 }
1061 }
1062 }
1063 }
1064
1065 fn set_winsize(&self, size: &WindowSize) {
1066 let Some(fd) = self.stdout_fd.or(self.stdin_fd) else {
1067 return;
1068 };
1069 let ws = libc::winsize {
1070 ws_row: size.rows.unwrap_or(0),
1071 ws_col: size.cols.unwrap_or(0),
1072 ws_xpixel: size.hpix.unwrap_or(0),
1073 ws_ypixel: size.vpix.unwrap_or(0),
1074 };
1075 unsafe {
1076 libc::ioctl(fd, libc::TIOCSWINSZ, &ws);
1077 }
1078 }
1079
1080 fn terminate(&mut self) {
1081 unsafe {
1082 libc::kill(self.pid, libc::SIGTERM);
1083 }
1084 self.close_stdin();
1085 }
1086}
1087
1088impl Drop for ChildProcess {
1089 fn drop(&mut self) {
1090 close_unique(&[
1091 self.stdin_fd.take(),
1092 self.stdout_fd.take(),
1093 self.stderr_fd.take(),
1094 ]);
1095 }
1096}
1097
1098fn pipe_pair() -> io::Result<(RawFd, RawFd)> {
1099 let mut fds = [-1; 2];
1100 if unsafe { libc::pipe(fds.as_mut_ptr()) } != 0 {
1101 return Err(io::Error::last_os_error());
1102 }
1103 Ok((fds[0], fds[1]))
1104}
1105
1106fn close_unique(fds: &[Option<RawFd>]) {
1107 let mut seen = HashSet::new();
1108 for fd in fds.iter().flatten().copied() {
1109 if fd >= 0 && seen.insert(fd) {
1110 unsafe {
1111 libc::close(fd);
1112 }
1113 }
1114 }
1115}
1116
1117fn spawn_reader(
1118 link_id: [u8; 16],
1119 stream_id: u16,
1120 fd: RawFd,
1121 event_tx: mpsc::Sender<RnshEvent>,
1122) -> std::thread::JoinHandle<()> {
1123 std::thread::spawn(move || {
1124 let mut buf = [0u8; 4096];
1125 loop {
1126 let n = unsafe { libc::read(fd, buf.as_mut_ptr().cast(), buf.len()) };
1127 if n > 0 {
1128 let _ = event_tx.send(RnshEvent::ProcessOutput {
1129 link_id,
1130 stream_id,
1131 data: buf[..n as usize].to_vec(),
1132 });
1133 } else {
1134 break;
1135 }
1136 }
1137 })
1138}
1139
1140fn spawn_waiter(
1141 link_id: [u8; 16],
1142 pid: libc::pid_t,
1143 reader_handles: Vec<std::thread::JoinHandle<()>>,
1144 event_tx: mpsc::Sender<RnshEvent>,
1145) {
1146 std::thread::spawn(move || {
1147 let mut status = 0;
1148 let _ = unsafe { libc::waitpid(pid, &mut status, 0) };
1149 let code = if libc::WIFEXITED(status) {
1150 libc::WEXITSTATUS(status)
1151 } else if libc::WIFSIGNALED(status) {
1152 128 + libc::WTERMSIG(status)
1153 } else {
1154 255
1155 };
1156 for handle in reader_handles {
1157 let _ = handle.join();
1158 }
1159 let _ = event_tx.send(RnshEvent::ProcessExited { link_id, code });
1160 });
1161}
1162
1163fn write_all_fd(fd: RawFd, mut data: &[u8]) -> io::Result<()> {
1164 while !data.is_empty() {
1165 let n = unsafe { libc::write(fd, data.as_ptr().cast(), data.len()) };
1166 if n < 0 {
1167 return Err(io::Error::last_os_error());
1168 }
1169 data = &data[n as usize..];
1170 }
1171 Ok(())
1172}
1173
1174struct TtyRestorer {
1175 fd: RawFd,
1176 original: Option<libc::termios>,
1177}
1178
1179impl TtyRestorer {
1180 fn new(fd: RawFd) -> Self {
1181 let mut original = unsafe { std::mem::zeroed() };
1182 let original = if unsafe { libc::tcgetattr(fd, &mut original) } == 0 {
1183 Some(original)
1184 } else {
1185 None
1186 };
1187 TtyRestorer { fd, original }
1188 }
1189
1190 fn raw(&self) {
1191 let Some(mut raw) = self.original else {
1192 return;
1193 };
1194 unsafe {
1195 libc::cfmakeraw(&mut raw);
1196 libc::tcsetattr(self.fd, libc::TCSANOW, &raw);
1197 }
1198 }
1199}
1200
1201impl Drop for TtyRestorer {
1202 fn drop(&mut self) {
1203 if let Some(original) = self.original {
1204 unsafe {
1205 libc::tcsetattr(self.fd, libc::TCSADRAIN, &original);
1206 }
1207 }
1208 }
1209}
1210
1211fn current_winsize(fd: RawFd) -> WindowSize {
1212 let mut ws = libc::winsize {
1213 ws_row: 0,
1214 ws_col: 0,
1215 ws_xpixel: 0,
1216 ws_ypixel: 0,
1217 };
1218 if unsafe { libc::ioctl(fd, libc::TIOCGWINSZ, &mut ws) } == 0 {
1219 WindowSize {
1220 rows: nonzero_u16(ws.ws_row),
1221 cols: nonzero_u16(ws.ws_col),
1222 hpix: nonzero_u16(ws.ws_xpixel),
1223 vpix: nonzero_u16(ws.ws_ypixel),
1224 }
1225 } else {
1226 WindowSize {
1227 rows: None,
1228 cols: None,
1229 hpix: None,
1230 vpix: None,
1231 }
1232 }
1233}
1234
1235fn nonzero_u16(value: u16) -> Option<u16> {
1236 (value != 0).then_some(value)
1237}
1238
1239#[derive(Clone)]
1240struct ListenerConfig {
1241 default_command: Vec<String>,
1242 allow_all: bool,
1243 allowed: HashSet<[u8; 16]>,
1244 allow_remote_command: bool,
1245 remote_command_as_args: bool,
1246}
1247
1248#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1249enum ListenerState {
1250 WaitIdent,
1251 WaitVersion,
1252 WaitCommand,
1253 Running,
1254 Closed,
1255}
1256
1257struct ListenerSession {
1258 link_id: [u8; 16],
1259 state: ListenerState,
1260 remote_identity: Option<IdentityHash>,
1261 config: ListenerConfig,
1262 process: Option<ChildProcess>,
1263}
1264
1265impl ListenerSession {
1266 fn new(link_id: [u8; 16], config: ListenerConfig) -> Self {
1267 let state = if config.allow_all {
1268 ListenerState::WaitVersion
1269 } else {
1270 ListenerState::WaitIdent
1271 };
1272 ListenerSession {
1273 link_id,
1274 state,
1275 remote_identity: None,
1276 config,
1277 process: None,
1278 }
1279 }
1280
1281 fn remote_identified(&mut self, transport: &dyn RnshTransport, identity_hash: IdentityHash) {
1282 if !self.config.allow_all && !self.config.allowed.contains(&identity_hash.0) {
1283 let _ = send_message(
1284 transport,
1285 self.link_id,
1286 &RnshMessage::Error {
1287 msg: "Identity is not allowed.".into(),
1288 fatal: true,
1289 },
1290 );
1291 let _ = transport.teardown_rnsh_link(self.link_id);
1292 self.state = ListenerState::Closed;
1293 return;
1294 }
1295 self.remote_identity = Some(identity_hash);
1296 if self.state == ListenerState::WaitIdent {
1297 self.state = ListenerState::WaitVersion;
1298 }
1299 }
1300
1301 fn handle_message(
1302 &mut self,
1303 transport: &dyn RnshTransport,
1304 event_tx: &mpsc::Sender<RnshEvent>,
1305 msgtype: u16,
1306 payload: Vec<u8>,
1307 ) {
1308 if self.state == ListenerState::WaitIdent {
1309 return;
1310 }
1311 let message = match RnshMessage::unpack(msgtype, &payload) {
1312 Ok(message) => message,
1313 Err(err) => {
1314 self.protocol_error(transport, &err.to_string());
1315 return;
1316 }
1317 };
1318 match self.state {
1319 ListenerState::WaitVersion => match message {
1320 RnshMessage::VersionInfo {
1321 protocol_version, ..
1322 } if protocol_version == PROTOCOL_VERSION => {
1323 let _ = send_message(transport, self.link_id, &version_message());
1324 self.state = ListenerState::WaitCommand;
1325 }
1326 RnshMessage::VersionInfo { .. } => {
1327 self.protocol_error(transport, "Incompatible protocol");
1328 }
1329 _ => self.protocol_error(transport, "expected version info"),
1330 },
1331 ListenerState::WaitCommand => match message {
1332 RnshMessage::ExecuteCommand(command) => {
1333 if let Err(err) = self.start_command(transport, event_tx, command) {
1334 self.protocol_error(transport, &format!("Unable to start process: {err}"));
1335 } else {
1336 self.state = ListenerState::Running;
1337 }
1338 }
1339 _ => self.protocol_error(transport, "expected execute command"),
1340 },
1341 ListenerState::Running => match message {
1342 RnshMessage::WindowSize(size) => {
1343 if let Some(process) = &self.process {
1344 process.set_winsize(&size);
1345 }
1346 }
1347 RnshMessage::StreamData(data) if data.stream_id == STREAM_STDIN => {
1348 if let Some(process) = &mut self.process {
1349 if !data.data.is_empty() {
1350 process.write_stdin(&data.data);
1351 }
1352 if data.eof {
1353 process.close_stdin();
1354 }
1355 }
1356 }
1357 RnshMessage::Noop => {
1358 let _ = send_message(transport, self.link_id, &RnshMessage::Noop);
1359 }
1360 _ => self.protocol_error(transport, "unexpected message while running"),
1361 },
1362 ListenerState::WaitIdent | ListenerState::Closed => {}
1363 }
1364 }
1365
1366 fn start_command(
1367 &mut self,
1368 transport: &dyn RnshTransport,
1369 event_tx: &mpsc::Sender<RnshEvent>,
1370 command: ExecuteCommand,
1371 ) -> Result<(), RnshError> {
1372 if !self.config.allow_remote_command && !command.cmdline.is_empty() {
1373 let _ = send_message(
1374 transport,
1375 self.link_id,
1376 &RnshMessage::Error {
1377 msg: "Remote command line not allowed by listener".into(),
1378 fatal: true,
1379 },
1380 );
1381 return Err(RnshError::Protocol(
1382 "remote command line not allowed by listener".into(),
1383 ));
1384 }
1385
1386 let mut argv = self.config.default_command.clone();
1387 if self.config.remote_command_as_args && !command.cmdline.is_empty() {
1388 argv.extend(command.cmdline.clone());
1389 } else if !command.cmdline.is_empty() {
1390 argv = command.cmdline.clone();
1391 }
1392
1393 let remote_identity = self
1394 .remote_identity
1395 .as_ref()
1396 .map(|ih| prettyhexrep(&ih.0))
1397 .unwrap_or_default();
1398 let env = [
1399 (
1400 "TERM",
1401 command
1402 .term
1403 .clone()
1404 .or_else(|| std::env::var("TERM").ok())
1405 .unwrap_or_else(|| "xterm".into()),
1406 ),
1407 ("RNS_REMOTE_IDENTITY", remote_identity),
1408 ];
1409 let process = ChildProcess::spawn(self.link_id, &argv, &env, &command, event_tx.clone())?;
1410 process.set_winsize(&WindowSize {
1411 rows: command.rows,
1412 cols: command.cols,
1413 hpix: command.hpix,
1414 vpix: command.vpix,
1415 });
1416 self.process = Some(process);
1417 Ok(())
1418 }
1419
1420 fn protocol_error(&mut self, transport: &dyn RnshTransport, message: &str) {
1421 let _ = send_message(
1422 transport,
1423 self.link_id,
1424 &RnshMessage::Error {
1425 msg: message.into(),
1426 fatal: true,
1427 },
1428 );
1429 let _ = transport.teardown_rnsh_link(self.link_id);
1430 if let Some(process) = &mut self.process {
1431 process.terminate();
1432 }
1433 self.state = ListenerState::Closed;
1434 }
1435}
1436
1437fn listen(opts: CliOptions) -> Result<(), RnshError> {
1438 let (event_tx, event_rx) = mpsc::channel();
1439 let node = RnsNode::connect_shared_from_config(
1440 opts.config.as_deref().map(Path::new),
1441 Box::new(RnshCallbacks {
1442 tx: event_tx.clone(),
1443 }),
1444 )?;
1445
1446 let service = opts.service.as_deref().unwrap_or(DEFAULT_SERVICE_NAME);
1447 let identity = prepare_identity(
1448 opts.config.as_deref(),
1449 opts.identity.as_deref(),
1450 Some(service),
1451 )?;
1452 let identity_hash = IdentityHash(*identity.hash());
1453 let dest = Destination::single_in(APP_NAME, &[], identity_hash);
1454 let (sig_prv, sig_pub) = extract_sig_keys(&identity)?;
1455 node.register_destination_with_proof(
1456 &dest,
1457 Some(
1458 identity.get_private_key().ok_or_else(|| {
1459 RnshError::Protocol("listener identity has no private key".into())
1460 })?,
1461 ),
1462 )?;
1463 node.register_link_destination(dest.hash.0, sig_prv, sig_pub, 0)?;
1464
1465 eprintln!("rnsh listening on {}", prettyhexrep(&dest.hash.0));
1466
1467 let allowed = load_allowed_identities(&opts)?;
1468 if allowed.is_empty() && !opts.no_auth {
1469 eprintln!("warning: no allowed identities configured; no initiators will be accepted");
1470 }
1471
1472 let default_command = if opts.command.is_empty() {
1473 vec![std::env::var("SHELL").unwrap_or_else(|_| "/bin/sh".into())]
1474 } else {
1475 opts.command.clone()
1476 };
1477 let config = ListenerConfig {
1478 default_command,
1479 allow_all: opts.no_auth,
1480 allowed,
1481 allow_remote_command: !opts.no_remote_command,
1482 remote_command_as_args: opts.remote_command_as_args,
1483 };
1484
1485 let mut sessions: HashMap<[u8; 16], ListenerSession> = HashMap::new();
1486 let mut last_announce = Instant::now() - Duration::from_secs(24 * 60 * 60);
1487 let mut announced_once = false;
1488
1489 loop {
1490 if let Some(period) = opts.announce_period {
1491 let due = period == 0 && !announced_once
1492 || period > 0 && last_announce.elapsed() >= Duration::from_secs(period);
1493 if due {
1494 node.announce(&dest, &identity, None)?;
1495 last_announce = Instant::now();
1496 announced_once = true;
1497 }
1498 }
1499
1500 match event_rx.recv_timeout(Duration::from_millis(100)) {
1501 Ok(RnshEvent::LinkEstablished {
1502 link_id,
1503 is_initiator: false,
1504 ..
1505 }) => {
1506 sessions
1507 .entry(link_id)
1508 .or_insert_with(|| ListenerSession::new(link_id, config.clone()));
1509 }
1510 Ok(RnshEvent::RemoteIdentified {
1511 link_id,
1512 identity_hash,
1513 }) => {
1514 if let Some(session) = sessions.get_mut(&link_id) {
1515 session.remote_identified(&node, identity_hash);
1516 }
1517 }
1518 Ok(RnshEvent::ChannelMessage {
1519 link_id,
1520 msgtype,
1521 payload,
1522 }) => {
1523 if let Some(session) = sessions.get_mut(&link_id) {
1524 session.handle_message(&node, &event_tx, msgtype, payload);
1525 }
1526 }
1527 Ok(RnshEvent::ProcessOutput {
1528 link_id,
1529 stream_id,
1530 data,
1531 }) => {
1532 send_stream_chunks(&node, link_id, stream_id, &data, false)?;
1533 }
1534 Ok(RnshEvent::ProcessExited { link_id, code }) => {
1535 send_stream_chunks(&node, link_id, STREAM_STDOUT, &[], true)?;
1536 let _ = send_message(&node, link_id, &RnshMessage::CommandExited(code));
1537 sessions.remove(&link_id);
1538 }
1539 Ok(RnshEvent::LinkClosed(link_id)) => {
1540 if let Some(mut session) = sessions.remove(&link_id) {
1541 if let Some(process) = &mut session.process {
1542 process.terminate();
1543 }
1544 }
1545 }
1546 Ok(_) | Err(mpsc::RecvTimeoutError::Timeout) => {}
1547 Err(mpsc::RecvTimeoutError::Disconnected) => break,
1548 }
1549 }
1550 Ok(())
1551}
1552
1553fn initiate(opts: CliOptions) -> Result<i32, RnshError> {
1554 let dest_hash = parse_hash_16(
1555 opts.destination
1556 .as_deref()
1557 .ok_or_else(|| RnshError::Protocol("missing destination".into()))?,
1558 )
1559 .ok_or_else(|| RnshError::Protocol("destination must be 32 hexadecimal characters".into()))?;
1560 let timeout = Duration::from_secs_f64(opts.timeout.unwrap_or(15.0));
1561 let (event_tx, event_rx) = mpsc::channel();
1562 let node = RnsNode::connect_shared_from_config(
1563 opts.config.as_deref().map(Path::new),
1564 Box::new(RnshCallbacks {
1565 tx: event_tx.clone(),
1566 }),
1567 )?;
1568 let identity = prepare_identity(opts.config.as_deref(), opts.identity.as_deref(), None)?;
1569
1570 wait_for_path(&node, dest_hash, &event_rx, timeout)?;
1571 let recalled = node
1572 .recall_identity(&DestHash(dest_hash))?
1573 .ok_or_else(|| RnshError::Protocol("destination identity was not recalled".into()))?;
1574 let mut sig_pub = [0u8; 32];
1575 sig_pub.copy_from_slice(&recalled.public_key[32..64]);
1576
1577 let link_id = node.create_link(dest_hash, sig_pub)?;
1578 wait_for_link(&event_rx, link_id, timeout)?;
1579 if !opts.no_id {
1580 node.identify_on_link(
1581 link_id,
1582 identity
1583 .get_private_key()
1584 .ok_or_else(|| RnshError::Protocol("identity has no private key".into()))?,
1585 )?;
1586 }
1587
1588 send_message(&node, link_id, &version_message())?;
1589 wait_for_version(&event_rx, timeout)?;
1590
1591 let stdin_is_tty = io::stdin().is_terminal();
1592 let stdout_is_tty = io::stdout().is_terminal();
1593 let stderr_is_tty = io::stderr().is_terminal();
1594 let size = current_winsize(0);
1595 let execute = ExecuteCommand {
1596 cmdline: opts.command.clone(),
1597 pipe_stdin: !stdin_is_tty,
1598 pipe_stdout: !stdout_is_tty,
1599 pipe_stderr: !stderr_is_tty,
1600 term: std::env::var("TERM").ok(),
1601 rows: size.rows,
1602 cols: size.cols,
1603 hpix: size.hpix,
1604 vpix: size.vpix,
1605 };
1606 send_message(&node, link_id, &RnshMessage::ExecuteCommand(execute))?;
1607
1608 let tty = stdin_is_tty.then(|| {
1609 let restorer = TtyRestorer::new(0);
1610 restorer.raw();
1611 unsafe {
1612 libc::signal(libc::SIGWINCH, sigwinch_handler as *const () as usize);
1613 }
1614 restorer
1615 });
1616 let _keep_tty = tty;
1617 spawn_stdin_reader(event_tx);
1618
1619 loop {
1620 if SIGWINCH_SEEN.swap(false, Ordering::SeqCst) {
1621 let _ = send_message(&node, link_id, &RnshMessage::WindowSize(current_winsize(0)));
1622 }
1623 match event_rx.recv_timeout(Duration::from_millis(100)) {
1624 Ok(RnshEvent::ChannelMessage {
1625 msgtype, payload, ..
1626 }) => match RnshMessage::unpack(msgtype, &payload)? {
1627 RnshMessage::StreamData(data) if data.stream_id == STREAM_STDOUT => {
1628 io::stdout().write_all(&data.data)?;
1629 io::stdout().flush()?;
1630 }
1631 RnshMessage::StreamData(data) if data.stream_id == STREAM_STDERR => {
1632 io::stderr().write_all(&data.data)?;
1633 io::stderr().flush()?;
1634 }
1635 RnshMessage::CommandExited(code) => return Ok(code),
1636 RnshMessage::Error { msg, fatal } => {
1637 eprintln!("remote error: {msg}");
1638 if fatal {
1639 return Ok(200);
1640 }
1641 }
1642 _ => {}
1643 },
1644 Ok(RnshEvent::LocalStdin(data)) => {
1645 send_stream_chunks(&node, link_id, STREAM_STDIN, &data, false)?;
1646 }
1647 Ok(RnshEvent::LocalStdinEof) => {
1648 send_stream_chunks(&node, link_id, STREAM_STDIN, &[], true)?;
1649 }
1650 Ok(RnshEvent::LinkClosed(_)) => return Ok(0),
1651 Ok(_) | Err(mpsc::RecvTimeoutError::Timeout) => {}
1652 Err(mpsc::RecvTimeoutError::Disconnected) => return Ok(0),
1653 }
1654 }
1655}
1656
1657fn send_message(
1658 transport: &dyn RnshTransport,
1659 link_id: [u8; 16],
1660 message: &RnshMessage,
1661) -> Result<(), RnshError> {
1662 transport.send_rnsh_message(link_id, message)
1663}
1664
1665fn send_stream_chunks(
1666 transport: &dyn RnshTransport,
1667 link_id: [u8; 16],
1668 stream_id: u16,
1669 data: &[u8],
1670 eof: bool,
1671) -> Result<(), RnshError> {
1672 for chunk in data.chunks(STREAM_CHUNK_MAX) {
1673 let msg = RnshMessage::StreamData(StreamDataMessage::new(
1674 stream_id,
1675 chunk.to_vec(),
1676 false,
1677 false,
1678 ));
1679 send_message(transport, link_id, &msg)?;
1680 }
1681 if eof {
1682 let msg =
1683 RnshMessage::StreamData(StreamDataMessage::new(stream_id, Vec::new(), true, false));
1684 send_message(transport, link_id, &msg)?;
1685 }
1686 Ok(())
1687}
1688
1689fn version_message() -> RnshMessage {
1690 RnshMessage::VersionInfo {
1691 sw_version: VERSION.into(),
1692 protocol_version: PROTOCOL_VERSION,
1693 }
1694}
1695
1696fn wait_for_path(
1697 node: &RnsNode,
1698 dest_hash: [u8; 16],
1699 event_rx: &mpsc::Receiver<RnshEvent>,
1700 timeout: Duration,
1701) -> Result<(), RnshError> {
1702 let started = Instant::now();
1703 if !node.has_path(&DestHash(dest_hash))? {
1704 node.request_path(&DestHash(dest_hash))?;
1705 }
1706 while started.elapsed() < timeout {
1707 if node.has_path(&DestHash(dest_hash))? {
1708 return Ok(());
1709 }
1710 if let Ok(RnshEvent::Announce(announced)) =
1711 event_rx.recv_timeout(Duration::from_millis(250))
1712 {
1713 if announced.dest_hash.0 == dest_hash {
1714 return Ok(());
1715 }
1716 }
1717 }
1718 Err(RnshError::Protocol("path not found".into()))
1719}
1720
1721fn wait_for_link(
1722 event_rx: &mpsc::Receiver<RnshEvent>,
1723 expected_link: [u8; 16],
1724 timeout: Duration,
1725) -> Result<(), RnshError> {
1726 let started = Instant::now();
1727 while started.elapsed() < timeout {
1728 match event_rx.recv_timeout(Duration::from_millis(100)) {
1729 Ok(RnshEvent::LinkEstablished {
1730 link_id,
1731 is_initiator: true,
1732 ..
1733 }) if link_id == expected_link => return Ok(()),
1734 Ok(_) | Err(mpsc::RecvTimeoutError::Timeout) => {}
1735 Err(mpsc::RecvTimeoutError::Disconnected) => break,
1736 }
1737 }
1738 Err(RnshError::Protocol("link establishment timed out".into()))
1739}
1740
1741fn wait_for_version(
1742 event_rx: &mpsc::Receiver<RnshEvent>,
1743 timeout: Duration,
1744) -> Result<(), RnshError> {
1745 let started = Instant::now();
1746 while started.elapsed() < timeout {
1747 match event_rx.recv_timeout(Duration::from_millis(100)) {
1748 Ok(RnshEvent::ChannelMessage {
1749 msgtype, payload, ..
1750 }) => match RnshMessage::unpack(msgtype, &payload)? {
1751 RnshMessage::VersionInfo {
1752 protocol_version, ..
1753 } if protocol_version == PROTOCOL_VERSION => return Ok(()),
1754 RnshMessage::Error { msg, .. } => return Err(RnshError::Protocol(msg)),
1755 _ => {}
1756 },
1757 Ok(_) | Err(mpsc::RecvTimeoutError::Timeout) => {}
1758 Err(mpsc::RecvTimeoutError::Disconnected) => break,
1759 }
1760 }
1761 Err(RnshError::Protocol(
1762 "protocol version exchange timed out".into(),
1763 ))
1764}
1765
1766fn spawn_stdin_reader(event_tx: mpsc::Sender<RnshEvent>) {
1767 std::thread::spawn(move || {
1768 let mut stdin = io::stdin();
1769 let mut buf = [0u8; 4096];
1770 loop {
1771 match stdin.read(&mut buf) {
1772 Ok(0) => {
1773 let _ = event_tx.send(RnshEvent::LocalStdinEof);
1774 break;
1775 }
1776 Ok(n) => {
1777 let _ = event_tx.send(RnshEvent::LocalStdin(buf[..n].to_vec()));
1778 }
1779 Err(_) => {
1780 let _ = event_tx.send(RnshEvent::LocalStdinEof);
1781 break;
1782 }
1783 }
1784 }
1785 });
1786}
1787
1788fn prepare_identity(
1789 config: Option<&str>,
1790 explicit_path: Option<&str>,
1791 service: Option<&str>,
1792) -> Result<Identity, RnshError> {
1793 let path = if let Some(path) = explicit_path {
1794 PathBuf::from(path)
1795 } else {
1796 let config_dir = rns_net::storage::resolve_config_dir(config.map(Path::new));
1797 let paths = rns_net::storage::ensure_storage_dirs(&config_dir)?;
1798 let suffix = service.map(sanitize_service_name).unwrap_or_default();
1799 let filename = if suffix.is_empty() {
1800 APP_NAME.to_string()
1801 } else {
1802 format!("{APP_NAME}.{suffix}")
1803 };
1804 paths.identities.join(filename)
1805 };
1806 if let Some(parent) = path.parent() {
1807 fs::create_dir_all(parent)?;
1808 }
1809 if path.exists() {
1810 Ok(rns_net::storage::load_identity(&path)?)
1811 } else {
1812 let identity = Identity::new(&mut OsRng);
1813 rns_net::storage::save_identity(&identity, &path)?;
1814 Ok(identity)
1815 }
1816}
1817
1818fn print_identity(opts: &CliOptions) -> Result<(), RnshError> {
1819 let identity = prepare_identity(
1820 opts.config.as_deref(),
1821 opts.identity.as_deref(),
1822 opts.service.as_deref(),
1823 )?;
1824 println!("Identity : {}", prettyhexrep(identity.hash()));
1825 if opts.base256 {
1826 println!("Identity b256: {}", prettyb256rep(identity.hash()));
1827 }
1828 if opts.listen {
1829 let dest = Destination::single_in(APP_NAME, &[], IdentityHash(*identity.hash()));
1830 println!("Listening on : {}", prettyhexrep(&dest.hash.0));
1831 if opts.base256 {
1832 println!("Listen b256 : {}", prettyb256rep(&dest.hash.0));
1833 }
1834 }
1835 Ok(())
1836}
1837
1838fn load_allowed_identities(opts: &CliOptions) -> Result<HashSet<[u8; 16]>, RnshError> {
1839 let mut allowed = HashSet::new();
1840 for entry in &opts.allowed {
1841 if let Some(hash) = parse_hash_16(entry) {
1842 allowed.insert(hash);
1843 } else {
1844 return Err(RnshError::Protocol(format!(
1845 "invalid allowed identity hash: {entry}"
1846 )));
1847 }
1848 }
1849 for path in allowed_identity_files() {
1850 if !path.exists() {
1851 continue;
1852 }
1853 let contents = fs::read_to_string(path)?;
1854 for line in contents
1855 .lines()
1856 .map(str::trim)
1857 .filter(|line| !line.is_empty())
1858 {
1859 if let Some(hash) = parse_hash_16(line) {
1860 allowed.insert(hash);
1861 }
1862 }
1863 }
1864 Ok(allowed)
1865}
1866
1867fn allowed_identity_files() -> Vec<PathBuf> {
1868 let home = std::env::var("HOME").unwrap_or_else(|_| ".".into());
1869 vec![
1870 PathBuf::from(&home)
1871 .join(".config")
1872 .join("rnsh")
1873 .join("allowed_identities"),
1874 PathBuf::from(home).join(".rnsh").join("allowed_identities"),
1875 ]
1876}
1877
1878fn sanitize_service_name(value: &str) -> String {
1879 value
1880 .chars()
1881 .filter(|c| c.is_ascii_alphanumeric())
1882 .collect()
1883}
1884
1885fn parse_hash_16(value: &str) -> Option<[u8; 16]> {
1886 let s = value.trim();
1887 if s.len() != 32 {
1888 return None;
1889 }
1890 let mut out = [0u8; 16];
1891 for i in 0..16 {
1892 out[i] = u8::from_str_radix(&s[i * 2..i * 2 + 2], 16).ok()?;
1893 }
1894 Some(out)
1895}
1896
1897fn extract_sig_keys(identity: &Identity) -> Result<([u8; 32], [u8; 32]), RnshError> {
1898 let private = identity
1899 .get_private_key()
1900 .ok_or_else(|| RnshError::Protocol("identity has no private key".into()))?;
1901 let public = identity
1902 .get_public_key()
1903 .ok_or_else(|| RnshError::Protocol("identity has no public key".into()))?;
1904 let mut sig_prv = [0u8; 32];
1905 let mut sig_pub = [0u8; 32];
1906 sig_prv.copy_from_slice(&private[32..64]);
1907 sig_pub.copy_from_slice(&public[32..64]);
1908 Ok((sig_prv, sig_pub))
1909}
1910
1911#[cfg(test)]
1912mod tests {
1913 use super::*;
1914 use std::sync::Mutex;
1915
1916 const TEST_LINK: [u8; 16] = [0x42; 16];
1917
1918 #[derive(Default)]
1919 struct FakeTransport {
1920 sent: Mutex<Vec<([u8; 16], u16, Vec<u8>)>>,
1921 teardowns: Mutex<Vec<[u8; 16]>>,
1922 }
1923
1924 impl FakeTransport {
1925 fn sent_messages(&self) -> Vec<([u8; 16], RnshMessage)> {
1926 self.sent
1927 .lock()
1928 .unwrap()
1929 .iter()
1930 .map(|(link_id, msgtype, payload)| {
1931 (
1932 *link_id,
1933 RnshMessage::unpack(*msgtype, payload)
1934 .expect("fake transport stored decodable message"),
1935 )
1936 })
1937 .collect()
1938 }
1939 }
1940
1941 impl RnshTransport for FakeTransport {
1942 fn send_rnsh_message(
1943 &self,
1944 link_id: [u8; 16],
1945 message: &RnshMessage,
1946 ) -> Result<(), RnshError> {
1947 self.sent
1948 .lock()
1949 .unwrap()
1950 .push((link_id, message.msgtype(), message.pack()));
1951 Ok(())
1952 }
1953
1954 fn teardown_rnsh_link(&self, link_id: [u8; 16]) -> Result<(), RnshError> {
1955 self.teardowns.lock().unwrap().push(link_id);
1956 Ok(())
1957 }
1958 }
1959
1960 fn test_config() -> ListenerConfig {
1961 ListenerConfig {
1962 default_command: vec!["/bin/cat".into()],
1963 allow_all: true,
1964 allowed: HashSet::new(),
1965 allow_remote_command: true,
1966 remote_command_as_args: false,
1967 }
1968 }
1969
1970 fn exec_msg(cmdline: Vec<&str>) -> RnshMessage {
1971 RnshMessage::ExecuteCommand(ExecuteCommand {
1972 cmdline: cmdline.into_iter().map(str::to_string).collect(),
1973 pipe_stdin: true,
1974 pipe_stdout: true,
1975 pipe_stderr: true,
1976 term: Some("xterm".into()),
1977 rows: Some(24),
1978 cols: Some(80),
1979 hpix: None,
1980 vpix: None,
1981 })
1982 }
1983
1984 #[test]
1985 fn msgpack_version_matches_upstream_shape() {
1986 let msg = RnshMessage::VersionInfo {
1987 sw_version: "1.2.0".into(),
1988 protocol_version: 1,
1989 };
1990 let packed = msg.pack();
1991 assert_eq!(packed, b"\x92\xa51.2.0\x01");
1992 assert_eq!(RnshMessage::unpack(MSG_VERSION_INFO, &packed).unwrap(), msg);
1993 }
1994
1995 #[test]
1996 fn execute_command_roundtrips() {
1997 let msg = RnshMessage::ExecuteCommand(ExecuteCommand {
1998 cmdline: vec!["/bin/sh".into(), "-lc".into(), "echo hi".into()],
1999 pipe_stdin: true,
2000 pipe_stdout: true,
2001 pipe_stderr: false,
2002 term: Some("xterm-256color".into()),
2003 rows: Some(24),
2004 cols: Some(80),
2005 hpix: None,
2006 vpix: None,
2007 });
2008 let packed = msg.pack();
2009 assert_eq!(
2010 RnshMessage::unpack(MSG_EXECUTE_COMMAND, &packed).unwrap(),
2011 msg
2012 );
2013 }
2014
2015 #[test]
2016 fn stream_data_uses_upstream_header_bits() {
2017 let msg = RnshMessage::StreamData(StreamDataMessage::new(2, b"err".to_vec(), true, false));
2018 let packed = msg.pack();
2019 assert_eq!(&packed[..2], &0x8002u16.to_be_bytes());
2020 assert_eq!(RnshMessage::unpack(MSG_STREAM_DATA, &packed).unwrap(), msg);
2021 }
2022
2023 #[test]
2024 fn cli_splits_command_after_double_dash() {
2025 let args = CliOptions::parse(vec![
2026 "-l".into(),
2027 "-s".into(),
2028 "ops".into(),
2029 "--".into(),
2030 "/bin/sh".into(),
2031 "-l".into(),
2032 ])
2033 .unwrap();
2034 assert!(args.listen);
2035 assert_eq!(args.service.as_deref(), Some("ops"));
2036 assert_eq!(args.command, vec!["/bin/sh", "-l"]);
2037 }
2038
2039 #[test]
2040 fn cli_parses_base256_display_flag() {
2041 let short = CliOptions::parse(vec!["-Zp".into()]).unwrap();
2042 assert!(short.base256);
2043 assert!(short.print_identity);
2044
2045 let long = CliOptions::parse(vec!["--base256".into(), "--print-identity".into()]).unwrap();
2046 assert!(long.base256);
2047 assert!(long.print_identity);
2048 }
2049
2050 #[test]
2051 fn service_name_is_sanitized_like_upstream() {
2052 assert_eq!(sanitize_service_name("dev-shell_1!"), "devshell1");
2053 }
2054
2055 #[test]
2056 fn rnsh_logging_uses_file_oriented_levels() {
2057 assert_eq!(rnsh_log_level(false, 0, 0), log::LevelFilter::Error);
2058 assert_eq!(rnsh_log_level(true, 0, 0), log::LevelFilter::Info);
2059 assert_eq!(rnsh_log_level(true, 2, 0), log::LevelFilter::Trace);
2060 assert_eq!(rnsh_log_level(true, 0, 4), log::LevelFilter::Off);
2061 }
2062
2063 #[test]
2064 fn listener_rejects_unallowed_identity_and_tears_down() {
2065 let mut allowed = HashSet::new();
2066 allowed.insert([0x11; 16]);
2067 let config = ListenerConfig {
2068 allow_all: false,
2069 allowed,
2070 ..test_config()
2071 };
2072 let fake = FakeTransport::default();
2073 let mut session = ListenerSession::new(TEST_LINK, config);
2074
2075 assert_eq!(session.state, ListenerState::WaitIdent);
2076 session.remote_identified(&fake, IdentityHash([0x22; 16]));
2077
2078 assert_eq!(session.state, ListenerState::Closed);
2079 assert_eq!(fake.teardowns.lock().unwrap().as_slice(), &[TEST_LINK]);
2080 let messages = fake.sent_messages();
2081 assert_eq!(messages.len(), 1);
2082 assert!(matches!(
2083 &messages[0].1,
2084 RnshMessage::Error { msg, fatal: true } if msg == "Identity is not allowed."
2085 ));
2086 }
2087
2088 #[test]
2089 fn listener_accepts_allowed_identity_and_completes_version_handshake() {
2090 let mut allowed = HashSet::new();
2091 allowed.insert([0x11; 16]);
2092 let config = ListenerConfig {
2093 allow_all: false,
2094 allowed,
2095 ..test_config()
2096 };
2097 let fake = FakeTransport::default();
2098 let (tx, _rx) = mpsc::channel();
2099 let mut session = ListenerSession::new(TEST_LINK, config);
2100
2101 session.remote_identified(&fake, IdentityHash([0x11; 16]));
2102 assert_eq!(session.state, ListenerState::WaitVersion);
2103 let version = version_message();
2104 session.handle_message(&fake, &tx, version.msgtype(), version.pack());
2105
2106 assert_eq!(session.state, ListenerState::WaitCommand);
2107 let messages = fake.sent_messages();
2108 assert_eq!(messages.len(), 1);
2109 assert!(matches!(
2110 &messages[0].1,
2111 RnshMessage::VersionInfo {
2112 protocol_version: PROTOCOL_VERSION,
2113 ..
2114 }
2115 ));
2116 }
2117
2118 #[test]
2119 fn listener_rejects_incompatible_protocol_version() {
2120 let fake = FakeTransport::default();
2121 let (tx, _rx) = mpsc::channel();
2122 let mut session = ListenerSession::new(TEST_LINK, test_config());
2123 let msg = RnshMessage::VersionInfo {
2124 sw_version: "future".into(),
2125 protocol_version: PROTOCOL_VERSION + 1,
2126 };
2127
2128 session.handle_message(&fake, &tx, msg.msgtype(), msg.pack());
2129
2130 assert_eq!(session.state, ListenerState::Closed);
2131 assert_eq!(fake.teardowns.lock().unwrap().as_slice(), &[TEST_LINK]);
2132 assert!(matches!(
2133 &fake.sent_messages()[0].1,
2134 RnshMessage::Error { msg, fatal: true } if msg == "Incompatible protocol"
2135 ));
2136 }
2137
2138 #[test]
2139 fn listener_rejects_remote_command_when_disabled() {
2140 let fake = FakeTransport::default();
2141 let (tx, _rx) = mpsc::channel();
2142 let config = ListenerConfig {
2143 allow_remote_command: false,
2144 ..test_config()
2145 };
2146 let mut session = ListenerSession::new(TEST_LINK, config);
2147 let version = version_message();
2148 session.handle_message(&fake, &tx, version.msgtype(), version.pack());
2149 let exec = exec_msg(vec!["/bin/echo", "nope"]);
2150
2151 session.handle_message(&fake, &tx, exec.msgtype(), exec.pack());
2152
2153 assert_eq!(session.state, ListenerState::Closed);
2154 assert_eq!(fake.teardowns.lock().unwrap().as_slice(), &[TEST_LINK]);
2155 assert!(fake.sent_messages().iter().any(|(_, msg)| matches!(
2156 msg,
2157 RnshMessage::Error { msg, fatal: true }
2158 if msg.contains("Remote command line not allowed")
2159 )));
2160 }
2161
2162 #[test]
2163 fn listener_executes_default_command_and_forwards_stdin_to_process() {
2164 let fake = FakeTransport::default();
2165 let (tx, rx) = mpsc::channel();
2166 let mut session = ListenerSession::new(TEST_LINK, test_config());
2167 let version = version_message();
2168 session.handle_message(&fake, &tx, version.msgtype(), version.pack());
2169 let exec = exec_msg(Vec::new());
2170 session.handle_message(&fake, &tx, exec.msgtype(), exec.pack());
2171 assert_eq!(session.state, ListenerState::Running);
2172
2173 let stdin = RnshMessage::StreamData(StreamDataMessage::new(
2174 STREAM_STDIN,
2175 b"hello over stdin".to_vec(),
2176 true,
2177 false,
2178 ));
2179 session.handle_message(&fake, &tx, stdin.msgtype(), stdin.pack());
2180
2181 let started = Instant::now();
2182 let mut stdout = Vec::new();
2183 let mut exit = None;
2184 while started.elapsed() < Duration::from_secs(5) && exit.is_none() {
2185 match rx.recv_timeout(Duration::from_millis(100)).unwrap() {
2186 RnshEvent::ProcessOutput {
2187 stream_id: STREAM_STDOUT,
2188 data,
2189 ..
2190 } => stdout.extend(data),
2191 RnshEvent::ProcessExited { code, .. } => exit = Some(code),
2192 _ => {}
2193 }
2194 }
2195
2196 assert_eq!(stdout, b"hello over stdin");
2197 assert_eq!(exit, Some(0));
2198 }
2199
2200 #[test]
2201 fn send_stream_chunks_splits_large_payload_and_appends_eof() {
2202 let fake = FakeTransport::default();
2203 let data = vec![0x55; STREAM_CHUNK_MAX * 2 + 3];
2204
2205 send_stream_chunks(&fake, TEST_LINK, STREAM_STDOUT, &data, true).unwrap();
2206
2207 let messages = fake.sent_messages();
2208 assert_eq!(messages.len(), 4);
2209 let mut payload = Vec::new();
2210 for (_, message) in &messages[..3] {
2211 match message {
2212 RnshMessage::StreamData(stream) => {
2213 assert_eq!(stream.stream_id, STREAM_STDOUT);
2214 assert!(!stream.eof);
2215 payload.extend_from_slice(&stream.data);
2216 }
2217 other => panic!("expected stream data, got {other:?}"),
2218 }
2219 }
2220 assert_eq!(payload, data);
2221 assert!(matches!(
2222 messages.last().unwrap().1,
2223 RnshMessage::StreamData(ref stream)
2224 if stream.stream_id == STREAM_STDOUT && stream.eof && stream.data.is_empty()
2225 ));
2226 }
2227
2228 #[test]
2229 fn process_pipe_mode_reports_stdout_stderr_and_exit() {
2230 let (tx, rx) = mpsc::channel();
2231 let link_id = [7u8; 16];
2232 let command = ExecuteCommand {
2233 cmdline: Vec::new(),
2234 pipe_stdin: true,
2235 pipe_stdout: true,
2236 pipe_stderr: true,
2237 term: None,
2238 rows: None,
2239 cols: None,
2240 hpix: None,
2241 vpix: None,
2242 };
2243 let _process = ChildProcess::spawn(
2244 link_id,
2245 &[
2246 "/bin/sh".into(),
2247 "-c".into(),
2248 "printf out; printf err >&2; exit 13".into(),
2249 ],
2250 &[],
2251 &command,
2252 tx,
2253 )
2254 .unwrap();
2255
2256 let started = Instant::now();
2257 let mut stdout = Vec::new();
2258 let mut stderr = Vec::new();
2259 let mut exit = None;
2260 while started.elapsed() < Duration::from_secs(5) && exit.is_none() {
2261 match rx.recv_timeout(Duration::from_millis(100)).unwrap() {
2262 RnshEvent::ProcessOutput {
2263 stream_id: STREAM_STDOUT,
2264 data,
2265 ..
2266 } => stdout.extend(data),
2267 RnshEvent::ProcessOutput {
2268 stream_id: STREAM_STDERR,
2269 data,
2270 ..
2271 } => stderr.extend(data),
2272 RnshEvent::ProcessExited { code, .. } => exit = Some(code),
2273 _ => {}
2274 }
2275 }
2276 assert_eq!(stdout, b"out");
2277 assert_eq!(stderr, b"err");
2278 assert_eq!(exit, Some(13));
2279 }
2280}