Skip to main content

client_core/
ws.rs

1//! Slim WebSocket subscriber for the relay's `/ws` endpoint.
2//!
3//! Phase 5 ships this for `cinch pull --watch`. It is intentionally
4//! narrower than `desktop/src-tauri/src/ws.rs` (which still owns its
5//! tauri-coupled lifecycle, tray status, db inserts, image fetch, and
6//! key-exchange responder). Once that desktop logic gets refactored to
7//! consume callbacks from this module, the duplicated reconnect/decrypt
8//! plumbing inside desktop's ws.rs can shrink to a thin event bridge.
9//!
10//! The client connects with bearer-token auth via the URL query string
11//! (`?token=...`), reads frames, decodes `WSMessage`, decrypts clip
12//! content when `encrypted=true`, and forwards every interesting message
13//! through an `mpsc::Sender<WsEvent>` provided by the caller.
14
15use std::time::Duration;
16
17use base64::engine::general_purpose::STANDARD;
18use base64::Engine;
19use futures_util::{SinkExt, StreamExt};
20use tokio::sync::mpsc;
21use tokio::time::sleep;
22use tokio_tungstenite::tungstenite::Message;
23use tokio_tungstenite::{connect_async, tungstenite};
24use tracing::{debug, info, warn};
25
26use reqwest;
27use serde_json;
28
29use crate::crypto;
30use crate::protocol::{
31    Clip, WSMessage, ACTION_CLIP_DELETED, ACTION_KEY_EXCHANGE_REQUESTED, ACTION_NEW_CLIP,
32    ACTION_PING, ACTION_REVOKED, ACTION_TOKEN_ROTATED,
33};
34
35#[derive(Debug, Clone)]
36pub enum WsEvent {
37    /// Connection state transitions — emitted on connect, disconnect, retry.
38    Status(WsStatus),
39    /// New clip received. `plaintext` is the decrypted body for encrypted
40    /// clips (already base64-decoded for binary), or the raw `clip.content`
41    /// when no encryption key was available or `encrypted=false`.
42    NewClip { clip: Box<Clip>, plaintext: Vec<u8> },
43    /// Clip deleted on the relay (e.g., retention sweep, manual delete).
44    ClipDeleted { clip_id: String },
45    /// The caller's device was revoked. Future reconnects will 401.
46    Revoked { reason: Option<String> },
47    /// Server rotated this device's token. Caller should persist the new
48    /// token and reconnect with it.
49    TokenRotated {
50        token: String,
51        device_id: Option<String>,
52    },
53    /// Another device asked for a key bundle. Desktop handles the ECDH
54    /// responder; CLI watchers can ignore.
55    KeyExchangeRequested { device_id: Option<String> },
56    /// Incoming clip could not be decrypted (missing key or wrong key).
57    /// The clip was NOT inserted as plaintext. Callers should surface this
58    /// to the user and fire `retry_key_bundle`.
59    ClipDecryptFailed {
60        clip_id: String,
61        reason: DecryptFailReason,
62    },
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66pub enum WsStatus {
67    Connecting,
68    Connected,
69    Disconnected,
70}
71
72#[derive(Debug, thiserror::Error)]
73pub enum WsError {
74    #[error("ws: {0}")]
75    Tungstenite(#[from] tungstenite::Error),
76    #[error("decode: {0}")]
77    Decode(String),
78}
79
80#[derive(Debug, Clone)]
81pub struct WsConfig {
82    /// Base relay URL (http/https). The ws module exchanges a short-lived
83    /// ticket from POST /ws/ticket before connecting; the bearer token
84    /// never appears in the WebSocket URL or server access logs.
85    pub relay_url: String,
86    /// Bearer token used to obtain a WS ticket from POST /ws/ticket.
87    pub token: String,
88    /// 32-byte AES key used to decrypt incoming `encrypted=true` clips.
89    /// Pass `None` to skip decryption (encrypted clips reach the caller
90    /// with `clip.content` set to ciphertext).
91    pub encryption_key: Option<[u8; 32]>,
92}
93
94/// Outcome of a single decrypt attempt on an incoming clip.
95#[derive(Debug, Clone, PartialEq, Eq)]
96pub enum DecryptOutcome {
97    /// Clip was not flagged encrypted; nothing to do.
98    Plaintext,
99    /// Successfully decrypted; `clip.content` + `clip.encrypted` mutated in place.
100    Decoded,
101    /// No AES key in local store; clip left untouched.
102    MissingKey,
103    /// AES-GCM tag verification failed (likely key mismatch); clip left untouched.
104    TagFailed { error: String },
105}
106
107#[derive(Debug, Clone, PartialEq, Eq)]
108pub enum DecryptFailReason {
109    MissingKey,
110    TagFailed(String),
111}
112
113/// Connect to the relay and forward decoded events to `tx` until the
114/// caller drops the receiver. Reconnects on socket error with exponential
115/// backoff (1s, 2s, 4s … capped at 30s). Returns when `tx` is closed.
116pub async fn run(cfg: WsConfig, tx: mpsc::Sender<WsEvent>) {
117    let mut attempt = 0u32;
118    loop {
119        if tx.is_closed() {
120            return;
121        }
122        let _ = tx.send(WsEvent::Status(WsStatus::Connecting)).await;
123        match connect_and_listen(&cfg, &tx).await {
124            Ok(()) => {
125                debug!("ws: closed cleanly");
126                attempt = 0;
127            }
128            Err(e) => {
129                warn!("ws error: {}", e);
130                attempt = attempt.saturating_add(1);
131            }
132        }
133        let _ = tx.send(WsEvent::Status(WsStatus::Disconnected)).await;
134        let backoff_secs = 1u64 << attempt.min(5); // 1, 2, 4, 8, 16, 32 ...
135        sleep(Duration::from_secs(backoff_secs.min(30))).await;
136    }
137}
138
139/// Fetch a short-lived single-use WebSocket ticket from the relay.
140/// Calls POST /ws/ticket with a Bearer auth header; returns the hex ticket string.
141async fn fetch_ws_ticket(relay_url: &str, token: &str) -> Result<String, WsError> {
142    let ticket_url = format!("{}/ws/ticket", relay_url.trim_end_matches('/'));
143    let client = reqwest::Client::builder()
144        .timeout(std::time::Duration::from_secs(10))
145        .build()
146        .map_err(|e| WsError::Decode(format!("build http client: {}", e)))?;
147    let resp = client
148        .post(&ticket_url)
149        .bearer_auth(token)
150        .send()
151        .await
152        .map_err(|e| WsError::Decode(format!("ticket request: {}", e)))?;
153    if !resp.status().is_success() {
154        return Err(WsError::Decode(format!(
155            "ticket endpoint returned {}",
156            resp.status()
157        )));
158    }
159    let body: serde_json::Value = resp
160        .json()
161        .await
162        .map_err(|e| WsError::Decode(format!("parse ticket response: {}", e)))?;
163    body["ticket"]
164        .as_str()
165        .map(|s| s.to_string())
166        .ok_or_else(|| WsError::Decode("no ticket in response".into()))
167}
168
169async fn connect_and_listen(cfg: &WsConfig, tx: &mpsc::Sender<WsEvent>) -> Result<(), WsError> {
170    let ticket = fetch_ws_ticket(&cfg.relay_url, &cfg.token).await?;
171    let ws_base = cfg
172        .relay_url
173        .replace("https://", "wss://")
174        .replace("http://", "ws://");
175    let ws_url = format!("{}/ws?ticket={}", ws_base.trim_end_matches('/'), ticket);
176    let (ws_stream, _) = connect_async(&ws_url).await?;
177    info!("ws connected");
178    log::info!("ws connected"); // also via log crate so env_logger captures it
179    let _ = tx.send(WsEvent::Status(WsStatus::Connected)).await;
180
181    let (mut write, mut read) = ws_stream.split();
182
183    while let Some(frame) = read.next().await {
184        let msg = frame?;
185        match msg {
186            Message::Text(text) => {
187                if let Some(event) = decode_message(text.as_str(), cfg.encryption_key) {
188                    if tx.send(event).await.is_err() {
189                        return Ok(());
190                    }
191                }
192            }
193            Message::Ping(data) => {
194                write.send(Message::Pong(data)).await?;
195            }
196            Message::Close(_) => {
197                debug!("relay sent close");
198                return Ok(());
199            }
200            _ => {}
201        }
202    }
203    Ok(())
204}
205
206fn decode_message(text: &str, key: Option<[u8; 32]>) -> Option<WsEvent> {
207    let msg: WSMessage = match serde_json::from_str(text) {
208        Ok(m) => m,
209        Err(e) => {
210            warn!("ws: bad message: {}", e);
211            return None;
212        }
213    };
214
215    match msg.action.as_str() {
216        ACTION_NEW_CLIP => {
217            let mut clip = msg.clip?;
218            match decrypt_clip_content(&mut clip, key) {
219                DecryptOutcome::Plaintext => {
220                    let plaintext = clip.content.as_bytes().to_vec();
221                    Some(WsEvent::NewClip {
222                        clip: Box::new(clip),
223                        plaintext,
224                    })
225                }
226                DecryptOutcome::Decoded => {
227                    let plaintext = clip.content.as_bytes().to_vec();
228                    Some(WsEvent::NewClip {
229                        clip: Box::new(clip),
230                        plaintext,
231                    })
232                }
233                DecryptOutcome::MissingKey => Some(WsEvent::ClipDecryptFailed {
234                    clip_id: clip.clip_id,
235                    reason: DecryptFailReason::MissingKey,
236                }),
237                DecryptOutcome::TagFailed { error } => Some(WsEvent::ClipDecryptFailed {
238                    clip_id: clip.clip_id,
239                    reason: DecryptFailReason::TagFailed(error),
240                }),
241            }
242        }
243        ACTION_CLIP_DELETED => Some(WsEvent::ClipDeleted {
244            clip_id: msg.clip.map(|c| c.clip_id).unwrap_or_default(),
245        }),
246        ACTION_REVOKED => Some(WsEvent::Revoked { reason: msg.reason }),
247        ACTION_TOKEN_ROTATED => msg.token.map(|t| WsEvent::TokenRotated {
248            token: t,
249            device_id: msg.device_id,
250        }),
251        ACTION_KEY_EXCHANGE_REQUESTED => {
252            log::info!(
253                "ws: decoded key_exchange_requested device_id={:?}",
254                msg.device_id
255            );
256            Some(WsEvent::KeyExchangeRequested {
257                device_id: msg.device_id,
258            })
259        }
260        ACTION_PING => None, // server pings handled by tungstenite Pong frames
261        _ => None,
262    }
263}
264
265/// Decrypt `clip.content` in place if `clip.encrypted` and a key is available.
266/// Returns a typed outcome — never silently returns ciphertext as plaintext.
267pub fn decrypt_clip_content(clip: &mut Clip, key: Option<[u8; 32]>) -> DecryptOutcome {
268    if !clip.encrypted {
269        return DecryptOutcome::Plaintext;
270    }
271    let Some(key) = key else {
272        return DecryptOutcome::MissingKey;
273    };
274    let plaintext = match crypto::decrypt(&key, &clip.content) {
275        Ok(p) => p,
276        Err(e) => {
277            return DecryptOutcome::TagFailed {
278                error: e.to_string(),
279            }
280        }
281    };
282    let is_binary = clip
283        .media_path
284        .as_deref()
285        .filter(|p| !p.is_empty())
286        .is_some()
287        || clip.content_type.starts_with("image");
288    if is_binary {
289        // Re-encode as base64 so the struct stays a valid String.
290        clip.content = STANDARD.encode(&plaintext);
291    } else {
292        match String::from_utf8(plaintext) {
293            Ok(s) => clip.content = s,
294            Err(e) => {
295                return DecryptOutcome::TagFailed {
296                    error: format!("post-decrypt utf-8 invalid: {e}"),
297                }
298            }
299        }
300    }
301    clip.encrypted = false;
302    DecryptOutcome::Decoded
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    fn make_msg(action: &str, body: serde_json::Value) -> String {
310        let mut v = body;
311        v.as_object_mut()
312            .unwrap()
313            .insert("action".into(), serde_json::Value::String(action.into()));
314        serde_json::to_string(&v).unwrap()
315    }
316
317    #[test]
318    fn decodes_new_clip_unencrypted() {
319        let json = make_msg(
320            ACTION_NEW_CLIP,
321            serde_json::json!({
322                "clip": {
323                    "clip_id": "01H",
324                    "user_id": "u1",
325                    "content": "hello",
326                    "content_type": "text",
327                    "source": "remote:host",
328                    "created_at": "2026-04-30T00:00:00Z",
329                    "encrypted": false
330                }
331            }),
332        );
333        match decode_message(&json, None).unwrap() {
334            WsEvent::NewClip { clip, plaintext } => {
335                assert_eq!(clip.clip_id, "01H");
336                assert_eq!(plaintext, b"hello");
337            }
338            other => panic!("unexpected event: {:?}", other),
339        }
340    }
341
342    #[test]
343    fn decodes_revoked() {
344        let json = make_msg(
345            ACTION_REVOKED,
346            serde_json::json!({"reason": "device removed"}),
347        );
348        match decode_message(&json, None).unwrap() {
349            WsEvent::Revoked { reason } => assert_eq!(reason.as_deref(), Some("device removed")),
350            other => panic!("unexpected event: {:?}", other),
351        }
352    }
353
354    #[test]
355    fn decodes_clip_deleted() {
356        let json = make_msg(
357            ACTION_CLIP_DELETED,
358            serde_json::json!({
359                "clip": {
360                    "clip_id": "delme",
361                    "user_id": "u1",
362                    "content": "",
363                    "content_type": "text",
364                    "source": "local",
365                    "created_at": "2026-04-30T00:00:00Z"
366                }
367            }),
368        );
369        match decode_message(&json, None).unwrap() {
370            WsEvent::ClipDeleted { clip_id } => assert_eq!(clip_id, "delme"),
371            other => panic!("unexpected event: {:?}", other),
372        }
373    }
374
375    #[test]
376    fn decrypts_text_clip_with_key() {
377        let key = [0x42u8; 32];
378        let ciphertext = crypto::encrypt(&key, b"secret payload").unwrap();
379        let json = make_msg(
380            ACTION_NEW_CLIP,
381            serde_json::json!({
382                "clip": {
383                    "clip_id": "01H",
384                    "user_id": "u1",
385                    "content": ciphertext,
386                    "content_type": "text",
387                    "source": "remote:host",
388                    "created_at": "2026-04-30T00:00:00Z",
389                    "encrypted": true
390                }
391            }),
392        );
393        match decode_message(&json, Some(key)).unwrap() {
394            WsEvent::NewClip { clip, plaintext } => {
395                assert_eq!(plaintext, b"secret payload");
396                assert!(!clip.encrypted);
397                assert_eq!(clip.content, "secret payload");
398            }
399            other => panic!("unexpected event: {:?}", other),
400        }
401    }
402
403    #[test]
404    fn decrypt_failure_does_not_silently_return_ciphertext() {
405        let sender_key = [0x11u8; 32];
406        let receiver_key = [0x22u8; 32];
407        let blob = crypto::encrypt(&sender_key, b"hello from remote cli").unwrap();
408
409        let mut clip = Clip {
410            clip_id: "c1".into(),
411            user_id: "u1".into(),
412            content: blob.clone(),
413            content_type: String::new(),
414            encrypted: true,
415            ..Default::default()
416        };
417
418        let outcome = decrypt_clip_content(&mut clip, Some(receiver_key));
419
420        assert!(
421            matches!(outcome, DecryptOutcome::TagFailed { .. }),
422            "wrong-key decrypt must return TagFailed, got {:?}",
423            outcome
424        );
425        assert!(clip.encrypted, "encrypted flag must remain true on failure");
426        assert_eq!(
427            clip.content, blob,
428            "content must not be replaced with garbage plaintext"
429        );
430    }
431
432    #[test]
433    fn decrypt_missing_key_returns_missing_key_outcome() {
434        let sender_key = [0x33u8; 32];
435        let blob = crypto::encrypt(&sender_key, b"secret").unwrap();
436
437        let mut clip = Clip {
438            clip_id: "c2".into(),
439            user_id: "u1".into(),
440            content: blob.clone(),
441            content_type: String::new(),
442            encrypted: true,
443            ..Default::default()
444        };
445
446        let outcome = decrypt_clip_content(&mut clip, None);
447        assert_eq!(outcome, DecryptOutcome::MissingKey);
448        assert!(
449            clip.encrypted,
450            "clip must remain encrypted when key is missing"
451        );
452        assert_eq!(
453            clip.content, blob,
454            "content must be untouched when key is missing"
455        );
456    }
457
458    #[test]
459    fn wrong_key_via_decode_message_emits_clip_decrypt_failed() {
460        let sender_key = [0x44u8; 32];
461        let receiver_key = [0x55u8; 32];
462        let blob = crypto::encrypt(&sender_key, b"payload").unwrap();
463
464        let json = make_msg(
465            ACTION_NEW_CLIP,
466            serde_json::json!({
467                "clip": {
468                    "clip_id": "bad-clip",
469                    "user_id": "u1",
470                    "content": blob,
471                    "content_type": "text",
472                    "source": "remote:host",
473                    "created_at": "2026-04-30T00:00:00Z",
474                    "encrypted": true
475                }
476            }),
477        );
478        match decode_message(&json, Some(receiver_key)).unwrap() {
479            WsEvent::ClipDecryptFailed { clip_id, reason } => {
480                assert_eq!(clip_id, "bad-clip");
481                assert!(matches!(reason, DecryptFailReason::TagFailed(_)));
482            }
483            other => panic!("expected ClipDecryptFailed, got {:?}", other),
484        }
485    }
486
487    #[test]
488    fn missing_key_via_decode_message_emits_clip_decrypt_failed() {
489        let sender_key = [0x66u8; 32];
490        let blob = crypto::encrypt(&sender_key, b"payload").unwrap();
491
492        let json = make_msg(
493            ACTION_NEW_CLIP,
494            serde_json::json!({
495                "clip": {
496                    "clip_id": "no-key-clip",
497                    "user_id": "u1",
498                    "content": blob,
499                    "content_type": "text",
500                    "source": "remote:host",
501                    "created_at": "2026-04-30T00:00:00Z",
502                    "encrypted": true
503                }
504            }),
505        );
506        match decode_message(&json, None).unwrap() {
507            WsEvent::ClipDecryptFailed { clip_id, reason } => {
508                assert_eq!(clip_id, "no-key-clip");
509                assert_eq!(reason, DecryptFailReason::MissingKey);
510            }
511            other => panic!("expected ClipDecryptFailed, got {:?}", other),
512        }
513    }
514}