1use std::io;
2use std::io::Error;
3use crate::codec::codec_trait::TfCodec;
4use crate::structures::temp_transport::TempTransport;
5use crate::structures::transport::{AsyncReadWrite, Transport};
6use aes_gcm::{
7 Aes256Gcm, Key, Nonce,
8 aead::{Aead, AeadCore, KeyInit, OsRng},
9};
10use async_trait::async_trait;
11use bytes::{Bytes, BytesMut};
12use futures_util::{SinkExt, StreamExt};
13use hkdf::Hkdf;
14use sha2::Sha256;
15use spake2::{Ed25519Group, Identity, Password, Spake2};
16use std::sync::Arc;
17use std::sync::atomic::{AtomicU64, Ordering};
18use aead::AeadInPlace;
19use tokio_util::codec::{Decoder, Encoder, Framed, LengthDelimitedCodec};
20
21pub struct Spake2Encrypted {
22 server_provider: Option<Arc<dyn ServerCredentialProvider>>,
23 client_provider: Option<Arc<dyn ClientCredentialProvider>>,
24 is_server: bool,
25 server_id: Vec<u8>,
26 length_codec: LengthDelimitedCodec,
27 keys: Option<SessionKeys>,
28}
29
30impl Spake2Encrypted {
31 pub fn create_server(
32 server_provider: Arc<dyn ServerCredentialProvider>,
33 server_id: String,
34 codec: LengthDelimitedCodec,
35 ) -> Self {
36 Self {
37 server_provider: Some(server_provider),
38 client_provider: None,
39 is_server: true,
40 server_id: server_id.as_bytes().to_vec(),
41 length_codec: codec,
42 keys: None,
43 }
44 }
45
46 pub fn create_client(
47 client_provider: Arc<dyn ClientCredentialProvider>,
48 server_id: String,
49 codec: LengthDelimitedCodec,
50 ) -> Self {
51 Self {
52 server_provider: None,
53 client_provider: Some(client_provider),
54 is_server: false,
55 server_id: server_id.as_bytes().to_vec(),
56 length_codec: codec,
57 keys: None,
58 }
59 }
60}
61impl Decoder for Spake2Encrypted {
62 type Item = BytesMut;
63 type Error = io::Error;
64
65 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
66 let mut frame = match self.length_codec.decode(src)? {
67 Some(f) => f,
68 None => return Ok(None),
69 };
70 if let Some(keys) = &self.keys {
71 keys.open_in_place(&mut frame)
72 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "decryption failed"))?;
73 }
74 Ok(Some(frame))
75 }
76}
77
78impl Encoder<Bytes> for Spake2Encrypted {
79 type Error = io::Error;
80
81 fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
82 if let Some(keys) = &self.keys {
83 let mut buf = BytesMut::from(item);
84 keys.seal_in_place(&mut buf)
85 .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "encryption failed"))?;
86 self.length_codec.encode(buf.freeze(), dst)
87 } else {
88 self.length_codec.encode(item, dst)
89 }
90 }
91}
92impl Clone for Spake2Encrypted {
93 fn clone(&self) -> Self {
94 Self{
95 server_provider: self.server_provider.clone(),
96 client_provider: self.client_provider.clone(),
97 is_server: self.is_server.clone(),
98 server_id: self.server_id.clone(),
99 length_codec: self.length_codec.clone(),
100 keys: None
101 }
102 }
103}
104
105#[async_trait]
106impl TfCodec for Spake2Encrypted {
107 async fn initial_setup(&mut self, tr: &mut Transport) -> bool {
108 let length_codec = LengthDelimitedCodec::builder().max_frame_length(2048).new_codec();
110 let mut framed = Framed::new(TempTransport::new(tr), length_codec);
111 if self.is_server{
112 let res = server_handshake(&mut framed, self.server_provider.as_ref().unwrap().clone(), self.server_id.as_slice()).await;
113 if let Some(keys) = res {
114 self.keys = Some(keys);
115 return true;
116 } else {
117 return false;
118 }
119 } else {
120 let res = client_handshake(&mut framed, self.client_provider.as_ref().unwrap().clone(), self.server_id.as_slice()).await;
121 if let Some(keys) = res {
122 self.keys = Some(keys);
123 return true;
124 }
125 return false;
126 }
127 }
128}
129
130
131#[async_trait]
132pub trait ServerCredentialProvider: Send+Sync+'static {
133 async fn get_client_password(&self, client_identity: &str) -> Option<Vec<u8>>;
134}
135
136#[async_trait]
137pub trait ClientCredentialProvider: Send+Sync+'static {
138 async fn get_client_credentials(&self) -> Option<(Vec<u8>, Vec<u8>)>;
140}
141
142pub struct SessionKeys {
143 pub send: Aes256Gcm,
144 pub recv: Aes256Gcm,
145
146 send_counter: AtomicU64,
148
149 recv_counter: AtomicU64,
151}
152
153
154struct BytesMutBuffer(pub BytesMut);
155
156impl AsRef<[u8]> for BytesMutBuffer {
157 fn as_ref(&self) -> &[u8] { &self.0 }
158}
159
160impl AsMut<[u8]> for BytesMutBuffer {
161 fn as_mut(&mut self) -> &mut [u8] { &mut self.0 }
162}
163
164impl aead::Buffer for BytesMutBuffer {
165 fn extend_from_slice(&mut self, other: &[u8]) -> aead::Result<()> {
166 self.0.extend_from_slice(other);
167 Ok(())
168 }
169
170 fn truncate(&mut self, len: usize) {
171 self.0.truncate(len);
172 }
173}
174impl SessionKeys {
175
176 fn derive_session_keys(shared: &[u8], is_server: bool) -> Option<Self> {
177 let hk = Hkdf::<Sha256>::new(None, shared);
178
179 let mut key_a = [0u8; 32];
180 let mut key_b = [0u8; 32];
181
182 hk.expand(b"aes-tunnel-key-a", &mut key_a).ok()?;
183 hk.expand(b"aes-tunnel-key-b", &mut key_b).ok()?;
184
185 let (send_key, recv_key) = if is_server {
186 (key_b, key_a)
187 } else {
188 (key_a, key_b)
189 };
190
191 Some(Self {
192 send: Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&send_key)),
193 recv: Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&recv_key)),
194 send_counter: AtomicU64::new(1),
195 recv_counter: AtomicU64::new(0),
196 })
197 }
198
199 #[inline]
200 fn nonce_from_counter(counter: u64) -> [u8; 12] {
201 let mut nonce = [0u8; 12];
202 nonce[4..].copy_from_slice(&counter.to_be_bytes());
203 nonce
204 }
205
206 pub fn seal_in_place(&self, buf: &mut BytesMut) -> Option<()> {
207 let counter = self.send_counter.fetch_add(1, Ordering::Relaxed);
208
209 if counter == u64::MAX {
210 return None;
211 }
212
213 let counter_bytes = counter.to_be_bytes();
214 let nonce_bytes = Self::nonce_from_counter(counter);
215 let nonce = Nonce::from_slice(&nonce_bytes);
216
217 let mut wrapped = BytesMutBuffer(buf.split());
218
219 self.send
221 .encrypt_in_place(nonce, &counter_bytes, &mut wrapped)
222 .ok()?;
223
224 buf.clear();
225 buf.reserve(8 + wrapped.0.len());
226 buf.extend_from_slice(&counter_bytes);
227 buf.unsplit(wrapped.0);
228
229 Some(())
230 }
231
232 pub fn open_in_place(&self, buf: &mut BytesMut) -> Option<()> {
233 const COUNTER_LEN: usize = 8;
234
235 if buf.len() < COUNTER_LEN {
236 return None;
237 }
238
239 let counter = u64::from_be_bytes(buf[..COUNTER_LEN].try_into().ok()?);
240
241 if counter == u64::MAX {
242 return None;
243 }
244
245 let mut last = self.recv_counter.load(Ordering::Acquire);
247 loop {
248 if counter <= last {
249 return None; }
251 match self.recv_counter.compare_exchange_weak(
252 last,
253 counter,
254 Ordering::AcqRel,
255 Ordering::Acquire,
256 ) {
257 Ok(_) => break,
258 Err(current) => last = current, }
260 }
261
262 let counter_bytes = counter.to_be_bytes();
263 let nonce_bytes = Self::nonce_from_counter(counter);
264 let nonce = Nonce::from_slice(&nonce_bytes);
265
266 let ciphertext = buf.split_off(COUNTER_LEN);
267 let mut wrapped = BytesMutBuffer(ciphertext);
268
269 self.recv
271 .decrypt_in_place(nonce, &counter_bytes, &mut wrapped)
272 .ok()?;
273
274 *buf = wrapped.0;
275
276 Some(())
277 }
278 }
279
280pub async fn client_handshake<'a, IO: AsyncReadWrite>(
281 io: &mut Framed<TempTransport<'a, IO>, LengthDelimitedCodec>,
282 cred: Arc<dyn ClientCredentialProvider>,
283 server_id: &[u8],
284) -> Option<SessionKeys> {
285 let creds = cred.get_client_credentials().await?;
286 let (spake, outbound_msg) = Spake2::<Ed25519Group>::start_a(
287 &Password::new(creds.1.as_slice()),
288 &Identity::new(creds.0.as_slice()),
289 &Identity::new(server_id),
290 );
291 io.send(Bytes::from(creds.0.clone())).await.ok()?;
292 io.send(Bytes::from(outbound_msg)).await.ok()?;
293
294 let peer_msg = io.next().await?.ok()?;
295
296 let shared = spake.finish(&peer_msg).ok()?;
297
298 SessionKeys::derive_session_keys(&shared, false)
299}
300
301pub async fn server_handshake<'a, IO: AsyncReadWrite>(
302 io: &mut Framed<TempTransport<'a, IO>, LengthDelimitedCodec>,
303 cred_provider: Arc<dyn ServerCredentialProvider>,
304 server_id: &[u8],
305) -> Option<SessionKeys>
306where
307 IO: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
308{
309 let client_identity = io.next().await?.ok()?;
310 let client_identity = String::from_utf8_lossy(client_identity.as_ref());
311 let password = cred_provider.get_client_password(&client_identity).await?;
312 let client_identity = client_identity.as_bytes();
313 let (spake, outbound_msg) = Spake2::<Ed25519Group>::start_b(
314 &Password::new(password),
315 &Identity::new(client_identity),
316 &Identity::new(server_id),
317 );
318 let peer_msg = io.next().await?.ok()?;
319
320 io.send(Bytes::from(outbound_msg)).await.ok()?;
321
322 let shared = spake.finish(&peer_msg).ok()?;
323
324 SessionKeys::derive_session_keys(&shared, true)
325}
326
327