Skip to main content

chai_framework/
server.rs

1impl<T: ChaiApp + Send + 'static> ChaiServer<T> {
2    /// Helper to get a mutable reference to a client, logging a warning if not found.
3    fn get_client_mut<'a>(
4        &self,
5        clients: &'a mut HashMap<usize, (SshTerminal, T)>,
6        context: &str,
7        id: usize,
8    ) -> Option<(&'a mut SshTerminal, &'a mut T)> {
9        match clients.get_mut(&id) {
10            Some((terminal, app)) => Some((terminal, app)),
11            None => {
12                tracing::warn!("No client found for id {} in {}", id, context);
13                None
14            }
15        }
16    }
17
18    /// Helper to draw the app, logging error if it fails.
19    fn try_draw(&self, terminal: &mut SshTerminal, app: &mut T) {
20        if let Err(e) = terminal.draw(|f| app.draw(f)) {
21            tracing::error!(
22                "Terminal draw error for user {} (id: {}): {:?}",
23                self.username,
24                self.id,
25                e
26            );
27        }
28    }
29}
30use crate::chai::ChaiApp;
31
32use std::collections::HashMap;
33use std::path::Path;
34use std::sync::Arc;
35
36use ratatui::layout::Rect;
37
38use ratatui::{Terminal, TerminalOptions, Viewport};
39use russh::keys::ssh_key::PublicKey;
40use russh::server::*;
41use russh::{Channel, ChannelId, Pty};
42use tokio::sync::Mutex;
43use tokio::sync::mpsc::{Sender, channel, error::TrySendError};
44
45const ENTER_ALT_SCREEN: &[u8] = b"\x1b[?1049h";
46const EXIT_ALT_SCREEN: &[u8] = b"\x1b[?1049l";
47const HIDE_CURSOR: &[u8] = b"\x1b[?25l";
48const SHOW_CURSOR: &[u8] = b"\x1b[?25h";
49const CLEAR_ALL: &[u8] = b"\x1b[2J";
50
51use ratatui::backend::CrosstermBackend;
52
53type SshTerminal = Terminal<CrosstermBackend<TerminalHandle>>;
54
55struct TerminalHandle {
56    sender: Sender<Vec<u8>>,
57    // The sink collects the data which is finally sent to sender.
58    sink: Vec<u8>,
59}
60
61impl TerminalHandle {
62    async fn start(
63        handle: Handle,
64        channel_id: ChannelId,
65        username: String,
66        id: usize,
67        buffer: usize,
68    ) -> Self {
69        let (sender, mut receiver) = channel::<Vec<u8>>(buffer);
70        let username_clone = username.clone();
71        let id_clone = id;
72        tokio::spawn(async move {
73            while let Some(data) = receiver.recv().await {
74                let result = handle.data(channel_id, data.into()).await;
75                if result.is_err() {
76                    tracing::debug!(
77                        "channel closed for user {} (id: {}), stopping data forwarding",
78                        username_clone,
79                        id_clone
80                    );
81                    break;
82                }
83            }
84        });
85        Self {
86            sender,
87            sink: Vec::new(),
88        }
89    }
90}
91
92impl std::io::Write for TerminalHandle {
93    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
94        self.sink.extend_from_slice(buf);
95        Ok(buf.len())
96    }
97
98    fn flush(&mut self) -> std::io::Result<()> {
99        match self.sender.try_send(self.sink.clone()) {
100            Ok(()) => {}
101            Err(TrySendError::Full(_)) => {
102                // Consumer is slow; drop this frame rather than block the async runtime.
103                tracing::debug!("terminal output buffer full, dropping frame");
104            }
105            Err(TrySendError::Closed(_)) => {
106                return Err(std::io::Error::new(
107                    std::io::ErrorKind::BrokenPipe,
108                    "terminal channel closed",
109                ));
110            }
111        }
112
113        self.sink.clear();
114        Ok(())
115    }
116}
117
118const DEFAULT_MAX_CONNECTIONS: usize = 100;
119const DEFAULT_CHANNEL_BUFFER: usize = 64;
120
121pub struct ChaiServer<T: ChaiApp + Send + 'static> {
122    clients: Arc<Mutex<HashMap<usize, (SshTerminal, T)>>>,
123    port: u16,
124    id: usize,
125    username: String,
126    max_connections: usize,
127    channel_buffer: usize,
128    cols: u32,
129    rows: u32,
130    tick_rate: std::time::Duration,
131}
132
133impl<T: ChaiApp + Send + 'static> Clone for ChaiServer<T> {
134    fn clone(&self) -> Self {
135        Self {
136            clients: self.clients.clone(),
137            port: self.port,
138            id: self.id,
139            username: self.username.clone(),
140            max_connections: self.max_connections,
141            channel_buffer: self.channel_buffer,
142            cols: 80,
143            rows: 24,
144            tick_rate: self.tick_rate,
145        }
146    }
147}
148
149impl<T: ChaiApp + Send + 'static> ChaiServer<T> {
150    pub fn new(port: u16) -> Self {
151        Self {
152            clients: Arc::new(Mutex::new(HashMap::new())),
153            port,
154            id: 0,
155            username: String::new(),
156            max_connections: DEFAULT_MAX_CONNECTIONS,
157            channel_buffer: DEFAULT_CHANNEL_BUFFER,
158            cols: 80,
159            rows: 24,
160            tick_rate: std::time::Duration::from_secs(1),
161        }
162    }
163
164    /// Set the maximum number of concurrent SSH connections. Default: 100.
165    pub fn with_max_connections(mut self, max: usize) -> Self {
166        self.max_connections = max;
167        self
168    }
169
170    /// Set the per-connection terminal output channel buffer size.
171    /// Frames are dropped (not buffered) when the buffer is full. Default: 64.
172    pub fn with_channel_buffer(mut self, size: usize) -> Self {
173        self.channel_buffer = size;
174        self
175    }
176    // sets tick rate for the periodic update loop, default 1sec
177    pub fn with_tick_rate(mut self, duration: std::time::Duration) -> Self {
178        self.tick_rate = duration;
179        self
180    }
181
182    pub async fn run(&mut self, config: Config) -> Result<(), anyhow::Error> {
183        let subscriber = tracing_subscriber::fmt()
184            .compact()
185            .with_file(true)
186            .with_line_number(true)
187            .with_target(true)
188            .with_env_filter(
189                tracing_subscriber::EnvFilter::try_from_default_env()
190                    .unwrap_or_else(|_| "info".into()),
191            )
192            .finish();
193
194        // Silently ignore the error if a global subscriber is already set.
195        let _ = tracing::subscriber::set_global_default(subscriber);
196
197        let clients = self.clients.clone();
198        let tick_rate = self.tick_rate;
199        tokio::spawn(async move {
200            loop {
201                tokio::time::sleep(tick_rate).await;
202
203                for (_, (terminal, app)) in clients.lock().await.iter_mut() {
204                    if let Err(e) = terminal.draw(|f| {
205                        app.update();
206                        app.draw(f);
207                    }) {
208                        tracing::error!(
209                            "Background terminal draw error in periodic updater: {:?}",
210                            e
211                        );
212                    }
213                }
214            }
215        });
216
217        tracing::info!("starting server on 0.0.0.0:{}", self.port);
218        self.run_on_address(Arc::new(config), ("0.0.0.0", self.port))
219            .await?;
220        Ok(())
221    }
222}
223
224impl<T: ChaiApp + Send + 'static> ChaiServer<T> {
225    fn send_data_or_log(
226        &mut self,
227        session: &mut Session,
228        channel: ChannelId,
229        data: &[u8],
230        description: &str,
231    ) {
232        if let Err(e) = session.data(channel, data.into()) {
233            tracing::error!(
234                "failed to {} for user {} (id: {}): {:?}",
235                description,
236                self.username,
237                self.id,
238                e
239            );
240        }
241    }
242}
243
244impl<T: ChaiApp + Send + 'static> Server for ChaiServer<T> {
245    type Handler = Self;
246    fn new_client(&mut self, _: Option<std::net::SocketAddr>) -> Self {
247        let s = self.clone();
248        self.id += 1;
249        s
250    }
251}
252
253impl<T: ChaiApp + Send + 'static> Handler for ChaiServer<T> {
254    type Error = anyhow::Error;
255
256    async fn channel_open_session(
257        &mut self,
258        _channel: Channel<Msg>,
259        _session: &mut Session,
260    ) -> Result<bool, Self::Error> {
261        let clients = self.clients.lock().await;
262        if clients.len() >= self.max_connections {
263            tracing::warn!(
264                "max connections ({}) reached, rejecting session for {}",
265                self.max_connections,
266                self.username
267            );
268            return Ok(false);
269        }
270
271        tracing::info!("{} (id: {}) opened a channel", self.username, self.id);
272        Ok(true)
273    }
274
275    async fn auth_publickey(&mut self, user: &str, _: &PublicKey) -> Result<Auth, Self::Error> {
276        self.username = user.to_string();
277        Ok(Auth::Accept)
278    }
279
280    async fn auth_none(&mut self, user: &str) -> Result<Auth, Self::Error> {
281        self.username = user.to_string();
282        Ok(Auth::Accept)
283    }
284
285    async fn data(
286        &mut self,
287        channel: ChannelId,
288        data: &[u8],
289        session: &mut Session,
290    ) -> Result<(), Self::Error> {
291        // Input validation: only allow ASCII bytes (including control characters like Ctrl+C);
292        // reject any non-ASCII input.
293        if !data.iter().all(|&b| b.is_ascii()) {
294            tracing::warn!(
295                "Received invalid input data from user {} (id: {})",
296                self.username,
297                self.id
298            );
299            return Ok(());
300        }
301
302        let should_quit = {
303            let mut clients = self.clients.lock().await;
304            if let Some((terminal, app)) =
305                self.get_client_mut(&mut clients, "data handler", self.id)
306            {
307                // ctrl+c: quit immediately unless the app wants to handle it itself
308                if data == b"\x03" && !app.capture_ctrl_c() {
309                    true
310                } else {
311                    app.handle_input(data);
312                    let quit = app.should_quit();
313                    if !quit {
314                        self.try_draw(terminal, app);
315                    }
316                    quit
317                }
318            } else {
319                false
320            }
321        };
322
323        if should_quit {
324            self.send_data_or_log(session, channel, EXIT_ALT_SCREEN, "exit alternate screen");
325            self.send_data_or_log(session, channel, SHOW_CURSOR, "show cursor");
326            session.close(channel)?;
327        }
328
329        Ok(())
330    }
331
332    async fn window_change_request(
333        &mut self,
334        channel: ChannelId,
335        col_width: u32,
336        row_height: u32,
337        _: u32,
338        _: u32,
339        session: &mut Session,
340    ) -> Result<(), Self::Error> {
341        let rect = Rect {
342            x: 0,
343            y: 0,
344            width: col_width as u16,
345            height: row_height as u16,
346        };
347
348        let mut clients = self.clients.lock().await;
349        if let Some((_, mut app)) = clients.remove(&self.id) {
350            // Recreate the terminal at the new size instead of calling
351            // terminal.resize()
352            let terminal_handle = TerminalHandle::start(
353                session.handle(),
354                channel,
355                self.username.clone(),
356                self.id,
357                self.channel_buffer,
358            )
359            .await;
360
361            let backend = CrosstermBackend::new(terminal_handle);
362            let options = TerminalOptions {
363                viewport: Viewport::Fixed(rect),
364            };
365            let mut terminal = Terminal::with_options(backend, options)?;
366
367            // Queue a full screen clear into the sink (no flush yet) so that
368            // it ships in the same message as the first draw frame.
369            {
370                use std::io::Write;
371                terminal.backend_mut().write_all(CLEAR_ALL)?;
372            }
373
374            self.try_draw(&mut terminal, &mut app);
375            clients.insert(self.id, (terminal, app));
376        }
377
378        Ok(())
379    }
380
381    async fn pty_request(
382        &mut self,
383        _channel: ChannelId,
384        _: &str,
385        col_width: u32,
386        row_height: u32,
387        _: u32,
388        _: u32,
389        _: &[(Pty, u32)],
390        _session: &mut Session,
391    ) -> Result<(), Self::Error> {
392        self.cols = col_width;
393        self.rows = row_height;
394        Ok(())
395    }
396
397    async fn shell_request(
398        &mut self,
399        channel: ChannelId,
400        session: &mut Session,
401    ) -> Result<(), Self::Error> {
402        let handle = session.handle();
403
404        let terminal_handle = TerminalHandle::start(
405            handle,
406            channel,
407            self.username.clone(),
408            self.id,
409            self.channel_buffer,
410        )
411        .await;
412
413        let rect = Rect {
414            x: 0,
415            y: 0,
416            width: self.cols as u16,
417            height: self.rows as u16,
418        };
419
420        let backend = CrosstermBackend::new(terminal_handle);
421        let options = TerminalOptions {
422            viewport: Viewport::Fixed(rect),
423        };
424        let mut terminal = Terminal::with_options(backend, options)?;
425        let mut app = T::new();
426        app.on_connect(&self.username);
427
428        // Send escape sequences through the terminal backend so they share
429        // the same ordered mpsc channel as draw output.
430        {
431            use std::io::Write;
432            let backend = terminal.backend_mut();
433            backend.write_all(ENTER_ALT_SCREEN)?;
434            backend.write_all(HIDE_CURSOR)?;
435            backend.flush()?;
436        }
437
438        self.try_draw(&mut terminal, &mut app);
439        self.clients.lock().await.insert(self.id, (terminal, app));
440
441        session.channel_success(channel)?;
442
443        Ok(())
444    }
445
446    async fn channel_close(
447        &mut self,
448        channel: ChannelId,
449        session: &mut Session,
450    ) -> Result<(), Self::Error> {
451        tracing::info!("{} (id: {}) closed a channel", self.username, self.id);
452        let reset_sequence = [EXIT_ALT_SCREEN, SHOW_CURSOR].concat();
453        let _ = session.data(channel, reset_sequence.into());
454
455        let mut clients = self.clients.lock().await;
456        if let Some((_, app)) = clients.get_mut(&self.id) {
457            app.on_disconnect(&self.username);
458        }
459        if clients.remove(&self.id).is_none() {
460            tracing::warn!("No client found for id {} in channel_close", self.id);
461        }
462        Ok(())
463    }
464}
465
466impl<T: ChaiApp + Send + 'static> Drop for ChaiServer<T> {
467    fn drop(&mut self) {
468        let id = self.id;
469        let clients = self.clients.clone();
470        if let Ok(mut guard) = clients.try_lock() {
471            guard.remove(&id);
472        } else {
473            // If we can't lock, log and skip
474            tracing::warn!("Could not lock clients for cleanup in Drop for id {}", id);
475        }
476    }
477}
478
479pub fn load_system_host_keys(key_name: &str) -> Result<russh::keys::PrivateKey, anyhow::Error> {
480    let key_path = Path::new("/.ssh").join(key_name);
481
482    if !key_path.exists() {
483        return Err(anyhow::anyhow!(
484            "Host key not found at {}. Please generate host keys first.",
485            key_path.display()
486        ));
487    }
488
489    let key = russh::keys::PrivateKey::read_openssh_file(&key_path)
490        .map_err(|e| anyhow::anyhow!("Failed to read host key: {}", e))?;
491
492    Ok(key)
493}
494
495pub fn load_host_keys(path: &str) -> Result<russh::keys::PrivateKey, anyhow::Error> {
496    let key_path = Path::new(path);
497
498    if !key_path.exists() {
499        return Err(anyhow::anyhow!(
500            "Host key not found at {}. Please generate host keys first.",
501            key_path.display()
502        ));
503    }
504
505    let key = russh::keys::PrivateKey::read_openssh_file(key_path)
506        .map_err(|e| anyhow::anyhow!("Failed to read host key: {}", e))?;
507
508    Ok(key)
509}
510