1use aes_gcm::{
4 aead::{Aead, KeyInit},
5 Aes256Gcm, Key, Nonce,
6};
7use anyhow::{Context, Result};
8use base64::{engine::general_purpose::STANDARD as B64, Engine as _};
9use sha2::{Digest, Sha256};
10use serde_json;
11
12fn derive_key(code: &str) -> [u8; 32] {
21 let hash = Sha256::digest(code.as_bytes());
22 hash.into()
23}
24
25pub fn encrypt_bytes(data: &[u8], code: &str) -> Result<String> {
27 let key_bytes = derive_key(code);
28 let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
29 let cipher = Aes256Gcm::new(key);
30 let nonce_bytes = crate::oauth::rand_bytes::<12>();
31 let nonce = Nonce::from_slice(&nonce_bytes);
32 let ciphertext = cipher
33 .encrypt(nonce, data)
34 .map_err(|e| anyhow::anyhow!("encryption failed: {e}"))?;
35 let mut wire = Vec::with_capacity(12 + ciphertext.len());
36 wire.extend_from_slice(&nonce_bytes);
37 wire.extend_from_slice(&ciphertext);
38 Ok(B64.encode(wire))
39}
40
41pub fn generate_share_code() -> String {
47 let bytes = crate::oauth::rand_bytes::<9>();
48 format!("SC-{}", hex::encode(bytes))
49}
50
51pub fn validate_share_code(code: &str) -> Result<()> {
53 if !code.starts_with("SC-") || code.len() != 21 {
54 anyhow::bail!("Invalid share code format. Expected SC-<18 hex chars>.");
55 }
56 if !code[3..].chars().all(|c| c.is_ascii_hexdigit()) {
57 anyhow::bail!("Invalid share code — must be hex characters after 'SC-'.");
58 }
59 Ok(())
60}
61
62pub async fn push_share(code: &str, base_url: &str, api_key: &str, relay_url: &str) -> Result<()> {
66 let encrypted_key = encrypt_bytes(api_key.as_bytes(), code)?;
67 let client = reqwest::Client::new();
68 let url = format!("{relay_url}/share/{code}");
69 let res = client
70 .put(&url)
71 .json(&serde_json::json!({ "base_url": base_url, "api_key": encrypted_key }))
72 .send()
73 .await
74 .context("Failed to reach relay")?;
75 if !res.status().is_success() {
76 let body = res.text().await.unwrap_or_default();
77 anyhow::bail!("Relay rejected share push ({}): {}", url, body);
78 }
79 Ok(())
80}
81
82pub async fn pull_share(code: &str, relay_url: &str) -> Result<(String, String)> {
85 let client = reqwest::Client::new();
86 let url = format!("{relay_url}/share/{code}");
87 let res = client
88 .get(&url)
89 .send()
90 .await
91 .context("Failed to reach relay")?;
92 if res.status() == reqwest::StatusCode::NOT_FOUND {
93 anyhow::bail!("Share code not found, expired, or already used. Ask the host to run `shunt share` again.");
94 }
95 if !res.status().is_success() {
96 let body = res.text().await.unwrap_or_default();
97 anyhow::bail!("Relay error: {body}");
98 }
99 let json: serde_json::Value = res.json().await.context("Invalid JSON from relay")?;
100 let base_url = json["base_url"].as_str().context("Missing base_url")?.to_owned();
101 let encrypted_key = json["api_key"].as_str().context("Missing api_key")?;
102 let key_bytes = decrypt_bytes(encrypted_key, code)?;
103 let api_key = String::from_utf8(key_bytes).context("api_key is not valid UTF-8")?;
104 Ok((base_url, api_key))
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110
111 #[test]
112 fn test_encrypt_decrypt_roundtrip() {
113 let code = "SC-aabbccddeeff001122";
114 let api_key = b"sk-ant-testkey-0000111122223333";
115 let encrypted = encrypt_bytes(api_key, code).unwrap();
116 let decrypted = decrypt_bytes(&encrypted, code).unwrap();
117 assert_eq!(api_key.as_slice(), decrypted.as_slice());
118 }
119
120 #[test]
121 fn test_wrong_code_fails() {
122 let code = "SC-aabbccddeeff001122";
123 let data = b"hello";
124 let encrypted = encrypt_bytes(data, code).unwrap();
125 assert!(decrypt_bytes(&encrypted, "SC-wrongcodewrongco").is_err());
126 }
127
128 #[tokio::test]
131 #[ignore]
132 async fn test_relay_roundtrip() {
133 let code = generate_share_code();
134 let relay = "https://relay.ramcharan.shop";
135 let base_url = "http://192.168.1.100:8082";
136 let api_key = "sk-ant-test-relay-roundtrip";
137
138 push_share(&code, base_url, api_key, relay).await.expect("push_share failed");
139 let (got_url, got_key) = pull_share(&code, relay).await.expect("pull_share failed");
140
141 assert_eq!(got_url, base_url);
142 assert_eq!(got_key, api_key);
143 println!("Relay roundtrip OK — code={code}");
144 }
145}
146
147pub fn decrypt_bytes(payload_b64: &str, code: &str) -> Result<Vec<u8>> {
149 let wire = B64.decode(payload_b64).context("invalid base64 in payload")?;
150 if wire.len() < 12 {
151 anyhow::bail!("payload too short");
152 }
153 let (nonce_bytes, ciphertext) = wire.split_at(12);
154 let key_bytes = derive_key(code);
155 let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
156 let cipher = Aes256Gcm::new(key);
157 let nonce = Nonce::from_slice(nonce_bytes);
158 cipher
159 .decrypt(nonce, ciphertext)
160 .map_err(|_| anyhow::anyhow!("decryption failed — wrong code or corrupted payload"))
161}