1use super::shared::{derive_aes_key, ecdh_compute_shared_secret, EphemeralKeyPair};
2use crate::client::ClientReq;
3use aes_gcm::{
4    aead::{Aead, KeyInit},
5    Aes256Gcm, Nonce,
6};
7use rand::RngCore;
8use ring::error::Unspecified;
9use serde::{Deserialize, Serialize};
10
11pub struct EphemeralServer {
13    pair: EphemeralKeyPair,
14}
15
16impl EphemeralServer {
17    pub fn new() -> Result<Self, Unspecified> {
18        Ok(Self {
19            pair: EphemeralKeyPair::new()?,
20        })
21    }
22
23    pub fn encrypt_secret(
25        self,
26        req: &ClientReq,
27        plaintext: &[u8],
28    ) -> Result<ServerEncryptedRes, Unspecified> {
29        let shared_secret = ecdh_compute_shared_secret(self.pair._pk, &req.pubk)?;
30
31        let mut salt_bytes = [0u8; 16];
33        rand::rngs::OsRng.fill_bytes(&mut salt_bytes);
34
35        let aes_key = derive_aes_key(salt_bytes, &shared_secret);
36
37        let mut nonce = [0u8; 12];
39        rand::rngs::OsRng.fill_bytes(&mut nonce);
40
41        let encrypted = aes_gcm_encrypt(&aes_key, &nonce, plaintext);
42
43        Ok(ServerEncryptedRes {
44            ciphertext: encrypted,
45            nonce,
46            salt: salt_bytes,
47            pubk: self.pair.pubk.as_ref().to_vec(),
48        })
49    }
50}
51
52fn aes_gcm_encrypt(key: &[u8], nonce: &[u8], plaintext: &[u8]) -> Vec<u8> {
54    let cipher = Aes256Gcm::new(key.into());
55    let nonce = Nonce::from_slice(nonce);
56    cipher
57        .encrypt(nonce, plaintext)
58        .expect("AES-GCM encryption failed")
59}
60
61#[derive(Debug, Clone, Deserialize, Serialize)]
62pub struct ServerEncryptedRes {
63    #[serde(with = "crate::shared::bytes_hex")]
64    pub ciphertext: Vec<u8>,
65    #[serde(with = "crate::shared::bytes_hex")]
66    pub pubk: Vec<u8>,
67    #[serde(with = "hex_12")]
68    pub nonce: [u8; 12],
69    #[serde(with = "hex_16")]
70    pub salt: [u8; 16],
71}
72
73mod hex_16 {
74    use crate::{from_hex_str, to_hex_str};
75    use serde::{Deserialize, Deserializer, Serializer};
76
77    pub fn serialize<S>(bytes: &[u8; 16], serializer: S) -> Result<S::Ok, S::Error>
78    where
79        S: Serializer,
80    {
81        serializer.serialize_str(&to_hex_str(bytes))
82    }
83
84    pub fn deserialize<'de, D>(deserializer: D) -> Result<[u8; 16], D::Error>
85    where
86        D: Deserializer<'de>,
87    {
88        let s = <String>::deserialize(deserializer)?;
89        let decoded = from_hex_str(&s).ok_or(serde::de::Error::custom("fail decode"))?;
90        if decoded.len() != 16 {
91            return Err(serde::de::Error::custom(format!(
92                "expected 16 bytes, got {}",
93                decoded.len()
94            )));
95        }
96        let mut arr = [0u8; 16];
97        arr.copy_from_slice(&decoded);
98        Ok(arr)
99    }
100}
101mod hex_12 {
102    use crate::{from_hex_str, to_hex_str};
103    use serde::{Deserialize, Deserializer, Serializer};
104
105    pub fn serialize<S>(bytes: &[u8; 12], serializer: S) -> Result<S::Ok, S::Error>
106    where
107        S: Serializer,
108    {
109        serializer.serialize_str(&to_hex_str(bytes))
110    }
111
112    pub fn deserialize<'de, D>(deserializer: D) -> Result<[u8; 12], D::Error>
113    where
114        D: Deserializer<'de>,
115    {
116        let s = <String>::deserialize(deserializer)?;
117        let decoded = from_hex_str(&s).ok_or(serde::de::Error::custom("fail decode"))?;
118        if decoded.len() != 12 {
119            return Err(serde::de::Error::custom(format!(
120                "expected 12 bytes, got {}",
121                decoded.len()
122            )));
123        }
124        let mut arr = [0u8; 12];
125        arr.copy_from_slice(&decoded);
126        Ok(arr)
127    }
128}
129
130#[cfg(test)]
131mod test {
132    use super::*;
133
134    #[test]
135    fn test_ser_server_res() {
136        let res = ServerEncryptedRes {
137            pubk: vec![1, 2, 3],
138            ciphertext: vec![4, 5, 6],
139            nonce: [1; 12],
140            salt: [20; 16],
141        };
142        let ser = serde_json::to_string(&res).unwrap();
143        println!("{}", ser);
144        let deser: ServerEncryptedRes = serde_json::from_str(&ser).unwrap();
145        assert_eq!(res.pubk, deser.pubk);
146        assert_eq!(res.ciphertext, deser.ciphertext);
147        assert_eq!(res.nonce, deser.nonce);
148        assert_eq!(res.salt, deser.salt);
149    }
150}