use std::collections::HashMap;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use anyhow::{anyhow, bail, Result};
use base64::engine::general_purpose::STANDARD as B64;
use base64::Engine;
use ed25519_dalek::{Signature, VerifyingKey};
use futures_util::{SinkExt, StreamExt};
use rand::RngCore;
use rusqlite::{params, Connection};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use tokio::io::AsyncWriteExt;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{mpsc, Mutex};
use tokio_tungstenite::tungstenite::Message as WsMessage;
use tokio_tungstenite::WebSocketStream;
use tracing::{debug, info};
const RELAY_AUTH_DOMAIN: &[u8] = b"huddle-relay-auth-v1";
const MAILBOX_TTL_SECS: i64 = 30 * 24 * 60 * 60;
const OUTBOUND_CAP: usize = 256;
const PRE_AUTH_TIMEOUT_SECS: u64 = 20;
const MAX_PAYLOAD_B64: usize = 256 * 1024;
const MAX_ID_LEN: usize = 128;
const MAX_MSG_ID_LEN: usize = 256;
const MAX_MAILBOX_PER_FP: usize = 500;
const MAX_ROOMS_PER_HELLO: usize = 1000;
#[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ClientMsg {
Hello {
fingerprint: String,
#[serde(default)]
pubkey_b64: String,
#[serde(default)]
signature_b64: String,
#[serde(default)]
rooms: Vec<String>,
},
Subscribe {
room: String,
},
Unsubscribe {
room: String,
},
Publish {
room: String,
id: String,
payload_b64: String,
},
Fetch,
Ping,
}
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ServerMsg {
Challenge { nonce_b64: String },
Ready { fingerprint: String },
Message { room: String, id: String, payload_b64: String },
Sent { id: String, delivered: usize, queued: usize },
Pong,
Error { message: String },
}
type Tx = mpsc::Sender<ServerMsg>;
struct Shared {
db: Mutex<Connection>,
conns: Mutex<HashMap<String, Vec<Tx>>>,
}
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
)
.init();
let bind = std::env::var("HUDDLE_SERVER_BIND").unwrap_or_else(|_| "127.0.0.1:8787".to_string());
let db_path = std::env::var("HUDDLE_SERVER_DB").unwrap_or_else(|_| "huddle-server.db".to_string());
let conn = Connection::open(&db_path)?;
migrate(&conn)?;
let shared = Arc::new(Shared {
db: Mutex::new(conn),
conns: Mutex::new(HashMap::new()),
});
{
let shared = shared.clone();
tokio::spawn(async move {
let mut tick = tokio::time::interval(std::time::Duration::from_secs(3600));
loop {
tick.tick().await;
let cutoff = now_unix() - MAILBOX_TTL_SECS;
let db = shared.db.lock().await;
match db.execute("DELETE FROM mailbox WHERE created_at < ?1", params![cutoff]) {
Ok(n) if n > 0 => {
debug!(removed = n, "mailbox GC: dropped expired ciphertext")
}
Ok(_) => {}
Err(e) => debug!(error = %e, "mailbox GC failed"),
}
}
});
}
let listener = TcpListener::bind(&bind).await?;
info!(%bind, db = %db_path, "huddle-server listening (WebSocket + /health)");
loop {
let (stream, _peer) = listener.accept().await?;
let shared = shared.clone();
tokio::spawn(async move {
if let Err(e) = handle_conn(stream, shared).await {
debug!(error = %e, "connection ended");
}
});
}
}
async fn handle_conn(stream: TcpStream, shared: Arc<Shared>) -> Result<()> {
let mut buf = [0u8; 1024];
let n = stream.peek(&mut buf).await?;
let head = String::from_utf8_lossy(&buf[..n]);
if head.to_ascii_lowercase().contains("upgrade: websocket") {
let ws = tokio_tungstenite::accept_async(stream).await?;
serve_ws(ws, shared).await
} else {
match request_target(&head) {
"/health" => serve_health(stream).await,
_ => serve_landing(stream).await,
}
}
}
async fn serve_health(mut stream: TcpStream) -> Result<()> {
let body = r#"{"ok":true,"service":"huddle-server"}"#;
let resp = format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
body.len(),
body
);
stream.write_all(resp.as_bytes()).await?;
stream.flush().await?;
Ok(())
}
const LANDING_HTML: &str = include_str!("landing.html");
fn request_target(head: &str) -> &str {
head.lines()
.next()
.and_then(|line| line.split_whitespace().nth(1))
.map(|target| target.split('?').next().unwrap_or("/"))
.unwrap_or("/")
}
async fn serve_landing(mut stream: TcpStream) -> Result<()> {
let resp = format!(
"HTTP/1.1 200 OK\r\n\
Content-Type: text/html; charset=utf-8\r\n\
Content-Length: {}\r\n\
Content-Security-Policy: default-src 'none'; style-src 'unsafe-inline'; img-src data:; base-uri 'none'; form-action 'none'; frame-ancestors 'none'\r\n\
X-Content-Type-Options: nosniff\r\n\
Referrer-Policy: no-referrer\r\n\
Connection: close\r\n\
\r\n{}",
LANDING_HTML.len(),
LANDING_HTML
);
stream.write_all(resp.as_bytes()).await?;
stream.flush().await?;
Ok(())
}
async fn serve_ws(ws: WebSocketStream<TcpStream>, shared: Arc<Shared>) -> Result<()> {
let (mut sink, mut stream) = ws.split();
let (tx, mut rx) = mpsc::channel::<ServerMsg>(OUTBOUND_CAP);
let writer = tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
let json = match serde_json::to_string(&msg) {
Ok(j) => j,
Err(_) => continue,
};
if sink.send(WsMessage::Text(json.into())).await.is_err() {
break;
}
}
});
let mut nonce = [0u8; 32];
rand::rngs::OsRng.fill_bytes(&mut nonce);
if tx
.send(ServerMsg::Challenge {
nonce_b64: B64.encode(nonce),
})
.await
.is_err()
{
writer.abort();
return Ok(());
}
let mut fingerprint: Option<String> = None;
let mut authenticated = false;
let auth_deadline = tokio::time::sleep(std::time::Duration::from_secs(PRE_AUTH_TIMEOUT_SECS));
tokio::pin!(auth_deadline);
loop {
let frame = tokio::select! {
_ = &mut auth_deadline, if !authenticated => {
debug!("client did not authenticate within the timeout; dropping");
break;
}
f = stream.next() => match f {
Some(Ok(fr)) => fr,
Some(Err(_)) | None => break,
},
};
let text = match frame {
WsMessage::Text(t) => t.as_str().to_string(),
WsMessage::Binary(b) => String::from_utf8_lossy(&b).into_owned(),
WsMessage::Close(_) => break,
WsMessage::Ping(_) | WsMessage::Pong(_) | WsMessage::Frame(_) => continue,
};
let msg: ClientMsg = match serde_json::from_str(&text) {
Ok(m) => m,
Err(e) => {
let _ = tx
.send(ServerMsg::Error {
message: format!("bad message: {e}"),
})
.await;
continue;
}
};
if !authenticated {
match &msg {
ClientMsg::Hello {
fingerprint: claimed,
pubkey_b64,
signature_b64,
..
} => match verify_client_auth(claimed, pubkey_b64, signature_b64, &nonce) {
Ok(proven_fp) => {
fingerprint = Some(proven_fp);
authenticated = true;
}
Err(e) => {
debug!(error = %e, "client auth failed; dropping connection");
let _ = tx
.send(ServerMsg::Error {
message: format!("auth failed: {e}"),
})
.await;
break;
}
},
_ => {
let _ = tx
.send(ServerMsg::Error {
message: "authenticate with hello first".into(),
})
.await;
break;
}
}
}
if let Err(e) = handle_client_msg(msg, &mut fingerprint, &tx, &shared).await {
let _ = tx
.send(ServerMsg::Error {
message: e.to_string(),
})
.await;
}
}
if let Some(fp) = &fingerprint {
let mut conns = shared.conns.lock().await;
if let Some(v) = conns.get_mut(fp) {
v.retain(|s| !s.same_channel(&tx));
if v.is_empty() {
conns.remove(fp);
}
}
}
writer.abort();
Ok(())
}
async fn handle_client_msg(
msg: ClientMsg,
fingerprint: &mut Option<String>,
tx: &Tx,
shared: &Arc<Shared>,
) -> Result<()> {
match msg {
ClientMsg::Hello { rooms, .. } => {
let fp = require_fp(fingerprint)?;
{
let mut conns = shared.conns.lock().await;
let entry = conns.entry(fp.clone()).or_default();
if !entry.iter().any(|s| s.same_channel(tx)) {
entry.push(tx.clone());
}
}
{
let db = shared.db.lock().await;
for room in rooms.iter().take(MAX_ROOMS_PER_HELLO) {
if let Some(room) = clean_id(room) {
add_membership(&db, &fp, &room)?;
}
}
}
let _ = tx.send(ServerMsg::Ready { fingerprint: fp.clone() }).await;
flush_mailbox(&fp, tx, shared).await?;
}
ClientMsg::Subscribe { room } => {
let fp = require_fp(fingerprint)?;
let room = clean_id(&room).ok_or_else(|| anyhow!("invalid room"))?;
let db = shared.db.lock().await;
add_membership(&db, &fp, &room)?;
}
ClientMsg::Unsubscribe { room } => {
let fp = require_fp(fingerprint)?;
let room = clean_id(&room).ok_or_else(|| anyhow!("invalid room"))?;
let db = shared.db.lock().await;
db.execute(
"DELETE FROM memberships WHERE fingerprint = ?1 AND room = ?2",
params![fp, room],
)?;
}
ClientMsg::Publish { room, id, payload_b64 } => {
let fp = require_fp(fingerprint)?;
let room = clean_id(&room).ok_or_else(|| anyhow!("invalid room"))?;
if id.is_empty() || id.len() > MAX_MSG_ID_LEN {
bail!("invalid message id");
}
if payload_b64.len() > MAX_PAYLOAD_B64 {
bail!("payload too large");
}
let members = {
let db = shared.db.lock().await;
add_membership(&db, &fp, &room)?;
room_members(&db, &room)?
};
let mut delivered = 0usize;
let mut queued = 0usize;
for member in members {
if member == fp {
continue;
}
let online = {
let conns = shared.conns.lock().await;
match conns.get(&member) {
Some(senders) => {
let out = ServerMsg::Message {
room: room.clone(),
id: id.clone(),
payload_b64: payload_b64.clone(),
};
senders
.iter()
.fold(false, |acc, s| acc | s.try_send(out.clone()).is_ok())
}
None => false,
}
};
if online {
delivered += 1;
} else {
let db = shared.db.lock().await;
enqueue(&db, &member, &room, &id, &payload_b64)?;
queued += 1;
}
}
let _ = tx.send(ServerMsg::Sent { id, delivered, queued }).await;
}
ClientMsg::Fetch => {
let fp = require_fp(fingerprint)?;
flush_mailbox(&fp, tx, shared).await?;
}
ClientMsg::Ping => {
let _ = tx.send(ServerMsg::Pong).await;
}
}
Ok(())
}
async fn flush_mailbox(fp: &str, tx: &Tx, shared: &Arc<Shared>) -> Result<()> {
let items = {
let db = shared.db.lock().await;
peek_mailbox(&db, fp)?
};
let mut delivered_ids: Vec<i64> = Vec::new();
for (row_id, room, msg_id, payload_b64) in items {
if tx
.send(ServerMsg::Message {
room,
id: msg_id,
payload_b64,
})
.await
.is_err()
{
break; }
delivered_ids.push(row_id);
}
if !delivered_ids.is_empty() {
let db = shared.db.lock().await;
delete_mailbox_ids(&db, &delivered_ids)?;
}
Ok(())
}
fn require_fp(fp: &Option<String>) -> Result<String> {
fp.clone().ok_or_else(|| anyhow!("send hello first"))
}
fn compute_fingerprint(public_key: &[u8; 32]) -> String {
let hash = Sha256::digest(public_key);
let hex_str = hex::encode(&hash[..12]);
hex_str
.as_bytes()
.chunks(4)
.map(|c| std::str::from_utf8(c).unwrap())
.collect::<Vec<&str>>()
.join("-")
}
fn verify_client_auth(
claimed_fp: &str,
pubkey_b64: &str,
signature_b64: &str,
nonce: &[u8; 32],
) -> Result<String> {
let pk_bytes = B64
.decode(pubkey_b64)
.map_err(|e| anyhow!("bad pubkey base64: {e}"))?;
let pk: [u8; 32] = pk_bytes
.as_slice()
.try_into()
.map_err(|_| anyhow!("pubkey must be 32 bytes"))?;
let sig_bytes = B64
.decode(signature_b64)
.map_err(|e| anyhow!("bad signature base64: {e}"))?;
let sig: [u8; 64] = sig_bytes
.as_slice()
.try_into()
.map_err(|_| anyhow!("signature must be 64 bytes"))?;
let proven_fp = compute_fingerprint(&pk);
if proven_fp != claimed_fp {
bail!("fingerprint does not match pubkey");
}
let vk = VerifyingKey::from_bytes(&pk).map_err(|e| anyhow!("bad ed25519 pubkey: {e}"))?;
let mut signed = Vec::with_capacity(RELAY_AUTH_DOMAIN.len() + nonce.len());
signed.extend_from_slice(RELAY_AUTH_DOMAIN);
signed.extend_from_slice(nonce);
vk.verify_strict(&signed, &Signature::from_bytes(&sig))
.map_err(|e| anyhow!("signature verification failed: {e}"))?;
Ok(proven_fp)
}
fn clean_id(s: &str) -> Option<String> {
let t = s.trim();
if t.is_empty() || t.len() > MAX_ID_LEN {
return None;
}
if t.chars()
.all(|c| c.is_ascii_alphanumeric() || matches!(c, '-' | '_' | ':' | '.'))
{
Some(t.to_string())
} else {
None
}
}
fn now_unix() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0)
}
fn migrate(c: &Connection) -> Result<()> {
c.execute_batch(
"CREATE TABLE IF NOT EXISTS memberships (
fingerprint TEXT NOT NULL,
room TEXT NOT NULL,
PRIMARY KEY (fingerprint, room)
);
CREATE TABLE IF NOT EXISTS mailbox (
id INTEGER PRIMARY KEY AUTOINCREMENT,
fingerprint TEXT NOT NULL,
room TEXT NOT NULL,
msg_id TEXT NOT NULL,
payload_b64 TEXT NOT NULL,
created_at INTEGER NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_mailbox_fp ON mailbox(fingerprint);",
)?;
Ok(())
}
fn add_membership(c: &Connection, fp: &str, room: &str) -> Result<()> {
c.execute(
"INSERT OR IGNORE INTO memberships(fingerprint, room) VALUES(?1, ?2)",
params![fp, room],
)?;
Ok(())
}
fn room_members(c: &Connection, room: &str) -> Result<Vec<String>> {
let mut stmt = c.prepare("SELECT fingerprint FROM memberships WHERE room = ?1")?;
let rows = stmt.query_map(params![room], |r| r.get::<_, String>(0))?;
Ok(rows.filter_map(|r| r.ok()).collect())
}
fn enqueue(c: &Connection, fp: &str, room: &str, id: &str, payload_b64: &str) -> Result<()> {
c.execute(
"INSERT INTO mailbox(fingerprint, room, msg_id, payload_b64, created_at)
VALUES(?1, ?2, ?3, ?4, ?5)",
params![fp, room, id, payload_b64, now_unix()],
)?;
c.execute(
"DELETE FROM mailbox WHERE fingerprint = ?1 AND id NOT IN (
SELECT id FROM mailbox WHERE fingerprint = ?1 ORDER BY id DESC LIMIT ?2
)",
params![fp, MAX_MAILBOX_PER_FP as i64],
)?;
Ok(())
}
fn peek_mailbox(c: &Connection, fp: &str) -> Result<Vec<(i64, String, String, String)>> {
let mut stmt = c.prepare(
"SELECT id, room, msg_id, payload_b64 FROM mailbox WHERE fingerprint = ?1 ORDER BY id ASC",
)?;
let rows = stmt.query_map(params![fp], |r| {
Ok((
r.get::<_, i64>(0)?,
r.get::<_, String>(1)?,
r.get::<_, String>(2)?,
r.get::<_, String>(3)?,
))
})?;
let mut out = Vec::new();
for row in rows {
out.push(row?);
}
Ok(out)
}
fn delete_mailbox_ids(c: &Connection, ids: &[i64]) -> Result<()> {
for id in ids {
c.execute("DELETE FROM mailbox WHERE id = ?1", params![id])?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::request_target;
#[test]
fn parses_plain_paths() {
assert_eq!(request_target("GET / HTTP/1.1\r\nHost: x.onion\r\n\r\n"), "/");
assert_eq!(request_target("GET /health HTTP/1.1\r\n\r\n"), "/health");
assert_eq!(request_target("GET /ws HTTP/1.1\r\n"), "/ws");
}
#[test]
fn strips_query_string() {
assert_eq!(request_target("GET /health?probe=1 HTTP/1.1\r\n"), "/health");
assert_eq!(request_target("GET /?x HTTP/1.1\r\n"), "/");
}
#[test]
fn other_methods_keep_their_target() {
assert_eq!(request_target("HEAD /health HTTP/1.1\r\n"), "/health");
assert_eq!(request_target("POST /anything HTTP/1.1\r\n"), "/anything");
}
#[test]
fn malformed_requests_fall_back_to_root() {
assert_eq!(request_target(""), "/"); assert_eq!(request_target("GET"), "/"); assert_eq!(request_target("garbage\r\n"), "/"); }
}