1use std::io::{Read, Write, ErrorKind};
6use std::sync::Arc;
7
8use bytes::{Bytes, BytesMut};
9use rustls::{ServerConfig, ServerConnection};
10use rustls::pki_types::{CertificateDer, PrivateKeyDer};
11
12use crate::{ProtocolError, Result};
13
14pub struct TlsHandler {
16 conn: ServerConnection,
18 incoming: BytesMut,
20 outgoing: BytesMut,
22 handshake_complete: bool,
24}
25
26impl TlsHandler {
27 pub fn new(config: Arc<ServerConfig>) -> Result<Self> {
29 let conn = ServerConnection::new(config)
30 .map_err(|e| ProtocolError::TlsError(e.to_string()))?;
31
32 Ok(Self {
33 conn,
34 incoming: BytesMut::with_capacity(16384),
35 outgoing: BytesMut::with_capacity(16384),
36 handshake_complete: false,
37 })
38 }
39
40 pub fn process_incoming(&mut self, data: &[u8]) -> Result<()> {
42 self.incoming.extend_from_slice(data);
43 self.process_tls()
44 }
45
46 pub fn process_tls_records(&mut self, records: Vec<Bytes>) -> Result<()> {
48 for record in records {
49 self.incoming.extend_from_slice(&record);
50 }
51 self.process_tls()
52 }
53
54 fn process_tls(&mut self) -> Result<()> {
56 let mut reader = &self.incoming[..];
58
59 match self.conn.read_tls(&mut reader) {
60 Ok(0) => {
61 }
63 Ok(n) => {
64 let _ = self.incoming.split_to(n);
66 }
67 Err(e) if e.kind() == ErrorKind::WouldBlock => {
68 }
70 Err(e) => {
71 return Err(ProtocolError::TlsError(e.to_string()));
72 }
73 }
74
75 match self.conn.process_new_packets() {
77 Ok(_state) => {
78 if !self.handshake_complete && !self.conn.is_handshaking() {
79 self.handshake_complete = true;
80 }
81 }
82 Err(e) => {
83 return Err(ProtocolError::TlsError(e.to_string()));
84 }
85 }
86
87 Ok(())
88 }
89
90 pub fn get_outgoing(&mut self) -> Result<Option<Bytes>> {
92 self.outgoing.clear();
93
94 match self.conn.write_tls(&mut VecWriter(&mut self.outgoing)) {
95 Ok(0) => Ok(None),
96 Ok(_) => Ok(Some(self.outgoing.clone().freeze())),
97 Err(e) if e.kind() == ErrorKind::WouldBlock => Ok(None),
98 Err(e) => Err(ProtocolError::TlsError(e.to_string())),
99 }
100 }
101
102 pub fn is_handshake_complete(&self) -> bool {
104 self.handshake_complete
105 }
106
107 pub fn is_handshaking(&self) -> bool {
109 self.conn.is_handshaking()
110 }
111
112 pub fn wants_write(&self) -> bool {
114 self.conn.wants_write()
115 }
116
117 pub fn read_plaintext(&mut self, buf: &mut [u8]) -> Result<usize> {
119 let mut reader = self.conn.reader();
120 match reader.read(buf) {
121 Ok(n) => Ok(n),
122 Err(e) if e.kind() == ErrorKind::WouldBlock => Ok(0),
123 Err(e) => Err(ProtocolError::TlsError(e.to_string())),
124 }
125 }
126
127 pub fn write_plaintext(&mut self, data: &[u8]) -> Result<usize> {
129 let mut writer = self.conn.writer();
130 match writer.write(data) {
131 Ok(n) => Ok(n),
132 Err(e) => Err(ProtocolError::TlsError(e.to_string())),
133 }
134 }
135
136 pub fn peer_certificates(&self) -> Option<Vec<CertificateDer<'static>>> {
138 self.conn.peer_certificates().map(|certs| {
139 certs.iter().map(|c| c.clone().into_owned()).collect()
140 })
141 }
142
143 pub fn cipher_suite(&self) -> Option<&'static str> {
145 self.conn.negotiated_cipher_suite().map(|cs| cs.suite().as_str().unwrap_or("unknown"))
146 }
147}
148
149struct VecWriter<'a>(&'a mut BytesMut);
151
152impl<'a> Write for VecWriter<'a> {
153 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
154 self.0.extend_from_slice(buf);
155 Ok(buf.len())
156 }
157
158 fn flush(&mut self) -> std::io::Result<()> {
159 Ok(())
160 }
161}
162
163pub fn create_server_config(
165 cert_chain: Vec<CertificateDer<'static>>,
166 key: PrivateKeyDer<'static>,
167 client_cert_verifier: Option<Arc<dyn rustls::server::danger::ClientCertVerifier>>,
168) -> Result<Arc<ServerConfig>> {
169 let config = if let Some(verifier) = client_cert_verifier {
177 ServerConfig::builder()
178 .with_client_cert_verifier(verifier)
179 .with_single_cert(cert_chain, key)
180 .map_err(|e: rustls::Error| ProtocolError::TlsError(e.to_string()))?
181 } else {
182 ServerConfig::builder()
183 .with_no_client_auth()
184 .with_single_cert(cert_chain, key)
185 .map_err(|e: rustls::Error| ProtocolError::TlsError(e.to_string()))?
186 };
187
188 Ok(Arc::new(config))
189}
190
191pub fn load_certs_from_pem(pem: &str) -> Result<Vec<CertificateDer<'static>>> {
193 let mut certs = Vec::new();
194 for cert in rustls_pemfile::certs(&mut pem.as_bytes()) {
195 match cert {
196 Ok(c) => certs.push(c),
197 Err(e) => return Err(ProtocolError::TlsError(format!("Failed to parse cert: {}", e))),
198 }
199 }
200 Ok(certs)
201}
202
203pub fn load_key_from_pem(pem: &str) -> Result<PrivateKeyDer<'static>> {
205 for item in rustls_pemfile::read_all(&mut pem.as_bytes()) {
207 match item {
208 Ok(rustls_pemfile::Item::Pkcs8Key(key)) => return Ok(PrivateKeyDer::Pkcs8(key)),
209 Ok(rustls_pemfile::Item::Pkcs1Key(key)) => return Ok(PrivateKeyDer::Pkcs1(key)),
210 Ok(rustls_pemfile::Item::Sec1Key(key)) => return Ok(PrivateKeyDer::Sec1(key)),
211 _ => continue,
212 }
213 }
214 Err(ProtocolError::TlsError("No private key found in PEM".into()))
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 #[test]
223 fn test_tls_handler_creation() {
224 }
226}