Skip to main content

shunt/
sync.rs

1//! Credential bundle encryption and relay upload/download for `shunt push` / `shunt login`.
2//!
3//! Security model:
4//! - Transfer code = 9 random bytes encoded as 18 hex chars, prefixed with "SH-"
5//! - Encryption key = SHA-256(code) — 32 bytes, never sent to the relay
6//! - Cipher: AES-256-GCM with a random 12-byte nonce
7//! - Wire payload = base64(nonce_12B ‖ ciphertext_with_tag)
8//! - Relay stores only ciphertext; bundle is deleted after first download
9
10use std::collections::HashMap;
11
12use aes_gcm::{
13    aead::{Aead, KeyInit},
14    Aes256Gcm, Key, Nonce,
15};
16use anyhow::{bail, Context, Result};
17use base64::{engine::general_purpose::STANDARD as B64, Engine as _};
18use serde::{Deserialize, Serialize};
19use sha2::{Digest, Sha256};
20
21use crate::oauth::OAuthCredential;
22
23// ---------------------------------------------------------------------------
24// Bundle
25// ---------------------------------------------------------------------------
26
27#[derive(Debug, Serialize, Deserialize)]
28pub struct SyncBundle {
29    pub config_toml: String,
30    pub accounts: HashMap<String, OAuthCredential>,
31}
32
33// ---------------------------------------------------------------------------
34// Code generation
35// ---------------------------------------------------------------------------
36
37/// Generate a random transfer code like `SH-a3f2b1c4d5e6f7a8b9`.
38pub fn generate_code() -> String {
39    let bytes = crate::oauth::rand_bytes::<9>();
40    format!("SH-{}", hex::encode(bytes))
41}
42
43/// Validate that a code looks like what we generated.
44pub fn validate_code(code: &str) -> Result<()> {
45    if !code.starts_with("SH-") || code.len() != 21 {
46        bail!("Invalid transfer code format. Expected SH-<18 hex chars> (e.g. SH-a3f2b1c4d5e6f7a8b9).");
47    }
48    if !code[3..].chars().all(|c| c.is_ascii_hexdigit()) {
49        bail!("Invalid transfer code — must be hex characters after 'SH-'.");
50    }
51    Ok(())
52}
53
54// ---------------------------------------------------------------------------
55// Encryption / decryption
56// ---------------------------------------------------------------------------
57
58fn derive_key(code: &str) -> [u8; 32] {
59    let hash = Sha256::digest(code.as_bytes());
60    hash.into()
61}
62
63/// Encrypt a `SyncBundle` and return a base64-encoded payload string.
64pub fn encrypt_bundle(bundle: &SyncBundle, code: &str) -> Result<String> {
65    let json = serde_json::to_vec(bundle).context("failed to serialize bundle")?;
66
67    let key_bytes = derive_key(code);
68    let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
69    let cipher = Aes256Gcm::new(key);
70
71    let nonce_bytes = crate::oauth::rand_bytes::<12>();
72    let nonce = Nonce::from_slice(&nonce_bytes);
73
74    let ciphertext = cipher
75        .encrypt(nonce, json.as_slice())
76        .map_err(|e| anyhow::anyhow!("encryption failed: {e}"))?;
77
78    // wire: nonce(12) ‖ ciphertext
79    let mut wire = Vec::with_capacity(12 + ciphertext.len());
80    wire.extend_from_slice(&nonce_bytes);
81    wire.extend_from_slice(&ciphertext);
82
83    Ok(B64.encode(wire))
84}
85
86/// Decrypt a base64-encoded payload into a `SyncBundle`.
87pub fn decrypt_bundle(payload_b64: &str, code: &str) -> Result<SyncBundle> {
88    let wire = B64
89        .decode(payload_b64)
90        .context("invalid base64 in payload")?;
91
92    if wire.len() < 12 {
93        bail!("payload too short");
94    }
95
96    let (nonce_bytes, ciphertext) = wire.split_at(12);
97
98    let key_bytes = derive_key(code);
99    let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
100    let cipher = Aes256Gcm::new(key);
101    let nonce = Nonce::from_slice(nonce_bytes);
102
103    let plaintext = cipher
104        .decrypt(nonce, ciphertext)
105        .map_err(|_| anyhow::anyhow!("decryption failed — wrong code or corrupted payload"))?;
106
107    serde_json::from_slice::<SyncBundle>(&plaintext).context("failed to deserialize bundle")
108}
109
110// ---------------------------------------------------------------------------
111// Relay HTTP
112// ---------------------------------------------------------------------------
113
114/// Upload an encrypted payload to the relay under the given code.
115pub async fn push_to_relay(code: &str, payload: &str, relay_url: &str) -> Result<()> {
116    let client = reqwest::Client::builder()
117        .timeout(std::time::Duration::from_secs(15))
118        .build()?;
119
120    let body = serde_json::json!({ "code": code, "payload": payload });
121
122    let resp = client
123        .post(format!("{relay_url}/bundle"))
124        .json(&body)
125        .send()
126        .await
127        .context("failed to reach relay")?;
128
129    if !resp.status().is_success() {
130        let status = resp.status();
131        let text = resp.text().await.unwrap_or_default();
132        bail!("relay returned {status}: {text}");
133    }
134
135    Ok(())
136}
137
138/// Download and delete the encrypted payload for the given code from the relay.
139/// Returns the base64 payload string.
140pub async fn pull_from_relay(code: &str, relay_url: &str) -> Result<String> {
141    let client = reqwest::Client::builder()
142        .timeout(std::time::Duration::from_secs(15))
143        .build()?;
144
145    let resp = client
146        .get(format!("{relay_url}/bundle/{code}"))
147        .send()
148        .await
149        .context("failed to reach relay")?;
150
151    if resp.status() == reqwest::StatusCode::NOT_FOUND {
152        bail!("Code not found or already used. Codes are one-time use — run `shunt push` again to get a new one.");
153    }
154
155    if !resp.status().is_success() {
156        let status = resp.status();
157        let text = resp.text().await.unwrap_or_default();
158        bail!("relay returned {status}: {text}");
159    }
160
161    let json: serde_json::Value = resp.json().await.context("invalid response from relay")?;
162    json["payload"]
163        .as_str()
164        .map(|s| s.to_owned())
165        .context("relay response missing 'payload' field")
166}