1use std::io;
2use std::io::Error;
3use crate::codec::codec_trait::TfCodec;
4use crate::structures::temp_transport::TempTransport;
5use crate::structures::transport::{AsyncReadWrite, Transport};
6use aes_gcm::{
7 Aes256Gcm, Key, Nonce,
8 aead::{Aead, AeadCore, KeyInit, OsRng},
9};
10use async_trait::async_trait;
11use bytes::{Bytes, BytesMut};
12use futures_util::{SinkExt, StreamExt};
13use hkdf::Hkdf;
14use sha2::Sha256;
15use spake2::{Ed25519Group, Identity, Password, Spake2};
16use std::sync::Arc;
17use std::sync::atomic::{AtomicU64, Ordering};
18use aead::AeadInPlace;
19use tokio_util::codec::{Decoder, Encoder, Framed, LengthDelimitedCodec};
20
21pub struct Spake2Encrypted {
22 server_provider: Option<Arc<dyn ServerCredentialProvider>>,
23 client_provider: Option<Arc<dyn ClientCredentialProvider>>,
24 is_server: bool,
25 server_id: Vec<u8>,
26 length_codec: LengthDelimitedCodec,
27 keys: Option<SessionKeys>,
28}
29
30impl Spake2Encrypted {
31 pub fn create_server(
32 server_provider: Arc<dyn ServerCredentialProvider>,
33 server_id: String,
34 codec: LengthDelimitedCodec,
35 ) -> Self {
36 Self {
37 server_provider: Some(server_provider),
38 client_provider: None,
39 is_server: true,
40 server_id: server_id.as_bytes().to_vec(),
41 length_codec: codec,
42 keys: None,
43 }
44 }
45
46 pub fn create_client(
47 client_provider: Arc<dyn ClientCredentialProvider>,
48 server_id: String,
49 codec: LengthDelimitedCodec,
50 ) -> Self {
51 Self {
52 server_provider: None,
53 client_provider: Some(client_provider),
54 is_server: false,
55 server_id: server_id.as_bytes().to_vec(),
56 length_codec: codec,
57 keys: None,
58 }
59 }
60}
61impl Decoder for Spake2Encrypted {
62 type Item = BytesMut;
63 type Error = io::Error;
64
65 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
66 let mut frame = match self.length_codec.decode(src)? {
67 Some(f) => f,
68 None => return Ok(None),
69 };
70 if let Some(keys) = &self.keys {
71 keys.open_in_place(&mut frame)
72 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "decryption failed"))?;
73 } else {
74 return Err(io::Error::new(io::ErrorKind::Other, "decryption failed"));
75 }
76 Ok(Some(frame))
77 }
78}
79
80impl Encoder<Bytes> for Spake2Encrypted {
81 type Error = io::Error;
82
83 fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
84 if let Some(keys) = &self.keys {
85 let mut buf = BytesMut::from(item);
86 keys.seal_in_place(&mut buf)
87 .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "encryption failed"))?;
88 self.length_codec.encode(buf.freeze(), dst)
89 } else {
90 return Err(io::Error::new(io::ErrorKind::Other, "encryption failed"));
91 }
92 }
93}
94impl Clone for Spake2Encrypted {
95 fn clone(&self) -> Self {
96 Self{
97 server_provider: self.server_provider.clone(),
98 client_provider: self.client_provider.clone(),
99 is_server: self.is_server.clone(),
100 server_id: self.server_id.clone(),
101 length_codec: self.length_codec.clone(),
102 keys: None
103 }
104 }
105}
106
107#[async_trait]
108impl TfCodec for Spake2Encrypted {
109 async fn initial_setup(&mut self, tr: &mut Transport) -> bool {
110 let length_codec = LengthDelimitedCodec::builder().max_frame_length(2048).new_codec();
112 let mut framed = Framed::new(TempTransport::new(tr), length_codec);
113 if self.is_server{
114 let res = server_handshake(&mut framed, self.server_provider.as_ref().unwrap().clone(), self.server_id.as_slice()).await;
115 if let Some(keys) = res {
116 self.keys = Some(keys);
117 return true;
118 } else {
119 return false;
120 }
121 } else {
122 let res = client_handshake(&mut framed, self.client_provider.as_ref().unwrap().clone(), self.server_id.as_slice()).await;
123 if let Some(keys) = res {
124 self.keys = Some(keys);
125 return true;
126 }
127 return false;
128 }
129 }
130}
131
132
133#[async_trait]
134pub trait ServerCredentialProvider: Send+Sync+'static {
135 async fn get_client_password(&self, client_identity: &str) -> Option<Vec<u8>>;
136}
137
138#[async_trait]
139pub trait ClientCredentialProvider: Send+Sync+'static {
140 async fn get_client_credentials(&self) -> Option<(Vec<u8>, Vec<u8>)>;
142}
143
144pub struct SessionKeys {
145 pub send: Aes256Gcm,
146 pub recv: Aes256Gcm,
147
148 send_counter: AtomicU64,
150
151 recv_counter: AtomicU64,
153}
154
155
156struct BytesMutBuffer(pub BytesMut);
157
158impl AsRef<[u8]> for BytesMutBuffer {
159 fn as_ref(&self) -> &[u8] { &self.0 }
160}
161
162impl AsMut<[u8]> for BytesMutBuffer {
163 fn as_mut(&mut self) -> &mut [u8] { &mut self.0 }
164}
165
166impl aead::Buffer for BytesMutBuffer {
167 fn extend_from_slice(&mut self, other: &[u8]) -> aead::Result<()> {
168 self.0.extend_from_slice(other);
169 Ok(())
170 }
171
172 fn truncate(&mut self, len: usize) {
173 self.0.truncate(len);
174 }
175}
176impl SessionKeys {
177
178 fn derive_session_keys(shared: &[u8], is_server: bool) -> Option<Self> {
179 let hk = Hkdf::<Sha256>::new(None, shared);
180
181 let mut key_a = [0u8; 32];
182 let mut key_b = [0u8; 32];
183
184 hk.expand(b"aes-tunnel-key-a", &mut key_a).ok()?;
185 hk.expand(b"aes-tunnel-key-b", &mut key_b).ok()?;
186
187 let (send_key, recv_key) = if is_server {
188 (key_b, key_a)
189 } else {
190 (key_a, key_b)
191 };
192
193 Some(Self {
194 send: Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&send_key)),
195 recv: Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&recv_key)),
196 send_counter: AtomicU64::new(1),
197 recv_counter: AtomicU64::new(0),
198 })
199 }
200
201 #[inline]
202 fn nonce_from_counter(counter: u64) -> [u8; 12] {
203 let mut nonce = [0u8; 12];
204 nonce[4..].copy_from_slice(&counter.to_be_bytes());
205 nonce
206 }
207
208 pub fn seal_in_place(&self, buf: &mut BytesMut) -> Option<()> {
209 let counter = self.send_counter.fetch_add(1, Ordering::Relaxed);
210
211 if counter == u64::MAX {
212 return None;
213 }
214
215 let counter_bytes = counter.to_be_bytes();
216 let nonce_bytes = Self::nonce_from_counter(counter);
217 let nonce = Nonce::from_slice(&nonce_bytes);
218
219 let mut wrapped = BytesMutBuffer(buf.split());
220
221 self.send
223 .encrypt_in_place(nonce, &counter_bytes, &mut wrapped)
224 .ok()?;
225
226 buf.clear();
227 buf.reserve(8 + wrapped.0.len());
228 buf.extend_from_slice(&counter_bytes);
229 buf.unsplit(wrapped.0);
230
231 Some(())
232 }
233
234 pub fn open_in_place(&self, buf: &mut BytesMut) -> Option<()> {
235 const COUNTER_LEN: usize = 8;
236
237 if buf.len() < COUNTER_LEN {
238 return None;
239 }
240
241 let counter = u64::from_be_bytes(buf[..COUNTER_LEN].try_into().ok()?);
242
243 if counter == u64::MAX {
244 return None;
245 }
246
247 let mut last = self.recv_counter.load(Ordering::Acquire);
249 loop {
250 if counter <= last {
251 return None; }
253 match self.recv_counter.compare_exchange_weak(
254 last,
255 counter,
256 Ordering::AcqRel,
257 Ordering::Acquire,
258 ) {
259 Ok(_) => break,
260 Err(current) => last = current, }
262 }
263
264 let counter_bytes = counter.to_be_bytes();
265 let nonce_bytes = Self::nonce_from_counter(counter);
266 let nonce = Nonce::from_slice(&nonce_bytes);
267
268 let ciphertext = buf.split_off(COUNTER_LEN);
269 let mut wrapped = BytesMutBuffer(ciphertext);
270
271 self.recv
273 .decrypt_in_place(nonce, &counter_bytes, &mut wrapped)
274 .ok()?;
275
276 *buf = wrapped.0;
277
278 Some(())
279 }
280 }
281
282pub async fn client_handshake<'a, IO: AsyncReadWrite>(
283 io: &mut Framed<TempTransport<'a, IO>, LengthDelimitedCodec>,
284 cred: Arc<dyn ClientCredentialProvider>,
285 server_id: &[u8],
286) -> Option<SessionKeys> {
287 let creds = cred.get_client_credentials().await?;
288 let (spake, outbound_msg) = Spake2::<Ed25519Group>::start_a(
289 &Password::new(creds.1.as_slice()),
290 &Identity::new(creds.0.as_slice()),
291 &Identity::new(server_id),
292 );
293 io.send(Bytes::from(creds.0.clone())).await.ok()?;
294 io.send(Bytes::from(outbound_msg)).await.ok()?;
295
296 let peer_msg = io.next().await?.ok()?;
297
298 let shared = spake.finish(&peer_msg).ok()?;
299
300 SessionKeys::derive_session_keys(&shared, false)
301}
302
303pub async fn server_handshake<'a, IO: AsyncReadWrite>(
304 io: &mut Framed<TempTransport<'a, IO>, LengthDelimitedCodec>,
305 cred_provider: Arc<dyn ServerCredentialProvider>,
306 server_id: &[u8],
307) -> Option<SessionKeys>
308where
309 IO: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
310{
311 let client_identity = io.next().await?.ok()?;
312 let client_identity = String::from_utf8_lossy(client_identity.as_ref());
313 let password = cred_provider.get_client_password(&client_identity).await?;
314 let client_identity = client_identity.as_bytes();
315 let (spake, outbound_msg) = Spake2::<Ed25519Group>::start_b(
316 &Password::new(password),
317 &Identity::new(client_identity),
318 &Identity::new(server_id),
319 );
320 let peer_msg = io.next().await?.ok()?;
321
322 io.send(Bytes::from(outbound_msg)).await.ok()?;
323
324 let shared = spake.finish(&peer_msg).ok()?;
325
326 SessionKeys::derive_session_keys(&shared, true)
327}
328
329