1use std::io::{Read, Write, ErrorKind};
6use std::sync::Arc;
7
8use bytes::{Bytes, BytesMut};
9use rustls::{ServerConfig, ServerConnection};
10use rustls::pki_types::{CertificateDer, PrivateKeyDer};
11
12use crate::{ProtocolError, Result};
13
14pub struct TlsHandler {
16 conn: ServerConnection,
18 incoming: BytesMut,
20 outgoing: BytesMut,
22 handshake_complete: bool,
24}
25
26impl TlsHandler {
27 pub fn new(config: Arc<ServerConfig>) -> Result<Self> {
29 let conn = ServerConnection::new(config)
30 .map_err(|e| ProtocolError::TlsError(e.to_string()))?;
31
32 Ok(Self {
33 conn,
34 incoming: BytesMut::with_capacity(16384),
35 outgoing: BytesMut::with_capacity(16384),
36 handshake_complete: false,
37 })
38 }
39
40 pub fn process_incoming(&mut self, data: &[u8]) -> Result<()> {
42 self.incoming.extend_from_slice(data);
43 self.process_tls()
44 }
45
46 pub fn process_tls_records(&mut self, records: Vec<Bytes>) -> Result<()> {
48 for record in records {
49 self.incoming.extend_from_slice(&record);
50 }
51 self.process_tls()
52 }
53
54 fn process_tls(&mut self) -> Result<()> {
56 let mut reader = &self.incoming[..];
58
59 match self.conn.read_tls(&mut reader) {
60 Ok(0) => {
61 }
63 Ok(n) => {
64 let _ = self.incoming.split_to(n);
66 }
67 Err(e) if e.kind() == ErrorKind::WouldBlock => {
68 }
70 Err(e) => {
71 return Err(ProtocolError::TlsError(e.to_string()));
72 }
73 }
74
75 match self.conn.process_new_packets() {
77 Ok(_state) => {
78 if !self.handshake_complete && !self.conn.is_handshaking() {
79 self.handshake_complete = true;
80 }
81 }
82 Err(e) => {
83 return Err(ProtocolError::TlsError(e.to_string()));
84 }
85 }
86
87 Ok(())
88 }
89
90 pub fn get_outgoing(&mut self) -> Result<Option<Bytes>> {
92 self.outgoing.clear();
93
94 match self.conn.write_tls(&mut VecWriter(&mut self.outgoing)) {
95 Ok(0) => Ok(None),
96 Ok(_) => Ok(Some(self.outgoing.clone().freeze())),
97 Err(e) if e.kind() == ErrorKind::WouldBlock => Ok(None),
98 Err(e) => Err(ProtocolError::TlsError(e.to_string())),
99 }
100 }
101
102 pub fn is_handshake_complete(&self) -> bool {
104 self.handshake_complete
105 }
106
107 pub fn is_handshaking(&self) -> bool {
109 self.conn.is_handshaking()
110 }
111
112 pub fn wants_write(&self) -> bool {
114 self.conn.wants_write()
115 }
116
117 pub fn read_plaintext(&mut self, buf: &mut [u8]) -> Result<usize> {
119 let mut reader = self.conn.reader();
120 match reader.read(buf) {
121 Ok(n) => Ok(n),
122 Err(e) if e.kind() == ErrorKind::WouldBlock => Ok(0),
123 Err(e) => Err(ProtocolError::TlsError(e.to_string())),
124 }
125 }
126
127 pub fn write_plaintext(&mut self, data: &[u8]) -> Result<usize> {
129 let mut writer = self.conn.writer();
130 match writer.write(data) {
131 Ok(n) => Ok(n),
132 Err(e) => Err(ProtocolError::TlsError(e.to_string())),
133 }
134 }
135
136 pub fn peer_certificates(&self) -> Option<Vec<CertificateDer<'static>>> {
138 self.conn.peer_certificates().map(|certs| {
139 certs.iter().map(|c| c.clone().into_owned()).collect()
140 })
141 }
142
143 pub fn cipher_suite(&self) -> Option<&'static str> {
145 self.conn.negotiated_cipher_suite().map(|cs| cs.suite().as_str().unwrap_or("unknown"))
146 }
147}
148
149struct VecWriter<'a>(&'a mut BytesMut);
151
152impl<'a> Write for VecWriter<'a> {
153 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
154 self.0.extend_from_slice(buf);
155 Ok(buf.len())
156 }
157
158 fn flush(&mut self) -> std::io::Result<()> {
159 Ok(())
160 }
161}
162
163pub fn create_server_config(
165 cert_chain: Vec<CertificateDer<'static>>,
166 key: PrivateKeyDer<'static>,
167 client_cert_verifier: Option<Arc<dyn rustls::server::danger::ClientCertVerifier>>,
168) -> Result<Arc<ServerConfig>> {
169
170
171
172 let config = if let Some(verifier) = client_cert_verifier {
173 ServerConfig::builder()
174 .with_client_cert_verifier(verifier)
175 .with_single_cert(cert_chain, key)
176 .map_err(|e| ProtocolError::TlsError(e.to_string()))?
177 } else {
178 ServerConfig::builder()
179 .with_no_client_auth()
180 .with_single_cert(cert_chain, key)
181 .map_err(|e| ProtocolError::TlsError(e.to_string()))?
182 };
183
184 Ok(Arc::new(config))
188}
189
190pub fn load_certs_from_pem(pem: &str) -> Result<Vec<CertificateDer<'static>>> {
192 let mut certs = Vec::new();
193 for cert in rustls_pemfile::certs(&mut pem.as_bytes()) {
194 match cert {
195 Ok(c) => certs.push(c),
196 Err(e) => return Err(ProtocolError::TlsError(format!("Failed to parse cert: {}", e))),
197 }
198 }
199 Ok(certs)
200}
201
202pub fn load_key_from_pem(pem: &str) -> Result<PrivateKeyDer<'static>> {
204 for item in rustls_pemfile::read_all(&mut pem.as_bytes()) {
206 match item {
207 Ok(rustls_pemfile::Item::Pkcs8Key(key)) => return Ok(PrivateKeyDer::Pkcs8(key)),
208 Ok(rustls_pemfile::Item::Pkcs1Key(key)) => return Ok(PrivateKeyDer::Pkcs1(key)),
209 Ok(rustls_pemfile::Item::Sec1Key(key)) => return Ok(PrivateKeyDer::Sec1(key)),
210 _ => continue,
211 }
212 }
213 Err(ProtocolError::TlsError("No private key found in PEM".into()))
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 #[test]
222 fn test_tls_handler_creation() {
223 }
225}