Skip to main content

room_cli/broker/
mod.rs

1pub(crate) mod admin;
2pub(crate) mod auth;
3pub(crate) mod commands;
4pub mod daemon;
5pub(crate) mod fanout;
6pub(crate) mod handshake;
7pub(crate) mod service;
8pub(crate) mod state;
9pub(crate) mod ws;
10
11use std::{
12    collections::HashMap,
13    path::PathBuf,
14    sync::{
15        atomic::{AtomicU64, Ordering},
16        Arc,
17    },
18};
19
20use crate::{
21    history,
22    message::{make_join, make_leave, make_system, parse_client_line, Message},
23    plugin::{self, PluginRegistry},
24};
25use auth::{handle_oneshot_join, validate_token};
26use commands::{route_command, CommandResult};
27use fanout::{broadcast_and_persist, dm_and_persist};
28use room_protocol::SubscriptionTier;
29use state::RoomState;
30use tokio::{
31    io::{AsyncBufRead, AsyncBufReadExt, AsyncWriteExt, BufReader},
32    net::{
33        unix::{OwnedReadHalf, OwnedWriteHalf},
34        UnixListener, UnixStream,
35    },
36    sync::{broadcast, watch, Mutex},
37};
38
39/// Maximum bytes allowed in a single line from a client connection.
40/// Prevents memory exhaustion from malicious clients sending arbitrarily large lines.
41pub const MAX_LINE_BYTES: usize = 64 * 1024; // 64 KB
42
43/// Read a single newline-terminated line, rejecting lines that exceed `MAX_LINE_BYTES`.
44///
45/// Returns `Ok(n)` where `n` is the number of bytes read (0 = EOF).
46/// Returns an error if the accumulated bytes before a newline exceed the limit.
47///
48/// The line (including the trailing `\n`) is appended to `buf`, matching the
49/// behaviour of `AsyncBufReadExt::read_line`.
50pub(crate) async fn read_line_limited<R: AsyncBufRead + Unpin>(
51    reader: &mut R,
52    buf: &mut String,
53) -> anyhow::Result<usize> {
54    let mut total = 0usize;
55    loop {
56        let available = reader.fill_buf().await?;
57        if available.is_empty() {
58            // EOF
59            return Ok(total);
60        }
61        // Look for a newline in the buffered data.
62        let (chunk, found_newline) = match available.iter().position(|&b| b == b'\n') {
63            Some(pos) => (&available[..=pos], true),
64            None => (available, false),
65        };
66        let chunk_len = chunk.len();
67        if total + chunk_len > MAX_LINE_BYTES {
68            anyhow::bail!("line exceeds maximum size of {} bytes", MAX_LINE_BYTES);
69        }
70        // Safety: we validate UTF-8 before appending.
71        let text = std::str::from_utf8(chunk)
72            .map_err(|e| anyhow::anyhow!("invalid UTF-8 in client line: {e}"))?;
73        buf.push_str(text);
74        total += chunk_len;
75        reader.consume(chunk_len);
76        if found_newline {
77            return Ok(total);
78        }
79    }
80}
81
82pub struct Broker {
83    room_id: String,
84    chat_path: PathBuf,
85    /// Path to the persisted token-map file (e.g. `~/.room/state/<room_id>.tokens`).
86    token_map_path: PathBuf,
87    /// Path to the persisted subscription-map file (e.g. `~/.room/state/<room_id>.subscriptions`).
88    subscription_map_path: PathBuf,
89    socket_path: PathBuf,
90    ws_port: Option<u16>,
91}
92
93impl Broker {
94    pub fn new(
95        room_id: &str,
96        chat_path: PathBuf,
97        token_map_path: PathBuf,
98        subscription_map_path: PathBuf,
99        socket_path: PathBuf,
100        ws_port: Option<u16>,
101    ) -> Self {
102        Self {
103            room_id: room_id.to_owned(),
104            chat_path,
105            token_map_path,
106            subscription_map_path,
107            socket_path,
108            ws_port,
109        }
110    }
111
112    pub async fn run(self) -> anyhow::Result<()> {
113        // Remove stale socket synchronously — using tokio::fs here is dangerous
114        // because the blocking thread pool may be shutting down if the broker
115        // is starting up inside a dying process.
116        if self.socket_path.exists() {
117            std::fs::remove_file(&self.socket_path)?;
118        }
119
120        let listener = UnixListener::bind(&self.socket_path)?;
121        eprintln!("[broker] listening on {}", self.socket_path.display());
122
123        let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
124
125        let mut registry = PluginRegistry::new();
126        registry.register(Box::new(plugin::help::HelpPlugin))?;
127        registry.register(Box::new(plugin::stats::StatsPlugin))?;
128        registry.register(Box::new(plugin::status::StatusPlugin))?;
129
130        // Load persisted state from a previous broker session (if any).
131        let persisted_tokens = auth::load_token_map(&self.token_map_path);
132        if !persisted_tokens.is_empty() {
133            eprintln!(
134                "[broker] loaded {} persisted token(s)",
135                persisted_tokens.len()
136            );
137        }
138        let persisted_subs = commands::load_subscription_map(&self.subscription_map_path);
139        if !persisted_subs.is_empty() {
140            eprintln!(
141                "[broker] loaded {} persisted subscription(s)",
142                persisted_subs.len()
143            );
144        }
145
146        let state = Arc::new(RoomState {
147            clients: Arc::new(Mutex::new(HashMap::new())),
148            status_map: Arc::new(Mutex::new(HashMap::new())),
149            host_user: Arc::new(Mutex::new(None)),
150            token_map: Arc::new(Mutex::new(persisted_tokens)),
151            claim_map: Arc::new(Mutex::new(HashMap::new())),
152            subscription_map: Arc::new(Mutex::new(persisted_subs)),
153            chat_path: Arc::new(self.chat_path.clone()),
154            token_map_path: Arc::new(self.token_map_path.clone()),
155            subscription_map_path: Arc::new(self.subscription_map_path.clone()),
156            room_id: Arc::new(self.room_id.clone()),
157            shutdown: Arc::new(shutdown_tx),
158            seq_counter: Arc::new(AtomicU64::new(0)),
159            plugin_registry: Arc::new(registry),
160            config: None,
161            registry: std::sync::OnceLock::new(),
162        });
163        let next_client_id = Arc::new(AtomicU64::new(0));
164
165        // Start WebSocket/REST server if a port was configured.
166        if let Some(port) = self.ws_port {
167            let ws_state = ws::WsAppState {
168                room_state: state.clone(),
169                next_client_id: next_client_id.clone(),
170                user_registry: None,
171            };
172            let app = ws::create_router(ws_state);
173            let tcp = tokio::net::TcpListener::bind(("0.0.0.0", port)).await?;
174            eprintln!("[broker] WebSocket/REST listening on port {port}");
175            tokio::spawn(async move {
176                if let Err(e) = axum::serve(tcp, app).await {
177                    eprintln!("[broker] WS server error: {e}");
178                }
179            });
180        }
181
182        loop {
183            tokio::select! {
184                accept = listener.accept() => {
185                    let (stream, _) = accept?;
186                    let cid = next_client_id.fetch_add(1, Ordering::SeqCst) + 1;
187
188                    let (tx, _) = broadcast::channel::<String>(256);
189                    // Insert with empty username; handle_client updates it after handshake.
190                    state
191                        .clients
192                        .lock()
193                        .await
194                        .insert(cid, (String::new(), tx.clone()));
195
196                    let state_clone = state.clone();
197
198                    tokio::spawn(async move {
199                        if let Err(e) = handle_client(cid, stream, tx, &state_clone).await {
200                            eprintln!("[broker] client {cid} error: {e:#}");
201                        }
202                        state_clone.clients.lock().await.remove(&cid);
203                    });
204                }
205                _ = shutdown_rx.changed() => {
206                    eprintln!("[broker] shutdown requested, exiting");
207                    break Ok(());
208                }
209            }
210        }
211    }
212}
213
214async fn handle_client(
215    cid: u64,
216    stream: UnixStream,
217    own_tx: broadcast::Sender<String>,
218    state: &Arc<RoomState>,
219) -> anyhow::Result<()> {
220    let token_map = state.token_map.clone();
221
222    let (read_half, mut write_half) = stream.into_split();
223    let mut reader = BufReader::new(read_half);
224
225    // First line: username handshake, or one of the one-shot prefixes.
226    let mut first = String::new();
227    read_line_limited(&mut reader, &mut first).await?;
228    let first_line = first.trim();
229
230    use handshake::{parse_client_handshake, ClientHandshake};
231    let username = match parse_client_handshake(first_line) {
232        ClientHandshake::Send(u) => {
233            eprintln!(
234                "[broker] DEPRECATED: SEND:{u} handshake used — \
235                 migrate to TOKEN:<uuid> (SEND: will be removed in a future version)"
236            );
237            return handle_oneshot_send(u, reader, write_half, state).await;
238        }
239        ClientHandshake::Token(token) => {
240            return match validate_token(&token, &token_map).await {
241                Some(u) => handle_oneshot_send(u, reader, write_half, state).await,
242                None => {
243                    let err = serde_json::json!({"type":"error","code":"invalid_token"});
244                    write_half
245                        .write_all(format!("{err}\n").as_bytes())
246                        .await
247                        .map_err(Into::into)
248                }
249            };
250        }
251        ClientHandshake::Join(u) => {
252            let result = handle_oneshot_join(
253                u,
254                write_half,
255                &token_map,
256                &state.subscription_map,
257                state.config.as_ref(),
258                Some(&state.token_map_path),
259            )
260            .await;
261            // Persist auto-subscription from join so it survives broker restart.
262            commands::persist_subscriptions(state).await;
263            return result;
264        }
265        ClientHandshake::Session(token) => {
266            return match validate_token(&token, &token_map).await {
267                Some(u) => {
268                    if let Err(reason) = auth::check_join_permission(&u, state.config.as_ref()) {
269                        let err = serde_json::json!({
270                            "type": "error",
271                            "code": "join_denied",
272                            "message": reason,
273                            "username": u
274                        });
275                        write_half.write_all(format!("{err}\n").as_bytes()).await?;
276                        return Ok(());
277                    }
278                    run_interactive_session(cid, &u, reader, write_half, own_tx, state).await
279                }
280                None => {
281                    let err = serde_json::json!({"type":"error","code":"invalid_token"});
282                    write_half
283                        .write_all(format!("{err}\n").as_bytes())
284                        .await
285                        .map_err(Into::into)
286                }
287            };
288        }
289        ClientHandshake::Interactive(u) => {
290            eprintln!(
291                "[broker] DEPRECATED: unauthenticated interactive join for '{u}' — \
292                 migrate to SESSION:<token> (plain username joins will be removed in a future version)"
293            );
294            u
295        }
296    };
297
298    // Remaining path: deprecated unauthenticated interactive join.
299    if username.is_empty() {
300        return Ok(());
301    }
302
303    // Check join permission before entering interactive session.
304    if let Err(reason) = auth::check_join_permission(&username, state.config.as_ref()) {
305        let err = serde_json::json!({
306            "type": "error",
307            "code": "join_denied",
308            "message": reason,
309            "username": username
310        });
311        write_half.write_all(format!("{err}\n").as_bytes()).await?;
312        return Ok(());
313    }
314
315    run_interactive_session(cid, &username, reader, write_half, own_tx, state).await
316}
317
318/// Run an interactive client session after the username has been determined.
319///
320/// Shared by both single-room (`handle_client`) and daemon (`dispatch_connection`)
321/// paths. Handles: client registration, host election, history replay, join/leave
322/// events, inbound/outbound message loops, and cleanup.
323pub(crate) async fn run_interactive_session(
324    cid: u64,
325    username: &str,
326    reader: BufReader<OwnedReadHalf>,
327    mut write_half: OwnedWriteHalf,
328    own_tx: broadcast::Sender<String>,
329    state: &Arc<RoomState>,
330) -> anyhow::Result<()> {
331    let username = username.to_owned();
332
333    // Register username in the client map
334    {
335        let mut map = state.clients.lock().await;
336        if let Some(entry) = map.get_mut(&cid) {
337            entry.0 = username.clone();
338        }
339    }
340
341    // Register as host if no host has been set yet (first to complete handshake).
342    // Persist the host username to the room meta file so oneshot commands (poll,
343    // pull, query) can apply the same DM visibility rules without a live broker.
344    {
345        let mut host = state.host_user.lock().await;
346        if host.is_none() {
347            *host = Some(username.clone());
348            let meta_path = crate::paths::room_meta_path(&state.room_id);
349            if meta_path.exists() {
350                if let Ok(data) = std::fs::read_to_string(&meta_path) {
351                    if let Ok(mut v) = serde_json::from_str::<serde_json::Value>(&data) {
352                        v["host"] = serde_json::Value::String(username.clone());
353                        let _ = std::fs::write(&meta_path, v.to_string());
354                    }
355                }
356            }
357        }
358    }
359
360    eprintln!("[broker] {username} joined (cid={cid})");
361
362    // Track this user in the status map (empty status by default)
363    state
364        .status_map
365        .lock()
366        .await
367        .insert(username.clone(), String::new());
368
369    // Subscribe before sending history so we don't miss concurrent messages
370    let mut rx = own_tx.subscribe();
371
372    // Send chat history directly to this client's socket, filtering DMs the
373    // client is not party to (sender, recipient, or host).
374    // If the client disconnects mid-replay, treat it as a clean exit.
375    let host_name = state.host_user.lock().await.clone();
376    let history = history::load(&state.chat_path).await.unwrap_or_default();
377    for msg in &history {
378        if msg.is_visible_to(&username, host_name.as_deref()) {
379            let line = format!("{}\n", serde_json::to_string(msg)?);
380            if write_half.write_all(line.as_bytes()).await.is_err() {
381                return Ok(());
382            }
383        }
384    }
385
386    // Broadcast join event (also persists it)
387    let join_msg = make_join(&state.room_id, &username);
388    if let Err(e) = broadcast_and_persist(
389        &join_msg,
390        &state.clients,
391        &state.chat_path,
392        &state.seq_counter,
393    )
394    .await
395    {
396        eprintln!("[broker] broadcast_and_persist(join) failed: {e:#}");
397        return Ok(());
398    }
399    state.plugin_registry.notify_join(&username);
400
401    // Wrap write half in Arc<Mutex> for shared use by outbound and inbound tasks
402    let write_half = Arc::new(Mutex::new(write_half));
403
404    // Outbound: receive from broadcast channel, forward to client socket.
405    // Also listens for the shutdown signal; drains the channel first so the
406    // client sees the shutdown system message before receiving EOF.
407    let write_half_out = write_half.clone();
408    let mut shutdown_rx = state.shutdown.subscribe();
409    let outbound = tokio::spawn(async move {
410        loop {
411            tokio::select! {
412                result = rx.recv() => {
413                    match result {
414                        Ok(line) => {
415                            let mut wh = write_half_out.lock().await;
416                            if wh.write_all(line.as_bytes()).await.is_err() {
417                                break;
418                            }
419                        }
420                        Err(broadcast::error::RecvError::Lagged(n)) => {
421                            eprintln!("[broker] cid={cid} lagged by {n}");
422                        }
423                        Err(broadcast::error::RecvError::Closed) => break,
424                    }
425                }
426                _ = shutdown_rx.changed() => {
427                    // Drain any messages already queued (e.g. the shutdown notice)
428                    // before closing so the client sees them before receiving EOF.
429                    while let Ok(line) = rx.try_recv() {
430                        let mut wh = write_half_out.lock().await;
431                        let _ = wh.write_all(line.as_bytes()).await;
432                    }
433                    // Explicitly shut down the write side to send EOF to the client,
434                    // even though write_half_in (in the inbound task) still holds
435                    // the Arc — without this, the socket stays open.
436                    let _ = write_half_out.lock().await.shutdown().await;
437                    break;
438                }
439            }
440        }
441    });
442
443    // Inbound: read lines from client socket, parse and broadcast
444    let username_in = username.clone();
445    let room_id_in = state.room_id.clone();
446    let write_half_in = write_half.clone();
447    let state_in = state.clone();
448    let inbound = tokio::spawn(async move {
449        let mut reader = reader;
450        let mut line = String::new();
451        loop {
452            line.clear();
453            match read_line_limited(&mut reader, &mut line).await {
454                Ok(0) => break,
455                Ok(_) => {
456                    let trimmed = line.trim();
457                    if trimmed.is_empty() {
458                        continue;
459                    }
460                    match parse_client_line(trimmed, &room_id_in, &username_in) {
461                        Ok(msg) => match route_command(msg, &username_in, &state_in).await {
462                            Ok(CommandResult::Handled | CommandResult::HandledWithReply(_)) => {}
463                            Ok(CommandResult::Reply(json)) => {
464                                let _ = write_half_in
465                                    .lock()
466                                    .await
467                                    .write_all(format!("{json}\n").as_bytes())
468                                    .await;
469                            }
470                            Ok(CommandResult::Shutdown) => break,
471                            Ok(CommandResult::Passthrough(msg)) => {
472                                // DM privacy: reject sends from non-participants
473                                if let Err(reason) = auth::check_send_permission(
474                                    &username_in,
475                                    state_in.config.as_ref(),
476                                ) {
477                                    let err = serde_json::json!({
478                                        "type": "error",
479                                        "code": "send_denied",
480                                        "message": reason
481                                    });
482                                    let _ = write_half_in
483                                        .lock()
484                                        .await
485                                        .write_all(format!("{err}\n").as_bytes())
486                                        .await;
487                                    continue;
488                                }
489                                let is_broadcast = !matches!(&msg, Message::DirectMessage { .. });
490                                // Subscribe @mentioned users BEFORE broadcast so the
491                                // subscription is on disk before the message (#481).
492                                let newly_subscribed = if is_broadcast {
493                                    subscribe_mentioned(&msg, &state_in).await
494                                } else {
495                                    Vec::new()
496                                };
497                                let result = match &msg {
498                                    Message::DirectMessage { .. } => {
499                                        dm_and_persist(
500                                            &msg,
501                                            &state_in.host_user,
502                                            &state_in.clients,
503                                            &state_in.chat_path,
504                                            &state_in.seq_counter,
505                                        )
506                                        .await
507                                    }
508                                    _ => {
509                                        broadcast_and_persist(
510                                            &msg,
511                                            &state_in.clients,
512                                            &state_in.chat_path,
513                                            &state_in.seq_counter,
514                                        )
515                                        .await
516                                    }
517                                };
518                                if let Err(e) = &result {
519                                    eprintln!("[broker] persist error: {e:#}");
520                                }
521                                if !newly_subscribed.is_empty() && result.is_ok() {
522                                    broadcast_subscribe_notices(&newly_subscribed, &state_in).await;
523                                }
524                            }
525                            Err(e) => eprintln!("[broker] route error: {e:#}"),
526                        },
527                        Err(e) => eprintln!("[broker] bad message from {username_in}: {e}"),
528                    }
529                }
530                Err(e) => {
531                    eprintln!("[broker] read error from {username_in}: {e:#}");
532                    let err = serde_json::json!({
533                        "type": "error",
534                        "code": "line_too_long",
535                        "message": format!("{e}")
536                    });
537                    let _ = write_half_in
538                        .lock()
539                        .await
540                        .write_all(format!("{err}\n").as_bytes())
541                        .await;
542                    break;
543                }
544            }
545        }
546    });
547
548    tokio::select! {
549        _ = outbound => {},
550        _ = inbound => {},
551    }
552
553    // Remove user from status map on disconnect
554    state.status_map.lock().await.remove(&username);
555
556    // Broadcast leave event
557    let leave_msg = make_leave(&state.room_id, &username);
558    let _ = broadcast_and_persist(
559        &leave_msg,
560        &state.clients,
561        &state.chat_path,
562        &state.seq_counter,
563    )
564    .await;
565    state.plugin_registry.notify_leave(&username);
566    eprintln!("[broker] {username} left (cid={cid})");
567
568    Ok(())
569}
570
571/// Handle a one-shot SEND connection: read one message line, route it, echo it back, close.
572/// The sender is never registered in ClientMap/StatusMap and generates no join/leave events.
573/// DM envelopes are routed via `dm_and_persist`; all other messages are broadcast.
574pub(crate) async fn handle_oneshot_send(
575    username: String,
576    mut reader: BufReader<OwnedReadHalf>,
577    mut write_half: OwnedWriteHalf,
578    state: &RoomState,
579) -> anyhow::Result<()> {
580    let mut line = String::new();
581    read_line_limited(&mut reader, &mut line).await?;
582    let trimmed = line.trim();
583    if trimmed.is_empty() {
584        return Ok(());
585    }
586    let msg = parse_client_line(trimmed, &state.room_id, &username)?;
587    match route_command(msg, &username, state).await? {
588        CommandResult::Handled | CommandResult::Shutdown => {}
589        CommandResult::HandledWithReply(json) | CommandResult::Reply(json) => {
590            write_half.write_all(format!("{json}\n").as_bytes()).await?;
591        }
592        CommandResult::Passthrough(msg) => {
593            // DM privacy: reject sends from non-participants
594            if let Err(reason) = auth::check_send_permission(&username, state.config.as_ref()) {
595                let err = serde_json::json!({
596                    "type": "error",
597                    "code": "send_denied",
598                    "message": reason
599                });
600                write_half.write_all(format!("{err}\n").as_bytes()).await?;
601                return Ok(());
602            }
603            let is_broadcast = !matches!(&msg, Message::DirectMessage { .. });
604            // Subscribe @mentioned users BEFORE broadcast so the
605            // subscription is on disk before the message (#481).
606            let newly_subscribed = if is_broadcast {
607                subscribe_mentioned(&msg, state).await
608            } else {
609                Vec::new()
610            };
611            let seq_msg = match &msg {
612                Message::DirectMessage { .. } => {
613                    dm_and_persist(
614                        &msg,
615                        &state.host_user,
616                        &state.clients,
617                        &state.chat_path,
618                        &state.seq_counter,
619                    )
620                    .await?
621                }
622                _ => {
623                    broadcast_and_persist(
624                        &msg,
625                        &state.clients,
626                        &state.chat_path,
627                        &state.seq_counter,
628                    )
629                    .await?
630                }
631            };
632            if !newly_subscribed.is_empty() {
633                broadcast_subscribe_notices(&newly_subscribed, state).await;
634            }
635            let echo = format!("{}\n", serde_json::to_string(&seq_msg)?);
636            write_half.write_all(echo.as_bytes()).await?;
637        }
638    }
639    Ok(())
640}
641
642/// Subscribe @mentioned users who are not already subscribed (or are `Unsubscribed`).
643///
644/// Must be called BEFORE `broadcast_and_persist` so that the subscription exists
645/// on disk before the message is persisted to the chat file. This ensures poll-based
646/// room discovery (`discover_joined_rooms`) finds the room before the mention message
647/// is written, closing the race window described in #481.
648///
649/// Returns the list of newly subscribed usernames. Callers should pass this to
650/// [`broadcast_subscribe_notices`] after the message has been broadcast.
651async fn subscribe_mentioned(msg: &Message, state: &RoomState) -> Vec<String> {
652    let mentioned = msg.mentions();
653    if mentioned.is_empty() {
654        return Vec::new();
655    }
656
657    // Collect users to auto-subscribe (brief lock hold).
658    let newly_subscribed = {
659        let token_map = state.token_map.lock().await;
660        let registered: std::collections::HashSet<&str> =
661            token_map.values().map(String::as_str).collect();
662
663        let mut sub_map = state.subscription_map.lock().await;
664        let mut newly = Vec::new();
665
666        for username in &mentioned {
667            if !registered.contains(username.as_str()) {
668                continue;
669            }
670            let dominated = match sub_map.get(username.as_str()) {
671                None | Some(SubscriptionTier::Unsubscribed) => true,
672                Some(_) => false,
673            };
674            if dominated {
675                sub_map.insert(username.clone(), SubscriptionTier::MentionsOnly);
676                newly.push(username.clone());
677            }
678        }
679        newly
680    };
681
682    if !newly_subscribed.is_empty() {
683        // Persist the updated subscription map to disk so that
684        // `discover_joined_rooms` picks up the new room immediately.
685        commands::persist_subscriptions(state).await;
686    }
687
688    newly_subscribed
689}
690
691/// Broadcast system notices for users that were auto-subscribed by [`subscribe_mentioned`].
692///
693/// Call this AFTER the original message has been broadcast so that the notice
694/// appears after the mention in chat history.
695async fn broadcast_subscribe_notices(newly_subscribed: &[String], state: &RoomState) {
696    for username in newly_subscribed {
697        let notice = format!(
698            "{username} auto-subscribed at mentions_only (mentioned in {})",
699            state.room_id
700        );
701        let sys = make_system(&state.room_id, "broker", notice);
702        let _ =
703            broadcast_and_persist(&sys, &state.clients, &state.chat_path, &state.seq_counter).await;
704    }
705}
706
707#[cfg(test)]
708mod tests {
709    use super::*;
710    use crate::message::make_message;
711    use std::collections::HashMap;
712    use tokio::sync::watch;
713
714    fn make_test_state(chat_path: std::path::PathBuf) -> Arc<RoomState> {
715        let (shutdown_tx, _) = watch::channel(false);
716        Arc::new(RoomState {
717            clients: Arc::new(Mutex::new(HashMap::new())),
718            status_map: Arc::new(Mutex::new(HashMap::new())),
719            host_user: Arc::new(Mutex::new(None)),
720            token_map: Arc::new(Mutex::new(HashMap::new())),
721            claim_map: Arc::new(Mutex::new(HashMap::new())),
722            subscription_map: Arc::new(Mutex::new(HashMap::new())),
723            chat_path: Arc::new(chat_path.clone()),
724            token_map_path: Arc::new(chat_path.with_extension("tokens")),
725            subscription_map_path: Arc::new(chat_path.with_extension("subscriptions")),
726            room_id: Arc::new("test-room".to_owned()),
727            shutdown: Arc::new(shutdown_tx),
728            seq_counter: Arc::new(AtomicU64::new(0)),
729            plugin_registry: Arc::new(PluginRegistry::new()),
730            config: None,
731            registry: std::sync::OnceLock::new(),
732        })
733    }
734
735    #[tokio::test]
736    async fn auto_subscribe_skips_unregistered_users() {
737        let tmp = tempfile::NamedTempFile::new().unwrap();
738        let state = make_test_state(tmp.path().to_path_buf());
739        // Message mentions @alice but alice has no token — should not auto-subscribe.
740        let msg = make_message("test-room", "bob", "hey @alice check this");
741        subscribe_mentioned(&msg, &state).await;
742        assert!(state.subscription_map.lock().await.is_empty());
743    }
744
745    #[tokio::test]
746    async fn auto_subscribe_registers_mentions_only_for_unsubscribed() {
747        let tmp = tempfile::NamedTempFile::new().unwrap();
748        let state = make_test_state(tmp.path().to_path_buf());
749        // Register alice in token map.
750        state
751            .token_map
752            .lock()
753            .await
754            .insert("tok-alice".to_owned(), "alice".to_owned());
755        let msg = make_message("test-room", "bob", "hey @alice check this");
756        subscribe_mentioned(&msg, &state).await;
757        assert_eq!(
758            *state.subscription_map.lock().await.get("alice").unwrap(),
759            SubscriptionTier::MentionsOnly
760        );
761    }
762
763    #[tokio::test]
764    async fn auto_subscribe_skips_already_subscribed_full() {
765        let tmp = tempfile::NamedTempFile::new().unwrap();
766        let state = make_test_state(tmp.path().to_path_buf());
767        state
768            .token_map
769            .lock()
770            .await
771            .insert("tok-alice".to_owned(), "alice".to_owned());
772        state
773            .subscription_map
774            .lock()
775            .await
776            .insert("alice".to_owned(), SubscriptionTier::Full);
777        let msg = make_message("test-room", "bob", "hey @alice check this");
778        subscribe_mentioned(&msg, &state).await;
779        // Should remain Full, not downgraded to MentionsOnly.
780        assert_eq!(
781            *state.subscription_map.lock().await.get("alice").unwrap(),
782            SubscriptionTier::Full
783        );
784    }
785
786    #[tokio::test]
787    async fn auto_subscribe_skips_already_subscribed_mentions_only() {
788        let tmp = tempfile::NamedTempFile::new().unwrap();
789        let state = make_test_state(tmp.path().to_path_buf());
790        state
791            .token_map
792            .lock()
793            .await
794            .insert("tok-alice".to_owned(), "alice".to_owned());
795        state
796            .subscription_map
797            .lock()
798            .await
799            .insert("alice".to_owned(), SubscriptionTier::MentionsOnly);
800        let msg = make_message("test-room", "bob", "@alice ping");
801        subscribe_mentioned(&msg, &state).await;
802        assert_eq!(
803            *state.subscription_map.lock().await.get("alice").unwrap(),
804            SubscriptionTier::MentionsOnly
805        );
806    }
807
808    #[tokio::test]
809    async fn auto_subscribe_upgrades_unsubscribed_to_mentions_only() {
810        let tmp = tempfile::NamedTempFile::new().unwrap();
811        let state = make_test_state(tmp.path().to_path_buf());
812        state
813            .token_map
814            .lock()
815            .await
816            .insert("tok-alice".to_owned(), "alice".to_owned());
817        state
818            .subscription_map
819            .lock()
820            .await
821            .insert("alice".to_owned(), SubscriptionTier::Unsubscribed);
822        let msg = make_message("test-room", "bob", "@alice come back");
823        subscribe_mentioned(&msg, &state).await;
824        assert_eq!(
825            *state.subscription_map.lock().await.get("alice").unwrap(),
826            SubscriptionTier::MentionsOnly
827        );
828    }
829
830    #[tokio::test]
831    async fn auto_subscribe_handles_multiple_mentions() {
832        let tmp = tempfile::NamedTempFile::new().unwrap();
833        let state = make_test_state(tmp.path().to_path_buf());
834        {
835            let mut tokens = state.token_map.lock().await;
836            tokens.insert("tok-alice".to_owned(), "alice".to_owned());
837            tokens.insert("tok-carol".to_owned(), "carol".to_owned());
838        }
839        let msg = make_message("test-room", "bob", "@alice @carol @unknown review this");
840        subscribe_mentioned(&msg, &state).await;
841        let sub_map = state.subscription_map.lock().await;
842        assert_eq!(
843            *sub_map.get("alice").unwrap(),
844            SubscriptionTier::MentionsOnly
845        );
846        assert_eq!(
847            *sub_map.get("carol").unwrap(),
848            SubscriptionTier::MentionsOnly
849        );
850        assert!(sub_map.get("unknown").is_none());
851    }
852
853    #[tokio::test]
854    async fn auto_subscribe_no_op_for_message_without_mentions() {
855        let tmp = tempfile::NamedTempFile::new().unwrap();
856        let state = make_test_state(tmp.path().to_path_buf());
857        state
858            .token_map
859            .lock()
860            .await
861            .insert("tok-alice".to_owned(), "alice".to_owned());
862        let msg = make_message("test-room", "bob", "hello everyone");
863        subscribe_mentioned(&msg, &state).await;
864        assert!(state.subscription_map.lock().await.is_empty());
865    }
866
867    #[tokio::test]
868    async fn auto_subscribe_broadcasts_notice() {
869        let tmp = tempfile::NamedTempFile::new().unwrap();
870        let state = make_test_state(tmp.path().to_path_buf());
871        state
872            .token_map
873            .lock()
874            .await
875            .insert("tok-alice".to_owned(), "alice".to_owned());
876        let msg = make_message("test-room", "bob", "hey @alice");
877        let newly = subscribe_mentioned(&msg, &state).await;
878        broadcast_subscribe_notices(&newly, &state).await;
879        // Verify the auto-subscribe notice was persisted to chat history.
880        let history = std::fs::read_to_string(tmp.path()).unwrap();
881        assert!(history.contains("auto-subscribed"));
882        assert!(history.contains("alice"));
883        assert!(history.contains("mentions_only"));
884    }
885
886    #[tokio::test]
887    async fn auto_subscribe_persists_to_disk() {
888        let tmp = tempfile::NamedTempFile::new().unwrap();
889        let state = make_test_state(tmp.path().to_path_buf());
890        state
891            .token_map
892            .lock()
893            .await
894            .insert("tok-alice".to_owned(), "alice".to_owned());
895        let msg = make_message("test-room", "bob", "hey @alice");
896        subscribe_mentioned(&msg, &state).await;
897        // Verify subscriptions were persisted to the .subscriptions file.
898        let loaded = commands::load_subscription_map(&state.subscription_map_path);
899        assert_eq!(loaded.get("alice"), Some(&SubscriptionTier::MentionsOnly));
900    }
901
902    /// Regression test for #481: subscription must be persisted to disk BEFORE
903    /// the message is written to the chat file. This ensures `discover_joined_rooms`
904    /// finds the room before the mention message appears in history.
905    #[tokio::test]
906    async fn subscribe_mentioned_returns_newly_subscribed_before_broadcast() {
907        let tmp = tempfile::NamedTempFile::new().unwrap();
908        let state = make_test_state(tmp.path().to_path_buf());
909        state
910            .token_map
911            .lock()
912            .await
913            .insert("tok-alice".to_owned(), "alice".to_owned());
914        let msg = make_message("test-room", "bob", "hey @alice check this");
915
916        // Step 1: subscribe_mentioned runs before broadcast — subscription is on disk.
917        let newly = subscribe_mentioned(&msg, &state).await;
918        assert_eq!(newly, vec!["alice"]);
919
920        // Verify subscription is persisted BEFORE any message is in the chat file.
921        let loaded = commands::load_subscription_map(&state.subscription_map_path);
922        assert_eq!(loaded.get("alice"), Some(&SubscriptionTier::MentionsOnly));
923        // Chat file should still be empty (broadcast hasn't happened yet).
924        let chat_content = std::fs::read_to_string(tmp.path()).unwrap();
925        assert!(
926            chat_content.is_empty(),
927            "chat file must be empty before broadcast — subscription should precede message"
928        );
929
930        // Step 2: broadcast the message (simulating the real flow).
931        let seq_msg =
932            broadcast_and_persist(&msg, &state.clients, &state.chat_path, &state.seq_counter)
933                .await
934                .unwrap();
935        assert!(seq_msg.seq().is_some());
936
937        // Step 3: broadcast notices after the message.
938        broadcast_subscribe_notices(&newly, &state).await;
939
940        // Verify ordering: chat file has the message, then the notice.
941        let history = std::fs::read_to_string(tmp.path()).unwrap();
942        let lines: Vec<&str> = history.trim().lines().collect();
943        assert_eq!(lines.len(), 2, "expected message + notice");
944        assert!(lines[0].contains("hey @alice check this"));
945        assert!(lines[1].contains("auto-subscribed"));
946    }
947
948    // --- read_line_limited tests ---
949
950    #[tokio::test]
951    async fn read_line_limited_reads_normal_line() {
952        let data = b"hello world\n";
953        let cursor = std::io::Cursor::new(data.to_vec());
954        let mut reader = tokio::io::BufReader::new(cursor);
955        let mut buf = String::new();
956        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
957        assert_eq!(n, 12);
958        assert_eq!(buf, "hello world\n");
959    }
960
961    #[tokio::test]
962    async fn read_line_limited_reads_line_without_trailing_newline() {
963        let data = b"no newline";
964        let cursor = std::io::Cursor::new(data.to_vec());
965        let mut reader = tokio::io::BufReader::new(cursor);
966        let mut buf = String::new();
967        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
968        assert_eq!(n, 10);
969        assert_eq!(buf, "no newline");
970    }
971
972    #[tokio::test]
973    async fn read_line_limited_returns_zero_on_eof() {
974        let data = b"";
975        let cursor = std::io::Cursor::new(data.to_vec());
976        let mut reader = tokio::io::BufReader::new(cursor);
977        let mut buf = String::new();
978        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
979        assert_eq!(n, 0);
980        assert!(buf.is_empty());
981    }
982
983    #[tokio::test]
984    async fn read_line_limited_rejects_oversized_line() {
985        // Create a line that exceeds the limit (no newline, so it keeps reading).
986        let data = vec![b'A'; MAX_LINE_BYTES + 1];
987        let cursor = std::io::Cursor::new(data);
988        let mut reader = tokio::io::BufReader::new(cursor);
989        let mut buf = String::new();
990        let result = read_line_limited(&mut reader, &mut buf).await;
991        assert!(result.is_err());
992        let err_msg = result.unwrap_err().to_string();
993        assert!(
994            err_msg.contains("exceeds maximum size"),
995            "unexpected error: {err_msg}"
996        );
997    }
998
999    #[tokio::test]
1000    async fn read_line_limited_accepts_line_at_exact_limit() {
1001        // Line of exactly MAX_LINE_BYTES (including the newline).
1002        let mut data = vec![b'A'; MAX_LINE_BYTES - 1];
1003        data.push(b'\n');
1004        let cursor = std::io::Cursor::new(data);
1005        let mut reader = tokio::io::BufReader::new(cursor);
1006        let mut buf = String::new();
1007        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
1008        assert_eq!(n, MAX_LINE_BYTES);
1009        assert!(buf.ends_with('\n'));
1010    }
1011
1012    #[tokio::test]
1013    async fn read_line_limited_reads_multiple_lines() {
1014        let data = b"line one\nline two\n";
1015        let cursor = std::io::Cursor::new(data.to_vec());
1016        let mut reader = tokio::io::BufReader::new(cursor);
1017
1018        let mut buf = String::new();
1019        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
1020        assert_eq!(n, 9);
1021        assert_eq!(buf, "line one\n");
1022
1023        buf.clear();
1024        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
1025        assert_eq!(n, 9);
1026        assert_eq!(buf, "line two\n");
1027
1028        buf.clear();
1029        let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
1030        assert_eq!(n, 0);
1031    }
1032
1033    #[tokio::test]
1034    async fn read_line_limited_rejects_invalid_utf8() {
1035        let data: Vec<u8> = vec![0xFF, 0xFE, b'\n'];
1036        let cursor = std::io::Cursor::new(data);
1037        let mut reader = tokio::io::BufReader::new(cursor);
1038        let mut buf = String::new();
1039        let result = read_line_limited(&mut reader, &mut buf).await;
1040        assert!(result.is_err());
1041        let err_msg = result.unwrap_err().to_string();
1042        assert!(
1043            err_msg.contains("invalid UTF-8"),
1044            "unexpected error: {err_msg}"
1045        );
1046    }
1047}