1impl<T: ChaiApp + Send + 'static> ChaiServer<T> {
2 fn get_client_mut<'a>(
4 &self,
5 clients: &'a mut HashMap<usize, (SshTerminal, T)>,
6 context: &str,
7 id: usize,
8 ) -> Option<(&'a mut SshTerminal, &'a mut T)> {
9 match clients.get_mut(&id) {
10 Some((terminal, app)) => Some((terminal, app)),
11 None => {
12 tracing::warn!("No client found for id {} in {}", id, context);
13 None
14 }
15 }
16 }
17
18 fn get_terminal_mut<'a>(
20 &self,
21 clients: &'a mut HashMap<usize, (SshTerminal, T)>,
22 context: &str,
23 id: usize,
24 ) -> Option<&'a mut SshTerminal> {
25 match clients.get_mut(&id) {
26 Some((terminal, _)) => Some(terminal),
27 None => {
28 tracing::warn!("No client found for id {} in {}", id, context);
29 None
30 }
31 }
32 }
33
34 fn try_draw(&self, terminal: &mut SshTerminal, app: &mut T) {
36 if let Err(e) = terminal.draw(|f| app.draw(f)) {
37 tracing::error!(
38 "Terminal draw error for user {} (id: {}): {:?}",
39 self.username,
40 self.id,
41 e
42 );
43 }
44 }
45}
46use crate::chai::ChaiApp;
47
48use std::collections::HashMap;
49use std::path::Path;
50use std::sync::Arc;
51
52use ratatui::backend::CrosstermBackend;
53use ratatui::layout::Rect;
54
55use ratatui::{Terminal, TerminalOptions, Viewport};
56use russh::keys::ssh_key::PublicKey;
57use russh::server::*;
58use russh::{Channel, ChannelId, Pty};
59use tokio::sync::Mutex;
60use tokio::sync::mpsc::{Sender, channel, error::TrySendError};
61
62const ENTER_ALT_SCREEN: &[u8] = b"\x1b[?1049h";
63const EXIT_ALT_SCREEN: &[u8] = b"\x1b[?1049l";
64const HIDE_CURSOR: &[u8] = b"\x1b[?25l";
65const SHOW_CURSOR: &[u8] = b"\x1b[?25h";
66
67type SshTerminal = Terminal<CrosstermBackend<TerminalHandle>>;
68
69struct TerminalHandle {
70 sender: Sender<Vec<u8>>,
71 sink: Vec<u8>,
73}
74
75impl TerminalHandle {
76 async fn start(
77 handle: Handle,
78 channel_id: ChannelId,
79 username: String,
80 id: usize,
81 buffer: usize,
82 ) -> Self {
83 let (sender, mut receiver) = channel::<Vec<u8>>(buffer);
84 let username_clone = username.clone();
85 let id_clone = id;
86 tokio::spawn(async move {
87 while let Some(data) = receiver.recv().await {
88 let result = handle.data(channel_id, data.into()).await;
89 if result.is_err() {
90 tracing::error!(
91 "failed to send data for user {} (id: {}): {result:?}",
92 username_clone,
93 id_clone
94 );
95 }
96 }
97 });
98 Self {
99 sender,
100 sink: Vec::new(),
101 }
102 }
103}
104
105impl std::io::Write for TerminalHandle {
106 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
107 self.sink.extend_from_slice(buf);
108 Ok(buf.len())
109 }
110
111 fn flush(&mut self) -> std::io::Result<()> {
112 match self.sender.try_send(self.sink.clone()) {
113 Ok(()) => {}
114 Err(TrySendError::Full(_)) => {
115 tracing::debug!("terminal output buffer full, dropping frame");
117 }
118 Err(TrySendError::Closed(_)) => {
119 return Err(std::io::Error::new(
120 std::io::ErrorKind::BrokenPipe,
121 "terminal channel closed",
122 ));
123 }
124 }
125
126 self.sink.clear();
127 Ok(())
128 }
129}
130
131const DEFAULT_MAX_CONNECTIONS: usize = 100;
132const DEFAULT_CHANNEL_BUFFER: usize = 64;
133
134pub struct ChaiServer<T: ChaiApp + Send + 'static> {
135 clients: Arc<Mutex<HashMap<usize, (SshTerminal, T)>>>,
136 port: u16,
137 id: usize,
138 username: String,
139 max_connections: usize,
140 channel_buffer: usize,
141}
142
143impl<T: ChaiApp + Send + 'static> Clone for ChaiServer<T> {
144 fn clone(&self) -> Self {
145 Self {
146 clients: self.clients.clone(),
147 port: self.port,
148 id: self.id,
149 username: self.username.clone(),
150 max_connections: self.max_connections,
151 channel_buffer: self.channel_buffer,
152 }
153 }
154}
155
156impl<T: ChaiApp + Send + 'static> ChaiServer<T> {
157 pub fn new(port: u16) -> Self {
158 Self {
159 clients: Arc::new(Mutex::new(HashMap::new())),
160 port,
161 id: 0,
162 username: String::new(),
163 max_connections: DEFAULT_MAX_CONNECTIONS,
164 channel_buffer: DEFAULT_CHANNEL_BUFFER,
165 }
166 }
167
168 pub fn with_max_connections(mut self, max: usize) -> Self {
170 self.max_connections = max;
171 self
172 }
173
174 pub fn with_channel_buffer(mut self, size: usize) -> Self {
177 self.channel_buffer = size;
178 self
179 }
180
181 pub async fn run(&mut self, config: Config) -> Result<(), anyhow::Error> {
182 let subscriber = tracing_subscriber::fmt()
183 .compact()
184 .with_file(true)
185 .with_line_number(true)
186 .with_target(true)
187 .with_env_filter(
188 tracing_subscriber::EnvFilter::try_from_default_env()
189 .unwrap_or_else(|_| "info".into()),
190 )
191 .finish();
192
193 let _ = tracing::subscriber::set_global_default(subscriber);
195
196 let clients = self.clients.clone();
197 tokio::spawn(async move {
198 loop {
199 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
200
201 for (_, (terminal, app)) in clients.lock().await.iter_mut() {
202 terminal
203 .draw(|f| {
204 app.update();
205 app.draw(f);
206 })
207 .unwrap();
208 }
209 }
210 });
211
212 tracing::info!("starting server on 0.0.0.0:{}", self.port);
213 self.run_on_address(Arc::new(config), ("0.0.0.0", self.port))
214 .await?;
215 Ok(())
216 }
217}
218
219impl<T: ChaiApp + Send + 'static> ChaiServer<T> {
220 fn send_data_or_log(
221 &mut self,
222 session: &mut Session,
223 channel: ChannelId,
224 data: &[u8],
225 description: &str,
226 ) {
227 if let Err(e) = session.data(channel, data.into()) {
228 tracing::error!(
229 "failed to {} for user {} (id: {}): {:?}",
230 description,
231 self.username,
232 self.id,
233 e
234 );
235 }
236 }
237}
238
239impl<T: ChaiApp + Send + 'static> Server for ChaiServer<T> {
240 type Handler = Self;
241 fn new_client(&mut self, _: Option<std::net::SocketAddr>) -> Self {
242 let s = self.clone();
243 self.id += 1;
244 s
245 }
246}
247
248impl<T: ChaiApp + Send + 'static> Handler for ChaiServer<T> {
249 type Error = anyhow::Error;
250
251 async fn channel_open_session(
252 &mut self,
253 channel: Channel<Msg>,
254 session: &mut Session,
255 ) -> Result<bool, Self::Error> {
256 {
257 let clients = self.clients.lock().await;
258 if clients.len() >= self.max_connections {
259 tracing::warn!(
260 "max connections ({}) reached, rejecting session for {}",
261 self.max_connections,
262 self.username
263 );
264 return Ok(false);
265 }
266 }
267
268 tracing::info!("{} (id: {}) opened a channel", self.username, self.id);
269 let terminal_handle = TerminalHandle::start(
270 session.handle(),
271 channel.id(),
272 self.username.clone(),
273 self.id,
274 self.channel_buffer,
275 )
276 .await;
277
278 let backend = CrosstermBackend::new(terminal_handle);
279
280 let options = TerminalOptions {
282 viewport: Viewport::Fixed(Rect::default()),
283 };
284
285 let terminal = Terminal::with_options(backend, options)?;
286 let app = T::new();
287
288 let mut clients = self.clients.lock().await;
289 clients.insert(self.id, (terminal, app));
290
291 Ok(true)
292 }
293
294 async fn auth_publickey(&mut self, user: &str, _: &PublicKey) -> Result<Auth, Self::Error> {
295 self.username = user.to_string();
296 Ok(Auth::Accept)
297 }
298
299 async fn auth_none(&mut self, user: &str) -> Result<Auth, Self::Error> {
300 self.username = user.to_string();
301 Ok(Auth::Accept)
302 }
303
304 async fn data(
305 &mut self,
306 channel: ChannelId,
307 data: &[u8],
308 session: &mut Session,
309 ) -> Result<(), Self::Error> {
310 if !data
312 .iter()
313 .all(|&b| b == b'\n' || b == b'\r' || (b >= 0x20 && b <= 0x7e))
314 {
315 tracing::warn!(
316 "Received invalid input data from user {} (id: {})",
317 self.username,
318 self.id
319 );
320 return Ok(());
321 }
322
323 let should_quit = {
324 let mut clients = self.clients.lock().await;
325 if let Some((terminal, app)) =
326 self.get_client_mut(&mut clients, "data handler", self.id)
327 {
328 app.handle_input(data);
329 let quit = app.should_quit();
330 if !quit {
331 self.try_draw(terminal, app);
332 }
333 quit
334 } else {
335 false
336 }
337 };
338
339 if should_quit {
340 self.send_data_or_log(session, channel, EXIT_ALT_SCREEN, "exit alternate screen");
341 self.send_data_or_log(session, channel, SHOW_CURSOR, "show cursor");
342 self.clients.lock().await.remove(&self.id);
343 session.close(channel)?;
344 }
345
346 Ok(())
347 }
348
349 async fn window_change_request(
350 &mut self,
351 _: ChannelId,
352 col_width: u32,
353 row_height: u32,
354 _: u32,
355 _: u32,
356 _: &mut Session,
357 ) -> Result<(), Self::Error> {
358 let rect = Rect {
359 x: 0,
360 y: 0,
361 width: col_width as u16,
362 height: row_height as u16,
363 };
364
365 let mut clients = self.clients.lock().await;
366 if let Some(terminal) =
367 self.get_terminal_mut(&mut clients, "window_change_request", self.id)
368 {
369 terminal.resize(rect)?;
370 }
371
372 Ok(())
373 }
374
375 async fn pty_request(
376 &mut self,
377 channel: ChannelId,
378 _: &str,
379 col_width: u32,
380 row_height: u32,
381 _: u32,
382 _: u32,
383 _: &[(Pty, u32)],
384 session: &mut Session,
385 ) -> Result<(), Self::Error> {
386 let rect = Rect {
387 x: 0,
388 y: 0,
389 width: col_width as u16,
390 height: row_height as u16,
391 };
392
393 {
394 let mut clients = self.clients.lock().await;
395 if let Some(terminal) =
396 self.get_terminal_mut(&mut clients, "pty_request (resize)", self.id)
397 {
398 terminal.resize(rect)?;
399 }
400 }
401
402 session.channel_success(channel)?;
403
404 self.send_data_or_log(session, channel, ENTER_ALT_SCREEN, "enter alternate screen");
405 self.send_data_or_log(session, channel, HIDE_CURSOR, "hide cursor");
406
407 {
408 let mut clients = self.clients.lock().await;
409 if let Some((terminal, app)) =
410 self.get_client_mut(&mut clients, "pty_request (draw)", self.id)
411 {
412 self.try_draw(terminal, app);
413 }
414 }
415
416 Ok(())
417 }
418
419 async fn channel_close(
420 &mut self,
421 channel: ChannelId,
422 session: &mut Session,
423 ) -> Result<(), Self::Error> {
424 tracing::info!("{} (id: {}) closed a channel", self.username, self.id);
425 let reset_sequence = [EXIT_ALT_SCREEN, SHOW_CURSOR].concat();
426 let _ = session.data(channel, reset_sequence.into());
427
428 let mut clients = self.clients.lock().await;
429 if clients.remove(&self.id).is_none() {
430 tracing::warn!("No client found for id {} in channel_close", self.id);
431 }
432 Ok(())
433 }
434}
435
436impl<T: ChaiApp + Send + 'static> Drop for ChaiServer<T> {
437 fn drop(&mut self) {
438 let id = self.id;
440 let clients = self.clients.clone();
441 if let Ok(mut guard) = clients.try_lock() {
442 guard.remove(&id);
443 } else {
444 tracing::warn!("Could not lock clients for cleanup in Drop for id {}", id);
446 }
447 }
448}
449
450pub fn load_system_host_keys(key_name: &str) -> Result<russh::keys::PrivateKey, anyhow::Error> {
451 let key_path = Path::new("/.ssh").join(key_name);
452
453 if !key_path.exists() {
454 return Err(anyhow::anyhow!(
455 "Host key not found at {}. Please generate host keys first.",
456 key_path.display()
457 ));
458 }
459
460 let key = russh::keys::PrivateKey::read_openssh_file(&key_path)
461 .map_err(|e| anyhow::anyhow!("Failed to read host key: {}", e))?;
462
463 Ok(key)
464}
465
466pub fn load_host_keys(path: &str) -> Result<russh::keys::PrivateKey, anyhow::Error> {
467 let key_path = Path::new(path);
468
469 if !key_path.exists() {
470 return Err(anyhow::anyhow!(
471 "Host key not found at {}. Please generate host keys first.",
472 key_path.display()
473 ));
474 }
475
476 let key = russh::keys::PrivateKey::read_openssh_file(&key_path)
477 .map_err(|e| anyhow::anyhow!("Failed to read host key: {}", e))?;
478
479 Ok(key)
480}