1#![warn(missing_docs)]
45
46mod crypto;
47use std::{convert::TryInto, fmt::{self, Debug, Display, Formatter}, io::{self, ErrorKind}, net::SocketAddr};
48use tokio::{io::{AsyncReadExt, AsyncWriteExt}, net::TcpStream};
49#[cfg(feature = "split")]
50use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
51use async_trait::async_trait;
52use ed25519_dalek::{ed25519::signature::Signature, Keypair, Signer, Verifier, SIGNATURE_LENGTH};
53use rand::{RngCore, rngs::OsRng};
54use sha2::{Sha384, Digest};
55use aes_gcm::{Aes128Gcm, aead::Aead, NewAead, aead::Payload, Nonce};
56use crypto::{HandshakeKeys, ApplicationKeys};
57
58const AES_TAG_LEN: usize = 16;
59const RANDOM_LEN: usize = 64;
60const MESSAGE_LEN_LEN: usize = 4;
61type MessageLenType = u32;
62
63pub const PUBLIC_KEY_LENGTH: usize = ed25519_dalek::PUBLIC_KEY_LENGTH;
65
66pub type Identity = Keypair;
70
71#[derive(Debug, PartialEq, Eq)]
73pub enum PsecError {
74 BrokenPipe,
76 ConnectionReset,
78 TransmissionCorrupted,
80 BufferTooLarge,
82 UnexpectedEof,
84 IoError {
86 error_kind: ErrorKind,
88 },
89 BadPadding,
91}
92
93impl Display for PsecError {
94 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
95 match self {
96 PsecError::BrokenPipe => f.write_str("Broken pipe"),
97 PsecError::ConnectionReset => f.write_str("Connection reset"),
98 PsecError::TransmissionCorrupted => f.write_str("Transmission corrupted"),
99 PsecError::BufferTooLarge => f.write_str("Received buffer is too large"),
100 PsecError::UnexpectedEof => f.write_str("Unexpected EOF"),
101 PsecError::IoError { error_kind } => f.write_str(&format!("{:?}", error_kind)),
102 PsecError::BadPadding => f.write_str("Bad Padding"),
103 }
104 }
105}
106
107fn slice_to_public_key(buff: &[u8]) -> x25519_dalek::PublicKey {
108 let array: [u8; PUBLIC_KEY_LENGTH] = buff.try_into().unwrap();
109 x25519_dalek::PublicKey::from(array)
110}
111
112async fn receive<T: AsyncReadExt + Unpin>(reader: &mut T, buff: &mut [u8]) -> Result<usize, PsecError> {
113 match reader.read_exact(buff).await {
114 Ok(read) => {
115 if read > 0 {
116 Ok(read)
117 } else {
118 Err(PsecError::BrokenPipe)
119 }
120 }
121 Err(e) => {
122 match e.kind() {
123 ErrorKind::UnexpectedEof => Err(PsecError::UnexpectedEof),
124 ErrorKind::ConnectionReset => Err(PsecError::ConnectionReset),
125 _ => Err(PsecError::IoError { error_kind: e.kind() })
126 }
127 }
128 }
129}
130
131async fn send<T: AsyncWriteExt + Unpin>(writer: &mut T, buff: &[u8]) -> Result<(), PsecError> {
132 match writer.write_all(buff).await {
133 Ok(_) => Ok(()),
134 Err(e) => Err(match e.kind() {
135 ErrorKind::BrokenPipe => PsecError::BrokenPipe,
136 ErrorKind::ConnectionReset => PsecError::ConnectionReset,
137 _ => PsecError::IoError { error_kind: e.kind() }
138 })
139 }
140}
141
142fn pad(plain_text: &[u8], use_padding: bool) -> Vec<u8> {
143 let encoded_msg_len = (plain_text.len() as MessageLenType).to_be_bytes();
144 let msg_len = plain_text.len()+encoded_msg_len.len();
145 let mut output = Vec::from(encoded_msg_len);
146 if use_padding {
147 let mut len = 1000;
148 while len < msg_len {
149 len *= 2;
150 }
151 output.reserve(len);
152 output.extend(plain_text);
153 output.resize(len, 0);
154 OsRng.fill_bytes(&mut output[msg_len..]);
155 } else {
156 output.extend(plain_text);
157 }
158 output
159}
160
161fn unpad(input: Vec<u8>) -> Result<Vec<u8>, PsecError> {
162 if input.len() < 4 {
163 Err(PsecError::BadPadding)
164 } else {
165 let msg_len = MessageLenType::from_be_bytes(input[0..MESSAGE_LEN_LEN].try_into().unwrap()) as usize;
166 Ok(Vec::from(&input[MESSAGE_LEN_LEN..MESSAGE_LEN_LEN+msg_len]))
167 }
168}
169
170fn encrypt(local_cipher: &Aes128Gcm, local_iv: &[u8], local_counter: &mut usize, plain_text: &[u8], use_padding: bool) -> Vec<u8> {
171 let padded_msg = pad(plain_text, use_padding);
172 let cipher_len = ((padded_msg.len() + AES_TAG_LEN) as MessageLenType).to_be_bytes();
173 let payload = Payload {
174 msg: &padded_msg,
175 aad: &cipher_len
176 };
177 let nonce = crypto::iv_to_nonce(local_iv, local_counter);
178 let cipher_text = local_cipher.encrypt(Nonce::from_slice(&nonce), payload).unwrap();
179 [&cipher_len, cipher_text.as_slice()].concat()
180}
181
182async fn encrypt_and_send<T: AsyncWriteExt + Unpin>(writer: &mut T, local_cipher: &Aes128Gcm, local_iv: &[u8], local_counter: &mut usize, plain_text: &[u8], use_padding: bool) -> Result<(), PsecError> {
183 let cipher_text = encrypt(local_cipher, local_iv, local_counter, plain_text, use_padding);
184 send(writer, &cipher_text).await
185}
186
187async fn receive_and_decrypt<T: AsyncReadExt + Unpin>(reader: &mut T, peer_cipher: &Aes128Gcm, peer_iv: &[u8], peer_counter: &mut usize, max_recv_size: Option<usize>) -> Result<Vec<u8>, PsecError> {
188 let mut message_len = [0; MESSAGE_LEN_LEN];
189 receive(reader, &mut message_len).await?;
190 let recv_len = MessageLenType::from_be_bytes(message_len) as usize;
191 if let Some(max_recv_size) = max_recv_size {
192 if recv_len > max_recv_size {
193 return Err(PsecError::BufferTooLarge);
194 }
195 }
196 let mut cipher_text = vec![0; recv_len];
197 let mut read = 0;
198 while read < recv_len {
199 read += receive(reader, &mut cipher_text[read..]).await?;
200 }
201 let peer_nonce = crypto::iv_to_nonce(peer_iv, peer_counter);
202 let payload = Payload {
203 msg: &cipher_text,
204 aad: &message_len
205 };
206 match peer_cipher.decrypt(Nonce::from_slice(&peer_nonce), payload) {
207 Ok(plain_text) => unpad(plain_text),
208 Err(_) => Err(PsecError::TransmissionCorrupted)
209 }
210}
211
212fn compute_max_recv_size(size: usize, is_raw_size: bool) -> usize {
213 if is_raw_size {
214 size
215 } else {
216 let max_not_padded_size = size+MESSAGE_LEN_LEN;
217 let mut max_padded_size = 1000;
218 while max_padded_size < max_not_padded_size {
219 max_padded_size *= 2;
220 }
221 max_padded_size+AES_TAG_LEN
222 }
223}
224
225#[async_trait]
227pub trait PsecReader {
228 fn set_max_recv_size(&mut self, size: usize, is_raw_size: bool);
234
235 async fn receive_and_decrypt(&mut self) -> Result<Vec<u8>, PsecError>;
240
241 async fn into_receive_and_decrypt(self) -> (Result<Vec<u8>, PsecError>, Self);
274}
275
276#[async_trait]
278pub trait PsecWriter {
279 async fn encrypt_and_send(&mut self, plain_text: &[u8], use_padding: bool) -> Result<(), PsecError>;
287
288 fn encrypt(&mut self, plain_text: &[u8], use_padding: bool) -> Vec<u8>;
310
311 async fn send(&mut self, cipher_text: &[u8]) -> Result<(), PsecError>;
315}
316
317#[cfg(feature = "split")]
319pub struct SessionReadHalf {
320 read_half: OwnedReadHalf,
321 peer_cipher: Aes128Gcm,
322 peer_iv: [u8; crypto::IV_LEN],
323 peer_counter: usize,
324 max_recv_size: Option<usize>,
325}
326
327#[cfg(feature = "split")]
328impl Debug for SessionReadHalf {
329 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
330 f.debug_struct("SessionReadHalf")
331 .field("read_half", &self.read_half)
332 .field("max_recv_size", &self.max_recv_size)
333 .field("peer_counter", &self.peer_counter)
334 .field("peer_iv", &hex_encode(&self.peer_iv))
335 .finish()
336 }
337}
338
339#[cfg(feature = "split")]
340#[async_trait]
341impl PsecReader for SessionReadHalf {
342 fn set_max_recv_size(&mut self, size: usize, is_raw_size: bool) {
343 self.max_recv_size = Some(compute_max_recv_size(size, is_raw_size));
344 }
345 async fn receive_and_decrypt(&mut self) -> Result<Vec<u8>, PsecError> {
346 receive_and_decrypt(&mut self.read_half, &self.peer_cipher, &self.peer_iv, &mut self.peer_counter, self.max_recv_size).await
347 }
348 async fn into_receive_and_decrypt(mut self) -> (Result<Vec<u8>, PsecError>, Self) {
349 (self.receive_and_decrypt().await, self)
350 }
351}
352
353#[cfg(feature = "split")]
354pub struct SessionWriteHalf {
356 write_half: OwnedWriteHalf,
357 local_cipher: Aes128Gcm,
358 local_iv: [u8; crypto::IV_LEN],
359 local_counter: usize,
360}
361
362#[cfg(feature = "split")]
363#[async_trait]
364impl PsecWriter for SessionWriteHalf {
365 async fn encrypt_and_send(&mut self, plain_text: &[u8], use_padding: bool) -> Result<(), PsecError> {
366 encrypt_and_send(&mut self.write_half, &self.local_cipher, &self.local_iv, &mut self.local_counter, plain_text, use_padding).await
367 }
368 fn encrypt(&mut self, plain_text: &[u8], use_padding: bool) -> Vec<u8> {
369 encrypt(&self.local_cipher, &self.local_iv, &mut self.local_counter, plain_text, use_padding)
370 }
371 async fn send(&mut self, cipher_text: &[u8]) -> Result<(), PsecError> {
372 send(&mut self.write_half, cipher_text).await
373 }
374}
375
376#[cfg(feature = "split")]
377impl Debug for SessionWriteHalf {
378 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
379 f.debug_struct("SessionWriteHalf")
380 .field("write_half", &self.write_half)
381 .field("local_counter", &self.local_counter)
382 .field("local_iv", &hex_encode(&self.local_iv))
383 .finish()
384 }
385}
386
387pub struct Session {
389 stream: TcpStream,
390 local_cipher: Option<Aes128Gcm>,
391 local_iv: Option<[u8; crypto::IV_LEN]>,
392 local_counter: usize,
393 peer_cipher: Option<Aes128Gcm>,
394 peer_iv: Option<[u8; crypto::IV_LEN]>,
395 peer_counter: usize,
396 max_recv_size: Option<usize>,
397 pub peer_public_key: Option<[u8; PUBLIC_KEY_LENGTH]>,
417}
418
419impl Session {
420 #[cfg(feature = "split")]
444 pub fn into_split(self) -> Option<(SessionReadHalf, SessionWriteHalf)> {
445 let (read_half, write_half) = self.stream.into_split();
446 Some((
447 SessionReadHalf {
448 read_half,
449 peer_cipher: self.peer_cipher?,
450 peer_iv: self.peer_iv?,
451 peer_counter: self.peer_counter,
452 max_recv_size: self.max_recv_size,
453 },
454 SessionWriteHalf {
455 write_half,
456 local_cipher: self.local_cipher?,
457 local_iv: self.local_iv?,
458 local_counter: self.local_counter,
459 }
460 ))
461 }
462
463 pub fn peer_addr(&self) -> io::Result<SocketAddr> {
482 self.stream.peer_addr()
483 }
484
485 async fn receive(&mut self, buff: &mut [u8]) -> Result<usize, PsecError> {
486 receive(&mut self.stream, buff).await
487 }
488
489 async fn send(&mut self, buff: &[u8]) -> Result<(), PsecError> {
490 send(&mut self.stream, buff).await
491 }
492
493 async fn handshake_read(&mut self, buff: &mut [u8], handshake_recv_buff: &mut Vec<u8>) -> Result<(), PsecError> {
494 self.receive(buff).await?;
495 handshake_recv_buff.extend(buff.as_ref());
496 Ok(())
497 }
498
499 async fn handshake_write(&mut self, buff: &[u8], handshake_sent_buff: &mut Vec<u8>) -> Result<(), PsecError> {
500 self.send(buff).await?;
501 handshake_sent_buff.extend(buff);
502 Ok(())
503 }
504
505 fn hash_handshake(i_am_bob: bool, handshake_sent_buff: &[u8], handshake_recv_buff: &[u8]) -> [u8; 48] {
506 let handshake_bytes = if i_am_bob {
507 [handshake_sent_buff, handshake_recv_buff].concat()
508 } else {
509 [handshake_recv_buff, handshake_sent_buff].concat()
510 };
511 let mut hasher = Sha384::new();
512 hasher.update(handshake_bytes);
513 let handshake_hash = hasher.finalize();
514 handshake_hash.as_slice().try_into().unwrap()
515 }
516
517 fn init_ciphers(&mut self, application_keys: ApplicationKeys){
518 self.local_cipher = Some(Aes128Gcm::new_from_slice(&application_keys.local_key).unwrap());
519 self.local_iv = Some(application_keys.local_iv);
520 self.peer_cipher = Some(Aes128Gcm::new_from_slice(&application_keys.peer_key).unwrap());
521 self.peer_iv = Some(application_keys.peer_iv);
522 }
523
524 pub async fn do_handshake(&mut self, identity: &Identity) -> Result<(), PsecError> {
528 let mut handshake_sent_buff = Vec::new();
529 let mut handshake_recv_buff = Vec::new();
530
531 let mut handshake_buffer = [0; RANDOM_LEN+PUBLIC_KEY_LENGTH];
534 OsRng.fill_bytes(&mut handshake_buffer[..RANDOM_LEN]);
535 let ephemeral_secret = x25519_dalek::EphemeralSecret::new(OsRng);
537 let ephemeral_public_key = x25519_dalek::PublicKey::from(&ephemeral_secret);
538 handshake_buffer[RANDOM_LEN..].copy_from_slice(&ephemeral_public_key.to_bytes());
539 self.handshake_write(&handshake_buffer, &mut handshake_sent_buff).await?;
541 self.handshake_read(&mut handshake_buffer, &mut handshake_recv_buff).await?;
542 let peer_ephemeral_public_key = slice_to_public_key(&handshake_buffer[RANDOM_LEN..]);
543 let i_am_bob = handshake_sent_buff < handshake_recv_buff; let handshake_hash = Session::hash_handshake(i_am_bob, &handshake_sent_buff, &handshake_recv_buff);
546 let shared_secret = ephemeral_secret.diffie_hellman(&peer_ephemeral_public_key);
547 let handshake_keys = HandshakeKeys::derive_keys(shared_secret.to_bytes(), handshake_hash, i_am_bob);
548
549 let mut auth_msg = [0; RANDOM_LEN+PUBLIC_KEY_LENGTH+SIGNATURE_LENGTH];
552 OsRng.fill_bytes(&mut auth_msg[..RANDOM_LEN]);
553 auth_msg[RANDOM_LEN..RANDOM_LEN+PUBLIC_KEY_LENGTH].copy_from_slice(&identity.public.to_bytes());
554 auth_msg[RANDOM_LEN+PUBLIC_KEY_LENGTH..].copy_from_slice(&identity.sign(ephemeral_public_key.as_bytes()).to_bytes());
555 let local_cipher = Aes128Gcm::new_from_slice(&handshake_keys.local_key).unwrap();
557 let encrypted_auth_msg = local_cipher.encrypt(Nonce::from_slice(&handshake_keys.local_iv), auth_msg.as_ref()).unwrap();
558 self.handshake_write(&encrypted_auth_msg, &mut handshake_sent_buff).await?;
559
560 let mut encrypted_peer_auth_msg = [0; RANDOM_LEN+PUBLIC_KEY_LENGTH+SIGNATURE_LENGTH+AES_TAG_LEN];
561 self.handshake_read(&mut encrypted_peer_auth_msg, &mut handshake_recv_buff).await?;
562 let peer_cipher = Aes128Gcm::new_from_slice(&handshake_keys.peer_key).unwrap();
564 match peer_cipher.decrypt(Nonce::from_slice(&handshake_keys.peer_iv), encrypted_peer_auth_msg.as_ref()) {
565 Ok(peer_auth_msg) => {
566 self.peer_public_key = Some(peer_auth_msg[RANDOM_LEN..RANDOM_LEN+PUBLIC_KEY_LENGTH].try_into().unwrap());
568 let peer_public_key = ed25519_dalek::PublicKey::from_bytes(&self.peer_public_key.unwrap()).unwrap();
569 let peer_signature = Signature::from_bytes(&peer_auth_msg[RANDOM_LEN+PUBLIC_KEY_LENGTH..]).unwrap();
570 if peer_public_key.verify(peer_ephemeral_public_key.as_bytes(), &peer_signature).is_ok() {
571 let handshake_hash = Session::hash_handshake(i_am_bob, &handshake_sent_buff, &handshake_recv_buff);
573 let handshake_finished = crypto::compute_handshake_finished(handshake_keys.local_handshake_traffic_secret, handshake_hash);
574 self.send(&handshake_finished).await?;
575 let mut peer_handshake_finished = [0; crypto::HASH_OUTPUT_LEN];
577 self.receive(&mut peer_handshake_finished).await?;
578 if crypto::verify_handshake_finished(peer_handshake_finished, handshake_keys.peer_handshake_traffic_secret, handshake_hash) {
579 let application_keys = ApplicationKeys::derive_keys(handshake_keys.handshake_secret, handshake_hash, i_am_bob);
581 self.init_ciphers(application_keys);
582 return Ok(());
583 }
584 }
585 }
586 Err(_) => {}
587 }
588 Err(PsecError::TransmissionCorrupted)
589 }
590}
591
592#[async_trait]
593impl PsecWriter for Session {
594 async fn encrypt_and_send(&mut self, plain_text: &[u8], use_padding: bool) -> Result<(), PsecError> {
595 encrypt_and_send(&mut self.stream, self.local_cipher.as_ref().unwrap(), self.local_iv.as_ref().unwrap(), &mut self.local_counter, plain_text, use_padding).await
596 }
597
598 fn encrypt(&mut self, plain_text: &[u8], use_padding: bool) -> Vec<u8> {
599 encrypt(self.local_cipher.as_ref().unwrap(), &self.local_iv.unwrap(), &mut self.local_counter, plain_text, use_padding)
600 }
601
602 async fn send(&mut self, cipher_text: &[u8]) -> Result<(), PsecError> {
603 send(&mut self.stream, cipher_text).await
604 }
605}
606
607#[async_trait]
608impl PsecReader for Session {
609 fn set_max_recv_size(&mut self, size: usize, is_raw_size: bool) {
610 self.max_recv_size = Some(compute_max_recv_size(size, is_raw_size));
611 }
612 async fn receive_and_decrypt(&mut self) -> Result<Vec<u8>, PsecError> {
613 receive_and_decrypt(&mut self.stream, &self.peer_cipher.as_ref().unwrap(), &self.peer_iv.unwrap(), &mut self.peer_counter, self.max_recv_size).await
614 }
615 async fn into_receive_and_decrypt(mut self) -> (Result<Vec<u8>, PsecError>, Self) {
616 (self.receive_and_decrypt().await, self)
617 }
618}
619
620impl From<TcpStream> for Session {
621 fn from(stream: TcpStream) -> Self {
622 Session {
623 stream: stream,
624 local_cipher: None,
625 local_iv: None,
626 local_counter: 0,
627 peer_cipher: None,
628 peer_iv: None,
629 peer_counter: 0,
630 peer_public_key: None,
631 max_recv_size: None,
632 }
633 }
634}
635
636impl Debug for Session {
637 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
638 let handshake_successful = self.peer_cipher.is_some();
639 let mut debug_struct = f.debug_struct("PSEC Session");
640 debug_struct
641 .field("stream", &self.stream)
642 .field("max_recv_size", &self.max_recv_size)
643 .field("handshake_successful", &handshake_successful);
644 if let Some(peer_public_key) = self.peer_public_key {
645 debug_struct.field("peer_public_key", &hex_encode(&peer_public_key));
646 }
647 if handshake_successful {
648 debug_struct.field("local_counter", &self.local_counter)
649 .field("local_iv", &hex_encode(&self.local_iv.unwrap()))
650 .field("peer_counter", &self.peer_counter)
651 .field("peer_iv", &hex_encode(&self.peer_iv.unwrap()));
652 }
653 debug_struct.finish()
654 }
655}
656
657fn hex_encode(buff: &[u8]) -> String {
658 let mut s = String::with_capacity(buff.len()*2);
659 for i in buff {
660 s += &format!("{:x}", i);
661 }
662 s
663}
664
665#[cfg(test)]
666mod tests {
667 use super::{pad, unpad, compute_max_recv_size, MESSAGE_LEN_LEN};
668
669 #[test]
670 fn padding() {
671 let padded = pad(b"Hello world!", true);
672 assert_eq!(padded.len(), 1000);
673 let not_padded = pad(b"Hello world!", false);
674 assert_eq!(not_padded.len(), "Hello world!".len()+MESSAGE_LEN_LEN);
675
676 let unpadded = unpad(padded).unwrap();
677 assert_eq!(unpadded, unpad(not_padded).unwrap());
678 assert_eq!(unpadded, b"Hello world!");
679
680 let large_msg = "a".repeat(5000);
681 assert_eq!(pad(large_msg.as_bytes(), true).len(), 8000);
682 }
683
684 #[test]
685 fn compute_max_size() {
686 assert_eq!(compute_max_recv_size(5, false), 1016);
687 assert_eq!(compute_max_recv_size(996, false), 1016);
688 assert_eq!(compute_max_recv_size(997, false), 2016);
689 assert_eq!(compute_max_recv_size(16383996, false), 16384016);
690 assert_eq!(compute_max_recv_size(16383997, false), 32768016)
691 }
692}