1use std::collections::HashMap;
7use std::net::SocketAddr;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use bubbletea::{
12 KeyMsg, Message, WindowSizeMsg,
13 key::{is_sequence_prefix, parse_sequence_prefix},
14};
15use parking_lot::RwLock;
16use russh::MethodSet;
17use russh::server::{Auth, Handler as RusshHandler, Msg, Session as RusshSession};
18use russh::{Channel, ChannelId};
19use russh_keys::PublicKeyBase64;
20use tokio::sync::{broadcast, mpsc};
21use tracing::{debug, info, trace, warn};
22
23use crate::{
24 AuthContext, AuthMethod, AuthResult, Context, Error, Handler, Pty, PublicKey, ServerOptions,
25 Session, SessionOutput, Window, compose_middleware, noop_handler,
26};
27
28pub use russh::server::{Config as RusshConfig, run_stream};
30
31pub struct ServerState {
33 pub options: ServerOptions,
35 pub handler: Handler,
37 pub connection_counter: RwLock<u64>,
39}
40
41impl ServerState {
42 pub fn new(options: ServerOptions) -> Self {
44 let base_handler = options.handler.clone().unwrap_or_else(noop_handler);
46 let handler = if options.middlewares.is_empty() {
47 base_handler
48 } else {
49 let composed = compose_middleware(options.middlewares.clone());
50 composed(base_handler)
51 };
52
53 Self {
54 options,
55 handler,
56 connection_counter: RwLock::new(0),
57 }
58 }
59
60 pub fn next_connection_id(&self) -> u64 {
62 let mut counter = self.connection_counter.write();
63 *counter += 1;
64 *counter
65 }
66}
67
68struct ChannelState {
70 session: Session,
72 input_tx: mpsc::Sender<Vec<u8>>,
74 started: bool,
76 input_buffer: Vec<u8>,
78}
79
80struct KeyboardInteractiveState {
82 prompts: Vec<String>,
83 echos: Vec<bool>,
84}
85
86fn parse_exec_command_args(command: &str) -> Option<Vec<String>> {
95 let mut args = Vec::new();
96 let mut current = String::new();
97 let mut token_in_progress = false;
98 let mut in_single_quotes = false;
99 let mut in_double_quotes = false;
100 let mut escaped = false;
101
102 for ch in command.chars() {
103 if escaped {
104 current.push(ch);
105 token_in_progress = true;
106 escaped = false;
107 continue;
108 }
109
110 match ch {
111 '\\' if !in_single_quotes => {
112 escaped = true;
113 token_in_progress = true;
114 }
115 '\'' if !in_double_quotes => {
116 in_single_quotes = !in_single_quotes;
117 token_in_progress = true;
118 }
119 '"' if !in_single_quotes => {
120 in_double_quotes = !in_double_quotes;
121 token_in_progress = true;
122 }
123 _ if ch.is_whitespace() && !in_single_quotes && !in_double_quotes => {
124 if token_in_progress {
125 args.push(std::mem::take(&mut current));
126 token_in_progress = false;
127 }
128 }
129 _ => {
130 current.push(ch);
131 token_in_progress = true;
132 }
133 }
134 }
135
136 if escaped {
137 current.push('\\');
138 token_in_progress = true;
139 }
140
141 if in_single_quotes || in_double_quotes {
142 return None;
143 }
144
145 if token_in_progress {
146 args.push(current);
147 }
148
149 Some(args)
150}
151
152pub struct WishHandler {
157 connection_id: u64,
159 remote_addr: SocketAddr,
161 local_addr: SocketAddr,
163 user: Option<String>,
165 public_key: Option<russh_keys::key::PublicKey>,
167 pty: Option<Pty>,
169 window: Window,
171 server_state: Arc<ServerState>,
173 channels: HashMap<ChannelId, ChannelState>,
175 #[allow(dead_code)]
177 shutdown_rx: broadcast::Receiver<()>,
178 auth_attempts: u32,
180 keyboard_interactive: Option<KeyboardInteractiveState>,
182}
183
184impl WishHandler {
185 pub fn new(
187 remote_addr: SocketAddr,
188 local_addr: SocketAddr,
189 server_state: Arc<ServerState>,
190 shutdown_rx: broadcast::Receiver<()>,
191 ) -> Self {
192 let connection_id = server_state.next_connection_id();
193 debug!(
194 connection_id,
195 remote_addr = %remote_addr,
196 "New connection handler created"
197 );
198
199 Self {
200 connection_id,
201 remote_addr,
202 local_addr,
203 user: None,
204 public_key: None,
205 pty: None,
206 window: Window::default(),
207 server_state,
208 channels: HashMap::new(),
209 shutdown_rx,
210 auth_attempts: 0,
211 keyboard_interactive: None,
212 }
213 }
214
215 fn make_context(&self, user: &str) -> Context {
217 let ctx = Context::new(user, self.remote_addr, self.local_addr);
218 ctx.set_value("connection_id", self.connection_id.to_string());
219 ctx
220 }
221
222 fn next_auth_context(&mut self, user: &str) -> AuthContext {
223 self.auth_attempts = self.auth_attempts.saturating_add(1);
224 AuthContext::new(user, self.remote_addr, crate::SessionId(self.connection_id))
225 .with_attempt(self.auth_attempts)
226 }
227
228 fn method_set_from(methods: &[AuthMethod]) -> Option<MethodSet> {
229 let mut set = MethodSet::empty();
230 for method in methods {
231 match method {
232 AuthMethod::None => set |= MethodSet::NONE,
233 AuthMethod::Password => set |= MethodSet::PASSWORD,
234 AuthMethod::PublicKey => set |= MethodSet::PUBLICKEY,
235 AuthMethod::KeyboardInteractive => set |= MethodSet::KEYBOARD_INTERACTIVE,
236 AuthMethod::HostBased => set |= MethodSet::HOSTBASED,
237 }
238 }
239 if set.is_empty() { None } else { Some(set) }
240 }
241
242 fn map_auth_result(result: AuthResult) -> Auth {
243 match result {
244 AuthResult::Accept => Auth::Accept,
245 AuthResult::Reject => Auth::Reject {
246 proceed_with_methods: None,
247 },
248 AuthResult::Partial { next_methods } => Auth::Reject {
249 proceed_with_methods: Self::method_set_from(&next_methods),
250 },
251 }
252 }
253
254 fn convert_public_key(key: &russh_keys::key::PublicKey) -> PublicKey {
256 let key_name = key.name();
257 let key_type = match key_name {
258 "ssh-ed25519" => "ssh-ed25519",
259 "rsa-sha2-256" | "rsa-sha2-512" | "ssh-rsa" => "ssh-rsa",
260 "ecdsa-sha2-nistp256" => "ecdsa-sha2-nistp256",
261 "ecdsa-sha2-nistp384" => "ecdsa-sha2-nistp384",
262 "ecdsa-sha2-nistp521" => "ecdsa-sha2-nistp521",
263 other => other,
264 };
265
266 let key_bytes = key.public_key_bytes();
267 PublicKey::new(key_type, key_bytes)
268 }
269
270 fn default_keyboard_interactive_state() -> KeyboardInteractiveState {
271 KeyboardInteractiveState {
272 prompts: vec!["Password: ".to_string()],
273 echos: vec![false],
274 }
275 }
276}
277
278#[async_trait]
279impl RusshHandler for WishHandler {
280 type Error = Error;
281
282 async fn auth_publickey(
284 &mut self,
285 user: &str,
286 public_key: &russh_keys::key::PublicKey,
287 ) -> std::result::Result<Auth, Self::Error> {
288 debug!(
289 connection_id = self.connection_id,
290 user = user,
291 key_type = public_key.name(),
292 "Public key auth attempt"
293 );
294
295 if let Some(handler) = self.server_state.options.auth_handler.clone() {
296 let ctx = self.next_auth_context(user);
297 let pk = Self::convert_public_key(public_key);
298 let result = handler.auth_publickey(&ctx, &pk).await;
299 if result.is_accepted() {
300 info!(
301 connection_id = self.connection_id,
302 user = user,
303 "Public key auth accepted"
304 );
305 self.user = Some(user.to_string());
306 self.public_key = Some(public_key.clone());
307 }
308 return Ok(Self::map_auth_result(result));
309 }
310
311 if let Some(handler) = &self.server_state.options.public_key_handler {
313 let ctx = self.make_context(user);
314 let pk = Self::convert_public_key(public_key);
315
316 if handler(&ctx, &pk) {
317 info!(
318 connection_id = self.connection_id,
319 user = user,
320 "Public key auth accepted"
321 );
322 self.user = Some(user.to_string());
323 self.public_key = Some(public_key.clone());
324 return Ok(Auth::Accept);
325 }
326 }
327
328 debug!(
330 connection_id = self.connection_id,
331 user = user,
332 "Public key auth rejected"
333 );
334 Ok(Auth::Reject {
335 proceed_with_methods: None,
336 })
337 }
338
339 async fn auth_password(
341 &mut self,
342 user: &str,
343 password: &str,
344 ) -> std::result::Result<Auth, Self::Error> {
345 debug!(
346 connection_id = self.connection_id,
347 user = user,
348 "Password auth attempt"
349 );
350
351 if let Some(handler) = self.server_state.options.auth_handler.clone() {
352 let ctx = self.next_auth_context(user);
353 let result = handler.auth_password(&ctx, password).await;
354 if result.is_accepted() {
355 info!(
356 connection_id = self.connection_id,
357 user = user,
358 "Password auth accepted"
359 );
360 self.user = Some(user.to_string());
361 }
362 return Ok(Self::map_auth_result(result));
363 }
364
365 if let Some(handler) = &self.server_state.options.password_handler {
367 let ctx = self.make_context(user);
368
369 if handler(&ctx, password) {
370 info!(
371 connection_id = self.connection_id,
372 user = user,
373 "Password auth accepted"
374 );
375 self.user = Some(user.to_string());
376 return Ok(Auth::Accept);
377 }
378 }
379
380 debug!(
381 connection_id = self.connection_id,
382 user = user,
383 "Password auth rejected"
384 );
385 Ok(Auth::Reject {
386 proceed_with_methods: None,
387 })
388 }
389
390 async fn auth_none(&mut self, user: &str) -> std::result::Result<Auth, Self::Error> {
392 if let Some(handler) = self.server_state.options.auth_handler.clone() {
393 let ctx = self.next_auth_context(user);
394 let result = handler.auth_none(&ctx).await;
395 if result.is_accepted() {
396 info!(
397 connection_id = self.connection_id,
398 user = user,
399 "Auth handler accepted none authentication"
400 );
401 self.user = Some(user.to_string());
402 }
403 return Ok(Self::map_auth_result(result));
404 }
405
406 let has_auth = self.server_state.options.public_key_handler.is_some()
410 || self.server_state.options.password_handler.is_some()
411 || self
412 .server_state
413 .options
414 .keyboard_interactive_handler
415 .is_some();
416
417 if !has_auth && self.server_state.options.allow_no_auth {
418 info!(
419 connection_id = self.connection_id,
420 user = user,
421 "No auth configured and allow_no_auth is set, accepting connection"
422 );
423 self.user = Some(user.to_string());
424 return Ok(Auth::Accept);
425 }
426
427 if !has_auth {
428 warn!(
429 connection_id = self.connection_id,
430 user = user,
431 "No auth handlers configured — rejecting auth_none. \
432 Set allow_no_auth=true to allow unauthenticated access."
433 );
434 }
435
436 Ok(Auth::Reject {
437 proceed_with_methods: None,
438 })
439 }
440
441 async fn auth_keyboard_interactive(
443 &mut self,
444 user: &str,
445 submethods: &str,
446 response: Option<russh::server::Response<'async_trait>>,
447 ) -> std::result::Result<Auth, Self::Error> {
448 debug!(
449 connection_id = self.connection_id,
450 user = user,
451 submethods = submethods,
452 "Keyboard-interactive auth attempt"
453 );
454
455 let has_handler = self.server_state.options.auth_handler.is_some()
456 || self
457 .server_state
458 .options
459 .keyboard_interactive_handler
460 .is_some();
461
462 if !has_handler {
463 return Ok(Auth::Reject {
464 proceed_with_methods: None,
465 });
466 }
467
468 if response.is_none() {
469 let state = self
470 .keyboard_interactive
471 .get_or_insert_with(Self::default_keyboard_interactive_state);
472 let prompts: Vec<(std::borrow::Cow<'static, str>, bool)> = state
473 .prompts
474 .iter()
475 .enumerate()
476 .map(|(index, prompt)| {
477 let echo = state.echos.get(index).copied().unwrap_or(false);
478 (std::borrow::Cow::Owned(prompt.clone()), echo)
479 })
480 .collect();
481
482 return Ok(Auth::Partial {
483 name: std::borrow::Cow::Borrowed("keyboard-interactive"),
484 instructions: std::borrow::Cow::Borrowed(""),
485 prompts: std::borrow::Cow::Owned(prompts),
486 });
487 }
488
489 let responses: Vec<String> = response
490 .into_iter()
491 .flatten()
492 .map(|bytes| String::from_utf8_lossy(bytes).to_string())
493 .collect();
494
495 if let Some(handler) = self.server_state.options.auth_handler.clone() {
496 let ctx = self.next_auth_context(user);
497 let response_text = responses.join("\n");
498 let result = handler
499 .auth_keyboard_interactive(&ctx, &response_text)
500 .await;
501 if result.is_accepted() {
502 info!(
503 connection_id = self.connection_id,
504 user = user,
505 "Keyboard-interactive auth accepted"
506 );
507 self.user = Some(user.to_string());
508 }
509 self.keyboard_interactive = None;
510 return Ok(Self::map_auth_result(result));
511 }
512
513 if let Some(handler) = &self.server_state.options.keyboard_interactive_handler {
514 let ctx = self.make_context(user);
515 let state = self
516 .keyboard_interactive
517 .take()
518 .unwrap_or_else(Self::default_keyboard_interactive_state);
519 let expected = handler(&ctx, submethods, &state.prompts, &state.echos);
520 if expected == responses {
521 info!(
522 connection_id = self.connection_id,
523 user = user,
524 "Keyboard-interactive auth accepted"
525 );
526 self.user = Some(user.to_string());
527 self.keyboard_interactive = None;
528 return Ok(Auth::Accept);
529 }
530 }
531
532 self.keyboard_interactive = None;
533 Ok(Auth::Reject {
534 proceed_with_methods: None,
535 })
536 }
537
538 async fn channel_open_session(
540 &mut self,
541 channel: Channel<Msg>,
542 session: &mut RusshSession,
543 ) -> std::result::Result<bool, Self::Error> {
544 let channel_id = channel.id();
545 debug!(
546 connection_id = self.connection_id,
547 channel = ?channel_id,
548 "Session channel opened"
549 );
550
551 let (input_tx, input_rx) = mpsc::channel(1024);
553 let (output_tx, mut output_rx) = mpsc::unbounded_channel::<SessionOutput>();
554
555 let user = self.user.clone().unwrap_or_default();
556 let mut ctx = self.make_context(&user);
557 let client_version = String::from_utf8_lossy(session.remote_sshid()).to_string();
558 ctx.set_client_version(client_version);
559 let mut wish_session = Session::new(ctx);
560 wish_session.set_output_sender(output_tx);
561 wish_session.set_input_receiver(input_rx).await;
562
563 let handle = session.handle();
565
566 let connection_id = self.connection_id;
568 tokio::spawn(async move {
569 debug!(connection_id, channel = ?channel_id, "Starting output pump");
570 while let Some(msg) = output_rx.recv().await {
571 match msg {
572 SessionOutput::Stdout(data) => {
573 let _ = channel.data(&data[..]).await;
574 }
575 SessionOutput::Stderr(data) => {
576 let _ = channel.extended_data(1, &data[..]).await;
577 }
578 SessionOutput::Exit(code) => {
579 let _ = handle.exit_status_request(channel_id, code).await;
580 let _ = channel.close().await;
581 break;
582 }
583 SessionOutput::Close => {
584 let _ = channel.close().await;
585 break;
586 }
587 }
588 }
589 debug!(connection_id, channel = ?channel_id, "Output pump finished");
590 });
591
592 if let Some(ref pk) = self.public_key {
594 wish_session = wish_session.with_public_key(Self::convert_public_key(pk));
595 }
596
597 wish_session
599 .context()
600 .set_value("channel_id", format!("{channel_id:?}"));
601
602 self.channels.insert(
603 channel_id,
604 ChannelState {
605 session: wish_session,
606 input_tx,
607 started: false,
608 input_buffer: Vec::new(),
609 },
610 );
611
612 Ok(true)
613 }
614
615 async fn pty_request(
617 &mut self,
618 channel: ChannelId,
619 term: &str,
620 col_width: u32,
621 row_height: u32,
622 _pix_width: u32,
623 _pix_height: u32,
624 _modes: &[(russh::Pty, u32)],
625 session: &mut RusshSession,
626 ) -> std::result::Result<(), Self::Error> {
627 debug!(
628 connection_id = self.connection_id,
629 channel = ?channel,
630 term = term,
631 width = col_width,
632 height = row_height,
633 "PTY request"
634 );
635
636 let pty = Pty {
637 term: term.to_string(),
638 window: Window {
639 width: col_width,
640 height: row_height,
641 },
642 };
643 self.pty = Some(pty.clone());
644 self.window = Window {
645 width: col_width,
646 height: row_height,
647 };
648
649 if let Some(state) = self.channels.get_mut(&channel) {
651 state.session = state.session.clone().with_pty(pty);
652 }
653
654 session.channel_success(channel);
655 Ok(())
656 }
657
658 async fn shell_request(
660 &mut self,
661 channel: ChannelId,
662 session: &mut RusshSession,
663 ) -> std::result::Result<(), Self::Error> {
664 debug!(
665 connection_id = self.connection_id,
666 channel = ?channel,
667 "Shell request"
668 );
669
670 if let Some(state) = self.channels.get_mut(&channel) {
671 if state.started {
672 warn!(
673 connection_id = self.connection_id,
674 channel = ?channel,
675 "Shell already started"
676 );
677 session.channel_failure(channel);
678 return Ok(());
679 }
680
681 state.started = true;
682 let wish_session = state.session.clone();
683 let handler = self.server_state.handler.clone();
684 let connection_id = self.connection_id;
685
686 tokio::spawn(async move {
688 debug!(connection_id, "Starting handler");
689 handler(wish_session).await;
690 debug!(connection_id, "Handler completed");
691 });
692
693 session.channel_success(channel);
694 } else {
695 session.channel_failure(channel);
696 }
697
698 Ok(())
699 }
700
701 async fn exec_request(
703 &mut self,
704 channel: ChannelId,
705 data: &[u8],
706 session: &mut RusshSession,
707 ) -> std::result::Result<(), Self::Error> {
708 let command = String::from_utf8_lossy(data).to_string();
709 debug!(
710 connection_id = self.connection_id,
711 channel = ?channel,
712 command = %command,
713 "Exec request"
714 );
715
716 if let Some(state) = self.channels.get_mut(&channel) {
717 if state.started {
718 session.channel_failure(channel);
719 return Ok(());
720 }
721
722 let args = parse_exec_command_args(&command).unwrap_or_else(|| {
724 warn!(
725 connection_id = self.connection_id,
726 channel = ?channel,
727 command = %command,
728 "Malformed quoted exec command; falling back to whitespace split"
729 );
730 command.split_whitespace().map(String::from).collect()
731 });
732 state.session = state.session.clone().with_command(args);
733 state.started = true;
734
735 let wish_session = state.session.clone();
736 let handler = self.server_state.handler.clone();
737 let connection_id = self.connection_id;
738
739 tokio::spawn(async move {
740 debug!(connection_id, "Starting exec handler");
741 handler(wish_session).await;
742 debug!(connection_id, "Exec handler completed");
743 });
744
745 session.channel_success(channel);
746 } else {
747 session.channel_failure(channel);
748 }
749
750 Ok(())
751 }
752
753 async fn env_request(
755 &mut self,
756 channel: ChannelId,
757 variable_name: &str,
758 variable_value: &str,
759 session: &mut RusshSession,
760 ) -> std::result::Result<(), Self::Error> {
761 trace!(
762 connection_id = self.connection_id,
763 channel = ?channel,
764 name = variable_name,
765 value = variable_value,
766 "Environment variable request"
767 );
768
769 if let Some(state) = self.channels.get_mut(&channel) {
770 state.session = state
771 .session
772 .clone()
773 .with_env(variable_name, variable_value);
774 }
775
776 session.channel_success(channel);
777 Ok(())
778 }
779
780 async fn subsystem_request(
782 &mut self,
783 channel: ChannelId,
784 name: &str,
785 session: &mut RusshSession,
786 ) -> std::result::Result<(), Self::Error> {
787 debug!(
788 connection_id = self.connection_id,
789 channel = ?channel,
790 subsystem = name,
791 "Subsystem request"
792 );
793
794 if let Some(handler) = self.server_state.options.subsystem_handlers.get(name)
796 && let Some(state) = self.channels.get_mut(&channel)
797 {
798 if state.started {
799 session.channel_failure(channel);
800 return Ok(());
801 }
802
803 state.session = state.session.clone().with_subsystem(name);
804 state.started = true;
805
806 let wish_session = state.session.clone();
807 let handler = handler.clone();
808 let connection_id = self.connection_id;
809 let subsystem_name = name.to_string();
810
811 tokio::spawn(async move {
812 debug!(
813 connection_id,
814 subsystem = %subsystem_name,
815 "Starting subsystem handler"
816 );
817 handler(wish_session).await;
818 debug!(connection_id, "Subsystem handler completed");
819 });
820
821 session.channel_success(channel);
822 return Ok(());
823 }
824
825 session.channel_failure(channel);
826 Ok(())
827 }
828
829 async fn window_change_request(
831 &mut self,
832 channel: ChannelId,
833 col_width: u32,
834 row_height: u32,
835 _pix_width: u32,
836 _pix_height: u32,
837 _session: &mut RusshSession,
838 ) -> std::result::Result<(), Self::Error> {
839 trace!(
840 connection_id = self.connection_id,
841 channel = ?channel,
842 width = col_width,
843 height = row_height,
844 "Window change request"
845 );
846
847 self.window = Window {
848 width: col_width,
849 height: row_height,
850 };
851
852 if let Some(ref mut pty) = self.pty {
854 pty.window = self.window;
855 }
856
857 if let Some(state) = self.channels.get(&channel) {
859 state.session.send_message(Message::new(WindowSizeMsg {
860 width: col_width as u16,
861 height: row_height as u16,
862 }));
863 }
864
865 Ok(())
866 }
867
868 async fn data(
870 &mut self,
871 channel: ChannelId,
872 data: &[u8],
873 _session: &mut RusshSession,
874 ) -> std::result::Result<(), Self::Error> {
875 trace!(
876 connection_id = self.connection_id,
877 channel = ?channel,
878 len = data.len(),
879 "Data received"
880 );
881
882 if let Some(state) = self.channels.get_mut(&channel) {
883 if let Err(mpsc::error::TrySendError::Full(_)) = state.input_tx.try_send(data.to_vec())
886 {
887 warn!(
888 connection_id = self.connection_id,
889 channel = ?channel,
890 "Input buffer full, dropping data (app not reading input?)"
891 );
892 }
893
894 const MAX_INPUT_BUFFER: usize = 64 * 1024;
897 if state.input_buffer.len() + data.len() > MAX_INPUT_BUFFER {
898 warn!(
899 connection_id = self.connection_id,
900 channel = ?channel,
901 buffer_len = state.input_buffer.len(),
902 data_len = data.len(),
903 "Input buffer exceeded 64KB limit, draining buffer"
904 );
905 state.input_buffer.clear();
906 }
907 state.input_buffer.extend_from_slice(data);
908
909 let mut i = 0;
911 let mut consumed_until = 0;
912
913 while i < state.input_buffer.len() {
914 let slice = &state.input_buffer[i..];
915
916 if let Some((key, len)) = parse_sequence_prefix(slice) {
918 state.session.send_message(Message::new(key));
919 i += len;
920 consumed_until = i;
921 continue;
922 }
923
924 if is_sequence_prefix(slice) {
927 break;
928 }
929
930 let b = slice[0];
933 let char_len = if b < 128 {
934 1
935 } else if (b & 0xE0) == 0xC0 {
936 2
937 } else if (b & 0xF0) == 0xE0 {
938 3
939 } else if (b & 0xF8) == 0xF0 {
940 4
941 } else {
942 let key = KeyMsg::from_char(std::char::REPLACEMENT_CHARACTER);
946 state.session.send_message(Message::new(key));
947 i += 1;
948 consumed_until = i;
949 continue;
950 };
951
952 if slice.len() < char_len {
953 break;
955 }
956
957 match std::str::from_utf8(&slice[..char_len]) {
959 Ok(s) => {
960 if let Some(c) = s.chars().next() {
962 let key = KeyMsg::from_char(c);
963 state.session.send_message(Message::new(key));
964 }
965 i += char_len;
966 consumed_until = i;
967 }
968 Err(_) => {
969 let key = KeyMsg::from_char(std::char::REPLACEMENT_CHARACTER);
971 state.session.send_message(Message::new(key));
972 i += 1;
973 consumed_until = i;
974 }
975 }
976 }
977
978 if consumed_until > 0 {
980 state.input_buffer.drain(0..consumed_until);
981 }
982 }
983
984 Ok(())
985 }
986
987 async fn channel_eof(
989 &mut self,
990 channel: ChannelId,
991 _session: &mut RusshSession,
992 ) -> std::result::Result<(), Self::Error> {
993 debug!(
994 connection_id = self.connection_id,
995 channel = ?channel,
996 "Channel EOF"
997 );
998 Ok(())
999 }
1000
1001 async fn channel_close(
1003 &mut self,
1004 channel: ChannelId,
1005 _session: &mut RusshSession,
1006 ) -> std::result::Result<(), Self::Error> {
1007 debug!(
1008 connection_id = self.connection_id,
1009 channel = ?channel,
1010 "Channel closed"
1011 );
1012
1013 self.channels.remove(&channel);
1014 Ok(())
1015 }
1016}
1017
1018pub struct WishHandlerFactory {
1020 server_state: Arc<ServerState>,
1021 shutdown_tx: broadcast::Sender<()>,
1022}
1023
1024impl WishHandlerFactory {
1025 pub fn new(options: ServerOptions) -> Self {
1027 let (shutdown_tx, _) = broadcast::channel(1);
1028 Self {
1029 server_state: Arc::new(ServerState::new(options)),
1030 shutdown_tx,
1031 }
1032 }
1033
1034 pub fn create_handler(&self, remote_addr: SocketAddr, local_addr: SocketAddr) -> WishHandler {
1036 WishHandler::new(
1037 remote_addr,
1038 local_addr,
1039 self.server_state.clone(),
1040 self.shutdown_tx.subscribe(),
1041 )
1042 }
1043
1044 pub fn shutdown(&self) {
1046 let _ = self.shutdown_tx.send(());
1047 }
1048}
1049
1050#[cfg(test)]
1051mod tests {
1052 use super::*;
1053
1054 #[test]
1055 fn test_parse_exec_command_args_basic() {
1056 let args = parse_exec_command_args("echo hello world").expect("parse should succeed");
1057 assert_eq!(args, vec!["echo", "hello", "world"]);
1058 }
1059
1060 #[test]
1061 fn test_parse_exec_command_args_preserves_quotes() {
1062 let args = parse_exec_command_args(r#"cmd --name "foo bar" --path '/tmp/one two'"#)
1063 .expect("parse should succeed");
1064 assert_eq!(
1065 args,
1066 vec!["cmd", "--name", "foo bar", "--path", "/tmp/one two"]
1067 );
1068 }
1069
1070 #[test]
1071 fn test_parse_exec_command_args_supports_escapes() {
1072 let args =
1073 parse_exec_command_args(r#"cmd one\ two \"quoted\""#).expect("parse should succeed");
1074 assert_eq!(args, vec!["cmd", "one two", "\"quoted\""]);
1075 }
1076
1077 #[test]
1078 fn test_parse_exec_command_args_rejects_unterminated_quotes() {
1079 assert!(parse_exec_command_args(r#"cmd "unterminated"#).is_none());
1080 assert!(parse_exec_command_args("cmd 'unterminated").is_none());
1081 }
1082
1083 #[test]
1084 fn test_parse_exec_command_args_preserves_empty_quoted_args() {
1085 let args = parse_exec_command_args(r#"cmd "" '' tail"#).expect("parse should succeed");
1086 assert_eq!(args, vec!["cmd", "", "", "tail"]);
1087 }
1088
1089 #[test]
1090 fn test_server_state_new() {
1091 let options = ServerOptions::default();
1092 let state = ServerState::new(options);
1093 assert_eq!(state.next_connection_id(), 1);
1094 assert_eq!(state.next_connection_id(), 2);
1095 }
1096}