1use crate::chai::ChaiApp;
2
3use std::collections::HashMap;
4use std::path::Path;
5use std::sync::Arc;
6
7use ratatui::backend::CrosstermBackend;
8use ratatui::layout::Rect;
9
10use ratatui::{Terminal, TerminalOptions, Viewport};
11use russh::keys::ssh_key::PublicKey;
12use russh::server::*;
13use russh::{Channel, ChannelId, Pty};
14use tokio::sync::Mutex;
15use tokio::sync::mpsc::{UnboundedSender, unbounded_channel};
16
17const ENTER_ALT_SCREEN: &[u8] = b"\x1b[?1049h";
18const EXIT_ALT_SCREEN: &[u8] = b"\x1b[?1049l";
19const HIDE_CURSOR: &[u8] = b"\x1b[?25l";
20const SHOW_CURSOR: &[u8] = b"\x1b[?25h";
21
22type SshTerminal = Terminal<CrosstermBackend<TerminalHandle>>;
23
24struct TerminalHandle {
25 sender: UnboundedSender<Vec<u8>>,
26 sink: Vec<u8>,
28}
29
30impl TerminalHandle {
31 async fn start(handle: Handle, channel_id: ChannelId, username: String, id: usize) -> Self {
32 let (sender, mut receiver) = unbounded_channel::<Vec<u8>>();
33 let username_clone = username.clone();
34 let id_clone = id;
35 tokio::spawn(async move {
36 while let Some(data) = receiver.recv().await {
37 let result = handle.data(channel_id, data.into()).await;
38 if result.is_err() {
39 tracing::error!(
40 "failed to send data for user {} (id: {}): {result:?}",
41 username_clone,
42 id_clone
43 );
44 }
45 }
46 });
47 Self {
48 sender,
49 sink: Vec::new(),
50 }
51 }
52}
53
54impl std::io::Write for TerminalHandle {
55 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
56 self.sink.extend_from_slice(buf);
57 Ok(buf.len())
58 }
59
60 fn flush(&mut self) -> std::io::Result<()> {
61 let result = self.sender.send(self.sink.clone());
62 if result.is_err() {
63 return Err(std::io::Error::new(
64 std::io::ErrorKind::BrokenPipe,
65 result.unwrap_err(),
66 ));
67 }
68
69 self.sink.clear();
70 Ok(())
71 }
72}
73
74#[derive(Clone)]
75pub struct ChaiServer<T: ChaiApp + Send + 'static> {
76 clients: Arc<Mutex<HashMap<usize, (SshTerminal, T)>>>,
77 port: u16,
78 id: usize,
79 username: String,
80}
81
82impl<T: ChaiApp + Send + 'static> ChaiServer<T> {
83 pub fn new(port: u16) -> Self {
84 Self {
85 clients: Arc::new(Mutex::new(HashMap::new())),
86 port,
87 id: 0,
88 username: String::new(),
89 }
90 }
91
92 pub async fn run(&mut self, config: Config) -> Result<(), anyhow::Error> {
93 let subscriber = tracing_subscriber::fmt()
94 .compact()
95 .with_file(true)
96 .with_line_number(true)
97 .with_target(true)
98 .finish();
99
100 tracing::subscriber::set_global_default(subscriber).unwrap();
101
102 let clients = self.clients.clone();
103 tokio::spawn(async move {
104 loop {
105 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
106
107 for (_, (terminal, app)) in clients.lock().await.iter_mut() {
108 terminal
109 .draw(|f| {
110 app.update();
111 app.draw(f);
112 })
113 .unwrap();
114 }
115 }
116 });
117
118 tracing::info!("starting server on 0.0.0.0:{}", self.port);
119 self.run_on_address(Arc::new(config), ("0.0.0.0", self.port))
120 .await?;
121 Ok(())
122 }
123}
124
125impl<T: ChaiApp + Send + 'static> ChaiServer<T> {
126 fn send_data_or_log(
127 &mut self,
128 session: &mut Session,
129 channel: ChannelId,
130 data: &[u8],
131 description: &str,
132 ) {
133 if let Err(e) = session.data(channel, data.into()) {
134 tracing::error!(
135 "failed to {} for user {} (id: {}): {:?}",
136 description,
137 self.username,
138 self.id,
139 e
140 );
141 }
142 }
143}
144
145impl<T: ChaiApp + Send + 'static> Server for ChaiServer<T> {
146 type Handler = Self;
147 fn new_client(&mut self, _: Option<std::net::SocketAddr>) -> Self {
148 let s = self.clone();
149 self.id += 1;
150 s
151 }
152}
153
154impl<T: ChaiApp + Send + 'static> Handler for ChaiServer<T> {
155 type Error = anyhow::Error;
156
157 async fn channel_open_session(
158 &mut self,
159 channel: Channel<Msg>,
160 session: &mut Session,
161 ) -> Result<bool, Self::Error> {
162 tracing::info!("{} (id: {}) opened a channel", self.username, self.id);
163 let terminal_handle = TerminalHandle::start(
164 session.handle(),
165 channel.id(),
166 self.username.clone(),
167 self.id,
168 )
169 .await;
170
171 let backend = CrosstermBackend::new(terminal_handle);
172
173 let options = TerminalOptions {
175 viewport: Viewport::Fixed(Rect::default()),
176 };
177
178 let terminal = Terminal::with_options(backend, options)?;
179 let app = T::new();
180
181 let mut clients = self.clients.lock().await;
182 clients.insert(self.id, (terminal, app));
183
184 Ok(true)
185 }
186
187 async fn auth_publickey(&mut self, user: &str, _: &PublicKey) -> Result<Auth, Self::Error> {
188 self.username = user.to_string();
189 Ok(Auth::Accept)
190 }
191
192 async fn auth_none(&mut self, user: &str) -> Result<Auth, Self::Error> {
193 self.username = user.to_string();
194 Ok(Auth::Accept)
195 }
196
197 async fn data(
198 &mut self,
199 channel: ChannelId,
200 data: &[u8],
201 session: &mut Session,
202 ) -> Result<(), Self::Error> {
203 match data {
204 b"q" => {
206 self.send_data_or_log(session, channel, EXIT_ALT_SCREEN, "exit alternate screen");
207 self.send_data_or_log(session, channel, SHOW_CURSOR, "show cursor");
208
209 self.clients.lock().await.remove(&self.id);
210 session.close(channel)?;
211 }
212 _ => {
213 let mut clients = self.clients.lock().await;
214 let (_, app) = clients.get_mut(&self.id).unwrap();
215 app.handle_input(data);
216 }
217 }
218
219 Ok(())
220 }
221
222 async fn window_change_request(
223 &mut self,
224 _: ChannelId,
225 col_width: u32,
226 row_height: u32,
227 _: u32,
228 _: u32,
229 _: &mut Session,
230 ) -> Result<(), Self::Error> {
231 let rect = Rect {
232 x: 0,
233 y: 0,
234 width: col_width as u16,
235 height: row_height as u16,
236 };
237
238 let mut clients = self.clients.lock().await;
239 let (terminal, _) = clients.get_mut(&self.id).unwrap();
240 terminal.resize(rect)?;
241
242 Ok(())
243 }
244
245 async fn pty_request(
246 &mut self,
247 channel: ChannelId,
248 _: &str,
249 col_width: u32,
250 row_height: u32,
251 _: u32,
252 _: u32,
253 _: &[(Pty, u32)],
254 session: &mut Session,
255 ) -> Result<(), Self::Error> {
256 let rect = Rect {
257 x: 0,
258 y: 0,
259 width: col_width as u16,
260 height: row_height as u16,
261 };
262
263 {
264 let mut clients = self.clients.lock().await;
265 let (terminal, _) = clients.get_mut(&self.id).unwrap();
266 terminal.resize(rect)?;
267 }
268
269 session.channel_success(channel)?;
270
271 self.send_data_or_log(session, channel, ENTER_ALT_SCREEN, "enter alternate screen");
272 self.send_data_or_log(session, channel, HIDE_CURSOR, "hide cursor");
273
274 Ok(())
275 }
276
277 async fn channel_close(
278 &mut self,
279 channel: ChannelId,
280 session: &mut Session,
281 ) -> Result<(), Self::Error> {
282 tracing::info!("{} (id: {}) closed a channel", self.username, self.id);
283 let reset_sequence = [EXIT_ALT_SCREEN, SHOW_CURSOR].concat();
284 let _ = session.data(channel, reset_sequence.into());
285
286 self.clients.lock().await.remove(&self.id);
287 Ok(())
288 }
289}
290
291impl<T: ChaiApp + Send + 'static> Drop for ChaiServer<T> {
292 fn drop(&mut self) {
293 let id = self.id;
294 let clients = self.clients.clone();
295 tokio::spawn(async move {
296 let mut clients = clients.lock().await;
297 clients.remove(&id);
298 });
299 }
300}
301
302pub fn load_system_host_keys(key_name: &str) -> Result<russh::keys::PrivateKey, anyhow::Error> {
303 let key_path = Path::new("/.ssh").join(key_name);
304
305 if !key_path.exists() {
306 return Err(anyhow::anyhow!(
307 "Host key not found at {}. Please generate host keys first.",
308 key_path.display()
309 ));
310 }
311
312 let key = russh::keys::PrivateKey::read_openssh_file(&key_path)
313 .map_err(|e| anyhow::anyhow!("Failed to read host key: {}", e))?;
314
315 Ok(key)
316}
317
318pub fn load_host_keys(path: &str) -> Result<russh::keys::PrivateKey, anyhow::Error> {
319 let key_path = Path::new(path);
320
321 if !key_path.exists() {
322 return Err(anyhow::anyhow!(
323 "Host key not found at {}. Please generate host keys first.",
324 key_path.display()
325 ));
326 }
327
328 let key = russh::keys::PrivateKey::read_openssh_file(&key_path)
329 .map_err(|e| anyhow::anyhow!("Failed to read host key: {}", e))?;
330
331 Ok(key)
332}