use std::io;
use std::fmt::Display;
use futures::{Future, Poll, Async};
use tk_bufstream::{WriteBuf, WriteRaw, FutureWriteRaw};
use tokio_io::AsyncWrite;
use base_serializer::{MessageState, HeaderError};
use enums::{Version, Status};
use super::headers::Head;
pub struct Encoder<S> {
state: MessageState,
io: WriteBuf<S>,
}
pub struct EncoderDone<S> {
buf: WriteBuf<S>,
}
#[derive(Debug, Clone, Copy)]
pub struct ResponseConfig {
pub is_head: bool,
pub do_close: bool,
pub version: Version,
}
pub struct FutureRawBody<S>(FutureWriteRaw<S>);
pub struct WaitFlush<S>(Option<Encoder<S>>, usize);
pub struct RawBody<S> {
io: WriteRaw<S>,
}
impl<S> Encoder<S> {
pub fn response_continue(&mut self) {
self.state.response_continue(&mut self.io.out_buf)
}
pub fn status(&mut self, status: Status) {
self.state.response_status(&mut self.io.out_buf,
status.code(), status.reason())
}
pub fn custom_status(&mut self, code: u16, reason: &str) {
self.state.response_status(&mut self.io.out_buf, code, reason)
}
pub fn add_header<V: AsRef<[u8]>>(&mut self, name: &str, value: V)
-> Result<(), HeaderError>
{
self.state.add_header(&mut self.io.out_buf, name, value.as_ref())
}
pub fn format_header<D: Display>(&mut self, name: &str, value: D)
-> Result<(), HeaderError>
{
self.state.format_header(&mut self.io.out_buf, name, value)
}
pub fn add_length(&mut self, n: u64)
-> Result<(), HeaderError>
{
self.state.add_length(&mut self.io.out_buf, n)
}
pub fn add_chunked(&mut self)
-> Result<(), HeaderError>
{
self.state.add_chunked(&mut self.io.out_buf)
}
#[cfg(feature="date_header")]
pub fn add_date(&mut self) {
use httpdate::HttpDate;
use std::time::SystemTime;
self.format_header("Date", HttpDate::from(SystemTime::now()))
.expect("always valid to add a date")
}
pub fn is_started(&self) -> bool {
self.state.is_started()
}
pub fn done_headers(&mut self) -> Result<bool, HeaderError> {
self.state.done_headers(&mut self.io.out_buf)
}
pub fn write_body(&mut self, data: &[u8]) {
self.state.write_body(&mut self.io.out_buf, data)
}
pub fn is_complete(&self) -> bool {
self.state.is_complete()
}
pub fn done(mut self) -> EncoderDone<S> {
self.state.done(&mut self.io.out_buf);
EncoderDone { buf: self.io }
}
pub fn raw_body(self) -> FutureRawBody<S> {
assert!(self.state.is_after_headers());
FutureRawBody(self.io.borrow_raw())
}
pub fn flush(&mut self) -> Result<(), io::Error>
where S: AsyncWrite
{
self.io.flush()
}
pub fn bytes_buffered(&mut self) -> usize {
self.io.out_buf.len()
}
pub fn wait_flush(self, watermark: usize) -> WaitFlush<S> {
WaitFlush(Some(self), watermark)
}
}
impl<S> RawBody<S> {
pub fn done(self) -> EncoderDone<S> {
EncoderDone { buf: self.io.into_buf() }
}
}
impl<S> io::Write for Encoder<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.write_body(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl<S: AsyncWrite> AsyncWrite for Encoder<S> {
fn shutdown(&mut self) -> Poll<(), io::Error> {
panic!("Can't shutdown request encoder");
}
}
impl<S: AsyncWrite> io::Write for RawBody<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.io.get_mut().write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.io.get_mut().flush()
}
}
impl<S: AsyncWrite> AsyncWrite for RawBody<S> {
fn shutdown(&mut self) -> Poll<(), io::Error> {
panic!("Can't shutdown request body");
}
}
pub fn get_inner<S>(e: EncoderDone<S>) -> WriteBuf<S> {
e.buf
}
pub fn new<S>(io: WriteBuf<S>, cfg: ResponseConfig) -> Encoder<S> {
use base_serializer::Body::*;
Encoder {
state: MessageState::ResponseStart {
body: if cfg.is_head { Head } else { Normal },
version: cfg.version,
close: cfg.do_close || cfg.version == Version::Http10,
},
io: io,
}
}
impl ResponseConfig {
pub fn from(req: &Head) -> ResponseConfig {
ResponseConfig {
version: req.version(),
is_head: req.method() == "HEAD",
do_close: req.connection_close(),
}
}
}
impl<S: AsyncWrite> Future for FutureRawBody<S> {
type Item = RawBody<S>;
type Error = io::Error;
fn poll(&mut self) -> Poll<RawBody<S>, io::Error> {
self.0.poll().map(|x| x.map(|y| RawBody { io: y }))
}
}
impl<S: AsyncWrite> Future for WaitFlush<S> {
type Item = Encoder<S>;
type Error = io::Error;
fn poll(&mut self) -> Result<Async<Encoder<S>>, io::Error> {
let bytes_left = {
let enc = self.0.as_mut().expect("future is polled twice");
enc.flush()?;
enc.io.out_buf.len()
};
if bytes_left < self.1 {
Ok(Async::Ready(self.0.take().unwrap()))
} else {
Ok(Async::NotReady)
}
}
}
#[cfg(feature="sendfile")]
mod sendfile {
extern crate tk_sendfile;
use std::io;
use futures::{Async};
use self::tk_sendfile::{Destination, FileOpener, Sendfile};
use super::RawBody;
impl<T: Destination> Destination for RawBody<T> {
fn write_file<O: FileOpener>(&mut self, file: &mut Sendfile<O>)
-> Result<usize, io::Error>
{
self.io.get_mut().write_file(file)
}
fn poll_write(&self) -> Async<()> {
self.io.get_ref().poll_write()
}
}
}
#[cfg(test)]
mod test {
use tk_bufstream::{MockData, IoBuf};
use {Status};
use base_serializer::{MessageState, Body};
use super::{Encoder, EncoderDone};
use enums::Version;
fn do_response11_str<F>(fun: F) -> String
where F: FnOnce(Encoder<MockData>) -> EncoderDone<MockData>
{
let mock = MockData::new();
let done = fun(Encoder {
state: MessageState::ResponseStart {
body: Body::Normal,
version: Version::Http11,
close: false,
},
io: IoBuf::new(mock.clone()).split().0,
});
{done}.buf.flush().unwrap();
String::from_utf8_lossy(&mock.output(..)).to_string()
}
#[test]
fn date_header() {
assert!(do_response11_str(|mut enc| {
enc.status(Status::Ok);
enc.add_date();
enc.add_length(0).unwrap();
enc.done_headers().unwrap();
enc.done()
}).starts_with("HTTP/1.1 200 OK\r\nDate: "));
}
}