Skip to main content

client_core/
ws.rs

1//! Slim WebSocket subscriber for the relay's `/ws` endpoint.
2//!
3//! Phase 5 ships this for `cinch pull --watch`. It is intentionally
4//! narrower than `desktop/src-tauri/src/ws.rs` (which still owns its
5//! tauri-coupled lifecycle, tray status, db inserts, image fetch, and
6//! key-exchange responder). Once that desktop logic gets refactored to
7//! consume callbacks from this module, the duplicated reconnect/decrypt
8//! plumbing inside desktop's ws.rs can shrink to a thin event bridge.
9//!
10//! The client connects with bearer-token auth via the URL query string
11//! (`?token=...`), reads frames, decodes `WSMessage`, decrypts clip
12//! content when `encrypted=true`, and forwards every interesting message
13//! through an `mpsc::Sender<WsEvent>` provided by the caller.
14
15use std::time::Duration;
16
17use base64::engine::general_purpose::STANDARD;
18use base64::Engine;
19use futures_util::{SinkExt, StreamExt};
20use tokio::sync::mpsc;
21use tokio::time::sleep;
22use tokio_tungstenite::tungstenite::Message;
23use tokio_tungstenite::{connect_async, tungstenite};
24use tracing::{debug, info, warn};
25
26use reqwest;
27use serde_json;
28
29use crate::crypto;
30use crate::protocol::{
31    Clip, WSMessage, ACTION_CLIP_DELETED, ACTION_KEY_EXCHANGE_REQUESTED, ACTION_NEW_CLIP,
32    ACTION_PING, ACTION_REVOKED, ACTION_TOKEN_ROTATED,
33};
34
35#[derive(Debug, Clone)]
36pub enum WsEvent {
37    /// Connection state transitions — emitted on connect, disconnect, retry.
38    Status(WsStatus),
39    /// New clip received. `plaintext` is the decrypted body for encrypted
40    /// clips (already base64-decoded for binary), or the raw `clip.content`
41    /// when no encryption key was available or `encrypted=false`.
42    NewClip { clip: Box<Clip>, plaintext: Vec<u8> },
43    /// Clip deleted on the relay (e.g., retention sweep, manual delete).
44    ClipDeleted { clip_id: String },
45    /// The caller's device was revoked. Future reconnects will 401.
46    Revoked { reason: Option<String> },
47    /// Server rotated this device's token. Caller should persist the new
48    /// token and reconnect with it.
49    TokenRotated {
50        token: String,
51        device_id: Option<String>,
52    },
53    /// Another device asked for a key bundle. Desktop handles the ECDH
54    /// responder; CLI watchers can ignore.
55    KeyExchangeRequested { device_id: Option<String> },
56    /// Incoming clip could not be decrypted (missing key or wrong key).
57    /// The clip was NOT inserted as plaintext. Callers should surface this
58    /// to the user and fire `retry_key_bundle`.
59    ClipDecryptFailed {
60        clip_id: String,
61        reason: DecryptFailReason,
62    },
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66pub enum WsStatus {
67    Connecting,
68    Connected,
69    Disconnected,
70}
71
72#[derive(Debug, thiserror::Error)]
73pub enum WsError {
74    #[error("ws: {0}")]
75    Tungstenite(#[from] tungstenite::Error),
76    #[error("decode: {0}")]
77    Decode(String),
78}
79
80#[derive(Debug, Clone)]
81pub struct WsConfig {
82    /// Base relay URL (http/https). The ws module exchanges a short-lived
83    /// ticket from POST /ws/ticket before connecting; the bearer token
84    /// never appears in the WebSocket URL or server access logs.
85    pub relay_url: String,
86    /// Bearer token used to obtain a WS ticket from POST /ws/ticket.
87    pub token: String,
88    /// 32-byte AES key used to decrypt incoming `encrypted=true` clips.
89    /// Pass `None` to skip decryption (encrypted clips reach the caller
90    /// with `clip.content` set to ciphertext).
91    pub encryption_key: Option<[u8; 32]>,
92}
93
94/// Outcome of a single decrypt attempt on an incoming clip.
95#[derive(Debug, Clone, PartialEq, Eq)]
96pub enum DecryptOutcome {
97    /// Clip was not flagged encrypted; nothing to do.
98    Plaintext,
99    /// Successfully decrypted; `clip.content` + `clip.encrypted` mutated in place.
100    Decoded,
101    /// No AES key in local store; clip left untouched.
102    MissingKey,
103    /// AES-GCM tag verification failed (likely key mismatch); clip left untouched.
104    TagFailed { error: String },
105}
106
107#[derive(Debug, Clone, PartialEq, Eq)]
108pub enum DecryptFailReason {
109    MissingKey,
110    TagFailed(String),
111}
112
113/// Connect to the relay and forward decoded events to `tx` until the
114/// caller drops the receiver. Reconnects on socket error with exponential
115/// backoff (1s, 2s, 4s … capped at 30s). Returns when `tx` is closed.
116pub async fn run(cfg: WsConfig, tx: mpsc::Sender<WsEvent>) {
117    let mut attempt = 0u32;
118    loop {
119        if tx.is_closed() {
120            return;
121        }
122        let _ = tx.send(WsEvent::Status(WsStatus::Connecting)).await;
123        match connect_and_listen(&cfg, &tx).await {
124            Ok(()) => {
125                debug!("ws: closed cleanly");
126                attempt = 0;
127            }
128            Err(e) => {
129                warn!("ws error: {}", e);
130                attempt = attempt.saturating_add(1);
131            }
132        }
133        let _ = tx.send(WsEvent::Status(WsStatus::Disconnected)).await;
134        let backoff_secs = 1u64 << attempt.min(5); // 1, 2, 4, 8, 16, 32 ...
135        sleep(Duration::from_secs(backoff_secs.min(30))).await;
136    }
137}
138
139/// Fetch a short-lived single-use WebSocket ticket from the relay.
140/// Calls POST /ws/ticket with a Bearer auth header; returns the hex ticket string.
141async fn fetch_ws_ticket(relay_url: &str, token: &str) -> Result<String, WsError> {
142    let ticket_url = format!("{}/ws/ticket", relay_url.trim_end_matches('/'));
143    let client = reqwest::Client::builder()
144        .timeout(std::time::Duration::from_secs(10))
145        .build()
146        .map_err(|e| WsError::Decode(format!("build http client: {}", e)))?;
147    let resp = client
148        .post(&ticket_url)
149        .bearer_auth(token)
150        .send()
151        .await
152        .map_err(|e| WsError::Decode(format!("ticket request: {}", e)))?;
153    if !resp.status().is_success() {
154        return Err(WsError::Decode(format!(
155            "ticket endpoint returned {}",
156            resp.status()
157        )));
158    }
159    let body: serde_json::Value = resp
160        .json()
161        .await
162        .map_err(|e| WsError::Decode(format!("parse ticket response: {}", e)))?;
163    body["ticket"]
164        .as_str()
165        .map(|s| s.to_string())
166        .ok_or_else(|| WsError::Decode("no ticket in response".into()))
167}
168
169async fn connect_and_listen(cfg: &WsConfig, tx: &mpsc::Sender<WsEvent>) -> Result<(), WsError> {
170    let ticket = fetch_ws_ticket(&cfg.relay_url, &cfg.token).await?;
171    let ws_base = cfg
172        .relay_url
173        .replace("https://", "wss://")
174        .replace("http://", "ws://");
175    let ws_url = format!("{}/ws?ticket={}", ws_base.trim_end_matches('/'), ticket);
176    let (ws_stream, _) = connect_async(&ws_url).await?;
177    info!("ws connected");
178    let _ = tx.send(WsEvent::Status(WsStatus::Connected)).await;
179
180    let (mut write, mut read) = ws_stream.split();
181
182    while let Some(frame) = read.next().await {
183        let msg = frame?;
184        match msg {
185            Message::Text(text) => {
186                if let Some(event) = decode_message(text.as_str(), cfg.encryption_key) {
187                    if tx.send(event).await.is_err() {
188                        return Ok(());
189                    }
190                }
191            }
192            Message::Ping(data) => {
193                write.send(Message::Pong(data)).await?;
194            }
195            Message::Close(_) => {
196                debug!("relay sent close");
197                return Ok(());
198            }
199            _ => {}
200        }
201    }
202    Ok(())
203}
204
205fn decode_message(text: &str, key: Option<[u8; 32]>) -> Option<WsEvent> {
206    let msg: WSMessage = match serde_json::from_str(text) {
207        Ok(m) => m,
208        Err(e) => {
209            warn!("ws: bad message: {}", e);
210            return None;
211        }
212    };
213
214    match msg.action.as_str() {
215        ACTION_NEW_CLIP => {
216            let mut clip = msg.clip?;
217            match decrypt_clip_content(&mut clip, key) {
218                DecryptOutcome::Plaintext => {
219                    let plaintext = clip.content.as_bytes().to_vec();
220                    Some(WsEvent::NewClip {
221                        clip: Box::new(clip),
222                        plaintext,
223                    })
224                }
225                DecryptOutcome::Decoded => {
226                    let plaintext = clip.content.as_bytes().to_vec();
227                    Some(WsEvent::NewClip {
228                        clip: Box::new(clip),
229                        plaintext,
230                    })
231                }
232                DecryptOutcome::MissingKey => Some(WsEvent::ClipDecryptFailed {
233                    clip_id: clip.clip_id,
234                    reason: DecryptFailReason::MissingKey,
235                }),
236                DecryptOutcome::TagFailed { error } => Some(WsEvent::ClipDecryptFailed {
237                    clip_id: clip.clip_id,
238                    reason: DecryptFailReason::TagFailed(error),
239                }),
240            }
241        }
242        ACTION_CLIP_DELETED => Some(WsEvent::ClipDeleted {
243            clip_id: msg.clip.map(|c| c.clip_id).unwrap_or_default(),
244        }),
245        ACTION_REVOKED => Some(WsEvent::Revoked { reason: msg.reason }),
246        ACTION_TOKEN_ROTATED => msg.token.map(|t| WsEvent::TokenRotated {
247            token: t,
248            device_id: msg.device_id,
249        }),
250        ACTION_KEY_EXCHANGE_REQUESTED => Some(WsEvent::KeyExchangeRequested {
251            device_id: msg.device_id,
252        }),
253        ACTION_PING => None, // server pings handled by tungstenite Pong frames
254        _ => None,
255    }
256}
257
258/// Decrypt `clip.content` in place if `clip.encrypted` and a key is available.
259/// Returns a typed outcome — never silently returns ciphertext as plaintext.
260pub fn decrypt_clip_content(clip: &mut Clip, key: Option<[u8; 32]>) -> DecryptOutcome {
261    if !clip.encrypted {
262        return DecryptOutcome::Plaintext;
263    }
264    let Some(key) = key else {
265        return DecryptOutcome::MissingKey;
266    };
267    let plaintext = match crypto::decrypt(&key, &clip.content) {
268        Ok(p) => p,
269        Err(e) => {
270            return DecryptOutcome::TagFailed {
271                error: e.to_string(),
272            }
273        }
274    };
275    let is_binary = clip
276        .media_path
277        .as_deref()
278        .filter(|p| !p.is_empty())
279        .is_some()
280        || clip.content_type.starts_with("image");
281    if is_binary {
282        // Re-encode as base64 so the struct stays a valid String.
283        clip.content = STANDARD.encode(&plaintext);
284    } else {
285        match String::from_utf8(plaintext) {
286            Ok(s) => clip.content = s,
287            Err(e) => {
288                return DecryptOutcome::TagFailed {
289                    error: format!("post-decrypt utf-8 invalid: {e}"),
290                }
291            }
292        }
293    }
294    clip.encrypted = false;
295    DecryptOutcome::Decoded
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    fn make_msg(action: &str, body: serde_json::Value) -> String {
303        let mut v = body;
304        v.as_object_mut()
305            .unwrap()
306            .insert("action".into(), serde_json::Value::String(action.into()));
307        serde_json::to_string(&v).unwrap()
308    }
309
310    #[test]
311    fn decodes_new_clip_unencrypted() {
312        let json = make_msg(
313            ACTION_NEW_CLIP,
314            serde_json::json!({
315                "clip": {
316                    "clip_id": "01H",
317                    "user_id": "u1",
318                    "content": "hello",
319                    "content_type": "text",
320                    "source": "remote:host",
321                    "created_at": "2026-04-30T00:00:00Z",
322                    "encrypted": false
323                }
324            }),
325        );
326        match decode_message(&json, None).unwrap() {
327            WsEvent::NewClip { clip, plaintext } => {
328                assert_eq!(clip.clip_id, "01H");
329                assert_eq!(plaintext, b"hello");
330            }
331            other => panic!("unexpected event: {:?}", other),
332        }
333    }
334
335    #[test]
336    fn decodes_revoked() {
337        let json = make_msg(
338            ACTION_REVOKED,
339            serde_json::json!({"reason": "device removed"}),
340        );
341        match decode_message(&json, None).unwrap() {
342            WsEvent::Revoked { reason } => assert_eq!(reason.as_deref(), Some("device removed")),
343            other => panic!("unexpected event: {:?}", other),
344        }
345    }
346
347    #[test]
348    fn decodes_clip_deleted() {
349        let json = make_msg(
350            ACTION_CLIP_DELETED,
351            serde_json::json!({
352                "clip": {
353                    "clip_id": "delme",
354                    "user_id": "u1",
355                    "content": "",
356                    "content_type": "text",
357                    "source": "local",
358                    "created_at": "2026-04-30T00:00:00Z"
359                }
360            }),
361        );
362        match decode_message(&json, None).unwrap() {
363            WsEvent::ClipDeleted { clip_id } => assert_eq!(clip_id, "delme"),
364            other => panic!("unexpected event: {:?}", other),
365        }
366    }
367
368    #[test]
369    fn decrypts_text_clip_with_key() {
370        let key = [0x42u8; 32];
371        let ciphertext = crypto::encrypt(&key, b"secret payload").unwrap();
372        let json = make_msg(
373            ACTION_NEW_CLIP,
374            serde_json::json!({
375                "clip": {
376                    "clip_id": "01H",
377                    "user_id": "u1",
378                    "content": ciphertext,
379                    "content_type": "text",
380                    "source": "remote:host",
381                    "created_at": "2026-04-30T00:00:00Z",
382                    "encrypted": true
383                }
384            }),
385        );
386        match decode_message(&json, Some(key)).unwrap() {
387            WsEvent::NewClip { clip, plaintext } => {
388                assert_eq!(plaintext, b"secret payload");
389                assert!(!clip.encrypted);
390                assert_eq!(clip.content, "secret payload");
391            }
392            other => panic!("unexpected event: {:?}", other),
393        }
394    }
395
396    #[test]
397    fn decrypt_failure_does_not_silently_return_ciphertext() {
398        let sender_key = [0x11u8; 32];
399        let receiver_key = [0x22u8; 32];
400        let blob = crypto::encrypt(&sender_key, b"hello from remote cli").unwrap();
401
402        let mut clip = Clip {
403            clip_id: "c1".into(),
404            user_id: "u1".into(),
405            content: blob.clone(),
406            content_type: String::new(),
407            encrypted: true,
408            ..Default::default()
409        };
410
411        let outcome = decrypt_clip_content(&mut clip, Some(receiver_key));
412
413        assert!(
414            matches!(outcome, DecryptOutcome::TagFailed { .. }),
415            "wrong-key decrypt must return TagFailed, got {:?}",
416            outcome
417        );
418        assert!(clip.encrypted, "encrypted flag must remain true on failure");
419        assert_eq!(
420            clip.content, blob,
421            "content must not be replaced with garbage plaintext"
422        );
423    }
424
425    #[test]
426    fn decrypt_missing_key_returns_missing_key_outcome() {
427        let sender_key = [0x33u8; 32];
428        let blob = crypto::encrypt(&sender_key, b"secret").unwrap();
429
430        let mut clip = Clip {
431            clip_id: "c2".into(),
432            user_id: "u1".into(),
433            content: blob.clone(),
434            content_type: String::new(),
435            encrypted: true,
436            ..Default::default()
437        };
438
439        let outcome = decrypt_clip_content(&mut clip, None);
440        assert_eq!(outcome, DecryptOutcome::MissingKey);
441        assert!(
442            clip.encrypted,
443            "clip must remain encrypted when key is missing"
444        );
445        assert_eq!(
446            clip.content, blob,
447            "content must be untouched when key is missing"
448        );
449    }
450
451    #[test]
452    fn wrong_key_via_decode_message_emits_clip_decrypt_failed() {
453        let sender_key = [0x44u8; 32];
454        let receiver_key = [0x55u8; 32];
455        let blob = crypto::encrypt(&sender_key, b"payload").unwrap();
456
457        let json = make_msg(
458            ACTION_NEW_CLIP,
459            serde_json::json!({
460                "clip": {
461                    "clip_id": "bad-clip",
462                    "user_id": "u1",
463                    "content": blob,
464                    "content_type": "text",
465                    "source": "remote:host",
466                    "created_at": "2026-04-30T00:00:00Z",
467                    "encrypted": true
468                }
469            }),
470        );
471        match decode_message(&json, Some(receiver_key)).unwrap() {
472            WsEvent::ClipDecryptFailed { clip_id, reason } => {
473                assert_eq!(clip_id, "bad-clip");
474                assert!(matches!(reason, DecryptFailReason::TagFailed(_)));
475            }
476            other => panic!("expected ClipDecryptFailed, got {:?}", other),
477        }
478    }
479
480    #[test]
481    fn missing_key_via_decode_message_emits_clip_decrypt_failed() {
482        let sender_key = [0x66u8; 32];
483        let blob = crypto::encrypt(&sender_key, b"payload").unwrap();
484
485        let json = make_msg(
486            ACTION_NEW_CLIP,
487            serde_json::json!({
488                "clip": {
489                    "clip_id": "no-key-clip",
490                    "user_id": "u1",
491                    "content": blob,
492                    "content_type": "text",
493                    "source": "remote:host",
494                    "created_at": "2026-04-30T00:00:00Z",
495                    "encrypted": true
496                }
497            }),
498        );
499        match decode_message(&json, None).unwrap() {
500            WsEvent::ClipDecryptFailed { clip_id, reason } => {
501                assert_eq!(clip_id, "no-key-clip");
502                assert_eq!(reason, DecryptFailReason::MissingKey);
503            }
504            other => panic!("expected ClipDecryptFailed, got {:?}", other),
505        }
506    }
507}