1pub(crate) mod admin;
2pub(crate) mod auth;
3pub(crate) mod commands;
4pub mod daemon;
5pub(crate) mod fanout;
6pub(crate) mod handshake;
7pub mod persistence;
8pub(crate) mod service;
9pub(crate) mod session;
10pub(crate) mod state;
11pub(crate) mod token_store;
12pub(crate) mod ws;
13
14use std::{
15 collections::HashMap,
16 path::PathBuf,
17 sync::{
18 atomic::{AtomicU64, Ordering},
19 Arc,
20 },
21};
22
23use crate::plugin::PluginRegistry;
24use auth::{handle_oneshot_join, validate_token};
25use state::RoomState;
26use tokio::{
27 io::{AsyncBufRead, AsyncBufReadExt, AsyncWriteExt, BufReader},
28 net::{
29 unix::{OwnedReadHalf, OwnedWriteHalf},
30 UnixListener, UnixStream,
31 },
32 sync::{broadcast, watch, Mutex},
33};
34
35pub const MAX_LINE_BYTES: usize = 64 * 1024; pub(crate) async fn read_line_limited<R: AsyncBufRead + Unpin>(
47 reader: &mut R,
48 buf: &mut String,
49) -> anyhow::Result<usize> {
50 let mut total = 0usize;
51 loop {
52 let available = reader.fill_buf().await?;
53 if available.is_empty() {
54 return Ok(total);
56 }
57 let (chunk, found_newline) = match available.iter().position(|&b| b == b'\n') {
59 Some(pos) => (&available[..=pos], true),
60 None => (available, false),
61 };
62 let chunk_len = chunk.len();
63 if total + chunk_len > MAX_LINE_BYTES {
64 anyhow::bail!("line exceeds maximum size of {} bytes", MAX_LINE_BYTES);
65 }
66 let text = std::str::from_utf8(chunk)
68 .map_err(|e| anyhow::anyhow!("invalid UTF-8 in client line: {e}"))?;
69 buf.push_str(text);
70 total += chunk_len;
71 reader.consume(chunk_len);
72 if found_newline {
73 return Ok(total);
74 }
75 }
76}
77
78pub struct Broker {
79 room_id: String,
80 chat_path: PathBuf,
81 token_map_path: PathBuf,
83 subscription_map_path: PathBuf,
85 socket_path: PathBuf,
86 ws_port: Option<u16>,
87}
88
89impl Broker {
90 pub fn new(
91 room_id: &str,
92 chat_path: PathBuf,
93 token_map_path: PathBuf,
94 subscription_map_path: PathBuf,
95 socket_path: PathBuf,
96 ws_port: Option<u16>,
97 ) -> Self {
98 Self {
99 room_id: room_id.to_owned(),
100 chat_path,
101 token_map_path,
102 subscription_map_path,
103 socket_path,
104 ws_port,
105 }
106 }
107
108 pub async fn run(self) -> anyhow::Result<()> {
109 if self.socket_path.exists() {
113 std::fs::remove_file(&self.socket_path)?;
114 }
115
116 let listener = UnixListener::bind(&self.socket_path)?;
117 eprintln!("[broker] listening on {}", self.socket_path.display());
118
119 let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
120
121 let registry = PluginRegistry::with_all_plugins(&self.chat_path)?;
122
123 let persisted_tokens = token_store::load_token_map(&self.token_map_path);
125 if !persisted_tokens.is_empty() {
126 eprintln!(
127 "[broker] loaded {} persisted token(s)",
128 persisted_tokens.len()
129 );
130 }
131 let persisted_subs = persistence::load_subscription_map(&self.subscription_map_path);
132 if !persisted_subs.is_empty() {
133 eprintln!(
134 "[broker] loaded {} persisted subscription(s)",
135 persisted_subs.len()
136 );
137 }
138
139 let state = Arc::new(RoomState {
140 clients: Arc::new(Mutex::new(HashMap::new())),
141 status_map: Arc::new(Mutex::new(HashMap::new())),
142 status_timestamps: Arc::new(Mutex::new(HashMap::new())),
143 last_message_times: Arc::new(Mutex::new(HashMap::new())),
144 host_user: Arc::new(Mutex::new(None)),
145 auth: state::AuthState {
146 token_map: Arc::new(Mutex::new(persisted_tokens)),
147 token_map_path: Arc::new(self.token_map_path.clone()),
148 registry: std::sync::OnceLock::new(),
149 },
150 filters: state::FilterState {
151 subscription_map: Arc::new(Mutex::new(persisted_subs)),
152 subscription_map_path: Arc::new(self.subscription_map_path.clone()),
153 event_filter_state: std::sync::OnceLock::new(),
154 },
155 chat_path: Arc::new(self.chat_path.clone()),
156 room_id: Arc::new(self.room_id.clone()),
157 shutdown: Arc::new(shutdown_tx),
158 seq_counter: Arc::new(AtomicU64::new(crate::history::max_seq_from_history(
159 &self.chat_path,
160 ))),
161 plugin_registry: Arc::new(registry),
162 config: None,
163 cross_room_resolver: std::sync::OnceLock::new(),
164 });
165 {
167 let ef_path = self.subscription_map_path.with_extension("event_filters");
168 let persisted_ef = persistence::load_event_filter_map(&ef_path);
169 if !persisted_ef.is_empty() {
170 eprintln!(
171 "[broker] loaded {} persisted event filter(s)",
172 persisted_ef.len()
173 );
174 }
175 state.set_event_filter_map(Arc::new(Mutex::new(persisted_ef)), ef_path);
176 }
177
178 let next_client_id = Arc::new(AtomicU64::new(0));
179
180 if let Some(port) = self.ws_port {
182 let ws_state = ws::WsAppState {
183 room_state: state.clone(),
184 next_client_id: next_client_id.clone(),
185 user_registry: None,
186 };
187 let app = ws::create_router(ws_state);
188 let tcp = tokio::net::TcpListener::bind(("0.0.0.0", port)).await?;
189 eprintln!("[broker] WebSocket/REST listening on port {port}");
190 tokio::spawn(async move {
191 if let Err(e) = axum::serve(tcp, app).await {
192 eprintln!("[broker] WS server error: {e}");
193 }
194 });
195 }
196
197 loop {
198 tokio::select! {
199 accept = listener.accept() => {
200 let (stream, _) = accept?;
201 let cid = next_client_id.fetch_add(1, Ordering::SeqCst) + 1;
202
203 let (tx, _) = broadcast::channel::<String>(256);
204 state
206 .clients
207 .lock()
208 .await
209 .insert(cid, (String::new(), tx.clone()));
210
211 let state_clone = state.clone();
212
213 tokio::spawn(async move {
214 if let Err(e) = handle_client(cid, stream, tx, &state_clone).await {
215 eprintln!("[broker] client {cid} error: {e:#}");
216 }
217 state_clone.clients.lock().await.remove(&cid);
218 });
219 }
220 _ = shutdown_rx.changed() => {
221 eprintln!("[broker] shutdown requested, exiting");
222 break Ok(());
223 }
224 }
225 }
226 }
227}
228
229async fn handle_client(
230 cid: u64,
231 stream: UnixStream,
232 own_tx: broadcast::Sender<String>,
233 state: &Arc<RoomState>,
234) -> anyhow::Result<()> {
235 let token_map = state.auth.token_map.clone();
236
237 let (read_half, mut write_half) = stream.into_split();
238 let mut reader = BufReader::new(read_half);
239
240 let mut first = String::new();
242 read_line_limited(&mut reader, &mut first).await?;
243 let first_line = first.trim();
244
245 use handshake::{parse_client_handshake, ClientHandshake};
246 let username = match parse_client_handshake(first_line) {
247 ClientHandshake::Send(u) => {
248 eprintln!(
249 "[broker] DEPRECATED: SEND:{u} handshake used — \
250 migrate to TOKEN:<uuid> (SEND: will be removed in a future version)"
251 );
252 return handle_oneshot_send(u, reader, write_half, state).await;
253 }
254 ClientHandshake::Token(token) => {
255 return match validate_token(&token, &token_map).await {
256 Some(u) => handle_oneshot_send(u, reader, write_half, state).await,
257 None => {
258 let err = serde_json::json!({"type":"error","code":"invalid_token"});
259 write_half
260 .write_all(format!("{err}\n").as_bytes())
261 .await
262 .map_err(Into::into)
263 }
264 };
265 }
266 ClientHandshake::Join(u) => {
267 let result = handle_oneshot_join(
268 u,
269 write_half,
270 &token_map,
271 &state.filters.subscription_map,
272 state.config.as_ref(),
273 Some(&state.auth.token_map_path),
274 )
275 .await;
276 persistence::persist_subscriptions(state).await;
278 return result;
279 }
280 ClientHandshake::Session(token) => {
281 return match validate_token(&token, &token_map).await {
282 Some(u) => {
283 if let Err(reason) = auth::check_join_permission(&u, state.config.as_ref()) {
284 let err = serde_json::json!({
285 "type": "error",
286 "code": "join_denied",
287 "message": reason,
288 "username": u
289 });
290 write_half.write_all(format!("{err}\n").as_bytes()).await?;
291 return Ok(());
292 }
293 run_interactive_session(cid, &u, reader, write_half, own_tx, state).await
294 }
295 None => {
296 let err = serde_json::json!({"type":"error","code":"invalid_token"});
297 write_half
298 .write_all(format!("{err}\n").as_bytes())
299 .await
300 .map_err(Into::into)
301 }
302 };
303 }
304 ClientHandshake::Interactive(u) => {
305 eprintln!(
306 "[broker] DEPRECATED: unauthenticated interactive join for '{u}' — \
307 migrate to SESSION:<token> (plain username joins will be removed in a future version)"
308 );
309 u
310 }
311 };
312
313 if username.is_empty() {
315 return Ok(());
316 }
317
318 if let Err(reason) = auth::check_join_permission(&username, state.config.as_ref()) {
320 let err = serde_json::json!({
321 "type": "error",
322 "code": "join_denied",
323 "message": reason,
324 "username": username
325 });
326 write_half.write_all(format!("{err}\n").as_bytes()).await?;
327 return Ok(());
328 }
329
330 run_interactive_session(cid, &username, reader, write_half, own_tx, state).await
331}
332
333pub(crate) async fn run_interactive_session(
340 cid: u64,
341 username: &str,
342 reader: BufReader<OwnedReadHalf>,
343 mut write_half: OwnedWriteHalf,
344 own_tx: broadcast::Sender<String>,
345 state: &Arc<RoomState>,
346) -> anyhow::Result<()> {
347 let username = username.to_owned();
348
349 let mut rx = own_tx.subscribe();
351
352 let history_lines = match session::session_setup(cid, &username, state).await {
354 Ok(lines) => lines,
355 Err(e) => {
356 eprintln!("[broker] session_setup failed: {e:#}");
357 return Ok(());
358 }
359 };
360
361 for line in &history_lines {
363 if write_half
364 .write_all(format!("{line}\n").as_bytes())
365 .await
366 .is_err()
367 {
368 return Ok(());
369 }
370 }
371
372 let write_half = Arc::new(Mutex::new(write_half));
374
375 let write_half_out = write_half.clone();
377 let mut shutdown_rx = state.shutdown.subscribe();
378 let outbound = tokio::spawn(async move {
379 loop {
380 tokio::select! {
381 result = rx.recv() => {
382 match result {
383 Ok(line) => {
384 let mut wh = write_half_out.lock().await;
385 if wh.write_all(line.as_bytes()).await.is_err() {
386 break;
387 }
388 }
389 Err(broadcast::error::RecvError::Lagged(n)) => {
390 eprintln!("[broker] cid={cid} lagged by {n}");
391 }
392 Err(broadcast::error::RecvError::Closed) => break,
393 }
394 }
395 _ = shutdown_rx.changed() => {
396 while let Ok(line) = rx.try_recv() {
397 let mut wh = write_half_out.lock().await;
398 let _ = wh.write_all(line.as_bytes()).await;
399 }
400 let _ = write_half_out.lock().await.shutdown().await;
401 break;
402 }
403 }
404 }
405 });
406
407 let username_in = username.clone();
409 let write_half_in = write_half.clone();
410 let state_in = state.clone();
411 let inbound = tokio::spawn(async move {
412 let mut reader = reader;
413 let mut line = String::new();
414 loop {
415 line.clear();
416 match read_line_limited(&mut reader, &mut line).await {
417 Ok(0) => break,
418 Ok(_) => {
419 let trimmed = line.trim();
420 if trimmed.is_empty() {
421 continue;
422 }
423 match session::process_inbound_message(trimmed, &username_in, &state_in).await {
424 session::InboundResult::Ok => {}
425 session::InboundResult::Reply(json) => {
426 let _ = write_half_in
427 .lock()
428 .await
429 .write_all(format!("{json}\n").as_bytes())
430 .await;
431 }
432 session::InboundResult::Shutdown => break,
433 }
434 }
435 Err(e) => {
436 eprintln!("[broker] read error from {username_in}: {e:#}");
437 let err = serde_json::json!({
438 "type": "error",
439 "code": "line_too_long",
440 "message": format!("{e}")
441 });
442 let _ = write_half_in
443 .lock()
444 .await
445 .write_all(format!("{err}\n").as_bytes())
446 .await;
447 break;
448 }
449 }
450 }
451 });
452
453 tokio::select! {
454 _ = outbound => {},
455 _ = inbound => {},
456 }
457
458 session::session_teardown(cid, &username, state).await;
460
461 Ok(())
462}
463
464pub(crate) async fn handle_oneshot_send(
468 username: String,
469 mut reader: BufReader<OwnedReadHalf>,
470 mut write_half: OwnedWriteHalf,
471 state: &RoomState,
472) -> anyhow::Result<()> {
473 let mut line = String::new();
474 read_line_limited(&mut reader, &mut line).await?;
475 let trimmed = line.trim();
476 if trimmed.is_empty() {
477 return Ok(());
478 }
479 let session::OneshotResult::Reply(reply) =
480 session::process_oneshot_send(trimmed, &username, state).await?;
481 write_half
482 .write_all(format!("{reply}\n").as_bytes())
483 .await?;
484 Ok(())
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490
491 #[tokio::test]
494 async fn read_line_limited_reads_normal_line() {
495 let data = b"hello world\n";
496 let cursor = std::io::Cursor::new(data.to_vec());
497 let mut reader = tokio::io::BufReader::new(cursor);
498 let mut buf = String::new();
499 let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
500 assert_eq!(n, 12);
501 assert_eq!(buf, "hello world\n");
502 }
503
504 #[tokio::test]
505 async fn read_line_limited_reads_line_without_trailing_newline() {
506 let data = b"no newline";
507 let cursor = std::io::Cursor::new(data.to_vec());
508 let mut reader = tokio::io::BufReader::new(cursor);
509 let mut buf = String::new();
510 let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
511 assert_eq!(n, 10);
512 assert_eq!(buf, "no newline");
513 }
514
515 #[tokio::test]
516 async fn read_line_limited_returns_zero_on_eof() {
517 let data = b"";
518 let cursor = std::io::Cursor::new(data.to_vec());
519 let mut reader = tokio::io::BufReader::new(cursor);
520 let mut buf = String::new();
521 let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
522 assert_eq!(n, 0);
523 assert!(buf.is_empty());
524 }
525
526 #[tokio::test]
527 async fn read_line_limited_rejects_oversized_line() {
528 let data = vec![b'A'; MAX_LINE_BYTES + 1];
529 let cursor = std::io::Cursor::new(data);
530 let mut reader = tokio::io::BufReader::new(cursor);
531 let mut buf = String::new();
532 let result = read_line_limited(&mut reader, &mut buf).await;
533 assert!(result.is_err());
534 let err_msg = result.unwrap_err().to_string();
535 assert!(
536 err_msg.contains("exceeds maximum size"),
537 "unexpected error: {err_msg}"
538 );
539 }
540
541 #[tokio::test]
542 async fn read_line_limited_accepts_line_at_exact_limit() {
543 let mut data = vec![b'A'; MAX_LINE_BYTES - 1];
544 data.push(b'\n');
545 let cursor = std::io::Cursor::new(data);
546 let mut reader = tokio::io::BufReader::new(cursor);
547 let mut buf = String::new();
548 let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
549 assert_eq!(n, MAX_LINE_BYTES);
550 assert!(buf.ends_with('\n'));
551 }
552
553 #[tokio::test]
554 async fn read_line_limited_reads_multiple_lines() {
555 let data = b"line one\nline two\n";
556 let cursor = std::io::Cursor::new(data.to_vec());
557 let mut reader = tokio::io::BufReader::new(cursor);
558
559 let mut buf = String::new();
560 let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
561 assert_eq!(n, 9);
562 assert_eq!(buf, "line one\n");
563
564 buf.clear();
565 let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
566 assert_eq!(n, 9);
567 assert_eq!(buf, "line two\n");
568
569 buf.clear();
570 let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
571 assert_eq!(n, 0);
572 }
573
574 #[tokio::test]
575 async fn read_line_limited_rejects_invalid_utf8() {
576 let data: Vec<u8> = vec![0xFF, 0xFE, b'\n'];
577 let cursor = std::io::Cursor::new(data);
578 let mut reader = tokio::io::BufReader::new(cursor);
579 let mut buf = String::new();
580 let result = read_line_limited(&mut reader, &mut buf).await;
581 assert!(result.is_err());
582 let err_msg = result.unwrap_err().to_string();
583 assert!(
584 err_msg.contains("invalid UTF-8"),
585 "unexpected error: {err_msg}"
586 );
587 }
588
589 #[tokio::test]
590 async fn read_line_limited_exact_limit_no_newline_accepted() {
591 let data = vec![b'X'; MAX_LINE_BYTES];
593 let cursor = std::io::Cursor::new(data);
594 let mut reader = tokio::io::BufReader::new(cursor);
595 let mut buf = String::new();
596 let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
597 assert_eq!(n, MAX_LINE_BYTES);
598 assert_eq!(buf.len(), MAX_LINE_BYTES);
599 }
600
601 #[tokio::test]
602 async fn read_line_limited_just_over_limit_no_newline_rejected() {
603 let data = vec![b'Y'; MAX_LINE_BYTES + 1];
605 let cursor = std::io::Cursor::new(data);
606 let mut reader = tokio::io::BufReader::new(cursor);
607 let mut buf = String::new();
608 let result = read_line_limited(&mut reader, &mut buf).await;
609 assert!(result.is_err());
610 assert!(result.unwrap_err().to_string().contains("exceeds maximum"));
611 }
612
613 #[tokio::test]
614 async fn read_line_limited_appends_to_existing_buffer() {
615 let data = b"world\n";
617 let cursor = std::io::Cursor::new(data.to_vec());
618 let mut reader = tokio::io::BufReader::new(cursor);
619 let mut buf = String::from("hello ");
620 let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
621 assert_eq!(n, 6);
622 assert_eq!(buf, "hello world\n");
623 }
624
625 #[tokio::test]
626 async fn read_line_limited_embedded_null_bytes() {
627 let data: Vec<u8> = vec![b'a', 0x00, b'b', b'\n'];
629 let cursor = std::io::Cursor::new(data);
630 let mut reader = tokio::io::BufReader::new(cursor);
631 let mut buf = String::new();
632 let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
633 assert_eq!(n, 4);
634 assert_eq!(buf, "a\0b\n");
635 }
636
637 #[tokio::test]
638 async fn read_line_limited_crlf_line_ending() {
639 let data = b"line\r\n";
641 let cursor = std::io::Cursor::new(data.to_vec());
642 let mut reader = tokio::io::BufReader::new(cursor);
643 let mut buf = String::new();
644 let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
645 assert_eq!(n, 6);
646 assert_eq!(buf, "line\r\n");
647 }
648
649 #[tokio::test]
650 async fn read_line_limited_long_line_with_newline_at_boundary() {
651 let mut data = vec![b'Z'; MAX_LINE_BYTES - 1];
653 data.push(b'\n');
654 data.extend_from_slice(b"next\n");
656 let cursor = std::io::Cursor::new(data);
657 let mut reader = tokio::io::BufReader::new(cursor);
658
659 let mut buf = String::new();
660 let n = read_line_limited(&mut reader, &mut buf).await.unwrap();
661 assert_eq!(n, MAX_LINE_BYTES);
662 assert!(buf.ends_with('\n'));
663 assert_eq!(buf.len(), MAX_LINE_BYTES);
664
665 buf.clear();
667 let n2 = read_line_limited(&mut reader, &mut buf).await.unwrap();
668 assert_eq!(n2, 5);
669 assert_eq!(buf, "next\n");
670 }
671}