1use base64::engine::general_purpose::STANDARD as B64;
18use base64::Engine;
19use futures::{SinkExt, StreamExt};
20use serde::{Deserialize, Serialize};
21use tokio::sync::mpsc;
22use tokio_tungstenite::tungstenite::Message as WsMessage;
23use tokio_tungstenite::WebSocketStream;
24use tracing::warn;
25
26use crate::error::{HuddleError, Result};
27
28#[derive(Debug, Serialize)]
30#[serde(tag = "type", rename_all = "snake_case")]
31enum ClientMsg {
32 Hello { fingerprint: String, rooms: Vec<String> },
33 Subscribe { room: String },
34 Unsubscribe { room: String },
35 Publish { room: String, id: String, payload_b64: String },
36 Fetch,
37 Ping,
38}
39
40#[derive(Debug, Deserialize)]
42#[serde(tag = "type", rename_all = "snake_case")]
43enum ServerMsg {
44 Ready,
48 Message { room: String, id: String, payload_b64: String },
49 Sent { id: String, delivered: usize, queued: usize },
50 Pong,
51 Error { message: String },
52}
53
54#[derive(Debug, Clone)]
58pub enum ServerEvent {
59 Ready,
61 Sent { id: String, delivered: usize, queued: usize },
65 Message { room: String, id: String, payload: Vec<u8> },
67 Disconnected,
69}
70
71#[derive(Clone)]
74pub struct ServerClient {
75 out_tx: mpsc::UnboundedSender<ClientMsg>,
76}
77
78impl ServerClient {
79 pub async fn connect(
87 url: &str,
88 dial: &crate::network::transport::DialMode,
89 fingerprint: String,
90 rooms: Vec<String>,
91 ) -> Result<(Self, mpsc::UnboundedReceiver<ServerEvent>)> {
92 use crate::network::transport::DialMode;
93 match dial {
94 DialMode::Socks5 { proxy } => {
95 let proxy: std::net::SocketAddr = proxy
96 .parse()
97 .map_err(|e| HuddleError::Network(format!("bad socks address: {e}")))?;
98 let target = host_port_from_ws_url(url)?;
99 let stream = tokio_socks::tcp::Socks5Stream::connect(proxy, target.as_str())
100 .await
101 .map_err(|e| HuddleError::Network(format!("tor socks connect: {e}")))?;
102 let (ws, _resp) = tokio_tungstenite::client_async(url, stream)
103 .await
104 .map_err(|e| HuddleError::Network(format!("ws handshake: {e}")))?;
105 Ok(Self::spawn(ws, fingerprint, rooms))
106 }
107 DialMode::Direct | DialMode::Tls { pinned_cert_der: None } => {
111 let (ws, _resp) = tokio_tungstenite::connect_async(url)
112 .await
113 .map_err(|e| HuddleError::Network(format!("ws connect: {e}")))?;
114 Ok(Self::spawn(ws, fingerprint, rooms))
115 }
116 DialMode::Tls {
121 pinned_cert_der: Some(_),
122 } => Err(HuddleError::Network(
123 "pinned-certificate wss is not supported in this build — use a real cert (Caddy/Let's Encrypt) or an onion door".into(),
124 )),
125 #[cfg(feature = "arti")]
129 DialMode::Arti { bridge } => {
130 let client =
131 crate::network::transport::arti_client(bridge.as_deref()).await?;
132 let hp = host_port_from_ws_url(url)?;
133 let (host, port_s) = hp.rsplit_once(':').ok_or_else(|| {
134 HuddleError::Network(format!("bad host:port from {url}"))
135 })?;
136 let port: u16 = port_s
137 .parse()
138 .map_err(|_| HuddleError::Network(format!("bad port in {url}")))?;
139 let stream = client
140 .connect((host, port))
141 .await
142 .map_err(|e| HuddleError::Network(format!("arti connect: {e}")))?;
143 let (ws, _resp) = tokio_tungstenite::client_async(url, stream)
144 .await
145 .map_err(|e| HuddleError::Network(format!("ws handshake: {e}")))?;
146 Ok(Self::spawn(ws, fingerprint, rooms))
147 }
148 }
149 }
150
151 fn spawn<S>(
155 ws: WebSocketStream<S>,
156 fingerprint: String,
157 rooms: Vec<String>,
158 ) -> (Self, mpsc::UnboundedReceiver<ServerEvent>)
159 where
160 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
161 {
162 let (mut sink, mut stream) = ws.split();
163 let (out_tx, mut out_rx) = mpsc::unbounded_channel::<ClientMsg>();
164 let (ev_tx, ev_rx) = mpsc::unbounded_channel::<ServerEvent>();
165
166 let _ = out_tx.send(ClientMsg::Hello { fingerprint, rooms });
169
170 tokio::spawn(async move {
178 while let Some(msg) = out_rx.recv().await {
179 let json = match serde_json::to_string(&msg) {
180 Ok(j) => j,
181 Err(_) => continue,
182 };
183 if sink.send(WsMessage::Text(json.into())).await.is_err() {
184 return;
185 }
186 }
187 let _ = sink.close().await;
188 });
189
190 tokio::spawn(async move {
192 while let Some(frame) = stream.next().await {
193 let frame = match frame {
194 Ok(f) => f,
195 Err(_) => break,
196 };
197 let text = match frame {
198 WsMessage::Text(t) => t.as_str().to_string(),
199 WsMessage::Binary(b) => String::from_utf8_lossy(&b).into_owned(),
200 WsMessage::Close(_) => break,
201 _ => continue,
202 };
203 match serde_json::from_str::<ServerMsg>(&text) {
204 Ok(ServerMsg::Ready) => {
205 let _ = ev_tx.send(ServerEvent::Ready);
206 }
207 Ok(ServerMsg::Sent { id, delivered, queued }) => {
208 let _ = ev_tx.send(ServerEvent::Sent { id, delivered, queued });
209 }
210 Ok(ServerMsg::Message { room, id, payload_b64 }) => {
211 match B64.decode(payload_b64.as_bytes()) {
212 Ok(payload) => {
213 let _ = ev_tx.send(ServerEvent::Message { room, id, payload });
214 }
215 Err(e) => warn!(error = %e, "server sent undecodable payload"),
216 }
217 }
218 Ok(ServerMsg::Error { message }) => warn!(%message, "huddle-server error"),
219 Ok(ServerMsg::Pong) => {}
220 Err(e) => warn!(error = %e, "unparseable server message"),
221 }
222 }
223 let _ = ev_tx.send(ServerEvent::Disconnected);
224 });
225
226 (Self { out_tx }, ev_rx)
227 }
228
229 pub fn publish(&self, room: &str, id: &str, payload: &[u8]) -> Result<()> {
231 self.send(ClientMsg::Publish {
232 room: room.to_string(),
233 id: id.to_string(),
234 payload_b64: B64.encode(payload),
235 })
236 }
237
238 pub fn subscribe(&self, room: &str) -> Result<()> {
240 self.send(ClientMsg::Subscribe { room: room.to_string() })
241 }
242
243 pub fn unsubscribe(&self, room: &str) -> Result<()> {
244 self.send(ClientMsg::Unsubscribe { room: room.to_string() })
245 }
246
247 pub fn fetch(&self) -> Result<()> {
249 self.send(ClientMsg::Fetch)
250 }
251
252 pub fn ping(&self) -> Result<()> {
253 self.send(ClientMsg::Ping)
254 }
255
256 fn send(&self, msg: ClientMsg) -> Result<()> {
257 self.out_tx
258 .send(msg)
259 .map_err(|_| HuddleError::Network("server connection closed".to_string()))
260 }
261}
262
263fn host_port_from_ws_url(url: &str) -> Result<String> {
267 let (rest, default_port) = if let Some(r) = url.strip_prefix("wss://") {
268 (r, 443)
269 } else if let Some(r) = url.strip_prefix("ws://") {
270 (r, 80)
271 } else {
272 return Err(HuddleError::Network(format!("expected ws:// url, got {url}")));
273 };
274 let authority = rest.split('/').next().unwrap_or(rest);
275 if authority.is_empty() {
276 return Err(HuddleError::Network(format!("no host in url: {url}")));
277 }
278 if authority.contains(':') {
279 Ok(authority.to_string())
280 } else {
281 Ok(format!("{authority}:{default_port}"))
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::host_port_from_ws_url;
288
289 #[test]
290 fn parses_host_port() {
291 assert_eq!(host_port_from_ws_url("ws://abc.onion/ws").unwrap(), "abc.onion:80");
292 assert_eq!(
293 host_port_from_ws_url("ws://127.0.0.1:8787/ws").unwrap(),
294 "127.0.0.1:8787"
295 );
296 assert_eq!(host_port_from_ws_url("wss://h:443").unwrap(), "h:443");
297 assert_eq!(host_port_from_ws_url("wss://relay.example/ws").unwrap(), "relay.example:443");
299 assert!(host_port_from_ws_url("http://x").is_err());
300 }
301}