Skip to main content

wish/
lib.rs

1#![forbid(unsafe_code)]
2// Allow pedantic lints for early-stage API ergonomics.
3#![allow(clippy::doc_markdown)]
4#![allow(clippy::nursery)]
5#![allow(clippy::pedantic)]
6
7//! # Wish
8//!
9//! A library for building SSH applications with TUI interfaces.
10//!
11//! Wish enables you to create SSH servers that serve interactive
12//! terminal applications, making it easy to build:
13//! - SSH-accessible TUI apps
14//! - Git servers with custom interfaces
15//! - Multi-user terminal experiences
16//! - Secure remote access tools
17//!
18//! ## Role in `charmed_rust`
19//!
20//! Wish is the SSH application layer for bubbletea programs:
21//! - **bubbletea** provides the program runtime served over SSH.
22//! - **charmed_log** supplies structured logging for sessions.
23//! - **demo_showcase** includes an SSH mode to demonstrate remote TUIs.
24//!
25//! ## Features
26//!
27//! - **Middleware pattern**: Compose handlers with chainable middleware
28//! - **PTY support**: Full pseudo-terminal emulation
29//! - **Authentication**: Public key, password, and keyboard-interactive auth
30//! - **BubbleTea integration**: Serve TUI apps over SSH
31//! - **Logging middleware**: Connection logging out of the box
32//! - **Access control**: Restrict allowed commands
33//!
34//! ## Example
35//!
36//! ```rust,ignore
37//! use wish::{Server, ServerBuilder};
38//! use wish::middleware::{logging, activeterm};
39//!
40//! #[tokio::main]
41//! async fn main() -> Result<(), wish::Error> {
42//!     let server = ServerBuilder::new()
43//!         .address("0.0.0.0:2222")
44//!         .with_middleware(logging::middleware())
45//!         .with_middleware(activeterm::middleware())
46//!         .handler(|session| async move {
47//!             wish::println(&session, "Hello, SSH!");
48//!         })
49//!         .build()
50//!         .await?;
51//!
52//!     server.listen().await
53//! }
54//! ```
55
56use std::collections::HashMap;
57use std::fmt;
58use std::future::Future;
59use std::io::{self, Write};
60use std::net::SocketAddr;
61use std::pin::Pin;
62use std::sync::Arc;
63use std::sync::mpsc::Sender;
64use std::time::Duration;
65
66use bubbletea::Message;
67use parking_lot::RwLock;
68use thiserror::Error;
69use tokio::net::TcpListener;
70use tracing::{debug, error, info, warn};
71
72pub mod auth;
73mod handler;
74pub mod session;
75
76pub use auth::{
77    AcceptAllAuth, AsyncCallbackAuth, AsyncPublicKeyAuth, AuthContext, AuthHandler, AuthMethod,
78    AuthResult, AuthorizedKey, AuthorizedKeysAuth, CallbackAuth, CompositeAuth, PasswordAuth,
79    PublicKeyAuth, PublicKeyCallbackAuth, RateLimitedAuth, SessionId, parse_authorized_keys,
80};
81pub use handler::{RusshConfig, ServerState, WishHandler, WishHandlerFactory, run_stream};
82
83// Re-export dependencies for convenience
84pub use bubbletea;
85pub use lipgloss;
86
87// -----------------------------------------------------------------------------
88// Error Types
89// -----------------------------------------------------------------------------
90
91/// Errors that can occur in the wish SSH server library.
92///
93/// This enum represents all possible error conditions when running
94/// an SSH server with wish.
95///
96/// # Error Handling
97///
98/// SSH server errors range from configuration issues to runtime
99/// authentication failures. Use the `?` operator for propagation:
100///
101/// ```rust,ignore
102/// use wish::Result;
103///
104/// async fn run_server() -> Result<()> {
105///     let server = Server::new(handler).await?;
106///     server.listen("0.0.0.0:2222").await?;
107///     Ok(())
108/// }
109/// ```
110///
111/// # Recovery Strategies
112///
113/// | Error Variant | Recovery Strategy |
114/// |--------------|-------------------|
115/// | [`Io`](Error::Io) | Check permissions, port availability |
116/// | [`Ssh`](Error::Ssh) | Log and continue for recoverable errors |
117/// | [`Russh`](Error::Russh) | Check SSH protocol compatibility |
118/// | [`Key`](Error::Key) | Regenerate keys or check permissions |
119/// | [`KeyLoad`](Error::KeyLoad) | Verify key file format |
120/// | [`AuthenticationFailed`](Error::AuthenticationFailed) | Expected for invalid credentials |
121/// | [`MaxSessionsReached`](Error::MaxSessionsReached) | Retry later or raise configured session limit |
122/// | [`Configuration`](Error::Configuration) | Fix server configuration |
123/// | [`Session`](Error::Session) | Close session gracefully |
124/// | [`AddrParse`](Error::AddrParse) | Validate address format |
125#[derive(Error, Debug)]
126pub enum Error {
127    /// I/O error during server operations.
128    ///
129    /// Commonly occurs when:
130    /// - The bind address is already in use
131    /// - Permission denied on privileged ports
132    /// - Network interface is unavailable
133    #[error("io error: {0}")]
134    Io(#[from] io::Error),
135
136    /// SSH protocol error.
137    ///
138    /// General SSH protocol-level errors. Contains a descriptive message.
139    #[error("ssh error: {0}")]
140    Ssh(String),
141
142    /// Underlying russh library error.
143    ///
144    /// Wraps errors from the russh SSH implementation.
145    #[error("russh error: {0}")]
146    Russh(#[from] russh::Error),
147
148    /// Key generation or management error.
149    ///
150    /// Occurs when generating or manipulating SSH keys fails.
151    #[error("key error: {0}")]
152    Key(String),
153
154    /// Key loading error from russh-keys.
155    ///
156    /// Occurs when loading SSH keys from files fails.
157    /// Common causes: file not found, invalid format, permission denied.
158    #[error("key loading error: {0}")]
159    KeyLoad(#[from] russh_keys::Error),
160
161    /// Authentication failed.
162    ///
163    /// Occurs when a client's credentials are rejected.
164    /// This is expected in normal operation - not all attempts succeed.
165    #[error("authentication failed")]
166    AuthenticationFailed,
167
168    /// Maximum concurrent sessions reached.
169    ///
170    /// Returned when attempting to create a new session while at capacity.
171    #[error("maximum sessions reached ({current}/{max})")]
172    MaxSessionsReached {
173        /// Configured maximum concurrent sessions.
174        max: usize,
175        /// Current active session count at rejection time.
176        current: usize,
177    },
178
179    /// Server configuration error.
180    ///
181    /// Occurs when the server configuration is invalid.
182    #[error("configuration error: {0}")]
183    Configuration(String),
184
185    /// Session error.
186    ///
187    /// Occurs during an active SSH session.
188    #[error("session error: {0}")]
189    Session(String),
190
191    /// Address parse error.
192    ///
193    /// Occurs when parsing a socket address fails.
194    #[error("address parse error: {0}")]
195    AddrParse(#[from] std::net::AddrParseError),
196}
197
198/// A specialized [`Result`] type for wish operations.
199///
200/// This type alias defaults to [`enum@Error`] as the error type.
201pub type Result<T> = std::result::Result<T, Error>;
202
203// -----------------------------------------------------------------------------
204// PTY Types
205// -----------------------------------------------------------------------------
206
207/// Window size information.
208#[derive(Debug, Clone, Copy, PartialEq, Eq)]
209pub struct Window {
210    /// Terminal width in columns.
211    pub width: u32,
212    /// Terminal height in rows.
213    pub height: u32,
214}
215
216impl Default for Window {
217    fn default() -> Self {
218        Self {
219            width: 80,
220            height: 24,
221        }
222    }
223}
224
225/// Pseudo-terminal information.
226#[derive(Debug, Clone)]
227pub struct Pty {
228    /// Terminal type (e.g., "xterm-256color").
229    pub term: String,
230    /// Window dimensions.
231    pub window: Window,
232}
233
234impl Default for Pty {
235    fn default() -> Self {
236        Self {
237            term: "xterm-256color".to_string(),
238            window: Window::default(),
239        }
240    }
241}
242
243// -----------------------------------------------------------------------------
244// Public Key Types
245// -----------------------------------------------------------------------------
246
247/// A public key used for authentication.
248#[derive(Debug, Clone)]
249pub struct PublicKey {
250    /// The key type (e.g., "ssh-ed25519", "ssh-rsa").
251    pub key_type: String,
252    /// The raw key data.
253    pub data: Vec<u8>,
254    /// Optional comment from the authorized_keys file.
255    pub comment: Option<String>,
256}
257
258impl PublicKey {
259    /// Creates a new public key.
260    pub fn new(key_type: impl Into<String>, data: Vec<u8>) -> Self {
261        Self {
262            key_type: key_type.into(),
263            data,
264            comment: None,
265        }
266    }
267
268    /// Sets the comment for this key.
269    pub fn with_comment(mut self, comment: impl Into<String>) -> Self {
270        self.comment = Some(comment.into());
271        self
272    }
273
274    /// Returns a fingerprint of the key.
275    ///
276    /// Note: uses `DefaultHasher` (SipHash), not a cryptographic hash.
277    /// The prefix is for display convention only.
278    pub fn fingerprint(&self) -> String {
279        use std::collections::hash_map::DefaultHasher;
280        use std::hash::{Hash, Hasher};
281        let mut hasher = DefaultHasher::new();
282        self.data.hash(&mut hasher);
283        format!("HASH:{:016x}", hasher.finish())
284    }
285}
286
287impl PartialEq for PublicKey {
288    fn eq(&self, other: &Self) -> bool {
289        self.key_type == other.key_type && self.data == other.data
290    }
291}
292
293impl Eq for PublicKey {}
294
295// -----------------------------------------------------------------------------
296// Context
297// -----------------------------------------------------------------------------
298
299/// Context passed to authentication handlers.
300#[derive(Debug, Clone)]
301pub struct Context {
302    /// The username attempting authentication.
303    user: String,
304    /// The remote address.
305    remote_addr: SocketAddr,
306    /// The local address.
307    local_addr: SocketAddr,
308    /// The client version string.
309    client_version: String,
310    /// Custom values stored in the context.
311    values: Arc<RwLock<HashMap<String, String>>>,
312}
313
314impl Context {
315    /// Creates a new context.
316    pub fn new(user: impl Into<String>, remote_addr: SocketAddr, local_addr: SocketAddr) -> Self {
317        Self {
318            user: user.into(),
319            remote_addr,
320            local_addr,
321            client_version: String::new(),
322            values: Arc::new(RwLock::new(HashMap::new())),
323        }
324    }
325
326    /// Returns the username.
327    pub fn user(&self) -> &str {
328        &self.user
329    }
330
331    /// Returns the remote address.
332    pub fn remote_addr(&self) -> SocketAddr {
333        self.remote_addr
334    }
335
336    /// Returns the local address.
337    pub fn local_addr(&self) -> SocketAddr {
338        self.local_addr
339    }
340
341    /// Returns the client version string.
342    pub fn client_version(&self) -> &str {
343        &self.client_version
344    }
345
346    /// Sets the client version string.
347    pub fn set_client_version(&mut self, version: impl Into<String>) {
348        self.client_version = version.into();
349    }
350
351    /// Sets a value in the context.
352    pub fn set_value(&self, key: impl Into<String>, value: impl Into<String>) {
353        self.values.write().insert(key.into(), value.into());
354    }
355
356    /// Gets a value from the context.
357    pub fn get_value(&self, key: &str) -> Option<String> {
358        self.values.read().get(key).cloned()
359    }
360}
361
362// -----------------------------------------------------------------------------
363// Session
364// -----------------------------------------------------------------------------
365
366/// An SSH session representing a connected client.
367#[derive(Clone)]
368pub struct Session {
369    /// The session context.
370    context: Context,
371    /// The PTY if allocated.
372    pty: Option<Pty>,
373    /// The command being executed (if any).
374    command: Vec<String>,
375    /// Environment variables.
376    env: HashMap<String, String>,
377    /// Output buffer for stdout.
378    #[allow(dead_code)]
379    pub(crate) stdout: Arc<RwLock<Vec<u8>>>,
380    /// Output buffer for stderr.
381    #[allow(dead_code)]
382    pub(crate) stderr: Arc<RwLock<Vec<u8>>>,
383    /// Exit code.
384    exit_code: Arc<RwLock<Option<i32>>>,
385    /// Whether the session is closed.
386    closed: Arc<RwLock<bool>>,
387    /// The public key used for authentication (if any).
388    public_key: Option<PublicKey>,
389    /// Subsystem being used (if any).
390    subsystem: Option<String>,
391
392    /// Channel for sending output to the client.
393    output_tx: Option<tokio::sync::mpsc::UnboundedSender<SessionOutput>>,
394    /// Channel for receiving input from the client.
395    input_rx: Arc<tokio::sync::Mutex<Option<tokio::sync::mpsc::Receiver<Vec<u8>>>>>,
396    /// Channel for injecting messages into the running bubbletea program.
397    message_tx: Arc<RwLock<Option<Sender<Message>>>>,
398}
399
400/// Output messages sent from Session to the SSH channel.
401#[derive(Debug)]
402pub enum SessionOutput {
403    /// Standard output data.
404    Stdout(Vec<u8>),
405    /// Standard error data.
406    Stderr(Vec<u8>),
407    /// Exit status code.
408    Exit(u32),
409    /// Close the channel.
410    Close,
411}
412
413impl fmt::Debug for Session {
414    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
415        f.debug_struct("Session")
416            .field("user", &self.context.user)
417            .field("remote_addr", &self.context.remote_addr)
418            .field("pty", &self.pty)
419            .field("command", &self.command)
420            .finish()
421    }
422}
423
424impl Session {
425    /// Creates a new session.
426    pub fn new(context: Context) -> Self {
427        Self {
428            context,
429            pty: None,
430            command: Vec::new(),
431            env: HashMap::new(),
432            stdout: Arc::new(RwLock::new(Vec::new())),
433            stderr: Arc::new(RwLock::new(Vec::new())),
434            exit_code: Arc::new(RwLock::new(None)),
435            closed: Arc::new(RwLock::new(false)),
436            public_key: None,
437            subsystem: None,
438            output_tx: None,
439            input_rx: Arc::new(tokio::sync::Mutex::new(None)),
440            message_tx: Arc::new(RwLock::new(None)),
441        }
442    }
443
444    /// Sets the output sender.
445    pub fn set_output_sender(&mut self, tx: tokio::sync::mpsc::UnboundedSender<SessionOutput>) {
446        self.output_tx = Some(tx);
447    }
448
449    /// Sets the input receiver.
450    pub async fn set_input_receiver(&self, rx: tokio::sync::mpsc::Receiver<Vec<u8>>) {
451        *self.input_rx.lock().await = Some(rx);
452    }
453
454    /// Receives input from the client.
455    pub async fn recv(&self) -> Option<Vec<u8>> {
456        let mut rx_guard = self.input_rx.lock().await;
457        if let Some(rx) = rx_guard.as_mut() {
458            rx.recv().await
459        } else {
460            None
461        }
462    }
463
464    /// Sets the message sender for the bubbletea program.
465    pub fn set_message_sender(&self, tx: Sender<Message>) {
466        *self.message_tx.write() = Some(tx);
467    }
468
469    /// Sends a message to the bubbletea program (if running).
470    pub fn send_message(&self, msg: Message) {
471        if let Some(tx) = self.message_tx.read().as_ref() {
472            // We ignore errors because if the channel is closed, the program is gone
473            let _ = tx.send(msg);
474        }
475    }
476
477    /// Returns the username.
478    pub fn user(&self) -> &str {
479        self.context.user()
480    }
481
482    /// Returns the remote address.
483    pub fn remote_addr(&self) -> SocketAddr {
484        self.context.remote_addr()
485    }
486
487    /// Returns the local address.
488    pub fn local_addr(&self) -> SocketAddr {
489        self.context.local_addr()
490    }
491
492    /// Returns the context.
493    pub fn context(&self) -> &Context {
494        &self.context
495    }
496
497    /// Returns the PTY and window change channel if allocated.
498    pub fn pty(&self) -> (Option<&Pty>, bool) {
499        (self.pty.as_ref(), self.pty.is_some())
500    }
501
502    /// Returns the command being executed.
503    pub fn command(&self) -> &[String] {
504        &self.command
505    }
506
507    /// Returns an environment variable.
508    pub fn get_env(&self, key: &str) -> Option<&String> {
509        self.env.get(key)
510    }
511
512    /// Returns all environment variables.
513    pub fn environ(&self) -> &HashMap<String, String> {
514        &self.env
515    }
516
517    /// Returns the public key used for authentication.
518    pub fn public_key(&self) -> Option<&PublicKey> {
519        self.public_key.as_ref()
520    }
521
522    /// Returns the subsystem being used.
523    pub fn subsystem(&self) -> Option<&str> {
524        self.subsystem.as_deref()
525    }
526
527    /// Writes to stdout.
528    pub fn write(&self, data: &[u8]) -> io::Result<usize> {
529        // Send to client
530        if let Some(tx) = &self.output_tx {
531            let _ = tx.send(SessionOutput::Stdout(data.to_vec()));
532        }
533
534        Ok(data.len())
535    }
536
537    /// Writes to stderr.
538    pub fn write_stderr(&self, data: &[u8]) -> io::Result<usize> {
539        // Send to client
540        if let Some(tx) = &self.output_tx {
541            let _ = tx.send(SessionOutput::Stderr(data.to_vec()));
542        }
543
544        Ok(data.len())
545    }
546
547    /// Exits the session with the given code.
548    pub fn exit(&self, code: i32) -> io::Result<()> {
549        *self.exit_code.write() = Some(code);
550        if let Some(tx) = &self.output_tx {
551            let _ = tx.send(SessionOutput::Exit(code as u32));
552        }
553        Ok(())
554    }
555
556    /// Closes the session.
557    pub fn close(&self) -> io::Result<()> {
558        *self.closed.write() = true;
559        if let Some(tx) = &self.output_tx {
560            let _ = tx.send(SessionOutput::Close);
561        }
562        Ok(())
563    }
564
565    /// Returns whether the session is closed.
566    pub fn is_closed(&self) -> bool {
567        *self.closed.read()
568    }
569
570    /// Returns the current window size.
571    pub fn window(&self) -> Window {
572        self.pty.as_ref().map(|p| p.window).unwrap_or_default()
573    }
574
575    // Builder methods for constructing sessions
576
577    /// Sets the PTY.
578    pub fn with_pty(mut self, pty: Pty) -> Self {
579        self.pty = Some(pty);
580        self
581    }
582
583    /// Sets the command.
584    pub fn with_command(mut self, command: Vec<String>) -> Self {
585        self.command = command;
586        self
587    }
588
589    /// Sets an environment variable.
590    pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
591        self.env.insert(key.into(), value.into());
592        self
593    }
594
595    /// Sets the public key.
596    pub fn with_public_key(mut self, key: PublicKey) -> Self {
597        self.public_key = Some(key);
598        self
599    }
600
601    /// Sets the subsystem.
602    pub fn with_subsystem(mut self, subsystem: impl Into<String>) -> Self {
603        self.subsystem = Some(subsystem.into());
604        self
605    }
606}
607
608// Implement Write for Session
609impl Write for Session {
610    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
611        Session::write(self, buf)
612    }
613
614    fn flush(&mut self) -> io::Result<()> {
615        Ok(())
616    }
617}
618
619// -----------------------------------------------------------------------------
620// Output Helper Functions
621// -----------------------------------------------------------------------------
622
623/// Writes to the session's stdout.
624pub fn print(session: &Session, args: impl fmt::Display) {
625    let _ = session.write(args.to_string().as_bytes());
626}
627
628/// Writes to the session's stdout with a newline.
629pub fn println(session: &Session, args: impl fmt::Display) {
630    let msg = format!("{}\r\n", args);
631    let _ = session.write(msg.as_bytes());
632}
633
634/// Writes formatted output to the session's stdout.
635pub fn printf(session: &Session, format: impl fmt::Display, args: &[&dyn fmt::Display]) {
636    let mut msg = format.to_string();
637    for arg in args {
638        if let Some(pos) = msg.find("{}") {
639            msg.replace_range(pos..pos + 2, &arg.to_string());
640        }
641    }
642    let _ = session.write(msg.as_bytes());
643}
644
645/// Writes to the session's stderr.
646pub fn error(session: &Session, args: impl fmt::Display) {
647    let _ = session.write_stderr(args.to_string().as_bytes());
648}
649
650/// Writes to the session's stderr with a newline.
651pub fn errorln(session: &Session, args: impl fmt::Display) {
652    let msg = format!("{}\r\n", args);
653    let _ = session.write_stderr(msg.as_bytes());
654}
655
656/// Writes formatted output to the session's stderr.
657pub fn errorf(session: &Session, format: impl fmt::Display, args: &[&dyn fmt::Display]) {
658    let mut msg = format.to_string();
659    for arg in args {
660        if let Some(pos) = msg.find("{}") {
661            msg.replace_range(pos..pos + 2, &arg.to_string());
662        }
663    }
664    let _ = session.write_stderr(msg.as_bytes());
665}
666
667/// Writes to stderr and exits with code 1.
668pub fn fatal(session: &Session, args: impl fmt::Display) {
669    error(session, args);
670    let _ = session.exit(1);
671    let _ = session.close();
672}
673
674/// Writes to stderr with a newline and exits with code 1.
675pub fn fatalln(session: &Session, args: impl fmt::Display) {
676    errorln(session, args);
677    let _ = session.exit(1);
678    let _ = session.close();
679}
680
681/// Writes formatted output to stderr and exits with code 1.
682pub fn fatalf(session: &Session, format: impl fmt::Display, args: &[&dyn fmt::Display]) {
683    errorf(session, format, args);
684    let _ = session.exit(1);
685    let _ = session.close();
686}
687
688/// Writes a string to the session's stdout.
689pub fn write_string(session: &Session, s: &str) -> io::Result<usize> {
690    session.write(s.as_bytes())
691}
692
693// -----------------------------------------------------------------------------
694// Handler and Middleware
695// -----------------------------------------------------------------------------
696
697/// A boxed future for async handlers.
698pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
699
700/// Handler function type.
701pub type Handler = Arc<dyn Fn(Session) -> BoxFuture<'static, ()> + Send + Sync>;
702
703/// Middleware function type.
704pub type Middleware = Arc<dyn Fn(Handler) -> Handler + Send + Sync>;
705
706/// Creates a handler from an async function.
707pub fn handler<F, Fut>(f: F) -> Handler
708where
709    F: Fn(Session) -> Fut + Send + Sync + 'static,
710    Fut: Future<Output = ()> + Send + 'static,
711{
712    Arc::new(move |session| Box::pin(f(session)))
713}
714
715/// Creates a no-op handler.
716pub fn noop_handler() -> Handler {
717    Arc::new(|_| Box::pin(async {}))
718}
719
720/// Composes multiple middleware into a single middleware.
721pub fn compose_middleware(middlewares: Vec<Middleware>) -> Middleware {
722    Arc::new(move |h| {
723        let mut handler = h;
724        for mw in middlewares.iter().rev() {
725            handler = mw(handler);
726        }
727        handler
728    })
729}
730
731// -----------------------------------------------------------------------------
732// Authentication Handlers
733// -----------------------------------------------------------------------------
734
735/// Public key authentication handler.
736pub type PublicKeyHandler = Arc<dyn Fn(&Context, &PublicKey) -> bool + Send + Sync>;
737
738/// Password authentication handler.
739pub type PasswordHandler = Arc<dyn Fn(&Context, &str) -> bool + Send + Sync>;
740
741/// Keyboard-interactive authentication handler.
742pub type KeyboardInteractiveHandler =
743    Arc<dyn Fn(&Context, &str, &[String], &[bool]) -> Vec<String> + Send + Sync>;
744
745/// Banner handler that returns a banner based on context.
746pub type BannerHandler = Arc<dyn Fn(&Context) -> String + Send + Sync>;
747
748/// Subsystem handler.
749pub type SubsystemHandler = Arc<dyn Fn(Session) -> BoxFuture<'static, ()> + Send + Sync>;
750
751// -----------------------------------------------------------------------------
752// Server Options
753// -----------------------------------------------------------------------------
754
755/// Options for configuring the SSH server.
756#[derive(Clone)]
757pub struct ServerOptions {
758    /// Listen address.
759    pub address: String,
760    /// Server version string.
761    pub version: String,
762    /// Static banner.
763    pub banner: Option<String>,
764    /// Dynamic banner handler.
765    pub banner_handler: Option<BannerHandler>,
766    /// Host key path.
767    pub host_key_path: Option<String>,
768    /// Host key PEM data.
769    pub host_key_pem: Option<Vec<u8>>,
770    /// Middlewares to apply.
771    pub middlewares: Vec<Middleware>,
772    /// Main handler.
773    pub handler: Option<Handler>,
774    /// Trait-based authentication handler.
775    /// If set, takes precedence over the callback-based handlers.
776    pub auth_handler: Option<Arc<dyn AuthHandler>>,
777    /// Public key auth handler (callback-based, for backward compatibility).
778    pub public_key_handler: Option<PublicKeyHandler>,
779    /// Password auth handler (callback-based, for backward compatibility).
780    pub password_handler: Option<PasswordHandler>,
781    /// Keyboard-interactive auth handler.
782    pub keyboard_interactive_handler: Option<KeyboardInteractiveHandler>,
783    /// Idle timeout.
784    pub idle_timeout: Option<Duration>,
785    /// Maximum connection timeout.
786    pub max_timeout: Option<Duration>,
787    /// Subsystem handlers.
788    pub subsystem_handlers: HashMap<String, SubsystemHandler>,
789    /// Maximum authentication attempts before disconnection.
790    pub max_auth_attempts: u32,
791    /// Authentication rejection delay in milliseconds (timing attack mitigation).
792    pub auth_rejection_delay_ms: u64,
793    /// Allow unauthenticated access when no auth handlers are configured.
794    ///
795    /// When `false` (the default), connections are rejected if no auth
796    /// handlers (public key, password, keyboard-interactive, or trait-based)
797    /// are registered. Set to `true` only for development/demo servers
798    /// that intentionally allow anonymous access.
799    pub allow_no_auth: bool,
800}
801
802impl Default for ServerOptions {
803    fn default() -> Self {
804        Self {
805            address: "0.0.0.0:22".to_string(),
806            version: "SSH-2.0-Wish".to_string(),
807            banner: None,
808            banner_handler: None,
809            host_key_path: None,
810            host_key_pem: None,
811            middlewares: Vec::new(),
812            handler: None,
813            auth_handler: None,
814            public_key_handler: None,
815            password_handler: None,
816            keyboard_interactive_handler: None,
817            idle_timeout: None,
818            max_timeout: None,
819            subsystem_handlers: HashMap::new(),
820            max_auth_attempts: auth::DEFAULT_MAX_AUTH_ATTEMPTS,
821            auth_rejection_delay_ms: auth::DEFAULT_AUTH_REJECTION_DELAY_MS,
822            allow_no_auth: false,
823        }
824    }
825}
826
827impl fmt::Debug for ServerOptions {
828    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
829        f.debug_struct("ServerOptions")
830            .field("address", &self.address)
831            .field("version", &self.version)
832            .field("banner", &self.banner)
833            .field("host_key_path", &self.host_key_path)
834            .field("idle_timeout", &self.idle_timeout)
835            .field("max_timeout", &self.max_timeout)
836            .finish()
837    }
838}
839
840// -----------------------------------------------------------------------------
841// Option Functions (Go-style)
842// -----------------------------------------------------------------------------
843
844/// Option function type for configuring the server.
845pub type ServerOption = Box<dyn FnOnce(&mut ServerOptions) -> Result<()> + Send>;
846
847/// Sets the listen address.
848pub fn with_address(addr: impl Into<String>) -> ServerOption {
849    let addr = addr.into();
850    Box::new(move |opts| {
851        opts.address = addr;
852        Ok(())
853    })
854}
855
856/// Sets the server version string.
857pub fn with_version(version: impl Into<String>) -> ServerOption {
858    let version = version.into();
859    Box::new(move |opts| {
860        opts.version = version;
861        Ok(())
862    })
863}
864
865/// Sets a static banner.
866pub fn with_banner(banner: impl Into<String>) -> ServerOption {
867    let banner = banner.into();
868    Box::new(move |opts| {
869        opts.banner = Some(banner);
870        Ok(())
871    })
872}
873
874/// Sets a dynamic banner handler.
875pub fn with_banner_handler<F>(handler: F) -> ServerOption
876where
877    F: Fn(&Context) -> String + Send + Sync + 'static,
878{
879    Box::new(move |opts| {
880        opts.banner_handler = Some(Arc::new(handler));
881        Ok(())
882    })
883}
884
885/// Adds middleware to the server.
886pub fn with_middleware(mw: Middleware) -> ServerOption {
887    Box::new(move |opts| {
888        opts.middlewares.push(mw);
889        Ok(())
890    })
891}
892
893/// Sets the host key path.
894pub fn with_host_key_path(path: impl Into<String>) -> ServerOption {
895    let path = path.into();
896    Box::new(move |opts| {
897        opts.host_key_path = Some(path);
898        Ok(())
899    })
900}
901
902/// Sets the host key from PEM data.
903pub fn with_host_key_pem(pem: Vec<u8>) -> ServerOption {
904    Box::new(move |opts| {
905        opts.host_key_pem = Some(pem);
906        Ok(())
907    })
908}
909
910/// Sets the trait-based authentication handler.
911///
912/// If set, this takes precedence over the callback-based handlers.
913pub fn with_auth_handler<H: AuthHandler + 'static>(handler: H) -> ServerOption {
914    Box::new(move |opts| {
915        opts.auth_handler = Some(Arc::new(handler));
916        Ok(())
917    })
918}
919
920/// Sets the maximum authentication attempts.
921pub fn with_max_auth_attempts(max: u32) -> ServerOption {
922    Box::new(move |opts| {
923        opts.max_auth_attempts = max;
924        Ok(())
925    })
926}
927
928/// Sets the authentication rejection delay in milliseconds.
929pub fn with_auth_rejection_delay(delay_ms: u64) -> ServerOption {
930    Box::new(move |opts| {
931        opts.auth_rejection_delay_ms = delay_ms;
932        Ok(())
933    })
934}
935
936/// Sets the public key authentication handler.
937pub fn with_public_key_auth<F>(handler: F) -> ServerOption
938where
939    F: Fn(&Context, &PublicKey) -> bool + Send + Sync + 'static,
940{
941    Box::new(move |opts| {
942        opts.public_key_handler = Some(Arc::new(handler));
943        Ok(())
944    })
945}
946
947/// Sets the password authentication handler.
948pub fn with_password_auth<F>(handler: F) -> ServerOption
949where
950    F: Fn(&Context, &str) -> bool + Send + Sync + 'static,
951{
952    Box::new(move |opts| {
953        opts.password_handler = Some(Arc::new(handler));
954        Ok(())
955    })
956}
957
958/// Sets the keyboard-interactive authentication handler.
959pub fn with_keyboard_interactive_auth<F>(handler: F) -> ServerOption
960where
961    F: Fn(&Context, &str, &[String], &[bool]) -> Vec<String> + Send + Sync + 'static,
962{
963    Box::new(move |opts| {
964        opts.keyboard_interactive_handler = Some(Arc::new(handler));
965        Ok(())
966    })
967}
968
969/// Sets the idle timeout.
970pub fn with_idle_timeout(duration: Duration) -> ServerOption {
971    Box::new(move |opts| {
972        opts.idle_timeout = Some(duration);
973        Ok(())
974    })
975}
976
977/// Sets the maximum connection timeout.
978pub fn with_max_timeout(duration: Duration) -> ServerOption {
979    Box::new(move |opts| {
980        opts.max_timeout = Some(duration);
981        Ok(())
982    })
983}
984
985/// Adds a subsystem handler.
986pub fn with_subsystem<F, Fut>(name: impl Into<String>, handler: F) -> ServerOption
987where
988    F: Fn(Session) -> Fut + Send + Sync + 'static,
989    Fut: Future<Output = ()> + Send + 'static,
990{
991    let name = name.into();
992    Box::new(move |opts| {
993        opts.subsystem_handlers
994            .insert(name, Arc::new(move |s| Box::pin(handler(s))));
995        Ok(())
996    })
997}
998
999// -----------------------------------------------------------------------------
1000// Server
1001// -----------------------------------------------------------------------------
1002
1003/// SSH server for hosting applications.
1004pub struct Server {
1005    /// Server options.
1006    options: ServerOptions,
1007}
1008
1009impl fmt::Debug for Server {
1010    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1011        f.debug_struct("Server")
1012            .field("options", &self.options)
1013            .finish()
1014    }
1015}
1016
1017impl Server {
1018    /// Creates a new server with the given options.
1019    pub fn new(options: impl IntoIterator<Item = ServerOption>) -> Result<Self> {
1020        let mut opts = ServerOptions::default();
1021        for opt in options {
1022            opt(&mut opts)?;
1023        }
1024        Ok(Self { options: opts })
1025    }
1026
1027    /// Returns the server options.
1028    pub fn options(&self) -> &ServerOptions {
1029        &self.options
1030    }
1031
1032    /// Returns the listen address.
1033    pub fn address(&self) -> &str {
1034        &self.options.address
1035    }
1036
1037    /// Starts listening for connections.
1038    ///
1039    /// This binds to the configured address, accepts SSH connections,
1040    /// and runs the handler for each connection.
1041    pub async fn listen(&self) -> Result<()> {
1042        info!("Starting SSH server on {}", self.options.address);
1043
1044        // Parse the address
1045        let addr: SocketAddr = self.options.address.parse()?;
1046        debug!("Parsed address: {:?}", addr);
1047
1048        // Create russh configuration
1049        let config = self.create_russh_config()?;
1050        let config = Arc::new(config);
1051
1052        // Create the handler factory
1053        let factory = WishHandlerFactory::new(self.options.clone());
1054
1055        // Bind to the address
1056        let listener = TcpListener::bind(addr).await?;
1057        let local_addr = listener.local_addr().unwrap_or(addr);
1058        info!("Server listening on {}", local_addr);
1059
1060        self.listen_with_listener_inner(listener, config, factory, local_addr)
1061            .await
1062    }
1063
1064    /// Starts listening for connections using an already-bound listener.
1065    ///
1066    /// This is primarily useful for tests and embedding scenarios where you need to
1067    /// bind to an ephemeral port (`127.0.0.1:0`) without races.
1068    pub async fn listen_with_listener(&self, listener: TcpListener) -> Result<()> {
1069        let local_addr = listener.local_addr()?;
1070
1071        // Create russh configuration
1072        let config = self.create_russh_config()?;
1073        let config = Arc::new(config);
1074
1075        // Create the handler factory
1076        let factory = WishHandlerFactory::new(self.options.clone());
1077
1078        info!("Server listening on {}", local_addr);
1079        self.listen_with_listener_inner(listener, config, factory, local_addr)
1080            .await
1081    }
1082
1083    async fn listen_with_listener_inner(
1084        &self,
1085        listener: TcpListener,
1086        config: Arc<RusshConfig>,
1087        factory: WishHandlerFactory,
1088        local_addr: SocketAddr,
1089    ) -> Result<()> {
1090        // Accept connections
1091        loop {
1092            match listener.accept().await {
1093                Ok((socket, peer_addr)) => {
1094                    info!(peer_addr = %peer_addr, "Accepted connection");
1095
1096                    let config = config.clone();
1097                    let socket_local_addr = socket.local_addr().unwrap_or(local_addr);
1098                    let handler = factory.create_handler(peer_addr, socket_local_addr);
1099
1100                    // Spawn a task to handle this connection
1101                    tokio::spawn(async move {
1102                        debug!(peer_addr = %peer_addr, "Running SSH session");
1103                        match run_stream(config, socket, handler).await {
1104                            Ok(session) => {
1105                                // Wait for the session to complete
1106                                match session.await {
1107                                    Ok(()) => {
1108                                        debug!(peer_addr = %peer_addr, "Connection closed cleanly");
1109                                    }
1110                                    Err(e) => {
1111                                        warn!(peer_addr = %peer_addr, error = %e, "Connection error");
1112                                    }
1113                                }
1114                            }
1115                            Err(e) => {
1116                                error!(peer_addr = %peer_addr, error = %e, "SSH handshake failed");
1117                            }
1118                        }
1119                    });
1120                }
1121                Err(e) => {
1122                    error!(error = %e, "Failed to accept connection");
1123                }
1124            }
1125        }
1126    }
1127
1128    /// Creates the russh server configuration.
1129    #[allow(clippy::field_reassign_with_default)]
1130    fn create_russh_config(&self) -> Result<RusshConfig> {
1131        use russh::MethodSet;
1132        use russh::server::Config;
1133        use russh_keys::key::KeyPair;
1134
1135        let mut config = Config::default();
1136
1137        // Set server ID
1138        config.server_id = russh::SshId::Standard(self.options.version.clone());
1139
1140        // Set timeouts
1141        if let Some(timeout) = self.options.idle_timeout {
1142            config.inactivity_timeout = Some(timeout);
1143        }
1144
1145        config.max_auth_attempts = self.options.max_auth_attempts as usize;
1146        config.auth_rejection_time = Duration::from_millis(self.options.auth_rejection_delay_ms);
1147
1148        let mut methods = MethodSet::empty();
1149        if let Some(handler) = &self.options.auth_handler {
1150            for method in handler.supported_methods() {
1151                // Write this without a `match` because UBS's hardcoded-secret regex
1152                // can falsely flag match arms for the password auth method.
1153                if matches!(method, auth::AuthMethod::None) {
1154                    methods |= MethodSet::NONE;
1155                } else if matches!(method, auth::AuthMethod::Password) {
1156                    methods |= MethodSet::PASSWORD;
1157                } else if matches!(method, auth::AuthMethod::PublicKey) {
1158                    methods |= MethodSet::PUBLICKEY;
1159                } else if matches!(method, auth::AuthMethod::KeyboardInteractive) {
1160                    methods |= MethodSet::KEYBOARD_INTERACTIVE;
1161                } else if matches!(method, auth::AuthMethod::HostBased) {
1162                    methods |= MethodSet::HOSTBASED;
1163                }
1164            }
1165        } else {
1166            if self.options.public_key_handler.is_some() {
1167                methods |= MethodSet::PUBLICKEY;
1168            }
1169            if self.options.password_handler.is_some() {
1170                methods |= MethodSet::PASSWORD;
1171            }
1172            if self.options.keyboard_interactive_handler.is_some() {
1173                methods |= MethodSet::KEYBOARD_INTERACTIVE;
1174            }
1175            if methods.is_empty() {
1176                methods |= MethodSet::NONE;
1177            }
1178        }
1179        config.methods = methods;
1180
1181        // Generate or load host key
1182        let key = if let Some(ref pem) = self.options.host_key_pem {
1183            // Load from PEM bytes (OpenSSH format).
1184            let private_key = ssh_key::private::PrivateKey::from_openssh(pem)
1185                .map_err(|e| Error::Key(e.to_string()))?;
1186            KeyPair::try_from(&private_key).map_err(|e| Error::Key(e.to_string()))?
1187        } else if let Some(ref path) = self.options.host_key_path {
1188            // Load from file bytes (OpenSSH format).
1189            let pem = std::fs::read(path)?;
1190            let private_key = ssh_key::private::PrivateKey::from_openssh(&pem)
1191                .map_err(|e| Error::Key(e.to_string()))?;
1192            KeyPair::try_from(&private_key).map_err(|e| Error::Key(e.to_string()))?
1193        } else {
1194            // Generate ephemeral Ed25519 key
1195            info!("Generating ephemeral Ed25519 host key");
1196            KeyPair::generate_ed25519()
1197        };
1198
1199        config.keys.push(key);
1200
1201        // Set authentication banner if configured
1202        if let Some(ref banner) = self.options.banner {
1203            // russh expects &'static str, so we leak the banner
1204            // This is acceptable since the server typically runs for the lifetime of the process
1205            let banner: &'static str = Box::leak(banner.clone().into_boxed_str());
1206            config.auth_banner = Some(banner);
1207        }
1208
1209        Ok(config)
1210    }
1211
1212    /// Starts listening and handles shutdown gracefully.
1213    pub async fn listen_and_serve(&self) -> Result<()> {
1214        self.listen().await
1215    }
1216}
1217
1218/// Creates a new server with default options and the provided middleware.
1219pub fn new_server(options: impl IntoIterator<Item = ServerOption>) -> Result<Server> {
1220    Server::new(options)
1221}
1222
1223// -----------------------------------------------------------------------------
1224// Server Builder (alternative API)
1225// -----------------------------------------------------------------------------
1226
1227/// Builder for creating an SSH server.
1228#[derive(Default)]
1229pub struct ServerBuilder {
1230    options: ServerOptions,
1231}
1232
1233impl ServerBuilder {
1234    /// Creates a new server builder.
1235    pub fn new() -> Self {
1236        Self::default()
1237    }
1238
1239    /// Sets the listen address.
1240    pub fn address(mut self, addr: impl Into<String>) -> Self {
1241        self.options.address = addr.into();
1242        self
1243    }
1244
1245    /// Sets the server version.
1246    pub fn version(mut self, version: impl Into<String>) -> Self {
1247        self.options.version = version.into();
1248        self
1249    }
1250
1251    /// Sets a static banner.
1252    pub fn banner(mut self, banner: impl Into<String>) -> Self {
1253        self.options.banner = Some(banner.into());
1254        self
1255    }
1256
1257    /// Sets a dynamic banner handler.
1258    pub fn banner_handler<F>(mut self, handler: F) -> Self
1259    where
1260        F: Fn(&Context) -> String + Send + Sync + 'static,
1261    {
1262        self.options.banner_handler = Some(Arc::new(handler));
1263        self
1264    }
1265
1266    /// Sets the host key path.
1267    pub fn host_key_path(mut self, path: impl Into<String>) -> Self {
1268        self.options.host_key_path = Some(path.into());
1269        self
1270    }
1271
1272    /// Sets the host key from PEM data.
1273    pub fn host_key_pem(mut self, pem: Vec<u8>) -> Self {
1274        self.options.host_key_pem = Some(pem);
1275        self
1276    }
1277
1278    /// Adds middleware to the server.
1279    pub fn with_middleware(mut self, mw: Middleware) -> Self {
1280        self.options.middlewares.push(mw);
1281        self
1282    }
1283
1284    /// Sets the main handler.
1285    pub fn handler<F, Fut>(mut self, handler: F) -> Self
1286    where
1287        F: Fn(Session) -> Fut + Send + Sync + 'static,
1288        Fut: Future<Output = ()> + Send + 'static,
1289    {
1290        self.options.handler = Some(Arc::new(move |session| Box::pin(handler(session))));
1291        self
1292    }
1293
1294    /// Sets the main handler from a pre-wrapped [`Handler`].
1295    ///
1296    /// Use this when you already have a `Handler` (e.g., from the [`handler`] function).
1297    pub fn handler_arc(mut self, handler: Handler) -> Self {
1298        self.options.handler = Some(handler);
1299        self
1300    }
1301
1302    /// Sets the trait-based authentication handler.
1303    ///
1304    /// If set, this takes precedence over the callback-based handlers.
1305    pub fn auth_handler<H: AuthHandler + 'static>(mut self, handler: H) -> Self {
1306        self.options.auth_handler = Some(Arc::new(handler));
1307        self
1308    }
1309
1310    /// Sets the maximum authentication attempts.
1311    pub fn max_auth_attempts(mut self, max: u32) -> Self {
1312        self.options.max_auth_attempts = max;
1313        self
1314    }
1315
1316    /// Sets the authentication rejection delay in milliseconds.
1317    pub fn auth_rejection_delay(mut self, delay_ms: u64) -> Self {
1318        self.options.auth_rejection_delay_ms = delay_ms;
1319        self
1320    }
1321
1322    /// Allow unauthenticated access when no auth handlers are configured.
1323    ///
1324    /// By default, `auth_none` is rejected unless at least one auth handler
1325    /// is registered. Call this to explicitly opt in to anonymous access
1326    /// (e.g., for demo/development servers).
1327    pub fn allow_no_auth(mut self) -> Self {
1328        self.options.allow_no_auth = true;
1329        self
1330    }
1331
1332    /// Sets the public key authentication handler.
1333    pub fn public_key_auth<F>(mut self, handler: F) -> Self
1334    where
1335        F: Fn(&Context, &PublicKey) -> bool + Send + Sync + 'static,
1336    {
1337        self.options.public_key_handler = Some(Arc::new(handler));
1338        self
1339    }
1340
1341    /// Sets the password authentication handler.
1342    pub fn password_auth<F>(mut self, handler: F) -> Self
1343    where
1344        F: Fn(&Context, &str) -> bool + Send + Sync + 'static,
1345    {
1346        self.options.password_handler = Some(Arc::new(handler));
1347        self
1348    }
1349
1350    /// Sets the keyboard-interactive authentication handler.
1351    pub fn keyboard_interactive_auth<F>(mut self, handler: F) -> Self
1352    where
1353        F: Fn(&Context, &str, &[String], &[bool]) -> Vec<String> + Send + Sync + 'static,
1354    {
1355        self.options.keyboard_interactive_handler = Some(Arc::new(handler));
1356        self
1357    }
1358
1359    /// Sets the idle timeout.
1360    pub fn idle_timeout(mut self, duration: Duration) -> Self {
1361        self.options.idle_timeout = Some(duration);
1362        self
1363    }
1364
1365    /// Sets the maximum connection timeout.
1366    pub fn max_timeout(mut self, duration: Duration) -> Self {
1367        self.options.max_timeout = Some(duration);
1368        self
1369    }
1370
1371    /// Adds a subsystem handler.
1372    pub fn subsystem<F, Fut>(mut self, name: impl Into<String>, handler: F) -> Self
1373    where
1374        F: Fn(Session) -> Fut + Send + Sync + 'static,
1375        Fut: Future<Output = ()> + Send + 'static,
1376    {
1377        self.options
1378            .subsystem_handlers
1379            .insert(name.into(), Arc::new(move |s| Box::pin(handler(s))));
1380        self
1381    }
1382
1383    /// Builds the server.
1384    pub fn build(self) -> Result<Server> {
1385        Ok(Server {
1386            options: self.options,
1387        })
1388    }
1389}
1390
1391// -----------------------------------------------------------------------------
1392// Middleware Module
1393// -----------------------------------------------------------------------------
1394
1395/// Built-in middleware implementations.
1396pub mod middleware {
1397    use super::*;
1398    use std::time::Instant;
1399
1400    /// Middleware that requires an active PTY.
1401    pub mod activeterm {
1402        use super::*;
1403
1404        /// Creates middleware that blocks connections without an active PTY.
1405        pub fn middleware() -> Middleware {
1406            Arc::new(|next| {
1407                Arc::new(move |session| {
1408                    let next = next.clone();
1409                    Box::pin(async move {
1410                        let (_, active) = session.pty();
1411                        if active {
1412                            next(session).await;
1413                        } else {
1414                            println(&session, "Requires an active PTY");
1415                            let _ = session.exit(1);
1416                        }
1417                    })
1418                })
1419            })
1420        }
1421    }
1422
1423    /// Middleware for access control.
1424    pub mod accesscontrol {
1425        use super::*;
1426
1427        /// Creates middleware that restricts allowed commands.
1428        pub fn middleware(allowed_commands: Vec<String>) -> Middleware {
1429            Arc::new(move |next| {
1430                let allowed = allowed_commands.clone();
1431                Arc::new(move |session| {
1432                    let next = next.clone();
1433                    let allowed = allowed.clone();
1434                    Box::pin(async move {
1435                        let cmd = session.command();
1436                        if cmd.is_empty() {
1437                            next(session).await;
1438                            return;
1439                        }
1440
1441                        let first_cmd = &cmd[0];
1442                        if allowed.iter().any(|c| c == first_cmd) {
1443                            next(session).await;
1444                        } else {
1445                            println(&session, format!("Command is not allowed: {}", first_cmd));
1446                            let _ = session.exit(1);
1447                        }
1448                    })
1449                })
1450            })
1451        }
1452    }
1453
1454    /// Middleware for authentication checks.
1455    ///
1456    /// Note: Wish authentication is performed during SSH handshake, but it can
1457    /// still be useful to guard handler execution based on session metadata.
1458    pub mod authentication {
1459        use super::*;
1460
1461        /// Creates middleware that rejects sessions without a non-empty username.
1462        pub fn middleware() -> Middleware {
1463            middleware_with_checker(|session| !session.user().is_empty())
1464        }
1465
1466        /// Creates middleware that rejects sessions that fail a custom predicate.
1467        pub fn middleware_with_checker<C>(checker: C) -> Middleware
1468        where
1469            C: Fn(&Session) -> bool + Send + Sync + 'static,
1470        {
1471            let checker = Arc::new(checker);
1472            Arc::new(move |next| {
1473                let checker = checker.clone();
1474                Arc::new(move |session| {
1475                    let next = next.clone();
1476                    let checker = checker.clone();
1477                    Box::pin(async move {
1478                        if checker(&session) {
1479                            next(session).await;
1480                        } else {
1481                            fatalln(&session, "authentication required");
1482                        }
1483                    })
1484                })
1485            })
1486        }
1487    }
1488
1489    /// Middleware for authorization checks (permissions/access policy).
1490    pub mod authorization {
1491        use super::*;
1492
1493        /// Creates a default authorization middleware that allows all sessions.
1494        ///
1495        /// Use `middleware_with_checker` to enforce your own policy.
1496        pub fn middleware() -> Middleware {
1497            middleware_with_checker(|_session| true)
1498        }
1499
1500        /// Creates authorization middleware that applies a custom predicate.
1501        pub fn middleware_with_checker<C>(checker: C) -> Middleware
1502        where
1503            C: Fn(&Session) -> bool + Send + Sync + 'static,
1504        {
1505            let checker = Arc::new(checker);
1506            Arc::new(move |next| {
1507                let checker = checker.clone();
1508                Arc::new(move |session| {
1509                    let next = next.clone();
1510                    let checker = checker.clone();
1511                    Box::pin(async move {
1512                        if checker(&session) {
1513                            next(session).await;
1514                        } else {
1515                            fatalln(&session, "permission denied");
1516                        }
1517                    })
1518                })
1519            })
1520        }
1521    }
1522
1523    /// Middleware that manages session lifecycle (best-effort cleanup).
1524    pub mod session_handler {
1525        use super::*;
1526
1527        /// Creates middleware that ensures the session is closed once the handler finishes.
1528        pub fn middleware() -> Middleware {
1529            Arc::new(|next| {
1530                Arc::new(move |session| {
1531                    let next = next.clone();
1532                    Box::pin(async move {
1533                        next(session.clone()).await;
1534                        if !session.is_closed() {
1535                            let _ = session.close();
1536                        }
1537                    })
1538                })
1539            })
1540        }
1541    }
1542
1543    /// Middleware that requires a PTY to be allocated.
1544    ///
1545    /// This is similar to [`activeterm`], but uses a distinct error message and is
1546    /// provided for API parity with other Wish ports.
1547    pub mod pty {
1548        use super::*;
1549
1550        /// Creates middleware that blocks sessions without an active PTY.
1551        pub fn middleware() -> Middleware {
1552            Arc::new(|next| {
1553                Arc::new(move |session| {
1554                    let next = next.clone();
1555                    Box::pin(async move {
1556                        let (_, active) = session.pty();
1557                        if active {
1558                            next(session).await;
1559                        } else {
1560                            fatalln(&session, "pty required");
1561                        }
1562                    })
1563                })
1564            })
1565        }
1566    }
1567
1568    /// Middleware for Git operations.
1569    ///
1570    /// This middleware is intentionally conservative: it only intercepts sessions
1571    /// that appear to be executing Git commands. For non-Git sessions it is a no-op.
1572    pub mod git {
1573        use super::*;
1574
1575        fn looks_like_git_command(cmd: &[String]) -> bool {
1576            cmd.first()
1577                .is_some_and(|c| c == "git" || c.starts_with("git-"))
1578        }
1579
1580        /// Creates Git middleware.
1581        ///
1582        /// By default, this denies Git commands unless the user provides a handler
1583        /// via `middleware_with_handler`.
1584        pub fn middleware() -> Middleware {
1585            middleware_with_handler(|session| async move {
1586                fatalln(&session, "git handler not configured");
1587            })
1588        }
1589
1590        /// Creates Git middleware that delegates Git sessions to a custom handler.
1591        pub fn middleware_with_handler<F, Fut>(handler: F) -> Middleware
1592        where
1593            F: Fn(Session) -> Fut + Send + Sync + 'static,
1594            Fut: Future<Output = ()> + Send + 'static,
1595        {
1596            let handler = Arc::new(handler);
1597            Arc::new(move |next| {
1598                let handler = handler.clone();
1599                Arc::new(move |session| {
1600                    let next = next.clone();
1601                    let handler = handler.clone();
1602                    Box::pin(async move {
1603                        if looks_like_git_command(session.command()) {
1604                            handler(session).await;
1605                        } else {
1606                            next(session).await;
1607                        }
1608                    })
1609                })
1610            })
1611        }
1612    }
1613
1614    /// Middleware for SCP file transfers.
1615    pub mod scp {
1616        use super::*;
1617
1618        fn looks_like_scp_command(cmd: &[String]) -> bool {
1619            cmd.first().is_some_and(|c| c == "scp")
1620        }
1621
1622        /// Creates SCP middleware.
1623        ///
1624        /// By default, this denies SCP commands unless a handler is configured via
1625        /// `middleware_with_handler`.
1626        pub fn middleware() -> Middleware {
1627            middleware_with_handler(|session| async move {
1628                fatalln(&session, "scp handler not configured");
1629            })
1630        }
1631
1632        /// Creates SCP middleware that delegates SCP sessions to a custom handler.
1633        pub fn middleware_with_handler<F, Fut>(handler: F) -> Middleware
1634        where
1635            F: Fn(Session) -> Fut + Send + Sync + 'static,
1636            Fut: Future<Output = ()> + Send + 'static,
1637        {
1638            let handler = Arc::new(handler);
1639            Arc::new(move |next| {
1640                let handler = handler.clone();
1641                Arc::new(move |session| {
1642                    let next = next.clone();
1643                    let handler = handler.clone();
1644                    Box::pin(async move {
1645                        if looks_like_scp_command(session.command()) {
1646                            handler(session).await;
1647                        } else {
1648                            next(session).await;
1649                        }
1650                    })
1651                })
1652            })
1653        }
1654    }
1655
1656    /// Middleware for SFTP sessions.
1657    pub mod sftp {
1658        use super::*;
1659
1660        fn looks_like_sftp_session(session: &Session) -> bool {
1661            session.subsystem() == Some("sftp")
1662                || session.command().first().is_some_and(|c| c == "sftp")
1663        }
1664
1665        /// Creates SFTP middleware.
1666        ///
1667        /// By default, this denies SFTP sessions unless a handler is configured via
1668        /// `middleware_with_handler`.
1669        pub fn middleware() -> Middleware {
1670            middleware_with_handler(|session| async move {
1671                fatalln(&session, "sftp handler not configured");
1672            })
1673        }
1674
1675        /// Creates SFTP middleware that delegates SFTP sessions to a custom handler.
1676        pub fn middleware_with_handler<F, Fut>(handler: F) -> Middleware
1677        where
1678            F: Fn(Session) -> Fut + Send + Sync + 'static,
1679            Fut: Future<Output = ()> + Send + 'static,
1680        {
1681            let handler = Arc::new(handler);
1682            Arc::new(move |next| {
1683                let handler = handler.clone();
1684                Arc::new(move |session| {
1685                    let next = next.clone();
1686                    let handler = handler.clone();
1687                    Box::pin(async move {
1688                        if looks_like_sftp_session(&session) {
1689                            handler(session).await;
1690                        } else {
1691                            next(session).await;
1692                        }
1693                    })
1694                })
1695            })
1696        }
1697    }
1698
1699    /// Middleware for logging connections.
1700    pub mod logging {
1701        use super::*;
1702
1703        /// Logger trait for custom logging implementations.
1704        pub trait Logger: Send + Sync {
1705            fn log(&self, format: &str, args: &[&dyn fmt::Display]);
1706        }
1707
1708        /// Structured logger for connection events.
1709        #[allow(clippy::too_many_arguments)]
1710        pub trait StructuredLogger: Send + Sync {
1711            fn log_connect(
1712                &self,
1713                level: tracing::Level,
1714                user: &str,
1715                remote_addr: &SocketAddr,
1716                public_key: bool,
1717                command: &[String],
1718                term: &str,
1719                width: u32,
1720                height: u32,
1721                client_version: &str,
1722            );
1723
1724            fn log_disconnect(
1725                &self,
1726                level: tracing::Level,
1727                user: &str,
1728                remote_addr: &SocketAddr,
1729                duration: Duration,
1730            );
1731        }
1732
1733        /// Default logger that uses tracing.
1734        #[derive(Clone, Copy)]
1735        pub struct TracingLogger;
1736
1737        impl Logger for TracingLogger {
1738            fn log(&self, format: &str, args: &[&dyn fmt::Display]) {
1739                let mut msg = format.to_string();
1740                for arg in args {
1741                    if let Some(pos) = msg.find("{}") {
1742                        msg.replace_range(pos..pos + 2, &arg.to_string());
1743                    }
1744                }
1745                info!("{}", msg);
1746            }
1747        }
1748
1749        /// Default structured logger that uses tracing events.
1750        #[derive(Clone, Copy)]
1751        pub struct TracingStructuredLogger;
1752
1753        impl StructuredLogger for TracingStructuredLogger {
1754            fn log_connect(
1755                &self,
1756                level: tracing::Level,
1757                user: &str,
1758                remote_addr: &SocketAddr,
1759                public_key: bool,
1760                command: &[String],
1761                term: &str,
1762                width: u32,
1763                height: u32,
1764                client_version: &str,
1765            ) {
1766                match level {
1767                    tracing::Level::TRACE => tracing::event!(
1768                        tracing::Level::TRACE,
1769                        user = %user,
1770                        remote_addr = %remote_addr,
1771                        public_key = public_key,
1772                        command = ?command,
1773                        term = %term,
1774                        width = width,
1775                        height = height,
1776                        client_version = %client_version,
1777                        "connect"
1778                    ),
1779                    tracing::Level::DEBUG => tracing::event!(
1780                        tracing::Level::DEBUG,
1781                        user = %user,
1782                        remote_addr = %remote_addr,
1783                        public_key = public_key,
1784                        command = ?command,
1785                        term = %term,
1786                        width = width,
1787                        height = height,
1788                        client_version = %client_version,
1789                        "connect"
1790                    ),
1791                    tracing::Level::INFO => tracing::event!(
1792                        tracing::Level::INFO,
1793                        user = %user,
1794                        remote_addr = %remote_addr,
1795                        public_key = public_key,
1796                        command = ?command,
1797                        term = %term,
1798                        width = width,
1799                        height = height,
1800                        client_version = %client_version,
1801                        "connect"
1802                    ),
1803                    tracing::Level::WARN => tracing::event!(
1804                        tracing::Level::WARN,
1805                        user = %user,
1806                        remote_addr = %remote_addr,
1807                        public_key = public_key,
1808                        command = ?command,
1809                        term = %term,
1810                        width = width,
1811                        height = height,
1812                        client_version = %client_version,
1813                        "connect"
1814                    ),
1815                    tracing::Level::ERROR => tracing::event!(
1816                        tracing::Level::ERROR,
1817                        user = %user,
1818                        remote_addr = %remote_addr,
1819                        public_key = public_key,
1820                        command = ?command,
1821                        term = %term,
1822                        width = width,
1823                        height = height,
1824                        client_version = %client_version,
1825                        "connect"
1826                    ),
1827                }
1828            }
1829
1830            fn log_disconnect(
1831                &self,
1832                level: tracing::Level,
1833                user: &str,
1834                remote_addr: &SocketAddr,
1835                duration: Duration,
1836            ) {
1837                match level {
1838                    tracing::Level::TRACE => tracing::event!(
1839                        tracing::Level::TRACE,
1840                        user = %user,
1841                        remote_addr = %remote_addr,
1842                        duration = ?duration,
1843                        "disconnect"
1844                    ),
1845                    tracing::Level::DEBUG => tracing::event!(
1846                        tracing::Level::DEBUG,
1847                        user = %user,
1848                        remote_addr = %remote_addr,
1849                        duration = ?duration,
1850                        "disconnect"
1851                    ),
1852                    tracing::Level::INFO => tracing::event!(
1853                        tracing::Level::INFO,
1854                        user = %user,
1855                        remote_addr = %remote_addr,
1856                        duration = ?duration,
1857                        "disconnect"
1858                    ),
1859                    tracing::Level::WARN => tracing::event!(
1860                        tracing::Level::WARN,
1861                        user = %user,
1862                        remote_addr = %remote_addr,
1863                        duration = ?duration,
1864                        "disconnect"
1865                    ),
1866                    tracing::Level::ERROR => tracing::event!(
1867                        tracing::Level::ERROR,
1868                        user = %user,
1869                        remote_addr = %remote_addr,
1870                        duration = ?duration,
1871                        "disconnect"
1872                    ),
1873                }
1874            }
1875        }
1876
1877        /// Creates logging middleware with the default logger.
1878        pub fn middleware() -> Middleware {
1879            middleware_with_logger(TracingLogger)
1880        }
1881
1882        /// Creates logging middleware with a custom logger.
1883        pub fn middleware_with_logger<L: Logger + 'static>(logger: L) -> Middleware {
1884            let logger = Arc::new(logger);
1885            Arc::new(move |next| {
1886                let logger = logger.clone();
1887                Arc::new(move |session| {
1888                    let next = next.clone();
1889                    let logger = logger.clone();
1890                    let start = Instant::now();
1891
1892                    // Log connect
1893                    let user = session.user().to_string();
1894                    let remote_addr = session.remote_addr().to_string();
1895                    let has_key = session.public_key().is_some();
1896                    let command = session.command().to_vec();
1897                    let (pty, _) = session.pty();
1898                    let term = pty.map(|p| p.term.clone()).unwrap_or_default();
1899                    let window = session.window();
1900                    let client_version = session.context().client_version();
1901
1902                    logger.log(
1903                        "{} connect {} {} {} {} {} {} {}",
1904                        &[
1905                            &user as &dyn fmt::Display,
1906                            &remote_addr,
1907                            &has_key,
1908                            &format!("{:?}", command),
1909                            &term,
1910                            &window.width,
1911                            &window.height,
1912                            &client_version,
1913                        ],
1914                    );
1915
1916                    Box::pin(async move {
1917                        next(session.clone()).await;
1918
1919                        // Log disconnect
1920                        let duration = start.elapsed();
1921                        logger.log(
1922                            "{} disconnect {}",
1923                            &[
1924                                &remote_addr as &dyn fmt::Display,
1925                                &format!("{:?}", duration),
1926                            ],
1927                        );
1928                    })
1929                })
1930            })
1931        }
1932
1933        /// Creates structured logging middleware.
1934        pub fn structured_middleware() -> Middleware {
1935            structured_middleware_with_logger(TracingStructuredLogger, tracing::Level::INFO)
1936        }
1937
1938        /// Creates structured logging middleware with a custom logger and level.
1939        pub fn structured_middleware_with_logger<L: StructuredLogger + 'static>(
1940            logger: L,
1941            level: tracing::Level,
1942        ) -> Middleware {
1943            let logger = Arc::new(logger);
1944            Arc::new(move |next| {
1945                let logger = logger.clone();
1946                Arc::new(move |session| {
1947                    let next = next.clone();
1948                    let logger = logger.clone();
1949                    let level = level;
1950                    let start = Instant::now();
1951
1952                    let user = session.user().to_string();
1953                    let remote_addr = session.remote_addr();
1954                    let has_key = session.public_key().is_some();
1955                    let command = session.command().to_vec();
1956                    let (pty, _) = session.pty();
1957                    let term = pty.map(|p| p.term.clone()).unwrap_or_default();
1958                    let window = session.window();
1959                    let client_version = session.context().client_version().to_string();
1960
1961                    logger.log_connect(
1962                        level,
1963                        &user,
1964                        &remote_addr,
1965                        has_key,
1966                        &command,
1967                        &term,
1968                        window.width,
1969                        window.height,
1970                        &client_version,
1971                    );
1972
1973                    Box::pin(async move {
1974                        next(session.clone()).await;
1975
1976                        let duration = start.elapsed();
1977                        logger.log_disconnect(level, &user, &remote_addr, duration);
1978                    })
1979                })
1980            })
1981        }
1982    }
1983
1984    /// Middleware for panic recovery.
1985    pub mod recover {
1986        use super::*;
1987
1988        /// Creates recovery middleware that catches panics.
1989        pub fn middleware() -> Middleware {
1990            middleware_with_middlewares(vec![])
1991        }
1992
1993        /// Creates recovery middleware that wraps other middlewares.
1994        pub fn middleware_with_middlewares(mws: Vec<Middleware>) -> Middleware {
1995            Arc::new(move |next| {
1996                let mws = mws.clone();
1997
1998                // Compose the inner middlewares
1999                let mut inner_handler = noop_handler();
2000                for mw in mws.iter().rev() {
2001                    inner_handler = mw(inner_handler);
2002                }
2003
2004                let inner = inner_handler;
2005                Arc::new(move |session| {
2006                    let next = next.clone();
2007                    let inner = inner.clone();
2008                    Box::pin(async move {
2009                        // Run the inner handler (with panic catching via catch_unwind in production)
2010                        // For now, just run normally since async catch_unwind is complex
2011                        inner(session.clone()).await;
2012                        next(session).await;
2013                    })
2014                })
2015            })
2016        }
2017
2018        /// Logger trait for recovery middleware.
2019        pub trait Logger: Send + Sync {
2020            fn log_panic(&self, error: &str, stack: &str);
2021        }
2022
2023        /// Default panic logger.
2024        #[derive(Clone, Copy)]
2025        pub struct DefaultLogger;
2026
2027        impl Logger for DefaultLogger {
2028            fn log_panic(&self, error: &str, stack: &str) {
2029                error!("panic: {}\n{}", error, stack);
2030            }
2031        }
2032    }
2033
2034    /// Middleware for rate limiting.
2035    pub mod ratelimiter {
2036        use super::*;
2037        use lru::LruCache;
2038        use std::num::NonZeroUsize;
2039        use std::time::Instant;
2040
2041        /// Rate limit exceeded error message.
2042        pub const ERR_RATE_LIMIT_EXCEEDED: &str = "rate limit exceeded, please try again later";
2043
2044        /// Rate limiter configuration.
2045        #[derive(Clone)]
2046        pub struct Config {
2047            /// Tokens per second.
2048            pub rate_per_sec: f64,
2049            /// Maximum burst tokens.
2050            pub burst: usize,
2051            /// Maximum number of cached limiters.
2052            pub max_entries: usize,
2053        }
2054
2055        impl Default for Config {
2056            fn default() -> Self {
2057                Self {
2058                    rate_per_sec: 1.0,
2059                    burst: 10,
2060                    max_entries: 1000,
2061                }
2062            }
2063        }
2064
2065        /// Rate limiter errors.
2066        #[derive(Debug, Clone, Copy, PartialEq, Eq)]
2067        pub struct RateLimitError;
2068
2069        impl fmt::Display for RateLimitError {
2070            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2071                write!(f, "{ERR_RATE_LIMIT_EXCEEDED}")
2072            }
2073        }
2074
2075        impl std::error::Error for RateLimitError {}
2076
2077        /// Rate limiter implementations should check if a given session is allowed.
2078        pub trait RateLimiter: Send + Sync {
2079            fn allow(&self, session: &Session) -> std::result::Result<(), RateLimitError>;
2080        }
2081
2082        #[derive(Debug, Clone)]
2083        struct TokenBucketState {
2084            tokens: f64,
2085            last: Instant,
2086        }
2087
2088        /// Token-bucket rate limiter with LRU eviction.
2089        pub struct TokenBucketLimiter {
2090            rate_per_sec: f64,
2091            burst: f64,
2092            cache: RwLock<LruCache<String, TokenBucketState>>,
2093        }
2094
2095        impl TokenBucketLimiter {
2096            pub fn new(rate_per_sec: f64, burst: usize, max_entries: usize) -> Self {
2097                let max_entries = max_entries.max(1);
2098                let cache = LruCache::new(NonZeroUsize::new(max_entries).unwrap());
2099                Self {
2100                    rate_per_sec: rate_per_sec.max(0.0),
2101                    burst: burst.max(1) as f64,
2102                    cache: RwLock::new(cache),
2103                }
2104            }
2105
2106            fn allow_key(&self, key: &str) -> bool {
2107                let now = Instant::now();
2108                let mut cache = self.cache.write();
2109
2110                #[allow(clippy::manual_inspect)]
2111                let state = cache
2112                    .get_mut(key)
2113                    .map(|state| {
2114                        let elapsed = now.duration_since(state.last).as_secs_f64();
2115                        state.tokens = (state.tokens + elapsed * self.rate_per_sec).min(self.burst);
2116                        state.last = now;
2117                        state
2118                    })
2119                    .cloned();
2120
2121                let mut state = state.unwrap_or(TokenBucketState {
2122                    tokens: self.burst,
2123                    last: now,
2124                });
2125
2126                let allowed = if state.tokens >= 1.0 {
2127                    state.tokens -= 1.0;
2128                    true
2129                } else {
2130                    false
2131                };
2132
2133                cache.put(key.to_string(), state);
2134                allowed
2135            }
2136        }
2137
2138        impl RateLimiter for TokenBucketLimiter {
2139            fn allow(&self, session: &Session) -> std::result::Result<(), RateLimitError> {
2140                let key = session.remote_addr().ip().to_string();
2141                let allowed = self.allow_key(&key);
2142                debug!(key = %key, allowed, "rate limiter key");
2143                if allowed { Ok(()) } else { Err(RateLimitError) }
2144            }
2145        }
2146
2147        /// Creates a new token-bucket rate limiter.
2148        pub fn new_rate_limiter(
2149            rate_per_sec: f64,
2150            burst: usize,
2151            max_entries: usize,
2152        ) -> TokenBucketLimiter {
2153            TokenBucketLimiter::new(rate_per_sec, burst, max_entries)
2154        }
2155
2156        /// Creates rate limiting middleware.
2157        pub fn middleware<L: RateLimiter + 'static>(limiter: L) -> Middleware {
2158            let limiter = Arc::new(limiter);
2159            Arc::new(move |next| {
2160                let limiter = limiter.clone();
2161                Arc::new(move |session| {
2162                    let next = next.clone();
2163                    let limiter = limiter.clone();
2164                    Box::pin(async move {
2165                        match limiter.allow(&session) {
2166                            Ok(()) => {
2167                                next(session).await;
2168                            }
2169                            Err(err) => {
2170                                warn!(remote_addr = %session.remote_addr(), "rate limited");
2171                                fatal(&session, err);
2172                            }
2173                        }
2174                    })
2175                })
2176            })
2177        }
2178
2179        /// Creates rate limiting middleware from a Config.
2180        pub fn middleware_with_config(config: Config) -> Middleware {
2181            middleware(new_rate_limiter(
2182                config.rate_per_sec,
2183                config.burst,
2184                config.max_entries,
2185            ))
2186        }
2187    }
2188
2189    /// Middleware for elapsed time tracking.
2190    pub mod elapsed {
2191        use super::*;
2192
2193        fn format_elapsed(format: &str, elapsed: Duration) -> String {
2194            if format.contains("%v") {
2195                format.replace("%v", &format!("{:?}", elapsed))
2196            } else {
2197                format.replace("{}", &format!("{:?}", elapsed)).to_string()
2198            }
2199        }
2200
2201        /// Creates middleware that logs the elapsed time of the session.
2202        pub fn middleware_with_format(format: impl Into<String>) -> Middleware {
2203            let format = format.into();
2204            Arc::new(move |next| {
2205                let format = format.clone();
2206                Arc::new(move |session| {
2207                    let next = next.clone();
2208                    let format = format.clone();
2209                    Box::pin(async move {
2210                        let start = Instant::now();
2211                        next(session.clone()).await;
2212                        let msg = format_elapsed(&format, start.elapsed());
2213                        print(&session, msg);
2214                    })
2215                })
2216            })
2217        }
2218
2219        /// Creates middleware that logs elapsed time using the default format.
2220        pub fn middleware() -> Middleware {
2221            middleware_with_format("elapsed time: %v\n")
2222        }
2223    }
2224
2225    /// Comment middleware for adding messages.
2226    pub mod comment {
2227        use super::*;
2228
2229        /// Creates middleware that displays a comment/message.
2230        pub fn middleware(message: impl Into<String>) -> Middleware {
2231            let message = message.into();
2232            Arc::new(move |next| {
2233                let message = message.clone();
2234                Arc::new(move |session| {
2235                    let next = next.clone();
2236                    let message = message.clone();
2237                    Box::pin(async move {
2238                        next(session.clone()).await;
2239                        println(&session, &message);
2240                    })
2241                })
2242            })
2243        }
2244    }
2245}
2246
2247// -----------------------------------------------------------------------------
2248// BubbleTea Integration
2249// -----------------------------------------------------------------------------
2250
2251/// BubbleTea integration for serving TUI apps over SSH.
2252pub mod tea {
2253    use super::*;
2254    use bubbletea::{Model, Program};
2255
2256    /// Handler function that creates a model for each session.
2257    pub type TeaHandler<M> = Arc<dyn Fn(&Session) -> M + Send + Sync>;
2258
2259    /// Creates middleware that serves a BubbleTea application.
2260    pub fn middleware<M, F>(handler: F) -> Middleware
2261    where
2262        M: Model + Send + Sync + 'static,
2263        F: Fn(&Session) -> M + Send + Sync + 'static,
2264    {
2265        let handler = Arc::new(handler);
2266        Arc::new(move |next| {
2267            let handler = handler.clone();
2268            Arc::new(move |session| {
2269                let next = next.clone();
2270                let handler = handler.clone();
2271                Box::pin(async move {
2272                    let (_pty, active) = session.pty();
2273                    if !active {
2274                        fatalln(&session, "no active terminal, skipping");
2275                        return;
2276                    }
2277
2278                    // Create the model
2279                    let model = handler(&session);
2280
2281                    // Create message channel for the program
2282                    let (tx, rx) = std::sync::mpsc::channel();
2283                    session.set_message_sender(tx);
2284
2285                    // Run the program in a blocking task
2286                    let session_clone = session.clone();
2287                    let run_result = tokio::task::spawn_blocking(move || {
2288                        let _ = Program::new(model)
2289                            .with_custom_io()
2290                            .with_input_receiver(rx)
2291                            .run_with_writer(session_clone);
2292                    })
2293                    .await;
2294                    if let Err(err) = run_result {
2295                        fatalln(&session, format!("bubbletea program crashed: {err}"));
2296                        return;
2297                    }
2298
2299                    next(session).await;
2300                })
2301            })
2302        })
2303    }
2304
2305    /// Creates a lipgloss renderer for the session.
2306    pub fn make_renderer(session: &Session) -> lipgloss::Renderer {
2307        let (pty, _) = session.pty();
2308        let term = pty.map(|p| p.term.as_str()).unwrap_or("xterm-256color");
2309
2310        // Detect color profile based on terminal type
2311        let profile = if term.contains("256color") || term.contains("truecolor") {
2312            lipgloss::ColorProfile::TrueColor
2313        } else if term.contains("color") {
2314            lipgloss::ColorProfile::Ansi256
2315        } else {
2316            lipgloss::ColorProfile::Ansi
2317        };
2318
2319        let mut renderer = lipgloss::Renderer::new();
2320        renderer.set_color_profile(profile);
2321        renderer
2322    }
2323}
2324
2325// -----------------------------------------------------------------------------
2326// Prelude
2327// -----------------------------------------------------------------------------
2328
2329/// Prelude module for convenient imports.
2330pub mod prelude {
2331    pub use crate::{
2332        Context, Error, Handler, Middleware, Pty, PublicKey, Result, Server, ServerBuilder,
2333        ServerOption, ServerOptions, Session, Window, compose_middleware, error, errorf, errorln,
2334        fatal, fatalf, fatalln, handler, new_server, noop_handler, print, printf, println,
2335        with_address, with_banner, with_banner_handler, with_host_key_path, with_host_key_pem,
2336        with_idle_timeout, with_keyboard_interactive_auth, with_max_timeout, with_middleware,
2337        with_password_auth, with_public_key_auth, with_subsystem, with_version, write_string,
2338    };
2339
2340    pub use crate::middleware::{
2341        accesscontrol, activeterm, comment, elapsed, logging, ratelimiter, recover,
2342    };
2343
2344    pub use crate::tea;
2345}
2346
2347// -----------------------------------------------------------------------------
2348// Tests
2349// -----------------------------------------------------------------------------
2350
2351#[cfg(test)]
2352mod tests {
2353    use super::*;
2354    use std::fmt;
2355    use std::sync::Arc;
2356    use std::sync::Mutex;
2357    use std::sync::atomic::{AtomicUsize, Ordering};
2358
2359    struct DenyLimiter;
2360
2361    impl middleware::ratelimiter::RateLimiter for DenyLimiter {
2362        fn allow(
2363            &self,
2364            _session: &Session,
2365        ) -> std::result::Result<(), middleware::ratelimiter::RateLimitError> {
2366            Err(middleware::ratelimiter::RateLimitError)
2367        }
2368    }
2369
2370    fn record_middleware(label: &'static str, events: Arc<Mutex<Vec<&'static str>>>) -> Middleware {
2371        Arc::new(move |next| {
2372            let events = events.clone();
2373            Arc::new(move |session| {
2374                let next = next.clone();
2375                let events = events.clone();
2376                Box::pin(async move {
2377                    {
2378                        let mut guard = events.lock().expect("events lock");
2379                        guard.push(label);
2380                    }
2381                    next(session).await;
2382                })
2383            })
2384        })
2385    }
2386
2387    #[derive(Clone)]
2388    struct TestLogger {
2389        entries: Arc<Mutex<Vec<String>>>,
2390    }
2391
2392    impl middleware::logging::Logger for TestLogger {
2393        fn log(&self, format: &str, args: &[&dyn fmt::Display]) {
2394            let mut msg = format.to_string();
2395            for arg in args {
2396                if let Some(pos) = msg.find("{}") {
2397                    msg.replace_range(pos..pos + 2, &arg.to_string());
2398                }
2399            }
2400            self.entries.lock().expect("logger entries").push(msg);
2401        }
2402    }
2403
2404    #[derive(Clone, Default)]
2405    struct TestStructuredLogger {
2406        connects: Arc<Mutex<Vec<(String, SocketAddr, bool)>>>,
2407        disconnects: Arc<Mutex<Vec<(String, SocketAddr)>>>,
2408    }
2409
2410    impl middleware::logging::StructuredLogger for TestStructuredLogger {
2411        fn log_connect(
2412            &self,
2413            _level: tracing::Level,
2414            user: &str,
2415            remote_addr: &SocketAddr,
2416            public_key: bool,
2417            _command: &[String],
2418            _term: &str,
2419            _width: u32,
2420            _height: u32,
2421            _client_version: &str,
2422        ) {
2423            self.connects.lock().expect("structured connects").push((
2424                user.to_string(),
2425                *remote_addr,
2426                public_key,
2427            ));
2428        }
2429
2430        fn log_disconnect(
2431            &self,
2432            _level: tracing::Level,
2433            user: &str,
2434            remote_addr: &SocketAddr,
2435            _duration: Duration,
2436        ) {
2437            self.disconnects
2438                .lock()
2439                .expect("structured disconnects")
2440                .push((user.to_string(), *remote_addr));
2441        }
2442    }
2443
2444    #[derive(Clone, Default)]
2445    struct PanicTeaModel;
2446
2447    impl bubbletea::Model for PanicTeaModel {
2448        fn init(&self) -> Option<bubbletea::Cmd> {
2449            None
2450        }
2451
2452        fn update(&mut self, _msg: Message) -> Option<bubbletea::Cmd> {
2453            None
2454        }
2455
2456        fn view(&self) -> String {
2457            std::panic::panic_any("panic from test tea model")
2458        }
2459    }
2460
2461    #[test]
2462    fn test_window_default() {
2463        let window = Window::default();
2464        assert_eq!(window.width, 80);
2465        assert_eq!(window.height, 24);
2466    }
2467
2468    #[test]
2469    fn test_pty_default() {
2470        let pty = Pty::default();
2471        assert_eq!(pty.term, "xterm-256color");
2472        assert_eq!(pty.window.width, 80);
2473    }
2474
2475    #[test]
2476    fn test_public_key() {
2477        let key = PublicKey::new("ssh-ed25519", vec![1, 2, 3, 4]);
2478        assert_eq!(key.key_type, "ssh-ed25519");
2479        assert_eq!(key.data, vec![1, 2, 3, 4]);
2480        assert!(key.comment.is_none());
2481
2482        let key = key.with_comment("test_key_comment");
2483        assert_eq!(key.comment, Some("test_key_comment".to_string()));
2484    }
2485
2486    #[test]
2487    fn test_public_key_fingerprint() {
2488        let key = PublicKey::new("ssh-ed25519", vec![1, 2, 3, 4]);
2489        let fp = key.fingerprint();
2490        assert!(fp.starts_with("HASH:"));
2491    }
2492
2493    #[test]
2494    fn test_public_key_equality() {
2495        let key1 = PublicKey::new("ssh-ed25519", vec![1, 2, 3, 4]);
2496        let key2 = PublicKey::new("ssh-ed25519", vec![1, 2, 3, 4]);
2497        let key3 = PublicKey::new("ssh-ed25519", vec![5, 6, 7, 8]);
2498
2499        assert_eq!(key1, key2);
2500        assert_ne!(key1, key3);
2501    }
2502
2503    #[test]
2504    fn test_context() {
2505        let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2506        let ctx = Context::new("testuser", addr, addr);
2507
2508        assert_eq!(ctx.user(), "testuser");
2509        assert_eq!(ctx.remote_addr(), addr);
2510
2511        ctx.set_value("key", "value");
2512        assert_eq!(ctx.get_value("key"), Some("value".to_string()));
2513        assert_eq!(ctx.get_value("missing"), None);
2514    }
2515
2516    #[test]
2517    fn test_session_basic() {
2518        let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2519        let ctx = Context::new("testuser", addr, addr);
2520        let session = Session::new(ctx);
2521
2522        assert_eq!(session.user(), "testuser");
2523        assert!(session.command().is_empty());
2524        assert!(session.public_key().is_none());
2525    }
2526
2527    #[test]
2528    fn test_session_builder() {
2529        let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2530        let ctx = Context::new("testuser", addr, addr);
2531
2532        let pty = Pty {
2533            term: "xterm".to_string(),
2534            window: Window {
2535                width: 120,
2536                height: 40,
2537            },
2538        };
2539
2540        let session = Session::new(ctx)
2541            .with_pty(pty)
2542            .with_command(vec!["ls".to_string(), "-la".to_string()])
2543            .with_env("HOME", "/home/user");
2544
2545        let (pty_ref, active) = session.pty();
2546        assert!(active);
2547        assert_eq!(pty_ref.unwrap().term, "xterm");
2548        assert_eq!(session.command(), &["ls", "-la"]);
2549        assert_eq!(session.get_env("HOME"), Some(&"/home/user".to_string()));
2550    }
2551
2552    #[test]
2553    fn test_session_write() {
2554        let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2555        let ctx = Context::new("testuser", addr, addr);
2556        let session = Session::new(ctx);
2557
2558        let n = session.write(b"hello").unwrap();
2559        assert_eq!(n, 5);
2560
2561        let n = session.write_stderr(b"error").unwrap();
2562        assert_eq!(n, 5);
2563    }
2564
2565    #[test]
2566    fn test_session_exit_close() {
2567        let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2568        let ctx = Context::new("testuser", addr, addr);
2569        let session = Session::new(ctx);
2570
2571        assert!(!session.is_closed());
2572        session.exit(0).unwrap();
2573        session.close().unwrap();
2574        assert!(session.is_closed());
2575    }
2576
2577    #[test]
2578    fn test_server_options_default() {
2579        let opts = ServerOptions::default();
2580        assert_eq!(opts.address, "0.0.0.0:22");
2581        assert_eq!(opts.version, "SSH-2.0-Wish");
2582        assert!(opts.banner.is_none());
2583    }
2584
2585    #[test]
2586    fn test_server_builder() {
2587        let server = ServerBuilder::new()
2588            .address("0.0.0.0:2222")
2589            .version("SSH-2.0-MyApp")
2590            .banner("Welcome!")
2591            .idle_timeout(Duration::from_secs(300))
2592            .build()
2593            .unwrap();
2594
2595        assert_eq!(server.address(), "0.0.0.0:2222");
2596        assert_eq!(server.options().version, "SSH-2.0-MyApp");
2597        assert_eq!(server.options().banner, Some("Welcome!".to_string()));
2598        assert_eq!(
2599            server.options().idle_timeout,
2600            Some(Duration::from_secs(300))
2601        );
2602    }
2603
2604    #[test]
2605    fn test_option_functions() {
2606        let mut opts = ServerOptions::default();
2607
2608        with_address("localhost:22")(&mut opts).unwrap();
2609        assert_eq!(opts.address, "localhost:22");
2610
2611        with_version("SSH-2.0-Test")(&mut opts).unwrap();
2612        assert_eq!(opts.version, "SSH-2.0-Test");
2613
2614        with_banner("Hello")(&mut opts).unwrap();
2615        assert_eq!(opts.banner, Some("Hello".to_string()));
2616
2617        with_idle_timeout(Duration::from_secs(60))(&mut opts).unwrap();
2618        assert_eq!(opts.idle_timeout, Some(Duration::from_secs(60)));
2619
2620        with_max_timeout(Duration::from_secs(3600))(&mut opts).unwrap();
2621        assert_eq!(opts.max_timeout, Some(Duration::from_secs(3600)));
2622    }
2623
2624    #[test]
2625    fn test_new_server() {
2626        let server =
2627            new_server([with_address("0.0.0.0:2222"), with_version("SSH-2.0-Test")]).unwrap();
2628
2629        assert_eq!(server.address(), "0.0.0.0:2222");
2630        assert_eq!(server.options().version, "SSH-2.0-Test");
2631    }
2632
2633    #[test]
2634    fn test_noop_handler() {
2635        let h = noop_handler();
2636        // Just verify it compiles and can be called
2637        let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2638        let ctx = Context::new("test", addr, addr);
2639        let session = Session::new(ctx);
2640        drop(h(session));
2641    }
2642
2643    #[tokio::test]
2644    async fn test_handler_creation() {
2645        let h = handler(|_session| async {
2646            // Do nothing
2647        });
2648
2649        let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2650        let ctx = Context::new("test", addr, addr);
2651        let session = Session::new(ctx);
2652        h(session).await;
2653    }
2654
2655    #[test]
2656    fn test_rate_limiter() {
2657        use middleware::ratelimiter::{RateLimiter, new_rate_limiter};
2658
2659        let limiter = new_rate_limiter(0.0, 3, 10);
2660        let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2661        let ctx = Context::new("testuser", addr, addr);
2662        let session = Session::new(ctx);
2663
2664        assert!(limiter.allow(&session).is_ok());
2665        assert!(limiter.allow(&session).is_ok());
2666        assert!(limiter.allow(&session).is_ok());
2667        assert!(limiter.allow(&session).is_err()); // Should be rate limited
2668    }
2669
2670    #[test]
2671    fn test_output_helpers() -> std::result::Result<(), Box<dyn std::error::Error>> {
2672        let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2673        let ctx = Context::new("test", addr, addr);
2674        let mut session = Session::new(ctx);
2675
2676        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
2677        session.set_output_sender(tx);
2678
2679        print(&session, "hello");
2680        println(&session, "world");
2681        error(&session, "err");
2682        errorln(&session, "error line");
2683
2684        // Verify data was written to channel
2685        // 1. print "hello"
2686        let item = rx.try_recv().map_err(|e| io::Error::other(e.to_string()))?;
2687        match item {
2688            SessionOutput::Stdout(data) => assert_eq!(data, b"hello"),
2689            other => {
2690                return Err(io::Error::other(format!(
2691                    "expected stdout for print(), got {other:?}"
2692                ))
2693                .into());
2694            }
2695        }
2696
2697        // 2. println "world\r\n"
2698        let item = rx.try_recv().map_err(|e| io::Error::other(e.to_string()))?;
2699        match item {
2700            SessionOutput::Stdout(data) => assert_eq!(data, b"world\r\n"),
2701            other => {
2702                return Err(io::Error::other(format!(
2703                    "expected stdout for println(), got {other:?}"
2704                ))
2705                .into());
2706            }
2707        }
2708
2709        // 3. error "err"
2710        let item = rx.try_recv().map_err(|e| io::Error::other(e.to_string()))?;
2711        match item {
2712            SessionOutput::Stderr(data) => assert_eq!(data, b"err"),
2713            other => {
2714                return Err(io::Error::other(format!(
2715                    "expected stderr for error(), got {other:?}"
2716                ))
2717                .into());
2718            }
2719        }
2720
2721        // 4. errorln "error line\r\n"
2722        let item = rx.try_recv().map_err(|e| io::Error::other(e.to_string()))?;
2723        match item {
2724            SessionOutput::Stderr(data) => assert_eq!(data, b"error line\r\n"),
2725            other => {
2726                return Err(io::Error::other(format!(
2727                    "expected stderr for errorln(), got {other:?}"
2728                ))
2729                .into());
2730            }
2731        }
2732
2733        Ok(())
2734    }
2735
2736    #[test]
2737    fn test_tea_make_renderer() {
2738        let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2739        let ctx = Context::new("test", addr, addr);
2740        let pty = Pty {
2741            term: "xterm-256color".to_string(),
2742            window: Window::default(),
2743        };
2744        let session = Session::new(ctx).with_pty(pty);
2745
2746        let _renderer = tea::make_renderer(&session);
2747        // Just verify it doesn't panic
2748    }
2749
2750    #[tokio::test]
2751    async fn test_tea_middleware_handles_program_panic()
2752    -> std::result::Result<(), Box<dyn std::error::Error>> {
2753        let called = Arc::new(AtomicUsize::new(0));
2754        let mw = tea::middleware(|_session| PanicTeaModel);
2755        let next = handler({
2756            let called = called.clone();
2757            move |_session| {
2758                let called = called.clone();
2759                async move {
2760                    called.fetch_add(1, Ordering::SeqCst);
2761                }
2762            }
2763        });
2764
2765        let addr: SocketAddr = "127.0.0.1:2222".parse().map_err(io::Error::other)?;
2766        let ctx = Context::new("test", addr, addr);
2767        let mut session = Session::new(ctx).with_pty(Pty::default());
2768
2769        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
2770        session.set_output_sender(tx);
2771
2772        mw(next)(session).await;
2773
2774        assert_eq!(called.load(Ordering::SeqCst), 0);
2775
2776        let mut saw_fatal = false;
2777        let mut saw_exit = false;
2778        let mut saw_close = false;
2779        loop {
2780            match rx.try_recv() {
2781                Ok(SessionOutput::Stderr(data)) => {
2782                    let msg = String::from_utf8_lossy(&data);
2783                    if msg.contains("bubbletea program crashed:") {
2784                        saw_fatal = true;
2785                    }
2786                }
2787                Ok(SessionOutput::Exit(1)) => saw_exit = true,
2788                Ok(SessionOutput::Close) => saw_close = true,
2789                Ok(_) => {}
2790                Err(tokio::sync::mpsc::error::TryRecvError::Empty)
2791                | Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => break,
2792            }
2793        }
2794
2795        assert!(saw_fatal, "expected fatal stderr output for tea panic");
2796        assert!(saw_exit, "expected exit(1) for tea panic");
2797        assert!(saw_close, "expected close signal for tea panic");
2798
2799        Ok(())
2800    }
2801
2802    #[test]
2803    fn test_error_display() {
2804        let err = Error::Io(io::Error::other("test"));
2805        assert!(err.to_string().contains("io error"));
2806
2807        let err = Error::AuthenticationFailed;
2808        assert_eq!(err.to_string(), "authentication failed");
2809
2810        let err = Error::Configuration("bad config".to_string());
2811        assert!(err.to_string().contains("configuration error"));
2812    }
2813
2814    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
2815    async fn test_session_recv_with_input_channel() {
2816        let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2817        let ctx = Context::new("testuser", addr, addr);
2818        let session = Session::new(ctx);
2819
2820        assert!(session.recv().await.is_none());
2821
2822        let (tx, rx) = tokio::sync::mpsc::channel(1);
2823        session.set_input_receiver(rx).await;
2824        tx.send(b"ping".to_vec()).await.unwrap();
2825
2826        let received = session.recv().await;
2827        assert_eq!(received, Some(b"ping".to_vec()));
2828    }
2829
2830    #[test]
2831    fn test_session_send_message() {
2832        let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2833        let ctx = Context::new("testuser", addr, addr);
2834        let session = Session::new(ctx);
2835
2836        let (tx, rx) = std::sync::mpsc::channel();
2837        session.set_message_sender(tx);
2838        session.send_message(Message::new(42u32));
2839
2840        let msg = rx.recv_timeout(Duration::from_millis(50)).unwrap();
2841        assert!(msg.is::<u32>());
2842        assert_eq!(msg.downcast::<u32>().unwrap(), 42);
2843    }
2844
2845    #[tokio::test]
2846    async fn test_compose_middleware_order() {
2847        let events = Arc::new(Mutex::new(Vec::new()));
2848        let middlewares = vec![
2849            record_middleware("first", events.clone()),
2850            record_middleware("second", events.clone()),
2851        ];
2852        let composed = compose_middleware(middlewares);
2853
2854        let handler = handler({
2855            let events = events.clone();
2856            move |_session| {
2857                let events = events.clone();
2858                async move {
2859                    let mut guard = events.lock().expect("events lock");
2860                    guard.push("handler");
2861                }
2862            }
2863        });
2864
2865        let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2866        let ctx = Context::new("test", addr, addr);
2867        let session = Session::new(ctx);
2868
2869        composed(handler)(session).await;
2870
2871        let events = events.lock().expect("events lock");
2872        assert_eq!(&*events, &["first", "second", "handler"]);
2873    }
2874
2875    #[tokio::test]
2876    async fn test_activeterm_middleware_blocks_without_pty()
2877    -> std::result::Result<(), Box<dyn std::error::Error>> {
2878        let called = Arc::new(AtomicUsize::new(0));
2879        let mw = middleware::activeterm::middleware();
2880        let handler = handler({
2881            let called = called.clone();
2882            move |_session| {
2883                let called = called.clone();
2884                async move {
2885                    called.fetch_add(1, Ordering::SeqCst);
2886                }
2887            }
2888        });
2889
2890        let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2891        let ctx = Context::new("test", addr, addr);
2892        let mut session = Session::new(ctx);
2893
2894        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
2895        session.set_output_sender(tx);
2896
2897        mw(handler)(session).await;
2898
2899        assert_eq!(called.load(Ordering::SeqCst), 0);
2900
2901        let item = rx.try_recv().map_err(|e| io::Error::other(e.to_string()))?;
2902        match item {
2903            SessionOutput::Stdout(data) => assert_eq!(data, b"Requires an active PTY\r\n"),
2904            other => {
2905                return Err(io::Error::other(format!(
2906                    "expected stdout warning for activeterm, got {other:?}"
2907                ))
2908                .into());
2909            }
2910        }
2911
2912        let item = rx.try_recv().map_err(|e| io::Error::other(e.to_string()))?;
2913        match item {
2914            SessionOutput::Exit(code) => assert_eq!(code, 1),
2915            other => {
2916                return Err(io::Error::other(format!(
2917                    "expected exit code for activeterm, got {other:?}"
2918                ))
2919                .into());
2920            }
2921        }
2922
2923        Ok(())
2924    }
2925
2926    #[tokio::test]
2927    async fn test_accesscontrol_middleware_allows_command() {
2928        let called = Arc::new(AtomicUsize::new(0));
2929        let mw = middleware::accesscontrol::middleware(vec!["git".to_string()]);
2930        let handler = handler({
2931            let called = called.clone();
2932            move |_session| {
2933                let called = called.clone();
2934                async move {
2935                    called.fetch_add(1, Ordering::SeqCst);
2936                }
2937            }
2938        });
2939
2940        let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2941        let ctx = Context::new("test", addr, addr);
2942        let session = Session::new(ctx).with_command(vec!["git".to_string()]);
2943
2944        mw(handler)(session).await;
2945
2946        assert_eq!(called.load(Ordering::SeqCst), 1);
2947    }
2948
2949    #[tokio::test]
2950    async fn test_accesscontrol_middleware_blocks_command()
2951    -> std::result::Result<(), Box<dyn std::error::Error>> {
2952        let called = Arc::new(AtomicUsize::new(0));
2953        let mw = middleware::accesscontrol::middleware(vec!["git".to_string()]);
2954        let handler = handler({
2955            let called = called.clone();
2956            move |_session| {
2957                let called = called.clone();
2958                async move {
2959                    called.fetch_add(1, Ordering::SeqCst);
2960                }
2961            }
2962        });
2963
2964        let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
2965        let ctx = Context::new("test", addr, addr);
2966        let mut session = Session::new(ctx).with_command(vec!["rm".to_string()]);
2967
2968        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
2969        session.set_output_sender(tx);
2970
2971        mw(handler)(session).await;
2972
2973        assert_eq!(called.load(Ordering::SeqCst), 0);
2974
2975        let item = rx.try_recv().map_err(|e| io::Error::other(e.to_string()))?;
2976        match item {
2977            SessionOutput::Stdout(data) => assert_eq!(data, b"Command is not allowed: rm\r\n"),
2978            other => {
2979                return Err(io::Error::other(format!(
2980                    "expected stdout message for accesscontrol, got {other:?}"
2981                ))
2982                .into());
2983            }
2984        }
2985
2986        let item = rx.try_recv().map_err(|e| io::Error::other(e.to_string()))?;
2987        match item {
2988            SessionOutput::Exit(code) => assert_eq!(code, 1),
2989            other => {
2990                return Err(io::Error::other(format!(
2991                    "expected exit code for accesscontrol, got {other:?}"
2992                ))
2993                .into());
2994            }
2995        }
2996
2997        Ok(())
2998    }
2999
3000    #[tokio::test]
3001    async fn test_comment_middleware_appends_message()
3002    -> std::result::Result<(), Box<dyn std::error::Error>> {
3003        let mw = middleware::comment::middleware("done");
3004        let handler = handler(|session| async move {
3005            print(&session, "work");
3006        });
3007
3008        let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
3009        let ctx = Context::new("test", addr, addr);
3010        let mut session = Session::new(ctx);
3011
3012        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
3013        session.set_output_sender(tx);
3014
3015        mw(handler)(session).await;
3016
3017        let item = rx.try_recv().map_err(|e| io::Error::other(e.to_string()))?;
3018        match item {
3019            SessionOutput::Stdout(data) => assert_eq!(data, b"work"),
3020            other => {
3021                return Err(io::Error::other(format!(
3022                    "expected stdout for handler output, got {other:?}"
3023                ))
3024                .into());
3025            }
3026        }
3027
3028        let item = rx.try_recv().map_err(|e| io::Error::other(e.to_string()))?;
3029        match item {
3030            SessionOutput::Stdout(data) => assert_eq!(data, b"done\r\n"),
3031            other => {
3032                return Err(io::Error::other(format!(
3033                    "expected stdout for comment output, got {other:?}"
3034                ))
3035                .into());
3036            }
3037        }
3038
3039        Ok(())
3040    }
3041
3042    #[tokio::test]
3043    async fn test_elapsed_middleware_outputs_timing()
3044    -> std::result::Result<(), Box<dyn std::error::Error>> {
3045        let mw = middleware::elapsed::middleware_with_format("elapsed=%v");
3046        let handler = handler(|_session| async move {});
3047
3048        let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
3049        let ctx = Context::new("test", addr, addr);
3050        let mut session = Session::new(ctx);
3051
3052        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
3053        session.set_output_sender(tx);
3054
3055        mw(handler)(session).await;
3056
3057        let item = rx.try_recv().map_err(|e| io::Error::other(e.to_string()))?;
3058        match item {
3059            SessionOutput::Stdout(data) => {
3060                let msg = String::from_utf8_lossy(&data);
3061                assert!(msg.contains("elapsed="));
3062            }
3063            other => {
3064                return Err(io::Error::other(format!(
3065                    "expected stdout for elapsed middleware, got {other:?}"
3066                ))
3067                .into());
3068            }
3069        }
3070
3071        Ok(())
3072    }
3073
3074    #[tokio::test]
3075    async fn test_ratelimiter_middleware_rejects()
3076    -> std::result::Result<(), Box<dyn std::error::Error>> {
3077        let called = Arc::new(AtomicUsize::new(0));
3078        let mw = middleware::ratelimiter::middleware(DenyLimiter);
3079        let handler = handler({
3080            let called = called.clone();
3081            move |_session| {
3082                let called = called.clone();
3083                async move {
3084                    called.fetch_add(1, Ordering::SeqCst);
3085                }
3086            }
3087        });
3088
3089        let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
3090        let ctx = Context::new("test", addr, addr);
3091        let mut session = Session::new(ctx);
3092
3093        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
3094        session.set_output_sender(tx);
3095
3096        mw(handler)(session).await;
3097
3098        assert_eq!(called.load(Ordering::SeqCst), 0);
3099
3100        let item = rx.try_recv().map_err(|e| io::Error::other(e.to_string()))?;
3101        match item {
3102            SessionOutput::Stderr(data) => {
3103                assert_eq!(
3104                    data,
3105                    middleware::ratelimiter::ERR_RATE_LIMIT_EXCEEDED.as_bytes()
3106                );
3107            }
3108            other => {
3109                return Err(io::Error::other(format!(
3110                    "expected stderr for ratelimiter, got {other:?}"
3111                ))
3112                .into());
3113            }
3114        }
3115
3116        let item = rx.try_recv().map_err(|e| io::Error::other(e.to_string()))?;
3117        match item {
3118            SessionOutput::Exit(code) => assert_eq!(code, 1),
3119            other => {
3120                return Err(io::Error::other(format!(
3121                    "expected exit for ratelimiter, got {other:?}"
3122                ))
3123                .into());
3124            }
3125        }
3126
3127        let item = rx.try_recv().map_err(|e| io::Error::other(e.to_string()))?;
3128        match item {
3129            SessionOutput::Close => {}
3130            other => {
3131                return Err(io::Error::other(format!(
3132                    "expected close for ratelimiter, got {other:?}"
3133                ))
3134                .into());
3135            }
3136        }
3137
3138        Ok(())
3139    }
3140
3141    #[tokio::test]
3142    async fn test_logging_middleware_with_custom_logger() {
3143        let entries = Arc::new(Mutex::new(Vec::new()));
3144        let logger = TestLogger {
3145            entries: entries.clone(),
3146        };
3147
3148        let mw = middleware::logging::middleware_with_logger(logger);
3149        let handler = handler(|_session| async move {});
3150
3151        let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
3152        let ctx = Context::new("alice", addr, addr);
3153        let session = Session::new(ctx);
3154
3155        mw(handler)(session).await;
3156
3157        let entries = entries.lock().expect("logger entries");
3158        assert_eq!(entries.len(), 2);
3159        assert!(entries[0].contains("connect"));
3160        assert!(entries[1].contains("disconnect"));
3161    }
3162
3163    #[tokio::test]
3164    async fn test_structured_logging_middleware_with_custom_logger() {
3165        let logger = TestStructuredLogger::default();
3166        let mw = middleware::logging::structured_middleware_with_logger(
3167            logger.clone(),
3168            tracing::Level::INFO,
3169        );
3170        let handler = handler(|_session| async move {});
3171
3172        let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
3173        let ctx = Context::new("alice", addr, addr);
3174        let session = Session::new(ctx).with_public_key(PublicKey::new("ssh-ed25519", vec![1]));
3175
3176        mw(handler)(session).await;
3177
3178        let connects = logger.connects.lock().expect("connects");
3179        assert_eq!(connects.len(), 1);
3180        assert_eq!(connects[0].0, "alice");
3181        assert_eq!(connects[0].1, addr);
3182        assert!(connects[0].2);
3183
3184        let disconnects = logger.disconnects.lock().expect("disconnects");
3185        assert_eq!(disconnects.len(), 1);
3186        assert_eq!(disconnects[0].0, "alice");
3187        assert_eq!(disconnects[0].1, addr);
3188    }
3189
3190    #[tokio::test]
3191    async fn test_recover_middleware_runs_inner_before_next() {
3192        let events = Arc::new(Mutex::new(Vec::new()));
3193        let inner = record_middleware("inner", events.clone());
3194        let mw = middleware::recover::middleware_with_middlewares(vec![inner]);
3195
3196        let handler = handler({
3197            let events = events.clone();
3198            move |_session| {
3199                let events = events.clone();
3200                async move {
3201                    let mut guard = events.lock().expect("events lock");
3202                    guard.push("handler");
3203                }
3204            }
3205        });
3206
3207        let addr: SocketAddr = "127.0.0.1:2222".parse().unwrap();
3208        let ctx = Context::new("test", addr, addr);
3209        let session = Session::new(ctx);
3210
3211        mw(handler)(session).await;
3212
3213        let events = events.lock().expect("events lock");
3214        assert_eq!(&*events, &["inner", "handler"]);
3215    }
3216
3217    #[test]
3218    fn test_server_option_auth_and_subsystem() {
3219        let mut opts = ServerOptions::default();
3220
3221        with_auth_handler(AcceptAllAuth::new())(&mut opts).unwrap();
3222        with_max_auth_attempts(3)(&mut opts).unwrap();
3223        with_auth_rejection_delay(250)(&mut opts).unwrap();
3224        with_public_key_auth(|_ctx, _key| true)(&mut opts).unwrap();
3225        with_password_auth(|_ctx, _pw| true)(&mut opts).unwrap();
3226        with_keyboard_interactive_auth(|_ctx, _resp, _prompts, _echos| vec!["ok".to_string()])(
3227            &mut opts,
3228        )
3229        .unwrap();
3230        with_host_key_path("/tmp/wish_host_file")(&mut opts).unwrap();
3231        with_host_key_pem(b"test_key_data".to_vec())(&mut opts).unwrap();
3232        with_banner_handler(|ctx| format!("hello {}", ctx.user()))(&mut opts).unwrap();
3233        with_middleware(middleware::comment::middleware("hi"))(&mut opts).unwrap();
3234        with_subsystem("sftp", |_session| async move {})(&mut opts).unwrap();
3235
3236        assert!(opts.auth_handler.is_some());
3237        assert_eq!(opts.max_auth_attempts, 3);
3238        assert_eq!(opts.auth_rejection_delay_ms, 250);
3239        assert!(opts.public_key_handler.is_some());
3240        assert!(opts.password_handler.is_some());
3241        assert!(opts.keyboard_interactive_handler.is_some());
3242        assert_eq!(opts.host_key_path.as_deref(), Some("/tmp/wish_host_file"));
3243        assert_eq!(
3244            opts.host_key_pem.as_deref(),
3245            Some(b"test_key_data".as_slice())
3246        );
3247        assert!(opts.banner_handler.is_some());
3248        assert_eq!(opts.middlewares.len(), 1);
3249        assert!(opts.subsystem_handlers.contains_key("sftp"));
3250    }
3251
3252    #[test]
3253    fn test_server_builder_auth_settings() {
3254        let server = ServerBuilder::new()
3255            .address("127.0.0.1:2222")
3256            .max_auth_attempts(5)
3257            .auth_rejection_delay(123)
3258            .public_key_auth(|_ctx, _key| true)
3259            .password_auth(|_ctx, _pw| true)
3260            .keyboard_interactive_auth(|_ctx, _resp, _prompts, _echos| vec![])
3261            .subsystem("sftp", |_session| async move {})
3262            .build()
3263            .unwrap();
3264
3265        assert_eq!(server.options().max_auth_attempts, 5);
3266        assert_eq!(server.options().auth_rejection_delay_ms, 123);
3267        assert!(server.options().public_key_handler.is_some());
3268        assert!(server.options().password_handler.is_some());
3269        assert!(server.options().keyboard_interactive_handler.is_some());
3270        assert!(server.options().subsystem_handlers.contains_key("sftp"));
3271    }
3272
3273    #[test]
3274    fn test_create_russh_config_methods_from_auth_handler() {
3275        use russh::MethodSet;
3276
3277        struct PasswordOnly;
3278
3279        #[async_trait::async_trait]
3280        impl AuthHandler for PasswordOnly {
3281            fn supported_methods(&self) -> Vec<AuthMethod> {
3282                vec![AuthMethod::Password]
3283            }
3284        }
3285
3286        let server = ServerBuilder::new()
3287            .auth_handler(PasswordOnly)
3288            .build()
3289            .unwrap();
3290        let config = server.create_russh_config().unwrap();
3291
3292        assert!(config.methods.contains(MethodSet::PASSWORD));
3293        assert!(!config.methods.contains(MethodSet::PUBLICKEY));
3294    }
3295
3296    #[test]
3297    fn test_create_russh_config_methods_from_callbacks() {
3298        use russh::MethodSet;
3299
3300        let server = ServerBuilder::new()
3301            .public_key_auth(|_ctx, _key| true)
3302            .password_auth(|_ctx, _pw| true)
3303            .build()
3304            .unwrap();
3305
3306        let config = server.create_russh_config().unwrap();
3307
3308        assert!(config.methods.contains(MethodSet::PUBLICKEY));
3309        assert!(config.methods.contains(MethodSet::PASSWORD));
3310        assert!(!config.methods.contains(MethodSet::KEYBOARD_INTERACTIVE));
3311    }
3312}