Skip to main content

chai_framework/
server.rs

1impl<T: ChaiApp + Send + 'static> ChaiServer<T> {
2    /// Helper to get a mutable reference to a client, logging a warning if not found.
3    fn get_client_mut<'a>(
4        &self,
5        clients: &'a mut HashMap<usize, (SshTerminal, T)>,
6        context: &str,
7        id: usize,
8    ) -> Option<(&'a mut SshTerminal, &'a mut T)> {
9        match clients.get_mut(&id) {
10            Some((terminal, app)) => Some((terminal, app)),
11            None => {
12                tracing::warn!("No client found for id {} in {}", id, context);
13                None
14            }
15        }
16    }
17
18    /// Helper to get a mutable reference to a client, logging a warning if not found, but only needs terminal.
19    fn get_terminal_mut<'a>(
20        &self,
21        clients: &'a mut HashMap<usize, (SshTerminal, T)>,
22        context: &str,
23        id: usize,
24    ) -> Option<&'a mut SshTerminal> {
25        match clients.get_mut(&id) {
26            Some((terminal, _)) => Some(terminal),
27            None => {
28                tracing::warn!("No client found for id {} in {}", id, context);
29                None
30            }
31        }
32    }
33
34    /// Helper to draw the app, logging error if it fails.
35    fn try_draw(&self, terminal: &mut SshTerminal, app: &mut T) {
36        if let Err(e) = terminal.draw(|f| app.draw(f)) {
37            tracing::error!(
38                "Terminal draw error for user {} (id: {}): {:?}",
39                self.username,
40                self.id,
41                e
42            );
43        }
44    }
45}
46use crate::chai::ChaiApp;
47
48use std::collections::HashMap;
49use std::path::Path;
50use std::sync::Arc;
51
52use ratatui::backend::CrosstermBackend;
53use ratatui::layout::Rect;
54
55use ratatui::{Terminal, TerminalOptions, Viewport};
56use russh::keys::ssh_key::PublicKey;
57use russh::server::*;
58use russh::{Channel, ChannelId, Pty};
59use tokio::sync::Mutex;
60use tokio::sync::mpsc::{Sender, channel, error::TrySendError};
61
62const ENTER_ALT_SCREEN: &[u8] = b"\x1b[?1049h";
63const EXIT_ALT_SCREEN: &[u8] = b"\x1b[?1049l";
64const HIDE_CURSOR: &[u8] = b"\x1b[?25l";
65const SHOW_CURSOR: &[u8] = b"\x1b[?25h";
66
67type SshTerminal = Terminal<CrosstermBackend<TerminalHandle>>;
68
69struct TerminalHandle {
70    sender: Sender<Vec<u8>>,
71    // The sink collects the data which is finally sent to sender.
72    sink: Vec<u8>,
73}
74
75impl TerminalHandle {
76    async fn start(
77        handle: Handle,
78        channel_id: ChannelId,
79        username: String,
80        id: usize,
81        buffer: usize,
82    ) -> Self {
83        let (sender, mut receiver) = channel::<Vec<u8>>(buffer);
84        let username_clone = username.clone();
85        let id_clone = id;
86        tokio::spawn(async move {
87            while let Some(data) = receiver.recv().await {
88                let result = handle.data(channel_id, data.into()).await;
89                if result.is_err() {
90                    tracing::error!(
91                        "failed to send data for user {} (id: {}): {result:?}",
92                        username_clone,
93                        id_clone
94                    );
95                }
96            }
97        });
98        Self {
99            sender,
100            sink: Vec::new(),
101        }
102    }
103}
104
105impl std::io::Write for TerminalHandle {
106    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
107        self.sink.extend_from_slice(buf);
108        Ok(buf.len())
109    }
110
111    fn flush(&mut self) -> std::io::Result<()> {
112        match self.sender.try_send(self.sink.clone()) {
113            Ok(()) => {}
114            Err(TrySendError::Full(_)) => {
115                // Consumer is slow; drop this frame rather than block the async runtime.
116                tracing::debug!("terminal output buffer full, dropping frame");
117            }
118            Err(TrySendError::Closed(_)) => {
119                return Err(std::io::Error::new(
120                    std::io::ErrorKind::BrokenPipe,
121                    "terminal channel closed",
122                ));
123            }
124        }
125
126        self.sink.clear();
127        Ok(())
128    }
129}
130
131const DEFAULT_MAX_CONNECTIONS: usize = 100;
132const DEFAULT_CHANNEL_BUFFER: usize = 64;
133
134pub struct ChaiServer<T: ChaiApp + Send + 'static> {
135    clients: Arc<Mutex<HashMap<usize, (SshTerminal, T)>>>,
136    port: u16,
137    id: usize,
138    username: String,
139    max_connections: usize,
140    channel_buffer: usize,
141}
142
143impl<T: ChaiApp + Send + 'static> Clone for ChaiServer<T> {
144    fn clone(&self) -> Self {
145        Self {
146            clients: self.clients.clone(),
147            port: self.port,
148            id: self.id,
149            username: self.username.clone(),
150            max_connections: self.max_connections,
151            channel_buffer: self.channel_buffer,
152        }
153    }
154}
155
156impl<T: ChaiApp + Send + 'static> ChaiServer<T> {
157    pub fn new(port: u16) -> Self {
158        Self {
159            clients: Arc::new(Mutex::new(HashMap::new())),
160            port,
161            id: 0,
162            username: String::new(),
163            max_connections: DEFAULT_MAX_CONNECTIONS,
164            channel_buffer: DEFAULT_CHANNEL_BUFFER,
165        }
166    }
167
168    /// Set the maximum number of concurrent SSH connections. Default: 100.
169    pub fn with_max_connections(mut self, max: usize) -> Self {
170        self.max_connections = max;
171        self
172    }
173
174    /// Set the per-connection terminal output channel buffer size.
175    /// Frames are dropped (not buffered) when the buffer is full. Default: 64.
176    pub fn with_channel_buffer(mut self, size: usize) -> Self {
177        self.channel_buffer = size;
178        self
179    }
180
181    pub async fn run(&mut self, config: Config) -> Result<(), anyhow::Error> {
182        let subscriber = tracing_subscriber::fmt()
183            .compact()
184            .with_file(true)
185            .with_line_number(true)
186            .with_target(true)
187            .with_env_filter(
188                tracing_subscriber::EnvFilter::try_from_default_env()
189                    .unwrap_or_else(|_| "info".into()),
190            )
191            .finish();
192
193        // Silently ignore the error if a global subscriber is already set.
194        let _ = tracing::subscriber::set_global_default(subscriber);
195
196        let clients = self.clients.clone();
197        tokio::spawn(async move {
198            loop {
199                tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
200
201                for (_, (terminal, app)) in clients.lock().await.iter_mut() {
202                    terminal
203                        .draw(|f| {
204                            app.update();
205                            app.draw(f);
206                        })
207                        .unwrap();
208                }
209            }
210        });
211
212        tracing::info!("starting server on 0.0.0.0:{}", self.port);
213        self.run_on_address(Arc::new(config), ("0.0.0.0", self.port))
214            .await?;
215        Ok(())
216    }
217}
218
219impl<T: ChaiApp + Send + 'static> ChaiServer<T> {
220    fn send_data_or_log(
221        &mut self,
222        session: &mut Session,
223        channel: ChannelId,
224        data: &[u8],
225        description: &str,
226    ) {
227        if let Err(e) = session.data(channel, data.into()) {
228            tracing::error!(
229                "failed to {} for user {} (id: {}): {:?}",
230                description,
231                self.username,
232                self.id,
233                e
234            );
235        }
236    }
237}
238
239impl<T: ChaiApp + Send + 'static> Server for ChaiServer<T> {
240    type Handler = Self;
241    fn new_client(&mut self, _: Option<std::net::SocketAddr>) -> Self {
242        let s = self.clone();
243        self.id += 1;
244        s
245    }
246}
247
248impl<T: ChaiApp + Send + 'static> Handler for ChaiServer<T> {
249    type Error = anyhow::Error;
250
251    async fn channel_open_session(
252        &mut self,
253        channel: Channel<Msg>,
254        session: &mut Session,
255    ) -> Result<bool, Self::Error> {
256        {
257            let clients = self.clients.lock().await;
258            if clients.len() >= self.max_connections {
259                tracing::warn!(
260                    "max connections ({}) reached, rejecting session for {}",
261                    self.max_connections,
262                    self.username
263                );
264                return Ok(false);
265            }
266        }
267
268        tracing::info!("{} (id: {}) opened a channel", self.username, self.id);
269        let terminal_handle = TerminalHandle::start(
270            session.handle(),
271            channel.id(),
272            self.username.clone(),
273            self.id,
274            self.channel_buffer,
275        )
276        .await;
277
278        let backend = CrosstermBackend::new(terminal_handle);
279
280        // the correct viewport area will be set when the client request a pty
281        let options = TerminalOptions {
282            viewport: Viewport::Fixed(Rect::default()),
283        };
284
285        let terminal = Terminal::with_options(backend, options)?;
286        let app = T::new();
287
288        let mut clients = self.clients.lock().await;
289        clients.insert(self.id, (terminal, app));
290
291        Ok(true)
292    }
293
294    async fn auth_publickey(&mut self, user: &str, _: &PublicKey) -> Result<Auth, Self::Error> {
295        self.username = user.to_string();
296        Ok(Auth::Accept)
297    }
298
299    async fn auth_none(&mut self, user: &str) -> Result<Auth, Self::Error> {
300        self.username = user.to_string();
301        Ok(Auth::Accept)
302    }
303
304    async fn data(
305        &mut self,
306        channel: ChannelId,
307        data: &[u8],
308        session: &mut Session,
309    ) -> Result<(), Self::Error> {
310        // Input validation: Only allow printable ASCII and control chars
311        if !data
312            .iter()
313            .all(|&b| b == b'\n' || b == b'\r' || (b >= 0x20 && b <= 0x7e))
314        {
315            tracing::warn!(
316                "Received invalid input data from user {} (id: {})",
317                self.username,
318                self.id
319            );
320            return Ok(());
321        }
322
323        let should_quit = {
324            let mut clients = self.clients.lock().await;
325            if let Some((terminal, app)) =
326                self.get_client_mut(&mut clients, "data handler", self.id)
327            {
328                app.handle_input(data);
329                let quit = app.should_quit();
330                if !quit {
331                    self.try_draw(terminal, app);
332                }
333                quit
334            } else {
335                false
336            }
337        };
338
339        if should_quit {
340            self.send_data_or_log(session, channel, EXIT_ALT_SCREEN, "exit alternate screen");
341            self.send_data_or_log(session, channel, SHOW_CURSOR, "show cursor");
342            self.clients.lock().await.remove(&self.id);
343            session.close(channel)?;
344        }
345
346        Ok(())
347    }
348
349    async fn window_change_request(
350        &mut self,
351        _: ChannelId,
352        col_width: u32,
353        row_height: u32,
354        _: u32,
355        _: u32,
356        _: &mut Session,
357    ) -> Result<(), Self::Error> {
358        let rect = Rect {
359            x: 0,
360            y: 0,
361            width: col_width as u16,
362            height: row_height as u16,
363        };
364
365        let mut clients = self.clients.lock().await;
366        if let Some(terminal) =
367            self.get_terminal_mut(&mut clients, "window_change_request", self.id)
368        {
369            terminal.resize(rect)?;
370        }
371
372        Ok(())
373    }
374
375    async fn pty_request(
376        &mut self,
377        channel: ChannelId,
378        _: &str,
379        col_width: u32,
380        row_height: u32,
381        _: u32,
382        _: u32,
383        _: &[(Pty, u32)],
384        session: &mut Session,
385    ) -> Result<(), Self::Error> {
386        let rect = Rect {
387            x: 0,
388            y: 0,
389            width: col_width as u16,
390            height: row_height as u16,
391        };
392
393        {
394            let mut clients = self.clients.lock().await;
395            if let Some(terminal) =
396                self.get_terminal_mut(&mut clients, "pty_request (resize)", self.id)
397            {
398                terminal.resize(rect)?;
399            }
400        }
401
402        session.channel_success(channel)?;
403
404        self.send_data_or_log(session, channel, ENTER_ALT_SCREEN, "enter alternate screen");
405        self.send_data_or_log(session, channel, HIDE_CURSOR, "hide cursor");
406
407        {
408            let mut clients = self.clients.lock().await;
409            if let Some((terminal, app)) =
410                self.get_client_mut(&mut clients, "pty_request (draw)", self.id)
411            {
412                self.try_draw(terminal, app);
413            }
414        }
415
416        Ok(())
417    }
418
419    async fn channel_close(
420        &mut self,
421        channel: ChannelId,
422        session: &mut Session,
423    ) -> Result<(), Self::Error> {
424        tracing::info!("{} (id: {}) closed a channel", self.username, self.id);
425        let reset_sequence = [EXIT_ALT_SCREEN, SHOW_CURSOR].concat();
426        let _ = session.data(channel, reset_sequence.into());
427
428        let mut clients = self.clients.lock().await;
429        if clients.remove(&self.id).is_none() {
430            tracing::warn!("No client found for id {} in channel_close", self.id);
431        }
432        Ok(())
433    }
434}
435
436impl<T: ChaiApp + Send + 'static> Drop for ChaiServer<T> {
437    fn drop(&mut self) {
438        // Synchronous cleanup: block on removing the client
439        let id = self.id;
440        let clients = self.clients.clone();
441        if let Ok(mut guard) = clients.try_lock() {
442            guard.remove(&id);
443        } else {
444            // If we can't lock, log and skip
445            tracing::warn!("Could not lock clients for cleanup in Drop for id {}", id);
446        }
447    }
448}
449
450pub fn load_system_host_keys(key_name: &str) -> Result<russh::keys::PrivateKey, anyhow::Error> {
451    let key_path = Path::new("/.ssh").join(key_name);
452
453    if !key_path.exists() {
454        return Err(anyhow::anyhow!(
455            "Host key not found at {}. Please generate host keys first.",
456            key_path.display()
457        ));
458    }
459
460    let key = russh::keys::PrivateKey::read_openssh_file(&key_path)
461        .map_err(|e| anyhow::anyhow!("Failed to read host key: {}", e))?;
462
463    Ok(key)
464}
465
466pub fn load_host_keys(path: &str) -> Result<russh::keys::PrivateKey, anyhow::Error> {
467    let key_path = Path::new(path);
468
469    if !key_path.exists() {
470        return Err(anyhow::anyhow!(
471            "Host key not found at {}. Please generate host keys first.",
472            key_path.display()
473        ));
474    }
475
476    let key = russh::keys::PrivateKey::read_openssh_file(&key_path)
477        .map_err(|e| anyhow::anyhow!("Failed to read host key: {}", e))?;
478
479    Ok(key)
480}