Skip to main content

viiper_client/
auth.rs

1// This file is auto-generated by VIIPER codegen. DO NOT EDIT.
2
3use crate::error::ViiperError;
4use chacha20poly1305::{
5    aead::{Aead, KeyInit},
6    ChaCha20Poly1305, Nonce,
7};
8use hmac::{Hmac, Mac};
9use pbkdf2::pbkdf2_hmac;
10use rand::RngCore;
11use sha2::{Digest, Sha256};
12use std::io::{Read, Write};
13use std::net::TcpStream;
14
15#[cfg(feature = "async")]
16use std::pin::Pin;
17#[cfg(feature = "async")]
18use std::task::{Context, Poll};
19#[cfg(feature = "async")]
20use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
21#[cfg(feature = "async")]
22use tokio::net::TcpStream as AsyncTcpStream;
23
24const HANDSHAKE_MAGIC: &[u8] = b"eVI1\x00";
25const NONCE_SIZE: usize = 32;
26const AUTH_CONTEXT: &[u8] = b"VIIPER-Auth-v1";
27const SESSION_CONTEXT: &[u8] = b"VIIPER-Session-v1";
28const PBKDF2_ITERATIONS: u32 = 100_000;
29const PBKDF2_SALT: &[u8] = b"VIIPER-Key-v1";
30
31/// Derive a 32-byte key from password using PBKDF2-SHA256
32fn derive_key(password: &str) -> Result<[u8; 32], ViiperError> {
33    if password.is_empty() {
34        return Err(ViiperError::UnexpectedResponse("Password cannot be empty".into()));
35    }
36    let mut key = [0u8; 32];
37    pbkdf2_hmac::<Sha256>(password.as_bytes(), PBKDF2_SALT, PBKDF2_ITERATIONS, &mut key);
38    Ok(key)
39}
40
41/// Derive session key from key and nonces using SHA-256
42fn derive_session_key(key: &[u8], server_nonce: &[u8], client_nonce: &[u8]) -> [u8; 32] {
43    let mut hasher = Sha256::new();
44    hasher.update(key);
45    hasher.update(server_nonce);
46    hasher.update(client_nonce);
47    hasher.update(SESSION_CONTEXT);
48    hasher.finalize().into()
49}
50
51/// Perform authentication handshake with VIIPER server (synchronous)
52pub fn perform_handshake(mut stream: TcpStream, password: &str) -> Result<EncryptedStream, ViiperError> {
53    let key = derive_key(password)?;
54    let mut client_nonce = [0u8; NONCE_SIZE];
55    rand::thread_rng().fill_bytes(&mut client_nonce);
56    
57    let mut mac = <Hmac::<Sha256> as KeyInit>::new_from_slice(&key)
58        .map_err(|_| ViiperError::UnexpectedResponse("Invalid key length".into()))?;
59    mac.update(AUTH_CONTEXT);
60    mac.update(&client_nonce);
61    let auth_tag = mac.finalize().into_bytes();
62    
63    let mut handshake_msg = Vec::with_capacity(HANDSHAKE_MAGIC.len() + NONCE_SIZE + 32);
64    handshake_msg.extend_from_slice(HANDSHAKE_MAGIC);
65    handshake_msg.extend_from_slice(&client_nonce);
66    handshake_msg.extend_from_slice(&auth_tag);
67    
68    stream.write_all(&handshake_msg)?;
69    
70    let mut response = vec![0u8; 3 + NONCE_SIZE];
71    stream.read_exact(&mut response)?;
72    
73    if &response[0..3] != b"OK\x00" {
74        let mut error_buf = Vec::new();
75        let _ = stream.read_to_end(&mut error_buf);
76        let full_response = [response, error_buf].concat();
77        let error_str = String::from_utf8_lossy(&full_response);
78        
79        if let Ok(problem) = serde_json::from_str::<crate::error::ProblemJson>(&error_str) {
80            return Err(ViiperError::Protocol(problem));
81        }
82        return Err(ViiperError::UnexpectedResponse(format!("Invalid handshake response: {}", error_str)));
83    }
84    
85    let server_nonce = &response[3..];
86    
87    let session_key = derive_session_key(&key, server_nonce, &client_nonce);
88    
89    Ok(EncryptedStream::new(stream, session_key)?)
90}
91
92/// Perform authentication handshake with VIIPER server (asynchronous)
93#[cfg(feature = "async")]
94pub async fn perform_handshake_async(mut stream: AsyncTcpStream, password: &str) -> Result<AsyncEncryptedStream, ViiperError> {
95    let key = derive_key(password)?;
96    
97    let mut client_nonce = [0u8; NONCE_SIZE];
98    rand::thread_rng().fill_bytes(&mut client_nonce);
99    
100    let mut mac = <Hmac::<Sha256> as KeyInit>::new_from_slice(&key)
101        .map_err(|_| ViiperError::UnexpectedResponse("Invalid key length".into()))?;
102    mac.update(AUTH_CONTEXT);
103    mac.update(&client_nonce);
104    let auth_tag = mac.finalize().into_bytes();
105    
106    let mut handshake_msg = Vec::with_capacity(HANDSHAKE_MAGIC.len() + NONCE_SIZE + 32);
107    handshake_msg.extend_from_slice(HANDSHAKE_MAGIC);
108    handshake_msg.extend_from_slice(&client_nonce);
109    handshake_msg.extend_from_slice(&auth_tag);
110    
111    stream.write_all(&handshake_msg).await?;
112    
113    let mut response = vec![0u8; 3 + NONCE_SIZE];
114    stream.read_exact(&mut response).await?;
115    
116    if &response[0..3] != b"OK\x00" {
117        let mut error_buf = Vec::new();
118        let _ = stream.read_to_end(&mut error_buf).await;
119        let full_response = [response, error_buf].concat();
120        let error_str = String::from_utf8_lossy(&full_response);
121        
122        if let Ok(problem) = serde_json::from_str::<crate::error::ProblemJson>(&error_str) {
123            return Err(ViiperError::Protocol(problem));
124        }
125        return Err(ViiperError::UnexpectedResponse(format!("Invalid handshake response: {}", error_str)));
126    }
127    
128    let server_nonce = &response[3..];
129    
130    let session_key = derive_session_key(&key, server_nonce, &client_nonce);
131    
132    Ok(AsyncEncryptedStream::new(stream, session_key))
133}
134
135/// Encrypted stream wrapper using ChaCha20-Poly1305 (synchronous)
136/// Read and write paths are independently locked to avoid blocking writes
137/// while a read thread is waiting for output.
138pub struct EncryptedStream {
139    read: std::sync::Arc<std::sync::Mutex<EncryptedReadState>>,
140    write: std::sync::Arc<std::sync::Mutex<EncryptedWriteState>>,
141}
142
143struct EncryptedReadState {
144    stream: TcpStream,
145    cipher: ChaCha20Poly1305,
146    recv_buffer: Vec<u8>,
147}
148
149struct EncryptedWriteState {
150    stream: TcpStream,
151    cipher: ChaCha20Poly1305,
152    send_counter: u64,
153}
154
155impl EncryptedStream {
156    fn new(inner: TcpStream, session_key: [u8; 32]) -> Result<Self, ViiperError> {
157        let read_stream = inner.try_clone()?;
158        let read_cipher = ChaCha20Poly1305::new(&session_key.into());
159        let write_cipher = ChaCha20Poly1305::new(&session_key.into());
160        Ok(Self {
161            read: std::sync::Arc::new(std::sync::Mutex::new(EncryptedReadState {
162                stream: read_stream,
163                cipher: read_cipher,
164                recv_buffer: Vec::new(),
165            })),
166            write: std::sync::Arc::new(std::sync::Mutex::new(EncryptedWriteState {
167                stream: inner,
168                cipher: write_cipher,
169                send_counter: 0,
170            })),
171        })
172    }
173    
174    pub fn set_nodelay(&self, nodelay: bool) -> std::io::Result<()> {
175        let read = self.read.lock().unwrap();
176        let write = self.write.lock().unwrap();
177        read.stream.set_nodelay(nodelay)?;
178        write.stream.set_nodelay(nodelay)
179    }
180    
181    pub fn try_clone(&self) -> std::io::Result<Self> {
182        Ok(Self {
183            read: std::sync::Arc::clone(&self.read),
184            write: std::sync::Arc::clone(&self.write),
185        })
186    }
187    
188    pub fn shutdown(&self, how: std::net::Shutdown) -> std::io::Result<()> {
189        let read = self.read.lock().unwrap();
190        let write = self.write.lock().unwrap();
191        let _ = read.stream.shutdown(how);
192        write.stream.shutdown(how)
193    }
194}
195
196impl Read for EncryptedStream {
197    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
198        let mut inner = self.read.lock().unwrap();
199        
200        if inner.recv_buffer.is_empty() {
201            let mut first_byte = [0u8; 1];
202            let n = inner.stream.read(&mut first_byte)?;
203            if n == 0 {
204                return Ok(0);
205            }
206            
207            let mut len_buf = [0u8; 4];
208            len_buf[0] = first_byte[0];
209            inner.stream.read_exact(&mut len_buf[1..])?;
210            let packet_len = u32::from_be_bytes(len_buf) as usize;
211            
212            if packet_len > 2 * 1024 * 1024 {
213                return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Packet too large"));
214            }
215            
216            let mut packet = vec![0u8; packet_len];
217            inner.stream.read_exact(&mut packet)?;
218            
219            let nonce = Nonce::from_slice(&packet[0..12]);
220            let ciphertext_and_tag = &packet[12..];
221            
222            let plaintext = inner.cipher.decrypt(nonce, ciphertext_and_tag)
223                .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidData, "Decryption failed"))?;
224            
225            inner.recv_buffer = plaintext;
226        }
227        
228        let to_copy = buf.len().min(inner.recv_buffer.len());
229        buf[..to_copy].copy_from_slice(&inner.recv_buffer[..to_copy]);
230        inner.recv_buffer.drain(..to_copy);
231        Ok(to_copy)
232    }
233}
234
235impl Write for EncryptedStream {
236    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
237        let mut inner = self.write.lock().unwrap();
238        
239        let mut nonce_bytes = [0u8; 12];
240        nonce_bytes[4..].copy_from_slice(&inner.send_counter.to_be_bytes());
241        inner.send_counter += 1;
242        let nonce = Nonce::from_slice(&nonce_bytes);
243        
244        let ciphertext = inner.cipher.encrypt(nonce, buf)
245            .map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "Encryption failed"))?;
246        
247        let packet = [&nonce_bytes[..], ciphertext.as_slice()].concat();
248        let len_buf = (packet.len() as u32).to_be_bytes();
249        
250        inner.stream.write_all(&len_buf)?;
251        inner.stream.write_all(&packet)?;
252        
253        Ok(buf.len())
254    }
255    
256    fn flush(&mut self) -> std::io::Result<()> {
257        let mut inner = self.write.lock().unwrap();
258        inner.stream.flush()
259    }
260}
261
262/// Encrypted stream wrapper using ChaCha20-Poly1305 (asynchronous)
263#[cfg(feature = "async")]
264pub struct AsyncEncryptedStream {
265    read: AsyncEncryptedRead,
266    write: AsyncEncryptedWrite,
267}
268
269#[cfg(feature = "async")]
270pub struct AsyncEncryptedRead {
271    inner: tokio::net::tcp::OwnedReadHalf,
272    cipher: ChaCha20Poly1305,
273    recv_buffer: Vec<u8>,
274    read_state: ReadState,
275}
276
277#[cfg(feature = "async")]
278pub struct AsyncEncryptedWrite {
279    inner: tokio::net::tcp::OwnedWriteHalf,
280    cipher: ChaCha20Poly1305,
281    send_counter: u64,
282}
283
284#[cfg(feature = "async")]
285enum ReadState {
286    ReadingLength { buf: [u8; 4], pos: usize },
287    ReadingPacket { expected_len: usize, buf: Vec<u8>, pos: usize },
288    Ready,
289}
290
291#[cfg(feature = "async")]
292impl AsyncEncryptedStream {
293    fn new(inner: AsyncTcpStream, session_key: [u8; 32]) -> Self {
294        let (read_half, write_half) = inner.into_split();
295        let read_cipher = ChaCha20Poly1305::new(&session_key.into());
296        let write_cipher = ChaCha20Poly1305::new(&session_key.into());
297        Self {
298            read: AsyncEncryptedRead {
299                inner: read_half,
300                cipher: read_cipher,
301                recv_buffer: Vec::new(),
302                read_state: ReadState::ReadingLength { buf: [0; 4], pos: 0 },
303            },
304            write: AsyncEncryptedWrite {
305                inner: write_half,
306                cipher: write_cipher,
307                send_counter: 0,
308            },
309        }
310    }
311    
312    pub fn into_split(self) -> (AsyncEncryptedRead, AsyncEncryptedWrite) {
313        (self.read, self.write)
314    }
315}
316
317#[cfg(feature = "async")]
318impl AsyncRead for AsyncEncryptedRead {
319    fn poll_read(
320        mut self: Pin<&mut Self>,
321        cx: &mut Context<'_>,
322        buf: &mut ReadBuf<'_>,
323    ) -> Poll<std::io::Result<()>> {
324        if !self.recv_buffer.is_empty() {
325            let to_copy = buf.remaining().min(self.recv_buffer.len());
326            buf.put_slice(&self.recv_buffer[..to_copy]);
327            self.recv_buffer.drain(..to_copy);
328            return Poll::Ready(Ok(()));
329        }
330        
331        loop {
332            let state = std::mem::replace(&mut self.read_state, ReadState::Ready);
333            match state {
334                ReadState::ReadingLength { buf: mut len_buf, pos } => {
335                    let mut read_buf = ReadBuf::new(&mut len_buf[pos..]);
336                    
337                    match Pin::new(&mut self.inner).poll_read(cx, &mut read_buf) {
338                        Poll::Ready(Ok(())) => {
339                            let bytes_read = read_buf.filled().len();
340                            if bytes_read == 0 {
341                                if pos == 0 {
342                                    return Poll::Ready(Ok(())); // Normal EOF
343                                } else {
344                                    return Poll::Ready(Err(std::io::Error::new(
345                                        std::io::ErrorKind::UnexpectedEof,
346                                        "Connection closed while reading length"
347                                    )));
348                                }
349                            }
350                            let new_pos = pos + bytes_read;
351                            if new_pos < 4 {
352                                self.read_state = ReadState::ReadingLength { buf: len_buf, pos: new_pos };
353                            } else {
354                                // We have all 4 bytes
355                                let packet_len = u32::from_be_bytes(len_buf) as usize;
356                                if packet_len > 2 * 1024 * 1024 {
357                                    return Poll::Ready(Err(std::io::Error::new(
358                                        std::io::ErrorKind::InvalidData,
359                                        "Packet too large"
360                                    )));
361                                }
362                                self.read_state = ReadState::ReadingPacket {
363                                    expected_len: packet_len,
364                                    buf: vec![0u8; packet_len],
365                                    pos: 0,
366                                };
367                            }
368                        }
369                        Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
370                        Poll::Pending => {
371                            self.read_state = ReadState::ReadingLength { buf: len_buf, pos };
372                            return Poll::Pending;
373                        }
374                    }
375                }
376                ReadState::ReadingPacket { expected_len, buf: mut packet_buf, pos } => {
377                    let mut read_buf = ReadBuf::new(&mut packet_buf[pos..]);
378                    
379                    match Pin::new(&mut self.inner).poll_read(cx, &mut read_buf) {
380                        Poll::Ready(Ok(())) => {
381                            let bytes_read = read_buf.filled().len();
382                            if bytes_read == 0 {
383                                return Poll::Ready(Err(std::io::Error::new(
384                                    std::io::ErrorKind::UnexpectedEof,
385                                    "Connection closed while reading packet"
386                                )));
387                            }
388                            let new_pos = pos + bytes_read;
389                            if new_pos < expected_len {
390                                self.read_state = ReadState::ReadingPacket {
391                                    expected_len,
392                                    buf: packet_buf,
393                                    pos: new_pos,
394                                };
395                            } else {
396                                let nonce = Nonce::from_slice(&packet_buf[0..12]);
397                                let ciphertext_and_tag = &packet_buf[12..];
398                                
399                                match self.cipher.decrypt(nonce, ciphertext_and_tag) {
400                                    Ok(plaintext) => {
401                                        self.recv_buffer = plaintext;
402                                        self.read_state = ReadState::ReadingLength { buf: [0; 4], pos: 0 };
403                                        
404                                        let to_copy = buf.remaining().min(self.recv_buffer.len());
405                                        buf.put_slice(&self.recv_buffer[..to_copy]);
406                                        self.recv_buffer.drain(..to_copy);
407                                        return Poll::Ready(Ok(()));
408                                    }
409                                    Err(_) => {
410                                        return Poll::Ready(Err(std::io::Error::new(
411                                            std::io::ErrorKind::InvalidData,
412                                            "Decryption failed"
413                                        )));
414                                    }
415                                }
416                            }
417                        }
418                        Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
419                        Poll::Pending => {
420                            self.read_state = ReadState::ReadingPacket {
421                                expected_len,
422                                buf: packet_buf,
423                                pos,
424                            };
425                            return Poll::Pending;
426                        }
427                    }
428                }
429                ReadState::Ready => {
430                    self.read_state = ReadState::ReadingLength { buf: [0; 4], pos: 0 };
431                }
432            }
433        }
434    }
435}
436
437#[cfg(feature = "async")]
438impl AsyncWrite for AsyncEncryptedWrite {
439    fn poll_write(
440        mut self: Pin<&mut Self>,
441        cx: &mut Context<'_>,
442        buf: &[u8],
443    ) -> Poll<Result<usize, std::io::Error>> {
444        let mut nonce_bytes = [0u8; 12];
445        nonce_bytes[4..].copy_from_slice(&self.send_counter.to_be_bytes());
446        self.send_counter += 1;
447        let nonce = Nonce::from_slice(&nonce_bytes);
448        
449        let ciphertext = self.cipher.encrypt(nonce, buf)
450            .map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "Encryption failed"))?;
451        
452        let packet = [&nonce_bytes[..], ciphertext.as_slice()].concat();
453        let len_buf = (packet.len() as u32).to_be_bytes();
454        
455        let full_packet = [&len_buf[..], &packet].concat();
456        
457        match Pin::new(&mut self.inner).poll_write(cx, &full_packet) {
458            Poll::Ready(Ok(n)) if n >= full_packet.len() => Poll::Ready(Ok(buf.len())),
459            Poll::Ready(Ok(_)) => Poll::Ready(Err(std::io::Error::new(
460                std::io::ErrorKind::WriteZero,
461                "Failed to write complete packet"
462            ))),
463            Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
464            Poll::Pending => Poll::Pending,
465        }
466    }
467    
468    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
469        Pin::new(&mut self.inner).poll_flush(cx)
470    }
471    
472    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
473        Pin::new(&mut self.inner).poll_shutdown(cx)
474    }
475}
476
477#[cfg(feature = "async")]
478impl AsyncRead for AsyncEncryptedStream {
479    fn poll_read(
480        mut self: Pin<&mut Self>,
481        cx: &mut Context<'_>,
482        buf: &mut ReadBuf<'_>,
483    ) -> Poll<std::io::Result<()>> {
484        Pin::new(&mut self.read).poll_read(cx, buf)
485    }
486}
487
488#[cfg(feature = "async")]
489impl AsyncWrite for AsyncEncryptedStream {
490    fn poll_write(
491        mut self: Pin<&mut Self>,
492        cx: &mut Context<'_>,
493        buf: &[u8],
494    ) -> Poll<Result<usize, std::io::Error>> {
495        Pin::new(&mut self.write).poll_write(cx, buf)
496    }
497
498    fn poll_flush(
499        mut self: Pin<&mut Self>,
500        cx: &mut Context<'_>,
501    ) -> Poll<Result<(), std::io::Error>> {
502        Pin::new(&mut self.write).poll_flush(cx)
503    }
504
505    fn poll_shutdown(
506        mut self: Pin<&mut Self>,
507        cx: &mut Context<'_>,
508    ) -> Poll<Result<(), std::io::Error>> {
509        Pin::new(&mut self.write).poll_shutdown(cx)
510    }
511}