use std::io::{self, Read, Write};
use super::codec::TlsCodec;
pub struct TlsStream<S> {
stream: S,
codec: TlsCodec,
}
impl<S> TlsStream<S> {
pub fn stream(&self) -> &S {
&self.stream
}
pub fn stream_mut(&mut self) -> &mut S {
&mut self.stream
}
pub fn codec(&self) -> &TlsCodec {
&self.codec
}
pub fn codec_mut(&mut self) -> &mut TlsCodec {
&mut self.codec
}
pub fn into_parts(self) -> (S, TlsCodec) {
(self.stream, self.codec)
}
pub fn set_buffer_limit(&mut self, limit: Option<usize>) {
self.codec.set_buffer_limit(limit);
}
}
impl<S: Read + Write> TlsStream<S> {
pub fn connect(stream: S, codec: TlsCodec) -> Result<Self, super::TlsError> {
let mut s = Self { stream, codec };
s.handshake()?;
Ok(s)
}
fn handshake(&mut self) -> Result<(), super::TlsError> {
while self.codec.is_handshaking() {
while self.codec.wants_write() {
self.codec.write_tls_to(&mut self.stream)?;
}
if self.codec.wants_read() {
let n = self.codec.read_tls_from(&mut self.stream)?;
if n == 0 {
return Err(super::TlsError::Io(io::Error::new(
io::ErrorKind::UnexpectedEof,
"connection closed during TLS handshake",
)));
}
}
}
while self.codec.wants_write() {
self.codec.write_tls_to(&mut self.stream)?;
}
Ok(())
}
}
impl<S: Read + Write> Read for TlsStream<S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let n = self.codec.read_plaintext(buf).map_err(tls_to_io)?;
if n > 0 {
return Ok(n);
}
loop {
let tls_n = self
.codec
.read_tls_from(&mut self.stream)
.map_err(tls_to_io)?;
if tls_n == 0 {
return Ok(0); }
let n = self.codec.read_plaintext(buf).map_err(tls_to_io)?;
if n > 0 {
return Ok(n);
}
}
}
}
impl<S: Read + Write> Write for TlsStream<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut written = 0;
while written < buf.len() {
let n = self.codec.encrypt(&buf[written..]).map_err(tls_to_io)?;
if n == 0 {
while self.codec.wants_write() {
self.codec.write_tls_to(&mut self.stream)?;
}
let n2 = self.codec.encrypt(&buf[written..]).map_err(tls_to_io)?;
if n2 == 0 {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"rustls plaintext queue limit is smaller than \
the remaining input — the buffer-limit may \
have been set too low (raise via \
TlsCodec::set_buffer_limit or \
TlsBufferCapacities::rustls_plaintext_limit), \
or chunk the write into smaller pieces",
));
}
written += n2;
} else {
written += n;
}
}
while self.codec.wants_write() {
self.codec.write_tls_to(&mut self.stream)?;
}
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
while self.codec.wants_write() {
self.codec.write_tls_to(&mut self.stream)?;
}
self.stream.flush()
}
}
fn tls_to_io(e: super::TlsError) -> io::Error {
match e {
super::TlsError::Io(io) => io,
other => io::Error::other(other),
}
}