use mio::net::TcpStream;
use chunked_transfer::Decoder as ChunkedDecoder;
use std::{io::{self, Write, Read}, time::{Duration, Instant}, collections::HashMap, net::{SocketAddr, Ipv4Addr}, mem::replace};
use crate::{dns, util::{make_socket_addr, notconnected, register_all, wouldblock, hash}, ResponseHead, ReqId, Response, ResponseState, Mode, Status, OwnedHeader, RawRequest};
#[cfg(feature = "tls")]
use std::sync::Arc;
pub struct Client {
dns: dns::DnsClient,
dns_cache: HashMap<u64, CachedAddr>,
requests: Vec<InternalReq>,
next_id: usize,
#[cfg(feature = "tls")]
tls_config: Arc<rustls::ClientConfig>,
#[cfg(not(feature = "tls"))]
tls_config: (),
}
impl Client {
#[inline(always)]
pub fn new(token: mio::Token) -> Self {
let tls_config = Self::default_tls_config();
Self {
dns: dns::DnsClient::new(token),
dns_cache: HashMap::new(),
requests: Vec::new(),
next_id: 0,
tls_config,
}
}
#[cfg(feature = "tls")]
#[inline(always)]
pub fn with_tls_config(token: mio::Token, tls_config: Arc<rustls::ClientConfig>) -> Self {
Self {
dns: dns::DnsClient::new(token),
dns_cache: HashMap::new(),
requests: Vec::new(),
next_id: 0,
tls_config,
}
}
pub fn send(&mut self, io: &mio::Poll, token: mio::Token, input: impl Into<RawRequest>) -> io::Result<ReqId> {
let request = input.into();
let id = self.next_id;
self.next_id = self.next_id.wrapping_add(1);
let mode = InternalMode::from_mode(request.mode, &self.tls_config, request.host());
let maybe_cached = self.dns_cache.get(&hash(request.host()));
let state = match maybe_cached {
Some(cached_addr) if !cached_addr.is_outdated() => {
let mut connection = Connection::new(cached_addr.ip_addr, mode)?;
register_all(io, &mut connection, token)?;
InternalReqState::Sending {
body: request.bytes,
connection,
}
},
_not_cached_or_old => {
let dns_id = self.dns.resolve(io, request.host(), request.timeout)?;
InternalReqState::Resolving {
host: hash(request.host()),
body: request.bytes,
dns_id,
mode
}
},
};
let internal_req = InternalReq {
id,
token,
state,
time_created: Instant::now(),
timeout: request.timeout,
};
self.requests.push(internal_req);
Ok(ReqId { inner: id })
}
pub fn pump(&mut self, io: &mio::Poll, events: &mio::Events) -> io::Result<Vec<Response>> {
let mut responses = Vec::new();
let dns_resps = self.dns.pump(&io, events)?;
'rq: for request in self.requests.iter_mut() {
if request.timeout.unwrap_or(Duration::MAX) <= request.time_created.elapsed() {
responses.push(Response::new(request.id, ResponseState::TimedOut));
request.deregister(&io)?; request.finish_error();
} else {
if let Some(connection) = request.state.connection_mut() {
connection.complete_io()?;
}
for event in events.iter() {
match &mut request.state {
InternalReqState::Resolving { dns_id, .. } => {
for resp in dns_resps.iter() {
if &resp.id == dns_id {
let (addr, ttl) = match resp.outcome {
dns::DnsOutcome::Known { addr, ttl } => (addr, ttl),
dns::DnsOutcome::Unknown => {
responses.push(Response::new(request.id, ResponseState::UnknownHost));
request.finish_error();
continue 'rq;
},
dns::DnsOutcome::ProtocolError => {
responses.push(Response::new(request.id, ResponseState::ProtocolError));
request.finish_error();
continue 'rq;
},
dns::DnsOutcome::TimedOut => {
responses.push(Response::new(request.id, ResponseState::TimedOut));
request.finish_error();
continue 'rq;
},
};
let state = replace(&mut request.state, InternalReqState::Unspecified);
if let InternalReqState::Resolving { body, host, mode, .. } = state {
self.dns_cache.insert(host, CachedAddr {
ip_addr: addr,
time_created: Instant::now(),
ttl,
});
let mut connection = Connection::new(addr, mode)?;
register_all(io, &mut connection, request.token)?;
request.state = InternalReqState::Sending { body, connection };
continue 'rq;
} else {
unreachable!()
}
}
}
},
InternalReqState::Sending { body, connection } => {
if event.token() == request.token {
match connection.peer_addr() {
Ok(..) => {
match connection.write(&body) {
Ok(..) => (),
Err(err) if wouldblock(&err) => continue 'rq,
Err(other) => return Err(other),
};
let state = replace(&mut request.state, InternalReqState::Unspecified);
if let InternalReqState::Sending { connection, .. } = state {
request.state = InternalReqState::RecvHead {
connection,
buffer: Vec::with_capacity(1024),
};
} else {
unreachable!()
}
},
Err(err) if notconnected(&err) => continue 'rq,
Err(other) => return Err(other),
}
}
},
InternalReqState::RecvHead { .. } |
InternalReqState::RecvBody { .. } => {
if event.token() == request.token {
if event.is_readable() {
if let InternalReqState::RecvHead { connection, buffer } = &mut request.state {
let mut bytes_read = buffer.len();
let mut closed = false;
loop {
buffer.resize(bytes_read + 2048, 0u8);
bytes_read += match connection.read(&mut buffer[bytes_read..]) {
Ok(0) => { closed = true; break },
Ok(num) => num,
Err(err) if wouldblock(&err) => break,
Err(other) => return Err(other),
};
}
buffer.truncate(bytes_read);
let mut headers = [httparse::EMPTY_HEADER; 4096]; let mut head = httparse::Response::new(&mut headers);
let status = match head.parse(&buffer) {
Ok(val) => val,
Err(_err) => {
responses.push(Response::new(request.id, ResponseState::ProtocolError));
request.finish_error();
continue 'rq;
}
};
if let httparse::Status::Complete(body_start) = status {
let content_length = head.headers.iter()
.find(|header| header.name == "Content-Length")
.map(|header| usize::from_str_radix(std::str::from_utf8(header.value)
.expect("Content-Length was invalid utf8"), 10)
.expect("Content-Length was not a number"))
.unwrap_or_default();
let transfer_chunked = head.headers.iter()
.find(|header| header.name == "Transfer-Encoding" && header.value == b"chunked")
.is_some();
responses.push(Response {
id: ReqId { inner: request.id },
state: ResponseState::Head(ResponseHead {
status: Status {
code: head.code.expect("missing status code"),
reason: head.reason.expect("missing reason").to_string(),
},
content_length,
transfer_chunked,
headers: head.headers.iter().map(OwnedHeader::from).collect(),
})
});
buffer.drain(..body_start);
let state = replace(&mut request.state, InternalReqState::Unspecified);
if let InternalReqState::RecvHead { connection, buffer } = state {
let chain = io::Cursor::new(buffer).chain(connection);
let recv = if transfer_chunked {
RecvBody::Chunked(ChunkedDecoder::new(chain))
} else {
RecvBody::Plain(chain)
};
request.state = InternalReqState::RecvBody {
recv,
bytes_read_total: 0,
content_length
};
} else {
unreachable!()
}
} else if closed {
responses.push(Response::new(request.id, ResponseState::Aborted));
request.finish_error();
continue 'rq;
}
}
}
if let InternalReqState::RecvBody { recv, bytes_read_total, content_length } = &mut request.state {
let mut data = Vec::new();
let mut bytes_read = 0;
let mut closed = false;
loop {
data.resize(bytes_read + 2048, 0u8);
bytes_read += match recv.read(&mut data[bytes_read..]) {
Ok(0) => { closed = true; break },
Ok(num) => num,
Err(err) if wouldblock(&err) => break,
Err(other) => return Err(other),
};
}
data.truncate(bytes_read);
if bytes_read > 0 {
responses.push(Response {
id: ReqId { inner: request.id },
state: ResponseState::Data(data),
});
*bytes_read_total += bytes_read;
}
let is_chunked = recv.is_chunked();
if is_chunked && (closed == true) ||
!is_chunked && (bytes_read_total >= content_length) {
responses.push(Response {
id: ReqId { inner: request.id },
state: ResponseState::Done,
});
request.deregister(&io)?;
request.finish_done();
continue 'rq
} else if closed {
responses.push(Response::new(request.id, ResponseState::Aborted));
request.finish_error();
continue 'rq;
}
}
}
},
_other => unreachable!(),
}
}
}
}
self.requests.retain(|request|
!request.is_finished()
);
Ok(responses)
}
#[inline(always)]
pub fn timeout(&self) -> Option<Duration> {
let now = Instant::now();
self.requests.iter().filter_map(|request|
request.timeout.map(|timeout| timeout.checked_sub(now - request.time_created).unwrap_or(Duration::ZERO))
).min()
}
#[cfg(feature = "tls")]
#[inline(always)]
fn default_tls_config() -> Arc<rustls::ClientConfig> {
let mut root_store = rustls::RootCertStore::empty();
root_store.add_trust_anchors(
webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta|
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(ta.subject, ta.spki, ta.name_constraints)
)
);
let config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
Arc::new(config)
}
#[cfg(not(feature = "tls"))]
fn default_tls_config() -> () {
()
}
}
struct InternalReq {
id: usize,
token: mio::Token,
time_created: Instant,
timeout: Option<Duration>,
state: InternalReqState,
}
impl InternalReq {
pub fn deregister(&mut self, io: &mio::Poll) -> io::Result<()> {
if let Some(conn) = self.state.connection_mut() {
io.registry().deregister(conn)
} else {
Ok(())
}
}
pub fn finish_done(&mut self) {
let _unused = replace(&mut self.state, InternalReqState::Done);
}
pub fn finish_error(&mut self) {
let _unused = replace(&mut self.state, InternalReqState::Error);
}
pub fn is_finished(&self) -> bool {
matches!(&self.state, InternalReqState::Done | InternalReqState::Error)
}
}
enum InternalReqState {
Unspecified,
Error,
Done,
Resolving {
body: Vec<u8>, dns_id: dns::DnsId,
host: u64, mode: InternalMode, },
Sending {
body: Vec<u8>, connection: Connection,
},
RecvHead {
connection: Connection,
buffer: Vec<u8>,
},
RecvBody {
recv: RecvBody,
bytes_read_total: usize,
content_length: usize,
},
}
impl InternalReqState {
pub fn connection_mut(&mut self) -> Option<&mut Connection> {
match self {
Self::Sending { connection, .. } => Some(connection),
Self::RecvHead { connection, .. } => Some(connection),
Self::RecvBody { recv, .. } => Some(recv.connection_mut()),
_other => None,
}
}
}
enum RecvBody {
Plain(io::Chain<io::Cursor<Vec<u8>>, Connection>),
Chunked(ChunkedDecoder<io::Chain<io::Cursor<Vec<u8>>, Connection>>)
}
impl RecvBody {
pub fn connection_mut(&mut self) -> &mut Connection {
match self {
Self::Plain(conn) => conn.get_mut().1,
Self::Chunked(decoder) => decoder.get_mut().get_mut().1
}
}
pub fn is_chunked(&self) -> bool {
match self {
Self::Plain(..) => false,
Self::Chunked(..) => true
}
}
}
impl io::Read for RecvBody {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
Self::Plain(conn) => conn.read(buf),
Self::Chunked(decoder) => decoder.read(buf)
}
}
}
struct CachedAddr {
pub ip_addr: Ipv4Addr,
pub time_created: Instant,
pub ttl: Duration,
}
impl CachedAddr {
pub fn is_outdated(&self) -> bool {
self.ttl <= self.time_created.elapsed()
}
}
enum InternalMode {
Plain,
#[cfg(feature = "tls")]
Secure { tls_config: Arc<rustls::ClientConfig>, server_name: rustls::ServerName }
}
impl InternalMode {
#[cfg(feature = "tls")]
pub(crate) fn from_mode(mode: Mode, tls_config: &Arc<rustls::ClientConfig>, host: &str) -> Self {
match mode {
Mode::Plain => Self::Plain,
Mode::Secure => Self::Secure {
tls_config: Arc::clone(tls_config),
server_name: host.try_into().expect("invalid host name")
},
}
}
#[cfg(not(feature = "tls"))]
pub(crate) fn from_mode(_mode: Mode, _tls_config: &(), _host: &str) -> Self {
Self::Plain
}
}
enum Connection {
Plain { tcp_stream: TcpStream },
#[cfg(feature = "tls")]
Secure { stream: rustls::StreamOwned<rustls::ClientConnection, TcpStream> },
}
impl Connection {
pub(crate) fn new(ip_addr: Ipv4Addr, mode: InternalMode) -> io::Result<Self> {
match mode {
InternalMode::Plain => {
let tcp_stream = TcpStream::connect(make_socket_addr(ip_addr, 80))?;
Ok(Self::Plain { tcp_stream })
},
#[cfg(feature = "tls")]
InternalMode::Secure { tls_config, server_name } => {
let tcp_stream = TcpStream::connect(make_socket_addr(ip_addr, 443))?;
let tls_connection = rustls::ClientConnection::new(tls_config, server_name).map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
let stream = rustls::StreamOwned::new(tls_connection, tcp_stream);
Ok(Self::Secure { stream })
}
}
}
pub(crate) fn peer_addr(&self) -> io::Result<SocketAddr> {
self.tcp_stream().peer_addr()
}
fn tcp_stream(&self) -> &TcpStream {
match self {
Self::Plain { tcp_stream } => tcp_stream,
#[cfg(feature = "tls")]
Self::Secure { stream } => &stream.sock,
}
}
fn tcp_stream_mut(&mut self) -> &mut TcpStream {
match self {
Self::Plain { tcp_stream } => tcp_stream,
#[cfg(feature = "tls")]
Self::Secure { stream } => &mut stream.sock,
}
}
pub(crate) fn complete_io(&mut self) -> io::Result<()> {
#[cfg(feature = "tls")]
if let Connection::Secure { stream } = self {
match stream.conn.complete_io(&mut stream.sock) {
Ok(..) => (),
Err(err) if wouldblock(&err) => (),
Err(other) => return Err(other),
};
}
Ok(())
}
}
impl mio::event::Source for Connection {
fn register(&mut self, registry: &mio::Registry, token: mio::Token, interests: mio::Interest) -> io::Result<()> {
self.tcp_stream_mut().register(registry, token, interests)
}
fn reregister(&mut self, registry: &mio::Registry, token: mio::Token, interests: mio::Interest) -> io::Result<()> {
self.tcp_stream_mut().reregister(registry, token, interests)
}
fn deregister(&mut self, registry: &mio::Registry) -> io::Result<()> {
self.tcp_stream_mut().deregister(registry)
}
}
impl Read for Connection {
fn read(&mut self, buff: &mut [u8]) -> io::Result<usize> {
match self {
Self::Plain { tcp_stream } => tcp_stream.read(buff),
#[cfg(feature = "tls")]
Self::Secure { stream } => stream.read(buff)
}
}
}
impl Write for Connection {
fn write(&mut self, buff: &[u8]) -> io::Result<usize> {
match self {
Self::Plain { tcp_stream } => tcp_stream.write(buff),
#[cfg(feature = "tls")]
Self::Secure { stream } => stream.write(buff)
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
Self::Plain { tcp_stream } => tcp_stream.flush(),
#[cfg(feature = "tls")]
Self::Secure { stream } => stream.flush()
}
}
}