1use std::time::Duration;
16
17use base64::engine::general_purpose::STANDARD;
18use base64::Engine;
19use futures_util::{SinkExt, StreamExt};
20use tokio::sync::mpsc;
21use tokio::time::sleep;
22use tokio_tungstenite::tungstenite::Message;
23use tokio_tungstenite::{connect_async, tungstenite};
24use tracing::{debug, info, warn};
25
26use reqwest;
27use serde_json;
28
29use crate::crypto;
30use crate::protocol::{
31 Clip, WSMessage, ACTION_CLIP_DELETED, ACTION_KEY_EXCHANGE_REQUESTED, ACTION_NEW_CLIP,
32 ACTION_PING, ACTION_REVOKED, ACTION_TOKEN_ROTATED,
33};
34
35#[derive(Debug, Clone)]
36pub enum WsEvent {
37 Status(WsStatus),
39 NewClip { clip: Box<Clip>, plaintext: Vec<u8> },
43 ClipDeleted { clip_id: String },
45 Revoked { reason: Option<String> },
47 TokenRotated {
50 token: String,
51 device_id: Option<String>,
52 },
53 KeyExchangeRequested { device_id: Option<String> },
56 ClipDecryptFailed {
60 clip_id: String,
61 reason: DecryptFailReason,
62 },
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66pub enum WsStatus {
67 Connecting,
68 Connected,
69 Disconnected,
70}
71
72#[derive(Debug, thiserror::Error)]
73pub enum WsError {
74 #[error("ws: {0}")]
75 Tungstenite(#[from] tungstenite::Error),
76 #[error("decode: {0}")]
77 Decode(String),
78}
79
80#[derive(Debug, Clone)]
81pub struct WsConfig {
82 pub relay_url: String,
86 pub token: String,
88 pub encryption_key: Option<[u8; 32]>,
92}
93
94#[derive(Debug, Clone, PartialEq, Eq)]
96pub enum DecryptOutcome {
97 Plaintext,
99 Decoded,
101 MissingKey,
103 TagFailed { error: String },
105}
106
107#[derive(Debug, Clone, PartialEq, Eq)]
108pub enum DecryptFailReason {
109 MissingKey,
110 TagFailed(String),
111}
112
113pub async fn run(cfg: WsConfig, tx: mpsc::Sender<WsEvent>) {
117 let mut attempt = 0u32;
118 loop {
119 if tx.is_closed() {
120 return;
121 }
122 let _ = tx.send(WsEvent::Status(WsStatus::Connecting)).await;
123 match connect_and_listen(&cfg, &tx).await {
124 Ok(()) => {
125 debug!("ws: closed cleanly");
126 attempt = 0;
127 }
128 Err(e) => {
129 warn!("ws error: {}", e);
130 attempt = attempt.saturating_add(1);
131 }
132 }
133 let _ = tx.send(WsEvent::Status(WsStatus::Disconnected)).await;
134 let backoff_secs = 1u64 << attempt.min(5); sleep(Duration::from_secs(backoff_secs.min(30))).await;
136 }
137}
138
139async fn fetch_ws_ticket(relay_url: &str, token: &str) -> Result<String, WsError> {
142 let ticket_url = format!("{}/ws/ticket", relay_url.trim_end_matches('/'));
143 let client = reqwest::Client::builder()
144 .timeout(std::time::Duration::from_secs(10))
145 .build()
146 .map_err(|e| WsError::Decode(format!("build http client: {}", e)))?;
147 let resp = client
148 .post(&ticket_url)
149 .bearer_auth(token)
150 .send()
151 .await
152 .map_err(|e| WsError::Decode(format!("ticket request: {}", e)))?;
153 if !resp.status().is_success() {
154 return Err(WsError::Decode(format!(
155 "ticket endpoint returned {}",
156 resp.status()
157 )));
158 }
159 let body: serde_json::Value = resp
160 .json()
161 .await
162 .map_err(|e| WsError::Decode(format!("parse ticket response: {}", e)))?;
163 body["ticket"]
164 .as_str()
165 .map(|s| s.to_string())
166 .ok_or_else(|| WsError::Decode("no ticket in response".into()))
167}
168
169async fn connect_and_listen(cfg: &WsConfig, tx: &mpsc::Sender<WsEvent>) -> Result<(), WsError> {
170 let ticket = fetch_ws_ticket(&cfg.relay_url, &cfg.token).await?;
171 let ws_base = cfg
172 .relay_url
173 .replace("https://", "wss://")
174 .replace("http://", "ws://");
175 let ws_url = format!("{}/ws?ticket={}", ws_base.trim_end_matches('/'), ticket);
176 let (ws_stream, _) = connect_async(&ws_url).await?;
177 info!("ws connected");
178 let _ = tx.send(WsEvent::Status(WsStatus::Connected)).await;
179
180 let (mut write, mut read) = ws_stream.split();
181
182 while let Some(frame) = read.next().await {
183 let msg = frame?;
184 match msg {
185 Message::Text(text) => {
186 if let Some(event) = decode_message(text.as_str(), cfg.encryption_key) {
187 if tx.send(event).await.is_err() {
188 return Ok(());
189 }
190 }
191 }
192 Message::Ping(data) => {
193 write.send(Message::Pong(data)).await?;
194 }
195 Message::Close(_) => {
196 debug!("relay sent close");
197 return Ok(());
198 }
199 _ => {}
200 }
201 }
202 Ok(())
203}
204
205fn decode_message(text: &str, key: Option<[u8; 32]>) -> Option<WsEvent> {
206 let msg: WSMessage = match serde_json::from_str(text) {
207 Ok(m) => m,
208 Err(e) => {
209 warn!("ws: bad message: {}", e);
210 return None;
211 }
212 };
213
214 match msg.action.as_str() {
215 ACTION_NEW_CLIP => {
216 let mut clip = msg.clip?;
217 match decrypt_clip_content(&mut clip, key) {
218 DecryptOutcome::Plaintext => {
219 let plaintext = clip.content.as_bytes().to_vec();
220 Some(WsEvent::NewClip {
221 clip: Box::new(clip),
222 plaintext,
223 })
224 }
225 DecryptOutcome::Decoded => {
226 let plaintext = clip.content.as_bytes().to_vec();
227 Some(WsEvent::NewClip {
228 clip: Box::new(clip),
229 plaintext,
230 })
231 }
232 DecryptOutcome::MissingKey => Some(WsEvent::ClipDecryptFailed {
233 clip_id: clip.clip_id,
234 reason: DecryptFailReason::MissingKey,
235 }),
236 DecryptOutcome::TagFailed { error } => Some(WsEvent::ClipDecryptFailed {
237 clip_id: clip.clip_id,
238 reason: DecryptFailReason::TagFailed(error),
239 }),
240 }
241 }
242 ACTION_CLIP_DELETED => Some(WsEvent::ClipDeleted {
243 clip_id: msg.clip.map(|c| c.clip_id).unwrap_or_default(),
244 }),
245 ACTION_REVOKED => Some(WsEvent::Revoked { reason: msg.reason }),
246 ACTION_TOKEN_ROTATED => msg.token.map(|t| WsEvent::TokenRotated {
247 token: t,
248 device_id: msg.device_id,
249 }),
250 ACTION_KEY_EXCHANGE_REQUESTED => Some(WsEvent::KeyExchangeRequested {
251 device_id: msg.device_id,
252 }),
253 ACTION_PING => None, _ => None,
255 }
256}
257
258pub fn decrypt_clip_content(clip: &mut Clip, key: Option<[u8; 32]>) -> DecryptOutcome {
261 if !clip.encrypted {
262 return DecryptOutcome::Plaintext;
263 }
264 let Some(key) = key else {
265 return DecryptOutcome::MissingKey;
266 };
267 let plaintext = match crypto::decrypt(&key, &clip.content) {
268 Ok(p) => p,
269 Err(e) => {
270 return DecryptOutcome::TagFailed {
271 error: e.to_string(),
272 }
273 }
274 };
275 let is_binary = clip
276 .media_path
277 .as_deref()
278 .filter(|p| !p.is_empty())
279 .is_some()
280 || clip.content_type.starts_with("image");
281 if is_binary {
282 clip.content = STANDARD.encode(&plaintext);
284 } else {
285 match String::from_utf8(plaintext) {
286 Ok(s) => clip.content = s,
287 Err(e) => {
288 return DecryptOutcome::TagFailed {
289 error: format!("post-decrypt utf-8 invalid: {e}"),
290 }
291 }
292 }
293 }
294 clip.encrypted = false;
295 DecryptOutcome::Decoded
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301
302 fn make_msg(action: &str, body: serde_json::Value) -> String {
303 let mut v = body;
304 v.as_object_mut()
305 .unwrap()
306 .insert("action".into(), serde_json::Value::String(action.into()));
307 serde_json::to_string(&v).unwrap()
308 }
309
310 #[test]
311 fn decodes_new_clip_unencrypted() {
312 let json = make_msg(
313 ACTION_NEW_CLIP,
314 serde_json::json!({
315 "clip": {
316 "clip_id": "01H",
317 "user_id": "u1",
318 "content": "hello",
319 "content_type": "text",
320 "source": "remote:host",
321 "created_at": "2026-04-30T00:00:00Z",
322 "encrypted": false
323 }
324 }),
325 );
326 match decode_message(&json, None).unwrap() {
327 WsEvent::NewClip { clip, plaintext } => {
328 assert_eq!(clip.clip_id, "01H");
329 assert_eq!(plaintext, b"hello");
330 }
331 other => panic!("unexpected event: {:?}", other),
332 }
333 }
334
335 #[test]
336 fn decodes_revoked() {
337 let json = make_msg(
338 ACTION_REVOKED,
339 serde_json::json!({"reason": "device removed"}),
340 );
341 match decode_message(&json, None).unwrap() {
342 WsEvent::Revoked { reason } => assert_eq!(reason.as_deref(), Some("device removed")),
343 other => panic!("unexpected event: {:?}", other),
344 }
345 }
346
347 #[test]
348 fn decodes_clip_deleted() {
349 let json = make_msg(
350 ACTION_CLIP_DELETED,
351 serde_json::json!({
352 "clip": {
353 "clip_id": "delme",
354 "user_id": "u1",
355 "content": "",
356 "content_type": "text",
357 "source": "local",
358 "created_at": "2026-04-30T00:00:00Z"
359 }
360 }),
361 );
362 match decode_message(&json, None).unwrap() {
363 WsEvent::ClipDeleted { clip_id } => assert_eq!(clip_id, "delme"),
364 other => panic!("unexpected event: {:?}", other),
365 }
366 }
367
368 #[test]
369 fn decrypts_text_clip_with_key() {
370 let key = [0x42u8; 32];
371 let ciphertext = crypto::encrypt(&key, b"secret payload").unwrap();
372 let json = make_msg(
373 ACTION_NEW_CLIP,
374 serde_json::json!({
375 "clip": {
376 "clip_id": "01H",
377 "user_id": "u1",
378 "content": ciphertext,
379 "content_type": "text",
380 "source": "remote:host",
381 "created_at": "2026-04-30T00:00:00Z",
382 "encrypted": true
383 }
384 }),
385 );
386 match decode_message(&json, Some(key)).unwrap() {
387 WsEvent::NewClip { clip, plaintext } => {
388 assert_eq!(plaintext, b"secret payload");
389 assert!(!clip.encrypted);
390 assert_eq!(clip.content, "secret payload");
391 }
392 other => panic!("unexpected event: {:?}", other),
393 }
394 }
395
396 #[test]
397 fn decrypt_failure_does_not_silently_return_ciphertext() {
398 let sender_key = [0x11u8; 32];
399 let receiver_key = [0x22u8; 32];
400 let blob = crypto::encrypt(&sender_key, b"hello from remote cli").unwrap();
401
402 let mut clip = Clip {
403 clip_id: "c1".into(),
404 user_id: "u1".into(),
405 content: blob.clone(),
406 content_type: String::new(),
407 encrypted: true,
408 ..Default::default()
409 };
410
411 let outcome = decrypt_clip_content(&mut clip, Some(receiver_key));
412
413 assert!(
414 matches!(outcome, DecryptOutcome::TagFailed { .. }),
415 "wrong-key decrypt must return TagFailed, got {:?}",
416 outcome
417 );
418 assert!(clip.encrypted, "encrypted flag must remain true on failure");
419 assert_eq!(
420 clip.content, blob,
421 "content must not be replaced with garbage plaintext"
422 );
423 }
424
425 #[test]
426 fn decrypt_missing_key_returns_missing_key_outcome() {
427 let sender_key = [0x33u8; 32];
428 let blob = crypto::encrypt(&sender_key, b"secret").unwrap();
429
430 let mut clip = Clip {
431 clip_id: "c2".into(),
432 user_id: "u1".into(),
433 content: blob.clone(),
434 content_type: String::new(),
435 encrypted: true,
436 ..Default::default()
437 };
438
439 let outcome = decrypt_clip_content(&mut clip, None);
440 assert_eq!(outcome, DecryptOutcome::MissingKey);
441 assert!(
442 clip.encrypted,
443 "clip must remain encrypted when key is missing"
444 );
445 assert_eq!(
446 clip.content, blob,
447 "content must be untouched when key is missing"
448 );
449 }
450
451 #[test]
452 fn wrong_key_via_decode_message_emits_clip_decrypt_failed() {
453 let sender_key = [0x44u8; 32];
454 let receiver_key = [0x55u8; 32];
455 let blob = crypto::encrypt(&sender_key, b"payload").unwrap();
456
457 let json = make_msg(
458 ACTION_NEW_CLIP,
459 serde_json::json!({
460 "clip": {
461 "clip_id": "bad-clip",
462 "user_id": "u1",
463 "content": blob,
464 "content_type": "text",
465 "source": "remote:host",
466 "created_at": "2026-04-30T00:00:00Z",
467 "encrypted": true
468 }
469 }),
470 );
471 match decode_message(&json, Some(receiver_key)).unwrap() {
472 WsEvent::ClipDecryptFailed { clip_id, reason } => {
473 assert_eq!(clip_id, "bad-clip");
474 assert!(matches!(reason, DecryptFailReason::TagFailed(_)));
475 }
476 other => panic!("expected ClipDecryptFailed, got {:?}", other),
477 }
478 }
479
480 #[test]
481 fn missing_key_via_decode_message_emits_clip_decrypt_failed() {
482 let sender_key = [0x66u8; 32];
483 let blob = crypto::encrypt(&sender_key, b"payload").unwrap();
484
485 let json = make_msg(
486 ACTION_NEW_CLIP,
487 serde_json::json!({
488 "clip": {
489 "clip_id": "no-key-clip",
490 "user_id": "u1",
491 "content": blob,
492 "content_type": "text",
493 "source": "remote:host",
494 "created_at": "2026-04-30T00:00:00Z",
495 "encrypted": true
496 }
497 }),
498 );
499 match decode_message(&json, None).unwrap() {
500 WsEvent::ClipDecryptFailed { clip_id, reason } => {
501 assert_eq!(clip_id, "no-key-clip");
502 assert_eq!(reason, DecryptFailReason::MissingKey);
503 }
504 other => panic!("expected ClipDecryptFailed, got {:?}", other),
505 }
506 }
507}