1use std::io::{Read, Write, ErrorKind};
6use std::sync::Arc;
7
8use bytes::{Bytes, BytesMut};
9use rustls::{ServerConfig, ServerConnection};
10use rustls::pki_types::{CertificateDer, PrivateKeyDer};
11
12use crate::{ProtocolError, Result};
13
14pub struct TlsHandler {
16 conn: ServerConnection,
18 incoming: BytesMut,
20 outgoing: BytesMut,
22 handshake_complete: bool,
24}
25
26impl TlsHandler {
27 pub fn new(config: Arc<ServerConfig>) -> Result<Self> {
29 let conn = ServerConnection::new(config)
30 .map_err(|e| ProtocolError::TlsError(e.to_string()))?;
31
32 Ok(Self {
33 conn,
34 incoming: BytesMut::with_capacity(16384),
35 outgoing: BytesMut::with_capacity(16384),
36 handshake_complete: false,
37 })
38 }
39
40 pub fn process_incoming(&mut self, data: &[u8]) -> Result<()> {
42 self.incoming.extend_from_slice(data);
43 self.process_tls()
44 }
45
46 pub fn process_tls_records(&mut self, records: Vec<Bytes>) -> Result<()> {
48 for record in records {
49 self.incoming.extend_from_slice(&record);
50 }
51 self.process_tls()
52 }
53
54 fn process_tls(&mut self) -> Result<()> {
56 let mut reader = &self.incoming[..];
58
59 match self.conn.read_tls(&mut reader) {
60 Ok(0) => {
61 }
63 Ok(n) => {
64 let _ = self.incoming.split_to(n);
66 }
67 Err(e) if e.kind() == ErrorKind::WouldBlock => {
68 }
70 Err(e) => {
71 return Err(ProtocolError::TlsError(e.to_string()));
72 }
73 }
74
75 match self.conn.process_new_packets() {
77 Ok(_state) => {
78 if !self.handshake_complete && !self.conn.is_handshaking() {
79 self.handshake_complete = true;
80 }
81 }
82 Err(e) => {
83 return Err(ProtocolError::TlsError(e.to_string()));
84 }
85 }
86
87 Ok(())
88 }
89
90 pub fn get_outgoing(&mut self) -> Result<Option<Bytes>> {
92 self.outgoing.clear();
93
94 match self.conn.write_tls(&mut VecWriter(&mut self.outgoing)) {
95 Ok(0) => Ok(None),
96 Ok(_) => Ok(Some(self.outgoing.clone().freeze())),
97 Err(e) if e.kind() == ErrorKind::WouldBlock => Ok(None),
98 Err(e) => Err(ProtocolError::TlsError(e.to_string())),
99 }
100 }
101
102 pub fn is_handshake_complete(&self) -> bool {
104 self.handshake_complete
105 }
106
107 pub fn is_handshaking(&self) -> bool {
109 self.conn.is_handshaking()
110 }
111
112 pub fn wants_write(&self) -> bool {
114 self.conn.wants_write()
115 }
116
117 pub fn read_plaintext(&mut self, buf: &mut [u8]) -> Result<usize> {
119 let mut reader = self.conn.reader();
120 match reader.read(buf) {
121 Ok(n) => Ok(n),
122 Err(e) if e.kind() == ErrorKind::WouldBlock => Ok(0),
123 Err(e) => Err(ProtocolError::TlsError(e.to_string())),
124 }
125 }
126
127 pub fn write_plaintext(&mut self, data: &[u8]) -> Result<usize> {
129 let mut writer = self.conn.writer();
130 match writer.write(data) {
131 Ok(n) => Ok(n),
132 Err(e) => Err(ProtocolError::TlsError(e.to_string())),
133 }
134 }
135
136 pub fn export_keying_material(
140 &self,
141 output: &mut [u8],
142 label: &[u8],
143 context: Option<&[u8]>,
144 ) -> Result<()> {
145 self.conn.export_keying_material(output, label, context)
146 .map_err(|e| ProtocolError::TlsError(format!("EKM export failed: {}", e)))?;
147 Ok(())
148 }
149
150 pub fn peer_certificates(&self) -> Option<Vec<CertificateDer<'static>>> {
152 self.conn.peer_certificates().map(|certs| {
153 certs.iter().map(|c| c.clone().into_owned()).collect()
154 })
155 }
156
157 pub fn cipher_suite(&self) -> Option<&'static str> {
159 self.conn.negotiated_cipher_suite().map(|cs| cs.suite().as_str().unwrap_or("unknown"))
160 }
161}
162
163struct VecWriter<'a>(&'a mut BytesMut);
165
166impl<'a> Write for VecWriter<'a> {
167 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
168 self.0.extend_from_slice(buf);
169 Ok(buf.len())
170 }
171
172 fn flush(&mut self) -> std::io::Result<()> {
173 Ok(())
174 }
175}
176
177pub fn create_server_config(
179 cert_chain: Vec<CertificateDer<'static>>,
180 key: PrivateKeyDer<'static>,
181 client_cert_verifier: Option<Arc<dyn rustls::server::danger::ClientCertVerifier>>,
182) -> Result<Arc<ServerConfig>> {
183 let config = if let Some(verifier) = client_cert_verifier {
191 ServerConfig::builder()
192 .with_client_cert_verifier(verifier)
193 .with_single_cert(cert_chain, key)
194 .map_err(|e: rustls::Error| ProtocolError::TlsError(e.to_string()))?
195 } else {
196 ServerConfig::builder()
197 .with_no_client_auth()
198 .with_single_cert(cert_chain, key)
199 .map_err(|e: rustls::Error| ProtocolError::TlsError(e.to_string()))?
200 };
201
202 Ok(Arc::new(config))
203}
204
205pub fn load_certs_from_pem(pem: &str) -> Result<Vec<CertificateDer<'static>>> {
207 let mut certs = Vec::new();
208 for cert in rustls_pemfile::certs(&mut pem.as_bytes()) {
209 match cert {
210 Ok(c) => certs.push(c),
211 Err(e) => return Err(ProtocolError::TlsError(format!("Failed to parse cert: {}", e))),
212 }
213 }
214 Ok(certs)
215}
216
217pub fn load_key_from_pem(pem: &str) -> Result<PrivateKeyDer<'static>> {
219 for item in rustls_pemfile::read_all(&mut pem.as_bytes()) {
221 match item {
222 Ok(rustls_pemfile::Item::Pkcs8Key(key)) => return Ok(PrivateKeyDer::Pkcs8(key)),
223 Ok(rustls_pemfile::Item::Pkcs1Key(key)) => return Ok(PrivateKeyDer::Pkcs1(key)),
224 Ok(rustls_pemfile::Item::Sec1Key(key)) => return Ok(PrivateKeyDer::Sec1(key)),
225 _ => continue,
226 }
227 }
228 Err(ProtocolError::TlsError("No private key found in PEM".into()))
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234
235 #[test]
237 fn test_tls_handler_creation() {
238 }
240}