use base64::engine::general_purpose::STANDARD as B64;
use base64::Engine;
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use tokio_tungstenite::tungstenite::Message as WsMessage;
use tokio_tungstenite::WebSocketStream;
use tracing::warn;
use crate::error::{HuddleError, Result};
#[derive(Debug, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ClientMsg {
Hello { fingerprint: String, rooms: Vec<String> },
Subscribe { room: String },
Unsubscribe { room: String },
Publish { room: String, id: String, payload_b64: String },
Fetch,
Ping,
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ServerMsg {
Ready,
Message { room: String, id: String, payload_b64: String },
Sent { id: String, delivered: usize, queued: usize },
Pong,
Error { message: String },
}
#[derive(Debug, Clone)]
pub enum ServerEvent {
Ready,
Sent { id: String, delivered: usize, queued: usize },
Message { room: String, id: String, payload: Vec<u8> },
Disconnected,
}
#[derive(Clone)]
pub struct ServerClient {
out_tx: mpsc::UnboundedSender<ClientMsg>,
}
impl ServerClient {
pub async fn connect(
url: &str,
dial: &crate::network::transport::DialMode,
fingerprint: String,
rooms: Vec<String>,
) -> Result<(Self, mpsc::UnboundedReceiver<ServerEvent>)> {
use crate::network::transport::DialMode;
match dial {
DialMode::Socks5 { proxy } => {
let proxy: std::net::SocketAddr = proxy
.parse()
.map_err(|e| HuddleError::Network(format!("bad socks address: {e}")))?;
let target = host_port_from_ws_url(url)?;
let stream = tokio_socks::tcp::Socks5Stream::connect(proxy, target.as_str())
.await
.map_err(|e| HuddleError::Network(format!("tor socks connect: {e}")))?;
let (ws, _resp) = tokio_tungstenite::client_async(url, stream)
.await
.map_err(|e| HuddleError::Network(format!("ws handshake: {e}")))?;
Ok(Self::spawn(ws, fingerprint, rooms))
}
DialMode::Direct | DialMode::Tls { pinned_cert_der: None } => {
let (ws, _resp) = tokio_tungstenite::connect_async(url)
.await
.map_err(|e| HuddleError::Network(format!("ws connect: {e}")))?;
Ok(Self::spawn(ws, fingerprint, rooms))
}
DialMode::Tls {
pinned_cert_der: Some(_),
} => Err(HuddleError::Network(
"pinned-certificate wss is not supported in this build — use a real cert (Caddy/Let's Encrypt) or an onion door".into(),
)),
#[cfg(feature = "arti")]
DialMode::Arti { bridge } => {
let client =
crate::network::transport::arti_client(bridge.as_deref()).await?;
let hp = host_port_from_ws_url(url)?;
let (host, port_s) = hp.rsplit_once(':').ok_or_else(|| {
HuddleError::Network(format!("bad host:port from {url}"))
})?;
let port: u16 = port_s
.parse()
.map_err(|_| HuddleError::Network(format!("bad port in {url}")))?;
let stream = client
.connect((host, port))
.await
.map_err(|e| HuddleError::Network(format!("arti connect: {e}")))?;
let (ws, _resp) = tokio_tungstenite::client_async(url, stream)
.await
.map_err(|e| HuddleError::Network(format!("ws handshake: {e}")))?;
Ok(Self::spawn(ws, fingerprint, rooms))
}
}
}
fn spawn<S>(
ws: WebSocketStream<S>,
fingerprint: String,
rooms: Vec<String>,
) -> (Self, mpsc::UnboundedReceiver<ServerEvent>)
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
{
let (mut sink, mut stream) = ws.split();
let (out_tx, mut out_rx) = mpsc::unbounded_channel::<ClientMsg>();
let (ev_tx, ev_rx) = mpsc::unbounded_channel::<ServerEvent>();
let _ = out_tx.send(ClientMsg::Hello { fingerprint, rooms });
tokio::spawn(async move {
while let Some(msg) = out_rx.recv().await {
let json = match serde_json::to_string(&msg) {
Ok(j) => j,
Err(_) => continue,
};
if sink.send(WsMessage::Text(json.into())).await.is_err() {
return;
}
}
let _ = sink.close().await;
});
tokio::spawn(async move {
while let Some(frame) = stream.next().await {
let frame = match frame {
Ok(f) => f,
Err(_) => break,
};
let text = match frame {
WsMessage::Text(t) => t.as_str().to_string(),
WsMessage::Binary(b) => String::from_utf8_lossy(&b).into_owned(),
WsMessage::Close(_) => break,
_ => continue,
};
match serde_json::from_str::<ServerMsg>(&text) {
Ok(ServerMsg::Ready) => {
let _ = ev_tx.send(ServerEvent::Ready);
}
Ok(ServerMsg::Sent { id, delivered, queued }) => {
let _ = ev_tx.send(ServerEvent::Sent { id, delivered, queued });
}
Ok(ServerMsg::Message { room, id, payload_b64 }) => {
match B64.decode(payload_b64.as_bytes()) {
Ok(payload) => {
let _ = ev_tx.send(ServerEvent::Message { room, id, payload });
}
Err(e) => warn!(error = %e, "server sent undecodable payload"),
}
}
Ok(ServerMsg::Error { message }) => warn!(%message, "huddle-server error"),
Ok(ServerMsg::Pong) => {}
Err(e) => warn!(error = %e, "unparseable server message"),
}
}
let _ = ev_tx.send(ServerEvent::Disconnected);
});
(Self { out_tx }, ev_rx)
}
pub fn publish(&self, room: &str, id: &str, payload: &[u8]) -> Result<()> {
self.send(ClientMsg::Publish {
room: room.to_string(),
id: id.to_string(),
payload_b64: B64.encode(payload),
})
}
pub fn subscribe(&self, room: &str) -> Result<()> {
self.send(ClientMsg::Subscribe { room: room.to_string() })
}
pub fn unsubscribe(&self, room: &str) -> Result<()> {
self.send(ClientMsg::Unsubscribe { room: room.to_string() })
}
pub fn fetch(&self) -> Result<()> {
self.send(ClientMsg::Fetch)
}
pub fn ping(&self) -> Result<()> {
self.send(ClientMsg::Ping)
}
fn send(&self, msg: ClientMsg) -> Result<()> {
self.out_tx
.send(msg)
.map_err(|_| HuddleError::Network("server connection closed".to_string()))
}
}
fn host_port_from_ws_url(url: &str) -> Result<String> {
let (rest, default_port) = if let Some(r) = url.strip_prefix("wss://") {
(r, 443)
} else if let Some(r) = url.strip_prefix("ws://") {
(r, 80)
} else {
return Err(HuddleError::Network(format!("expected ws:// url, got {url}")));
};
let authority = rest.split('/').next().unwrap_or(rest);
if authority.is_empty() {
return Err(HuddleError::Network(format!("no host in url: {url}")));
}
if authority.contains(':') {
Ok(authority.to_string())
} else {
Ok(format!("{authority}:{default_port}"))
}
}
#[cfg(test)]
mod tests {
use super::host_port_from_ws_url;
#[test]
fn parses_host_port() {
assert_eq!(host_port_from_ws_url("ws://abc.onion/ws").unwrap(), "abc.onion:80");
assert_eq!(
host_port_from_ws_url("ws://127.0.0.1:8787/ws").unwrap(),
"127.0.0.1:8787"
);
assert_eq!(host_port_from_ws_url("wss://h:443").unwrap(), "h:443");
assert_eq!(host_port_from_ws_url("wss://relay.example/ws").unwrap(), "relay.example:443");
assert!(host_port_from_ws_url("http://x").is_err());
}
}