Skip to main content

shunt/
sync.rs

1//! Encryption helpers used by the `remote` command for device-to-device notification relay.
2
3use aes_gcm::{
4    aead::{Aead, KeyInit},
5    Aes256Gcm, Key, Nonce,
6};
7use anyhow::{Context, Result};
8use base64::{engine::general_purpose::STANDARD as B64, Engine as _};
9use sha2::{Digest, Sha256};
10use serde_json;
11
12// ---------------------------------------------------------------------------
13// Code generation
14// ---------------------------------------------------------------------------
15
16// ---------------------------------------------------------------------------
17// Encryption / decryption
18// ---------------------------------------------------------------------------
19
20fn derive_key(code: &str) -> [u8; 32] {
21    let hash = Sha256::digest(code.as_bytes());
22    hash.into()
23}
24
25/// Encrypt arbitrary bytes with the given code; returns a base64 payload string.
26pub fn encrypt_bytes(data: &[u8], code: &str) -> Result<String> {
27    let key_bytes = derive_key(code);
28    let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
29    let cipher = Aes256Gcm::new(key);
30    let nonce_bytes = crate::oauth::rand_bytes::<12>();
31    let nonce = Nonce::from_slice(&nonce_bytes);
32    let ciphertext = cipher
33        .encrypt(nonce, data)
34        .map_err(|e| anyhow::anyhow!("encryption failed: {e}"))?;
35    let mut wire = Vec::with_capacity(12 + ciphertext.len());
36    wire.extend_from_slice(&nonce_bytes);
37    wire.extend_from_slice(&ciphertext);
38    Ok(B64.encode(wire))
39}
40
41// ---------------------------------------------------------------------------
42// Share code helpers (SC- prefix — one-time relay handshake for shunt connect)
43// ---------------------------------------------------------------------------
44
45/// Generate a random share code like `SC-a3f2b1c4d5e6f7a8b9`.
46pub fn generate_share_code() -> String {
47    let bytes = crate::oauth::rand_bytes::<9>();
48    format!("SC-{}", hex::encode(bytes))
49}
50
51/// Validate that a share code has the expected format.
52pub fn validate_share_code(code: &str) -> Result<()> {
53    if !code.starts_with("SC-") || code.len() != 21 {
54        anyhow::bail!("Invalid share code format. Expected SC-<18 hex chars>.");
55    }
56    if !code[3..].chars().all(|c| c.is_ascii_hexdigit()) {
57        anyhow::bail!("Invalid share code — must be hex characters after 'SC-'.");
58    }
59    Ok(())
60}
61
62/// Push {base_url, api_key} to the relay under `code`.
63/// base_url is sent plaintext (not sensitive — it's just an IP/URL).
64/// api_key is encrypted with the share code before sending — the relay never sees it.
65pub async fn push_share(code: &str, base_url: &str, api_key: &str, relay_url: &str) -> Result<()> {
66    let encrypted_key = encrypt_bytes(api_key.as_bytes(), code)?;
67    let client = reqwest::Client::new();
68    let url = format!("{relay_url}/share/{code}");
69    let res = client
70        .put(&url)
71        .json(&serde_json::json!({ "base_url": base_url, "api_key": encrypted_key }))
72        .send()
73        .await
74        .context("Failed to reach relay")?;
75    if !res.status().is_success() {
76        let body = res.text().await.unwrap_or_default();
77        anyhow::bail!("Relay rejected share push ({}): {}", url, body);
78    }
79    Ok(())
80}
81
82/// Pull {base_url, api_key} from the relay for `code`. api_key is decrypted with the code.
83/// Deletes the entry on success.
84pub async fn pull_share(code: &str, relay_url: &str) -> Result<(String, String)> {
85    let client = reqwest::Client::new();
86    let url = format!("{relay_url}/share/{code}");
87    let res = client
88        .get(&url)
89        .send()
90        .await
91        .context("Failed to reach relay")?;
92    if res.status() == reqwest::StatusCode::NOT_FOUND {
93        anyhow::bail!("Share code not found, expired, or already used. Ask the host to run `shunt share` again.");
94    }
95    if !res.status().is_success() {
96        let body = res.text().await.unwrap_or_default();
97        anyhow::bail!("Relay error: {body}");
98    }
99    let json: serde_json::Value = res.json().await.context("Invalid JSON from relay")?;
100    let base_url = json["base_url"].as_str().context("Missing base_url")?.to_owned();
101    let encrypted_key = json["api_key"].as_str().context("Missing api_key")?;
102    let key_bytes = decrypt_bytes(encrypted_key, code)?;
103    let api_key = String::from_utf8(key_bytes).context("api_key is not valid UTF-8")?;
104    Ok((base_url, api_key))
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    #[test]
112    fn test_encrypt_decrypt_roundtrip() {
113        let code = "SC-aabbccddeeff001122";
114        let api_key = b"sk-ant-testkey-0000111122223333";
115        let encrypted = encrypt_bytes(api_key, code).unwrap();
116        let decrypted = decrypt_bytes(&encrypted, code).unwrap();
117        assert_eq!(api_key.as_slice(), decrypted.as_slice());
118    }
119
120    #[test]
121    fn test_wrong_code_fails() {
122        let code = "SC-aabbccddeeff001122";
123        let data = b"hello";
124        let encrypted = encrypt_bytes(data, code).unwrap();
125        assert!(decrypt_bytes(&encrypted, "SC-wrongcodewrongco").is_err());
126    }
127
128    /// Full relay roundtrip — requires network, skipped by default.
129    /// Run with: cargo test --lib sync::tests::test_relay_roundtrip -- --ignored --nocapture
130    #[tokio::test]
131    #[ignore]
132    async fn test_relay_roundtrip() {
133        let code = generate_share_code();
134        let relay = "https://relay.ramcharan.shop";
135        let base_url = "http://192.168.1.100:8082";
136        let api_key = "sk-ant-test-relay-roundtrip";
137
138        push_share(&code, base_url, api_key, relay).await.expect("push_share failed");
139        let (got_url, got_key) = pull_share(&code, relay).await.expect("pull_share failed");
140
141        assert_eq!(got_url, base_url);
142        assert_eq!(got_key, api_key);
143        println!("Relay roundtrip OK — code={code}");
144    }
145}
146
147/// Decrypt a base64 payload into bytes using the given code.
148pub fn decrypt_bytes(payload_b64: &str, code: &str) -> Result<Vec<u8>> {
149    let wire = B64.decode(payload_b64).context("invalid base64 in payload")?;
150    if wire.len() < 12 {
151        anyhow::bail!("payload too short");
152    }
153    let (nonce_bytes, ciphertext) = wire.split_at(12);
154    let key_bytes = derive_key(code);
155    let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
156    let cipher = Aes256Gcm::new(key);
157    let nonce = Nonce::from_slice(nonce_bytes);
158    cipher
159        .decrypt(nonce, ciphertext)
160        .map_err(|_| anyhow::anyhow!("decryption failed — wrong code or corrupted payload"))
161}