1use std::time::Duration;
16
17use base64::engine::general_purpose::STANDARD;
18use base64::Engine;
19use futures_util::{SinkExt, StreamExt};
20use tokio::sync::mpsc;
21use tokio::time::sleep;
22use tokio_tungstenite::tungstenite::Message;
23use tokio_tungstenite::{connect_async, tungstenite};
24use tracing::{debug, info, warn};
25
26use reqwest;
27use serde_json;
28
29use crate::crypto;
30use crate::protocol::{
31 Clip, WSMessage, ACTION_CLIP_DELETED, ACTION_KEY_EXCHANGE_REQUESTED, ACTION_NEW_CLIP,
32 ACTION_PING, ACTION_REVOKED, ACTION_TOKEN_ROTATED,
33};
34use crate::version::ClientInfo;
35
36#[derive(Debug, Clone)]
37pub enum WsEvent {
38 Status(WsStatus),
40 NewClip { clip: Box<Clip>, plaintext: Vec<u8> },
44 ClipDeleted { clip_id: String },
46 Revoked { reason: Option<String> },
48 TokenRotated {
51 token: String,
52 device_id: Option<String>,
53 },
54 KeyExchangeRequested { device_id: Option<String> },
57 ClipDecryptFailed {
61 clip_id: String,
62 reason: DecryptFailReason,
63 },
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
67pub enum WsStatus {
68 Connecting,
69 Connected,
70 Disconnected,
71}
72
73#[derive(Debug, thiserror::Error)]
74pub enum WsError {
75 #[error("ws: {0}")]
76 Tungstenite(#[from] tungstenite::Error),
77 #[error("decode: {0}")]
78 Decode(String),
79}
80
81#[derive(Debug, Clone)]
82pub struct WsConfig {
83 pub relay_url: String,
87 pub token: String,
89 pub encryption_key: Option<[u8; 32]>,
93 pub client_info: Option<ClientInfo>,
99}
100
101#[derive(Debug, Clone, PartialEq, Eq)]
103pub enum DecryptOutcome {
104 Plaintext,
106 Decoded,
108 MissingKey,
110 TagFailed { error: String },
112}
113
114#[derive(Debug, Clone, PartialEq, Eq)]
115pub enum DecryptFailReason {
116 MissingKey,
117 TagFailed(String),
118}
119
120pub async fn run(cfg: WsConfig, tx: mpsc::Sender<WsEvent>) {
124 let mut attempt = 0u32;
125 loop {
126 if tx.is_closed() {
127 return;
128 }
129 let _ = tx.send(WsEvent::Status(WsStatus::Connecting)).await;
130 match connect_and_listen(&cfg, &tx).await {
131 Ok(()) => {
132 debug!("ws: closed cleanly");
133 attempt = 0;
134 }
135 Err(e) => {
136 warn!("ws error: {}", e);
137 attempt = attempt.saturating_add(1);
138 }
139 }
140 let _ = tx.send(WsEvent::Status(WsStatus::Disconnected)).await;
141 let backoff_secs = 1u64 << attempt.min(5); sleep(Duration::from_secs(backoff_secs.min(30))).await;
143 }
144}
145
146async fn fetch_ws_ticket(relay_url: &str, token: &str) -> Result<String, WsError> {
149 let ticket_url = format!("{}/ws/ticket", relay_url.trim_end_matches('/'));
150 let client = reqwest::Client::builder()
151 .timeout(std::time::Duration::from_secs(10))
152 .build()
153 .map_err(|e| WsError::Decode(format!("build http client: {}", e)))?;
154 let resp = client
155 .post(&ticket_url)
156 .bearer_auth(token)
157 .send()
158 .await
159 .map_err(|e| WsError::Decode(format!("ticket request: {}", e)))?;
160 if !resp.status().is_success() {
161 return Err(WsError::Decode(format!(
162 "ticket endpoint returned {}",
163 resp.status()
164 )));
165 }
166 let body: serde_json::Value = resp
167 .json()
168 .await
169 .map_err(|e| WsError::Decode(format!("parse ticket response: {}", e)))?;
170 body["ticket"]
171 .as_str()
172 .map(|s| s.to_string())
173 .ok_or_else(|| WsError::Decode("no ticket in response".into()))
174}
175
176async fn connect_and_listen(cfg: &WsConfig, tx: &mpsc::Sender<WsEvent>) -> Result<(), WsError> {
177 let ticket = fetch_ws_ticket(&cfg.relay_url, &cfg.token).await?;
178 let ws_base = cfg
179 .relay_url
180 .replace("https://", "wss://")
181 .replace("http://", "ws://");
182 let ws_url = format!("{}/ws?ticket={}", ws_base.trim_end_matches('/'), ticket);
183 let (ws_stream, _) = connect_async(&ws_url).await?;
184 info!("ws connected");
185 log::info!("ws connected"); let _ = tx.send(WsEvent::Status(WsStatus::Connected)).await;
187
188 let (mut write, mut read) = ws_stream.split();
189
190 if let Some(info) = cfg.client_info.as_ref() {
196 let hello = info.client_hello_message();
197 match serde_json::to_string(&hello) {
198 Ok(text) => {
199 write.send(Message::Text(text.into())).await?;
200 }
201 Err(e) => {
202 warn!("ws: failed to serialize client_hello: {}", e);
203 }
204 }
205 }
206
207 while let Some(frame) = read.next().await {
208 let msg = frame?;
209 match msg {
210 Message::Text(text) => {
211 if let Some(event) = decode_message(text.as_str(), cfg.encryption_key) {
212 if tx.send(event).await.is_err() {
213 return Ok(());
214 }
215 }
216 }
217 Message::Ping(data) => {
218 write.send(Message::Pong(data)).await?;
219 }
220 Message::Close(_) => {
221 debug!("relay sent close");
222 return Ok(());
223 }
224 _ => {}
225 }
226 }
227 Ok(())
228}
229
230fn decode_message(text: &str, key: Option<[u8; 32]>) -> Option<WsEvent> {
231 let msg: WSMessage = match serde_json::from_str(text) {
232 Ok(m) => m,
233 Err(e) => {
234 warn!("ws: bad message: {}", e);
235 return None;
236 }
237 };
238
239 match msg.action.as_str() {
240 ACTION_NEW_CLIP => {
241 let mut clip = msg.clip?;
242 match decrypt_clip_content(&mut clip, key) {
243 DecryptOutcome::Plaintext => {
244 let plaintext = clip.content.as_bytes().to_vec();
245 Some(WsEvent::NewClip {
246 clip: Box::new(clip),
247 plaintext,
248 })
249 }
250 DecryptOutcome::Decoded => {
251 let plaintext = clip.content.as_bytes().to_vec();
252 Some(WsEvent::NewClip {
253 clip: Box::new(clip),
254 plaintext,
255 })
256 }
257 DecryptOutcome::MissingKey => Some(WsEvent::ClipDecryptFailed {
258 clip_id: clip.clip_id,
259 reason: DecryptFailReason::MissingKey,
260 }),
261 DecryptOutcome::TagFailed { error } => Some(WsEvent::ClipDecryptFailed {
262 clip_id: clip.clip_id,
263 reason: DecryptFailReason::TagFailed(error),
264 }),
265 }
266 }
267 ACTION_CLIP_DELETED => Some(WsEvent::ClipDeleted {
268 clip_id: msg.clip.map(|c| c.clip_id).unwrap_or_default(),
269 }),
270 ACTION_REVOKED => Some(WsEvent::Revoked { reason: msg.reason }),
271 ACTION_TOKEN_ROTATED => msg.token.map(|t| WsEvent::TokenRotated {
272 token: t,
273 device_id: msg.device_id,
274 }),
275 ACTION_KEY_EXCHANGE_REQUESTED => {
276 log::info!(
277 "ws: decoded key_exchange_requested device_id={:?}",
278 msg.device_id
279 );
280 Some(WsEvent::KeyExchangeRequested {
281 device_id: msg.device_id,
282 })
283 }
284 ACTION_PING => None, _ => None,
286 }
287}
288
289pub fn decrypt_clip_content(clip: &mut Clip, key: Option<[u8; 32]>) -> DecryptOutcome {
292 if !clip.encrypted {
293 return DecryptOutcome::Plaintext;
294 }
295 let Some(key) = key else {
296 return DecryptOutcome::MissingKey;
297 };
298 let plaintext = match crypto::decrypt(&key, &clip.content) {
299 Ok(p) => p,
300 Err(e) => {
301 return DecryptOutcome::TagFailed {
302 error: e.to_string(),
303 }
304 }
305 };
306 let is_binary = clip
307 .media_path
308 .as_deref()
309 .filter(|p| !p.is_empty())
310 .is_some()
311 || clip.content_type.starts_with("image");
312 if is_binary {
313 clip.content = STANDARD.encode(&plaintext);
315 } else {
316 match String::from_utf8(plaintext) {
317 Ok(s) => clip.content = s,
318 Err(e) => {
319 return DecryptOutcome::TagFailed {
320 error: format!("post-decrypt utf-8 invalid: {e}"),
321 }
322 }
323 }
324 }
325 clip.encrypted = false;
326 DecryptOutcome::Decoded
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332
333 fn make_msg(action: &str, body: serde_json::Value) -> String {
334 let mut v = body;
335 v.as_object_mut()
336 .unwrap()
337 .insert("action".into(), serde_json::Value::String(action.into()));
338 serde_json::to_string(&v).unwrap()
339 }
340
341 #[test]
342 fn decodes_new_clip_unencrypted() {
343 let json = make_msg(
344 ACTION_NEW_CLIP,
345 serde_json::json!({
346 "clip": {
347 "clip_id": "01H",
348 "user_id": "u1",
349 "content": "hello",
350 "content_type": "text",
351 "source": "remote:host",
352 "created_at": "2026-04-30T00:00:00Z",
353 "encrypted": false
354 }
355 }),
356 );
357 match decode_message(&json, None).unwrap() {
358 WsEvent::NewClip { clip, plaintext } => {
359 assert_eq!(clip.clip_id, "01H");
360 assert_eq!(plaintext, b"hello");
361 }
362 other => panic!("unexpected event: {:?}", other),
363 }
364 }
365
366 #[test]
367 fn decodes_revoked() {
368 let json = make_msg(
369 ACTION_REVOKED,
370 serde_json::json!({"reason": "device removed"}),
371 );
372 match decode_message(&json, None).unwrap() {
373 WsEvent::Revoked { reason } => assert_eq!(reason.as_deref(), Some("device removed")),
374 other => panic!("unexpected event: {:?}", other),
375 }
376 }
377
378 #[test]
379 fn decodes_clip_deleted() {
380 let json = make_msg(
381 ACTION_CLIP_DELETED,
382 serde_json::json!({
383 "clip": {
384 "clip_id": "delme",
385 "user_id": "u1",
386 "content": "",
387 "content_type": "text",
388 "source": "local",
389 "created_at": "2026-04-30T00:00:00Z"
390 }
391 }),
392 );
393 match decode_message(&json, None).unwrap() {
394 WsEvent::ClipDeleted { clip_id } => assert_eq!(clip_id, "delme"),
395 other => panic!("unexpected event: {:?}", other),
396 }
397 }
398
399 #[test]
400 fn decrypts_text_clip_with_key() {
401 let key = [0x42u8; 32];
402 let ciphertext = crypto::encrypt(&key, b"secret payload").unwrap();
403 let json = make_msg(
404 ACTION_NEW_CLIP,
405 serde_json::json!({
406 "clip": {
407 "clip_id": "01H",
408 "user_id": "u1",
409 "content": ciphertext,
410 "content_type": "text",
411 "source": "remote:host",
412 "created_at": "2026-04-30T00:00:00Z",
413 "encrypted": true
414 }
415 }),
416 );
417 match decode_message(&json, Some(key)).unwrap() {
418 WsEvent::NewClip { clip, plaintext } => {
419 assert_eq!(plaintext, b"secret payload");
420 assert!(!clip.encrypted);
421 assert_eq!(clip.content, "secret payload");
422 }
423 other => panic!("unexpected event: {:?}", other),
424 }
425 }
426
427 #[test]
428 fn decrypt_failure_does_not_silently_return_ciphertext() {
429 let sender_key = [0x11u8; 32];
430 let receiver_key = [0x22u8; 32];
431 let blob = crypto::encrypt(&sender_key, b"hello from remote cli").unwrap();
432
433 let mut clip = Clip {
434 clip_id: "c1".into(),
435 user_id: "u1".into(),
436 content: blob.clone(),
437 content_type: String::new(),
438 encrypted: true,
439 ..Default::default()
440 };
441
442 let outcome = decrypt_clip_content(&mut clip, Some(receiver_key));
443
444 assert!(
445 matches!(outcome, DecryptOutcome::TagFailed { .. }),
446 "wrong-key decrypt must return TagFailed, got {:?}",
447 outcome
448 );
449 assert!(clip.encrypted, "encrypted flag must remain true on failure");
450 assert_eq!(
451 clip.content, blob,
452 "content must not be replaced with garbage plaintext"
453 );
454 }
455
456 #[test]
457 fn decrypt_missing_key_returns_missing_key_outcome() {
458 let sender_key = [0x33u8; 32];
459 let blob = crypto::encrypt(&sender_key, b"secret").unwrap();
460
461 let mut clip = Clip {
462 clip_id: "c2".into(),
463 user_id: "u1".into(),
464 content: blob.clone(),
465 content_type: String::new(),
466 encrypted: true,
467 ..Default::default()
468 };
469
470 let outcome = decrypt_clip_content(&mut clip, None);
471 assert_eq!(outcome, DecryptOutcome::MissingKey);
472 assert!(
473 clip.encrypted,
474 "clip must remain encrypted when key is missing"
475 );
476 assert_eq!(
477 clip.content, blob,
478 "content must be untouched when key is missing"
479 );
480 }
481
482 #[test]
483 fn wrong_key_via_decode_message_emits_clip_decrypt_failed() {
484 let sender_key = [0x44u8; 32];
485 let receiver_key = [0x55u8; 32];
486 let blob = crypto::encrypt(&sender_key, b"payload").unwrap();
487
488 let json = make_msg(
489 ACTION_NEW_CLIP,
490 serde_json::json!({
491 "clip": {
492 "clip_id": "bad-clip",
493 "user_id": "u1",
494 "content": blob,
495 "content_type": "text",
496 "source": "remote:host",
497 "created_at": "2026-04-30T00:00:00Z",
498 "encrypted": true
499 }
500 }),
501 );
502 match decode_message(&json, Some(receiver_key)).unwrap() {
503 WsEvent::ClipDecryptFailed { clip_id, reason } => {
504 assert_eq!(clip_id, "bad-clip");
505 assert!(matches!(reason, DecryptFailReason::TagFailed(_)));
506 }
507 other => panic!("expected ClipDecryptFailed, got {:?}", other),
508 }
509 }
510
511 #[test]
512 fn missing_key_via_decode_message_emits_clip_decrypt_failed() {
513 let sender_key = [0x66u8; 32];
514 let blob = crypto::encrypt(&sender_key, b"payload").unwrap();
515
516 let json = make_msg(
517 ACTION_NEW_CLIP,
518 serde_json::json!({
519 "clip": {
520 "clip_id": "no-key-clip",
521 "user_id": "u1",
522 "content": blob,
523 "content_type": "text",
524 "source": "remote:host",
525 "created_at": "2026-04-30T00:00:00Z",
526 "encrypted": true
527 }
528 }),
529 );
530 match decode_message(&json, None).unwrap() {
531 WsEvent::ClipDecryptFailed { clip_id, reason } => {
532 assert_eq!(clip_id, "no-key-clip");
533 assert_eq!(reason, DecryptFailReason::MissingKey);
534 }
535 other => panic!("expected ClipDecryptFailed, got {:?}", other),
536 }
537 }
538}