use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use anyhow::{anyhow, bail, Result};
use base64::engine::general_purpose::STANDARD as B64;
use base64::Engine;
use ed25519_dalek::{Signature, VerifyingKey};
use futures_util::{SinkExt, StreamExt};
use rand::RngCore;
use rusqlite::{params, Connection};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use tokio::io::AsyncWriteExt;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{mpsc, oneshot, Mutex, Semaphore};
use tokio_tungstenite::tungstenite::Message as WsMessage;
use tokio_tungstenite::WebSocketStream;
use tracing::{debug, info, warn};
const RELAY_AUTH_DOMAIN: &[u8] = b"huddle-relay-auth-v1";
const MAILBOX_TTL_SECS: i64 = 30 * 24 * 60 * 60;
const OUTBOUND_CAP: usize = 256;
const PRE_AUTH_TIMEOUT_SECS: u64 = 20;
const CONNECT_TOKEN_TTL_SECS: i64 = 5 * 60;
const CONNECT_TOKEN_ALPHABET: &[u8] = b"0123456789ABCDEFGHJKMNPQRSTVWXYZ";
const CONNECT_TOKEN_LEN: usize = 8;
const MAX_CONNECT_TOKENS: usize = 50_000;
const MAX_PAYLOAD_B64: usize = 256 * 1024;
const MAX_ID_LEN: usize = 128;
const MAX_MSG_ID_LEN: usize = 256;
const MAX_MAILBOX_PER_FP: usize = 500;
const MAX_ROOMS_PER_HELLO: usize = 1000;
const MAX_CONNECTIONS_PER_FP: usize = 16;
const MAX_TOTAL_CONNECTIONS: usize = 4096;
#[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ClientMsg {
Hello {
fingerprint: String,
#[serde(default)]
pubkey_b64: String,
#[serde(default)]
signature_b64: String,
#[serde(default)]
rooms: Vec<String>,
#[serde(default)]
acks: bool,
},
Subscribe {
room: String,
},
Unsubscribe {
room: String,
},
Publish {
room: String,
id: String,
payload_b64: String,
},
SendDirect {
to: String,
room: String,
id: String,
payload_b64: String,
},
CreateConnectToken,
RedeemConnectToken {
token: String,
},
Fetch,
Ack { mailbox_id: i64 },
Ping,
}
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ServerMsg {
Challenge { nonce_b64: String },
Ready { fingerprint: String },
Message {
room: String,
id: String,
payload_b64: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
mailbox_id: Option<i64>,
},
Sent { id: String, delivered: usize, queued: usize },
ConnectToken { token: String, ttl_secs: u64 },
ConnectTokenResolved {
token: String,
fingerprint: Option<String>,
pubkey_b64: Option<String>,
},
Pong,
Error { message: String },
}
enum OutEvent {
Msg(ServerMsg),
Flush(oneshot::Sender<()>),
}
type Tx = mpsc::Sender<OutEvent>;
struct ConnectTokenEntry {
fingerprint: String,
pubkey_b64: String,
expires_at: i64,
}
#[derive(Default)]
struct ConnectTokens {
by_token: HashMap<String, ConnectTokenEntry>,
by_fp: HashMap<String, String>,
}
impl ConnectTokens {
fn gc(&mut self, now: i64) {
let expired: Vec<String> = self
.by_token
.iter()
.filter(|(_, e)| e.expires_at <= now)
.map(|(t, _)| t.clone())
.collect();
for t in expired {
if let Some(e) = self.by_token.remove(&t) {
if self.by_fp.get(&e.fingerprint) == Some(&t) {
self.by_fp.remove(&e.fingerprint);
}
}
}
}
}
#[derive(Default)]
struct Metrics {
active_connections: AtomicU64,
publish_delivered: AtomicU64,
publish_queued: AtomicU64,
outbound_cap_drops: AtomicU64,
pre_auth_timeouts: AtomicU64,
}
struct Shared {
db: Mutex<Connection>,
conns: Mutex<HashMap<String, Vec<Tx>>>,
tokens: Mutex<ConnectTokens>,
metrics: Metrics,
}
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
)
.init();
let bind = std::env::var("HUDDLE_SERVER_BIND").unwrap_or_else(|_| "127.0.0.1:8787".to_string());
let db_path = std::env::var("HUDDLE_SERVER_DB").unwrap_or_else(|_| "huddle-server.db".to_string());
let conn = Connection::open(&db_path)?;
migrate(&conn)?;
let shared = Arc::new(Shared {
db: Mutex::new(conn),
conns: Mutex::new(HashMap::new()),
tokens: Mutex::new(ConnectTokens::default()),
metrics: Metrics::default(),
});
{
let shared = shared.clone();
tokio::spawn(async move {
let mut tick = tokio::time::interval(std::time::Duration::from_secs(3600));
loop {
tick.tick().await;
let cutoff = now_unix() - MAILBOX_TTL_SECS;
let db = shared.db.lock().await;
match db.execute("DELETE FROM mailbox WHERE created_at < ?1", params![cutoff]) {
Ok(n) if n > 0 => {
debug!(removed = n, "mailbox GC: dropped expired ciphertext")
}
Ok(_) => {}
Err(e) => debug!(error = %e, "mailbox GC failed"),
}
}
});
}
let listener = TcpListener::bind(&bind).await?;
info!(%bind, db = %db_path, "huddle-server listening (WebSocket + /health + /metrics)");
let conn_limit = Arc::new(Semaphore::new(MAX_TOTAL_CONNECTIONS));
loop {
let (stream, _peer) = listener.accept().await?;
let permit = match conn_limit.clone().try_acquire_owned() {
Ok(p) => p,
Err(_) => {
warn!(
cap = MAX_TOTAL_CONNECTIONS,
"connection limit reached — dropping new connection"
);
drop(stream);
continue;
}
};
let shared = shared.clone();
tokio::spawn(async move {
let _permit = permit;
shared.metrics.active_connections.fetch_add(1, Ordering::Relaxed);
if let Err(e) = handle_conn(stream, shared.clone()).await {
debug!(error = %e, "connection ended");
}
shared.metrics.active_connections.fetch_sub(1, Ordering::Relaxed);
});
}
}
async fn handle_conn(stream: TcpStream, shared: Arc<Shared>) -> Result<()> {
let mut buf = [0u8; 1024];
let dur = std::time::Duration::from_secs(PRE_AUTH_TIMEOUT_SECS);
let pre_ws_deadline = tokio::time::Instant::now() + dur;
let n = match tokio::time::timeout_at(pre_ws_deadline, stream.peek(&mut buf)).await {
Ok(r) => r?,
Err(_) => {
shared.metrics.pre_auth_timeouts.fetch_add(1, Ordering::Relaxed);
return Ok(());
}
};
let head = String::from_utf8_lossy(&buf[..n]);
if head.to_ascii_lowercase().contains("upgrade: websocket") {
let config = tokio_tungstenite::tungstenite::protocol::WebSocketConfig {
max_message_size: Some(512 * 1024),
max_frame_size: Some(512 * 1024),
..Default::default()
};
let ws = match tokio::time::timeout_at(
pre_ws_deadline,
tokio_tungstenite::accept_async_with_config(stream, Some(config)),
)
.await
{
Ok(r) => r?,
Err(_) => {
shared.metrics.pre_auth_timeouts.fetch_add(1, Ordering::Relaxed);
return Ok(());
}
};
serve_ws(ws, shared).await
} else {
match request_target(&head) {
"/health" => serve_health(stream).await,
"/metrics" => serve_metrics(stream, &shared).await,
_ => serve_landing(stream).await,
}
}
}
async fn serve_health(mut stream: TcpStream) -> Result<()> {
let body = r#"{"ok":true,"service":"huddle-server"}"#;
let resp = format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
body.len(),
body
);
stream.write_all(resp.as_bytes()).await?;
stream.flush().await?;
Ok(())
}
async fn serve_metrics(mut stream: TcpStream, shared: &Arc<Shared>) -> Result<()> {
let m = &shared.metrics;
let active = m.active_connections.load(Ordering::Relaxed);
let delivered = m.publish_delivered.load(Ordering::Relaxed);
let queued = m.publish_queued.load(Ordering::Relaxed);
let cap_drops = m.outbound_cap_drops.load(Ordering::Relaxed);
let pre_auth = m.pre_auth_timeouts.load(Ordering::Relaxed);
let mailbox_rows: i64 = {
let db = shared.db.lock().await;
db.query_row("SELECT COUNT(*) FROM mailbox", [], |r| r.get(0))
.unwrap_or(0)
};
let body = format!(
"# HELP huddle_active_connections Accepted sockets currently being handled.\n\
# TYPE huddle_active_connections gauge\n\
huddle_active_connections {active}\n\
# HELP huddle_max_connections Configured ceiling on concurrent connections.\n\
# TYPE huddle_max_connections gauge\n\
huddle_max_connections {max_conns}\n\
# HELP huddle_mailbox_rows Queued mailbox rows across all recipients.\n\
# TYPE huddle_mailbox_rows gauge\n\
huddle_mailbox_rows {mailbox_rows}\n\
# HELP huddle_publish_delivered_total Messages fanned out live to a connected recipient.\n\
# TYPE huddle_publish_delivered_total counter\n\
huddle_publish_delivered_total {delivered}\n\
# HELP huddle_publish_queued_total Messages queued to an offline recipient's mailbox.\n\
# TYPE huddle_publish_queued_total counter\n\
huddle_publish_queued_total {queued}\n\
# HELP huddle_outbound_cap_drops_total Fan-out sends dropped because a recipient's outbound queue was full.\n\
# TYPE huddle_outbound_cap_drops_total counter\n\
huddle_outbound_cap_drops_total {cap_drops}\n\
# HELP huddle_pre_auth_timeouts_total Connections dropped for not authenticating before the deadline.\n\
# TYPE huddle_pre_auth_timeouts_total counter\n\
huddle_pre_auth_timeouts_total {pre_auth}\n",
max_conns = MAX_TOTAL_CONNECTIONS,
);
let resp = format!(
"HTTP/1.1 200 OK\r\nContent-Type: text/plain; version=0.0.4; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
body.len(),
body
);
stream.write_all(resp.as_bytes()).await?;
stream.flush().await?;
Ok(())
}
const LANDING_HTML: &str = include_str!("landing.html");
fn request_target(head: &str) -> &str {
head.lines()
.next()
.and_then(|line| line.split_whitespace().nth(1))
.map(|target| target.split('?').next().unwrap_or("/"))
.unwrap_or("/")
}
async fn serve_landing(mut stream: TcpStream) -> Result<()> {
let resp = format!(
"HTTP/1.1 200 OK\r\n\
Content-Type: text/html; charset=utf-8\r\n\
Content-Length: {}\r\n\
Content-Security-Policy: default-src 'none'; style-src 'unsafe-inline'; img-src data:; base-uri 'none'; form-action 'none'; frame-ancestors 'none'\r\n\
X-Content-Type-Options: nosniff\r\n\
Referrer-Policy: no-referrer\r\n\
Connection: close\r\n\
\r\n{}",
LANDING_HTML.len(),
LANDING_HTML
);
stream.write_all(resp.as_bytes()).await?;
stream.flush().await?;
Ok(())
}
async fn serve_ws(ws: WebSocketStream<TcpStream>, shared: Arc<Shared>) -> Result<()> {
let (mut sink, mut stream) = ws.split();
let (tx, mut rx) = mpsc::channel::<OutEvent>(OUTBOUND_CAP);
let writer = tokio::spawn(async move {
while let Some(evt) = rx.recv().await {
match evt {
OutEvent::Msg(msg) => {
let json = match serde_json::to_string(&msg) {
Ok(j) => j,
Err(_) => continue,
};
if sink.send(WsMessage::Text(json.into())).await.is_err() {
break;
}
}
OutEvent::Flush(done) => {
let _ = done.send(());
}
}
}
});
let mut nonce = [0u8; 32];
rand::rngs::OsRng.fill_bytes(&mut nonce);
if tx
.send(OutEvent::Msg(ServerMsg::Challenge {
nonce_b64: B64.encode(nonce),
}))
.await
.is_err()
{
writer.abort();
return Ok(());
}
let mut fingerprint: Option<String> = None;
let mut proven_pubkey: Option<String> = None;
let mut authenticated = false;
let mut acks_enabled = false;
let auth_deadline = tokio::time::sleep(std::time::Duration::from_secs(PRE_AUTH_TIMEOUT_SECS));
tokio::pin!(auth_deadline);
loop {
let frame = tokio::select! {
_ = &mut auth_deadline, if !authenticated => {
debug!("client did not authenticate within the timeout; dropping");
shared.metrics.pre_auth_timeouts.fetch_add(1, Ordering::Relaxed);
break;
}
f = stream.next() => match f {
Some(Ok(fr)) => fr,
Some(Err(_)) | None => break,
},
};
let text = match frame {
WsMessage::Text(t) => t.as_str().to_string(),
WsMessage::Binary(b) => String::from_utf8_lossy(&b).into_owned(),
WsMessage::Close(_) => break,
WsMessage::Ping(_) | WsMessage::Pong(_) | WsMessage::Frame(_) => continue,
};
let msg: ClientMsg = match serde_json::from_str(&text) {
Ok(m) => m,
Err(e) => {
let _ = tx
.send(OutEvent::Msg(ServerMsg::Error {
message: format!("bad message: {e}"),
}))
.await;
continue;
}
};
if !authenticated {
match &msg {
ClientMsg::Hello {
fingerprint: claimed,
pubkey_b64,
signature_b64,
..
} => match verify_client_auth(claimed, pubkey_b64, signature_b64, &nonce) {
Ok(proven_fp) => {
fingerprint = Some(proven_fp);
proven_pubkey = Some(pubkey_b64.clone());
authenticated = true;
}
Err(e) => {
debug!(error = %e, "client auth failed; dropping connection");
let _ = tx
.send(OutEvent::Msg(ServerMsg::Error {
message: format!("auth failed: {e}"),
}))
.await;
break;
}
},
_ => {
let _ = tx
.send(OutEvent::Msg(ServerMsg::Error {
message: "authenticate with hello first".into(),
}))
.await;
break;
}
}
}
if let Err(e) =
handle_client_msg(msg, &mut fingerprint, &proven_pubkey, &mut acks_enabled, &tx, &shared)
.await
{
let _ = tx
.send(OutEvent::Msg(ServerMsg::Error {
message: e.to_string(),
}))
.await;
}
}
if let Some(fp) = &fingerprint {
let mut conns = shared.conns.lock().await;
if let Some(v) = conns.get_mut(fp) {
v.retain(|s| !s.same_channel(&tx));
if v.is_empty() {
conns.remove(fp);
}
}
}
writer.abort();
Ok(())
}
async fn handle_client_msg(
msg: ClientMsg,
fingerprint: &mut Option<String>,
proven_pubkey: &Option<String>,
acks_enabled: &mut bool,
tx: &Tx,
shared: &Arc<Shared>,
) -> Result<()> {
match msg {
ClientMsg::Hello { rooms, acks, .. } => {
let fp = require_fp(fingerprint)?;
let registered = {
let mut conns = shared.conns.lock().await;
let entry = conns.entry(fp.clone()).or_default();
if entry.iter().any(|s| s.same_channel(tx)) {
true
} else if entry.len() >= MAX_CONNECTIONS_PER_FP {
false
} else {
entry.push(tx.clone());
true
}
};
if !registered {
let _ = tx
.send(OutEvent::Msg(ServerMsg::Error {
message: "too many concurrent connections for this identity".into(),
}))
.await;
bail!("per-fingerprint connection limit ({MAX_CONNECTIONS_PER_FP}) exceeded");
}
{
let db = shared.db.lock().await;
for room in rooms.iter().take(MAX_ROOMS_PER_HELLO) {
if let Some(room) = clean_id(room) {
add_membership(&db, &fp, &room)?;
}
}
}
*acks_enabled = acks;
let _ = tx.send(OutEvent::Msg(ServerMsg::Ready { fingerprint: fp.clone() })).await;
flush_mailbox(&fp, tx, shared, acks).await?;
}
ClientMsg::Subscribe { room } => {
let fp = require_fp(fingerprint)?;
let room = clean_id(&room).ok_or_else(|| anyhow!("invalid room"))?;
let db = shared.db.lock().await;
add_membership(&db, &fp, &room)?;
}
ClientMsg::Unsubscribe { room } => {
let fp = require_fp(fingerprint)?;
let room = clean_id(&room).ok_or_else(|| anyhow!("invalid room"))?;
let db = shared.db.lock().await;
db.execute(
"DELETE FROM memberships WHERE fingerprint = ?1 AND room = ?2",
params![fp, room],
)?;
}
ClientMsg::Publish { room, id, payload_b64 } => {
let fp = require_fp(fingerprint)?;
let room = clean_id(&room).ok_or_else(|| anyhow!("invalid room"))?;
if id.is_empty() || id.len() > MAX_MSG_ID_LEN {
bail!("invalid message id");
}
if payload_b64.len() > MAX_PAYLOAD_B64 {
bail!("payload too large");
}
let members = {
let db = shared.db.lock().await;
add_membership(&db, &fp, &room)?;
room_members(&db, &room)?
};
let mut delivered = 0usize;
let mut queued = 0usize;
for member in members {
if member == fp {
continue;
}
let online = {
let conns = shared.conns.lock().await;
match conns.get(&member) {
Some(senders) => {
let out = ServerMsg::Message {
room: room.clone(),
id: id.clone(),
payload_b64: payload_b64.clone(),
mailbox_id: None,
};
fan_out(senders, &out, &shared.metrics)
}
None => false,
}
};
if online {
delivered += 1;
shared.metrics.publish_delivered.fetch_add(1, Ordering::Relaxed);
} else {
let db = shared.db.lock().await;
enqueue(&db, &member, &room, &id, &payload_b64)?;
queued += 1;
shared.metrics.publish_queued.fetch_add(1, Ordering::Relaxed);
}
}
let _ = tx.send(OutEvent::Msg(ServerMsg::Sent { id, delivered, queued })).await;
}
ClientMsg::SendDirect {
to,
room,
id,
payload_b64,
} => {
let _from = require_fp(fingerprint)?;
let to = clean_id(&to).ok_or_else(|| anyhow!("invalid recipient"))?;
let room = clean_id(&room).ok_or_else(|| anyhow!("invalid room"))?;
if id.is_empty() || id.len() > MAX_MSG_ID_LEN {
bail!("invalid message id");
}
if payload_b64.len() > MAX_PAYLOAD_B64 {
bail!("payload too large");
}
let online = {
let conns = shared.conns.lock().await;
match conns.get(&to) {
Some(senders) => {
let out = ServerMsg::Message {
room: room.clone(),
id: id.clone(),
payload_b64: payload_b64.clone(),
mailbox_id: None,
};
fan_out(senders, &out, &shared.metrics)
}
None => false,
}
};
let (delivered, queued) = if online {
shared.metrics.publish_delivered.fetch_add(1, Ordering::Relaxed);
(1usize, 0usize)
} else {
let db = shared.db.lock().await;
enqueue(&db, &to, &room, &id, &payload_b64)?;
shared.metrics.publish_queued.fetch_add(1, Ordering::Relaxed);
(0usize, 1usize)
};
let _ = tx.send(OutEvent::Msg(ServerMsg::Sent { id, delivered, queued })).await;
}
ClientMsg::CreateConnectToken => {
let fp = require_fp(fingerprint)?;
let pubkey = proven_pubkey.clone().unwrap_or_default();
let now = now_unix();
let token = {
let mut t = shared.tokens.lock().await;
t.gc(now);
if let Some(old) = t.by_fp.remove(&fp) {
t.by_token.remove(&old);
}
if t.by_token.len() >= MAX_CONNECT_TOKENS {
bail!("connect-code registry full; try again shortly");
}
let mut tok = gen_connect_token();
let mut tries = 0;
while t.by_token.contains_key(&tok) && tries < 8 {
tok = gen_connect_token();
tries += 1;
}
t.by_token.insert(
tok.clone(),
ConnectTokenEntry {
fingerprint: fp.clone(),
pubkey_b64: pubkey,
expires_at: now + CONNECT_TOKEN_TTL_SECS,
},
);
t.by_fp.insert(fp.clone(), tok.clone());
tok
};
let _ = tx
.send(OutEvent::Msg(ServerMsg::ConnectToken {
token,
ttl_secs: CONNECT_TOKEN_TTL_SECS as u64,
}))
.await;
}
ClientMsg::RedeemConnectToken { token } => {
let _fp = require_fp(fingerprint)?;
let now = now_unix();
let resolved = {
let mut t = shared.tokens.lock().await;
t.gc(now);
normalize_connect_token(&token)
.and_then(|norm| t.by_token.get(&norm))
.filter(|e| e.expires_at > now)
.map(|e| (e.fingerprint.clone(), e.pubkey_b64.clone()))
};
let (fingerprint, pubkey_b64) = match resolved {
Some((fp, pk)) => (Some(fp), Some(pk)),
None => (None, None),
};
let _ = tx
.send(OutEvent::Msg(ServerMsg::ConnectTokenResolved {
token: normalize_connect_token(&token).unwrap_or_default(),
fingerprint,
pubkey_b64,
}))
.await;
}
ClientMsg::Fetch => {
let fp = require_fp(fingerprint)?;
flush_mailbox(&fp, tx, shared, *acks_enabled).await?;
}
ClientMsg::Ack { mailbox_id } => {
let fp = require_fp(fingerprint)?;
let db = shared.db.lock().await;
db.execute(
"DELETE FROM mailbox WHERE id = ?1 AND fingerprint = ?2",
params![mailbox_id, fp],
)?;
}
ClientMsg::Ping => {
let _ = tx.send(OutEvent::Msg(ServerMsg::Pong)).await;
}
}
Ok(())
}
fn fan_out(senders: &[Tx], out: &ServerMsg, metrics: &Metrics) -> bool {
let mut any_ok = false;
for s in senders {
match s.try_send(OutEvent::Msg(out.clone())) {
Ok(()) => any_ok = true,
Err(mpsc::error::TrySendError::Full(_)) => {
metrics.outbound_cap_drops.fetch_add(1, Ordering::Relaxed);
}
Err(mpsc::error::TrySendError::Closed(_)) => {}
}
}
any_ok
}
async fn flush_mailbox(fp: &str, tx: &Tx, shared: &Arc<Shared>, acks: bool) -> Result<()> {
let items = {
let db = shared.db.lock().await;
peek_mailbox(&db, fp)?
};
let mut legacy_ids: Vec<i64> = Vec::new();
let mut socket_alive = true;
for (row_id, room, msg_id, payload_b64) in items {
let mailbox_id = if acks { Some(row_id) } else { None };
if tx
.send(OutEvent::Msg(ServerMsg::Message {
room,
id: msg_id,
payload_b64,
mailbox_id,
}))
.await
.is_err()
{
socket_alive = false;
break; }
if !acks {
legacy_ids.push(row_id);
}
}
if !acks && socket_alive && !legacy_ids.is_empty() {
let (done_tx, done_rx) = oneshot::channel();
if tx.send(OutEvent::Flush(done_tx)).await.is_ok() && done_rx.await.is_ok() {
let db = shared.db.lock().await;
delete_mailbox_ids(&db, &legacy_ids)?;
}
}
Ok(())
}
fn require_fp(fp: &Option<String>) -> Result<String> {
fp.clone().ok_or_else(|| anyhow!("send hello first"))
}
fn compute_fingerprint(public_key: &[u8; 32]) -> String {
let hash = Sha256::digest(public_key);
let hex_str = hex::encode(&hash[..12]);
hex_str
.as_bytes()
.chunks(4)
.map(|c| std::str::from_utf8(c).unwrap())
.collect::<Vec<&str>>()
.join("-")
}
fn verify_client_auth(
claimed_fp: &str,
pubkey_b64: &str,
signature_b64: &str,
nonce: &[u8; 32],
) -> Result<String> {
let pk_bytes = B64
.decode(pubkey_b64)
.map_err(|e| anyhow!("bad pubkey base64: {e}"))?;
let pk: [u8; 32] = pk_bytes
.as_slice()
.try_into()
.map_err(|_| anyhow!("pubkey must be 32 bytes"))?;
let sig_bytes = B64
.decode(signature_b64)
.map_err(|e| anyhow!("bad signature base64: {e}"))?;
let sig: [u8; 64] = sig_bytes
.as_slice()
.try_into()
.map_err(|_| anyhow!("signature must be 64 bytes"))?;
let proven_fp = compute_fingerprint(&pk);
if proven_fp != claimed_fp {
bail!("fingerprint does not match pubkey");
}
let vk = VerifyingKey::from_bytes(&pk).map_err(|e| anyhow!("bad ed25519 pubkey: {e}"))?;
let mut signed = Vec::with_capacity(RELAY_AUTH_DOMAIN.len() + nonce.len());
signed.extend_from_slice(RELAY_AUTH_DOMAIN);
signed.extend_from_slice(nonce);
vk.verify_strict(&signed, &Signature::from_bytes(&sig))
.map_err(|e| anyhow!("signature verification failed: {e}"))?;
Ok(proven_fp)
}
fn clean_id(s: &str) -> Option<String> {
let t = s.trim();
if t.is_empty() || t.len() > MAX_ID_LEN {
return None;
}
if t.chars()
.all(|c| c.is_ascii_alphanumeric() || matches!(c, '-' | '_' | ':' | '.'))
{
Some(t.to_string())
} else {
None
}
}
fn now_unix() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0)
}
fn gen_connect_token() -> String {
let mut rng = rand::rngs::OsRng;
(0..CONNECT_TOKEN_LEN)
.map(|_| {
let i = (rng.next_u32() as usize) % CONNECT_TOKEN_ALPHABET.len();
CONNECT_TOKEN_ALPHABET[i] as char
})
.collect()
}
fn normalize_connect_token(s: &str) -> Option<String> {
let up: String = s
.trim()
.to_ascii_uppercase()
.chars()
.filter(|c| *c != '-' && *c != ' ')
.collect();
if up.len() == CONNECT_TOKEN_LEN
&& up.bytes().all(|b| CONNECT_TOKEN_ALPHABET.contains(&b))
{
Some(up)
} else {
None
}
}
fn migrate(c: &Connection) -> Result<()> {
c.execute_batch(
"CREATE TABLE IF NOT EXISTS memberships (
fingerprint TEXT NOT NULL,
room TEXT NOT NULL,
PRIMARY KEY (fingerprint, room)
);
CREATE TABLE IF NOT EXISTS mailbox (
id INTEGER PRIMARY KEY AUTOINCREMENT,
fingerprint TEXT NOT NULL,
room TEXT NOT NULL,
msg_id TEXT NOT NULL,
payload_b64 TEXT NOT NULL,
created_at INTEGER NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_mailbox_fp ON mailbox(fingerprint);",
)?;
Ok(())
}
fn add_membership(c: &Connection, fp: &str, room: &str) -> Result<()> {
c.execute(
"INSERT OR IGNORE INTO memberships(fingerprint, room) VALUES(?1, ?2)",
params![fp, room],
)?;
Ok(())
}
fn room_members(c: &Connection, room: &str) -> Result<Vec<String>> {
let mut stmt = c.prepare("SELECT fingerprint FROM memberships WHERE room = ?1")?;
let rows = stmt.query_map(params![room], |r| r.get::<_, String>(0))?;
Ok(rows.filter_map(|r| r.ok()).collect())
}
fn enqueue(c: &Connection, fp: &str, room: &str, id: &str, payload_b64: &str) -> Result<()> {
c.execute(
"INSERT INTO mailbox(fingerprint, room, msg_id, payload_b64, created_at)
VALUES(?1, ?2, ?3, ?4, ?5)",
params![fp, room, id, payload_b64, now_unix()],
)?;
c.execute(
"DELETE FROM mailbox WHERE fingerprint = ?1 AND id NOT IN (
SELECT id FROM mailbox WHERE fingerprint = ?1 ORDER BY id DESC LIMIT ?2
)",
params![fp, MAX_MAILBOX_PER_FP as i64],
)?;
Ok(())
}
fn peek_mailbox(c: &Connection, fp: &str) -> Result<Vec<(i64, String, String, String)>> {
let mut stmt = c.prepare(
"SELECT id, room, msg_id, payload_b64 FROM mailbox WHERE fingerprint = ?1 ORDER BY id ASC",
)?;
let rows = stmt.query_map(params![fp], |r| {
Ok((
r.get::<_, i64>(0)?,
r.get::<_, String>(1)?,
r.get::<_, String>(2)?,
r.get::<_, String>(3)?,
))
})?;
let mut out = Vec::new();
for row in rows {
out.push(row?);
}
Ok(out)
}
fn delete_mailbox_ids(c: &Connection, ids: &[i64]) -> Result<()> {
for id in ids {
c.execute("DELETE FROM mailbox WHERE id = ?1", params![id])?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::{
enqueue, flush_mailbox, migrate, peek_mailbox, request_target, ClientMsg, ConnectTokens,
Connection, Metrics, OutEvent, ServerMsg, Shared, OUTBOUND_CAP,
};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
fn test_shared() -> Arc<Shared> {
let conn = Connection::open_in_memory().unwrap();
migrate(&conn).unwrap();
Arc::new(Shared {
db: Mutex::new(conn),
conns: Mutex::new(HashMap::new()),
tokens: Mutex::new(ConnectTokens::default()),
metrics: Metrics::default(),
})
}
async fn mailbox_count(shared: &Arc<Shared>, fp: &str) -> usize {
let db = shared.db.lock().await;
peek_mailbox(&db, fp).unwrap().len()
}
#[test]
fn message_serializes_mailbox_id_only_when_present() {
let live = ServerMsg::Message {
room: "r".into(),
id: "i".into(),
payload_b64: "p".into(),
mailbox_id: None,
};
let s = serde_json::to_string(&live).unwrap();
assert!(
!s.contains("mailbox_id"),
"live message must omit mailbox_id for old-client compat: {s}"
);
let queued = ServerMsg::Message {
room: "r".into(),
id: "i".into(),
payload_b64: "p".into(),
mailbox_id: Some(42),
};
let s = serde_json::to_string(&queued).unwrap();
assert!(
s.contains("\"mailbox_id\":42"),
"queued message must carry its mailbox_id: {s}"
);
}
#[test]
fn parses_ack() {
let m: ClientMsg = serde_json::from_str(r#"{"type":"ack","mailbox_id":7}"#).unwrap();
match m {
ClientMsg::Ack { mailbox_id } => assert_eq!(mailbox_id, 7),
other => panic!("expected ack, got {other:?}"),
}
}
#[test]
fn hello_acks_capability_defaults_false() {
let m: ClientMsg =
serde_json::from_str(r#"{"type":"hello","fingerprint":"fp"}"#).unwrap();
match m {
ClientMsg::Hello { acks, .. } => assert!(!acks),
other => panic!("expected hello, got {other:?}"),
}
let m: ClientMsg =
serde_json::from_str(r#"{"type":"hello","fingerprint":"fp","acks":true}"#).unwrap();
match m {
ClientMsg::Hello { acks, .. } => assert!(acks),
other => panic!("expected hello, got {other:?}"),
}
}
#[test]
fn parses_plain_paths() {
assert_eq!(request_target("GET / HTTP/1.1\r\nHost: x.onion\r\n\r\n"), "/");
assert_eq!(request_target("GET /health HTTP/1.1\r\n\r\n"), "/health");
assert_eq!(request_target("GET /ws HTTP/1.1\r\n"), "/ws");
}
#[test]
fn strips_query_string() {
assert_eq!(request_target("GET /health?probe=1 HTTP/1.1\r\n"), "/health");
assert_eq!(request_target("GET /?x HTTP/1.1\r\n"), "/");
}
#[test]
fn other_methods_keep_their_target() {
assert_eq!(request_target("HEAD /health HTTP/1.1\r\n"), "/health");
assert_eq!(request_target("POST /anything HTTP/1.1\r\n"), "/anything");
}
#[test]
fn malformed_requests_fall_back_to_root() {
assert_eq!(request_target(""), "/"); assert_eq!(request_target("GET"), "/"); assert_eq!(request_target("garbage\r\n"), "/"); }
#[tokio::test]
async fn legacy_flush_deletes_only_after_writer_confirms() {
let shared = test_shared();
let fp = "fp-legacy";
{
let db = shared.db.lock().await;
for i in 0..3 {
enqueue(&db, fp, "room", &format!("m{i}"), "p").unwrap();
}
}
let (tx, mut rx) = mpsc::channel::<OutEvent>(OUTBOUND_CAP);
let writer = tokio::spawn(async move {
let mut written = 0usize;
while let Some(evt) = rx.recv().await {
match evt {
OutEvent::Msg(_) => written += 1,
OutEvent::Flush(done) => {
let _ = done.send(());
}
}
}
written
});
flush_mailbox(fp, &tx, &shared, false).await.unwrap();
drop(tx); let written = writer.await.unwrap();
assert_eq!(written, 3, "all three legacy messages reached the writer");
assert_eq!(
mailbox_count(&shared, fp).await,
0,
"rows deleted only after the write barrier confirmed delivery"
);
}
#[tokio::test]
async fn legacy_flush_keeps_rows_when_socket_dies_before_barrier() {
let shared = test_shared();
let fp = "fp-legacy-drop";
{
let db = shared.db.lock().await;
for i in 0..3 {
enqueue(&db, fp, "room", &format!("m{i}"), "p").unwrap();
}
}
let (tx, mut rx) = mpsc::channel::<OutEvent>(OUTBOUND_CAP);
let writer = tokio::spawn(async move {
for _ in 0..3 {
let _ = rx.recv().await;
}
});
flush_mailbox(fp, &tx, &shared, false).await.unwrap();
writer.await.unwrap();
assert_eq!(
mailbox_count(&shared, fp).await,
3,
"no rows deleted: the messages may never have reached the socket"
);
}
#[tokio::test]
async fn at_least_once_tags_rows_and_keeps_them_until_ack() {
let shared = test_shared();
let fp = "fp-ack";
{
let db = shared.db.lock().await;
enqueue(&db, fp, "room", "m0", "p").unwrap();
enqueue(&db, fp, "room", "m1", "p").unwrap();
}
let (tx, mut rx) = mpsc::channel::<OutEvent>(OUTBOUND_CAP);
let collector = tokio::spawn(async move {
let mut ids = Vec::new();
while let Some(evt) = rx.recv().await {
if let OutEvent::Msg(ServerMsg::Message { mailbox_id, .. }) = evt {
ids.push(mailbox_id);
}
}
ids
});
flush_mailbox(fp, &tx, &shared, true).await.unwrap();
drop(tx);
let ids = collector.await.unwrap();
assert_eq!(ids.len(), 2, "both queued messages were delivered");
assert!(
ids.iter().all(|m| m.is_some()),
"acks=true tags every delivery with its mailbox_id to ACK"
);
assert_eq!(
mailbox_count(&shared, fp).await,
2,
"acks=true never deletes in flush; it waits for the client's Ack"
);
}
}