1use std::collections::HashMap;
2
3use futures_util::{SinkExt, StreamExt};
4use pushwire_core::{ChannelKind, Frame, SystemOp};
5use reqwest::Client as HttpClient;
6use tokio::sync::mpsc;
7use tokio::task::JoinHandle;
8use tokio_tungstenite::tungstenite::Message as WsMessage;
9use tracing::{debug, warn};
10use uuid::Uuid;
11
12use crate::session::ConnectError;
13
14#[non_exhaustive]
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum TransportPreference {
18 WsFirst,
20 SseFirst,
22 WsOnly,
24 SseOnly,
26}
27
28#[derive(Debug)]
30pub(crate) enum OutboundMsg<C: ChannelKind> {
31 Frame(Frame<C>),
32 System(SystemOp<C>),
33 Close,
34}
35
36#[derive(Debug)]
38pub(crate) enum InboundMsg<C: ChannelKind> {
39 Frame(Frame<C>),
40 System(SystemOp<C>),
41 Closed,
42}
43
44pub(crate) enum ActiveTransport<C: ChannelKind> {
46 WebSocket {
47 outbound_tx: mpsc::Sender<OutboundMsg<C>>,
48 reader_handle: JoinHandle<()>,
49 writer_handle: JoinHandle<()>,
50 },
51 Sse {
52 http: HttpClient,
53 ack_url: String,
54 client_id: Uuid,
55 reader_handle: JoinHandle<()>,
56 },
57}
58
59impl<C: ChannelKind> ActiveTransport<C> {
60 pub(crate) async fn send_frame(
62 &self,
63 frame: Frame<C>,
64 ) -> Result<(), crate::session::SendError> {
65 match self {
66 ActiveTransport::WebSocket { outbound_tx, .. } => outbound_tx
67 .send(OutboundMsg::Frame(frame))
68 .await
69 .map_err(|_| crate::session::SendError::ChannelClosed),
70 ActiveTransport::Sse { .. } => Err(crate::session::SendError::NotConnected),
71 }
72 }
73
74 pub(crate) async fn send_system(
76 &self,
77 op: SystemOp<C>,
78 ) -> Result<(), crate::session::SendError> {
79 match self {
80 ActiveTransport::WebSocket { outbound_tx, .. } => outbound_tx
81 .send(OutboundMsg::System(op))
82 .await
83 .map_err(|_| crate::session::SendError::ChannelClosed),
84 ActiveTransport::Sse {
85 http,
86 ack_url,
87 client_id,
88 ..
89 } => {
90 if let SystemOp::Ack { channel, cursor } = &op {
92 let body = serde_json::json!({
93 "client_id": client_id,
94 "channel": channel,
95 "cursor": cursor,
96 });
97 let _ = http.post(ack_url).json(&body).send().await;
98 Ok(())
99 } else {
100 warn!("system op not supported in SSE mode, dropping");
102 Ok(())
103 }
104 }
105 }
106 }
107
108 pub(crate) async fn close(self) {
110 match self {
111 ActiveTransport::WebSocket {
112 outbound_tx,
113 reader_handle,
114 writer_handle,
115 } => {
116 let _ = outbound_tx.send(OutboundMsg::Close).await;
117 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
119 reader_handle.abort();
120 writer_handle.abort();
121 }
122 ActiveTransport::Sse { reader_handle, .. } => {
123 reader_handle.abort();
124 }
125 }
126 }
127}
128
129pub(crate) async fn connect_ws<C: ChannelKind>(
135 url: &str,
136 client_id: Uuid,
137 token: Option<&str>,
138 capabilities: &[C],
139 resume_cursors: HashMap<C, u64>,
140) -> Result<(ActiveTransport<C>, mpsc::Receiver<InboundMsg<C>>), ConnectError> {
141 let ws_url = http_to_ws_url(url);
143 let rps_url = format!("{ws_url}/rps");
144
145 let (ws_stream, _response) = tokio_tungstenite::connect_async(&rps_url)
146 .await
147 .map_err(|e| ConnectError::Transport(format!("WebSocket connect failed: {e}")))?;
148
149 let (mut ws_tx, mut ws_rx) = ws_stream.split();
150
151 let global_cursor = resume_cursors.values().copied().max();
153 let auth = SystemOp::<C>::Auth {
154 client_id,
155 token: token.map(String::from),
156 capabilities: capabilities.to_vec(),
157 resume_cursor: global_cursor,
158 resume_cursors: resume_cursors.clone(),
159 };
160 let auth_json =
161 serde_json::to_string(&auth).map_err(|e| ConnectError::Transport(e.to_string()))?;
162 ws_tx
163 .send(WsMessage::Text(auth_json))
164 .await
165 .map_err(|e| ConnectError::Transport(format!("failed to send auth: {e}")))?;
166
167 let auth_reply = ws_rx
169 .next()
170 .await
171 .ok_or(ConnectError::Transport(
172 "connection closed before auth reply".into(),
173 ))?
174 .map_err(|e| ConnectError::Transport(format!("auth reply read error: {e}")))?;
175
176 let auth_ok: SystemOp<C> = match auth_reply {
177 WsMessage::Text(text) => serde_json::from_str(&text)
178 .map_err(|e| ConnectError::AuthRejected(format!("invalid auth reply: {e}")))?,
179 WsMessage::Close(frame) => {
180 let reason = frame
181 .map(|f| f.reason.to_string())
182 .unwrap_or_else(|| "unknown".into());
183 return Err(ConnectError::AuthRejected(reason));
184 }
185 other => {
186 return Err(ConnectError::Transport(format!(
187 "unexpected auth reply type: {other:?}"
188 )));
189 }
190 };
191
192 match auth_ok {
193 SystemOp::AuthOk { .. } => {
194 debug!(?client_id, "auth handshake complete");
195 }
196 SystemOp::Error { message } => return Err(ConnectError::AuthRejected(message)),
197 other => {
198 return Err(ConnectError::AuthRejected(format!(
199 "expected AuthOk, got {other:?}"
200 )));
201 }
202 }
203
204 let (inbound_tx, inbound_rx) = mpsc::channel::<InboundMsg<C>>(256);
206 let (outbound_tx, mut outbound_rx) = mpsc::channel::<OutboundMsg<C>>(64);
207
208 let reader_inbound_tx = inbound_tx.clone();
210 let reader_handle = tokio::spawn(async move {
211 while let Some(msg) = ws_rx.next().await {
212 match msg {
213 Ok(WsMessage::Text(text)) => {
214 if let Ok(frame) = serde_json::from_str::<Frame<C>>(&text) {
217 if frame.channel.is_system() {
218 if let Ok(op) =
220 serde_json::from_value::<SystemOp<C>>(frame.payload.clone())
221 {
222 if reader_inbound_tx
223 .send(InboundMsg::System(op))
224 .await
225 .is_err()
226 {
227 break;
228 }
229 } else {
230 if reader_inbound_tx
232 .send(InboundMsg::Frame(frame))
233 .await
234 .is_err()
235 {
236 break;
237 }
238 }
239 } else if reader_inbound_tx
240 .send(InboundMsg::Frame(frame))
241 .await
242 .is_err()
243 {
244 break;
245 }
246 } else if let Ok(op) = serde_json::from_str::<SystemOp<C>>(&text) {
247 if reader_inbound_tx
250 .send(InboundMsg::System(op))
251 .await
252 .is_err()
253 {
254 break;
255 }
256 } else {
257 warn!("failed to parse inbound WS message");
258 }
259 }
260 Ok(WsMessage::Close(_)) => {
261 let _ = reader_inbound_tx.send(InboundMsg::Closed).await;
262 break;
263 }
264 Ok(WsMessage::Ping(_) | WsMessage::Pong(_)) => {
265 }
267 Ok(_) => {}
268 Err(e) => {
269 warn!(?e, "WS read error");
270 let _ = reader_inbound_tx.send(InboundMsg::Closed).await;
271 break;
272 }
273 }
274 }
275 });
276
277 let writer_handle = tokio::spawn(async move {
279 while let Some(msg) = outbound_rx.recv().await {
280 let ws_msg = match msg {
281 OutboundMsg::Frame(frame) => match serde_json::to_string(&frame) {
282 Ok(json) => WsMessage::Text(json),
283 Err(e) => {
284 warn!(?e, "failed to serialize outbound frame");
285 continue;
286 }
287 },
288 OutboundMsg::System(op) => match serde_json::to_string(&op) {
289 Ok(json) => WsMessage::Text(json),
290 Err(e) => {
291 warn!(?e, "failed to serialize outbound system op");
292 continue;
293 }
294 },
295 OutboundMsg::Close => {
296 let _ = ws_tx.send(WsMessage::Close(None)).await;
297 break;
298 }
299 };
300 if ws_tx.send(ws_msg).await.is_err() {
301 break;
302 }
303 }
304 });
305
306 let transport = ActiveTransport::WebSocket {
307 outbound_tx,
308 reader_handle,
309 writer_handle,
310 };
311
312 Ok((transport, inbound_rx))
313}
314
315pub(crate) async fn connect_sse<C: ChannelKind>(
321 url: &str,
322 client_id: Uuid,
323 token: Option<&str>,
324 capabilities: &[C],
325 resume_cursor: Option<u64>,
326) -> Result<(ActiveTransport<C>, mpsc::Receiver<InboundMsg<C>>), ConnectError> {
327 let http = HttpClient::new();
328
329 let mut sse_url = format!("{url}/rps/sse?client_id={client_id}");
330 if let Some(tok) = token {
331 sse_url.push_str(&format!("&token={tok}"));
332 }
333 if !capabilities.is_empty() {
334 let caps: Vec<&str> = capabilities.iter().map(|c| c.name()).collect();
335 sse_url.push_str(&format!("&capabilities={}", caps.join(",")));
336 sse_url.push_str(&format!("&channels={}", caps.join(",")));
337 }
338 if let Some(cursor) = resume_cursor {
339 sse_url.push_str(&format!("&resume_cursor={cursor}"));
340 }
341
342 let response = http
343 .get(&sse_url)
344 .send()
345 .await
346 .map_err(|e| ConnectError::Transport(format!("SSE connect failed: {e}")))?;
347
348 if !response.status().is_success() {
349 return Err(ConnectError::AuthRejected(format!(
350 "SSE returned {}",
351 response.status()
352 )));
353 }
354
355 let (inbound_tx, inbound_rx) = mpsc::channel::<InboundMsg<C>>(256);
356 let ack_url = format!("{url}/rps/ack");
357
358 let reader_handle = tokio::spawn(async move {
360 let mut stream = response.bytes_stream();
361 let mut buffer = String::new();
362 let mut event_type = String::new();
363 let mut data_lines = Vec::<String>::new();
364
365 while let Some(chunk) = stream.next().await {
366 let bytes = match chunk {
367 Ok(b) => b,
368 Err(e) => {
369 warn!(?e, "SSE stream error");
370 let _ = inbound_tx.send(InboundMsg::Closed).await;
371 break;
372 }
373 };
374
375 buffer.push_str(&String::from_utf8_lossy(&bytes));
376
377 while let Some(newline_pos) = buffer.find('\n') {
379 let line = buffer[..newline_pos].trim_end_matches('\r').to_string();
380 buffer = buffer[newline_pos + 1..].to_string();
381
382 if line.is_empty() {
383 if !data_lines.is_empty() && (event_type == "frame" || event_type.is_empty()) {
385 let data = data_lines.join("\n");
386 if let Ok(frame) = serde_json::from_str::<Frame<C>>(&data) {
387 if frame.channel.is_system() {
388 if let Ok(op) =
389 serde_json::from_value::<SystemOp<C>>(frame.payload.clone())
390 {
391 let _ = inbound_tx.send(InboundMsg::System(op)).await;
392 } else {
393 let _ = inbound_tx.send(InboundMsg::Frame(frame)).await;
394 }
395 } else {
396 let _ = inbound_tx.send(InboundMsg::Frame(frame)).await;
397 }
398 }
399 }
400 event_type.clear();
401 data_lines.clear();
402 } else if let Some(value) = line.strip_prefix("event:") {
403 event_type = value.trim().to_string();
404 } else if let Some(value) = line.strip_prefix("data:") {
405 data_lines.push(value.trim_start().to_string());
406 }
407 }
409 }
410 });
411
412 let transport = ActiveTransport::Sse {
413 http: HttpClient::new(),
414 ack_url,
415 client_id,
416 reader_handle,
417 };
418
419 Ok((transport, inbound_rx))
420}
421
422pub(crate) async fn connect_with_preference<C: ChannelKind>(
428 preference: TransportPreference,
429 url: &str,
430 client_id: Uuid,
431 token: Option<&str>,
432 capabilities: &[C],
433 resume_cursors: HashMap<C, u64>,
434) -> Result<(ActiveTransport<C>, mpsc::Receiver<InboundMsg<C>>), ConnectError> {
435 let global_cursor = resume_cursors.values().copied().max();
436
437 match preference {
438 TransportPreference::WsOnly => {
439 connect_ws(url, client_id, token, capabilities, resume_cursors).await
440 }
441 TransportPreference::SseOnly => {
442 connect_sse(url, client_id, token, capabilities, global_cursor).await
443 }
444 TransportPreference::WsFirst => {
445 match connect_ws(url, client_id, token, capabilities, resume_cursors.clone()).await {
446 Ok(result) => Ok(result),
447 Err(ws_err) => {
448 debug!(?ws_err, "WS failed, falling back to SSE");
449 connect_sse(url, client_id, token, capabilities, global_cursor).await
450 }
451 }
452 }
453 TransportPreference::SseFirst => {
454 match connect_sse(url, client_id, token, capabilities, global_cursor).await {
455 Ok(result) => Ok(result),
456 Err(sse_err) => {
457 debug!(?sse_err, "SSE failed, falling back to WS");
458 connect_ws(url, client_id, token, capabilities, resume_cursors).await
459 }
460 }
461 }
462 }
463}
464
465fn http_to_ws_url(url: &str) -> String {
470 if let Some(rest) = url.strip_prefix("http://") {
471 format!("ws://{rest}")
472 } else if let Some(rest) = url.strip_prefix("https://") {
473 format!("wss://{rest}")
474 } else if url.starts_with("ws://") || url.starts_with("wss://") {
475 url.to_string()
476 } else {
477 format!("ws://{url}")
478 }
479}