Skip to main content

room_cli/broker/
mod.rs

1pub(crate) mod auth;
2pub(crate) mod commands;
3pub(crate) mod fanout;
4pub(crate) mod state;
5pub(crate) mod ws;
6
7use std::{
8    collections::HashMap,
9    path::PathBuf,
10    sync::{
11        atomic::{AtomicU64, Ordering},
12        Arc,
13    },
14};
15
16use crate::{
17    history,
18    message::{make_join, make_leave, parse_client_line, Message},
19    plugin::{self, PluginRegistry},
20};
21use auth::{handle_oneshot_join, validate_token};
22use commands::{route_command, CommandResult};
23use fanout::{broadcast_and_persist, dm_and_persist};
24use state::RoomState;
25use tokio::{
26    io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
27    net::{
28        unix::{OwnedReadHalf, OwnedWriteHalf},
29        UnixListener, UnixStream,
30    },
31    sync::{broadcast, watch, Mutex},
32};
33
34pub struct Broker {
35    room_id: String,
36    chat_path: PathBuf,
37    socket_path: PathBuf,
38    ws_port: Option<u16>,
39}
40
41impl Broker {
42    pub fn new(
43        room_id: &str,
44        chat_path: PathBuf,
45        socket_path: PathBuf,
46        ws_port: Option<u16>,
47    ) -> Self {
48        Self {
49            room_id: room_id.to_owned(),
50            chat_path,
51            socket_path,
52            ws_port,
53        }
54    }
55
56    pub async fn run(self) -> anyhow::Result<()> {
57        // Remove stale socket synchronously — using tokio::fs here is dangerous
58        // because the blocking thread pool may be shutting down if the broker
59        // is starting up inside a dying process.
60        if self.socket_path.exists() {
61            std::fs::remove_file(&self.socket_path)?;
62        }
63
64        let listener = UnixListener::bind(&self.socket_path)?;
65        eprintln!("[broker] listening on {}", self.socket_path.display());
66
67        let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
68
69        let mut registry = PluginRegistry::new();
70        registry.register(Box::new(plugin::help::HelpPlugin))?;
71        registry.register(Box::new(plugin::stats::StatsPlugin))?;
72
73        let state = Arc::new(RoomState {
74            clients: Arc::new(Mutex::new(HashMap::new())),
75            status_map: Arc::new(Mutex::new(HashMap::new())),
76            host_user: Arc::new(Mutex::new(None)),
77            token_map: Arc::new(Mutex::new(HashMap::new())),
78            chat_path: Arc::new(self.chat_path.clone()),
79            room_id: Arc::new(self.room_id.clone()),
80            shutdown: Arc::new(shutdown_tx),
81            seq_counter: Arc::new(AtomicU64::new(0)),
82            plugin_registry: Arc::new(registry),
83        });
84        let next_client_id = Arc::new(AtomicU64::new(0));
85
86        // Start WebSocket/REST server if a port was configured.
87        if let Some(port) = self.ws_port {
88            let ws_state = ws::WsAppState {
89                room_state: state.clone(),
90                next_client_id: next_client_id.clone(),
91            };
92            let app = ws::create_router(ws_state);
93            let tcp = tokio::net::TcpListener::bind(("0.0.0.0", port)).await?;
94            eprintln!("[broker] WebSocket/REST listening on port {port}");
95            tokio::spawn(async move {
96                if let Err(e) = axum::serve(tcp, app).await {
97                    eprintln!("[broker] WS server error: {e}");
98                }
99            });
100        }
101
102        loop {
103            tokio::select! {
104                accept = listener.accept() => {
105                    let (stream, _) = accept?;
106                    let cid = next_client_id.fetch_add(1, Ordering::SeqCst) + 1;
107
108                    let (tx, _) = broadcast::channel::<String>(256);
109                    // Insert with empty username; handle_client updates it after handshake.
110                    state
111                        .clients
112                        .lock()
113                        .await
114                        .insert(cid, (String::new(), tx.clone()));
115
116                    let state_clone = state.clone();
117
118                    tokio::spawn(async move {
119                        if let Err(e) = handle_client(cid, stream, tx, &state_clone).await {
120                            eprintln!("[broker] client {cid} error: {e:#}");
121                        }
122                        state_clone.clients.lock().await.remove(&cid);
123                    });
124                }
125                _ = shutdown_rx.changed() => {
126                    eprintln!("[broker] shutdown requested, exiting");
127                    break Ok(());
128                }
129            }
130        }
131    }
132}
133
134async fn handle_client(
135    cid: u64,
136    stream: UnixStream,
137    own_tx: broadcast::Sender<String>,
138    state: &Arc<RoomState>,
139) -> anyhow::Result<()> {
140    // Clone the Arc fields up-front so spawned tasks can capture owned handles.
141    let clients = state.clients.clone();
142    let status_map = state.status_map.clone();
143    let host_user = state.host_user.clone();
144    let token_map = state.token_map.clone();
145    let chat_path = state.chat_path.clone();
146    let room_id = state.room_id.clone();
147    let seq_counter = state.seq_counter.clone();
148
149    let (read_half, mut write_half) = stream.into_split();
150    let mut reader = BufReader::new(read_half);
151
152    // First line: username handshake, or one of the one-shot prefixes:
153    //   SEND:<username>  — legacy one-shot send
154    //   TOKEN:<uuid>     — token-authenticated one-shot send
155    //   JOIN:<username>  — register username, receive a session token
156    let mut first = String::new();
157    reader.read_line(&mut first).await?;
158    let first_line = first.trim();
159
160    if let Some(send_user) = first_line.strip_prefix("SEND:") {
161        return handle_oneshot_send(send_user.to_owned(), reader, write_half, state).await;
162    }
163
164    if let Some(token) = first_line.strip_prefix("TOKEN:") {
165        return match validate_token(token, &token_map).await {
166            Some(u) => handle_oneshot_send(u, reader, write_half, state).await,
167            None => {
168                let err = serde_json::json!({"type":"error","code":"invalid_token"});
169                write_half
170                    .write_all(format!("{err}\n").as_bytes())
171                    .await
172                    .map_err(Into::into)
173            }
174        };
175    }
176
177    if let Some(join_user) = first_line.strip_prefix("JOIN:") {
178        return handle_oneshot_join(join_user.to_owned(), write_half, &token_map).await;
179    }
180
181    // Remaining path: full interactive join — first_line is the username.
182    let username = first_line.to_owned();
183    if username.is_empty() {
184        return Ok(());
185    }
186
187    // Register username in the client map
188    {
189        let mut map = clients.lock().await;
190        if let Some(entry) = map.get_mut(&cid) {
191            entry.0 = username.clone();
192        }
193    }
194
195    // Register as host if no host has been set yet (first to complete handshake)
196    {
197        let mut host = host_user.lock().await;
198        if host.is_none() {
199            *host = Some(username.clone());
200        }
201    }
202
203    eprintln!("[broker] {username} joined (cid={cid})");
204
205    // Track this user in the status map (empty status by default)
206    status_map
207        .lock()
208        .await
209        .insert(username.clone(), String::new());
210
211    // Subscribe before sending history so we don't miss concurrent messages
212    let mut rx = own_tx.subscribe();
213
214    // Send chat history directly to this client's socket, filtering DMs the
215    // client is not party to (sender, recipient, or host).
216    // If the client disconnects mid-replay, treat it as a clean exit.
217    let host_name = host_user.lock().await.clone();
218    let is_host = host_name.as_deref() == Some(username.as_str());
219    let history = history::load(&chat_path).await.unwrap_or_default();
220    for msg in &history {
221        let visible = match msg {
222            Message::DirectMessage { user, to, .. } => {
223                is_host || user == &username || to == &username
224            }
225            _ => true,
226        };
227        if visible {
228            let line = format!("{}\n", serde_json::to_string(msg)?);
229            if write_half.write_all(line.as_bytes()).await.is_err() {
230                return Ok(());
231            }
232        }
233    }
234
235    // Broadcast join event (also persists it)
236    let join_msg = make_join(room_id.as_str(), &username);
237    if let Err(e) = broadcast_and_persist(&join_msg, &clients, &chat_path, &seq_counter).await {
238        eprintln!("[broker] broadcast_and_persist(join) failed: {e:#}");
239        return Ok(());
240    }
241
242    // Wrap write half in Arc<Mutex> for shared use by outbound and inbound tasks
243    let write_half = Arc::new(Mutex::new(write_half));
244
245    // Outbound: receive from broadcast channel, forward to client socket.
246    // Also listens for the shutdown signal; drains the channel first so the
247    // client sees the shutdown system message before receiving EOF.
248    let write_half_out = write_half.clone();
249    let mut shutdown_rx = state.shutdown.subscribe();
250    let outbound = tokio::spawn(async move {
251        loop {
252            tokio::select! {
253                result = rx.recv() => {
254                    match result {
255                        Ok(line) => {
256                            let mut wh = write_half_out.lock().await;
257                            if wh.write_all(line.as_bytes()).await.is_err() {
258                                break;
259                            }
260                        }
261                        Err(broadcast::error::RecvError::Lagged(n)) => {
262                            eprintln!("[broker] cid={cid} lagged by {n}");
263                        }
264                        Err(broadcast::error::RecvError::Closed) => break,
265                    }
266                }
267                _ = shutdown_rx.changed() => {
268                    // Drain any messages already queued (e.g. the shutdown notice)
269                    // before closing so the client sees them before receiving EOF.
270                    while let Ok(line) = rx.try_recv() {
271                        let mut wh = write_half_out.lock().await;
272                        let _ = wh.write_all(line.as_bytes()).await;
273                    }
274                    // Explicitly shut down the write side to send EOF to the client,
275                    // even though write_half_in (in the inbound task) still holds
276                    // the Arc — without this, the socket stays open.
277                    let _ = write_half_out.lock().await.shutdown().await;
278                    break;
279                }
280            }
281        }
282    });
283
284    // Inbound: read lines from client socket, parse and broadcast
285    let username_in = username.clone();
286    let room_id_in = room_id.clone();
287    let write_half_in = write_half.clone();
288    let state_in = state.clone();
289    let inbound = tokio::spawn(async move {
290        let mut line = String::new();
291        loop {
292            line.clear();
293            match reader.read_line(&mut line).await {
294                Ok(0) => break,
295                Ok(_) => {
296                    let trimmed = line.trim();
297                    if trimmed.is_empty() {
298                        continue;
299                    }
300                    match parse_client_line(trimmed, &room_id_in, &username_in) {
301                        Ok(msg) => match route_command(msg, &username_in, &state_in).await {
302                            Ok(CommandResult::Handled) => {}
303                            Ok(CommandResult::Reply(json)) => {
304                                let _ = write_half_in
305                                    .lock()
306                                    .await
307                                    .write_all(format!("{json}\n").as_bytes())
308                                    .await;
309                            }
310                            Ok(CommandResult::Shutdown) => break,
311                            Ok(CommandResult::Passthrough(msg)) => {
312                                let result = match &msg {
313                                    Message::DirectMessage { to, .. } => {
314                                        dm_and_persist(
315                                            &msg,
316                                            &username_in,
317                                            to,
318                                            &state_in.host_user,
319                                            &state_in.clients,
320                                            &state_in.chat_path,
321                                            &state_in.seq_counter,
322                                        )
323                                        .await
324                                    }
325                                    _ => {
326                                        broadcast_and_persist(
327                                            &msg,
328                                            &state_in.clients,
329                                            &state_in.chat_path,
330                                            &state_in.seq_counter,
331                                        )
332                                        .await
333                                    }
334                                };
335                                if let Err(e) = result {
336                                    eprintln!("[broker] persist error: {e:#}");
337                                }
338                            }
339                            Err(e) => eprintln!("[broker] route error: {e:#}"),
340                        },
341                        Err(e) => eprintln!("[broker] bad message from {username_in}: {e}"),
342                    }
343                }
344                Err(_) => break,
345            }
346        }
347    });
348
349    tokio::select! {
350        _ = outbound => {},
351        _ = inbound => {},
352    }
353
354    // Remove user from status map on disconnect
355    status_map.lock().await.remove(&username);
356
357    // Broadcast leave event
358    let leave_msg = make_leave(room_id.as_str(), &username);
359    let _ = broadcast_and_persist(&leave_msg, &clients, &chat_path, &seq_counter).await;
360    eprintln!("[broker] {username} left (cid={cid})");
361
362    Ok(())
363}
364
365/// Handle a one-shot SEND connection: read one message line, route it, echo it back, close.
366/// The sender is never registered in ClientMap/StatusMap and generates no join/leave events.
367/// DM envelopes are routed via `dm_and_persist`; all other messages are broadcast.
368async fn handle_oneshot_send(
369    username: String,
370    mut reader: BufReader<OwnedReadHalf>,
371    mut write_half: OwnedWriteHalf,
372    state: &RoomState,
373) -> anyhow::Result<()> {
374    let mut line = String::new();
375    reader.read_line(&mut line).await?;
376    let trimmed = line.trim();
377    if trimmed.is_empty() {
378        return Ok(());
379    }
380    let msg = parse_client_line(trimmed, &state.room_id, &username)?;
381    match route_command(msg, &username, state).await? {
382        CommandResult::Handled | CommandResult::Shutdown => {}
383        CommandResult::Reply(json) => {
384            write_half.write_all(format!("{json}\n").as_bytes()).await?;
385        }
386        CommandResult::Passthrough(msg) => {
387            let seq_msg = match &msg {
388                Message::DirectMessage { to, .. } => {
389                    dm_and_persist(
390                        &msg,
391                        &username,
392                        to,
393                        &state.host_user,
394                        &state.clients,
395                        &state.chat_path,
396                        &state.seq_counter,
397                    )
398                    .await?
399                }
400                _ => {
401                    broadcast_and_persist(
402                        &msg,
403                        &state.clients,
404                        &state.chat_path,
405                        &state.seq_counter,
406                    )
407                    .await?
408                }
409            };
410            let echo = format!("{}\n", serde_json::to_string(&seq_msg)?);
411            write_half.write_all(echo.as_bytes()).await?;
412        }
413    }
414    Ok(())
415}