Skip to main content

tfserver/codec/
spake2_encrypted.rs

1use std::io;
2use std::io::Error;
3use crate::codec::codec_trait::TfCodec;
4use crate::structures::temp_transport::TempTransport;
5use crate::structures::transport::{AsyncReadWrite, Transport};
6use aes_gcm::{
7    Aes256Gcm, Key, Nonce,
8    aead::{Aead, AeadCore, KeyInit, OsRng},
9};
10use async_trait::async_trait;
11use bytes::{Bytes, BytesMut};
12use futures_util::{SinkExt, StreamExt};
13use hkdf::Hkdf;
14use sha2::Sha256;
15use spake2::{Ed25519Group, Identity, Password, Spake2};
16use std::sync::Arc;
17use std::sync::atomic::{AtomicU64, Ordering};
18use aead::AeadInPlace;
19use tokio_util::codec::{Decoder, Encoder, Framed, LengthDelimitedCodec};
20
21pub struct Spake2Encrypted {
22    server_provider: Option<Arc<dyn ServerCredentialProvider>>,
23    client_provider: Option<Arc<dyn ClientCredentialProvider>>,
24    is_server: bool,
25    server_id: Vec<u8>,
26    length_codec: LengthDelimitedCodec,
27    keys: Option<SessionKeys>,
28}
29
30impl Spake2Encrypted {
31    pub fn create_server(
32        server_provider: Arc<dyn ServerCredentialProvider>,
33        server_id: String,
34        codec: LengthDelimitedCodec,
35    ) -> Self {
36        Self {
37            server_provider: Some(server_provider),
38            client_provider: None,
39            is_server: true,
40            server_id: server_id.as_bytes().to_vec(),
41            length_codec: codec,
42            keys: None,
43        }
44    }
45
46    pub fn create_client(
47        client_provider: Arc<dyn ClientCredentialProvider>,
48        server_id: String,
49        codec: LengthDelimitedCodec,
50    ) -> Self {
51        Self {
52            server_provider: None,
53            client_provider: Some(client_provider),
54            is_server: false,
55            server_id: server_id.as_bytes().to_vec(),
56            length_codec: codec,
57            keys: None,
58        }
59    }
60}
61impl Decoder for Spake2Encrypted {
62    type Item = BytesMut;
63    type Error = io::Error;
64
65    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
66        let mut frame = match self.length_codec.decode(src)? {
67            Some(f) => f,
68            None => return Ok(None),
69        };
70        if let Some(keys) = &self.keys {
71            keys.open_in_place(&mut frame)
72                .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "decryption failed"))?;
73        } else {
74            return Err(io::Error::new(io::ErrorKind::Other, "decryption failed"));
75        }
76        Ok(Some(frame))
77    }
78}
79
80impl Encoder<Bytes> for Spake2Encrypted {
81    type Error = io::Error;
82
83    fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), Self::Error> {
84        if let Some(keys) = &self.keys {
85            let mut buf = BytesMut::from(item);
86            keys.seal_in_place(&mut buf)
87                .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "encryption failed"))?;
88            self.length_codec.encode(buf.freeze(), dst)
89        } else {
90            return Err(io::Error::new(io::ErrorKind::Other, "encryption failed"));
91        }
92    }
93}
94impl Clone for Spake2Encrypted {
95    fn clone(&self) -> Self {
96        Self{
97            server_provider: self.server_provider.clone(),
98            client_provider: self.client_provider.clone(),
99            is_server: self.is_server.clone(),
100            server_id: self.server_id.clone(),
101            length_codec: self.length_codec.clone(),
102            keys: None
103        }
104    }
105}
106
107#[async_trait]
108impl TfCodec for Spake2Encrypted {
109    async fn initial_setup(&mut self, tr: &mut Transport) -> bool {
110        ///Safe limitation to prevent dos
111        let length_codec = LengthDelimitedCodec::builder().max_frame_length(2048).new_codec();
112        let mut framed = Framed::new(TempTransport::new(tr), length_codec);
113        if self.is_server{
114            let res = server_handshake(&mut framed, self.server_provider.as_ref().unwrap().clone(), self.server_id.as_slice()).await;
115            if let Some(keys) = res {
116                self.keys = Some(keys);
117                return true;
118            } else {
119                return false;
120            }
121        } else {
122            let res = client_handshake(&mut framed, self.client_provider.as_ref().unwrap().clone(), self.server_id.as_slice()).await;
123            if let Some(keys) = res {
124                self.keys = Some(keys);
125                return true;
126            }
127            return false;
128        }
129    }
130}
131
132
133#[async_trait]
134pub trait ServerCredentialProvider: Send+Sync+'static  {
135    async fn get_client_password(&self, client_identity: &str) -> Option<Vec<u8>>;
136}
137
138#[async_trait]
139pub trait ClientCredentialProvider: Send+Sync+'static {
140    ///Return 0 - client identity, 1 - client password
141    async fn get_client_credentials(&self) -> Option<(Vec<u8>, Vec<u8>)>;
142}
143
144pub struct SessionKeys {
145    pub send: Aes256Gcm,
146    pub recv: Aes256Gcm,
147
148    /// Local outbound packet counter
149    send_counter: AtomicU64,
150
151    /// Highest accepted inbound counter
152    recv_counter: AtomicU64,
153}
154
155
156struct BytesMutBuffer(pub BytesMut);
157
158impl AsRef<[u8]> for BytesMutBuffer {
159    fn as_ref(&self) -> &[u8] { &self.0 }
160}
161
162impl AsMut<[u8]> for BytesMutBuffer {
163    fn as_mut(&mut self) -> &mut [u8] { &mut self.0 }
164}
165
166impl aead::Buffer for BytesMutBuffer {
167    fn extend_from_slice(&mut self, other: &[u8]) -> aead::Result<()> {
168        self.0.extend_from_slice(other);
169        Ok(())
170    }
171
172    fn truncate(&mut self, len: usize) {
173        self.0.truncate(len);
174    }
175}
176impl SessionKeys {
177
178        fn derive_session_keys(shared: &[u8], is_server: bool) -> Option<Self> {
179            let hk = Hkdf::<Sha256>::new(None, shared);
180
181            let mut key_a = [0u8; 32];
182            let mut key_b = [0u8; 32];
183
184            hk.expand(b"aes-tunnel-key-a", &mut key_a).ok()?;
185            hk.expand(b"aes-tunnel-key-b", &mut key_b).ok()?;
186
187            let (send_key, recv_key) = if is_server {
188                (key_b, key_a)
189            } else {
190                (key_a, key_b)
191            };
192
193            Some(Self {
194                send: Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&send_key)),
195                recv: Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&recv_key)),
196                send_counter: AtomicU64::new(1),
197                recv_counter: AtomicU64::new(0),
198            })
199        }
200
201        #[inline]
202        fn nonce_from_counter(counter: u64) -> [u8; 12] {
203            let mut nonce = [0u8; 12];
204            nonce[4..].copy_from_slice(&counter.to_be_bytes());
205            nonce
206        }
207
208        pub fn seal_in_place(&self, buf: &mut BytesMut) -> Option<()> {
209            let counter = self.send_counter.fetch_add(1, Ordering::Relaxed);
210
211            if counter == u64::MAX {
212                return None;
213            }
214
215            let counter_bytes = counter.to_be_bytes();
216            let nonce_bytes = Self::nonce_from_counter(counter);
217            let nonce = Nonce::from_slice(&nonce_bytes);
218
219            let mut wrapped = BytesMutBuffer(buf.split());
220
221            // counter is included in AAD so any wire tampering fails tag verification
222            self.send
223                .encrypt_in_place(nonce, &counter_bytes, &mut wrapped)
224                .ok()?;
225
226            buf.clear();
227            buf.reserve(8 + wrapped.0.len());
228            buf.extend_from_slice(&counter_bytes);
229            buf.unsplit(wrapped.0);
230
231            Some(())
232        }
233
234        pub fn open_in_place(&self, buf: &mut BytesMut) -> Option<()> {
235            const COUNTER_LEN: usize = 8;
236
237            if buf.len() < COUNTER_LEN {
238                return None;
239            }
240
241            let counter = u64::from_be_bytes(buf[..COUNTER_LEN].try_into().ok()?);
242
243            if counter == u64::MAX {
244                return None;
245            }
246
247            // compare-exchange loop — prevents TOCTOU race if called concurrently
248            let mut last = self.recv_counter.load(Ordering::Acquire);
249            loop {
250                if counter <= last {
251                    return None; // replay or reorder
252                }
253                match self.recv_counter.compare_exchange_weak(
254                    last,
255                    counter,
256                    Ordering::AcqRel,
257                    Ordering::Acquire,
258                ) {
259                    Ok(_) => break,
260                    Err(current) => last = current, // another thread advanced it, retry
261                }
262            }
263
264            let counter_bytes = counter.to_be_bytes();
265            let nonce_bytes = Self::nonce_from_counter(counter);
266            let nonce = Nonce::from_slice(&nonce_bytes);
267
268            let ciphertext = buf.split_off(COUNTER_LEN);
269            let mut wrapped = BytesMutBuffer(ciphertext);
270
271            // AAD must match what seal used, otherwise tag fails
272            self.recv
273                .decrypt_in_place(nonce, &counter_bytes, &mut wrapped)
274                .ok()?;
275
276            *buf = wrapped.0;
277
278            Some(())
279        }
280    }
281
282pub async fn client_handshake<'a, IO: AsyncReadWrite>(
283    io: &mut Framed<TempTransport<'a, IO>, LengthDelimitedCodec>,
284    cred: Arc<dyn ClientCredentialProvider>,
285    server_id: &[u8],
286) -> Option<SessionKeys> {
287    let creds = cred.get_client_credentials().await?;
288    let (spake, outbound_msg) = Spake2::<Ed25519Group>::start_a(
289        &Password::new(creds.1.as_slice()),
290        &Identity::new(creds.0.as_slice()),
291        &Identity::new(server_id),
292    );
293    io.send(Bytes::from(creds.0.clone())).await.ok()?;
294    io.send(Bytes::from(outbound_msg)).await.ok()?;
295
296    let peer_msg = io.next().await?.ok()?;
297
298    let shared = spake.finish(&peer_msg).ok()?;
299
300    SessionKeys::derive_session_keys(&shared, false)
301}
302
303pub async fn server_handshake<'a, IO: AsyncReadWrite>(
304    io: &mut Framed<TempTransport<'a, IO>, LengthDelimitedCodec>,
305    cred_provider: Arc<dyn ServerCredentialProvider>,
306    server_id: &[u8],
307) -> Option<SessionKeys>
308where
309    IO: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
310{
311    let client_identity = io.next().await?.ok()?;
312    let client_identity = String::from_utf8_lossy(client_identity.as_ref());
313    let password = cred_provider.get_client_password(&client_identity).await?;
314    let client_identity = client_identity.as_bytes();
315    let (spake, outbound_msg) = Spake2::<Ed25519Group>::start_b(
316        &Password::new(password),
317        &Identity::new(client_identity),
318        &Identity::new(server_id),
319    );
320    let peer_msg = io.next().await?.ok()?;
321
322    io.send(Bytes::from(outbound_msg)).await.ok()?;
323
324    let shared = spake.finish(&peer_msg).ok()?;
325
326    SessionKeys::derive_session_keys(&shared, true)
327}
328
329