1use base64::engine::general_purpose::STANDARD as B64;
18use base64::Engine;
19use futures::{SinkExt, StreamExt};
20use serde::{Deserialize, Serialize};
21use tokio::sync::mpsc;
22use tokio_tungstenite::tungstenite::Message as WsMessage;
23use tokio_tungstenite::WebSocketStream;
24use tracing::warn;
25
26use crate::error::{HuddleError, Result};
27
28#[derive(Debug, Serialize)]
30#[serde(tag = "type", rename_all = "snake_case")]
31enum ClientMsg {
32 Hello { fingerprint: String, rooms: Vec<String> },
33 Subscribe { room: String },
34 Unsubscribe { room: String },
35 Publish { room: String, id: String, payload_b64: String },
36 Fetch,
37 Ping,
38}
39
40#[derive(Debug, Deserialize)]
42#[serde(tag = "type", rename_all = "snake_case")]
43enum ServerMsg {
44 Ready,
48 Message { room: String, id: String, payload_b64: String },
49 Sent { id: String, delivered: usize, queued: usize },
50 Pong,
51 Error { message: String },
52}
53
54#[derive(Debug, Clone)]
58pub enum ServerEvent {
59 Ready,
61 Sent { id: String, delivered: usize, queued: usize },
65 Message { room: String, id: String, payload: Vec<u8> },
67 Disconnected,
69}
70
71#[derive(Clone)]
74pub struct ServerClient {
75 out_tx: mpsc::UnboundedSender<ClientMsg>,
76}
77
78impl ServerClient {
79 pub async fn connect(
86 url: &str,
87 socks: Option<&str>,
88 fingerprint: String,
89 rooms: Vec<String>,
90 ) -> Result<(Self, mpsc::UnboundedReceiver<ServerEvent>)> {
91 match socks {
92 Some(proxy) => {
93 let proxy: std::net::SocketAddr = proxy
94 .parse()
95 .map_err(|e| HuddleError::Network(format!("bad socks address: {e}")))?;
96 let target = host_port_from_ws_url(url)?;
97 let stream = tokio_socks::tcp::Socks5Stream::connect(proxy, target.as_str())
98 .await
99 .map_err(|e| HuddleError::Network(format!("tor socks connect: {e}")))?;
100 let (ws, _resp) = tokio_tungstenite::client_async(url, stream)
101 .await
102 .map_err(|e| HuddleError::Network(format!("ws handshake: {e}")))?;
103 Ok(Self::spawn(ws, fingerprint, rooms))
104 }
105 None => {
106 let (ws, _resp) = tokio_tungstenite::connect_async(url)
107 .await
108 .map_err(|e| HuddleError::Network(format!("ws connect: {e}")))?;
109 Ok(Self::spawn(ws, fingerprint, rooms))
110 }
111 }
112 }
113
114 fn spawn<S>(
118 ws: WebSocketStream<S>,
119 fingerprint: String,
120 rooms: Vec<String>,
121 ) -> (Self, mpsc::UnboundedReceiver<ServerEvent>)
122 where
123 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
124 {
125 let (mut sink, mut stream) = ws.split();
126 let (out_tx, mut out_rx) = mpsc::unbounded_channel::<ClientMsg>();
127 let (ev_tx, ev_rx) = mpsc::unbounded_channel::<ServerEvent>();
128
129 let _ = out_tx.send(ClientMsg::Hello { fingerprint, rooms });
132
133 tokio::spawn(async move {
141 while let Some(msg) = out_rx.recv().await {
142 let json = match serde_json::to_string(&msg) {
143 Ok(j) => j,
144 Err(_) => continue,
145 };
146 if sink.send(WsMessage::Text(json.into())).await.is_err() {
147 return;
148 }
149 }
150 let _ = sink.close().await;
151 });
152
153 tokio::spawn(async move {
155 while let Some(frame) = stream.next().await {
156 let frame = match frame {
157 Ok(f) => f,
158 Err(_) => break,
159 };
160 let text = match frame {
161 WsMessage::Text(t) => t.as_str().to_string(),
162 WsMessage::Binary(b) => String::from_utf8_lossy(&b).into_owned(),
163 WsMessage::Close(_) => break,
164 _ => continue,
165 };
166 match serde_json::from_str::<ServerMsg>(&text) {
167 Ok(ServerMsg::Ready) => {
168 let _ = ev_tx.send(ServerEvent::Ready);
169 }
170 Ok(ServerMsg::Sent { id, delivered, queued }) => {
171 let _ = ev_tx.send(ServerEvent::Sent { id, delivered, queued });
172 }
173 Ok(ServerMsg::Message { room, id, payload_b64 }) => {
174 match B64.decode(payload_b64.as_bytes()) {
175 Ok(payload) => {
176 let _ = ev_tx.send(ServerEvent::Message { room, id, payload });
177 }
178 Err(e) => warn!(error = %e, "server sent undecodable payload"),
179 }
180 }
181 Ok(ServerMsg::Error { message }) => warn!(%message, "huddle-server error"),
182 Ok(ServerMsg::Pong) => {}
183 Err(e) => warn!(error = %e, "unparseable server message"),
184 }
185 }
186 let _ = ev_tx.send(ServerEvent::Disconnected);
187 });
188
189 (Self { out_tx }, ev_rx)
190 }
191
192 pub fn publish(&self, room: &str, id: &str, payload: &[u8]) -> Result<()> {
194 self.send(ClientMsg::Publish {
195 room: room.to_string(),
196 id: id.to_string(),
197 payload_b64: B64.encode(payload),
198 })
199 }
200
201 pub fn subscribe(&self, room: &str) -> Result<()> {
203 self.send(ClientMsg::Subscribe { room: room.to_string() })
204 }
205
206 pub fn unsubscribe(&self, room: &str) -> Result<()> {
207 self.send(ClientMsg::Unsubscribe { room: room.to_string() })
208 }
209
210 pub fn fetch(&self) -> Result<()> {
212 self.send(ClientMsg::Fetch)
213 }
214
215 pub fn ping(&self) -> Result<()> {
216 self.send(ClientMsg::Ping)
217 }
218
219 fn send(&self, msg: ClientMsg) -> Result<()> {
220 self.out_tx
221 .send(msg)
222 .map_err(|_| HuddleError::Network("server connection closed".to_string()))
223 }
224}
225
226fn host_port_from_ws_url(url: &str) -> Result<String> {
230 let rest = url
231 .strip_prefix("ws://")
232 .or_else(|| url.strip_prefix("wss://"))
233 .ok_or_else(|| HuddleError::Network(format!("expected ws:// url, got {url}")))?;
234 let authority = rest.split('/').next().unwrap_or(rest);
235 if authority.is_empty() {
236 return Err(HuddleError::Network(format!("no host in url: {url}")));
237 }
238 if authority.contains(':') {
239 Ok(authority.to_string())
240 } else {
241 Ok(format!("{authority}:80"))
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::host_port_from_ws_url;
248
249 #[test]
250 fn parses_host_port() {
251 assert_eq!(host_port_from_ws_url("ws://abc.onion/ws").unwrap(), "abc.onion:80");
252 assert_eq!(
253 host_port_from_ws_url("ws://127.0.0.1:8787/ws").unwrap(),
254 "127.0.0.1:8787"
255 );
256 assert_eq!(host_port_from_ws_url("wss://h:443").unwrap(), "h:443");
257 assert!(host_port_from_ws_url("http://x").is_err());
258 }
259}