1#![forbid(unsafe_code)]
2#![allow(clippy::doc_markdown)]
4#![allow(clippy::nursery)]
5#![allow(clippy::pedantic)]
6
7use std::collections::HashMap;
57use std::fmt;
58use std::future::Future;
59use std::io::{self, Write};
60use std::net::SocketAddr;
61use std::pin::Pin;
62use std::sync::Arc;
63use std::sync::mpsc::Sender;
64use std::time::Duration;
65
66use bubbletea::Message;
67use parking_lot::RwLock;
68use thiserror::Error;
69use tokio::net::TcpListener;
70use tracing::{debug, error, info, warn};
71
72pub mod auth;
73mod handler;
74pub mod session;
75
76pub use auth::{
77 AcceptAllAuth, AsyncCallbackAuth, AsyncPublicKeyAuth, AuthContext, AuthHandler, AuthMethod,
78 AuthResult, AuthorizedKey, AuthorizedKeysAuth, CallbackAuth, CompositeAuth, PasswordAuth,
79 PublicKeyAuth, PublicKeyCallbackAuth, RateLimitedAuth, SessionId, parse_authorized_keys,
80};
81pub use handler::{RusshConfig, ServerState, WishHandler, WishHandlerFactory, run_stream};
82
83pub use bubbletea;
85pub use lipgloss;
86
87#[derive(Error, Debug)]
126pub enum Error {
127 #[error("io error: {0}")]
134 Io(#[from] io::Error),
135
136 #[error("ssh error: {0}")]
140 Ssh(String),
141
142 #[error("russh error: {0}")]
146 Russh(#[from] russh::Error),
147
148 #[error("key error: {0}")]
152 Key(String),
153
154 #[error("key loading error: {0}")]
159 KeyLoad(#[from] russh_keys::Error),
160
161 #[error("authentication failed")]
166 AuthenticationFailed,
167
168 #[error("maximum sessions reached ({current}/{max})")]
172 MaxSessionsReached {
173 max: usize,
175 current: usize,
177 },
178
179 #[error("configuration error: {0}")]
183 Configuration(String),
184
185 #[error("session error: {0}")]
189 Session(String),
190
191 #[error("address parse error: {0}")]
195 AddrParse(#[from] std::net::AddrParseError),
196}
197
198pub type Result<T> = std::result::Result<T, Error>;
202
203#[derive(Debug, Clone, Copy, PartialEq, Eq)]
209pub struct Window {
210 pub width: u32,
212 pub height: u32,
214}
215
216impl Default for Window {
217 fn default() -> Self {
218 Self {
219 width: 80,
220 height: 24,
221 }
222 }
223}
224
225#[derive(Debug, Clone)]
227pub struct Pty {
228 pub term: String,
230 pub window: Window,
232}
233
234impl Default for Pty {
235 fn default() -> Self {
236 Self {
237 term: "xterm-256color".to_string(),
238 window: Window::default(),
239 }
240 }
241}
242
243#[derive(Debug, Clone)]
249pub struct PublicKey {
250 pub key_type: String,
252 pub data: Vec<u8>,
254 pub comment: Option<String>,
256}
257
258impl PublicKey {
259 pub fn new(key_type: impl Into<String>, data: Vec<u8>) -> Self {
261 Self {
262 key_type: key_type.into(),
263 data,
264 comment: None,
265 }
266 }
267
268 pub fn with_comment(mut self, comment: impl Into<String>) -> Self {
270 self.comment = Some(comment.into());
271 self
272 }
273
274 pub fn fingerprint(&self) -> String {
279 use std::collections::hash_map::DefaultHasher;
280 use std::hash::{Hash, Hasher};
281 let mut hasher = DefaultHasher::new();
282 self.data.hash(&mut hasher);
283 format!("HASH:{:016x}", hasher.finish())
284 }
285}
286
287impl PartialEq for PublicKey {
288 fn eq(&self, other: &Self) -> bool {
289 self.key_type == other.key_type && self.data == other.data
290 }
291}
292
293impl Eq for PublicKey {}
294
295#[derive(Debug, Clone)]
301pub struct Context {
302 user: String,
304 remote_addr: SocketAddr,
306 local_addr: SocketAddr,
308 client_version: String,
310 values: Arc<RwLock<HashMap<String, String>>>,
312}
313
314impl Context {
315 pub fn new(user: impl Into<String>, remote_addr: SocketAddr, local_addr: SocketAddr) -> Self {
317 Self {
318 user: user.into(),
319 remote_addr,
320 local_addr,
321 client_version: String::new(),
322 values: Arc::new(RwLock::new(HashMap::new())),
323 }
324 }
325
326 pub fn user(&self) -> &str {
328 &self.user
329 }
330
331 pub fn remote_addr(&self) -> SocketAddr {
333 self.remote_addr
334 }
335
336 pub fn local_addr(&self) -> SocketAddr {
338 self.local_addr
339 }
340
341 pub fn client_version(&self) -> &str {
343 &self.client_version
344 }
345
346 pub fn set_client_version(&mut self, version: impl Into<String>) {
348 self.client_version = version.into();
349 }
350
351 pub fn set_value(&self, key: impl Into<String>, value: impl Into<String>) {
353 self.values.write().insert(key.into(), value.into());
354 }
355
356 pub fn get_value(&self, key: &str) -> Option<String> {
358 self.values.read().get(key).cloned()
359 }
360}
361
362#[derive(Clone)]
368pub struct Session {
369 context: Context,
371 pty: Option<Pty>,
373 command: Vec<String>,
375 env: HashMap<String, String>,
377 #[allow(dead_code)]
379 pub(crate) stdout: Arc<RwLock<Vec<u8>>>,
380 #[allow(dead_code)]
382 pub(crate) stderr: Arc<RwLock<Vec<u8>>>,
383 exit_code: Arc<RwLock<Option<i32>>>,
385 closed: Arc<RwLock<bool>>,
387 public_key: Option<PublicKey>,
389 subsystem: Option<String>,
391
392 output_tx: Option<tokio::sync::mpsc::UnboundedSender<SessionOutput>>,
394 input_rx: Arc<tokio::sync::Mutex<Option<tokio::sync::mpsc::Receiver<Vec<u8>>>>>,
396 message_tx: Arc<RwLock<Option<Sender<Message>>>>,
398}
399
400#[derive(Debug)]
402pub enum SessionOutput {
403 Stdout(Vec<u8>),
405 Stderr(Vec<u8>),
407 Exit(u32),
409 Close,
411}
412
413impl fmt::Debug for Session {
414 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
415 f.debug_struct("Session")
416 .field("user", &self.context.user)
417 .field("remote_addr", &self.context.remote_addr)
418 .field("pty", &self.pty)
419 .field("command", &self.command)
420 .finish()
421 }
422}
423
424impl Session {
425 pub fn new(context: Context) -> Self {
427 Self {
428 context,
429 pty: None,
430 command: Vec::new(),
431 env: HashMap::new(),
432 stdout: Arc::new(RwLock::new(Vec::new())),
433 stderr: Arc::new(RwLock::new(Vec::new())),
434 exit_code: Arc::new(RwLock::new(None)),
435 closed: Arc::new(RwLock::new(false)),
436 public_key: None,
437 subsystem: None,
438 output_tx: None,
439 input_rx: Arc::new(tokio::sync::Mutex::new(None)),
440 message_tx: Arc::new(RwLock::new(None)),
441 }
442 }
443
444 pub fn set_output_sender(&mut self, tx: tokio::sync::mpsc::UnboundedSender<SessionOutput>) {
446 self.output_tx = Some(tx);
447 }
448
449 pub async fn set_input_receiver(&self, rx: tokio::sync::mpsc::Receiver<Vec<u8>>) {
451 *self.input_rx.lock().await = Some(rx);
452 }
453
454 pub async fn recv(&self) -> Option<Vec<u8>> {
456 let mut rx_guard = self.input_rx.lock().await;
457 if let Some(rx) = rx_guard.as_mut() {
458 rx.recv().await
459 } else {
460 None
461 }
462 }
463
464 pub fn set_message_sender(&self, tx: Sender<Message>) {
466 *self.message_tx.write() = Some(tx);
467 }
468
469 pub fn send_message(&self, msg: Message) {
471 if let Some(tx) = self.message_tx.read().as_ref() {
472 let _ = tx.send(msg);
474 }
475 }
476
477 pub fn user(&self) -> &str {
479 self.context.user()
480 }
481
482 pub fn remote_addr(&self) -> SocketAddr {
484 self.context.remote_addr()
485 }
486
487 pub fn local_addr(&self) -> SocketAddr {
489 self.context.local_addr()
490 }
491
492 pub fn context(&self) -> &Context {
494 &self.context
495 }
496
497 pub fn pty(&self) -> (Option<&Pty>, bool) {
499 (self.pty.as_ref(), self.pty.is_some())
500 }
501
502 pub fn command(&self) -> &[String] {
504 &self.command
505 }
506
507 pub fn get_env(&self, key: &str) -> Option<&String> {
509 self.env.get(key)
510 }
511
512 pub fn environ(&self) -> &HashMap<String, String> {
514 &self.env
515 }
516
517 pub fn public_key(&self) -> Option<&PublicKey> {
519 self.public_key.as_ref()
520 }
521
522 pub fn subsystem(&self) -> Option<&str> {
524 self.subsystem.as_deref()
525 }
526
527 pub fn write(&self, data: &[u8]) -> io::Result<usize> {
529 if let Some(tx) = &self.output_tx {
531 let _ = tx.send(SessionOutput::Stdout(data.to_vec()));
532 }
533
534 Ok(data.len())
535 }
536
537 pub fn write_stderr(&self, data: &[u8]) -> io::Result<usize> {
539 if let Some(tx) = &self.output_tx {
541 let _ = tx.send(SessionOutput::Stderr(data.to_vec()));
542 }
543
544 Ok(data.len())
545 }
546
547 pub fn exit(&self, code: i32) -> io::Result<()> {
549 *self.exit_code.write() = Some(code);
550 if let Some(tx) = &self.output_tx {
551 let _ = tx.send(SessionOutput::Exit(code as u32));
552 }
553 Ok(())
554 }
555
556 pub fn close(&self) -> io::Result<()> {
558 *self.closed.write() = true;
559 if let Some(tx) = &self.output_tx {
560 let _ = tx.send(SessionOutput::Close);
561 }
562 Ok(())
563 }
564
565 pub fn is_closed(&self) -> bool {
567 *self.closed.read()
568 }
569
570 pub fn window(&self) -> Window {
572 self.pty.as_ref().map(|p| p.window).unwrap_or_default()
573 }
574
575 pub fn with_pty(mut self, pty: Pty) -> Self {
579 self.pty = Some(pty);
580 self
581 }
582
583 pub fn with_command(mut self, command: Vec<String>) -> Self {
585 self.command = command;
586 self
587 }
588
589 pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
591 self.env.insert(key.into(), value.into());
592 self
593 }
594
595 pub fn with_public_key(mut self, key: PublicKey) -> Self {
597 self.public_key = Some(key);
598 self
599 }
600
601 pub fn with_subsystem(mut self, subsystem: impl Into<String>) -> Self {
603 self.subsystem = Some(subsystem.into());
604 self
605 }
606}
607
608impl Write for Session {
610 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
611 Session::write(self, buf)
612 }
613
614 fn flush(&mut self) -> io::Result<()> {
615 Ok(())
616 }
617}
618
619pub fn print(session: &Session, args: impl fmt::Display) {
625 let _ = session.write(args.to_string().as_bytes());
626}
627
628pub fn println(session: &Session, args: impl fmt::Display) {
630 let msg = format!("{}\r\n", args);
631 let _ = session.write(msg.as_bytes());
632}
633
634pub fn printf(session: &Session, format: impl fmt::Display, args: &[&dyn fmt::Display]) {
636 let mut msg = format.to_string();
637 for arg in args {
638 if let Some(pos) = msg.find("{}") {
639 msg.replace_range(pos..pos + 2, &arg.to_string());
640 }
641 }
642 let _ = session.write(msg.as_bytes());
643}
644
645pub fn error(session: &Session, args: impl fmt::Display) {
647 let _ = session.write_stderr(args.to_string().as_bytes());
648}
649
650pub fn errorln(session: &Session, args: impl fmt::Display) {
652 let msg = format!("{}\r\n", args);
653 let _ = session.write_stderr(msg.as_bytes());
654}
655
656pub fn errorf(session: &Session, format: impl fmt::Display, args: &[&dyn fmt::Display]) {
658 let mut msg = format.to_string();
659 for arg in args {
660 if let Some(pos) = msg.find("{}") {
661 msg.replace_range(pos..pos + 2, &arg.to_string());
662 }
663 }
664 let _ = session.write_stderr(msg.as_bytes());
665}
666
667pub fn fatal(session: &Session, args: impl fmt::Display) {
669 error(session, args);
670 let _ = session.exit(1);
671 let _ = session.close();
672}
673
674pub fn fatalln(session: &Session, args: impl fmt::Display) {
676 errorln(session, args);
677 let _ = session.exit(1);
678 let _ = session.close();
679}
680
681pub fn fatalf(session: &Session, format: impl fmt::Display, args: &[&dyn fmt::Display]) {
683 errorf(session, format, args);
684 let _ = session.exit(1);
685 let _ = session.close();
686}
687
688pub fn write_string(session: &Session, s: &str) -> io::Result<usize> {
690 session.write(s.as_bytes())
691}
692
693pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
699
700pub type Handler = Arc<dyn Fn(Session) -> BoxFuture<'static, ()> + Send + Sync>;
702
703pub type Middleware = Arc<dyn Fn(Handler) -> Handler + Send + Sync>;
705
706pub fn handler<F, Fut>(f: F) -> Handler
708where
709 F: Fn(Session) -> Fut + Send + Sync + 'static,
710 Fut: Future<Output = ()> + Send + 'static,
711{
712 Arc::new(move |session| Box::pin(f(session)))
713}
714
715pub fn noop_handler() -> Handler {
717 Arc::new(|_| Box::pin(async {}))
718}
719
720pub fn compose_middleware(middlewares: Vec<Middleware>) -> Middleware {
722 Arc::new(move |h| {
723 let mut handler = h;
724 for mw in middlewares.iter().rev() {
725 handler = mw(handler);
726 }
727 handler
728 })
729}
730
731pub type PublicKeyHandler = Arc<dyn Fn(&Context, &PublicKey) -> bool + Send + Sync>;
737
738pub type PasswordHandler = Arc<dyn Fn(&Context, &str) -> bool + Send + Sync>;
740
741pub type KeyboardInteractiveHandler =
743 Arc<dyn Fn(&Context, &str, &[String], &[bool]) -> Vec<String> + Send + Sync>;
744
745pub type BannerHandler = Arc<dyn Fn(&Context) -> String + Send + Sync>;
747
748pub type SubsystemHandler = Arc<dyn Fn(Session) -> BoxFuture<'static, ()> + Send + Sync>;
750
751#[derive(Clone)]
757pub struct ServerOptions {
758 pub address: String,
760 pub version: String,
762 pub banner: Option<String>,
764 pub banner_handler: Option<BannerHandler>,
766 pub host_key_path: Option<String>,
768 pub host_key_pem: Option<Vec<u8>>,
770 pub middlewares: Vec<Middleware>,
772 pub handler: Option<Handler>,
774 pub auth_handler: Option<Arc<dyn AuthHandler>>,
777 pub public_key_handler: Option<PublicKeyHandler>,
779 pub password_handler: Option<PasswordHandler>,
781 pub keyboard_interactive_handler: Option<KeyboardInteractiveHandler>,
783 pub idle_timeout: Option<Duration>,
785 pub max_timeout: Option<Duration>,
787 pub subsystem_handlers: HashMap<String, SubsystemHandler>,
789 pub max_auth_attempts: u32,
791 pub auth_rejection_delay_ms: u64,
793 pub allow_no_auth: bool,
800}
801
802impl Default for ServerOptions {
803 fn default() -> Self {
804 Self {
805 address: "0.0.0.0:22".to_string(),
806 version: "SSH-2.0-Wish".to_string(),
807 banner: None,
808 banner_handler: None,
809 host_key_path: None,
810 host_key_pem: None,
811 middlewares: Vec::new(),
812 handler: None,
813 auth_handler: None,
814 public_key_handler: None,
815 password_handler: None,
816 keyboard_interactive_handler: None,
817 idle_timeout: None,
818 max_timeout: None,
819 subsystem_handlers: HashMap::new(),
820 max_auth_attempts: auth::DEFAULT_MAX_AUTH_ATTEMPTS,
821 auth_rejection_delay_ms: auth::DEFAULT_AUTH_REJECTION_DELAY_MS,
822 allow_no_auth: false,
823 }
824 }
825}
826
827impl fmt::Debug for ServerOptions {
828 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
829 f.debug_struct("ServerOptions")
830 .field("address", &self.address)
831 .field("version", &self.version)
832 .field("banner", &self.banner)
833 .field("host_key_path", &self.host_key_path)
834 .field("idle_timeout", &self.idle_timeout)
835 .field("max_timeout", &self.max_timeout)
836 .finish()
837 }
838}
839
840pub type ServerOption = Box<dyn FnOnce(&mut ServerOptions) -> Result<()> + Send>;
846
847pub fn with_address(addr: impl Into<String>) -> ServerOption {
849 let addr = addr.into();
850 Box::new(move |opts| {
851 opts.address = addr;
852 Ok(())
853 })
854}
855
856pub fn with_version(version: impl Into<String>) -> ServerOption {
858 let version = version.into();
859 Box::new(move |opts| {
860 opts.version = version;
861 Ok(())
862 })
863}
864
865pub fn with_banner(banner: impl Into<String>) -> ServerOption {
867 let banner = banner.into();
868 Box::new(move |opts| {
869 opts.banner = Some(banner);
870 Ok(())
871 })
872}
873
874pub fn with_banner_handler<F>(handler: F) -> ServerOption
876where
877 F: Fn(&Context) -> String + Send + Sync + 'static,
878{
879 Box::new(move |opts| {
880 opts.banner_handler = Some(Arc::new(handler));
881 Ok(())
882 })
883}
884
885pub fn with_middleware(mw: Middleware) -> ServerOption {
887 Box::new(move |opts| {
888 opts.middlewares.push(mw);
889 Ok(())
890 })
891}
892
893pub fn with_host_key_path(path: impl Into<String>) -> ServerOption {
895 let path = path.into();
896 Box::new(move |opts| {
897 opts.host_key_path = Some(path);
898 Ok(())
899 })
900}
901
902pub fn with_host_key_pem(pem: Vec<u8>) -> ServerOption {
904 Box::new(move |opts| {
905 opts.host_key_pem = Some(pem);
906 Ok(())
907 })
908}
909
910pub fn with_auth_handler<H: AuthHandler + 'static>(handler: H) -> ServerOption {
914 Box::new(move |opts| {
915 opts.auth_handler = Some(Arc::new(handler));
916 Ok(())
917 })
918}
919
920pub fn with_max_auth_attempts(max: u32) -> ServerOption {
922 Box::new(move |opts| {
923 opts.max_auth_attempts = max;
924 Ok(())
925 })
926}
927
928pub fn with_auth_rejection_delay(delay_ms: u64) -> ServerOption {
930 Box::new(move |opts| {
931 opts.auth_rejection_delay_ms = delay_ms;
932 Ok(())
933 })
934}
935
936pub fn with_public_key_auth<F>(handler: F) -> ServerOption
938where
939 F: Fn(&Context, &PublicKey) -> bool + Send + Sync + 'static,
940{
941 Box::new(move |opts| {
942 opts.public_key_handler = Some(Arc::new(handler));
943 Ok(())
944 })
945}
946
947pub fn with_password_auth<F>(handler: F) -> ServerOption
949where
950 F: Fn(&Context, &str) -> bool + Send + Sync + 'static,
951{
952 Box::new(move |opts| {
953 opts.password_handler = Some(Arc::new(handler));
954 Ok(())
955 })
956}
957
958pub fn with_keyboard_interactive_auth<F>(handler: F) -> ServerOption
960where
961 F: Fn(&Context, &str, &[String], &[bool]) -> Vec<String> + Send + Sync + 'static,
962{
963 Box::new(move |opts| {
964 opts.keyboard_interactive_handler = Some(Arc::new(handler));
965 Ok(())
966 })
967}
968
969pub fn with_idle_timeout(duration: Duration) -> ServerOption {
971 Box::new(move |opts| {
972 opts.idle_timeout = Some(duration);
973 Ok(())
974 })
975}
976
977pub fn with_max_timeout(duration: Duration) -> ServerOption {
979 Box::new(move |opts| {
980 opts.max_timeout = Some(duration);
981 Ok(())
982 })
983}
984
985pub fn with_subsystem<F, Fut>(name: impl Into<String>, handler: F) -> ServerOption
987where
988 F: Fn(Session) -> Fut + Send + Sync + 'static,
989 Fut: Future<Output = ()> + Send + 'static,
990{
991 let name = name.into();
992 Box::new(move |opts| {
993 opts.subsystem_handlers
994 .insert(name, Arc::new(move |s| Box::pin(handler(s))));
995 Ok(())
996 })
997}
998
999pub struct Server {
1005 options: ServerOptions,
1007}
1008
1009impl fmt::Debug for Server {
1010 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1011 f.debug_struct("Server")
1012 .field("options", &self.options)
1013 .finish()
1014 }
1015}
1016
1017impl Server {
1018 pub fn new(options: impl IntoIterator<Item = ServerOption>) -> Result<Self> {
1020 let mut opts = ServerOptions::default();
1021 for opt in options {
1022 opt(&mut opts)?;
1023 }
1024 Ok(Self { options: opts })
1025 }
1026
1027 pub fn options(&self) -> &ServerOptions {
1029 &self.options
1030 }
1031
1032 pub fn address(&self) -> &str {
1034 &self.options.address
1035 }
1036
1037 pub async fn listen(&self) -> Result<()> {
1042 info!("Starting SSH server on {}", self.options.address);
1043
1044 let addr: SocketAddr = self.options.address.parse()?;
1046 debug!("Parsed address: {:?}", addr);
1047
1048 let config = self.create_russh_config()?;
1050 let config = Arc::new(config);
1051
1052 let factory = WishHandlerFactory::new(self.options.clone());
1054
1055 let listener = TcpListener::bind(addr).await?;
1057 let local_addr = listener.local_addr().unwrap_or(addr);
1058 info!("Server listening on {}", local_addr);
1059
1060 self.listen_with_listener_inner(listener, config, factory, local_addr)
1061 .await
1062 }
1063
1064 pub async fn listen_with_listener(&self, listener: TcpListener) -> Result<()> {
1069 let local_addr = listener.local_addr()?;
1070
1071 let config = self.create_russh_config()?;
1073 let config = Arc::new(config);
1074
1075 let factory = WishHandlerFactory::new(self.options.clone());
1077
1078 info!("Server listening on {}", local_addr);
1079 self.listen_with_listener_inner(listener, config, factory, local_addr)
1080 .await
1081 }
1082
1083 async fn listen_with_listener_inner(
1084 &self,
1085 listener: TcpListener,
1086 config: Arc<RusshConfig>,
1087 factory: WishHandlerFactory,
1088 local_addr: SocketAddr,
1089 ) -> Result<()> {
1090 loop {
1092 match listener.accept().await {
1093 Ok((socket, peer_addr)) => {
1094 info!(peer_addr = %peer_addr, "Accepted connection");
1095
1096 let config = config.clone();
1097 let socket_local_addr = socket.local_addr().unwrap_or(local_addr);
1098 let handler = factory.create_handler(peer_addr, socket_local_addr);
1099
1100 tokio::spawn(async move {
1102 debug!(peer_addr = %peer_addr, "Running SSH session");
1103 match run_stream(config, socket, handler).await {
1104 Ok(session) => {
1105 match session.await {
1107 Ok(()) => {
1108 debug!(peer_addr = %peer_addr, "Connection closed cleanly");
1109 }
1110 Err(e) => {
1111 warn!(peer_addr = %peer_addr, error = %e, "Connection error");
1112 }
1113 }
1114 }
1115 Err(e) => {
1116 error!(peer_addr = %peer_addr, error = %e, "SSH handshake failed");
1117 }
1118 }
1119 });
1120 }
1121 Err(e) => {
1122 error!(error = %e, "Failed to accept connection");
1123 }
1124 }
1125 }
1126 }
1127
1128 #[allow(clippy::field_reassign_with_default)]
1130 fn create_russh_config(&self) -> Result<RusshConfig> {
1131 use russh::MethodSet;
1132 use russh::server::Config;
1133 use russh_keys::key::KeyPair;
1134
1135 let mut config = Config::default();
1136
1137 config.server_id = russh::SshId::Standard(self.options.version.clone());
1139
1140 if let Some(timeout) = self.options.idle_timeout {
1142 config.inactivity_timeout = Some(timeout);
1143 }
1144
1145 config.max_auth_attempts = self.options.max_auth_attempts as usize;
1146 config.auth_rejection_time = Duration::from_millis(self.options.auth_rejection_delay_ms);
1147
1148 let mut methods = MethodSet::empty();
1149 if let Some(handler) = &self.options.auth_handler {
1150 for method in handler.supported_methods() {
1151 if matches!(method, auth::AuthMethod::None) {
1154 methods |= MethodSet::NONE;
1155 } else if matches!(method, auth::AuthMethod::Password) {
1156 methods |= MethodSet::PASSWORD;
1157 } else if matches!(method, auth::AuthMethod::PublicKey) {
1158 methods |= MethodSet::PUBLICKEY;
1159 } else if matches!(method, auth::AuthMethod::KeyboardInteractive) {
1160 methods |= MethodSet::KEYBOARD_INTERACTIVE;
1161 } else if matches!(method, auth::AuthMethod::HostBased) {
1162 methods |= MethodSet::HOSTBASED;
1163 }
1164 }
1165 } else {
1166 if self.options.public_key_handler.is_some() {
1167 methods |= MethodSet::PUBLICKEY;
1168 }
1169 if self.options.password_handler.is_some() {
1170 methods |= MethodSet::PASSWORD;
1171 }
1172 if self.options.keyboard_interactive_handler.is_some() {
1173 methods |= MethodSet::KEYBOARD_INTERACTIVE;
1174 }
1175 if methods.is_empty() {
1176 methods |= MethodSet::NONE;
1177 }
1178 }
1179 config.methods = methods;
1180
1181 let key = if let Some(ref pem) = self.options.host_key_pem {
1183 let private_key = ssh_key::private::PrivateKey::from_openssh(pem)
1185 .map_err(|e| Error::Key(e.to_string()))?;
1186 KeyPair::try_from(&private_key).map_err(|e| Error::Key(e.to_string()))?
1187 } else if let Some(ref path) = self.options.host_key_path {
1188 let pem = std::fs::read(path)?;
1190 let private_key = ssh_key::private::PrivateKey::from_openssh(&pem)
1191 .map_err(|e| Error::Key(e.to_string()))?;
1192 KeyPair::try_from(&private_key).map_err(|e| Error::Key(e.to_string()))?
1193 } else {
1194 info!("Generating ephemeral Ed25519 host key");
1196 KeyPair::generate_ed25519()
1197 };
1198
1199 config.keys.push(key);
1200
1201 if let Some(ref banner) = self.options.banner {
1203 let banner: &'static str = Box::leak(banner.clone().into_boxed_str());
1206 config.auth_banner = Some(banner);
1207 }
1208
1209 Ok(config)
1210 }
1211
1212 pub async fn listen_and_serve(&self) -> Result<()> {
1214 self.listen().await
1215 }
1216}
1217
1218pub fn new_server(options: impl IntoIterator<Item = ServerOption>) -> Result<Server> {
1220 Server::new(options)
1221}
1222
1223#[derive(Default)]
1229pub struct ServerBuilder {
1230 options: ServerOptions,
1231}
1232
1233impl ServerBuilder {
1234 pub fn new() -> Self {
1236 Self::default()
1237 }
1238
1239 pub fn address(mut self, addr: impl Into<String>) -> Self {
1241 self.options.address = addr.into();
1242 self
1243 }
1244
1245 pub fn version(mut self, version: impl Into<String>) -> Self {
1247 self.options.version = version.into();
1248 self
1249 }
1250
1251 pub fn banner(mut self, banner: impl Into<String>) -> Self {
1253 self.options.banner = Some(banner.into());
1254 self
1255 }
1256
1257 pub fn banner_handler<F>(mut self, handler: F) -> Self
1259 where
1260 F: Fn(&Context) -> String + Send + Sync + 'static,
1261 {
1262 self.options.banner_handler = Some(Arc::new(handler));
1263 self
1264 }
1265
1266 pub fn host_key_path(mut self, path: impl Into<String>) -> Self {
1268 self.options.host_key_path = Some(path.into());
1269 self
1270 }
1271
1272 pub fn host_key_pem(mut self, pem: Vec<u8>) -> Self {
1274 self.options.host_key_pem = Some(pem);
1275 self
1276 }
1277
1278 pub fn with_middleware(mut self, mw: Middleware) -> Self {
1280 self.options.middlewares.push(mw);
1281 self
1282 }
1283
1284 pub fn handler<F, Fut>(mut self, handler: F) -> Self
1286 where
1287 F: Fn(Session) -> Fut + Send + Sync + 'static,
1288 Fut: Future<Output = ()> + Send + 'static,
1289 {
1290 self.options.handler = Some(Arc::new(move |session| Box::pin(handler(session))));
1291 self
1292 }
1293
1294 pub fn handler_arc(mut self, handler: Handler) -> Self {
1298 self.options.handler = Some(handler);
1299 self
1300 }
1301
1302 pub fn auth_handler<H: AuthHandler + 'static>(mut self, handler: H) -> Self {
1306 self.options.auth_handler = Some(Arc::new(handler));
1307 self
1308 }
1309
1310 pub fn max_auth_attempts(mut self, max: u32) -> Self {
1312 self.options.max_auth_attempts = max;
1313 self
1314 }
1315
1316 pub fn auth_rejection_delay(mut self, delay_ms: u64) -> Self {
1318 self.options.auth_rejection_delay_ms = delay_ms;
1319 self
1320 }
1321
1322 pub fn allow_no_auth(mut self) -> Self {
1328 self.options.allow_no_auth = true;
1329 self
1330 }
1331
1332 pub fn public_key_auth<F>(mut self, handler: F) -> Self
1334 where
1335 F: Fn(&Context, &PublicKey) -> bool + Send + Sync + 'static,
1336 {
1337 self.options.public_key_handler = Some(Arc::new(handler));
1338 self
1339 }
1340
1341 pub fn password_auth<F>(mut self, handler: F) -> Self
1343 where
1344 F: Fn(&Context, &str) -> bool + Send + Sync + 'static,
1345 {
1346 self.options.password_handler = Some(Arc::new(handler));
1347 self
1348 }
1349
1350 pub fn keyboard_interactive_auth<F>(mut self, handler: F) -> Self
1352 where
1353 F: Fn(&Context, &str, &[String], &[bool]) -> Vec<String> + Send + Sync + 'static,
1354 {
1355 self.options.keyboard_interactive_handler = Some(Arc::new(handler));
1356 self
1357 }
1358
1359 pub fn idle_timeout(mut self, duration: Duration) -> Self {
1361 self.options.idle_timeout = Some(duration);
1362 self
1363 }
1364
1365 pub fn max_timeout(mut self, duration: Duration) -> Self {
1367 self.options.max_timeout = Some(duration);
1368 self
1369 }
1370
1371 pub fn subsystem<F, Fut>(mut self, name: impl Into<String>, handler: F) -> Self
1373 where
1374 F: Fn(Session) -> Fut + Send + Sync + 'static,
1375 Fut: Future<Output = ()> + Send + 'static,
1376 {
1377 self.options
1378 .subsystem_handlers
1379 .insert(name.into(), Arc::new(move |s| Box::pin(handler(s))));
1380 self
1381 }
1382
1383 pub fn build(self) -> Result<Server> {
1385 Ok(Server {
1386 options: self.options,
1387 })
1388 }
1389}
1390
1391pub mod middleware {
1397 use super::*;
1398 use std::time::Instant;
1399
1400 pub mod activeterm {
1402 use super::*;
1403
1404 pub fn middleware() -> Middleware {
1406 Arc::new(|next| {
1407 Arc::new(move |session| {
1408 let next = next.clone();
1409 Box::pin(async move {
1410 let (_, active) = session.pty();
1411 if active {
1412 next(session).await;
1413 } else {
1414 println(&session, "Requires an active PTY");
1415 let _ = session.exit(1);
1416 }
1417 })
1418 })
1419 })
1420 }
1421 }
1422
1423 pub mod accesscontrol {
1425 use super::*;
1426
1427 pub fn middleware(allowed_commands: Vec<String>) -> Middleware {
1429 Arc::new(move |next| {
1430 let allowed = allowed_commands.clone();
1431 Arc::new(move |session| {
1432 let next = next.clone();
1433 let allowed = allowed.clone();
1434 Box::pin(async move {
1435 let cmd = session.command();
1436 if cmd.is_empty() {
1437 next(session).await;
1438 return;
1439 }
1440
1441 let first_cmd = &cmd[0];
1442 if allowed.iter().any(|c| c == first_cmd) {
1443 next(session).await;
1444 } else {
1445 println(&session, format!("Command is not allowed: {}", first_cmd));
1446 let _ = session.exit(1);
1447 }
1448 })
1449 })
1450 })
1451 }
1452 }
1453
1454 pub mod authentication {
1459 use super::*;
1460
1461 pub fn middleware() -> Middleware {
1463 middleware_with_checker(|session| !session.user().is_empty())
1464 }
1465
1466 pub fn middleware_with_checker<C>(checker: C) -> Middleware
1468 where
1469 C: Fn(&Session) -> bool + Send + Sync + 'static,
1470 {
1471 let checker = Arc::new(checker);
1472 Arc::new(move |next| {
1473 let checker = checker.clone();
1474 Arc::new(move |session| {
1475 let next = next.clone();
1476 let checker = checker.clone();
1477 Box::pin(async move {
1478 if checker(&session) {
1479 next(session).await;
1480 } else {
1481 fatalln(&session, "authentication required");
1482 }
1483 })
1484 })
1485 })
1486 }
1487 }
1488
1489 pub mod authorization {
1491 use super::*;
1492
1493 pub fn middleware() -> Middleware {
1497 middleware_with_checker(|_session| true)
1498 }
1499
1500 pub fn middleware_with_checker<C>(checker: C) -> Middleware
1502 where
1503 C: Fn(&Session) -> bool + Send + Sync + 'static,
1504 {
1505 let checker = Arc::new(checker);
1506 Arc::new(move |next| {
1507 let checker = checker.clone();
1508 Arc::new(move |session| {
1509 let next = next.clone();
1510 let checker = checker.clone();
1511 Box::pin(async move {
1512 if checker(&session) {
1513 next(session).await;
1514 } else {
1515 fatalln(&session, "permission denied");
1516 }
1517 })
1518 })
1519 })
1520 }
1521 }
1522
1523 pub mod session_handler {
1525 use super::*;
1526
1527 pub fn middleware() -> Middleware {
1529 Arc::new(|next| {
1530 Arc::new(move |session| {
1531 let next = next.clone();
1532 Box::pin(async move {
1533 next(session.clone()).await;
1534 if !session.is_closed() {
1535 let _ = session.close();
1536 }
1537 })
1538 })
1539 })
1540 }
1541 }
1542
1543 pub mod pty {
1548 use super::*;
1549
1550 pub fn middleware() -> Middleware {
1552 Arc::new(|next| {
1553 Arc::new(move |session| {
1554 let next = next.clone();
1555 Box::pin(async move {
1556 let (_, active) = session.pty();
1557 if active {
1558 next(session).await;
1559 } else {
1560 fatalln(&session, "pty required");
1561 }
1562 })
1563 })
1564 })
1565 }
1566 }
1567
1568 pub mod git {
1573 use super::*;
1574
1575 fn looks_like_git_command(cmd: &[String]) -> bool {
1576 cmd.first()
1577 .is_some_and(|c| c == "git" || c.starts_with("git-"))
1578 }
1579
1580 pub fn middleware() -> Middleware {
1585 middleware_with_handler(|session| async move {
1586 fatalln(&session, "git handler not configured");
1587 })
1588 }
1589
1590 pub fn middleware_with_handler<F, Fut>(handler: F) -> Middleware
1592 where
1593 F: Fn(Session) -> Fut + Send + Sync + 'static,
1594 Fut: Future<Output = ()> + Send + 'static,
1595 {
1596 let handler = Arc::new(handler);
1597 Arc::new(move |next| {
1598 let handler = handler.clone();
1599 Arc::new(move |session| {
1600 let next = next.clone();
1601 let handler = handler.clone();
1602 Box::pin(async move {
1603 if looks_like_git_command(session.command()) {
1604 handler(session).await;
1605 } else {
1606 next(session).await;
1607 }
1608 })
1609 })
1610 })
1611 }
1612 }
1613
1614 pub mod scp {
1616 use super::*;
1617
1618 fn looks_like_scp_command(cmd: &[String]) -> bool {
1619 cmd.first().is_some_and(|c| c == "scp")
1620 }
1621
1622 pub fn middleware() -> Middleware {
1627 middleware_with_handler(|session| async move {
1628 fatalln(&session, "scp handler not configured");
1629 })
1630 }
1631
1632 pub fn middleware_with_handler<F, Fut>(handler: F) -> Middleware
1634 where
1635 F: Fn(Session) -> Fut + Send + Sync + 'static,
1636 Fut: Future<Output = ()> + Send + 'static,
1637 {
1638 let handler = Arc::new(handler);
1639 Arc::new(move |next| {
1640 let handler = handler.clone();
1641 Arc::new(move |session| {
1642 let next = next.clone();
1643 let handler = handler.clone();
1644 Box::pin(async move {
1645 if looks_like_scp_command(session.command()) {
1646 handler(session).await;
1647 } else {
1648 next(session).await;
1649 }
1650 })
1651 })
1652 })
1653 }
1654 }
1655
1656 pub mod sftp {
1658 use super::*;
1659
1660 fn looks_like_sftp_session(session: &Session) -> bool {
1661 session.subsystem() == Some("sftp")
1662 || session.command().first().is_some_and(|c| c == "sftp")
1663 }
1664
1665 pub fn middleware() -> Middleware {
1670 middleware_with_handler(|session| async move {
1671 fatalln(&session, "sftp handler not configured");
1672 })
1673 }
1674
1675 pub fn middleware_with_handler<F, Fut>(handler: F) -> Middleware
1677 where
1678 F: Fn(Session) -> Fut + Send + Sync + 'static,
1679 Fut: Future<Output = ()> + Send + 'static,
1680 {
1681 let handler = Arc::new(handler);
1682 Arc::new(move |next| {
1683 let handler = handler.clone();
1684 Arc::new(move |session| {
1685 let next = next.clone();
1686 let handler = handler.clone();
1687 Box::pin(async move {
1688 if looks_like_sftp_session(&session) {
1689 handler(session).await;
1690 } else {
1691 next(session).await;
1692 }
1693 })
1694 })
1695 })
1696 }
1697 }
1698
1699 pub mod logging {
1701 use super::*;
1702
1703 pub trait Logger: Send + Sync {
1705 fn log(&self, format: &str, args: &[&dyn fmt::Display]);
1706 }
1707
1708 #[allow(clippy::too_many_arguments)]
1710 pub trait StructuredLogger: Send + Sync {
1711 fn log_connect(
1712 &self,
1713 level: tracing::Level,
1714 user: &str,
1715 remote_addr: &SocketAddr,
1716 public_key: bool,
1717 command: &[String],
1718 term: &str,
1719 width: u32,
1720 height: u32,
1721 client_version: &str,
1722 );
1723
1724 fn log_disconnect(
1725 &self,
1726 level: tracing::Level,
1727 user: &str,
1728 remote_addr: &SocketAddr,
1729 duration: Duration,
1730 );
1731 }
1732
1733 #[derive(Clone, Copy)]
1735 pub struct TracingLogger;
1736
1737 impl Logger for TracingLogger {
1738 fn log(&self, format: &str, args: &[&dyn fmt::Display]) {
1739 let mut msg = format.to_string();
1740 for arg in args {
1741 if let Some(pos) = msg.find("{}") {
1742 msg.replace_range(pos..pos + 2, &arg.to_string());
1743 }
1744 }
1745 info!("{}", msg);
1746 }
1747 }
1748
1749 #[derive(Clone, Copy)]
1751 pub struct TracingStructuredLogger;
1752
1753 impl StructuredLogger for TracingStructuredLogger {
1754 fn log_connect(
1755 &self,
1756 level: tracing::Level,
1757 user: &str,
1758 remote_addr: &SocketAddr,
1759 public_key: bool,
1760 command: &[String],
1761 term: &str,
1762 width: u32,
1763 height: u32,
1764 client_version: &str,
1765 ) {
1766 match level {
1767 tracing::Level::TRACE => tracing::event!(
1768 tracing::Level::TRACE,
1769 user = %user,
1770 remote_addr = %remote_addr,
1771 public_key = public_key,
1772 command = ?command,
1773 term = %term,
1774 width = width,
1775 height = height,
1776 client_version = %client_version,
1777 "connect"
1778 ),
1779 tracing::Level::DEBUG => tracing::event!(
1780 tracing::Level::DEBUG,
1781 user = %user,
1782 remote_addr = %remote_addr,
1783 public_key = public_key,
1784 command = ?command,
1785 term = %term,
1786 width = width,
1787 height = height,
1788 client_version = %client_version,
1789 "connect"
1790 ),
1791 tracing::Level::INFO => tracing::event!(
1792 tracing::Level::INFO,
1793 user = %user,
1794 remote_addr = %remote_addr,
1795 public_key = public_key,
1796 command = ?command,
1797 term = %term,
1798 width = width,
1799 height = height,
1800 client_version = %client_version,
1801 "connect"
1802 ),
1803 tracing::Level::WARN => tracing::event!(
1804 tracing::Level::WARN,
1805 user = %user,
1806 remote_addr = %remote_addr,
1807 public_key = public_key,
1808 command = ?command,
1809 term = %term,
1810 width = width,
1811 height = height,
1812 client_version = %client_version,
1813 "connect"
1814 ),
1815 tracing::Level::ERROR => tracing::event!(
1816 tracing::Level::ERROR,
1817 user = %user,
1818 remote_addr = %remote_addr,
1819 public_key = public_key,
1820 command = ?command,
1821 term = %term,
1822 width = width,
1823 height = height,
1824 client_version = %client_version,
1825 "connect"
1826 ),
1827 }
1828 }
1829
1830 fn log_disconnect(
1831 &self,
1832 level: tracing::Level,
1833 user: &str,
1834 remote_addr: &SocketAddr,
1835 duration: Duration,
1836 ) {
1837 match level {
1838 tracing::Level::TRACE => tracing::event!(
1839 tracing::Level::TRACE,
1840 user = %user,
1841 remote_addr = %remote_addr,
1842 duration = ?duration,
1843 "disconnect"
1844 ),
1845 tracing::Level::DEBUG => tracing::event!(
1846 tracing::Level::DEBUG,
1847 user = %user,
1848 remote_addr = %remote_addr,
1849 duration = ?duration,
1850 "disconnect"
1851 ),
1852 tracing::Level::INFO => tracing::event!(
1853 tracing::Level::INFO,
1854 user = %user,
1855 remote_addr = %remote_addr,
1856 duration = ?duration,
1857 "disconnect"
1858 ),
1859 tracing::Level::WARN => tracing::event!(
1860 tracing::Level::WARN,
1861 user = %user,
1862 remote_addr = %remote_addr,
1863 duration = ?duration,
1864 "disconnect"
1865 ),
1866 tracing::Level::ERROR => tracing::event!(
1867 tracing::Level::ERROR,
1868 user = %user,
1869 remote_addr = %remote_addr,
1870 duration = ?duration,
1871 "disconnect"
1872 ),
1873 }
1874 }
1875 }
1876
1877 pub fn middleware() -> Middleware {
1879 middleware_with_logger(TracingLogger)
1880 }
1881
1882 pub fn middleware_with_logger<L: Logger + 'static>(logger: L) -> Middleware {
1884 let logger = Arc::new(logger);
1885 Arc::new(move |next| {
1886 let logger = logger.clone();
1887 Arc::new(move |session| {
1888 let next = next.clone();
1889 let logger = logger.clone();
1890 let start = Instant::now();
1891
1892 let user = session.user().to_string();
1894 let remote_addr = session.remote_addr().to_string();
1895 let has_key = session.public_key().is_some();
1896 let command = session.command().to_vec();
1897 let (pty, _) = session.pty();
1898 let term = pty.map(|p| p.term.clone()).unwrap_or_default();
1899 let window = session.window();
1900 let client_version = session.context().client_version();
1901
1902 logger.log(
1903 "{} connect {} {} {} {} {} {} {}",
1904 &[
1905 &user as &dyn fmt::Display,
1906 &remote_addr,
1907 &has_key,
1908 &format!("{:?}", command),
1909 &term,
1910 &window.width,
1911 &window.height,
1912 &client_version,
1913 ],
1914 );
1915
1916 Box::pin(async move {
1917 next(session.clone()).await;
1918
1919 let duration = start.elapsed();
1921 logger.log(
1922 "{} disconnect {}",
1923 &[
1924 &remote_addr as &dyn fmt::Display,
1925 &format!("{:?}", duration),
1926 ],
1927 );
1928 })
1929 })
1930 })
1931 }
1932
1933 pub fn structured_middleware() -> Middleware {
1935 structured_middleware_with_logger(TracingStructuredLogger, tracing::Level::INFO)
1936 }
1937
1938 pub fn structured_middleware_with_logger<L: StructuredLogger + 'static>(
1940 logger: L,
1941 level: tracing::Level,
1942 ) -> Middleware {
1943 let logger = Arc::new(logger);
1944 Arc::new(move |next| {
1945 let logger = logger.clone();
1946 Arc::new(move |session| {
1947 let next = next.clone();
1948 let logger = logger.clone();
1949 let level = level;
1950 let start = Instant::now();
1951
1952 let user = session.user().to_string();
1953 let remote_addr = session.remote_addr();
1954 let has_key = session.public_key().is_some();
1955 let command = session.command().to_vec();
1956 let (pty, _) = session.pty();
1957 let term = pty.map(|p| p.term.clone()).unwrap_or_default();
1958 let window = session.window();
1959 let client_version = session.context().client_version().to_string();
1960
1961 logger.log_connect(
1962 level,
1963 &user,
1964 &remote_addr,
1965 has_key,
1966 &command,
1967 &term,
1968 window.width,
1969 window.height,
1970 &client_version,
1971 );
1972
1973 Box::pin(async move {
1974 next(session.clone()).await;
1975
1976 let duration = start.elapsed();
1977 logger.log_disconnect(level, &user, &remote_addr, duration);
1978 })
1979 })
1980 })
1981 }
1982 }
1983
1984 pub mod recover {
1986 use super::*;
1987
1988 pub fn middleware() -> Middleware {
1990 middleware_with_middlewares(vec![])
1991 }
1992
1993 pub fn middleware_with_middlewares(mws: Vec<Middleware>) -> Middleware {
1995 Arc::new(move |next| {
1996 let mws = mws.clone();
1997
1998 let mut inner_handler = noop_handler();
2000 for mw in mws.iter().rev() {
2001 inner_handler = mw(inner_handler);
2002 }
2003
2004 let inner = inner_handler;
2005 Arc::new(move |session| {
2006 let next = next.clone();
2007 let inner = inner.clone();
2008 Box::pin(async move {
2009 inner(session.clone()).await;
2012 next(session).await;
2013 })
2014 })
2015 })
2016 }
2017
2018 pub trait Logger: Send + Sync {
2020 fn log_panic(&self, error: &str, stack: &str);
2021 }
2022
2023 #[derive(Clone, Copy)]
2025 pub struct DefaultLogger;
2026
2027 impl Logger for DefaultLogger {
2028 fn log_panic(&self, error: &str, stack: &str) {
2029 error!("panic: {}\n{}", error, stack);
2030 }
2031 }
2032 }
2033
2034 pub mod ratelimiter {
2036 use super::*;
2037 use lru::LruCache;
2038 use std::num::NonZeroUsize;
2039 use std::time::Instant;
2040
2041 pub const ERR_RATE_LIMIT_EXCEEDED: &str = "rate limit exceeded, please try again later";
2043
2044 #[derive(Clone)]
2046 pub struct Config {
2047 pub rate_per_sec: f64,
2049 pub burst: usize,
2051 pub max_entries: usize,
2053 }
2054
2055 impl Default for Config {
2056 fn default() -> Self {
2057 Self {
2058 rate_per_sec: 1.0,
2059 burst: 10,
2060 max_entries: 1000,
2061 }
2062 }
2063 }
2064
2065 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
2067 pub struct RateLimitError;
2068
2069 impl fmt::Display for RateLimitError {
2070 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2071 write!(f, "{ERR_RATE_LIMIT_EXCEEDED}")
2072 }
2073 }
2074
2075 impl std::error::Error for RateLimitError {}
2076
2077 pub trait RateLimiter: Send + Sync {
2079 fn allow(&self, session: &Session) -> std::result::Result<(), RateLimitError>;
2080 }
2081
2082 #[derive(Debug, Clone)]
2083 struct TokenBucketState {
2084 tokens: f64,
2085 last: Instant,
2086 }
2087
2088 pub struct TokenBucketLimiter {
2090 rate_per_sec: f64,
2091 burst: f64,
2092 cache: RwLock<LruCache<String, TokenBucketState>>,
2093 }
2094
2095 impl TokenBucketLimiter {
2096 pub fn new(rate_per_sec: f64, burst: usize, max_entries: usize) -> Self {
2097 let max_entries = max_entries.max(1);
2098 let cache = LruCache::new(NonZeroUsize::new(max_entries).unwrap());
2099 Self {
2100 rate_per_sec: rate_per_sec.max(0.0),
2101 burst: burst.max(1) as f64,
2102 cache: RwLock::new(cache),
2103 }
2104 }
2105
2106 fn allow_key(&self, key: &str) -> bool {
2107 let now = Instant::now();
2108 let mut cache = self.cache.write();
2109
2110 #[allow(clippy::manual_inspect)]
2111 let state = cache
2112 .get_mut(key)
2113 .map(|state| {
2114 let elapsed = now.duration_since(state.last).as_secs_f64();
2115 state.tokens = (state.tokens + elapsed * self.rate_per_sec).min(self.burst);
2116 state.last = now;
2117 state
2118 })
2119 .cloned();
2120
2121 let mut state = state.unwrap_or(TokenBucketState {
2122 tokens: self.burst,
2123 last: now,
2124 });
2125
2126 let allowed = if state.tokens >= 1.0 {
2127 state.tokens -= 1.0;
2128 true
2129 } else {
2130 false
2131 };
2132
2133 cache.put(key.to_string(), state);
2134 allowed
2135 }
2136 }
2137
2138 impl RateLimiter for TokenBucketLimiter {
2139 fn allow(&self, session: &Session) -> std::result::Result<(), RateLimitError> {
2140 let key = session.remote_addr().ip().to_string();
2141 let allowed = self.allow_key(&key);
2142 debug!(key = %key, allowed, "rate limiter key");
2143 if allowed { Ok(()) } else { Err(RateLimitError) }
2144 }
2145 }
2146
2147 pub fn new_rate_limiter(
2149 rate_per_sec: f64,
2150 burst: usize,
2151 max_entries: usize,
2152 ) -> TokenBucketLimiter {
2153 TokenBucketLimiter::new(rate_per_sec, burst, max_entries)
2154 }
2155
2156 pub fn middleware<L: RateLimiter + 'static>(limiter: L) -> Middleware {
2158 let limiter = Arc::new(limiter);
2159 Arc::new(move |next| {
2160 let limiter = limiter.clone();
2161 Arc::new(move |session| {
2162 let next = next.clone();
2163 let limiter = limiter.clone();
2164 Box::pin(async move {
2165 match limiter.allow(&session) {
2166 Ok(()) => {
2167 next(session).await;
2168 }
2169 Err(err) => {
2170 warn!(remote_addr = %session.remote_addr(), "rate limited");
2171 fatal(&session, err);
2172 }
2173 }
2174 })
2175 })
2176 })
2177 }
2178
2179 pub fn middleware_with_config(config: Config) -> Middleware {
2181 middleware(new_rate_limiter(
2182 config.rate_per_sec,
2183 config.burst,
2184 config.max_entries,
2185 ))
2186 }
2187 }
2188
2189 pub mod elapsed {
2191 use super::*;
2192
2193 fn format_elapsed(format: &str, elapsed: Duration) -> String {
2194 if format.contains("%v") {
2195 format.replace("%v", &format!("{:?}", elapsed))
2196 } else {
2197 format.replace("{}", &format!("{:?}", elapsed)).to_string()
2198 }
2199 }
2200
2201 pub fn middleware_with_format(format: impl Into<String>) -> Middleware {
2203 let format = format.into();
2204 Arc::new(move |next| {
2205 let format = format.clone();
2206 Arc::new(move |session| {
2207 let next = next.clone();
2208 let format = format.clone();
2209 Box::pin(async move {
2210 let start = Instant::now();
2211 next(session.clone()).await;
2212 let msg = format_elapsed(&format, start.elapsed());
2213 print(&session, msg);
2214 })
2215 })
2216 })
2217 }
2218
2219 pub fn middleware() -> Middleware {
2221 middleware_with_format("elapsed time: %v\n")
2222 }
2223 }
2224
2225 pub mod comment {
2227 use super::*;
2228
2229 pub fn middleware(message: impl Into<String>) -> Middleware {
2231 let message = message.into();
2232 Arc::new(move |next| {
2233 let message = message.clone();
2234 Arc::new(move |session| {
2235 let next = next.clone();
2236 let message = message.clone();
2237 Box::pin(async move {
2238 next(session.clone()).await;
2239 println(&session, &message);
2240 })
2241 })
2242 })
2243 }
2244 }
2245}
2246
2247pub mod tea {
2253 use super::*;
2254 use bubbletea::{Model, Program};
2255
2256 pub type TeaHandler<M> = Arc<dyn Fn(&Session) -> M + Send + Sync>;
2258
2259 pub fn middleware<M, F>(handler: F) -> Middleware
2261 where
2262 M: Model + Send + Sync + 'static,
2263 F: Fn(&Session) -> M + Send + Sync + 'static,
2264 {
2265 let handler = Arc::new(handler);
2266 Arc::new(move |next| {
2267 let handler = handler.clone();
2268 Arc::new(move |session| {
2269 let next = next.clone();
2270 let handler = handler.clone();
2271 Box::pin(async move {
2272 let (_pty, active) = session.pty();
2273 if !active {
2274 fatalln(&session, "no active terminal, skipping");
2275 return;
2276 }
2277
2278 let model = handler(&session);
2280
2281 let (tx, rx) = std::sync::mpsc::channel();
2283 session.set_message_sender(tx);
2284
2285 let session_clone = session.clone();
2287 let run_result = tokio::task::spawn_blocking(move || {
2288 let _ = Program::new(model)
2289 .with_custom_io()
2290 .with_input_receiver(rx)
2291 .run_with_writer(session_clone);
2292 })
2293 .await;
2294 if let Err(err) = run_result {
2295 fatalln(&session, format!("bubbletea program crashed: {err}"));
2296 return;
2297 }
2298
2299 next(session).await;
2300 })
2301 })
2302 })
2303 }
2304
2305 pub fn make_renderer(session: &Session) -> lipgloss::Renderer {
2307 let (pty, _) = session.pty();
2308 let term = pty.map(|p| p.term.as_str()).unwrap_or("xterm-256color");
2309
2310 let profile = if term.contains("256color") || term.contains("truecolor") {
2312 lipgloss::ColorProfile::TrueColor
2313 } else if term.contains("color") {
2314 lipgloss::ColorProfile::Ansi256
2315 } else {
2316 lipgloss::ColorProfile::Ansi
2317 };
2318
2319 let mut renderer = lipgloss::Renderer::new();
2320 renderer.set_color_profile(profile);
2321 renderer
2322 }
2323}
2324
2325pub mod prelude {
2331 pub use crate::{
2332 Context, Error, Handler, Middleware, Pty, PublicKey, Result, Server, ServerBuilder,
2333 ServerOption, ServerOptions, Session, Window, compose_middleware, error, errorf, errorln,
2334 fatal, fatalf, fatalln, handler, new_server, noop_handler, print, printf, println,
2335 with_address, with_banner, with_banner_handler, with_host_key_path, with_host_key_pem,
2336 with_idle_timeout, with_keyboard_interactive_auth, with_max_timeout, with_middleware,
2337 with_password_auth, with_public_key_auth, with_subsystem, with_version, write_string,
2338 };
2339
2340 pub use crate::middleware::{
2341 accesscontrol, activeterm, comment, elapsed, logging, ratelimiter, recover,
2342 };
2343
2344 pub use crate::tea;
2345}
2346
2347#[cfg(test)]
2352mod tests {
2353 use super::*;
2354 use std::fmt;
2355 use std::sync::Arc;
2356 use std::sync::Mutex;
2357 use std::sync::atomic::{AtomicUsize, Ordering};
2358
2359 struct DenyLimiter;
2360
2361 impl middleware::ratelimiter::RateLimiter for DenyLimiter {
2362 fn allow(
2363 &self,
2364 _session: &Session,
2365 ) -> std::result::Result<(), middleware::ratelimiter::RateLimitError> {
2366 Err(middleware::ratelimiter::RateLimitError)
2367 }
2368 }
2369
2370 fn record_middleware(label: &'static str, events: Arc<Mutex<Vec<&'static str>>>) -> Middleware {
2371 Arc::new(move |next| {
2372 let events = events.clone();
2373 Arc::new(move |session| {
2374 let next = next.clone();
2375 let events = events.clone();
2376 Box::pin(async move {
2377 {
2378 let mut guard = events.lock().expect("events lock");
2379 guard.push(label);
2380 }
2381 next(session).await;
2382 })
2383 })
2384 })
2385 }
2386
2387 #[derive(Clone)]
2388 struct TestLogger {
2389 entries: Arc<Mutex<Vec<String>>>,
2390 }
2391
2392 impl middleware::logging::Logger for TestLogger {
2393 fn log(&self, format: &str, args: &[&dyn fmt::Display]) {
2394 let mut msg = format.to_string();
2395 for arg in args {
2396 if let Some(pos) = msg.find("{}") {
2397 msg.replace_range(pos..pos + 2, &arg.to_string());
2398 }
2399 }
2400 self.entries.lock().expect("logger entries").push(msg);
2401 }
2402 }
2403
2404 #[derive(Clone, Default)]
2405 struct TestStructuredLogger {
2406 connects: Arc<Mutex<Vec<(String, SocketAddr, bool)>>>,
2407 disconnects: Arc<Mutex<Vec<(String, SocketAddr)>>>,
2408 }
2409
2410 impl middleware::logging::StructuredLogger for TestStructuredLogger {
2411 fn log_connect(
2412 &self,
2413 _level: tracing::Level,
2414 user: &str,
2415 remote_addr: &SocketAddr,
2416 public_key: bool,
2417 _command: &[String],
2418 _term: &str,
2419 _width: u32,
2420 _height: u32,
2421 _client_version: &str,
2422 ) {
2423 self.connects.lock().expect("structured connects").push((
2424 user.to_string(),
2425 *remote_addr,
2426 public_key,
2427 ));
2428 }
2429
2430 fn log_disconnect(
2431 &self,
2432 _level: tracing::Level,
2433 user: &str,
2434 remote_addr: &SocketAddr,
2435 _duration: Duration,
2436 ) {
2437 self.disconnects
2438 .lock()
2439 .expect("structured disconnects")
2440 .push((user.to_string(), *remote_addr));
2441 }
2442 }
2443
2444 #[derive(Clone, Default)]
2445 struct PanicTeaModel;
2446
2447 impl bubbletea::Model for PanicTeaModel {
2448 fn init(&self) -> Option<bubbletea::Cmd> {
2449 None
2450 }
2451
2452 fn update(&mut self, _msg: Message) -> Option<bubbletea::Cmd> {
2453 None
2454 }
2455
2456 fn view(&self) -> String {
2457 std::panic::panic_any("panic from test tea model")
2458 }
2459 }
2460
2461 #[test]
2462 fn test_window_default() {
2463 let window = Window::default();
2464 assert_eq!(window.width, 80);
2465 assert_eq!(window.height, 24);
2466 }
2467
2468 #[test]
2469 fn test_pty_default() {
2470 let pty = Pty::default();
2471 assert_eq!(pty.term, "xterm-256color");
2472 assert_eq!(pty.window.width, 80);
2473 }
2474
2475 #[test]
2476 fn test_public_key() {
2477 let key = PublicKey::new("ssh-ed25519", vec![1, 2, 3, 4]);
2478 assert_eq!(key.key_type, "ssh-ed25519");
2479 assert_eq!(key.data, vec![1, 2, 3, 4]);
2480 assert!(key.comment.is_none());
2481
2482 let key = key.with_comment("test_key_comment");
2483 assert_eq!(key.comment, Some("test_key_comment".to_string()));
2484 }
2485
2486 #[test]
2487 fn test_public_key_fingerprint() {
2488 let key = PublicKey::new("ssh-ed25519", vec![1, 2, 3, 4]);
2489 let fp = key.fingerprint();
2490 assert!(fp.starts_with("HASH:"));
2491 }
2492
2493 #[test]
2494 fn test_public_key_equality() {
2495 let key1 = PublicKey::new("ssh-ed25519", vec![1, 2, 3, 4]);
2496 let key2 = PublicKey::new("ssh-ed25519", vec![1, 2, 3, 4]);
2497 let key3 = PublicKey::new("ssh-ed25519", vec![5, 6, 7, 8]);
2498
2499 assert_eq!(key1, key2);
2500 assert_ne!(key1, key3);
2501 }
2502
2503 #[test]
2504 fn test_context() {
2505 let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2506 let ctx = Context::new("testuser", addr, addr);
2507
2508 assert_eq!(ctx.user(), "testuser");
2509 assert_eq!(ctx.remote_addr(), addr);
2510
2511 ctx.set_value("key", "value");
2512 assert_eq!(ctx.get_value("key"), Some("value".to_string()));
2513 assert_eq!(ctx.get_value("missing"), None);
2514 }
2515
2516 #[test]
2517 fn test_session_basic() {
2518 let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2519 let ctx = Context::new("testuser", addr, addr);
2520 let session = Session::new(ctx);
2521
2522 assert_eq!(session.user(), "testuser");
2523 assert!(session.command().is_empty());
2524 assert!(session.public_key().is_none());
2525 }
2526
2527 #[test]
2528 fn test_session_builder() {
2529 let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2530 let ctx = Context::new("testuser", addr, addr);
2531
2532 let pty = Pty {
2533 term: "xterm".to_string(),
2534 window: Window {
2535 width: 120,
2536 height: 40,
2537 },
2538 };
2539
2540 let session = Session::new(ctx)
2541 .with_pty(pty)
2542 .with_command(vec!["ls".to_string(), "-la".to_string()])
2543 .with_env("HOME", "/home/user");
2544
2545 let (pty_ref, active) = session.pty();
2546 assert!(active);
2547 assert_eq!(pty_ref.unwrap().term, "xterm");
2548 assert_eq!(session.command(), &["ls", "-la"]);
2549 assert_eq!(session.get_env("HOME"), Some(&"/home/user".to_string()));
2550 }
2551
2552 #[test]
2553 fn test_session_write() {
2554 let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2555 let ctx = Context::new("testuser", addr, addr);
2556 let session = Session::new(ctx);
2557
2558 let n = session.write(b"hello").unwrap();
2559 assert_eq!(n, 5);
2560
2561 let n = session.write_stderr(b"error").unwrap();
2562 assert_eq!(n, 5);
2563 }
2564
2565 #[test]
2566 fn test_session_exit_close() {
2567 let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2568 let ctx = Context::new("testuser", addr, addr);
2569 let session = Session::new(ctx);
2570
2571 assert!(!session.is_closed());
2572 session.exit(0).unwrap();
2573 session.close().unwrap();
2574 assert!(session.is_closed());
2575 }
2576
2577 #[test]
2578 fn test_server_options_default() {
2579 let opts = ServerOptions::default();
2580 assert_eq!(opts.address, "0.0.0.0:22");
2581 assert_eq!(opts.version, "SSH-2.0-Wish");
2582 assert!(opts.banner.is_none());
2583 }
2584
2585 #[test]
2586 fn test_server_builder() {
2587 let server = ServerBuilder::new()
2588 .address("0.0.0.0:2222")
2589 .version("SSH-2.0-MyApp")
2590 .banner("Welcome!")
2591 .idle_timeout(Duration::from_secs(300))
2592 .build()
2593 .unwrap();
2594
2595 assert_eq!(server.address(), "0.0.0.0:2222");
2596 assert_eq!(server.options().version, "SSH-2.0-MyApp");
2597 assert_eq!(server.options().banner, Some("Welcome!".to_string()));
2598 assert_eq!(
2599 server.options().idle_timeout,
2600 Some(Duration::from_secs(300))
2601 );
2602 }
2603
2604 #[test]
2605 fn test_option_functions() {
2606 let mut opts = ServerOptions::default();
2607
2608 with_address("localhost:22")(&mut opts).unwrap();
2609 assert_eq!(opts.address, "localhost:22");
2610
2611 with_version("SSH-2.0-Test")(&mut opts).unwrap();
2612 assert_eq!(opts.version, "SSH-2.0-Test");
2613
2614 with_banner("Hello")(&mut opts).unwrap();
2615 assert_eq!(opts.banner, Some("Hello".to_string()));
2616
2617 with_idle_timeout(Duration::from_secs(60))(&mut opts).unwrap();
2618 assert_eq!(opts.idle_timeout, Some(Duration::from_secs(60)));
2619
2620 with_max_timeout(Duration::from_secs(3600))(&mut opts).unwrap();
2621 assert_eq!(opts.max_timeout, Some(Duration::from_secs(3600)));
2622 }
2623
2624 #[test]
2625 fn test_new_server() {
2626 let server =
2627 new_server([with_address("0.0.0.0:2222"), with_version("SSH-2.0-Test")]).unwrap();
2628
2629 assert_eq!(server.address(), "0.0.0.0:2222");
2630 assert_eq!(server.options().version, "SSH-2.0-Test");
2631 }
2632
2633 #[test]
2634 fn test_noop_handler() {
2635 let h = noop_handler();
2636 let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2638 let ctx = Context::new("test", addr, addr);
2639 let session = Session::new(ctx);
2640 drop(h(session));
2641 }
2642
2643 #[tokio::test]
2644 async fn test_handler_creation() {
2645 let h = handler(|_session| async {
2646 });
2648
2649 let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2650 let ctx = Context::new("test", addr, addr);
2651 let session = Session::new(ctx);
2652 h(session).await;
2653 }
2654
2655 #[test]
2656 fn test_rate_limiter() {
2657 use middleware::ratelimiter::{RateLimiter, new_rate_limiter};
2658
2659 let limiter = new_rate_limiter(0.0, 3, 10);
2660 let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2661 let ctx = Context::new("testuser", addr, addr);
2662 let session = Session::new(ctx);
2663
2664 assert!(limiter.allow(&session).is_ok());
2665 assert!(limiter.allow(&session).is_ok());
2666 assert!(limiter.allow(&session).is_ok());
2667 assert!(limiter.allow(&session).is_err()); }
2669
2670 #[test]
2671 fn test_output_helpers() -> std::result::Result<(), Box<dyn std::error::Error>> {
2672 let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2673 let ctx = Context::new("test", addr, addr);
2674 let mut session = Session::new(ctx);
2675
2676 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
2677 session.set_output_sender(tx);
2678
2679 print(&session, "hello");
2680 println(&session, "world");
2681 error(&session, "err");
2682 errorln(&session, "error line");
2683
2684 let item = rx.try_recv().map_err(|e| io::Error::other(e.to_string()))?;
2687 match item {
2688 SessionOutput::Stdout(data) => assert_eq!(data, b"hello"),
2689 other => {
2690 return Err(io::Error::other(format!(
2691 "expected stdout for print(), got {other:?}"
2692 ))
2693 .into());
2694 }
2695 }
2696
2697 let item = rx.try_recv().map_err(|e| io::Error::other(e.to_string()))?;
2699 match item {
2700 SessionOutput::Stdout(data) => assert_eq!(data, b"world\r\n"),
2701 other => {
2702 return Err(io::Error::other(format!(
2703 "expected stdout for println(), got {other:?}"
2704 ))
2705 .into());
2706 }
2707 }
2708
2709 let item = rx.try_recv().map_err(|e| io::Error::other(e.to_string()))?;
2711 match item {
2712 SessionOutput::Stderr(data) => assert_eq!(data, b"err"),
2713 other => {
2714 return Err(io::Error::other(format!(
2715 "expected stderr for error(), got {other:?}"
2716 ))
2717 .into());
2718 }
2719 }
2720
2721 let item = rx.try_recv().map_err(|e| io::Error::other(e.to_string()))?;
2723 match item {
2724 SessionOutput::Stderr(data) => assert_eq!(data, b"error line\r\n"),
2725 other => {
2726 return Err(io::Error::other(format!(
2727 "expected stderr for errorln(), got {other:?}"
2728 ))
2729 .into());
2730 }
2731 }
2732
2733 Ok(())
2734 }
2735
2736 #[test]
2737 fn test_tea_make_renderer() {
2738 let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2739 let ctx = Context::new("test", addr, addr);
2740 let pty = Pty {
2741 term: "xterm-256color".to_string(),
2742 window: Window::default(),
2743 };
2744 let session = Session::new(ctx).with_pty(pty);
2745
2746 let _renderer = tea::make_renderer(&session);
2747 }
2749
2750 #[tokio::test]
2751 async fn test_tea_middleware_handles_program_panic()
2752 -> std::result::Result<(), Box<dyn std::error::Error>> {
2753 let called = Arc::new(AtomicUsize::new(0));
2754 let mw = tea::middleware(|_session| PanicTeaModel);
2755 let next = handler({
2756 let called = called.clone();
2757 move |_session| {
2758 let called = called.clone();
2759 async move {
2760 called.fetch_add(1, Ordering::SeqCst);
2761 }
2762 }
2763 });
2764
2765 let addr: SocketAddr = "127.0.0.1:2222".parse().map_err(io::Error::other)?;
2766 let ctx = Context::new("test", addr, addr);
2767 let mut session = Session::new(ctx).with_pty(Pty::default());
2768
2769 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
2770 session.set_output_sender(tx);
2771
2772 mw(next)(session).await;
2773
2774 assert_eq!(called.load(Ordering::SeqCst), 0);
2775
2776 let mut saw_fatal = false;
2777 let mut saw_exit = false;
2778 let mut saw_close = false;
2779 loop {
2780 match rx.try_recv() {
2781 Ok(SessionOutput::Stderr(data)) => {
2782 let msg = String::from_utf8_lossy(&data);
2783 if msg.contains("bubbletea program crashed:") {
2784 saw_fatal = true;
2785 }
2786 }
2787 Ok(SessionOutput::Exit(1)) => saw_exit = true,
2788 Ok(SessionOutput::Close) => saw_close = true,
2789 Ok(_) => {}
2790 Err(tokio::sync::mpsc::error::TryRecvError::Empty)
2791 | Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => break,
2792 }
2793 }
2794
2795 assert!(saw_fatal, "expected fatal stderr output for tea panic");
2796 assert!(saw_exit, "expected exit(1) for tea panic");
2797 assert!(saw_close, "expected close signal for tea panic");
2798
2799 Ok(())
2800 }
2801
2802 #[test]
2803 fn test_error_display() {
2804 let err = Error::Io(io::Error::other("test"));
2805 assert!(err.to_string().contains("io error"));
2806
2807 let err = Error::AuthenticationFailed;
2808 assert_eq!(err.to_string(), "authentication failed");
2809
2810 let err = Error::Configuration("bad config".to_string());
2811 assert!(err.to_string().contains("configuration error"));
2812 }
2813
2814 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
2815 async fn test_session_recv_with_input_channel() {
2816 let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2817 let ctx = Context::new("testuser", addr, addr);
2818 let session = Session::new(ctx);
2819
2820 assert!(session.recv().await.is_none());
2821
2822 let (tx, rx) = tokio::sync::mpsc::channel(1);
2823 session.set_input_receiver(rx).await;
2824 tx.send(b"ping".to_vec()).await.unwrap();
2825
2826 let received = session.recv().await;
2827 assert_eq!(received, Some(b"ping".to_vec()));
2828 }
2829
2830 #[test]
2831 fn test_session_send_message() {
2832 let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2833 let ctx = Context::new("testuser", addr, addr);
2834 let session = Session::new(ctx);
2835
2836 let (tx, rx) = std::sync::mpsc::channel();
2837 session.set_message_sender(tx);
2838 session.send_message(Message::new(42u32));
2839
2840 let msg = rx.recv_timeout(Duration::from_millis(50)).unwrap();
2841 assert!(msg.is::<u32>());
2842 assert_eq!(msg.downcast::<u32>().unwrap(), 42);
2843 }
2844
2845 #[tokio::test]
2846 async fn test_compose_middleware_order() {
2847 let events = Arc::new(Mutex::new(Vec::new()));
2848 let middlewares = vec![
2849 record_middleware("first", events.clone()),
2850 record_middleware("second", events.clone()),
2851 ];
2852 let composed = compose_middleware(middlewares);
2853
2854 let handler = handler({
2855 let events = events.clone();
2856 move |_session| {
2857 let events = events.clone();
2858 async move {
2859 let mut guard = events.lock().expect("events lock");
2860 guard.push("handler");
2861 }
2862 }
2863 });
2864
2865 let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2866 let ctx = Context::new("test", addr, addr);
2867 let session = Session::new(ctx);
2868
2869 composed(handler)(session).await;
2870
2871 let events = events.lock().expect("events lock");
2872 assert_eq!(&*events, &["first", "second", "handler"]);
2873 }
2874
2875 #[tokio::test]
2876 async fn test_activeterm_middleware_blocks_without_pty()
2877 -> std::result::Result<(), Box<dyn std::error::Error>> {
2878 let called = Arc::new(AtomicUsize::new(0));
2879 let mw = middleware::activeterm::middleware();
2880 let handler = handler({
2881 let called = called.clone();
2882 move |_session| {
2883 let called = called.clone();
2884 async move {
2885 called.fetch_add(1, Ordering::SeqCst);
2886 }
2887 }
2888 });
2889
2890 let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2891 let ctx = Context::new("test", addr, addr);
2892 let mut session = Session::new(ctx);
2893
2894 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
2895 session.set_output_sender(tx);
2896
2897 mw(handler)(session).await;
2898
2899 assert_eq!(called.load(Ordering::SeqCst), 0);
2900
2901 let item = rx.try_recv().map_err(|e| io::Error::other(e.to_string()))?;
2902 match item {
2903 SessionOutput::Stdout(data) => assert_eq!(data, b"Requires an active PTY\r\n"),
2904 other => {
2905 return Err(io::Error::other(format!(
2906 "expected stdout warning for activeterm, got {other:?}"
2907 ))
2908 .into());
2909 }
2910 }
2911
2912 let item = rx.try_recv().map_err(|e| io::Error::other(e.to_string()))?;
2913 match item {
2914 SessionOutput::Exit(code) => assert_eq!(code, 1),
2915 other => {
2916 return Err(io::Error::other(format!(
2917 "expected exit code for activeterm, got {other:?}"
2918 ))
2919 .into());
2920 }
2921 }
2922
2923 Ok(())
2924 }
2925
2926 #[tokio::test]
2927 async fn test_accesscontrol_middleware_allows_command() {
2928 let called = Arc::new(AtomicUsize::new(0));
2929 let mw = middleware::accesscontrol::middleware(vec!["git".to_string()]);
2930 let handler = handler({
2931 let called = called.clone();
2932 move |_session| {
2933 let called = called.clone();
2934 async move {
2935 called.fetch_add(1, Ordering::SeqCst);
2936 }
2937 }
2938 });
2939
2940 let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2941 let ctx = Context::new("test", addr, addr);
2942 let session = Session::new(ctx).with_command(vec!["git".to_string()]);
2943
2944 mw(handler)(session).await;
2945
2946 assert_eq!(called.load(Ordering::SeqCst), 1);
2947 }
2948
2949 #[tokio::test]
2950 async fn test_accesscontrol_middleware_blocks_command()
2951 -> std::result::Result<(), Box<dyn std::error::Error>> {
2952 let called = Arc::new(AtomicUsize::new(0));
2953 let mw = middleware::accesscontrol::middleware(vec!["git".to_string()]);
2954 let handler = handler({
2955 let called = called.clone();
2956 move |_session| {
2957 let called = called.clone();
2958 async move {
2959 called.fetch_add(1, Ordering::SeqCst);
2960 }
2961 }
2962 });
2963
2964 let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2965 let ctx = Context::new("test", addr, addr);
2966 let mut session = Session::new(ctx).with_command(vec!["rm".to_string()]);
2967
2968 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
2969 session.set_output_sender(tx);
2970
2971 mw(handler)(session).await;
2972
2973 assert_eq!(called.load(Ordering::SeqCst), 0);
2974
2975 let item = rx.try_recv().map_err(|e| io::Error::other(e.to_string()))?;
2976 match item {
2977 SessionOutput::Stdout(data) => assert_eq!(data, b"Command is not allowed: rm\r\n"),
2978 other => {
2979 return Err(io::Error::other(format!(
2980 "expected stdout message for accesscontrol, got {other:?}"
2981 ))
2982 .into());
2983 }
2984 }
2985
2986 let item = rx.try_recv().map_err(|e| io::Error::other(e.to_string()))?;
2987 match item {
2988 SessionOutput::Exit(code) => assert_eq!(code, 1),
2989 other => {
2990 return Err(io::Error::other(format!(
2991 "expected exit code for accesscontrol, got {other:?}"
2992 ))
2993 .into());
2994 }
2995 }
2996
2997 Ok(())
2998 }
2999
3000 #[tokio::test]
3001 async fn test_comment_middleware_appends_message()
3002 -> std::result::Result<(), Box<dyn std::error::Error>> {
3003 let mw = middleware::comment::middleware("done");
3004 let handler = handler(|session| async move {
3005 print(&session, "work");
3006 });
3007
3008 let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
3009 let ctx = Context::new("test", addr, addr);
3010 let mut session = Session::new(ctx);
3011
3012 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
3013 session.set_output_sender(tx);
3014
3015 mw(handler)(session).await;
3016
3017 let item = rx.try_recv().map_err(|e| io::Error::other(e.to_string()))?;
3018 match item {
3019 SessionOutput::Stdout(data) => assert_eq!(data, b"work"),
3020 other => {
3021 return Err(io::Error::other(format!(
3022 "expected stdout for handler output, got {other:?}"
3023 ))
3024 .into());
3025 }
3026 }
3027
3028 let item = rx.try_recv().map_err(|e| io::Error::other(e.to_string()))?;
3029 match item {
3030 SessionOutput::Stdout(data) => assert_eq!(data, b"done\r\n"),
3031 other => {
3032 return Err(io::Error::other(format!(
3033 "expected stdout for comment output, got {other:?}"
3034 ))
3035 .into());
3036 }
3037 }
3038
3039 Ok(())
3040 }
3041
3042 #[tokio::test]
3043 async fn test_elapsed_middleware_outputs_timing()
3044 -> std::result::Result<(), Box<dyn std::error::Error>> {
3045 let mw = middleware::elapsed::middleware_with_format("elapsed=%v");
3046 let handler = handler(|_session| async move {});
3047
3048 let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
3049 let ctx = Context::new("test", addr, addr);
3050 let mut session = Session::new(ctx);
3051
3052 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
3053 session.set_output_sender(tx);
3054
3055 mw(handler)(session).await;
3056
3057 let item = rx.try_recv().map_err(|e| io::Error::other(e.to_string()))?;
3058 match item {
3059 SessionOutput::Stdout(data) => {
3060 let msg = String::from_utf8_lossy(&data);
3061 assert!(msg.contains("elapsed="));
3062 }
3063 other => {
3064 return Err(io::Error::other(format!(
3065 "expected stdout for elapsed middleware, got {other:?}"
3066 ))
3067 .into());
3068 }
3069 }
3070
3071 Ok(())
3072 }
3073
3074 #[tokio::test]
3075 async fn test_ratelimiter_middleware_rejects()
3076 -> std::result::Result<(), Box<dyn std::error::Error>> {
3077 let called = Arc::new(AtomicUsize::new(0));
3078 let mw = middleware::ratelimiter::middleware(DenyLimiter);
3079 let handler = handler({
3080 let called = called.clone();
3081 move |_session| {
3082 let called = called.clone();
3083 async move {
3084 called.fetch_add(1, Ordering::SeqCst);
3085 }
3086 }
3087 });
3088
3089 let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
3090 let ctx = Context::new("test", addr, addr);
3091 let mut session = Session::new(ctx);
3092
3093 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
3094 session.set_output_sender(tx);
3095
3096 mw(handler)(session).await;
3097
3098 assert_eq!(called.load(Ordering::SeqCst), 0);
3099
3100 let item = rx.try_recv().map_err(|e| io::Error::other(e.to_string()))?;
3101 match item {
3102 SessionOutput::Stderr(data) => {
3103 assert_eq!(
3104 data,
3105 middleware::ratelimiter::ERR_RATE_LIMIT_EXCEEDED.as_bytes()
3106 );
3107 }
3108 other => {
3109 return Err(io::Error::other(format!(
3110 "expected stderr for ratelimiter, got {other:?}"
3111 ))
3112 .into());
3113 }
3114 }
3115
3116 let item = rx.try_recv().map_err(|e| io::Error::other(e.to_string()))?;
3117 match item {
3118 SessionOutput::Exit(code) => assert_eq!(code, 1),
3119 other => {
3120 return Err(io::Error::other(format!(
3121 "expected exit for ratelimiter, got {other:?}"
3122 ))
3123 .into());
3124 }
3125 }
3126
3127 let item = rx.try_recv().map_err(|e| io::Error::other(e.to_string()))?;
3128 match item {
3129 SessionOutput::Close => {}
3130 other => {
3131 return Err(io::Error::other(format!(
3132 "expected close for ratelimiter, got {other:?}"
3133 ))
3134 .into());
3135 }
3136 }
3137
3138 Ok(())
3139 }
3140
3141 #[tokio::test]
3142 async fn test_logging_middleware_with_custom_logger() {
3143 let entries = Arc::new(Mutex::new(Vec::new()));
3144 let logger = TestLogger {
3145 entries: entries.clone(),
3146 };
3147
3148 let mw = middleware::logging::middleware_with_logger(logger);
3149 let handler = handler(|_session| async move {});
3150
3151 let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
3152 let ctx = Context::new("alice", addr, addr);
3153 let session = Session::new(ctx);
3154
3155 mw(handler)(session).await;
3156
3157 let entries = entries.lock().expect("logger entries");
3158 assert_eq!(entries.len(), 2);
3159 assert!(entries[0].contains("connect"));
3160 assert!(entries[1].contains("disconnect"));
3161 }
3162
3163 #[tokio::test]
3164 async fn test_structured_logging_middleware_with_custom_logger() {
3165 let logger = TestStructuredLogger::default();
3166 let mw = middleware::logging::structured_middleware_with_logger(
3167 logger.clone(),
3168 tracing::Level::INFO,
3169 );
3170 let handler = handler(|_session| async move {});
3171
3172 let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
3173 let ctx = Context::new("alice", addr, addr);
3174 let session = Session::new(ctx).with_public_key(PublicKey::new("ssh-ed25519", vec![1]));
3175
3176 mw(handler)(session).await;
3177
3178 let connects = logger.connects.lock().expect("connects");
3179 assert_eq!(connects.len(), 1);
3180 assert_eq!(connects[0].0, "alice");
3181 assert_eq!(connects[0].1, addr);
3182 assert!(connects[0].2);
3183
3184 let disconnects = logger.disconnects.lock().expect("disconnects");
3185 assert_eq!(disconnects.len(), 1);
3186 assert_eq!(disconnects[0].0, "alice");
3187 assert_eq!(disconnects[0].1, addr);
3188 }
3189
3190 #[tokio::test]
3191 async fn test_recover_middleware_runs_inner_before_next() {
3192 let events = Arc::new(Mutex::new(Vec::new()));
3193 let inner = record_middleware("inner", events.clone());
3194 let mw = middleware::recover::middleware_with_middlewares(vec![inner]);
3195
3196 let handler = handler({
3197 let events = events.clone();
3198 move |_session| {
3199 let events = events.clone();
3200 async move {
3201 let mut guard = events.lock().expect("events lock");
3202 guard.push("handler");
3203 }
3204 }
3205 });
3206
3207 let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
3208 let ctx = Context::new("test", addr, addr);
3209 let session = Session::new(ctx);
3210
3211 mw(handler)(session).await;
3212
3213 let events = events.lock().expect("events lock");
3214 assert_eq!(&*events, &["inner", "handler"]);
3215 }
3216
3217 #[test]
3218 fn test_server_option_auth_and_subsystem() {
3219 let mut opts = ServerOptions::default();
3220
3221 with_auth_handler(AcceptAllAuth::new())(&mut opts).unwrap();
3222 with_max_auth_attempts(3)(&mut opts).unwrap();
3223 with_auth_rejection_delay(250)(&mut opts).unwrap();
3224 with_public_key_auth(|_ctx, _key| true)(&mut opts).unwrap();
3225 with_password_auth(|_ctx, _pw| true)(&mut opts).unwrap();
3226 with_keyboard_interactive_auth(|_ctx, _resp, _prompts, _echos| vec!["ok".to_string()])(
3227 &mut opts,
3228 )
3229 .unwrap();
3230 with_host_key_path("/tmp/wish_host_file")(&mut opts).unwrap();
3231 with_host_key_pem(b"test_key_data".to_vec())(&mut opts).unwrap();
3232 with_banner_handler(|ctx| format!("hello {}", ctx.user()))(&mut opts).unwrap();
3233 with_middleware(middleware::comment::middleware("hi"))(&mut opts).unwrap();
3234 with_subsystem("sftp", |_session| async move {})(&mut opts).unwrap();
3235
3236 assert!(opts.auth_handler.is_some());
3237 assert_eq!(opts.max_auth_attempts, 3);
3238 assert_eq!(opts.auth_rejection_delay_ms, 250);
3239 assert!(opts.public_key_handler.is_some());
3240 assert!(opts.password_handler.is_some());
3241 assert!(opts.keyboard_interactive_handler.is_some());
3242 assert_eq!(opts.host_key_path.as_deref(), Some("/tmp/wish_host_file"));
3243 assert_eq!(
3244 opts.host_key_pem.as_deref(),
3245 Some(b"test_key_data".as_slice())
3246 );
3247 assert!(opts.banner_handler.is_some());
3248 assert_eq!(opts.middlewares.len(), 1);
3249 assert!(opts.subsystem_handlers.contains_key("sftp"));
3250 }
3251
3252 #[test]
3253 fn test_server_builder_auth_settings() {
3254 let server = ServerBuilder::new()
3255 .address("127.0.0.1:2222")
3256 .max_auth_attempts(5)
3257 .auth_rejection_delay(123)
3258 .public_key_auth(|_ctx, _key| true)
3259 .password_auth(|_ctx, _pw| true)
3260 .keyboard_interactive_auth(|_ctx, _resp, _prompts, _echos| vec![])
3261 .subsystem("sftp", |_session| async move {})
3262 .build()
3263 .unwrap();
3264
3265 assert_eq!(server.options().max_auth_attempts, 5);
3266 assert_eq!(server.options().auth_rejection_delay_ms, 123);
3267 assert!(server.options().public_key_handler.is_some());
3268 assert!(server.options().password_handler.is_some());
3269 assert!(server.options().keyboard_interactive_handler.is_some());
3270 assert!(server.options().subsystem_handlers.contains_key("sftp"));
3271 }
3272
3273 #[test]
3274 fn test_create_russh_config_methods_from_auth_handler() {
3275 use russh::MethodSet;
3276
3277 struct PasswordOnly;
3278
3279 #[async_trait::async_trait]
3280 impl AuthHandler for PasswordOnly {
3281 fn supported_methods(&self) -> Vec<AuthMethod> {
3282 vec![AuthMethod::Password]
3283 }
3284 }
3285
3286 let server = ServerBuilder::new()
3287 .auth_handler(PasswordOnly)
3288 .build()
3289 .unwrap();
3290 let config = server.create_russh_config().unwrap();
3291
3292 assert!(config.methods.contains(MethodSet::PASSWORD));
3293 assert!(!config.methods.contains(MethodSet::PUBLICKEY));
3294 }
3295
3296 #[test]
3297 fn test_create_russh_config_methods_from_callbacks() {
3298 use russh::MethodSet;
3299
3300 let server = ServerBuilder::new()
3301 .public_key_auth(|_ctx, _key| true)
3302 .password_auth(|_ctx, _pw| true)
3303 .build()
3304 .unwrap();
3305
3306 let config = server.create_russh_config().unwrap();
3307
3308 assert!(config.methods.contains(MethodSet::PUBLICKEY));
3309 assert!(config.methods.contains(MethodSet::PASSWORD));
3310 assert!(!config.methods.contains(MethodSet::KEYBOARD_INTERACTIVE));
3311 }
3312}