chai_framework/
server.rs

1use crate::chai::ChaiApp;
2
3use std::collections::HashMap;
4use std::path::Path;
5use std::sync::Arc;
6
7use ratatui::backend::CrosstermBackend;
8use ratatui::layout::Rect;
9
10use ratatui::{Terminal, TerminalOptions, Viewport};
11use russh::keys::ssh_key::PublicKey;
12use russh::server::*;
13use russh::{Channel, ChannelId, Pty};
14use tokio::sync::Mutex;
15use tokio::sync::mpsc::{UnboundedSender, unbounded_channel};
16
17const ENTER_ALT_SCREEN: &[u8] = b"\x1b[?1049h";
18const EXIT_ALT_SCREEN: &[u8] = b"\x1b[?1049l";
19const HIDE_CURSOR: &[u8] = b"\x1b[?25l";
20const SHOW_CURSOR: &[u8] = b"\x1b[?25h";
21
22type SshTerminal = Terminal<CrosstermBackend<TerminalHandle>>;
23
24struct TerminalHandle {
25    sender: UnboundedSender<Vec<u8>>,
26    // The sink collects the data which is finally sent to sender.
27    sink: Vec<u8>,
28}
29
30impl TerminalHandle {
31    async fn start(handle: Handle, channel_id: ChannelId, username: String, id: usize) -> Self {
32        let (sender, mut receiver) = unbounded_channel::<Vec<u8>>();
33        let username_clone = username.clone();
34        let id_clone = id;
35        tokio::spawn(async move {
36            while let Some(data) = receiver.recv().await {
37                let result = handle.data(channel_id, data.into()).await;
38                if result.is_err() {
39                    tracing::error!(
40                        "failed to send data for user {} (id: {}): {result:?}",
41                        username_clone,
42                        id_clone
43                    );
44                }
45            }
46        });
47        Self {
48            sender,
49            sink: Vec::new(),
50        }
51    }
52}
53
54impl std::io::Write for TerminalHandle {
55    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
56        self.sink.extend_from_slice(buf);
57        Ok(buf.len())
58    }
59
60    fn flush(&mut self) -> std::io::Result<()> {
61        let result = self.sender.send(self.sink.clone());
62        if result.is_err() {
63            return Err(std::io::Error::new(
64                std::io::ErrorKind::BrokenPipe,
65                result.unwrap_err(),
66            ));
67        }
68
69        self.sink.clear();
70        Ok(())
71    }
72}
73
74#[derive(Clone)]
75pub struct ChaiServer<T: ChaiApp + Send + 'static> {
76    clients: Arc<Mutex<HashMap<usize, (SshTerminal, T)>>>,
77    port: u16,
78    id: usize,
79    username: String,
80}
81
82impl<T: ChaiApp + Send + 'static> ChaiServer<T> {
83    pub fn new(port: u16) -> Self {
84        Self {
85            clients: Arc::new(Mutex::new(HashMap::new())),
86            port,
87            id: 0,
88            username: String::new(),
89        }
90    }
91
92    pub async fn run(&mut self, config: Config) -> Result<(), anyhow::Error> {
93        let subscriber = tracing_subscriber::fmt()
94            .compact()
95            .with_file(true)
96            .with_line_number(true)
97            .with_target(true)
98            .finish();
99
100        tracing::subscriber::set_global_default(subscriber).unwrap();
101
102        let clients = self.clients.clone();
103        tokio::spawn(async move {
104            loop {
105                tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
106
107                for (_, (terminal, app)) in clients.lock().await.iter_mut() {
108                    terminal
109                        .draw(|f| {
110                            app.update();
111                            app.draw(f);
112                        })
113                        .unwrap();
114                }
115            }
116        });
117
118        tracing::info!("starting server on 0.0.0.0:{}", self.port);
119        self.run_on_address(Arc::new(config), ("0.0.0.0", self.port))
120            .await?;
121        Ok(())
122    }
123}
124
125impl<T: ChaiApp + Send + 'static> ChaiServer<T> {
126    fn send_data_or_log(
127        &mut self,
128        session: &mut Session,
129        channel: ChannelId,
130        data: &[u8],
131        description: &str,
132    ) {
133        if let Err(e) = session.data(channel, data.into()) {
134            tracing::error!(
135                "failed to {} for user {} (id: {}): {:?}",
136                description,
137                self.username,
138                self.id,
139                e
140            );
141        }
142    }
143}
144
145impl<T: ChaiApp + Send + 'static> Server for ChaiServer<T> {
146    type Handler = Self;
147    fn new_client(&mut self, _: Option<std::net::SocketAddr>) -> Self {
148        let s = self.clone();
149        self.id += 1;
150        s
151    }
152}
153
154impl<T: ChaiApp + Send + 'static> Handler for ChaiServer<T> {
155    type Error = anyhow::Error;
156
157    async fn channel_open_session(
158        &mut self,
159        channel: Channel<Msg>,
160        session: &mut Session,
161    ) -> Result<bool, Self::Error> {
162        tracing::info!("{} (id: {}) opened a channel", self.username, self.id);
163        let terminal_handle = TerminalHandle::start(
164            session.handle(),
165            channel.id(),
166            self.username.clone(),
167            self.id,
168        )
169        .await;
170
171        let backend = CrosstermBackend::new(terminal_handle);
172
173        // the correct viewport area will be set when the client request a pty
174        let options = TerminalOptions {
175            viewport: Viewport::Fixed(Rect::default()),
176        };
177
178        let terminal = Terminal::with_options(backend, options)?;
179        let app = T::new();
180
181        let mut clients = self.clients.lock().await;
182        clients.insert(self.id, (terminal, app));
183
184        Ok(true)
185    }
186
187    async fn auth_publickey(&mut self, user: &str, _: &PublicKey) -> Result<Auth, Self::Error> {
188        self.username = user.to_string();
189        Ok(Auth::Accept)
190    }
191
192    async fn auth_none(&mut self, user: &str) -> Result<Auth, Self::Error> {
193        self.username = user.to_string();
194        Ok(Auth::Accept)
195    }
196
197    async fn data(
198        &mut self,
199        channel: ChannelId,
200        data: &[u8],
201        session: &mut Session,
202    ) -> Result<(), Self::Error> {
203        match data {
204            // Pressing 'q' closes the connection.
205            b"q" => {
206                self.send_data_or_log(session, channel, EXIT_ALT_SCREEN, "exit alternate screen");
207                self.send_data_or_log(session, channel, SHOW_CURSOR, "show cursor");
208
209                self.clients.lock().await.remove(&self.id);
210                session.close(channel)?;
211            }
212            _ => {
213                let mut clients = self.clients.lock().await;
214                let (_, app) = clients.get_mut(&self.id).unwrap();
215                app.handle_input(data);
216            }
217        }
218
219        Ok(())
220    }
221
222    async fn window_change_request(
223        &mut self,
224        _: ChannelId,
225        col_width: u32,
226        row_height: u32,
227        _: u32,
228        _: u32,
229        _: &mut Session,
230    ) -> Result<(), Self::Error> {
231        let rect = Rect {
232            x: 0,
233            y: 0,
234            width: col_width as u16,
235            height: row_height as u16,
236        };
237
238        let mut clients = self.clients.lock().await;
239        let (terminal, _) = clients.get_mut(&self.id).unwrap();
240        terminal.resize(rect)?;
241
242        Ok(())
243    }
244
245    async fn pty_request(
246        &mut self,
247        channel: ChannelId,
248        _: &str,
249        col_width: u32,
250        row_height: u32,
251        _: u32,
252        _: u32,
253        _: &[(Pty, u32)],
254        session: &mut Session,
255    ) -> Result<(), Self::Error> {
256        let rect = Rect {
257            x: 0,
258            y: 0,
259            width: col_width as u16,
260            height: row_height as u16,
261        };
262
263        {
264            let mut clients = self.clients.lock().await;
265            let (terminal, _) = clients.get_mut(&self.id).unwrap();
266            terminal.resize(rect)?;
267        }
268
269        session.channel_success(channel)?;
270
271        self.send_data_or_log(session, channel, ENTER_ALT_SCREEN, "enter alternate screen");
272        self.send_data_or_log(session, channel, HIDE_CURSOR, "hide cursor");
273
274        Ok(())
275    }
276
277    async fn channel_close(
278        &mut self,
279        channel: ChannelId,
280        session: &mut Session,
281    ) -> Result<(), Self::Error> {
282        tracing::info!("{} (id: {}) closed a channel", self.username, self.id);
283        let reset_sequence = [EXIT_ALT_SCREEN, SHOW_CURSOR].concat();
284        let _ = session.data(channel, reset_sequence.into());
285
286        self.clients.lock().await.remove(&self.id);
287        Ok(())
288    }
289}
290
291impl<T: ChaiApp + Send + 'static> Drop for ChaiServer<T> {
292    fn drop(&mut self) {
293        let id = self.id;
294        let clients = self.clients.clone();
295        tokio::spawn(async move {
296            let mut clients = clients.lock().await;
297            clients.remove(&id);
298        });
299    }
300}
301
302pub fn load_system_host_keys(key_name: &str) -> Result<russh::keys::PrivateKey, anyhow::Error> {
303    let key_path = Path::new("/.ssh").join(key_name);
304
305    if !key_path.exists() {
306        return Err(anyhow::anyhow!(
307            "Host key not found at {}. Please generate host keys first.",
308            key_path.display()
309        ));
310    }
311
312    let key = russh::keys::PrivateKey::read_openssh_file(&key_path)
313        .map_err(|e| anyhow::anyhow!("Failed to read host key: {}", e))?;
314
315    Ok(key)
316}
317
318pub fn load_host_keys(path: &str) -> Result<russh::keys::PrivateKey, anyhow::Error> {
319    let key_path = Path::new(path);
320
321    if !key_path.exists() {
322        return Err(anyhow::anyhow!(
323            "Host key not found at {}. Please generate host keys first.",
324            key_path.display()
325        ));
326    }
327
328    let key = russh::keys::PrivateKey::read_openssh_file(&key_path)
329        .map_err(|e| anyhow::anyhow!("Failed to read host key: {}", e))?;
330
331    Ok(key)
332}