Skip to main content

shunt/
sync.rs

1//! Credential bundle encryption and relay upload/download for `shunt push` / `shunt login`.
2//!
3//! Security model:
4//! - Transfer code = 9 random bytes encoded as 18 hex chars, prefixed with "SH-"
5//! - Encryption key = SHA-256(code) — 32 bytes, never sent to the relay
6//! - Cipher: AES-256-GCM with a random 12-byte nonce
7//! - Wire payload = base64(nonce_12B ‖ ciphertext_with_tag)
8//! - Relay stores only ciphertext; bundle is deleted after first download
9
10use std::collections::HashMap;
11
12use aes_gcm::{
13    aead::{Aead, KeyInit},
14    Aes256Gcm, Key, Nonce,
15};
16use anyhow::{bail, Context, Result};
17use base64::{engine::general_purpose::STANDARD as B64, Engine as _};
18use serde::{Deserialize, Serialize};
19use sha2::{Digest, Sha256};
20
21use crate::oauth::OAuthCredential;
22
23// ---------------------------------------------------------------------------
24// Bundle
25// ---------------------------------------------------------------------------
26
27#[derive(Debug, Serialize, Deserialize)]
28pub struct SyncBundle {
29    pub config_toml: String,
30    pub accounts: HashMap<String, OAuthCredential>,
31}
32
33// ---------------------------------------------------------------------------
34// Code generation
35// ---------------------------------------------------------------------------
36
37/// Generate a random transfer code like `SH-a3f2b1c4d5e6f7a8b9`.
38pub fn generate_code() -> String {
39    let bytes = crate::oauth::rand_bytes::<9>();
40    format!("SH-{}", hex::encode(bytes))
41}
42
43/// Generate a random remote-watch code like `RM-a3f2b1c4d5e6f7a8b9`.
44pub fn generate_remote_code() -> String {
45    let bytes = crate::oauth::rand_bytes::<9>();
46    format!("RM-{}", hex::encode(bytes))
47}
48
49/// Validate that a remote-watch code looks like what we generated.
50pub fn validate_remote_code(code: &str) -> Result<()> {
51    if !code.starts_with("RM-") || code.len() != 21 {
52        anyhow::bail!("Invalid remote code format. Expected RM-<18 hex chars>.");
53    }
54    if !code[3..].chars().all(|c| c.is_ascii_hexdigit()) {
55        anyhow::bail!("Invalid remote code — must be hex characters after 'RM-'.");
56    }
57    Ok(())
58}
59
60/// Validate that a code looks like what we generated.
61pub fn validate_code(code: &str) -> Result<()> {
62    if !code.starts_with("SH-") || code.len() != 21 {
63        bail!("Invalid transfer code format. Expected SH-<18 hex chars> (e.g. SH-a3f2b1c4d5e6f7a8b9).");
64    }
65    if !code[3..].chars().all(|c| c.is_ascii_hexdigit()) {
66        bail!("Invalid transfer code — must be hex characters after 'SH-'.");
67    }
68    Ok(())
69}
70
71// ---------------------------------------------------------------------------
72// Encryption / decryption
73// ---------------------------------------------------------------------------
74
75fn derive_key(code: &str) -> [u8; 32] {
76    let hash = Sha256::digest(code.as_bytes());
77    hash.into()
78}
79
80/// Encrypt a `SyncBundle` and return a base64-encoded payload string.
81pub fn encrypt_bundle(bundle: &SyncBundle, code: &str) -> Result<String> {
82    let json = serde_json::to_vec(bundle).context("failed to serialize bundle")?;
83
84    let key_bytes = derive_key(code);
85    let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
86    let cipher = Aes256Gcm::new(key);
87
88    let nonce_bytes = crate::oauth::rand_bytes::<12>();
89    let nonce = Nonce::from_slice(&nonce_bytes);
90
91    let ciphertext = cipher
92        .encrypt(nonce, json.as_slice())
93        .map_err(|e| anyhow::anyhow!("encryption failed: {e}"))?;
94
95    // wire: nonce(12) ‖ ciphertext
96    let mut wire = Vec::with_capacity(12 + ciphertext.len());
97    wire.extend_from_slice(&nonce_bytes);
98    wire.extend_from_slice(&ciphertext);
99
100    Ok(B64.encode(wire))
101}
102
103/// Decrypt a base64-encoded payload into a `SyncBundle`.
104pub fn decrypt_bundle(payload_b64: &str, code: &str) -> Result<SyncBundle> {
105    let wire = B64
106        .decode(payload_b64)
107        .context("invalid base64 in payload")?;
108
109    if wire.len() < 12 {
110        bail!("payload too short");
111    }
112
113    let (nonce_bytes, ciphertext) = wire.split_at(12);
114
115    let key_bytes = derive_key(code);
116    let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
117    let cipher = Aes256Gcm::new(key);
118    let nonce = Nonce::from_slice(nonce_bytes);
119
120    let plaintext = cipher
121        .decrypt(nonce, ciphertext)
122        .map_err(|_| anyhow::anyhow!("decryption failed — wrong code or corrupted payload"))?;
123
124    serde_json::from_slice::<SyncBundle>(&plaintext).context("failed to deserialize bundle")
125}
126
127// ---------------------------------------------------------------------------
128// Relay HTTP
129// ---------------------------------------------------------------------------
130
131/// Upload an encrypted payload to the relay under the given code.
132pub async fn push_to_relay(code: &str, payload: &str, relay_url: &str) -> Result<()> {
133    let client = reqwest::Client::builder()
134        .timeout(std::time::Duration::from_secs(15))
135        .build()?;
136
137    let body = serde_json::json!({ "code": code, "payload": payload });
138
139    let resp = client
140        .post(format!("{relay_url}/bundle"))
141        .json(&body)
142        .send()
143        .await
144        .context("failed to reach relay")?;
145
146    if !resp.status().is_success() {
147        let status = resp.status();
148        let text = resp.text().await.unwrap_or_default();
149        bail!("relay returned {status}: {text}");
150    }
151
152    Ok(())
153}
154
155/// Encrypt arbitrary bytes with the given code; returns a base64 payload string.
156/// Uses the same AES-256-GCM scheme as `encrypt_bundle`.
157pub fn encrypt_bytes(data: &[u8], code: &str) -> Result<String> {
158    let key_bytes = derive_key(code);
159    let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
160    let cipher = Aes256Gcm::new(key);
161    let nonce_bytes = crate::oauth::rand_bytes::<12>();
162    let nonce = Nonce::from_slice(&nonce_bytes);
163    let ciphertext = cipher
164        .encrypt(nonce, data)
165        .map_err(|e| anyhow::anyhow!("encryption failed: {e}"))?;
166    let mut wire = Vec::with_capacity(12 + ciphertext.len());
167    wire.extend_from_slice(&nonce_bytes);
168    wire.extend_from_slice(&ciphertext);
169    Ok(B64.encode(wire))
170}
171
172/// Decrypt a base64 payload into bytes using the given code.
173pub fn decrypt_bytes(payload_b64: &str, code: &str) -> Result<Vec<u8>> {
174    let wire = B64.decode(payload_b64).context("invalid base64 in payload")?;
175    if wire.len() < 12 { anyhow::bail!("payload too short"); }
176    let (nonce_bytes, ciphertext) = wire.split_at(12);
177    let key_bytes = derive_key(code);
178    let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
179    let cipher = Aes256Gcm::new(key);
180    let nonce = Nonce::from_slice(nonce_bytes);
181    cipher
182        .decrypt(nonce, ciphertext)
183        .map_err(|_| anyhow::anyhow!("decryption failed — wrong code or corrupted payload"))
184}
185
186/// Download and delete the encrypted payload for the given code from the relay.
187/// Returns the base64 payload string.
188pub async fn pull_from_relay(code: &str, relay_url: &str) -> Result<String> {
189    let client = reqwest::Client::builder()
190        .timeout(std::time::Duration::from_secs(15))
191        .build()?;
192
193    let resp = client
194        .get(format!("{relay_url}/bundle/{code}"))
195        .send()
196        .await
197        .context("failed to reach relay")?;
198
199    if resp.status() == reqwest::StatusCode::NOT_FOUND {
200        bail!("Code not found or already used. Codes are one-time use — run `shunt push` again to get a new one.");
201    }
202
203    if !resp.status().is_success() {
204        let status = resp.status();
205        let text = resp.text().await.unwrap_or_default();
206        bail!("relay returned {status}: {text}");
207    }
208
209    let json: serde_json::Value = resp.json().await.context("invalid response from relay")?;
210    json["payload"]
211        .as_str()
212        .map(|s| s.to_owned())
213        .context("relay response missing 'payload' field")
214}