use std::fmt;
use std::io;
use std::io::{ErrorKind, Write};
use std::net::{TcpStream, ToSocketAddrs};
use std::str::FromStr;
use std::time::Duration;
use log::*;
use native_tls::TlsConnector;
use crate::raw::compression::{Compression, Decoder};
use crate::raw::error::Result;
use crate::raw::parse::{is_end_of_datablock, parse_data_block_line, parse_first_line};
use crate::raw::response::{DataBlocks, RawResponse};
use crate::raw::stream::NntpStream;
use crate::types::command::NntpCommand;
use crate::types::prelude::*;
#[derive(Clone)]
pub struct TlsConfig {
connector: TlsConnector,
domain: String,
}
impl TlsConfig {
pub fn new(domain: String, connector: TlsConnector) -> Self {
Self { connector, domain }
}
pub fn default_connector(domain: impl AsRef<str>) -> Result<Self> {
let connector = TlsConnector::new()?;
Ok(Self {
connector,
domain: domain.as_ref().to_string(),
})
}
pub fn connector(&self) -> &TlsConnector {
&self.connector
}
}
impl fmt::Debug for TlsConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TlsConfig")
.field("domain", &self.domain)
.finish()
}
}
#[derive(Debug)]
pub struct NntpConnection {
stream: BufNntpStream,
first_line_buf: Vec<u8>,
data_blocks_buf: Vec<u8>,
config: ConnectionConfig,
}
impl NntpConnection {
pub fn connect(
addr: impl ToSocketAddrs,
config: ConnectionConfig,
) -> Result<(Self, RawResponse)> {
let ConnectionConfig {
compression: _,
tls_config,
read_timeout,
write_timeout: _,
first_line_buf_size,
data_blocks_buf_size,
} = config.clone();
trace!("Opening TcpStream...");
let tcp_stream = TcpStream::connect(&addr)?;
tcp_stream.set_read_timeout(read_timeout)?;
let nntp_stream = if let Some(TlsConfig { connector, domain }) = tls_config.as_ref() {
trace!("Wrapping TcpStream w/ TlsConnector");
connector.connect(domain, tcp_stream)?.into()
} else {
trace!("No TLS config providing, continuing with plain text");
tcp_stream.into()
};
let first_line_buf = Vec::with_capacity(first_line_buf_size);
let data_blocks_buf = Vec::with_capacity(data_blocks_buf_size);
let mut conn = Self {
stream: io::BufReader::new(nntp_stream),
first_line_buf,
data_blocks_buf,
config,
};
let initial_resp = conn.read_response_auto()?;
Ok((conn, initial_resp))
}
pub fn with_defaults(addr: impl ToSocketAddrs) -> Result<(Self, RawResponse)> {
Self::connect(addr, Default::default())
}
pub fn command<C: NntpCommand>(&mut self, command: &C) -> Result<RawResponse> {
self.send(command)?;
let resp = self.read_response_auto()?;
Ok(resp)
}
pub fn command_multiline<C: NntpCommand>(
&mut self,
command: &C,
is_multiline: bool,
) -> Result<RawResponse> {
self.send(command)?;
let resp = self.read_response(Some(is_multiline))?;
Ok(resp)
}
pub fn send<C: NntpCommand>(&mut self, command: &C) -> Result<usize> {
let bytes = self.send_bytes(command.to_string().as_bytes())?;
Ok(bytes)
}
pub fn send_bytes(&mut self, command: impl AsRef<[u8]>) -> Result<usize> {
let writer = self.stream.get_mut();
let bytes = writer.write(command.as_ref())? + writer.write(b"\r\n")?;
writer.flush()?;
Ok(bytes)
}
pub fn read_response_auto(&mut self) -> Result<RawResponse> {
self.read_response(None)
}
pub fn read_response(&mut self, is_multiline: Option<bool>) -> Result<RawResponse> {
self.first_line_buf.truncate(0);
self.data_blocks_buf.truncate(0);
let resp_code = read_initial_response(&mut self.stream, &mut self.first_line_buf)?;
let data_blocks = match (is_multiline, resp_code.is_multiline()) {
(Some(true), _) | (_, true) => {
trace!("Parsing data blocks for response {}", u16::from(resp_code));
let mut line_boundaries = Vec::with_capacity(10);
let mut stream = match self.config.compression {
Some(c) if c.use_decoder(&self.first_line_buf) => {
trace!("Compression enabled, wrapping stream with decoder");
c.decoder(&mut self.stream)
}
_ => {
trace!("Using passthrough decoder");
Decoder::Passthrough(&mut self.stream)
}
};
read_data_blocks(&mut stream, &mut self.data_blocks_buf, &mut line_boundaries)?;
Some(DataBlocks {
payload: self.data_blocks_buf.clone(),
line_boundaries,
})
}
(Some(false), _) => None, _ => None,
};
let resp = RawResponse {
code: resp_code,
first_line: self.first_line_buf.clone(),
data_blocks,
};
self.reset_buffers();
Ok(resp)
}
fn reset_buffers(&mut self) {
self.first_line_buf
.truncate(self.config.first_line_buf_size);
self.first_line_buf.shrink_to_fit();
self.data_blocks_buf
.truncate(self.config.data_blocks_buf_size);
self.data_blocks_buf.shrink_to_fit();
}
pub fn stream(&self) -> &io::BufReader<NntpStream> {
&self.stream
}
pub fn stream_mut(&mut self) -> &mut io::BufReader<NntpStream> {
&mut self.stream
}
pub fn config(&self) -> &ConnectionConfig {
&self.config
}
}
pub type BufNntpStream = io::BufReader<NntpStream>;
#[derive(Clone, Debug)]
pub struct ConnectionConfig {
pub(crate) compression: Option<Compression>,
pub(crate) tls_config: Option<TlsConfig>,
pub(crate) read_timeout: Option<Duration>,
pub(crate) write_timeout: Option<Duration>,
pub(crate) first_line_buf_size: usize,
pub(crate) data_blocks_buf_size: usize,
}
impl Default for ConnectionConfig {
fn default() -> Self {
ConnectionConfig {
compression: None,
tls_config: None,
read_timeout: None,
write_timeout: None,
first_line_buf_size: 128,
data_blocks_buf_size: 16 * 1024,
}
}
}
impl ConnectionConfig {
pub fn new() -> ConnectionConfig {
Default::default()
}
pub fn compression(&mut self, compression: Option<Compression>) -> &mut Self {
self.compression = compression;
self
}
pub fn tls_config(&mut self, config: Option<TlsConfig>) -> &mut Self {
self.tls_config = config;
self
}
pub fn default_tls(&mut self, domain: impl AsRef<str>) -> Result<&mut Self> {
let domain = domain.as_ref().to_string();
let tls_config = TlsConfig::default_connector(domain)?;
self.tls_config = Some(tls_config);
Ok(self)
}
pub fn read_timeout(&mut self, dur: Option<Duration>) -> &mut Self {
self.read_timeout = dur;
self
}
pub fn first_line_buf_size(&mut self, s: usize) -> &mut Self {
self.first_line_buf_size = s;
self
}
pub fn data_blocks_buf_size(&mut self, s: usize) -> &mut Self {
self.data_blocks_buf_size = s;
self
}
pub fn connect(&self, addr: impl ToSocketAddrs) -> Result<(NntpConnection, RawResponse)> {
NntpConnection::connect(addr, self.clone())
}
}
fn read_initial_response<S: io::BufRead>(
stream: &mut S,
buffer: &mut Vec<u8>,
) -> Result<ResponseCode> {
stream.read_until(b'\n', buffer)?;
let (_initial_line_buffer, resp) = parse_first_line(&buffer).map_err(|_e| {
io::Error::new(
ErrorKind::InvalidData,
"Failed to parse first line of response",
)
})?;
let code_str = std::str::from_utf8(resp.code).unwrap();
let code_u16 = u16::from_str(code_str).unwrap();
Ok(code_u16.into())
}
fn read_data_blocks<S: io::BufRead>(
stream: &mut S,
buffer: &mut Vec<u8>,
line_boundaries: &mut Vec<(usize, usize)>,
) -> Result<()> {
let mut read_head = 0;
trace!("Reading data blocks...");
loop {
let bytes_read = stream.read_until(b'\n', buffer)?;
let (_empty, line) = parse_data_block_line(&buffer[read_head..]).map_err(|e| {
trace!("parse_data_block_line failed -- {:?}", e);
io::Error::new(
ErrorKind::InvalidData,
format!(
"Failed to parse line {} of data blocks",
line_boundaries.len() + 1
),
)
})?;
line_boundaries.push((read_head, read_head + bytes_read));
read_head += bytes_read;
if is_end_of_datablock(line) {
trace!(
"Read {} bytes of data across {} lines",
read_head,
line_boundaries.len()
);
break;
}
}
Ok(())
}