Skip to main content

room_daemon/broker/
mod.rs

1pub(crate) mod admin;
2pub(crate) mod auth;
3pub(crate) mod commands;
4pub mod daemon;
5pub(crate) mod fanout;
6pub(crate) mod handshake;
7pub mod persistence;
8pub(crate) mod service;
9pub(crate) mod session;
10pub(crate) mod state;
11pub(crate) mod token_store;
12pub(crate) mod ws;
13
14use std::{
15    collections::HashMap,
16    path::PathBuf,
17    sync::{
18        atomic::{AtomicU64, Ordering},
19        Arc,
20    },
21};
22
23use crate::plugin::PluginRegistry;
24use auth::{handle_oneshot_join, validate_token};
25use state::RoomState;
26use tokio::{
27    io::{AsyncBufRead, AsyncBufReadExt, AsyncWriteExt, BufReader},
28    net::{
29        unix::{OwnedReadHalf, OwnedWriteHalf},
30        UnixListener, UnixStream,
31    },
32    sync::{broadcast, watch, Mutex},
33};
34
35/// Maximum bytes allowed in a single line from a client connection.
36/// Prevents memory exhaustion from malicious clients sending arbitrarily large lines.
37pub const MAX_LINE_BYTES: usize = 64 * 1024; // 64 KB
38
39/// Read a single newline-terminated line, rejecting lines that exceed `MAX_LINE_BYTES`.
40///
41/// Returns `Ok(n)` where `n` is the number of bytes read (0 = EOF).
42/// Returns an error if the accumulated bytes before a newline exceed the limit.
43///
44/// The line (including the trailing `\n`) is appended to `buf`, matching the
45/// behaviour of `AsyncBufReadExt::read_line`.
46pub(crate) async fn read_line_limited<R: AsyncBufRead + Unpin>(
47    reader: &mut R,
48    buf: &mut String,
49) -> anyhow::Result<usize> {
50    let mut total = 0usize;
51    loop {
52        let available = reader.fill_buf().await?;
53        if available.is_empty() {
54            // EOF
55            return Ok(total);
56        }
57        // Look for a newline in the buffered data.
58        let (chunk, found_newline) = match available.iter().position(|&b| b == b'\n') {
59            Some(pos) => (&available[..=pos], true),
60            None => (available, false),
61        };
62        let chunk_len = chunk.len();
63        if total + chunk_len > MAX_LINE_BYTES {
64            anyhow::bail!("line exceeds maximum size of {} bytes", MAX_LINE_BYTES);
65        }
66        // Safety: we validate UTF-8 before appending.
67        let text = std::str::from_utf8(chunk)
68            .map_err(|e| anyhow::anyhow!("invalid UTF-8 in client line: {e}"))?;
69        buf.push_str(text);
70        total += chunk_len;
71        reader.consume(chunk_len);
72        if found_newline {
73            return Ok(total);
74        }
75    }
76}
77
78pub struct Broker {
79    room_id: String,
80    chat_path: PathBuf,
81    /// Path to the persisted token-map file (e.g. `~/.room/state/<room_id>.tokens`).
82    token_map_path: PathBuf,
83    /// Path to the persisted subscription-map file (e.g. `~/.room/state/<room_id>.subscriptions`).
84    subscription_map_path: PathBuf,
85    socket_path: PathBuf,
86    ws_port: Option<u16>,
87}
88
89impl Broker {
90    pub fn new(
91        room_id: &str,
92        chat_path: PathBuf,
93        token_map_path: PathBuf,
94        subscription_map_path: PathBuf,
95        socket_path: PathBuf,
96        ws_port: Option<u16>,
97    ) -> Self {
98        Self {
99            room_id: room_id.to_owned(),
100            chat_path,
101            token_map_path,
102            subscription_map_path,
103            socket_path,
104            ws_port,
105        }
106    }
107
108    pub async fn run(self) -> anyhow::Result<()> {
109        // Remove stale socket synchronously — using tokio::fs here is dangerous
110        // because the blocking thread pool may be shutting down if the broker
111        // is starting up inside a dying process.
112        if self.socket_path.exists() {
113            std::fs::remove_file(&self.socket_path)?;
114        }
115
116        let listener = UnixListener::bind(&self.socket_path)?;
117        eprintln!("[broker] listening on {}", self.socket_path.display());
118
119        let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
120
121        let registry = PluginRegistry::with_all_plugins(&self.chat_path)?;
122
123        // Load persisted state from a previous broker session (if any).
124        let persisted_tokens = token_store::load_token_map(&self.token_map_path);
125        if !persisted_tokens.is_empty() {
126            eprintln!(
127                "[broker] loaded {} persisted token(s)",
128                persisted_tokens.len()
129            );
130        }
131        let persisted_subs = persistence::load_subscription_map(&self.subscription_map_path);
132        if !persisted_subs.is_empty() {
133            eprintln!(
134                "[broker] loaded {} persisted subscription(s)",
135                persisted_subs.len()
136            );
137        }
138
139        let state = Arc::new(RoomState {
140            clients: Arc::new(Mutex::new(HashMap::new())),
141            status_map: Arc::new(Mutex::new(HashMap::new())),
142            host_user: Arc::new(Mutex::new(None)),
143            auth: state::AuthState {
144                token_map: Arc::new(Mutex::new(persisted_tokens)),
145                token_map_path: Arc::new(self.token_map_path.clone()),
146                registry: std::sync::OnceLock::new(),
147            },
148            filters: state::FilterState {
149                subscription_map: Arc::new(Mutex::new(persisted_subs)),
150                subscription_map_path: Arc::new(self.subscription_map_path.clone()),
151                event_filter_state: std::sync::OnceLock::new(),
152            },
153            chat_path: Arc::new(self.chat_path.clone()),
154            room_id: Arc::new(self.room_id.clone()),
155            shutdown: Arc::new(shutdown_tx),
156            seq_counter: Arc::new(AtomicU64::new(0)),
157            plugin_registry: Arc::new(registry),
158            config: None,
159            cross_room_resolver: std::sync::OnceLock::new(),
160        });
161        // Attach event filter map (parallel to subscription map).
162        {
163            let ef_path = self.subscription_map_path.with_extension("event_filters");
164            let persisted_ef = persistence::load_event_filter_map(&ef_path);
165            if !persisted_ef.is_empty() {
166                eprintln!(
167                    "[broker] loaded {} persisted event filter(s)",
168                    persisted_ef.len()
169                );
170            }
171            state.set_event_filter_map(Arc::new(Mutex::new(persisted_ef)), ef_path);
172        }
173
174        let next_client_id = Arc::new(AtomicU64::new(0));
175
176        // Start WebSocket/REST server if a port was configured.
177        if let Some(port) = self.ws_port {
178            let ws_state = ws::WsAppState {
179                room_state: state.clone(),
180                next_client_id: next_client_id.clone(),
181                user_registry: None,
182            };
183            let app = ws::create_router(ws_state);
184            let tcp = tokio::net::TcpListener::bind(("0.0.0.0", port)).await?;
185            eprintln!("[broker] WebSocket/REST listening on port {port}");
186            tokio::spawn(async move {
187                if let Err(e) = axum::serve(tcp, app).await {
188                    eprintln!("[broker] WS server error: {e}");
189                }
190            });
191        }
192
193        loop {
194            tokio::select! {
195                accept = listener.accept() => {
196                    let (stream, _) = accept?;
197                    let cid = next_client_id.fetch_add(1, Ordering::SeqCst) + 1;
198
199                    let (tx, _) = broadcast::channel::<String>(256);
200                    // Insert with empty username; handle_client updates it after handshake.
201                    state
202                        .clients
203                        .lock()
204                        .await
205                        .insert(cid, (String::new(), tx.clone()));
206
207                    let state_clone = state.clone();
208
209                    tokio::spawn(async move {
210                        if let Err(e) = handle_client(cid, stream, tx, &state_clone).await {
211                            eprintln!("[broker] client {cid} error: {e:#}");
212                        }
213                        state_clone.clients.lock().await.remove(&cid);
214                    });
215                }
216                _ = shutdown_rx.changed() => {
217                    eprintln!("[broker] shutdown requested, exiting");
218                    break Ok(());
219                }
220            }
221        }
222    }
223}
224
225async fn handle_client(
226    cid: u64,
227    stream: UnixStream,
228    own_tx: broadcast::Sender<String>,
229    state: &Arc<RoomState>,
230) -> anyhow::Result<()> {
231    let token_map = state.auth.token_map.clone();
232
233    let (read_half, mut write_half) = stream.into_split();
234    let mut reader = BufReader::new(read_half);
235
236    // First line: username handshake, or one of the one-shot prefixes.
237    let mut first = String::new();
238    read_line_limited(&mut reader, &mut first).await?;
239    let first_line = first.trim();
240
241    use handshake::{parse_client_handshake, ClientHandshake};
242    let username = match parse_client_handshake(first_line) {
243        ClientHandshake::Send(u) => {
244            eprintln!(
245                "[broker] DEPRECATED: SEND:{u} handshake used — \
246                 migrate to TOKEN:<uuid> (SEND: will be removed in a future version)"
247            );
248            return handle_oneshot_send(u, reader, write_half, state).await;
249        }
250        ClientHandshake::Token(token) => {
251            return match validate_token(&token, &token_map).await {
252                Some(u) => handle_oneshot_send(u, reader, write_half, state).await,
253                None => {
254                    let err = serde_json::json!({"type":"error","code":"invalid_token"});
255                    write_half
256                        .write_all(format!("{err}\n").as_bytes())
257                        .await
258                        .map_err(Into::into)
259                }
260            };
261        }
262        ClientHandshake::Join(u) => {
263            let result = handle_oneshot_join(
264                u,
265                write_half,
266                &token_map,
267                &state.filters.subscription_map,
268                state.config.as_ref(),
269                Some(&state.auth.token_map_path),
270            )
271            .await;
272            // Persist auto-subscription from join so it survives broker restart.
273            persistence::persist_subscriptions(state).await;
274            return result;
275        }
276        ClientHandshake::Session(token) => {
277            return match validate_token(&token, &token_map).await {
278                Some(u) => {
279                    if let Err(reason) = auth::check_join_permission(&u, state.config.as_ref()) {
280                        let err = serde_json::json!({
281                            "type": "error",
282                            "code": "join_denied",
283                            "message": reason,
284                            "username": u
285                        });
286                        write_half.write_all(format!("{err}\n").as_bytes()).await?;
287                        return Ok(());
288                    }
289                    run_interactive_session(cid, &u, reader, write_half, own_tx, state).await
290                }
291                None => {
292                    let err = serde_json::json!({"type":"error","code":"invalid_token"});
293                    write_half
294                        .write_all(format!("{err}\n").as_bytes())
295                        .await
296                        .map_err(Into::into)
297                }
298            };
299        }
300        ClientHandshake::Interactive(u) => {
301            eprintln!(
302                "[broker] DEPRECATED: unauthenticated interactive join for '{u}' — \
303                 migrate to SESSION:<token> (plain username joins will be removed in a future version)"
304            );
305            u
306        }
307    };
308
309    // Remaining path: deprecated unauthenticated interactive join.
310    if username.is_empty() {
311        return Ok(());
312    }
313
314    // Check join permission before entering interactive session.
315    if let Err(reason) = auth::check_join_permission(&username, state.config.as_ref()) {
316        let err = serde_json::json!({
317            "type": "error",
318            "code": "join_denied",
319            "message": reason,
320            "username": username
321        });
322        write_half.write_all(format!("{err}\n").as_bytes()).await?;
323        return Ok(());
324    }
325
326    run_interactive_session(cid, &username, reader, write_half, own_tx, state).await
327}
328
329/// Run an interactive client session after the username has been determined.
330///
331/// Shared by both single-room (`handle_client`) and daemon (`dispatch_connection`)
332/// paths. Delegates setup, message processing, and teardown to
333/// [`session`](super::session) — this function only handles UDS-specific I/O
334/// (reading lines, writing bytes, shutdown signaling).
335pub(crate) async fn run_interactive_session(
336    cid: u64,
337    username: &str,
338    reader: BufReader<OwnedReadHalf>,
339    mut write_half: OwnedWriteHalf,
340    own_tx: broadcast::Sender<String>,
341    state: &Arc<RoomState>,
342) -> anyhow::Result<()> {
343    let username = username.to_owned();
344
345    // Subscribe before setup so we don't miss concurrent messages.
346    let mut rx = own_tx.subscribe();
347
348    // Shared setup: register client, elect host, load history, broadcast join.
349    let history_lines = match session::session_setup(cid, &username, state).await {
350        Ok(lines) => lines,
351        Err(e) => {
352            eprintln!("[broker] session_setup failed: {e:#}");
353            return Ok(());
354        }
355    };
356
357    // Send history to client over UDS.
358    for line in &history_lines {
359        if write_half
360            .write_all(format!("{line}\n").as_bytes())
361            .await
362            .is_err()
363        {
364            return Ok(());
365        }
366    }
367
368    // Wrap write half in Arc<Mutex> for shared use by outbound and inbound tasks.
369    let write_half = Arc::new(Mutex::new(write_half));
370
371    // Outbound: receive from broadcast channel, forward to client socket.
372    let write_half_out = write_half.clone();
373    let mut shutdown_rx = state.shutdown.subscribe();
374    let outbound = tokio::spawn(async move {
375        loop {
376            tokio::select! {
377                result = rx.recv() => {
378                    match result {
379                        Ok(line) => {
380                            let mut wh = write_half_out.lock().await;
381                            if wh.write_all(line.as_bytes()).await.is_err() {
382                                break;
383                            }
384                        }
385                        Err(broadcast::error::RecvError::Lagged(n)) => {
386                            eprintln!("[broker] cid={cid} lagged by {n}");
387                        }
388                        Err(broadcast::error::RecvError::Closed) => break,
389                    }
390                }
391                _ = shutdown_rx.changed() => {
392                    while let Ok(line) = rx.try_recv() {
393                        let mut wh = write_half_out.lock().await;
394                        let _ = wh.write_all(line.as_bytes()).await;
395                    }
396                    let _ = write_half_out.lock().await.shutdown().await;
397                    break;
398                }
399            }
400        }
401    });
402
403    // Inbound: read lines from client socket, delegate to shared processing.
404    let username_in = username.clone();
405    let write_half_in = write_half.clone();
406    let state_in = state.clone();
407    let inbound = tokio::spawn(async move {
408        let mut reader = reader;
409        let mut line = String::new();
410        loop {
411            line.clear();
412            match read_line_limited(&mut reader, &mut line).await {
413                Ok(0) => break,
414                Ok(_) => {
415                    let trimmed = line.trim();
416                    if trimmed.is_empty() {
417                        continue;
418                    }
419                    match session::process_inbound_message(trimmed, &username_in, &state_in).await {
420                        session::InboundResult::Ok => {}
421                        session::InboundResult::Reply(json) => {
422                            let _ = write_half_in
423                                .lock()
424                                .await
425                                .write_all(format!("{json}\n").as_bytes())
426                                .await;
427                        }
428                        session::InboundResult::Shutdown => break,
429                    }
430                }
431                Err(e) => {
432                    eprintln!("[broker] read error from {username_in}: {e:#}");
433                    let err = serde_json::json!({
434                        "type": "error",
435                        "code": "line_too_long",
436                        "message": format!("{e}")
437                    });
438                    let _ = write_half_in
439                        .lock()
440                        .await
441                        .write_all(format!("{err}\n").as_bytes())
442                        .await;
443                    break;
444                }
445            }
446        }
447    });
448
449    tokio::select! {
450        _ = outbound => {},
451        _ = inbound => {},
452    }
453
454    // Shared teardown: remove status, broadcast leave.
455    session::session_teardown(cid, &username, state).await;
456
457    Ok(())
458}
459
460/// Handle a one-shot SEND connection: read one message line, route it, echo it back, close.
461/// The sender is never registered in ClientMap/StatusMap and generates no join/leave events.
462/// DM envelopes are routed via `dm_and_persist`; all other messages are broadcast.
463pub(crate) async fn handle_oneshot_send(
464    username: String,
465    mut reader: BufReader<OwnedReadHalf>,
466    mut write_half: OwnedWriteHalf,
467    state: &RoomState,
468) -> anyhow::Result<()> {
469    let mut line = String::new();
470    read_line_limited(&mut reader, &mut line).await?;
471    let trimmed = line.trim();
472    if trimmed.is_empty() {
473        return Ok(());
474    }
475    let session::OneshotResult::Reply(reply) =
476        session::process_oneshot_send(trimmed, &username, state).await?;
477    write_half
478        .write_all(format!("{reply}\n").as_bytes())
479        .await?;
480    Ok(())
481}
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486
487    // --- read_line_limited tests ---
488
489    #[tokio::test]
490    async fn read_line_limited_reads_normal_line() {
491        let data = b"hello world\n";
492        let cursor = std::io::Cursor::new(data.to_vec());
493        let mut reader = tokio::io::BufReader::new(cursor);
494        let mut buf = String::new();
495        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
496        assert_eq!(n, 12);
497        assert_eq!(buf, "hello world\n");
498    }
499
500    #[tokio::test]
501    async fn read_line_limited_reads_line_without_trailing_newline() {
502        let data = b"no newline";
503        let cursor = std::io::Cursor::new(data.to_vec());
504        let mut reader = tokio::io::BufReader::new(cursor);
505        let mut buf = String::new();
506        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
507        assert_eq!(n, 10);
508        assert_eq!(buf, "no newline");
509    }
510
511    #[tokio::test]
512    async fn read_line_limited_returns_zero_on_eof() {
513        let data = b"";
514        let cursor = std::io::Cursor::new(data.to_vec());
515        let mut reader = tokio::io::BufReader::new(cursor);
516        let mut buf = String::new();
517        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
518        assert_eq!(n, 0);
519        assert!(buf.is_empty());
520    }
521
522    #[tokio::test]
523    async fn read_line_limited_rejects_oversized_line() {
524        let data = vec![b'A'; MAX_LINE_BYTES + 1];
525        let cursor = std::io::Cursor::new(data);
526        let mut reader = tokio::io::BufReader::new(cursor);
527        let mut buf = String::new();
528        let result = read_line_limited(&mut reader, &mut buf).await;
529        assert!(result.is_err());
530        let err_msg = result.unwrap_err().to_string();
531        assert!(
532            err_msg.contains("exceeds maximum size"),
533            "unexpected error: {err_msg}"
534        );
535    }
536
537    #[tokio::test]
538    async fn read_line_limited_accepts_line_at_exact_limit() {
539        let mut data = vec![b'A'; MAX_LINE_BYTES - 1];
540        data.push(b'\n');
541        let cursor = std::io::Cursor::new(data);
542        let mut reader = tokio::io::BufReader::new(cursor);
543        let mut buf = String::new();
544        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
545        assert_eq!(n, MAX_LINE_BYTES);
546        assert!(buf.ends_with('\n'));
547    }
548
549    #[tokio::test]
550    async fn read_line_limited_reads_multiple_lines() {
551        let data = b"line one\nline two\n";
552        let cursor = std::io::Cursor::new(data.to_vec());
553        let mut reader = tokio::io::BufReader::new(cursor);
554
555        let mut buf = String::new();
556        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
557        assert_eq!(n, 9);
558        assert_eq!(buf, "line one\n");
559
560        buf.clear();
561        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
562        assert_eq!(n, 9);
563        assert_eq!(buf, "line two\n");
564
565        buf.clear();
566        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
567        assert_eq!(n, 0);
568    }
569
570    #[tokio::test]
571    async fn read_line_limited_rejects_invalid_utf8() {
572        let data: Vec<u8> = vec![0xFF, 0xFE, b'\n'];
573        let cursor = std::io::Cursor::new(data);
574        let mut reader = tokio::io::BufReader::new(cursor);
575        let mut buf = String::new();
576        let result = read_line_limited(&mut reader, &mut buf).await;
577        assert!(result.is_err());
578        let err_msg = result.unwrap_err().to_string();
579        assert!(
580            err_msg.contains("invalid UTF-8"),
581            "unexpected error: {err_msg}"
582        );
583    }
584
585    #[tokio::test]
586    async fn read_line_limited_exact_limit_no_newline_accepted() {
587        // Exactly MAX_LINE_BYTES of data with no trailing newline → EOF returns Ok.
588        let data = vec![b'X'; MAX_LINE_BYTES];
589        let cursor = std::io::Cursor::new(data);
590        let mut reader = tokio::io::BufReader::new(cursor);
591        let mut buf = String::new();
592        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
593        assert_eq!(n, MAX_LINE_BYTES);
594        assert_eq!(buf.len(), MAX_LINE_BYTES);
595    }
596
597    #[tokio::test]
598    async fn read_line_limited_just_over_limit_no_newline_rejected() {
599        // MAX_LINE_BYTES + 1 bytes without newline → error before EOF.
600        let data = vec![b'Y'; MAX_LINE_BYTES + 1];
601        let cursor = std::io::Cursor::new(data);
602        let mut reader = tokio::io::BufReader::new(cursor);
603        let mut buf = String::new();
604        let result = read_line_limited(&mut reader, &mut buf).await;
605        assert!(result.is_err());
606        assert!(result.unwrap_err().to_string().contains("exceeds maximum"));
607    }
608
609    #[tokio::test]
610    async fn read_line_limited_appends_to_existing_buffer() {
611        // Buffer already has content — read_line_limited appends, does not overwrite.
612        let data = b"world\n";
613        let cursor = std::io::Cursor::new(data.to_vec());
614        let mut reader = tokio::io::BufReader::new(cursor);
615        let mut buf = String::from("hello ");
616        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
617        assert_eq!(n, 6);
618        assert_eq!(buf, "hello world\n");
619    }
620
621    #[tokio::test]
622    async fn read_line_limited_embedded_null_bytes() {
623        // Null bytes are valid UTF-8 — should be accepted.
624        let data: Vec<u8> = vec![b'a', 0x00, b'b', b'\n'];
625        let cursor = std::io::Cursor::new(data);
626        let mut reader = tokio::io::BufReader::new(cursor);
627        let mut buf = String::new();
628        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
629        assert_eq!(n, 4);
630        assert_eq!(buf, "a\0b\n");
631    }
632
633    #[tokio::test]
634    async fn read_line_limited_crlf_line_ending() {
635        // CRLF: the \r is part of the line content, \n terminates.
636        let data = b"line\r\n";
637        let cursor = std::io::Cursor::new(data.to_vec());
638        let mut reader = tokio::io::BufReader::new(cursor);
639        let mut buf = String::new();
640        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
641        assert_eq!(n, 6);
642        assert_eq!(buf, "line\r\n");
643    }
644
645    #[tokio::test]
646    async fn read_line_limited_long_line_with_newline_at_boundary() {
647        // Line of MAX_LINE_BYTES - 1 chars + newline = exactly at limit.
648        let mut data = vec![b'Z'; MAX_LINE_BYTES - 1];
649        data.push(b'\n');
650        // Add trailing data to verify only one line is consumed.
651        data.extend_from_slice(b"next\n");
652        let cursor = std::io::Cursor::new(data);
653        let mut reader = tokio::io::BufReader::new(cursor);
654
655        let mut buf = String::new();
656        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
657        assert_eq!(n, MAX_LINE_BYTES);
658        assert!(buf.ends_with('\n'));
659        assert_eq!(buf.len(), MAX_LINE_BYTES);
660
661        // Second line should be readable independently.
662        buf.clear();
663        let n2 = read_line_limited(&mut reader, &mut buf).await.unwrap();
664        assert_eq!(n2, 5);
665        assert_eq!(buf, "next\n");
666    }
667}