use crate::room::{ChatRoom, RoomEvent, Message};
use crate::tui::{render_frame, ChatTui};
use async_trait::async_trait;
use russh::server::{Auth, Handle, Handler, Msg, Session};
use russh::{Channel, ChannelId, CryptoVec, Pty};
use russh_keys::key::KeyPair;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, Mutex, RwLock};
use tracing::{debug, error, info, warn};
#[derive(Clone)]
pub struct ChatServer {
room: Arc<ChatRoom>,
data_senders: Arc<Mutex<HashMap<usize, mpsc::Sender<Vec<u8>>>>>,
next_id: Arc<RwLock<usize>>,
}
impl ChatServer {
pub fn new(room: Arc<ChatRoom>) -> Self {
Self {
room,
data_senders: Arc::new(Mutex::new(HashMap::new())),
next_id: Arc::new(RwLock::new(1)),
}
}
pub async fn run(self, port: u16) -> anyhow::Result<()> {
let key_pair = KeyPair::generate_ed25519().expect("Failed to generate host key");
let fingerprint = key_pair.clone_public_key()
.map(|k| k.fingerprint())
.unwrap_or_else(|_| "unknown".to_string());
info!("Host key fingerprint: {}", fingerprint);
let mut config = russh::server::Config::default();
config.keys.push(key_pair);
config.auth_rejection_time = Duration::from_secs(3);
config.auth_rejection_time_initial = Some(Duration::from_secs(0));
config.inactivity_timeout = Some(Duration::from_secs(3600));
let config = Arc::new(config);
let addr = format!("0.0.0.0:{}", port);
info!("Starting SSH server on {}", addr);
russh::server::run(config, &addr, self).await?;
Ok(())
}
}
impl russh::server::Server for ChatServer {
type Handler = ChatHandler;
fn new_client(&mut self, peer_addr: Option<std::net::SocketAddr>) -> Self::Handler {
let client_id = {
let mut id = futures::executor::block_on(self.next_id.write());
let current = *id;
*id += 1;
current
};
info!("New connection from {:?} (client_id: {})", peer_addr, client_id);
ChatHandler {
client_id,
room: self.room.clone(),
data_senders: self.data_senders.clone(),
session_id: None,
username: None,
channel_id: None,
pty_size: (80, 24),
authenticated: false,
}
}
}
pub struct ChatHandler {
client_id: usize,
room: Arc<ChatRoom>,
data_senders: Arc<Mutex<HashMap<usize, mpsc::Sender<Vec<u8>>>>>,
session_id: Option<u64>,
username: Option<String>,
channel_id: Option<ChannelId>,
pty_size: (u16, u16),
authenticated: bool,
}
#[async_trait]
impl Handler for ChatHandler {
type Error = anyhow::Error;
async fn auth_password(
mut self,
user: &str,
password: &str,
) -> Result<(Self, Auth), Self::Error> {
info!("Password auth attempt for user: {}", user);
if self.room.verify_password(password) {
self.username = Some(user.to_string());
self.authenticated = true;
info!("User {} authenticated successfully", user);
Ok((self, Auth::Accept))
} else {
warn!("Authentication failed for user: {}", user);
Ok((self, Auth::Reject { proceed_with_methods: None }))
}
}
async fn channel_open_session(
mut self,
channel: Channel<Msg>,
session: Session,
) -> Result<(Self, bool, Session), Self::Error> {
if !self.authenticated {
return Ok((self, false, session));
}
self.channel_id = Some(channel.id());
Ok((self, true, session))
}
async fn pty_request(
mut self,
_channel_id: ChannelId,
term: &str,
col_width: u32,
row_height: u32,
_pix_width: u32,
_pix_height: u32,
_modes: &[(Pty, u32)],
session: Session,
) -> Result<(Self, Session), Self::Error> {
info!(
"PTY request: term={}, size={}x{}",
term, col_width, row_height
);
self.pty_size = (col_width as u16, row_height as u16);
Ok((self, session))
}
async fn shell_request(
mut self,
channel_id: ChannelId,
session: Session,
) -> Result<(Self, Session), Self::Error> {
let username = self.username.clone().unwrap_or_else(|| "anonymous".to_string());
let (cols, rows) = self.pty_size;
let session_id = self.room.join(username.clone(), cols, rows).await?;
self.session_id = Some(session_id);
info!("User {} joined room (session_id: {})", username, session_id);
let (data_tx, data_rx) = mpsc::channel::<Vec<u8>>(64);
self.data_senders.lock().await.insert(self.client_id, data_tx);
let handle = session.handle();
let client_id = self.client_id;
let data_senders = self.data_senders.clone();
let ctx = SessionContext {
handle,
channel_id,
room: self.room.clone(),
session_id,
username,
cols,
rows,
};
tokio::spawn(async move {
if let Err(e) = run_chat_session(ctx, data_rx).await {
error!("Chat session error: {}", e);
}
data_senders.lock().await.remove(&client_id);
});
Ok((self, session))
}
async fn window_change_request(
mut self,
_channel_id: ChannelId,
col_width: u32,
row_height: u32,
_pix_width: u32,
_pix_height: u32,
session: Session,
) -> Result<(Self, Session), Self::Error> {
debug!("Window resize: {}x{}", col_width, row_height);
self.pty_size = (col_width as u16, row_height as u16);
if let Some(session_id) = self.session_id {
self.room.update_terminal_size(
session_id,
col_width as u16,
row_height as u16,
).await;
}
Ok((self, session))
}
async fn data(
self,
_channel_id: ChannelId,
data: &[u8],
session: Session,
) -> Result<(Self, Session), Self::Error> {
if let Some(sender) = self.data_senders.lock().await.get(&self.client_id) {
let _ = sender.send(data.to_vec()).await;
}
Ok((self, session))
}
async fn channel_close(
self,
_channel_id: ChannelId,
session: Session,
) -> Result<(Self, Session), Self::Error> {
info!("Channel closed for client {}", self.client_id);
if let Some(session_id) = self.session_id {
self.room.leave(session_id).await;
}
self.data_senders.lock().await.remove(&self.client_id);
Ok((self, session))
}
}
struct SessionContext {
handle: Handle,
channel_id: ChannelId,
room: Arc<ChatRoom>,
session_id: u64,
username: String,
cols: u16,
rows: u16,
}
async fn run_chat_session(
ctx: SessionContext,
mut data_rx: mpsc::Receiver<Vec<u8>>,
) -> anyhow::Result<()> {
let mut tui = ChatTui::new(ctx.room.clone(), ctx.username.clone(), ctx.cols, ctx.rows);
let mut rx = ctx.room.subscribe();
let mut input_buffer = Vec::new();
tui.refresh_cache().await;
let frame = render_frame(&tui);
ctx.handle.data(ctx.channel_id, CryptoVec::from_slice(&frame))
.await
.map_err(|_| anyhow::anyhow!("Failed to send initial frame"))?;
loop {
tokio::select! {
event = rx.recv() => {
match event {
Ok(RoomEvent::NewMessage) |
Ok(RoomEvent::UserJoined) |
Ok(RoomEvent::UserLeft) |
Ok(RoomEvent::Refresh) => {
tui.refresh_cache().await;
let frame = render_frame(&tui);
if ctx.handle.data(ctx.channel_id, CryptoVec::from_slice(&frame)).await.is_err() {
break;
}
}
Err(_) => break,
}
}
Some(data) = data_rx.recv() => {
input_buffer.extend_from_slice(&data);
let mut should_quit = false;
while let Some(msg) = process_input(&mut input_buffer, &mut tui).await {
eprintln!("DEBUG: Processed msg: {:?}", msg);
info!("Processed full input line: {:?}", msg);
if msg == "__QUIT__" {
should_quit = true;
break;
}
if msg == "hey" {
ctx.room.send_message(ctx.session_id, "hey nigga".to_string()).await;
continue;
}
let trimmed_msg = msg.trim();
eprintln!("DEBUG: Trimmed: '{}', starts_with /: {}", trimmed_msg, trimmed_msg.starts_with('/'));
if trimmed_msg.starts_with('/') {
eprintln!("DEBUG: Handling command");
handle_command(trimmed_msg, &ctx).await;
} else {
eprintln!("DEBUG: Broadcasting regular message");
ctx.room.send_message(ctx.session_id, msg).await;
}
}
if should_quit {
ctx.room.leave(ctx.session_id).await;
let _ = ctx.handle.eof(ctx.channel_id).await;
let _ = ctx.handle.close(ctx.channel_id).await;
return Ok(());
}
let frame = render_frame(&tui);
if ctx.handle.data(ctx.channel_id, CryptoVec::from_slice(&frame)).await.is_err() {
break;
}
}
}
}
ctx.room.leave(ctx.session_id).await;
Ok(())
}
async fn handle_command(cmd: &str, ctx: &SessionContext) {
info!("Inside handle_command with: '{}'", cmd);
let parts: Vec<&str> = cmd.trim().split_whitespace().collect();
if parts.is_empty() {
warn!("Command parts empty after split");
return;
}
info!("Command parts: {:?}", parts);
match parts[0] {
"/greet" => {
let target = if parts.len() > 1 {
parts[1].to_string()
} else {
ctx.username.clone()
};
info!("Executing /greet for target: {}", target);
let response = format!("hey {}", target);
let msg = Message::text("System", response);
info!("Adding system message to room");
ctx.room.add_message(msg).await;
info!("System message added");
}
"/joke" => {
let response = fetch_joke().await;
let msg = Message::text("Joker", response);
ctx.room.add_message(msg).await;
info!("Joke added to room");
}
_ => {
warn!("Unknown command: {}", parts[0]);
}
}
}
#[derive(serde::Deserialize)]
struct Flags {
nsfw: bool,
religious: bool,
political: bool,
racist: bool,
sexist: bool,
#[serde(rename = "explicit")]
explicit_flag: bool,
}
#[derive(serde::Deserialize)]
struct JokeResponse {
error: bool,
category: String,
#[serde(rename = "type")]
joke_type: String,
joke: String,
flags: Flags,
id: u64,
safe: bool,
lang: String,
}
async fn fetch_joke() -> String {
let response = reqwest::get("https://v2.jokeapi.dev/joke/Any?type=single")
.await
.unwrap()
.json::<JokeResponse>()
.await
.unwrap();
if response.error {
return "Error fetching joke".to_string();
}
response.joke
}
async fn process_input(buffer: &mut Vec<u8>, tui: &mut ChatTui) -> Option<String> {
if buffer.is_empty() {
return None;
}
let byte = buffer.remove(0);
let key_event = match byte {
3 => {
return Some("__QUIT__".to_string());
}
13 => {
crossterm::event::KeyEvent::new(
crossterm::event::KeyCode::Enter,
crossterm::event::KeyModifiers::NONE,
)
}
127 | 8 => {
crossterm::event::KeyEvent::new(
crossterm::event::KeyCode::Backspace,
crossterm::event::KeyModifiers::NONE,
)
}
27 => {
if buffer.len() >= 2 && buffer[0] == b'[' {
buffer.remove(0); match buffer.remove(0) {
b'A' => crossterm::event::KeyEvent::new(
crossterm::event::KeyCode::Up,
crossterm::event::KeyModifiers::NONE,
),
b'B' => crossterm::event::KeyEvent::new(
crossterm::event::KeyCode::Down,
crossterm::event::KeyModifiers::NONE,
),
b'C' => crossterm::event::KeyEvent::new(
crossterm::event::KeyCode::Right,
crossterm::event::KeyModifiers::NONE,
),
b'D' => crossterm::event::KeyEvent::new(
crossterm::event::KeyCode::Left,
crossterm::event::KeyModifiers::NONE,
),
b'5' if !buffer.is_empty() && buffer[0] == b'~' => {
buffer.remove(0);
crossterm::event::KeyEvent::new(
crossterm::event::KeyCode::PageUp,
crossterm::event::KeyModifiers::NONE,
)
}
b'6' if !buffer.is_empty() && buffer[0] == b'~' => {
buffer.remove(0);
crossterm::event::KeyEvent::new(
crossterm::event::KeyCode::PageDown,
crossterm::event::KeyModifiers::NONE,
)
}
_ => return None,
}
} else {
return None;
}
}
c if (32..127).contains(&c) => {
crossterm::event::KeyEvent::new(
crossterm::event::KeyCode::Char(c as char),
crossterm::event::KeyModifiers::NONE,
)
}
_ => return None,
};
tui.handle_key(key_event).await
}