Skip to main content

pushwire_client/
connection.rs

1use std::collections::HashMap;
2
3use futures_util::{SinkExt, StreamExt};
4use pushwire_core::{ChannelKind, Frame, SystemOp};
5use reqwest::Client as HttpClient;
6use tokio::sync::mpsc;
7use tokio::task::JoinHandle;
8use tokio_tungstenite::tungstenite::Message as WsMessage;
9use tracing::{debug, warn};
10use uuid::Uuid;
11
12use crate::session::ConnectError;
13
14/// Transport preference for the client connection.
15#[non_exhaustive]
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum TransportPreference {
18    /// Try WebSocket first, fall back to SSE if unavailable.
19    WsFirst,
20    /// Try SSE first, fall back to WebSocket if unavailable.
21    SseFirst,
22    /// WebSocket only — fail if upgrade is rejected.
23    WsOnly,
24    /// SSE only — client→server via POST /ack endpoint.
25    SseOnly,
26}
27
28/// Message sent from session to the writer task.
29#[derive(Debug)]
30pub(crate) enum OutboundMsg<C: ChannelKind> {
31    Frame(Frame<C>),
32    System(SystemOp<C>),
33    Close,
34}
35
36/// Message received from transport and forwarded to session processor.
37#[derive(Debug)]
38pub(crate) enum InboundMsg<C: ChannelKind> {
39    Frame(Frame<C>),
40    System(SystemOp<C>),
41    Closed,
42}
43
44/// Active transport handle. Owns background tasks for I/O.
45pub(crate) enum ActiveTransport<C: ChannelKind> {
46    WebSocket {
47        outbound_tx: mpsc::Sender<OutboundMsg<C>>,
48        reader_handle: JoinHandle<()>,
49        writer_handle: JoinHandle<()>,
50    },
51    Sse {
52        http: HttpClient,
53        ack_url: String,
54        client_id: Uuid,
55        reader_handle: JoinHandle<()>,
56    },
57}
58
59impl<C: ChannelKind> ActiveTransport<C> {
60    /// Send a frame to the server (WebSocket only).
61    pub(crate) async fn send_frame(
62        &self,
63        frame: Frame<C>,
64    ) -> Result<(), crate::session::SendError> {
65        match self {
66            ActiveTransport::WebSocket { outbound_tx, .. } => outbound_tx
67                .send(OutboundMsg::Frame(frame))
68                .await
69                .map_err(|_| crate::session::SendError::ChannelClosed),
70            ActiveTransport::Sse { .. } => Err(crate::session::SendError::NotConnected),
71        }
72    }
73
74    /// Send a system op to the server.
75    pub(crate) async fn send_system(
76        &self,
77        op: SystemOp<C>,
78    ) -> Result<(), crate::session::SendError> {
79        match self {
80            ActiveTransport::WebSocket { outbound_tx, .. } => outbound_tx
81                .send(OutboundMsg::System(op))
82                .await
83                .map_err(|_| crate::session::SendError::ChannelClosed),
84            ActiveTransport::Sse {
85                http,
86                ack_url,
87                client_id,
88                ..
89            } => {
90                // SSE mode: only ACKs are supported via POST.
91                if let SystemOp::Ack { channel, cursor } = &op {
92                    let body = serde_json::json!({
93                        "client_id": client_id,
94                        "channel": channel,
95                        "cursor": cursor,
96                    });
97                    let _ = http.post(ack_url).json(&body).send().await;
98                    Ok(())
99                } else {
100                    // Other system ops not supported in SSE mode.
101                    warn!("system op not supported in SSE mode, dropping");
102                    Ok(())
103                }
104            }
105        }
106    }
107
108    /// Send close signal and abort tasks.
109    pub(crate) async fn close(self) {
110        match self {
111            ActiveTransport::WebSocket {
112                outbound_tx,
113                reader_handle,
114                writer_handle,
115            } => {
116                let _ = outbound_tx.send(OutboundMsg::Close).await;
117                // Give writer a moment to send the close frame, then abort.
118                tokio::time::sleep(std::time::Duration::from_millis(100)).await;
119                reader_handle.abort();
120                writer_handle.abort();
121            }
122            ActiveTransport::Sse { reader_handle, .. } => {
123                reader_handle.abort();
124            }
125        }
126    }
127}
128
129// ---------------------------------------------------------------------------
130// WebSocket transport
131// ---------------------------------------------------------------------------
132
133/// Connect via WebSocket, perform auth, return transport + inbound channel.
134pub(crate) async fn connect_ws<C: ChannelKind>(
135    url: &str,
136    client_id: Uuid,
137    token: Option<&str>,
138    capabilities: &[C],
139    resume_cursors: HashMap<C, u64>,
140) -> Result<(ActiveTransport<C>, mpsc::Receiver<InboundMsg<C>>), ConnectError> {
141    // Convert HTTP URL to WebSocket URL.
142    let ws_url = http_to_ws_url(url);
143    let rps_url = format!("{ws_url}/rps");
144
145    let (ws_stream, _response) = tokio_tungstenite::connect_async(&rps_url)
146        .await
147        .map_err(|e| ConnectError::Transport(format!("WebSocket connect failed: {e}")))?;
148
149    let (mut ws_tx, mut ws_rx) = ws_stream.split();
150
151    // --- Auth handshake ---
152    let global_cursor = resume_cursors.values().copied().max();
153    let auth = SystemOp::<C>::Auth {
154        client_id,
155        token: token.map(String::from),
156        capabilities: capabilities.to_vec(),
157        resume_cursor: global_cursor,
158        resume_cursors: resume_cursors.clone(),
159    };
160    let auth_json =
161        serde_json::to_string(&auth).map_err(|e| ConnectError::Transport(e.to_string()))?;
162    ws_tx
163        .send(WsMessage::Text(auth_json))
164        .await
165        .map_err(|e| ConnectError::Transport(format!("failed to send auth: {e}")))?;
166
167    // Wait for AuthOk.
168    let auth_reply = ws_rx
169        .next()
170        .await
171        .ok_or(ConnectError::Transport(
172            "connection closed before auth reply".into(),
173        ))?
174        .map_err(|e| ConnectError::Transport(format!("auth reply read error: {e}")))?;
175
176    let auth_ok: SystemOp<C> = match auth_reply {
177        WsMessage::Text(text) => serde_json::from_str(&text)
178            .map_err(|e| ConnectError::AuthRejected(format!("invalid auth reply: {e}")))?,
179        WsMessage::Close(frame) => {
180            let reason = frame
181                .map(|f| f.reason.to_string())
182                .unwrap_or_else(|| "unknown".into());
183            return Err(ConnectError::AuthRejected(reason));
184        }
185        other => {
186            return Err(ConnectError::Transport(format!(
187                "unexpected auth reply type: {other:?}"
188            )));
189        }
190    };
191
192    match auth_ok {
193        SystemOp::AuthOk { .. } => {
194            debug!(?client_id, "auth handshake complete");
195        }
196        SystemOp::Error { message } => return Err(ConnectError::AuthRejected(message)),
197        other => {
198            return Err(ConnectError::AuthRejected(format!(
199                "expected AuthOk, got {other:?}"
200            )));
201        }
202    }
203
204    // --- Spawn reader + writer tasks ---
205    let (inbound_tx, inbound_rx) = mpsc::channel::<InboundMsg<C>>(256);
206    let (outbound_tx, mut outbound_rx) = mpsc::channel::<OutboundMsg<C>>(64);
207
208    // Reader: WS → inbound channel.
209    let reader_inbound_tx = inbound_tx.clone();
210    let reader_handle = tokio::spawn(async move {
211        while let Some(msg) = ws_rx.next().await {
212            match msg {
213                Ok(WsMessage::Text(text)) => {
214                    // Try parsing as SystemOp first (system channel frames wrap
215                    // SystemOp in the payload), then as regular Frame.
216                    if let Ok(frame) = serde_json::from_str::<Frame<C>>(&text) {
217                        if frame.channel.is_system() {
218                            // System channel: extract the SystemOp from payload.
219                            if let Ok(op) =
220                                serde_json::from_value::<SystemOp<C>>(frame.payload.clone())
221                            {
222                                if reader_inbound_tx
223                                    .send(InboundMsg::System(op))
224                                    .await
225                                    .is_err()
226                                {
227                                    break;
228                                }
229                            } else {
230                                // Might be a regular system-channel frame.
231                                if reader_inbound_tx
232                                    .send(InboundMsg::Frame(frame))
233                                    .await
234                                    .is_err()
235                                {
236                                    break;
237                                }
238                            }
239                        } else if reader_inbound_tx
240                            .send(InboundMsg::Frame(frame))
241                            .await
242                            .is_err()
243                        {
244                            break;
245                        }
246                    } else if let Ok(op) = serde_json::from_str::<SystemOp<C>>(&text) {
247                        // Server sends SystemOps directly (not wrapped in Frame)
248                        // for WebSocket connections.
249                        if reader_inbound_tx
250                            .send(InboundMsg::System(op))
251                            .await
252                            .is_err()
253                        {
254                            break;
255                        }
256                    } else {
257                        warn!("failed to parse inbound WS message");
258                    }
259                }
260                Ok(WsMessage::Close(_)) => {
261                    let _ = reader_inbound_tx.send(InboundMsg::Closed).await;
262                    break;
263                }
264                Ok(WsMessage::Ping(_) | WsMessage::Pong(_)) => {
265                    // tungstenite handles ping/pong at the protocol level.
266                }
267                Ok(_) => {}
268                Err(e) => {
269                    warn!(?e, "WS read error");
270                    let _ = reader_inbound_tx.send(InboundMsg::Closed).await;
271                    break;
272                }
273            }
274        }
275    });
276
277    // Writer: outbound channel → WS.
278    let writer_handle = tokio::spawn(async move {
279        while let Some(msg) = outbound_rx.recv().await {
280            let ws_msg = match msg {
281                OutboundMsg::Frame(frame) => match serde_json::to_string(&frame) {
282                    Ok(json) => WsMessage::Text(json),
283                    Err(e) => {
284                        warn!(?e, "failed to serialize outbound frame");
285                        continue;
286                    }
287                },
288                OutboundMsg::System(op) => match serde_json::to_string(&op) {
289                    Ok(json) => WsMessage::Text(json),
290                    Err(e) => {
291                        warn!(?e, "failed to serialize outbound system op");
292                        continue;
293                    }
294                },
295                OutboundMsg::Close => {
296                    let _ = ws_tx.send(WsMessage::Close(None)).await;
297                    break;
298                }
299            };
300            if ws_tx.send(ws_msg).await.is_err() {
301                break;
302            }
303        }
304    });
305
306    let transport = ActiveTransport::WebSocket {
307        outbound_tx,
308        reader_handle,
309        writer_handle,
310    };
311
312    Ok((transport, inbound_rx))
313}
314
315// ---------------------------------------------------------------------------
316// SSE transport
317// ---------------------------------------------------------------------------
318
319/// Connect via SSE, return transport + inbound channel.
320pub(crate) async fn connect_sse<C: ChannelKind>(
321    url: &str,
322    client_id: Uuid,
323    token: Option<&str>,
324    capabilities: &[C],
325    resume_cursor: Option<u64>,
326) -> Result<(ActiveTransport<C>, mpsc::Receiver<InboundMsg<C>>), ConnectError> {
327    let http = HttpClient::new();
328
329    let mut sse_url = format!("{url}/rps/sse?client_id={client_id}");
330    if let Some(tok) = token {
331        sse_url.push_str(&format!("&token={tok}"));
332    }
333    if !capabilities.is_empty() {
334        let caps: Vec<&str> = capabilities.iter().map(|c| c.name()).collect();
335        sse_url.push_str(&format!("&capabilities={}", caps.join(",")));
336        sse_url.push_str(&format!("&channels={}", caps.join(",")));
337    }
338    if let Some(cursor) = resume_cursor {
339        sse_url.push_str(&format!("&resume_cursor={cursor}"));
340    }
341
342    let response = http
343        .get(&sse_url)
344        .send()
345        .await
346        .map_err(|e| ConnectError::Transport(format!("SSE connect failed: {e}")))?;
347
348    if !response.status().is_success() {
349        return Err(ConnectError::AuthRejected(format!(
350            "SSE returned {}",
351            response.status()
352        )));
353    }
354
355    let (inbound_tx, inbound_rx) = mpsc::channel::<InboundMsg<C>>(256);
356    let ack_url = format!("{url}/rps/ack");
357
358    // Reader: parse SSE event stream.
359    let reader_handle = tokio::spawn(async move {
360        let mut stream = response.bytes_stream();
361        let mut buffer = String::new();
362        let mut event_type = String::new();
363        let mut data_lines = Vec::<String>::new();
364
365        while let Some(chunk) = stream.next().await {
366            let bytes = match chunk {
367                Ok(b) => b,
368                Err(e) => {
369                    warn!(?e, "SSE stream error");
370                    let _ = inbound_tx.send(InboundMsg::Closed).await;
371                    break;
372                }
373            };
374
375            buffer.push_str(&String::from_utf8_lossy(&bytes));
376
377            // Process complete lines.
378            while let Some(newline_pos) = buffer.find('\n') {
379                let line = buffer[..newline_pos].trim_end_matches('\r').to_string();
380                buffer = buffer[newline_pos + 1..].to_string();
381
382                if line.is_empty() {
383                    // Empty line = end of event.
384                    if !data_lines.is_empty() && (event_type == "frame" || event_type.is_empty()) {
385                        let data = data_lines.join("\n");
386                        if let Ok(frame) = serde_json::from_str::<Frame<C>>(&data) {
387                            if frame.channel.is_system() {
388                                if let Ok(op) =
389                                    serde_json::from_value::<SystemOp<C>>(frame.payload.clone())
390                                {
391                                    let _ = inbound_tx.send(InboundMsg::System(op)).await;
392                                } else {
393                                    let _ = inbound_tx.send(InboundMsg::Frame(frame)).await;
394                                }
395                            } else {
396                                let _ = inbound_tx.send(InboundMsg::Frame(frame)).await;
397                            }
398                        }
399                    }
400                    event_type.clear();
401                    data_lines.clear();
402                } else if let Some(value) = line.strip_prefix("event:") {
403                    event_type = value.trim().to_string();
404                } else if let Some(value) = line.strip_prefix("data:") {
405                    data_lines.push(value.trim_start().to_string());
406                }
407                // Ignore id:, retry:, comments (:), etc.
408            }
409        }
410    });
411
412    let transport = ActiveTransport::Sse {
413        http: HttpClient::new(),
414        ack_url,
415        client_id,
416        reader_handle,
417    };
418
419    Ok((transport, inbound_rx))
420}
421
422// ---------------------------------------------------------------------------
423// Transport fallback
424// ---------------------------------------------------------------------------
425
426/// Connect with fallback based on preference.
427pub(crate) async fn connect_with_preference<C: ChannelKind>(
428    preference: TransportPreference,
429    url: &str,
430    client_id: Uuid,
431    token: Option<&str>,
432    capabilities: &[C],
433    resume_cursors: HashMap<C, u64>,
434) -> Result<(ActiveTransport<C>, mpsc::Receiver<InboundMsg<C>>), ConnectError> {
435    let global_cursor = resume_cursors.values().copied().max();
436
437    match preference {
438        TransportPreference::WsOnly => {
439            connect_ws(url, client_id, token, capabilities, resume_cursors).await
440        }
441        TransportPreference::SseOnly => {
442            connect_sse(url, client_id, token, capabilities, global_cursor).await
443        }
444        TransportPreference::WsFirst => {
445            match connect_ws(url, client_id, token, capabilities, resume_cursors.clone()).await {
446                Ok(result) => Ok(result),
447                Err(ws_err) => {
448                    debug!(?ws_err, "WS failed, falling back to SSE");
449                    connect_sse(url, client_id, token, capabilities, global_cursor).await
450                }
451            }
452        }
453        TransportPreference::SseFirst => {
454            match connect_sse(url, client_id, token, capabilities, global_cursor).await {
455                Ok(result) => Ok(result),
456                Err(sse_err) => {
457                    debug!(?sse_err, "SSE failed, falling back to WS");
458                    connect_ws(url, client_id, token, capabilities, resume_cursors).await
459                }
460            }
461        }
462    }
463}
464
465// ---------------------------------------------------------------------------
466// Utilities
467// ---------------------------------------------------------------------------
468
469fn http_to_ws_url(url: &str) -> String {
470    if let Some(rest) = url.strip_prefix("http://") {
471        format!("ws://{rest}")
472    } else if let Some(rest) = url.strip_prefix("https://") {
473        format!("wss://{rest}")
474    } else if url.starts_with("ws://") || url.starts_with("wss://") {
475        url.to_string()
476    } else {
477        format!("ws://{url}")
478    }
479}