Skip to main content

room_daemon/broker/
mod.rs

1pub(crate) mod admin;
2pub(crate) mod auth;
3pub(crate) mod commands;
4pub mod daemon;
5pub(crate) mod fanout;
6pub(crate) mod handshake;
7pub mod persistence;
8pub(crate) mod service;
9pub(crate) mod session;
10pub(crate) mod state;
11pub(crate) mod token_store;
12pub(crate) mod ws;
13
14use std::{
15    collections::HashMap,
16    path::PathBuf,
17    sync::{
18        atomic::{AtomicU64, Ordering},
19        Arc,
20    },
21};
22
23use crate::plugin::PluginRegistry;
24use auth::{handle_oneshot_join, validate_token};
25use state::RoomState;
26use tokio::{
27    io::{AsyncBufRead, AsyncBufReadExt, AsyncWriteExt, BufReader},
28    net::{
29        unix::{OwnedReadHalf, OwnedWriteHalf},
30        UnixListener, UnixStream,
31    },
32    sync::{broadcast, watch, Mutex},
33};
34
35/// Maximum bytes allowed in a single line from a client connection.
36/// Prevents memory exhaustion from malicious clients sending arbitrarily large lines.
37pub const MAX_LINE_BYTES: usize = 64 * 1024; // 64 KB
38
39/// Read a single newline-terminated line, rejecting lines that exceed `MAX_LINE_BYTES`.
40///
41/// Returns `Ok(n)` where `n` is the number of bytes read (0 = EOF).
42/// Returns an error if the accumulated bytes before a newline exceed the limit.
43///
44/// The line (including the trailing `\n`) is appended to `buf`, matching the
45/// behaviour of `AsyncBufReadExt::read_line`.
46pub(crate) async fn read_line_limited<R: AsyncBufRead + Unpin>(
47    reader: &mut R,
48    buf: &mut String,
49) -> anyhow::Result<usize> {
50    let mut total = 0usize;
51    loop {
52        let available = reader.fill_buf().await?;
53        if available.is_empty() {
54            // EOF
55            return Ok(total);
56        }
57        // Look for a newline in the buffered data.
58        let (chunk, found_newline) = match available.iter().position(|&b| b == b'\n') {
59            Some(pos) => (&available[..=pos], true),
60            None => (available, false),
61        };
62        let chunk_len = chunk.len();
63        if total + chunk_len > MAX_LINE_BYTES {
64            anyhow::bail!("line exceeds maximum size of {} bytes", MAX_LINE_BYTES);
65        }
66        // Safety: we validate UTF-8 before appending.
67        let text = std::str::from_utf8(chunk)
68            .map_err(|e| anyhow::anyhow!("invalid UTF-8 in client line: {e}"))?;
69        buf.push_str(text);
70        total += chunk_len;
71        reader.consume(chunk_len);
72        if found_newline {
73            return Ok(total);
74        }
75    }
76}
77
78pub struct Broker {
79    room_id: String,
80    chat_path: PathBuf,
81    /// Path to the persisted token-map file (e.g. `~/.room/state/<room_id>.tokens`).
82    token_map_path: PathBuf,
83    /// Path to the persisted subscription-map file (e.g. `~/.room/state/<room_id>.subscriptions`).
84    subscription_map_path: PathBuf,
85    socket_path: PathBuf,
86    ws_port: Option<u16>,
87}
88
89impl Broker {
90    pub fn new(
91        room_id: &str,
92        chat_path: PathBuf,
93        token_map_path: PathBuf,
94        subscription_map_path: PathBuf,
95        socket_path: PathBuf,
96        ws_port: Option<u16>,
97    ) -> Self {
98        Self {
99            room_id: room_id.to_owned(),
100            chat_path,
101            token_map_path,
102            subscription_map_path,
103            socket_path,
104            ws_port,
105        }
106    }
107
108    pub async fn run(self) -> anyhow::Result<()> {
109        // Remove stale socket synchronously — using tokio::fs here is dangerous
110        // because the blocking thread pool may be shutting down if the broker
111        // is starting up inside a dying process.
112        if self.socket_path.exists() {
113            std::fs::remove_file(&self.socket_path)?;
114        }
115
116        let listener = UnixListener::bind(&self.socket_path)?;
117        eprintln!("[broker] listening on {}", self.socket_path.display());
118
119        let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
120
121        let registry = PluginRegistry::with_all_plugins(&self.chat_path)?;
122
123        // Load persisted state from a previous broker session (if any).
124        let persisted_tokens = token_store::load_token_map(&self.token_map_path);
125        if !persisted_tokens.is_empty() {
126            eprintln!(
127                "[broker] loaded {} persisted token(s)",
128                persisted_tokens.len()
129            );
130        }
131        let persisted_subs = persistence::load_subscription_map(&self.subscription_map_path);
132        if !persisted_subs.is_empty() {
133            eprintln!(
134                "[broker] loaded {} persisted subscription(s)",
135                persisted_subs.len()
136            );
137        }
138
139        let state = Arc::new(RoomState {
140            clients: Arc::new(Mutex::new(HashMap::new())),
141            status_map: Arc::new(Mutex::new(HashMap::new())),
142            status_timestamps: Arc::new(Mutex::new(HashMap::new())),
143            last_message_times: Arc::new(Mutex::new(HashMap::new())),
144            host_user: Arc::new(Mutex::new(None)),
145            auth: state::AuthState {
146                token_map: Arc::new(Mutex::new(persisted_tokens)),
147                token_map_path: Arc::new(self.token_map_path.clone()),
148                registry: std::sync::OnceLock::new(),
149            },
150            filters: state::FilterState {
151                subscription_map: Arc::new(Mutex::new(persisted_subs)),
152                subscription_map_path: Arc::new(self.subscription_map_path.clone()),
153                event_filter_state: std::sync::OnceLock::new(),
154            },
155            chat_path: Arc::new(self.chat_path.clone()),
156            room_id: Arc::new(self.room_id.clone()),
157            shutdown: Arc::new(shutdown_tx),
158            seq_counter: Arc::new(AtomicU64::new(crate::history::max_seq_from_history(
159                &self.chat_path,
160            ))),
161            plugin_registry: Arc::new(registry),
162            config: None,
163            cross_room_resolver: std::sync::OnceLock::new(),
164        });
165        // Attach event filter map (parallel to subscription map).
166        {
167            let ef_path = self.subscription_map_path.with_extension("event_filters");
168            let persisted_ef = persistence::load_event_filter_map(&ef_path);
169            if !persisted_ef.is_empty() {
170                eprintln!(
171                    "[broker] loaded {} persisted event filter(s)",
172                    persisted_ef.len()
173                );
174            }
175            state.set_event_filter_map(Arc::new(Mutex::new(persisted_ef)), ef_path);
176        }
177
178        let next_client_id = Arc::new(AtomicU64::new(0));
179
180        // Start WebSocket/REST server if a port was configured.
181        if let Some(port) = self.ws_port {
182            let ws_state = ws::WsAppState {
183                room_state: state.clone(),
184                next_client_id: next_client_id.clone(),
185                user_registry: None,
186            };
187            let app = ws::create_router(ws_state);
188            let tcp = tokio::net::TcpListener::bind(("0.0.0.0", port)).await?;
189            eprintln!("[broker] WebSocket/REST listening on port {port}");
190            tokio::spawn(async move {
191                if let Err(e) = axum::serve(tcp, app).await {
192                    eprintln!("[broker] WS server error: {e}");
193                }
194            });
195        }
196
197        loop {
198            tokio::select! {
199                accept = listener.accept() => {
200                    let (stream, _) = accept?;
201                    let cid = next_client_id.fetch_add(1, Ordering::SeqCst) + 1;
202
203                    let (tx, _) = broadcast::channel::<String>(256);
204                    // Insert with empty username; handle_client updates it after handshake.
205                    state
206                        .clients
207                        .lock()
208                        .await
209                        .insert(cid, (String::new(), tx.clone()));
210
211                    let state_clone = state.clone();
212
213                    tokio::spawn(async move {
214                        if let Err(e) = handle_client(cid, stream, tx, &state_clone).await {
215                            eprintln!("[broker] client {cid} error: {e:#}");
216                        }
217                        state_clone.clients.lock().await.remove(&cid);
218                    });
219                }
220                _ = shutdown_rx.changed() => {
221                    eprintln!("[broker] shutdown requested, exiting");
222                    break Ok(());
223                }
224            }
225        }
226    }
227}
228
229async fn handle_client(
230    cid: u64,
231    stream: UnixStream,
232    own_tx: broadcast::Sender<String>,
233    state: &Arc<RoomState>,
234) -> anyhow::Result<()> {
235    let token_map = state.auth.token_map.clone();
236
237    let (read_half, mut write_half) = stream.into_split();
238    let mut reader = BufReader::new(read_half);
239
240    // First line: username handshake, or one of the one-shot prefixes.
241    let mut first = String::new();
242    read_line_limited(&mut reader, &mut first).await?;
243    let first_line = first.trim();
244
245    use handshake::{parse_client_handshake, ClientHandshake};
246    let username = match parse_client_handshake(first_line) {
247        ClientHandshake::Send(u) => {
248            eprintln!(
249                "[broker] DEPRECATED: SEND:{u} handshake used — \
250                 migrate to TOKEN:<uuid> (SEND: will be removed in a future version)"
251            );
252            return handle_oneshot_send(u, reader, write_half, state).await;
253        }
254        ClientHandshake::Token(token) => {
255            return match validate_token(&token, &token_map).await {
256                Some(u) => handle_oneshot_send(u, reader, write_half, state).await,
257                None => {
258                    let err = serde_json::json!({"type":"error","code":"invalid_token"});
259                    write_half
260                        .write_all(format!("{err}\n").as_bytes())
261                        .await
262                        .map_err(Into::into)
263                }
264            };
265        }
266        ClientHandshake::Join(u) => {
267            let result = handle_oneshot_join(
268                u,
269                write_half,
270                &token_map,
271                &state.filters.subscription_map,
272                state.config.as_ref(),
273                Some(&state.auth.token_map_path),
274            )
275            .await;
276            // Persist auto-subscription from join so it survives broker restart.
277            persistence::persist_subscriptions(state).await;
278            return result;
279        }
280        ClientHandshake::Session(token) => {
281            return match validate_token(&token, &token_map).await {
282                Some(u) => {
283                    if let Err(reason) = auth::check_join_permission(&u, state.config.as_ref()) {
284                        let err = serde_json::json!({
285                            "type": "error",
286                            "code": "join_denied",
287                            "message": reason,
288                            "username": u
289                        });
290                        write_half.write_all(format!("{err}\n").as_bytes()).await?;
291                        return Ok(());
292                    }
293                    run_interactive_session(cid, &u, reader, write_half, own_tx, state).await
294                }
295                None => {
296                    let err = serde_json::json!({"type":"error","code":"invalid_token"});
297                    write_half
298                        .write_all(format!("{err}\n").as_bytes())
299                        .await
300                        .map_err(Into::into)
301                }
302            };
303        }
304        ClientHandshake::Interactive(u) => {
305            eprintln!(
306                "[broker] DEPRECATED: unauthenticated interactive join for '{u}' — \
307                 migrate to SESSION:<token> (plain username joins will be removed in a future version)"
308            );
309            u
310        }
311    };
312
313    // Remaining path: deprecated unauthenticated interactive join.
314    if username.is_empty() {
315        return Ok(());
316    }
317
318    // Check join permission before entering interactive session.
319    if let Err(reason) = auth::check_join_permission(&username, state.config.as_ref()) {
320        let err = serde_json::json!({
321            "type": "error",
322            "code": "join_denied",
323            "message": reason,
324            "username": username
325        });
326        write_half.write_all(format!("{err}\n").as_bytes()).await?;
327        return Ok(());
328    }
329
330    run_interactive_session(cid, &username, reader, write_half, own_tx, state).await
331}
332
333/// Run an interactive client session after the username has been determined.
334///
335/// Shared by both single-room (`handle_client`) and daemon (`dispatch_connection`)
336/// paths. Delegates setup, message processing, and teardown to
337/// [`session`](super::session) — this function only handles UDS-specific I/O
338/// (reading lines, writing bytes, shutdown signaling).
339pub(crate) async fn run_interactive_session(
340    cid: u64,
341    username: &str,
342    reader: BufReader<OwnedReadHalf>,
343    mut write_half: OwnedWriteHalf,
344    own_tx: broadcast::Sender<String>,
345    state: &Arc<RoomState>,
346) -> anyhow::Result<()> {
347    let username = username.to_owned();
348
349    // Subscribe before setup so we don't miss concurrent messages.
350    let mut rx = own_tx.subscribe();
351
352    // Shared setup: register client, elect host, load history, broadcast join.
353    let history_lines = match session::session_setup(cid, &username, state).await {
354        Ok(lines) => lines,
355        Err(e) => {
356            eprintln!("[broker] session_setup failed: {e:#}");
357            return Ok(());
358        }
359    };
360
361    // Send history to client over UDS.
362    for line in &history_lines {
363        if write_half
364            .write_all(format!("{line}\n").as_bytes())
365            .await
366            .is_err()
367        {
368            return Ok(());
369        }
370    }
371
372    // Wrap write half in Arc<Mutex> for shared use by outbound and inbound tasks.
373    let write_half = Arc::new(Mutex::new(write_half));
374
375    // Outbound: receive from broadcast channel, forward to client socket.
376    let write_half_out = write_half.clone();
377    let mut shutdown_rx = state.shutdown.subscribe();
378    let outbound = tokio::spawn(async move {
379        loop {
380            tokio::select! {
381                result = rx.recv() => {
382                    match result {
383                        Ok(line) => {
384                            let mut wh = write_half_out.lock().await;
385                            if wh.write_all(line.as_bytes()).await.is_err() {
386                                break;
387                            }
388                        }
389                        Err(broadcast::error::RecvError::Lagged(n)) => {
390                            eprintln!("[broker] cid={cid} lagged by {n}");
391                        }
392                        Err(broadcast::error::RecvError::Closed) => break,
393                    }
394                }
395                _ = shutdown_rx.changed() => {
396                    while let Ok(line) = rx.try_recv() {
397                        let mut wh = write_half_out.lock().await;
398                        let _ = wh.write_all(line.as_bytes()).await;
399                    }
400                    let _ = write_half_out.lock().await.shutdown().await;
401                    break;
402                }
403            }
404        }
405    });
406
407    // Inbound: read lines from client socket, delegate to shared processing.
408    let username_in = username.clone();
409    let write_half_in = write_half.clone();
410    let state_in = state.clone();
411    let inbound = tokio::spawn(async move {
412        let mut reader = reader;
413        let mut line = String::new();
414        loop {
415            line.clear();
416            match read_line_limited(&mut reader, &mut line).await {
417                Ok(0) => break,
418                Ok(_) => {
419                    let trimmed = line.trim();
420                    if trimmed.is_empty() {
421                        continue;
422                    }
423                    match session::process_inbound_message(trimmed, &username_in, &state_in).await {
424                        session::InboundResult::Ok => {}
425                        session::InboundResult::Reply(json) => {
426                            let _ = write_half_in
427                                .lock()
428                                .await
429                                .write_all(format!("{json}\n").as_bytes())
430                                .await;
431                        }
432                        session::InboundResult::Shutdown => break,
433                    }
434                }
435                Err(e) => {
436                    eprintln!("[broker] read error from {username_in}: {e:#}");
437                    let err = serde_json::json!({
438                        "type": "error",
439                        "code": "line_too_long",
440                        "message": format!("{e}")
441                    });
442                    let _ = write_half_in
443                        .lock()
444                        .await
445                        .write_all(format!("{err}\n").as_bytes())
446                        .await;
447                    break;
448                }
449            }
450        }
451    });
452
453    tokio::select! {
454        _ = outbound => {},
455        _ = inbound => {},
456    }
457
458    // Shared teardown: remove status, broadcast leave.
459    session::session_teardown(cid, &username, state).await;
460
461    Ok(())
462}
463
464/// Handle a one-shot SEND connection: read one message line, route it, echo it back, close.
465/// The sender is never registered in ClientMap/StatusMap and generates no join/leave events.
466/// DM envelopes are routed via `dm_and_persist`; all other messages are broadcast.
467pub(crate) async fn handle_oneshot_send(
468    username: String,
469    mut reader: BufReader<OwnedReadHalf>,
470    mut write_half: OwnedWriteHalf,
471    state: &RoomState,
472) -> anyhow::Result<()> {
473    let mut line = String::new();
474    read_line_limited(&mut reader, &mut line).await?;
475    let trimmed = line.trim();
476    if trimmed.is_empty() {
477        return Ok(());
478    }
479    let session::OneshotResult::Reply(reply) =
480        session::process_oneshot_send(trimmed, &username, state).await?;
481    write_half
482        .write_all(format!("{reply}\n").as_bytes())
483        .await?;
484    Ok(())
485}
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490
491    // --- read_line_limited tests ---
492
493    #[tokio::test]
494    async fn read_line_limited_reads_normal_line() {
495        let data = b"hello world\n";
496        let cursor = std::io::Cursor::new(data.to_vec());
497        let mut reader = tokio::io::BufReader::new(cursor);
498        let mut buf = String::new();
499        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
500        assert_eq!(n, 12);
501        assert_eq!(buf, "hello world\n");
502    }
503
504    #[tokio::test]
505    async fn read_line_limited_reads_line_without_trailing_newline() {
506        let data = b"no newline";
507        let cursor = std::io::Cursor::new(data.to_vec());
508        let mut reader = tokio::io::BufReader::new(cursor);
509        let mut buf = String::new();
510        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
511        assert_eq!(n, 10);
512        assert_eq!(buf, "no newline");
513    }
514
515    #[tokio::test]
516    async fn read_line_limited_returns_zero_on_eof() {
517        let data = b"";
518        let cursor = std::io::Cursor::new(data.to_vec());
519        let mut reader = tokio::io::BufReader::new(cursor);
520        let mut buf = String::new();
521        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
522        assert_eq!(n, 0);
523        assert!(buf.is_empty());
524    }
525
526    #[tokio::test]
527    async fn read_line_limited_rejects_oversized_line() {
528        let data = vec![b'A'; MAX_LINE_BYTES + 1];
529        let cursor = std::io::Cursor::new(data);
530        let mut reader = tokio::io::BufReader::new(cursor);
531        let mut buf = String::new();
532        let result = read_line_limited(&mut reader, &mut buf).await;
533        assert!(result.is_err());
534        let err_msg = result.unwrap_err().to_string();
535        assert!(
536            err_msg.contains("exceeds maximum size"),
537            "unexpected error: {err_msg}"
538        );
539    }
540
541    #[tokio::test]
542    async fn read_line_limited_accepts_line_at_exact_limit() {
543        let mut data = vec![b'A'; MAX_LINE_BYTES - 1];
544        data.push(b'\n');
545        let cursor = std::io::Cursor::new(data);
546        let mut reader = tokio::io::BufReader::new(cursor);
547        let mut buf = String::new();
548        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
549        assert_eq!(n, MAX_LINE_BYTES);
550        assert!(buf.ends_with('\n'));
551    }
552
553    #[tokio::test]
554    async fn read_line_limited_reads_multiple_lines() {
555        let data = b"line one\nline two\n";
556        let cursor = std::io::Cursor::new(data.to_vec());
557        let mut reader = tokio::io::BufReader::new(cursor);
558
559        let mut buf = String::new();
560        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
561        assert_eq!(n, 9);
562        assert_eq!(buf, "line one\n");
563
564        buf.clear();
565        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
566        assert_eq!(n, 9);
567        assert_eq!(buf, "line two\n");
568
569        buf.clear();
570        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
571        assert_eq!(n, 0);
572    }
573
574    #[tokio::test]
575    async fn read_line_limited_rejects_invalid_utf8() {
576        let data: Vec<u8> = vec![0xFF, 0xFE, b'\n'];
577        let cursor = std::io::Cursor::new(data);
578        let mut reader = tokio::io::BufReader::new(cursor);
579        let mut buf = String::new();
580        let result = read_line_limited(&mut reader, &mut buf).await;
581        assert!(result.is_err());
582        let err_msg = result.unwrap_err().to_string();
583        assert!(
584            err_msg.contains("invalid UTF-8"),
585            "unexpected error: {err_msg}"
586        );
587    }
588
589    #[tokio::test]
590    async fn read_line_limited_exact_limit_no_newline_accepted() {
591        // Exactly MAX_LINE_BYTES of data with no trailing newline → EOF returns Ok.
592        let data = vec![b'X'; MAX_LINE_BYTES];
593        let cursor = std::io::Cursor::new(data);
594        let mut reader = tokio::io::BufReader::new(cursor);
595        let mut buf = String::new();
596        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
597        assert_eq!(n, MAX_LINE_BYTES);
598        assert_eq!(buf.len(), MAX_LINE_BYTES);
599    }
600
601    #[tokio::test]
602    async fn read_line_limited_just_over_limit_no_newline_rejected() {
603        // MAX_LINE_BYTES + 1 bytes without newline → error before EOF.
604        let data = vec![b'Y'; MAX_LINE_BYTES + 1];
605        let cursor = std::io::Cursor::new(data);
606        let mut reader = tokio::io::BufReader::new(cursor);
607        let mut buf = String::new();
608        let result = read_line_limited(&mut reader, &mut buf).await;
609        assert!(result.is_err());
610        assert!(result.unwrap_err().to_string().contains("exceeds maximum"));
611    }
612
613    #[tokio::test]
614    async fn read_line_limited_appends_to_existing_buffer() {
615        // Buffer already has content — read_line_limited appends, does not overwrite.
616        let data = b"world\n";
617        let cursor = std::io::Cursor::new(data.to_vec());
618        let mut reader = tokio::io::BufReader::new(cursor);
619        let mut buf = String::from("hello ");
620        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
621        assert_eq!(n, 6);
622        assert_eq!(buf, "hello world\n");
623    }
624
625    #[tokio::test]
626    async fn read_line_limited_embedded_null_bytes() {
627        // Null bytes are valid UTF-8 — should be accepted.
628        let data: Vec<u8> = vec![b'a', 0x00, b'b', b'\n'];
629        let cursor = std::io::Cursor::new(data);
630        let mut reader = tokio::io::BufReader::new(cursor);
631        let mut buf = String::new();
632        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
633        assert_eq!(n, 4);
634        assert_eq!(buf, "a\0b\n");
635    }
636
637    #[tokio::test]
638    async fn read_line_limited_crlf_line_ending() {
639        // CRLF: the \r is part of the line content, \n terminates.
640        let data = b"line\r\n";
641        let cursor = std::io::Cursor::new(data.to_vec());
642        let mut reader = tokio::io::BufReader::new(cursor);
643        let mut buf = String::new();
644        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
645        assert_eq!(n, 6);
646        assert_eq!(buf, "line\r\n");
647    }
648
649    #[tokio::test]
650    async fn read_line_limited_long_line_with_newline_at_boundary() {
651        // Line of MAX_LINE_BYTES - 1 chars + newline = exactly at limit.
652        let mut data = vec![b'Z'; MAX_LINE_BYTES - 1];
653        data.push(b'\n');
654        // Add trailing data to verify only one line is consumed.
655        data.extend_from_slice(b"next\n");
656        let cursor = std::io::Cursor::new(data);
657        let mut reader = tokio::io::BufReader::new(cursor);
658
659        let mut buf = String::new();
660        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
661        assert_eq!(n, MAX_LINE_BYTES);
662        assert!(buf.ends_with('\n'));
663        assert_eq!(buf.len(), MAX_LINE_BYTES);
664
665        // Second line should be readable independently.
666        buf.clear();
667        let n2 = read_line_limited(&mut reader, &mut buf).await.unwrap();
668        assert_eq!(n2, 5);
669        assert_eq!(buf, "next\n");
670    }
671}