1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
use rsa::{PaddingScheme, RsaPrivateKey, RsaPublicKey};
use super::{
crypto::{CryptoError, CryptoStore, EncryptType, KeyEncryptType},
layer::SecureLayer,
stream::SecureStream,
};
use crate::secure::{SecureHandshakeHeader, SECURE_HANDSHAKE_HEADER_SIZE};
use std::io::{self, Read, Write};
#[derive(Debug)]
pub enum SecureHandshakeError {
Bincode(bincode::Error),
Io(io::Error),
Crypto(CryptoError),
}
impl From<bincode::Error> for SecureHandshakeError {
fn from(err: bincode::Error) -> Self {
Self::Bincode(err)
}
}
impl From<io::Error> for SecureHandshakeError {
fn from(err: io::Error) -> Self {
Self::Io(err)
}
}
impl From<CryptoError> for SecureHandshakeError {
fn from(err: CryptoError) -> Self {
Self::Crypto(err)
}
}
pub trait SecureSession<S>: Sized {
fn handshake(self) -> Result<SecureLayer<S>, SecureHandshakeError>;
fn handshake_stream(self) -> Result<SecureStream<S>, SecureHandshakeError> {
Ok(self.handshake()?.into())
}
}
#[derive(Debug)]
pub struct SecureClientSession<S> {
stream: S,
crypto: CryptoStore,
key: RsaPublicKey,
}
impl<S> SecureClientSession<S> {
pub fn new(key: RsaPublicKey, crypto: CryptoStore, stream: S) -> Self {
Self {
stream,
crypto,
key,
}
}
}
impl<S: Write> SecureSession<S> for SecureClientSession<S> {
fn handshake(mut self) -> Result<SecureLayer<S>, SecureHandshakeError> {
let encrypted_key = self.crypto.encrypt_key(&self.key)?;
let handshake_header = SecureHandshakeHeader {
encrypted_key_len: encrypted_key.len() as u32,
key_encrypt_type: KeyEncryptType::RsaOaepSha1Mgf1Sha1 as u32,
encrypt_type: EncryptType::AesCfb128 as u32,
};
let header_data = bincode::serialize(&handshake_header)?;
self.stream
.write_all(&header_data)
.and(self.stream.write_all(&encrypted_key))?;
Ok(SecureLayer::new(self.crypto, self.stream))
}
}
#[derive(Debug)]
pub struct SecureServerSession<S> {
stream: S,
key: RsaPrivateKey,
current_header: Option<SecureHandshakeHeader>,
}
impl<S> SecureServerSession<S> {
pub fn new(key: RsaPrivateKey, stream: S) -> Self {
Self {
stream,
key,
current_header: None,
}
}
}
impl<S: Read> SecureSession<S> for SecureServerSession<S> {
fn handshake(mut self) -> Result<SecureLayer<S>, SecureHandshakeError> {
let handshake_header = match self.current_header.take() {
Some(header) => header,
None => {
let mut handshake_header_buf = [0_u8; SECURE_HANDSHAKE_HEADER_SIZE];
self.stream.read_exact(&mut handshake_header_buf)?;
bincode::deserialize::<SecureHandshakeHeader>(&handshake_header_buf)?
}
};
let mut encrypted_key = vec![0_u8; handshake_header.encrypted_key_len as usize];
if let Err(err) = self.stream.read_exact(&mut encrypted_key) {
self.current_header = Some(handshake_header);
return Err(SecureHandshakeError::from(err));
}
let key = [0_u8; 16];
self.key
.decrypt(PaddingScheme::new_oaep::<sha1::Sha1>(), &encrypted_key)
.map_err(|_| CryptoError::CorruptedData)?;
let crypto = CryptoStore::new_with_key(key);
Ok(SecureLayer::new(crypto, self.stream))
}
}