Skip to main content

tfserver/codec/
spake2_encrypted.rs

1use std::io;
2use std::io::Error;
3use crate::codec::codec_trait::TfCodec;
4use crate::structures::temp_transport::TempTransport;
5use crate::structures::transport::{AsyncReadWrite, Transport};
6use aes_gcm::{
7    Aes256Gcm, Key, Nonce,
8    aead::{Aead, AeadCore, KeyInit, OsRng},
9};
10use async_trait::async_trait;
11use bytes::{Bytes, BytesMut};
12use futures_util::{SinkExt, StreamExt};
13use hkdf::Hkdf;
14use sha2::Sha256;
15use spake2::{Ed25519Group, Identity, Password, Spake2};
16use std::sync::Arc;
17use std::sync::atomic::{AtomicU64, Ordering};
18use aead::AeadInPlace;
19use tokio_util::codec::{Decoder, Encoder, Framed, LengthDelimitedCodec};
20
21pub struct Spake2Encrypted {
22    server_provider: Option<Arc<dyn ServerCredentialProvider>>,
23    client_provider: Option<Arc<dyn ClientCredentialProvider>>,
24    is_server: bool,
25    server_id: Vec<u8>,
26    length_codec: LengthDelimitedCodec,
27    keys: Option<SessionKeys>,
28}
29
30impl Spake2Encrypted {
31    pub fn create_server(
32        server_provider: Arc<dyn ServerCredentialProvider>,
33        server_id: String,
34        codec: LengthDelimitedCodec,
35    ) -> Self {
36        Self {
37            server_provider: Some(server_provider),
38            client_provider: None,
39            is_server: true,
40            server_id: server_id.as_bytes().to_vec(),
41            length_codec: codec,
42            keys: None,
43        }
44    }
45
46    pub fn create_client(
47        client_provider: Arc<dyn ClientCredentialProvider>,
48        server_id: String,
49        codec: LengthDelimitedCodec,
50    ) -> Self {
51        Self {
52            server_provider: None,
53            client_provider: Some(client_provider),
54            is_server: false,
55            server_id: server_id.as_bytes().to_vec(),
56            length_codec: codec,
57            keys: None,
58        }
59    }
60}
61impl Decoder for Spake2Encrypted {
62    type Item = BytesMut;
63    type Error = io::Error;
64
65    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
66        let mut frame = match self.length_codec.decode(src)? {
67            Some(f) => f,
68            None => return Ok(None),
69        };
70        if let Some(keys) = &self.keys {
71            keys.open_in_place(&mut frame)
72                .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "decryption failed"))?;
73        }
74        Ok(Some(frame))
75    }
76}
77
78impl Encoder<Bytes> for Spake2Encrypted {
79    type Error = io::Error;
80
81    fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
82        if let Some(keys) = &self.keys {
83            let mut buf = BytesMut::from(item);
84            keys.seal_in_place(&mut buf)
85                .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "encryption failed"))?;
86            self.length_codec.encode(buf.freeze(), dst)
87        } else {
88            self.length_codec.encode(item, dst)
89        }
90    }
91}
92impl Clone for Spake2Encrypted {
93    fn clone(&self) -> Self {
94        Self{
95            server_provider: self.server_provider.clone(),
96            client_provider: self.client_provider.clone(),
97            is_server: self.is_server.clone(),
98            server_id: self.server_id.clone(),
99            length_codec: self.length_codec.clone(),
100            keys: None
101        }
102    }
103}
104
105#[async_trait]
106impl TfCodec for Spake2Encrypted {
107    async fn initial_setup(&mut self, tr: &mut Transport) -> bool {
108        ///Safe limitation to prevent dos
109        let length_codec = LengthDelimitedCodec::builder().max_frame_length(2048).new_codec();
110        let mut framed = Framed::new(TempTransport::new(tr), length_codec);
111        if self.is_server{
112            let res = server_handshake(&mut framed, self.server_provider.as_ref().unwrap().clone(), self.server_id.as_slice()).await;
113            if let Some(keys) = res {
114                self.keys = Some(keys);
115                return true;
116            } else {
117                return false;
118            }
119        } else {
120            let res = client_handshake(&mut framed, self.client_provider.as_ref().unwrap().clone(), self.server_id.as_slice()).await;
121            if let Some(keys) = res {
122                self.keys = Some(keys);
123                return true;
124            }
125            return false;
126        }
127    }
128}
129
130
131#[async_trait]
132pub trait ServerCredentialProvider: Send+Sync+'static  {
133    async fn get_client_password(&self, client_identity: &str) -> Option<Vec<u8>>;
134}
135
136#[async_trait]
137pub trait ClientCredentialProvider: Send+Sync+'static {
138    ///Return 0 - client identity, 1 - client password
139    async fn get_client_credentials(&self) -> Option<(Vec<u8>, Vec<u8>)>;
140}
141
142pub struct SessionKeys {
143    pub send: Aes256Gcm,
144    pub recv: Aes256Gcm,
145
146    /// Local outbound packet counter
147    send_counter: AtomicU64,
148
149    /// Highest accepted inbound counter
150    recv_counter: AtomicU64,
151}
152
153
154struct BytesMutBuffer(pub BytesMut);
155
156impl AsRef<[u8]> for BytesMutBuffer {
157    fn as_ref(&self) -> &[u8] { &self.0 }
158}
159
160impl AsMut<[u8]> for BytesMutBuffer {
161    fn as_mut(&mut self) -> &mut [u8] { &mut self.0 }
162}
163
164impl aead::Buffer for BytesMutBuffer {
165    fn extend_from_slice(&mut self, other: &[u8]) -> aead::Result<()> {
166        self.0.extend_from_slice(other);
167        Ok(())
168    }
169
170    fn truncate(&mut self, len: usize) {
171        self.0.truncate(len);
172    }
173}
174impl SessionKeys {
175
176        fn derive_session_keys(shared: &[u8], is_server: bool) -> Option<Self> {
177            let hk = Hkdf::<Sha256>::new(None, shared);
178
179            let mut key_a = [0u8; 32];
180            let mut key_b = [0u8; 32];
181
182            hk.expand(b"aes-tunnel-key-a", &mut key_a).ok()?;
183            hk.expand(b"aes-tunnel-key-b", &mut key_b).ok()?;
184
185            let (send_key, recv_key) = if is_server {
186                (key_b, key_a)
187            } else {
188                (key_a, key_b)
189            };
190
191            Some(Self {
192                send: Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&send_key)),
193                recv: Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&recv_key)),
194                send_counter: AtomicU64::new(1),
195                recv_counter: AtomicU64::new(0),
196            })
197        }
198
199        #[inline]
200        fn nonce_from_counter(counter: u64) -> [u8; 12] {
201            let mut nonce = [0u8; 12];
202            nonce[4..].copy_from_slice(&counter.to_be_bytes());
203            nonce
204        }
205
206        pub fn seal_in_place(&self, buf: &mut BytesMut) -> Option<()> {
207            let counter = self.send_counter.fetch_add(1, Ordering::Relaxed);
208
209            if counter == u64::MAX {
210                return None;
211            }
212
213            let counter_bytes = counter.to_be_bytes();
214            let nonce_bytes = Self::nonce_from_counter(counter);
215            let nonce = Nonce::from_slice(&nonce_bytes);
216
217            let mut wrapped = BytesMutBuffer(buf.split());
218
219            // counter is included in AAD so any wire tampering fails tag verification
220            self.send
221                .encrypt_in_place(nonce, &counter_bytes, &mut wrapped)
222                .ok()?;
223
224            buf.clear();
225            buf.reserve(8 + wrapped.0.len());
226            buf.extend_from_slice(&counter_bytes);
227            buf.unsplit(wrapped.0);
228
229            Some(())
230        }
231
232        pub fn open_in_place(&self, buf: &mut BytesMut) -> Option<()> {
233            const COUNTER_LEN: usize = 8;
234
235            if buf.len() < COUNTER_LEN {
236                return None;
237            }
238
239            let counter = u64::from_be_bytes(buf[..COUNTER_LEN].try_into().ok()?);
240
241            if counter == u64::MAX {
242                return None;
243            }
244
245            // compare-exchange loop — prevents TOCTOU race if called concurrently
246            let mut last = self.recv_counter.load(Ordering::Acquire);
247            loop {
248                if counter <= last {
249                    return None; // replay or reorder
250                }
251                match self.recv_counter.compare_exchange_weak(
252                    last,
253                    counter,
254                    Ordering::AcqRel,
255                    Ordering::Acquire,
256                ) {
257                    Ok(_) => break,
258                    Err(current) => last = current, // another thread advanced it, retry
259                }
260            }
261
262            let counter_bytes = counter.to_be_bytes();
263            let nonce_bytes = Self::nonce_from_counter(counter);
264            let nonce = Nonce::from_slice(&nonce_bytes);
265
266            let ciphertext = buf.split_off(COUNTER_LEN);
267            let mut wrapped = BytesMutBuffer(ciphertext);
268
269            // AAD must match what seal used, otherwise tag fails
270            self.recv
271                .decrypt_in_place(nonce, &counter_bytes, &mut wrapped)
272                .ok()?;
273
274            *buf = wrapped.0;
275
276            Some(())
277        }
278    }
279
280pub async fn client_handshake<'a, IO: AsyncReadWrite>(
281    io: &mut Framed<TempTransport<'a, IO>, LengthDelimitedCodec>,
282    cred: Arc<dyn ClientCredentialProvider>,
283    server_id: &[u8],
284) -> Option<SessionKeys> {
285    let creds = cred.get_client_credentials().await?;
286    let (spake, outbound_msg) = Spake2::<Ed25519Group>::start_a(
287        &Password::new(creds.1.as_slice()),
288        &Identity::new(creds.0.as_slice()),
289        &Identity::new(server_id),
290    );
291    io.send(Bytes::from(creds.0.clone())).await.ok()?;
292    io.send(Bytes::from(outbound_msg)).await.ok()?;
293
294    let peer_msg = io.next().await?.ok()?;
295
296    let shared = spake.finish(&peer_msg).ok()?;
297
298    SessionKeys::derive_session_keys(&shared, false)
299}
300
301pub async fn server_handshake<'a, IO: AsyncReadWrite>(
302    io: &mut Framed<TempTransport<'a, IO>, LengthDelimitedCodec>,
303    cred_provider: Arc<dyn ServerCredentialProvider>,
304    server_id: &[u8],
305) -> Option<SessionKeys>
306where
307    IO: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
308{
309    let client_identity = io.next().await?.ok()?;
310    let client_identity = String::from_utf8_lossy(client_identity.as_ref());
311    let password = cred_provider.get_client_password(&client_identity).await?;
312    let client_identity = client_identity.as_bytes();
313    let (spake, outbound_msg) = Spake2::<Ed25519Group>::start_b(
314        &Password::new(password),
315        &Identity::new(client_identity),
316        &Identity::new(server_id),
317    );
318    let peer_msg = io.next().await?.ok()?;
319
320    io.send(Bytes::from(outbound_msg)).await.ok()?;
321
322    let shared = spake.finish(&peer_msg).ok()?;
323
324    SessionKeys::derive_session_keys(&shared, true)
325}
326
327