use crate::{parse, Body, Error};
use crossbeam_channel::{Receiver, Sender, TryRecvError};
use curl::easy::{InfoType, ReadError, SeekResult, WriteError};
use futures_io::{AsyncRead, AsyncWrite};
use futures_util::task::AtomicWaker;
use http::Response;
use sluice::pipe;
use std::ascii;
use std::fmt;
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll, Waker};
pub(crate) struct RequestHandler {
id: Option<usize>,
sender: Option<Sender<Result<http::response::Builder, Error>>>,
shared: Arc<Shared>,
request_body: Body,
request_body_waker: Option<Waker>,
response_status_code: Option<http::StatusCode>,
response_version: Option<http::Version>,
response_headers: http::HeaderMap,
response_body_writer: pipe::PipeWriter,
response_body_waker: Option<Waker>,
}
#[derive(Debug, Default)]
struct Shared {
waker: AtomicWaker,
}
impl RequestHandler {
pub(crate) fn new(request_body: Body) -> (Self, RequestHandlerFuture) {
let (sender, receiver) = crossbeam_channel::bounded(1);
let shared = Arc::new(Shared::default());
let (response_body_reader, response_body_writer) = pipe::pipe();
(
Self {
id: None,
sender: Some(sender),
shared: shared.clone(),
request_body,
request_body_waker: None,
response_status_code: None,
response_version: None,
response_headers: http::HeaderMap::new(),
response_body_writer,
response_body_waker: None,
},
RequestHandlerFuture {
receiver,
shared,
response_body_reader: Some(response_body_reader),
},
)
}
fn is_disconnected(&self) -> bool {
Arc::strong_count(&self.shared) == 1
}
pub(crate) fn init(&mut self, id: usize, request_waker: Waker, response_waker: Waker) {
debug_assert!(self.id.is_none());
debug_assert!(self.request_body_waker.is_none());
debug_assert!(self.response_body_waker.is_none());
log::debug!("initializing handler for request [id={}]", id);
self.id = Some(id);
self.request_body_waker = Some(request_waker);
self.response_body_waker = Some(response_waker);
}
pub(crate) fn on_result(&mut self, result: Result<(), curl::Error>) {
match result {
Ok(()) => self.flush_response_headers(),
Err(e) => {
log::debug!("curl error: {}", e);
self.complete(Err(e.into()));
}
}
}
fn flush_response_headers(&mut self) {
if self.sender.is_some() {
let mut builder = http::Response::builder();
if let Some(status) = self.response_status_code.take() {
builder.status(status);
}
if let Some(version) = self.response_version.take() {
builder.version(version);
}
for (name, values) in self.response_headers.drain() {
for value in values {
builder.header(&name, value);
}
}
self.complete(Ok(builder));
}
}
fn complete(&mut self, result: Result<http::response::Builder, Error>) {
if let Some(sender) = self.sender.take() {
if let Err(e) = result.as_ref() {
log::warn!("request completed with error [id={:?}]: {}", self.id, e);
}
match sender.send(result) {
Ok(()) => {
self.shared.waker.wake();
}
Err(_) => {
log::debug!("request canceled by user [id={:?}]", self.id);
}
}
}
}
}
impl curl::easy::Handler for RequestHandler {
fn header(&mut self, data: &[u8]) -> bool {
if self.is_disconnected() {
return false;
}
if let Some((version, status)) = parse::parse_status_line(data) {
self.response_version = Some(version);
self.response_status_code = Some(status);
self.response_headers.clear();
return true;
}
if let Some((name, value)) = parse::parse_header(data) {
self.response_headers.insert(name, value);
return true;
}
if data == b"\r\n" {
return true;
}
false
}
fn read(&mut self, data: &mut [u8]) -> Result<usize, ReadError> {
if self.is_disconnected() {
return Err(ReadError::Abort);
}
if let Some(waker) = self.request_body_waker.as_ref() {
let mut context = Context::from_waker(waker);
match Pin::new(&mut self.request_body).poll_read(&mut context, data) {
Poll::Pending => Err(ReadError::Pause),
Poll::Ready(Ok(len)) => Ok(len),
Poll::Ready(Err(e)) => {
log::error!("error reading request body: {}", e);
Err(ReadError::Abort)
}
}
} else {
log::error!("request has not been initialized!");
Err(ReadError::Abort)
}
}
fn seek(&mut self, whence: io::SeekFrom) -> SeekResult {
if whence == io::SeekFrom::Start(0) && self.request_body.reset() {
SeekResult::Ok
} else {
log::warn!("seek requested for request body, but it is not supported");
SeekResult::CantSeek
}
}
fn write(&mut self, data: &[u8]) -> Result<usize, WriteError> {
log::trace!("received {} bytes of data", data.len());
self.flush_response_headers();
if let Some(waker) = self.response_body_waker.as_ref() {
let mut context = Context::from_waker(waker);
match Pin::new(&mut self.response_body_writer).poll_write(&mut context, data) {
Poll::Pending => Err(WriteError::Pause),
Poll::Ready(Ok(len)) => Ok(len),
Poll::Ready(Err(e)) => {
if e.kind() == io::ErrorKind::BrokenPipe {
log::warn!(
"failed to write response body because the response reader was dropped"
);
} else {
log::error!("error writing response body to buffer: {}", e);
}
Ok(0)
}
}
} else {
log::error!("request has not been initialized!");
Ok(0)
}
}
fn debug(&mut self, kind: InfoType, data: &[u8]) {
fn format_byte_string(bytes: impl AsRef<[u8]>) -> String {
String::from_utf8(
bytes
.as_ref()
.iter()
.flat_map(|byte| ascii::escape_default(*byte))
.collect(),
)
.unwrap_or_else(|_| String::from("<binary>"))
}
match kind {
InfoType::Text => {
log::debug!(target: "chttp::curl", "{}", String::from_utf8_lossy(data).trim_end())
}
InfoType::HeaderIn | InfoType::DataIn => {
log::trace!(target: "chttp::wire", "<< {}", format_byte_string(data))
}
InfoType::HeaderOut | InfoType::DataOut => {
log::trace!(target: "chttp::wire", ">> {}", format_byte_string(data))
}
_ => (),
}
}
}
impl fmt::Debug for RequestHandler {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "RequestHandler({:?})", self.id)
}
}
#[derive(Debug)]
pub(crate) struct RequestHandlerFuture {
receiver: Receiver<Result<http::response::Builder, Error>>,
shared: Arc<Shared>,
response_body_reader: Option<pipe::PipeReader>,
}
impl RequestHandlerFuture {
pub(crate) fn join(mut self) -> Result<Response<Body>, Error> {
match self.receiver.recv() {
Ok(Ok(builder)) => self.complete(builder),
Ok(Err(e)) => Err(e),
Err(_) => Err(Error::Aborted),
}
}
fn complete(&mut self, mut builder: http::response::Builder) -> Result<Response<Body>, Error> {
let reader = self.response_body_reader.take().unwrap();
let content_length = builder
.headers_ref()
.unwrap()
.get(http::header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse().ok());
let body = match content_length {
Some(len) => Body::reader_sized(reader, len),
None => Body::reader(reader),
};
match builder.body(body) {
Ok(response) => Ok(response),
Err(e) => Err(Error::InvalidHttpFormat(e)),
}
}
}
impl Future for RequestHandlerFuture {
type Output = Result<Response<Body>, Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.shared.waker.register(cx.waker());
match self.receiver.try_recv() {
Err(TryRecvError::Empty) => Poll::Pending,
Ok(Ok(builder)) => Poll::Ready(self.complete(builder)),
Ok(Err(e)) => Poll::Ready(Err(e)),
Err(TryRecvError::Disconnected) => Poll::Ready(Err(Error::Aborted)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn is_send<T: Send>() {}
#[test]
fn traits() {
is_send::<RequestHandlerFuture>();
}
}