use std::io::{self, Read, Write};
use rustls::ClientConnection;
use rustls::pki_types::ServerName;
use super::{TlsConfig, TlsError};
use crate::ws::FrameReader;
pub struct TlsCodec {
inner: ClientConnection,
}
impl TlsCodec {
pub fn new(config: &TlsConfig, hostname: &str) -> Result<Self, TlsError> {
let server_name = ServerName::try_from(hostname.to_owned())
.map_err(|_| TlsError::InvalidHostname(hostname.to_owned()))?;
let conn = ClientConnection::new(config.inner.clone(), server_name)?;
Ok(Self { inner: conn })
}
pub fn read_tls(&mut self, src: &[u8]) -> Result<usize, TlsError> {
let mut cursor = io::Cursor::new(src);
Ok(self.inner.read_tls(&mut cursor)?)
}
pub fn read_tls_from<R: Read>(&mut self, src: &mut R) -> io::Result<usize> {
self.inner.read_tls(src)
}
pub fn process_new_packets(&mut self) -> Result<(), TlsError> {
self.inner.process_new_packets()?;
Ok(())
}
pub fn process_into(&mut self, reader: &mut FrameReader) -> Result<usize, TlsError> {
self.inner.process_new_packets()?;
let mut rd = self.inner.reader();
let chunk = match std::io::BufRead::fill_buf(&mut rd) {
Ok(chunk) => chunk,
Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(0),
Err(e) => return Err(TlsError::Io(e)),
};
if chunk.is_empty() {
return Ok(0);
}
let n = chunk.len();
if let Err(e) = reader.read(chunk) {
return Err(TlsError::Io(io::Error::other(format!(
"FrameReader buffer full: {e}"
))));
}
std::io::BufRead::consume(&mut rd, n);
Ok(n)
}
pub fn read_plaintext(&mut self, dst: &mut [u8]) -> Result<usize, TlsError> {
match self.inner.reader().read(dst) {
Ok(n) => Ok(n),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => Ok(0),
Err(e) => Err(TlsError::Io(e)),
}
}
pub fn encrypt(&mut self, plaintext: &[u8]) -> Result<(), TlsError> {
self.inner.writer().write_all(plaintext)?;
Ok(())
}
pub fn write_tls_to<W: Write>(&mut self, dst: &mut W) -> io::Result<usize> {
self.inner.write_tls(dst)
}
pub fn is_handshaking(&self) -> bool {
self.inner.is_handshaking()
}
pub fn wants_read(&self) -> bool {
self.inner.wants_read()
}
pub fn wants_write(&self) -> bool {
self.inner.wants_write()
}
}
impl std::fmt::Debug for TlsCodec {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TlsCodec")
.field("handshaking", &self.inner.is_handshaking())
.finish()
}
}