1use super::shared::{derive_aes_key, ecdh_compute_shared_secret, EphemeralKeyPair};
2use crate::client::ClientReq;
3use aes_gcm::{
4 aead::{Aead, KeyInit},
5 Aes256Gcm, Nonce,
6};
7use rand::RngCore;
8use ring::error::Unspecified;
9use serde::{Deserialize, Serialize};
10
11pub struct EphemeralServer {
13 pair: EphemeralKeyPair,
14}
15
16impl EphemeralServer {
17 pub fn new() -> Result<Self, Unspecified> {
18 Ok(Self {
19 pair: EphemeralKeyPair::new()?,
20 })
21 }
22
23 pub fn encrypt_secret(
25 self,
26 req: &ClientReq,
27 plaintext: &[u8],
28 ) -> Result<ServerEncryptedRes, Unspecified> {
29 let shared_secret = ecdh_compute_shared_secret(self.pair._pk, &req.pubk)?;
30
31 let mut salt_bytes = [0u8; 16];
33 rand::rngs::OsRng.fill_bytes(&mut salt_bytes);
34
35 let aes_key = derive_aes_key(salt_bytes, &shared_secret);
36
37 let mut nonce = [0u8; 12];
39 rand::rngs::OsRng.fill_bytes(&mut nonce);
40
41 let encrypted = aes_gcm_encrypt(&aes_key, &nonce, plaintext);
42
43 Ok(ServerEncryptedRes {
44 ciphertext: encrypted,
45 nonce,
46 salt: salt_bytes,
47 pubk: self.pair.pubk.as_ref().to_vec(),
48 })
49 }
50}
51
52fn aes_gcm_encrypt(key: &[u8], nonce: &[u8], plaintext: &[u8]) -> Vec<u8> {
54 let cipher = Aes256Gcm::new(key.into());
55 let nonce = Nonce::from_slice(nonce);
56 cipher
57 .encrypt(nonce, plaintext)
58 .expect("AES-GCM encryption failed")
59}
60
61#[derive(Debug, Clone, Deserialize, Serialize)]
62pub struct ServerEncryptedRes {
63 #[serde(with = "crate::shared::bytes_hex")]
64 pub ciphertext: Vec<u8>,
65 #[serde(with = "crate::shared::bytes_hex")]
66 pub pubk: Vec<u8>,
67 #[serde(with = "hex_12")]
68 pub nonce: [u8; 12],
69 #[serde(with = "hex_16")]
70 pub salt: [u8; 16],
71}
72
73mod hex_16 {
74 use crate::{from_hex_str, to_hex_str};
75 use serde::{Deserialize, Deserializer, Serializer};
76
77 pub fn serialize<S>(bytes: &[u8; 16], serializer: S) -> Result<S::Ok, S::Error>
78 where
79 S: Serializer,
80 {
81 serializer.serialize_str(&to_hex_str(bytes))
82 }
83
84 pub fn deserialize<'de, D>(deserializer: D) -> Result<[u8; 16], D::Error>
85 where
86 D: Deserializer<'de>,
87 {
88 let s = <String>::deserialize(deserializer)?;
89 let decoded = from_hex_str(&s).ok_or(serde::de::Error::custom("fail decode"))?;
90 if decoded.len() != 16 {
91 return Err(serde::de::Error::custom(format!(
92 "expected 16 bytes, got {}",
93 decoded.len()
94 )));
95 }
96 let mut arr = [0u8; 16];
97 arr.copy_from_slice(&decoded);
98 Ok(arr)
99 }
100}
101mod hex_12 {
102 use crate::{from_hex_str, to_hex_str};
103 use serde::{Deserialize, Deserializer, Serializer};
104
105 pub fn serialize<S>(bytes: &[u8; 12], serializer: S) -> Result<S::Ok, S::Error>
106 where
107 S: Serializer,
108 {
109 serializer.serialize_str(&to_hex_str(bytes))
110 }
111
112 pub fn deserialize<'de, D>(deserializer: D) -> Result<[u8; 12], D::Error>
113 where
114 D: Deserializer<'de>,
115 {
116 let s = <String>::deserialize(deserializer)?;
117 let decoded = from_hex_str(&s).ok_or(serde::de::Error::custom("fail decode"))?;
118 if decoded.len() != 12 {
119 return Err(serde::de::Error::custom(format!(
120 "expected 12 bytes, got {}",
121 decoded.len()
122 )));
123 }
124 let mut arr = [0u8; 12];
125 arr.copy_from_slice(&decoded);
126 Ok(arr)
127 }
128}
129
130#[cfg(test)]
131mod test {
132 use super::*;
133
134 #[test]
135 fn test_ser_server_res() {
136 let res = ServerEncryptedRes {
137 pubk: vec![1, 2, 3],
138 ciphertext: vec![4, 5, 6],
139 nonce: [1; 12],
140 salt: [20; 16],
141 };
142 let ser = serde_json::to_string(&res).unwrap();
143 println!("{}", ser);
144 let deser: ServerEncryptedRes = serde_json::from_str(&ser).unwrap();
145 assert_eq!(res.pubk, deser.pubk);
146 assert_eq!(res.ciphertext, deser.ciphertext);
147 assert_eq!(res.nonce, deser.nonce);
148 assert_eq!(res.salt, deser.salt);
149 }
150}