Skip to main content

client_core/
ws.rs

1//! Slim WebSocket subscriber for the relay's `/ws` endpoint.
2//!
3//! Phase 5 ships this for `cinch pull --watch`. It is intentionally
4//! narrower than `desktop/src-tauri/src/ws.rs` (which still owns its
5//! tauri-coupled lifecycle, tray status, db inserts, image fetch, and
6//! key-exchange responder). Once that desktop logic gets refactored to
7//! consume callbacks from this module, the duplicated reconnect/decrypt
8//! plumbing inside desktop's ws.rs can shrink to a thin event bridge.
9//!
10//! The client connects with bearer-token auth via the URL query string
11//! (`?token=...`), reads frames, decodes `WSMessage`, decrypts clip
12//! content when `encrypted=true`, and forwards every interesting message
13//! through an `mpsc::Sender<WsEvent>` provided by the caller.
14
15use std::time::Duration;
16
17use base64::engine::general_purpose::STANDARD;
18use base64::Engine;
19use futures_util::{SinkExt, StreamExt};
20use tokio::sync::mpsc;
21use tokio::time::sleep;
22use tokio_tungstenite::tungstenite::Message;
23use tokio_tungstenite::{connect_async, tungstenite};
24use tracing::{debug, info, warn};
25
26use reqwest;
27use serde_json;
28
29use crate::crypto;
30use crate::protocol::{
31    Clip, WSMessage, ACTION_CLIP_DELETED, ACTION_KEY_EXCHANGE_REQUESTED, ACTION_NEW_CLIP,
32    ACTION_PING, ACTION_REVOKED, ACTION_TOKEN_ROTATED,
33};
34use crate::version::ClientInfo;
35
36#[derive(Debug, Clone)]
37pub enum WsEvent {
38    /// Connection state transitions — emitted on connect, disconnect, retry.
39    Status(WsStatus),
40    /// New clip received. `plaintext` is the decrypted body for encrypted
41    /// clips (already base64-decoded for binary), or the raw `clip.content`
42    /// when no encryption key was available or `encrypted=false`.
43    NewClip { clip: Box<Clip>, plaintext: Vec<u8> },
44    /// Clip deleted on the relay (e.g., retention sweep, manual delete).
45    ClipDeleted { clip_id: String },
46    /// The caller's device was revoked. Future reconnects will 401.
47    Revoked { reason: Option<String> },
48    /// Server rotated this device's token. Caller should persist the new
49    /// token and reconnect with it.
50    TokenRotated {
51        token: String,
52        device_id: Option<String>,
53    },
54    /// Another device asked for a key bundle. Desktop handles the ECDH
55    /// responder; CLI watchers can ignore.
56    KeyExchangeRequested { device_id: Option<String> },
57    /// Incoming clip could not be decrypted (missing key or wrong key).
58    /// The clip was NOT inserted as plaintext. Callers should surface this
59    /// to the user and fire `retry_key_bundle`.
60    ClipDecryptFailed {
61        clip_id: String,
62        reason: DecryptFailReason,
63    },
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
67pub enum WsStatus {
68    Connecting,
69    Connected,
70    Disconnected,
71}
72
73#[derive(Debug, thiserror::Error)]
74pub enum WsError {
75    #[error("ws: {0}")]
76    Tungstenite(#[from] tungstenite::Error),
77    #[error("decode: {0}")]
78    Decode(String),
79}
80
81#[derive(Debug, Clone)]
82pub struct WsConfig {
83    /// Base relay URL (http/https). The ws module exchanges a short-lived
84    /// ticket from POST /ws/ticket before connecting; the bearer token
85    /// never appears in the WebSocket URL or server access logs.
86    pub relay_url: String,
87    /// Bearer token used to obtain a WS ticket from POST /ws/ticket.
88    pub token: String,
89    /// 32-byte AES key used to decrypt incoming `encrypted=true` clips.
90    /// Pass `None` to skip decryption (encrypted clips reach the caller
91    /// with `clip.content` set to ciphertext).
92    pub encryption_key: Option<[u8; 32]>,
93    /// Optional self-identification. When `Some`, the client sends a single
94    /// `client_hello` WS frame right after connect (before entering the read
95    /// loop) so the relay can record the caller's binary type + version
96    /// without a follow-up REST call. `None` preserves the legacy behavior
97    /// of staying silent until the first server frame arrives.
98    pub client_info: Option<ClientInfo>,
99}
100
101/// Outcome of a single decrypt attempt on an incoming clip.
102#[derive(Debug, Clone, PartialEq, Eq)]
103pub enum DecryptOutcome {
104    /// Clip was not flagged encrypted; nothing to do.
105    Plaintext,
106    /// Successfully decrypted; `clip.content` + `clip.encrypted` mutated in place.
107    Decoded,
108    /// No AES key in local store; clip left untouched.
109    MissingKey,
110    /// AES-GCM tag verification failed (likely key mismatch); clip left untouched.
111    TagFailed { error: String },
112}
113
114#[derive(Debug, Clone, PartialEq, Eq)]
115pub enum DecryptFailReason {
116    MissingKey,
117    TagFailed(String),
118}
119
120/// Connect to the relay and forward decoded events to `tx` until the
121/// caller drops the receiver. Reconnects on socket error with exponential
122/// backoff (1s, 2s, 4s … capped at 30s). Returns when `tx` is closed.
123pub async fn run(cfg: WsConfig, tx: mpsc::Sender<WsEvent>) {
124    let mut attempt = 0u32;
125    loop {
126        if tx.is_closed() {
127            return;
128        }
129        let _ = tx.send(WsEvent::Status(WsStatus::Connecting)).await;
130        match connect_and_listen(&cfg, &tx).await {
131            Ok(()) => {
132                debug!("ws: closed cleanly");
133                attempt = 0;
134            }
135            Err(e) => {
136                warn!("ws error: {}", e);
137                attempt = attempt.saturating_add(1);
138            }
139        }
140        let _ = tx.send(WsEvent::Status(WsStatus::Disconnected)).await;
141        let backoff_secs = 1u64 << attempt.min(5); // 1, 2, 4, 8, 16, 32 ...
142        sleep(Duration::from_secs(backoff_secs.min(30))).await;
143    }
144}
145
146/// Fetch a short-lived single-use WebSocket ticket from the relay.
147/// Calls POST /ws/ticket with a Bearer auth header; returns the hex ticket string.
148async fn fetch_ws_ticket(relay_url: &str, token: &str) -> Result<String, WsError> {
149    let ticket_url = format!("{}/ws/ticket", relay_url.trim_end_matches('/'));
150    let client = reqwest::Client::builder()
151        .timeout(std::time::Duration::from_secs(10))
152        .build()
153        .map_err(|e| WsError::Decode(format!("build http client: {}", e)))?;
154    let resp = client
155        .post(&ticket_url)
156        .bearer_auth(token)
157        .send()
158        .await
159        .map_err(|e| WsError::Decode(format!("ticket request: {}", e)))?;
160    if !resp.status().is_success() {
161        return Err(WsError::Decode(format!(
162            "ticket endpoint returned {}",
163            resp.status()
164        )));
165    }
166    let body: serde_json::Value = resp
167        .json()
168        .await
169        .map_err(|e| WsError::Decode(format!("parse ticket response: {}", e)))?;
170    body["ticket"]
171        .as_str()
172        .map(|s| s.to_string())
173        .ok_or_else(|| WsError::Decode("no ticket in response".into()))
174}
175
176async fn connect_and_listen(cfg: &WsConfig, tx: &mpsc::Sender<WsEvent>) -> Result<(), WsError> {
177    let ticket = fetch_ws_ticket(&cfg.relay_url, &cfg.token).await?;
178    let ws_base = cfg
179        .relay_url
180        .replace("https://", "wss://")
181        .replace("http://", "ws://");
182    let ws_url = format!("{}/ws?ticket={}", ws_base.trim_end_matches('/'), ticket);
183    let (ws_stream, _) = connect_async(&ws_url).await?;
184    info!("ws connected");
185    log::info!("ws connected"); // also via log crate so env_logger captures it
186    let _ = tx.send(WsEvent::Status(WsStatus::Connected)).await;
187
188    let (mut write, mut read) = ws_stream.split();
189
190    // Self-identify before the read loop. The relay's hub dispatches by
191    // action — `client_hello` is just one of the actions it expects to see
192    // and records into the devices table. On marshal/send failure we let the
193    // outer `run` reconnect with backoff; if the payload itself is malformed
194    // the warn-level log surfaces it.
195    if let Some(info) = cfg.client_info.as_ref() {
196        let hello = info.client_hello_message();
197        match serde_json::to_string(&hello) {
198            Ok(text) => {
199                write.send(Message::Text(text.into())).await?;
200            }
201            Err(e) => {
202                warn!("ws: failed to serialize client_hello: {}", e);
203            }
204        }
205    }
206
207    while let Some(frame) = read.next().await {
208        let msg = frame?;
209        match msg {
210            Message::Text(text) => {
211                if let Some(event) = decode_message(text.as_str(), cfg.encryption_key) {
212                    if tx.send(event).await.is_err() {
213                        return Ok(());
214                    }
215                }
216            }
217            Message::Ping(data) => {
218                write.send(Message::Pong(data)).await?;
219            }
220            Message::Close(_) => {
221                debug!("relay sent close");
222                return Ok(());
223            }
224            _ => {}
225        }
226    }
227    Ok(())
228}
229
230fn decode_message(text: &str, key: Option<[u8; 32]>) -> Option<WsEvent> {
231    let msg: WSMessage = match serde_json::from_str(text) {
232        Ok(m) => m,
233        Err(e) => {
234            warn!("ws: bad message: {}", e);
235            return None;
236        }
237    };
238
239    match msg.action.as_str() {
240        ACTION_NEW_CLIP => {
241            let mut clip = msg.clip?;
242            match decrypt_clip_content(&mut clip, key) {
243                DecryptOutcome::Plaintext => {
244                    let plaintext = clip.content.as_bytes().to_vec();
245                    Some(WsEvent::NewClip {
246                        clip: Box::new(clip),
247                        plaintext,
248                    })
249                }
250                DecryptOutcome::Decoded => {
251                    let plaintext = clip.content.as_bytes().to_vec();
252                    Some(WsEvent::NewClip {
253                        clip: Box::new(clip),
254                        plaintext,
255                    })
256                }
257                DecryptOutcome::MissingKey => Some(WsEvent::ClipDecryptFailed {
258                    clip_id: clip.clip_id,
259                    reason: DecryptFailReason::MissingKey,
260                }),
261                DecryptOutcome::TagFailed { error } => Some(WsEvent::ClipDecryptFailed {
262                    clip_id: clip.clip_id,
263                    reason: DecryptFailReason::TagFailed(error),
264                }),
265            }
266        }
267        ACTION_CLIP_DELETED => Some(WsEvent::ClipDeleted {
268            clip_id: msg.clip.map(|c| c.clip_id).unwrap_or_default(),
269        }),
270        ACTION_REVOKED => Some(WsEvent::Revoked { reason: msg.reason }),
271        ACTION_TOKEN_ROTATED => msg.token.map(|t| WsEvent::TokenRotated {
272            token: t,
273            device_id: msg.device_id,
274        }),
275        ACTION_KEY_EXCHANGE_REQUESTED => {
276            log::info!(
277                "ws: decoded key_exchange_requested device_id={:?}",
278                msg.device_id
279            );
280            Some(WsEvent::KeyExchangeRequested {
281                device_id: msg.device_id,
282            })
283        }
284        ACTION_PING => None, // server pings handled by tungstenite Pong frames
285        _ => None,
286    }
287}
288
289/// Decrypt `clip.content` in place if `clip.encrypted` and a key is available.
290/// Returns a typed outcome — never silently returns ciphertext as plaintext.
291pub fn decrypt_clip_content(clip: &mut Clip, key: Option<[u8; 32]>) -> DecryptOutcome {
292    if !clip.encrypted {
293        return DecryptOutcome::Plaintext;
294    }
295    let Some(key) = key else {
296        return DecryptOutcome::MissingKey;
297    };
298    let plaintext = match crypto::decrypt(&key, &clip.content) {
299        Ok(p) => p,
300        Err(e) => {
301            return DecryptOutcome::TagFailed {
302                error: e.to_string(),
303            }
304        }
305    };
306    let is_binary = clip
307        .media_path
308        .as_deref()
309        .filter(|p| !p.is_empty())
310        .is_some()
311        || clip.content_type.starts_with("image");
312    if is_binary {
313        // Re-encode as base64 so the struct stays a valid String.
314        clip.content = STANDARD.encode(&plaintext);
315    } else {
316        match String::from_utf8(plaintext) {
317            Ok(s) => clip.content = s,
318            Err(e) => {
319                return DecryptOutcome::TagFailed {
320                    error: format!("post-decrypt utf-8 invalid: {e}"),
321                }
322            }
323        }
324    }
325    clip.encrypted = false;
326    DecryptOutcome::Decoded
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332
333    fn make_msg(action: &str, body: serde_json::Value) -> String {
334        let mut v = body;
335        v.as_object_mut()
336            .unwrap()
337            .insert("action".into(), serde_json::Value::String(action.into()));
338        serde_json::to_string(&v).unwrap()
339    }
340
341    #[test]
342    fn decodes_new_clip_unencrypted() {
343        let json = make_msg(
344            ACTION_NEW_CLIP,
345            serde_json::json!({
346                "clip": {
347                    "clip_id": "01H",
348                    "user_id": "u1",
349                    "content": "hello",
350                    "content_type": "text",
351                    "source": "remote:host",
352                    "created_at": "2026-04-30T00:00:00Z",
353                    "encrypted": false
354                }
355            }),
356        );
357        match decode_message(&json, None).unwrap() {
358            WsEvent::NewClip { clip, plaintext } => {
359                assert_eq!(clip.clip_id, "01H");
360                assert_eq!(plaintext, b"hello");
361            }
362            other => panic!("unexpected event: {:?}", other),
363        }
364    }
365
366    #[test]
367    fn decodes_revoked() {
368        let json = make_msg(
369            ACTION_REVOKED,
370            serde_json::json!({"reason": "device removed"}),
371        );
372        match decode_message(&json, None).unwrap() {
373            WsEvent::Revoked { reason } => assert_eq!(reason.as_deref(), Some("device removed")),
374            other => panic!("unexpected event: {:?}", other),
375        }
376    }
377
378    #[test]
379    fn decodes_clip_deleted() {
380        let json = make_msg(
381            ACTION_CLIP_DELETED,
382            serde_json::json!({
383                "clip": {
384                    "clip_id": "delme",
385                    "user_id": "u1",
386                    "content": "",
387                    "content_type": "text",
388                    "source": "local",
389                    "created_at": "2026-04-30T00:00:00Z"
390                }
391            }),
392        );
393        match decode_message(&json, None).unwrap() {
394            WsEvent::ClipDeleted { clip_id } => assert_eq!(clip_id, "delme"),
395            other => panic!("unexpected event: {:?}", other),
396        }
397    }
398
399    #[test]
400    fn decrypts_text_clip_with_key() {
401        let key = [0x42u8; 32];
402        let ciphertext = crypto::encrypt(&key, b"secret payload").unwrap();
403        let json = make_msg(
404            ACTION_NEW_CLIP,
405            serde_json::json!({
406                "clip": {
407                    "clip_id": "01H",
408                    "user_id": "u1",
409                    "content": ciphertext,
410                    "content_type": "text",
411                    "source": "remote:host",
412                    "created_at": "2026-04-30T00:00:00Z",
413                    "encrypted": true
414                }
415            }),
416        );
417        match decode_message(&json, Some(key)).unwrap() {
418            WsEvent::NewClip { clip, plaintext } => {
419                assert_eq!(plaintext, b"secret payload");
420                assert!(!clip.encrypted);
421                assert_eq!(clip.content, "secret payload");
422            }
423            other => panic!("unexpected event: {:?}", other),
424        }
425    }
426
427    #[test]
428    fn decrypt_failure_does_not_silently_return_ciphertext() {
429        let sender_key = [0x11u8; 32];
430        let receiver_key = [0x22u8; 32];
431        let blob = crypto::encrypt(&sender_key, b"hello from remote cli").unwrap();
432
433        let mut clip = Clip {
434            clip_id: "c1".into(),
435            user_id: "u1".into(),
436            content: blob.clone(),
437            content_type: String::new(),
438            encrypted: true,
439            ..Default::default()
440        };
441
442        let outcome = decrypt_clip_content(&mut clip, Some(receiver_key));
443
444        assert!(
445            matches!(outcome, DecryptOutcome::TagFailed { .. }),
446            "wrong-key decrypt must return TagFailed, got {:?}",
447            outcome
448        );
449        assert!(clip.encrypted, "encrypted flag must remain true on failure");
450        assert_eq!(
451            clip.content, blob,
452            "content must not be replaced with garbage plaintext"
453        );
454    }
455
456    #[test]
457    fn decrypt_missing_key_returns_missing_key_outcome() {
458        let sender_key = [0x33u8; 32];
459        let blob = crypto::encrypt(&sender_key, b"secret").unwrap();
460
461        let mut clip = Clip {
462            clip_id: "c2".into(),
463            user_id: "u1".into(),
464            content: blob.clone(),
465            content_type: String::new(),
466            encrypted: true,
467            ..Default::default()
468        };
469
470        let outcome = decrypt_clip_content(&mut clip, None);
471        assert_eq!(outcome, DecryptOutcome::MissingKey);
472        assert!(
473            clip.encrypted,
474            "clip must remain encrypted when key is missing"
475        );
476        assert_eq!(
477            clip.content, blob,
478            "content must be untouched when key is missing"
479        );
480    }
481
482    #[test]
483    fn wrong_key_via_decode_message_emits_clip_decrypt_failed() {
484        let sender_key = [0x44u8; 32];
485        let receiver_key = [0x55u8; 32];
486        let blob = crypto::encrypt(&sender_key, b"payload").unwrap();
487
488        let json = make_msg(
489            ACTION_NEW_CLIP,
490            serde_json::json!({
491                "clip": {
492                    "clip_id": "bad-clip",
493                    "user_id": "u1",
494                    "content": blob,
495                    "content_type": "text",
496                    "source": "remote:host",
497                    "created_at": "2026-04-30T00:00:00Z",
498                    "encrypted": true
499                }
500            }),
501        );
502        match decode_message(&json, Some(receiver_key)).unwrap() {
503            WsEvent::ClipDecryptFailed { clip_id, reason } => {
504                assert_eq!(clip_id, "bad-clip");
505                assert!(matches!(reason, DecryptFailReason::TagFailed(_)));
506            }
507            other => panic!("expected ClipDecryptFailed, got {:?}", other),
508        }
509    }
510
511    #[test]
512    fn missing_key_via_decode_message_emits_clip_decrypt_failed() {
513        let sender_key = [0x66u8; 32];
514        let blob = crypto::encrypt(&sender_key, b"payload").unwrap();
515
516        let json = make_msg(
517            ACTION_NEW_CLIP,
518            serde_json::json!({
519                "clip": {
520                    "clip_id": "no-key-clip",
521                    "user_id": "u1",
522                    "content": blob,
523                    "content_type": "text",
524                    "source": "remote:host",
525                    "created_at": "2026-04-30T00:00:00Z",
526                    "encrypted": true
527                }
528            }),
529        );
530        match decode_message(&json, None).unwrap() {
531            WsEvent::ClipDecryptFailed { clip_id, reason } => {
532                assert_eq!(clip_id, "no-key-clip");
533                assert_eq!(reason, DecryptFailReason::MissingKey);
534            }
535            other => panic!("expected ClipDecryptFailed, got {:?}", other),
536        }
537    }
538}