Skip to main content

shunt/
sync.rs

1//! Encryption helpers used by the `remote` command for device-to-device notification relay.
2
3use aes_gcm::{
4    aead::{Aead, KeyInit},
5    Aes256Gcm, Key, Nonce,
6};
7use anyhow::{Context, Result};
8use base64::{engine::general_purpose::STANDARD as B64, Engine as _};
9use sha2::{Digest, Sha256};
10use serde_json;
11
12// ---------------------------------------------------------------------------
13// Code generation
14// ---------------------------------------------------------------------------
15
16/// Generate a random remote-watch code like `RM-a3f2b1c4d5e6f7a8b9`.
17pub fn generate_remote_code() -> String {
18    let bytes = crate::oauth::rand_bytes::<9>();
19    format!("RM-{}", hex::encode(bytes))
20}
21
22/// Validate that a remote-watch code looks like what we generated.
23pub fn validate_remote_code(code: &str) -> Result<()> {
24    if !code.starts_with("RM-") || code.len() != 21 {
25        anyhow::bail!("Invalid remote code format. Expected RM-<18 hex chars>.");
26    }
27    if !code[3..].chars().all(|c| c.is_ascii_hexdigit()) {
28        anyhow::bail!("Invalid remote code — must be hex characters after 'RM-'.");
29    }
30    Ok(())
31}
32
33// ---------------------------------------------------------------------------
34// Encryption / decryption
35// ---------------------------------------------------------------------------
36
37fn derive_key(code: &str) -> [u8; 32] {
38    let hash = Sha256::digest(code.as_bytes());
39    hash.into()
40}
41
42/// Encrypt arbitrary bytes with the given code; returns a base64 payload string.
43pub fn encrypt_bytes(data: &[u8], code: &str) -> Result<String> {
44    let key_bytes = derive_key(code);
45    let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
46    let cipher = Aes256Gcm::new(key);
47    let nonce_bytes = crate::oauth::rand_bytes::<12>();
48    let nonce = Nonce::from_slice(&nonce_bytes);
49    let ciphertext = cipher
50        .encrypt(nonce, data)
51        .map_err(|e| anyhow::anyhow!("encryption failed: {e}"))?;
52    let mut wire = Vec::with_capacity(12 + ciphertext.len());
53    wire.extend_from_slice(&nonce_bytes);
54    wire.extend_from_slice(&ciphertext);
55    Ok(B64.encode(wire))
56}
57
58// ---------------------------------------------------------------------------
59// Share code helpers (SC- prefix — one-time relay handshake for shunt connect)
60// ---------------------------------------------------------------------------
61
62/// Generate a random share code like `SC-a3f2b1c4d5e6f7a8b9`.
63pub fn generate_share_code() -> String {
64    let bytes = crate::oauth::rand_bytes::<9>();
65    format!("SC-{}", hex::encode(bytes))
66}
67
68/// Validate that a share code has the expected format.
69pub fn validate_share_code(code: &str) -> Result<()> {
70    if !code.starts_with("SC-") || code.len() != 21 {
71        anyhow::bail!("Invalid share code format. Expected SC-<18 hex chars>.");
72    }
73    if !code[3..].chars().all(|c| c.is_ascii_hexdigit()) {
74        anyhow::bail!("Invalid share code — must be hex characters after 'SC-'.");
75    }
76    Ok(())
77}
78
79/// Push {base_url, api_key} to the relay under `code`. TTL 10 minutes, one-time read.
80pub async fn push_share(code: &str, base_url: &str, api_key: &str, relay_url: &str) -> Result<()> {
81    let client = reqwest::Client::new();
82    let url = format!("{relay_url}/share/{code}");
83    let res = client
84        .put(&url)
85        .json(&serde_json::json!({ "base_url": base_url, "api_key": api_key }))
86        .send()
87        .await
88        .context("Failed to reach relay")?;
89    if !res.status().is_success() {
90        let body = res.text().await.unwrap_or_default();
91        anyhow::bail!("Relay rejected share push ({}): {}", url, body);
92    }
93    Ok(())
94}
95
96/// Pull {base_url, api_key} from the relay for `code`. Deletes the entry on success.
97pub async fn pull_share(code: &str, relay_url: &str) -> Result<(String, String)> {
98    let client = reqwest::Client::new();
99    let url = format!("{relay_url}/share/{code}");
100    let res = client
101        .get(&url)
102        .send()
103        .await
104        .context("Failed to reach relay")?;
105    if res.status() == reqwest::StatusCode::NOT_FOUND {
106        anyhow::bail!("Share code not found, expired, or already used. Ask the host to run `shunt share` again.");
107    }
108    if !res.status().is_success() {
109        let body = res.text().await.unwrap_or_default();
110        anyhow::bail!("Relay error: {body}");
111    }
112    let json: serde_json::Value = res.json().await.context("Invalid JSON from relay")?;
113    let base_url = json["base_url"].as_str().context("Missing base_url")?.to_owned();
114    let api_key = json["api_key"].as_str().context("Missing api_key")?.to_owned();
115    Ok((base_url, api_key))
116}
117
118/// Decrypt a base64 payload into bytes using the given code.
119pub fn decrypt_bytes(payload_b64: &str, code: &str) -> Result<Vec<u8>> {
120    let wire = B64.decode(payload_b64).context("invalid base64 in payload")?;
121    if wire.len() < 12 {
122        anyhow::bail!("payload too short");
123    }
124    let (nonce_bytes, ciphertext) = wire.split_at(12);
125    let key_bytes = derive_key(code);
126    let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
127    let cipher = Aes256Gcm::new(key);
128    let nonce = Nonce::from_slice(nonce_bytes);
129    cipher
130        .decrypt(nonce, ciphertext)
131        .map_err(|_| anyhow::anyhow!("decryption failed — wrong code or corrupted payload"))
132}