1impl<T: ChaiApp + Send + 'static> ChaiServer<T> {
2 fn get_client_mut<'a>(
4 &self,
5 clients: &'a mut HashMap<usize, (SshTerminal, T)>,
6 context: &str,
7 id: usize,
8 ) -> Option<(&'a mut SshTerminal, &'a mut T)> {
9 match clients.get_mut(&id) {
10 Some((terminal, app)) => Some((terminal, app)),
11 None => {
12 tracing::warn!("No client found for id {} in {}", id, context);
13 None
14 }
15 }
16 }
17
18 fn try_draw(&self, terminal: &mut SshTerminal, app: &mut T) {
20 if let Err(e) = terminal.draw(|f| app.draw(f)) {
21 tracing::error!(
22 "Terminal draw error for user {} (id: {}): {:?}",
23 self.username,
24 self.id,
25 e
26 );
27 }
28 }
29}
30use crate::chai::ChaiApp;
31
32use std::collections::HashMap;
33use std::path::Path;
34use std::sync::Arc;
35
36use ratatui::layout::Rect;
37
38use ratatui::{Terminal, TerminalOptions, Viewport};
39use russh::keys::ssh_key::PublicKey;
40use russh::server::*;
41use russh::{Channel, ChannelId, Pty};
42use tokio::sync::Mutex;
43use tokio::sync::mpsc::{Sender, channel, error::TrySendError};
44
45const ENTER_ALT_SCREEN: &[u8] = b"\x1b[?1049h";
46const EXIT_ALT_SCREEN: &[u8] = b"\x1b[?1049l";
47const HIDE_CURSOR: &[u8] = b"\x1b[?25l";
48const SHOW_CURSOR: &[u8] = b"\x1b[?25h";
49const CLEAR_ALL: &[u8] = b"\x1b[2J";
50
51use ratatui::backend::CrosstermBackend;
52
53type SshTerminal = Terminal<CrosstermBackend<TerminalHandle>>;
54
55struct TerminalHandle {
56 sender: Sender<Vec<u8>>,
57 sink: Vec<u8>,
59}
60
61impl TerminalHandle {
62 async fn start(
63 handle: Handle,
64 channel_id: ChannelId,
65 username: String,
66 id: usize,
67 buffer: usize,
68 ) -> Self {
69 let (sender, mut receiver) = channel::<Vec<u8>>(buffer);
70 let username_clone = username.clone();
71 let id_clone = id;
72 tokio::spawn(async move {
73 while let Some(data) = receiver.recv().await {
74 let result = handle.data(channel_id, data.into()).await;
75 if result.is_err() {
76 tracing::debug!(
77 "channel closed for user {} (id: {}), stopping data forwarding",
78 username_clone,
79 id_clone
80 );
81 break;
82 }
83 }
84 });
85 Self {
86 sender,
87 sink: Vec::new(),
88 }
89 }
90}
91
92impl std::io::Write for TerminalHandle {
93 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
94 self.sink.extend_from_slice(buf);
95 Ok(buf.len())
96 }
97
98 fn flush(&mut self) -> std::io::Result<()> {
99 match self.sender.try_send(self.sink.clone()) {
100 Ok(()) => {}
101 Err(TrySendError::Full(_)) => {
102 tracing::debug!("terminal output buffer full, dropping frame");
104 }
105 Err(TrySendError::Closed(_)) => {
106 return Err(std::io::Error::new(
107 std::io::ErrorKind::BrokenPipe,
108 "terminal channel closed",
109 ));
110 }
111 }
112
113 self.sink.clear();
114 Ok(())
115 }
116}
117
118const DEFAULT_MAX_CONNECTIONS: usize = 100;
119const DEFAULT_CHANNEL_BUFFER: usize = 64;
120
121pub struct ChaiServer<T: ChaiApp + Send + 'static> {
122 clients: Arc<Mutex<HashMap<usize, (SshTerminal, T)>>>,
123 port: u16,
124 id: usize,
125 username: String,
126 max_connections: usize,
127 channel_buffer: usize,
128 cols: u32,
129 rows: u32,
130 tick_rate: std::time::Duration,
131}
132
133impl<T: ChaiApp + Send + 'static> Clone for ChaiServer<T> {
134 fn clone(&self) -> Self {
135 Self {
136 clients: self.clients.clone(),
137 port: self.port,
138 id: self.id,
139 username: self.username.clone(),
140 max_connections: self.max_connections,
141 channel_buffer: self.channel_buffer,
142 cols: 80,
143 rows: 24,
144 tick_rate: self.tick_rate,
145 }
146 }
147}
148
149impl<T: ChaiApp + Send + 'static> ChaiServer<T> {
150 pub fn new(port: u16) -> Self {
151 Self {
152 clients: Arc::new(Mutex::new(HashMap::new())),
153 port,
154 id: 0,
155 username: String::new(),
156 max_connections: DEFAULT_MAX_CONNECTIONS,
157 channel_buffer: DEFAULT_CHANNEL_BUFFER,
158 cols: 80,
159 rows: 24,
160 tick_rate: std::time::Duration::from_secs(1),
161 }
162 }
163
164 pub fn with_max_connections(mut self, max: usize) -> Self {
166 self.max_connections = max;
167 self
168 }
169
170 pub fn with_channel_buffer(mut self, size: usize) -> Self {
173 self.channel_buffer = size;
174 self
175 }
176 pub fn with_tick_rate(mut self, duration: std::time::Duration) -> Self {
178 self.tick_rate = duration;
179 self
180 }
181
182 pub async fn run(&mut self, config: Config) -> Result<(), anyhow::Error> {
183 let subscriber = tracing_subscriber::fmt()
184 .compact()
185 .with_file(true)
186 .with_line_number(true)
187 .with_target(true)
188 .with_env_filter(
189 tracing_subscriber::EnvFilter::try_from_default_env()
190 .unwrap_or_else(|_| "info".into()),
191 )
192 .finish();
193
194 let _ = tracing::subscriber::set_global_default(subscriber);
196
197 let clients = self.clients.clone();
198 let tick_rate = self.tick_rate;
199 tokio::spawn(async move {
200 loop {
201 tokio::time::sleep(tick_rate).await;
202
203 for (_, (terminal, app)) in clients.lock().await.iter_mut() {
204 if let Err(e) = terminal.draw(|f| {
205 app.update();
206 app.draw(f);
207 }) {
208 tracing::error!(
209 "Background terminal draw error in periodic updater: {:?}",
210 e
211 );
212 }
213 }
214 }
215 });
216
217 tracing::info!("starting server on 0.0.0.0:{}", self.port);
218 self.run_on_address(Arc::new(config), ("0.0.0.0", self.port))
219 .await?;
220 Ok(())
221 }
222}
223
224impl<T: ChaiApp + Send + 'static> ChaiServer<T> {
225 fn send_data_or_log(
226 &mut self,
227 session: &mut Session,
228 channel: ChannelId,
229 data: &[u8],
230 description: &str,
231 ) {
232 if let Err(e) = session.data(channel, data.into()) {
233 tracing::error!(
234 "failed to {} for user {} (id: {}): {:?}",
235 description,
236 self.username,
237 self.id,
238 e
239 );
240 }
241 }
242}
243
244impl<T: ChaiApp + Send + 'static> Server for ChaiServer<T> {
245 type Handler = Self;
246 fn new_client(&mut self, _: Option<std::net::SocketAddr>) -> Self {
247 let s = self.clone();
248 self.id += 1;
249 s
250 }
251}
252
253impl<T: ChaiApp + Send + 'static> Handler for ChaiServer<T> {
254 type Error = anyhow::Error;
255
256 async fn channel_open_session(
257 &mut self,
258 _channel: Channel<Msg>,
259 _session: &mut Session,
260 ) -> Result<bool, Self::Error> {
261 let clients = self.clients.lock().await;
262 if clients.len() >= self.max_connections {
263 tracing::warn!(
264 "max connections ({}) reached, rejecting session for {}",
265 self.max_connections,
266 self.username
267 );
268 return Ok(false);
269 }
270
271 tracing::info!("{} (id: {}) opened a channel", self.username, self.id);
272 Ok(true)
273 }
274
275 async fn auth_publickey(&mut self, user: &str, _: &PublicKey) -> Result<Auth, Self::Error> {
276 self.username = user.to_string();
277 Ok(Auth::Accept)
278 }
279
280 async fn auth_none(&mut self, user: &str) -> Result<Auth, Self::Error> {
281 self.username = user.to_string();
282 Ok(Auth::Accept)
283 }
284
285 async fn data(
286 &mut self,
287 channel: ChannelId,
288 data: &[u8],
289 session: &mut Session,
290 ) -> Result<(), Self::Error> {
291 if !data.iter().all(|&b| b.is_ascii()) {
294 tracing::warn!(
295 "Received invalid input data from user {} (id: {})",
296 self.username,
297 self.id
298 );
299 return Ok(());
300 }
301
302 let should_quit = {
303 let mut clients = self.clients.lock().await;
304 if let Some((terminal, app)) =
305 self.get_client_mut(&mut clients, "data handler", self.id)
306 {
307 if data == b"\x03" && !app.capture_ctrl_c() {
309 true
310 } else {
311 app.handle_input(data);
312 let quit = app.should_quit();
313 if !quit {
314 self.try_draw(terminal, app);
315 }
316 quit
317 }
318 } else {
319 false
320 }
321 };
322
323 if should_quit {
324 self.send_data_or_log(session, channel, EXIT_ALT_SCREEN, "exit alternate screen");
325 self.send_data_or_log(session, channel, SHOW_CURSOR, "show cursor");
326 session.close(channel)?;
327 }
328
329 Ok(())
330 }
331
332 async fn window_change_request(
333 &mut self,
334 channel: ChannelId,
335 col_width: u32,
336 row_height: u32,
337 _: u32,
338 _: u32,
339 session: &mut Session,
340 ) -> Result<(), Self::Error> {
341 let rect = Rect {
342 x: 0,
343 y: 0,
344 width: col_width as u16,
345 height: row_height as u16,
346 };
347
348 let mut clients = self.clients.lock().await;
349 if let Some((_, mut app)) = clients.remove(&self.id) {
350 let terminal_handle = TerminalHandle::start(
353 session.handle(),
354 channel,
355 self.username.clone(),
356 self.id,
357 self.channel_buffer,
358 )
359 .await;
360
361 let backend = CrosstermBackend::new(terminal_handle);
362 let options = TerminalOptions {
363 viewport: Viewport::Fixed(rect),
364 };
365 let mut terminal = Terminal::with_options(backend, options)?;
366
367 {
370 use std::io::Write;
371 terminal.backend_mut().write_all(CLEAR_ALL)?;
372 }
373
374 self.try_draw(&mut terminal, &mut app);
375 clients.insert(self.id, (terminal, app));
376 }
377
378 Ok(())
379 }
380
381 async fn pty_request(
382 &mut self,
383 _channel: ChannelId,
384 _: &str,
385 col_width: u32,
386 row_height: u32,
387 _: u32,
388 _: u32,
389 _: &[(Pty, u32)],
390 _session: &mut Session,
391 ) -> Result<(), Self::Error> {
392 self.cols = col_width;
393 self.rows = row_height;
394 Ok(())
395 }
396
397 async fn shell_request(
398 &mut self,
399 channel: ChannelId,
400 session: &mut Session,
401 ) -> Result<(), Self::Error> {
402 let handle = session.handle();
403
404 let terminal_handle = TerminalHandle::start(
405 handle,
406 channel,
407 self.username.clone(),
408 self.id,
409 self.channel_buffer,
410 )
411 .await;
412
413 let rect = Rect {
414 x: 0,
415 y: 0,
416 width: self.cols as u16,
417 height: self.rows as u16,
418 };
419
420 let backend = CrosstermBackend::new(terminal_handle);
421 let options = TerminalOptions {
422 viewport: Viewport::Fixed(rect),
423 };
424 let mut terminal = Terminal::with_options(backend, options)?;
425 let mut app = T::new();
426 app.on_connect(&self.username);
427
428 {
431 use std::io::Write;
432 let backend = terminal.backend_mut();
433 backend.write_all(ENTER_ALT_SCREEN)?;
434 backend.write_all(HIDE_CURSOR)?;
435 backend.flush()?;
436 }
437
438 self.try_draw(&mut terminal, &mut app);
439 self.clients.lock().await.insert(self.id, (terminal, app));
440
441 session.channel_success(channel)?;
442
443 Ok(())
444 }
445
446 async fn channel_close(
447 &mut self,
448 channel: ChannelId,
449 session: &mut Session,
450 ) -> Result<(), Self::Error> {
451 tracing::info!("{} (id: {}) closed a channel", self.username, self.id);
452 let reset_sequence = [EXIT_ALT_SCREEN, SHOW_CURSOR].concat();
453 let _ = session.data(channel, reset_sequence.into());
454
455 let mut clients = self.clients.lock().await;
456 if let Some((_, app)) = clients.get_mut(&self.id) {
457 app.on_disconnect(&self.username);
458 }
459 if clients.remove(&self.id).is_none() {
460 tracing::warn!("No client found for id {} in channel_close", self.id);
461 }
462 Ok(())
463 }
464}
465
466impl<T: ChaiApp + Send + 'static> Drop for ChaiServer<T> {
467 fn drop(&mut self) {
468 let id = self.id;
469 let clients = self.clients.clone();
470 if let Ok(mut guard) = clients.try_lock() {
471 guard.remove(&id);
472 } else {
473 tracing::warn!("Could not lock clients for cleanup in Drop for id {}", id);
475 }
476 }
477}
478
479pub fn load_system_host_keys(key_name: &str) -> Result<russh::keys::PrivateKey, anyhow::Error> {
480 let key_path = Path::new("/.ssh").join(key_name);
481
482 if !key_path.exists() {
483 return Err(anyhow::anyhow!(
484 "Host key not found at {}. Please generate host keys first.",
485 key_path.display()
486 ));
487 }
488
489 let key = russh::keys::PrivateKey::read_openssh_file(&key_path)
490 .map_err(|e| anyhow::anyhow!("Failed to read host key: {}", e))?;
491
492 Ok(key)
493}
494
495pub fn load_host_keys(path: &str) -> Result<russh::keys::PrivateKey, anyhow::Error> {
496 let key_path = Path::new(path);
497
498 if !key_path.exists() {
499 return Err(anyhow::anyhow!(
500 "Host key not found at {}. Please generate host keys first.",
501 key_path.display()
502 ));
503 }
504
505 let key = russh::keys::PrivateKey::read_openssh_file(key_path)
506 .map_err(|e| anyhow::anyhow!("Failed to read host key: {}", e))?;
507
508 Ok(key)
509}
510