1use std::time::Duration;
16
17use base64::engine::general_purpose::STANDARD;
18use base64::Engine;
19use futures_util::{SinkExt, StreamExt};
20use tokio::sync::mpsc;
21use tokio::time::sleep;
22use tokio_tungstenite::tungstenite::Message;
23use tokio_tungstenite::{connect_async, tungstenite};
24use tracing::{debug, info, warn};
25
26use reqwest;
27use serde_json;
28
29use crate::crypto;
30use crate::protocol::{
31 Clip, WSMessage, ACTION_CLIP_DELETED, ACTION_KEY_EXCHANGE_REQUESTED, ACTION_NEW_CLIP,
32 ACTION_PING, ACTION_REVOKED, ACTION_TOKEN_ROTATED,
33};
34
35#[derive(Debug, Clone)]
36pub enum WsEvent {
37 Status(WsStatus),
39 NewClip { clip: Box<Clip>, plaintext: Vec<u8> },
43 ClipDeleted { clip_id: String },
45 Revoked { reason: Option<String> },
47 TokenRotated {
50 token: String,
51 device_id: Option<String>,
52 },
53 KeyExchangeRequested { device_id: Option<String> },
56 ClipDecryptFailed {
60 clip_id: String,
61 reason: DecryptFailReason,
62 },
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66pub enum WsStatus {
67 Connecting,
68 Connected,
69 Disconnected,
70}
71
72#[derive(Debug, thiserror::Error)]
73pub enum WsError {
74 #[error("ws: {0}")]
75 Tungstenite(#[from] tungstenite::Error),
76 #[error("decode: {0}")]
77 Decode(String),
78}
79
80#[derive(Debug, Clone)]
81pub struct WsConfig {
82 pub relay_url: String,
86 pub token: String,
88 pub encryption_key: Option<[u8; 32]>,
92}
93
94#[derive(Debug, Clone, PartialEq, Eq)]
96pub enum DecryptOutcome {
97 Plaintext,
99 Decoded,
101 MissingKey,
103 TagFailed { error: String },
105}
106
107#[derive(Debug, Clone, PartialEq, Eq)]
108pub enum DecryptFailReason {
109 MissingKey,
110 TagFailed(String),
111}
112
113pub async fn run(cfg: WsConfig, tx: mpsc::Sender<WsEvent>) {
117 let mut attempt = 0u32;
118 loop {
119 if tx.is_closed() {
120 return;
121 }
122 let _ = tx.send(WsEvent::Status(WsStatus::Connecting)).await;
123 match connect_and_listen(&cfg, &tx).await {
124 Ok(()) => {
125 debug!("ws: closed cleanly");
126 attempt = 0;
127 }
128 Err(e) => {
129 warn!("ws error: {}", e);
130 attempt = attempt.saturating_add(1);
131 }
132 }
133 let _ = tx.send(WsEvent::Status(WsStatus::Disconnected)).await;
134 let backoff_secs = 1u64 << attempt.min(5); sleep(Duration::from_secs(backoff_secs.min(30))).await;
136 }
137}
138
139async fn fetch_ws_ticket(relay_url: &str, token: &str) -> Result<String, WsError> {
142 let ticket_url = format!("{}/ws/ticket", relay_url.trim_end_matches('/'));
143 let client = reqwest::Client::builder()
144 .timeout(std::time::Duration::from_secs(10))
145 .build()
146 .map_err(|e| WsError::Decode(format!("build http client: {}", e)))?;
147 let resp = client
148 .post(&ticket_url)
149 .bearer_auth(token)
150 .send()
151 .await
152 .map_err(|e| WsError::Decode(format!("ticket request: {}", e)))?;
153 if !resp.status().is_success() {
154 return Err(WsError::Decode(format!(
155 "ticket endpoint returned {}",
156 resp.status()
157 )));
158 }
159 let body: serde_json::Value = resp
160 .json()
161 .await
162 .map_err(|e| WsError::Decode(format!("parse ticket response: {}", e)))?;
163 body["ticket"]
164 .as_str()
165 .map(|s| s.to_string())
166 .ok_or_else(|| WsError::Decode("no ticket in response".into()))
167}
168
169async fn connect_and_listen(cfg: &WsConfig, tx: &mpsc::Sender<WsEvent>) -> Result<(), WsError> {
170 let ticket = fetch_ws_ticket(&cfg.relay_url, &cfg.token).await?;
171 let ws_base = cfg
172 .relay_url
173 .replace("https://", "wss://")
174 .replace("http://", "ws://");
175 let ws_url = format!("{}/ws?ticket={}", ws_base.trim_end_matches('/'), ticket);
176 let (ws_stream, _) = connect_async(&ws_url).await?;
177 info!("ws connected");
178 log::info!("ws connected"); let _ = tx.send(WsEvent::Status(WsStatus::Connected)).await;
180
181 let (mut write, mut read) = ws_stream.split();
182
183 while let Some(frame) = read.next().await {
184 let msg = frame?;
185 match msg {
186 Message::Text(text) => {
187 if let Some(event) = decode_message(text.as_str(), cfg.encryption_key) {
188 if tx.send(event).await.is_err() {
189 return Ok(());
190 }
191 }
192 }
193 Message::Ping(data) => {
194 write.send(Message::Pong(data)).await?;
195 }
196 Message::Close(_) => {
197 debug!("relay sent close");
198 return Ok(());
199 }
200 _ => {}
201 }
202 }
203 Ok(())
204}
205
206fn decode_message(text: &str, key: Option<[u8; 32]>) -> Option<WsEvent> {
207 let msg: WSMessage = match serde_json::from_str(text) {
208 Ok(m) => m,
209 Err(e) => {
210 warn!("ws: bad message: {}", e);
211 return None;
212 }
213 };
214
215 match msg.action.as_str() {
216 ACTION_NEW_CLIP => {
217 let mut clip = msg.clip?;
218 match decrypt_clip_content(&mut clip, key) {
219 DecryptOutcome::Plaintext => {
220 let plaintext = clip.content.as_bytes().to_vec();
221 Some(WsEvent::NewClip {
222 clip: Box::new(clip),
223 plaintext,
224 })
225 }
226 DecryptOutcome::Decoded => {
227 let plaintext = clip.content.as_bytes().to_vec();
228 Some(WsEvent::NewClip {
229 clip: Box::new(clip),
230 plaintext,
231 })
232 }
233 DecryptOutcome::MissingKey => Some(WsEvent::ClipDecryptFailed {
234 clip_id: clip.clip_id,
235 reason: DecryptFailReason::MissingKey,
236 }),
237 DecryptOutcome::TagFailed { error } => Some(WsEvent::ClipDecryptFailed {
238 clip_id: clip.clip_id,
239 reason: DecryptFailReason::TagFailed(error),
240 }),
241 }
242 }
243 ACTION_CLIP_DELETED => Some(WsEvent::ClipDeleted {
244 clip_id: msg.clip.map(|c| c.clip_id).unwrap_or_default(),
245 }),
246 ACTION_REVOKED => Some(WsEvent::Revoked { reason: msg.reason }),
247 ACTION_TOKEN_ROTATED => msg.token.map(|t| WsEvent::TokenRotated {
248 token: t,
249 device_id: msg.device_id,
250 }),
251 ACTION_KEY_EXCHANGE_REQUESTED => {
252 log::info!(
253 "ws: decoded key_exchange_requested device_id={:?}",
254 msg.device_id
255 );
256 Some(WsEvent::KeyExchangeRequested {
257 device_id: msg.device_id,
258 })
259 }
260 ACTION_PING => None, _ => None,
262 }
263}
264
265pub fn decrypt_clip_content(clip: &mut Clip, key: Option<[u8; 32]>) -> DecryptOutcome {
268 if !clip.encrypted {
269 return DecryptOutcome::Plaintext;
270 }
271 let Some(key) = key else {
272 return DecryptOutcome::MissingKey;
273 };
274 let plaintext = match crypto::decrypt(&key, &clip.content) {
275 Ok(p) => p,
276 Err(e) => {
277 return DecryptOutcome::TagFailed {
278 error: e.to_string(),
279 }
280 }
281 };
282 let is_binary = clip
283 .media_path
284 .as_deref()
285 .filter(|p| !p.is_empty())
286 .is_some()
287 || clip.content_type.starts_with("image");
288 if is_binary {
289 clip.content = STANDARD.encode(&plaintext);
291 } else {
292 match String::from_utf8(plaintext) {
293 Ok(s) => clip.content = s,
294 Err(e) => {
295 return DecryptOutcome::TagFailed {
296 error: format!("post-decrypt utf-8 invalid: {e}"),
297 }
298 }
299 }
300 }
301 clip.encrypted = false;
302 DecryptOutcome::Decoded
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308
309 fn make_msg(action: &str, body: serde_json::Value) -> String {
310 let mut v = body;
311 v.as_object_mut()
312 .unwrap()
313 .insert("action".into(), serde_json::Value::String(action.into()));
314 serde_json::to_string(&v).unwrap()
315 }
316
317 #[test]
318 fn decodes_new_clip_unencrypted() {
319 let json = make_msg(
320 ACTION_NEW_CLIP,
321 serde_json::json!({
322 "clip": {
323 "clip_id": "01H",
324 "user_id": "u1",
325 "content": "hello",
326 "content_type": "text",
327 "source": "remote:host",
328 "created_at": "2026-04-30T00:00:00Z",
329 "encrypted": false
330 }
331 }),
332 );
333 match decode_message(&json, None).unwrap() {
334 WsEvent::NewClip { clip, plaintext } => {
335 assert_eq!(clip.clip_id, "01H");
336 assert_eq!(plaintext, b"hello");
337 }
338 other => panic!("unexpected event: {:?}", other),
339 }
340 }
341
342 #[test]
343 fn decodes_revoked() {
344 let json = make_msg(
345 ACTION_REVOKED,
346 serde_json::json!({"reason": "device removed"}),
347 );
348 match decode_message(&json, None).unwrap() {
349 WsEvent::Revoked { reason } => assert_eq!(reason.as_deref(), Some("device removed")),
350 other => panic!("unexpected event: {:?}", other),
351 }
352 }
353
354 #[test]
355 fn decodes_clip_deleted() {
356 let json = make_msg(
357 ACTION_CLIP_DELETED,
358 serde_json::json!({
359 "clip": {
360 "clip_id": "delme",
361 "user_id": "u1",
362 "content": "",
363 "content_type": "text",
364 "source": "local",
365 "created_at": "2026-04-30T00:00:00Z"
366 }
367 }),
368 );
369 match decode_message(&json, None).unwrap() {
370 WsEvent::ClipDeleted { clip_id } => assert_eq!(clip_id, "delme"),
371 other => panic!("unexpected event: {:?}", other),
372 }
373 }
374
375 #[test]
376 fn decrypts_text_clip_with_key() {
377 let key = [0x42u8; 32];
378 let ciphertext = crypto::encrypt(&key, b"secret payload").unwrap();
379 let json = make_msg(
380 ACTION_NEW_CLIP,
381 serde_json::json!({
382 "clip": {
383 "clip_id": "01H",
384 "user_id": "u1",
385 "content": ciphertext,
386 "content_type": "text",
387 "source": "remote:host",
388 "created_at": "2026-04-30T00:00:00Z",
389 "encrypted": true
390 }
391 }),
392 );
393 match decode_message(&json, Some(key)).unwrap() {
394 WsEvent::NewClip { clip, plaintext } => {
395 assert_eq!(plaintext, b"secret payload");
396 assert!(!clip.encrypted);
397 assert_eq!(clip.content, "secret payload");
398 }
399 other => panic!("unexpected event: {:?}", other),
400 }
401 }
402
403 #[test]
404 fn decrypt_failure_does_not_silently_return_ciphertext() {
405 let sender_key = [0x11u8; 32];
406 let receiver_key = [0x22u8; 32];
407 let blob = crypto::encrypt(&sender_key, b"hello from remote cli").unwrap();
408
409 let mut clip = Clip {
410 clip_id: "c1".into(),
411 user_id: "u1".into(),
412 content: blob.clone(),
413 content_type: String::new(),
414 encrypted: true,
415 ..Default::default()
416 };
417
418 let outcome = decrypt_clip_content(&mut clip, Some(receiver_key));
419
420 assert!(
421 matches!(outcome, DecryptOutcome::TagFailed { .. }),
422 "wrong-key decrypt must return TagFailed, got {:?}",
423 outcome
424 );
425 assert!(clip.encrypted, "encrypted flag must remain true on failure");
426 assert_eq!(
427 clip.content, blob,
428 "content must not be replaced with garbage plaintext"
429 );
430 }
431
432 #[test]
433 fn decrypt_missing_key_returns_missing_key_outcome() {
434 let sender_key = [0x33u8; 32];
435 let blob = crypto::encrypt(&sender_key, b"secret").unwrap();
436
437 let mut clip = Clip {
438 clip_id: "c2".into(),
439 user_id: "u1".into(),
440 content: blob.clone(),
441 content_type: String::new(),
442 encrypted: true,
443 ..Default::default()
444 };
445
446 let outcome = decrypt_clip_content(&mut clip, None);
447 assert_eq!(outcome, DecryptOutcome::MissingKey);
448 assert!(
449 clip.encrypted,
450 "clip must remain encrypted when key is missing"
451 );
452 assert_eq!(
453 clip.content, blob,
454 "content must be untouched when key is missing"
455 );
456 }
457
458 #[test]
459 fn wrong_key_via_decode_message_emits_clip_decrypt_failed() {
460 let sender_key = [0x44u8; 32];
461 let receiver_key = [0x55u8; 32];
462 let blob = crypto::encrypt(&sender_key, b"payload").unwrap();
463
464 let json = make_msg(
465 ACTION_NEW_CLIP,
466 serde_json::json!({
467 "clip": {
468 "clip_id": "bad-clip",
469 "user_id": "u1",
470 "content": blob,
471 "content_type": "text",
472 "source": "remote:host",
473 "created_at": "2026-04-30T00:00:00Z",
474 "encrypted": true
475 }
476 }),
477 );
478 match decode_message(&json, Some(receiver_key)).unwrap() {
479 WsEvent::ClipDecryptFailed { clip_id, reason } => {
480 assert_eq!(clip_id, "bad-clip");
481 assert!(matches!(reason, DecryptFailReason::TagFailed(_)));
482 }
483 other => panic!("expected ClipDecryptFailed, got {:?}", other),
484 }
485 }
486
487 #[test]
488 fn missing_key_via_decode_message_emits_clip_decrypt_failed() {
489 let sender_key = [0x66u8; 32];
490 let blob = crypto::encrypt(&sender_key, b"payload").unwrap();
491
492 let json = make_msg(
493 ACTION_NEW_CLIP,
494 serde_json::json!({
495 "clip": {
496 "clip_id": "no-key-clip",
497 "user_id": "u1",
498 "content": blob,
499 "content_type": "text",
500 "source": "remote:host",
501 "created_at": "2026-04-30T00:00:00Z",
502 "encrypted": true
503 }
504 }),
505 );
506 match decode_message(&json, None).unwrap() {
507 WsEvent::ClipDecryptFailed { clip_id, reason } => {
508 assert_eq!(clip_id, "no-key-clip");
509 assert_eq!(reason, DecryptFailReason::MissingKey);
510 }
511 other => panic!("expected ClipDecryptFailed, got {:?}", other),
512 }
513 }
514}