use std::io;
use std::fmt::Display;
#[allow(unused_imports)]
use std::ascii::AsciiExt;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, AtomicBool, Ordering};
use tk_bufstream::WriteBuf;
use futures::{Future, Async};
use tokio_io::AsyncWrite;
use enums::Version;
use headers::is_close;
use base_serializer::{MessageState, HeaderError};
pub enum RequestState {
Empty = 0,
StartedHead = 1,
StartedNormal = 2,
}
pub struct Encoder<S> {
message: MessageState,
buf: WriteBuf<S>,
state: Arc<AtomicUsize>,
close_signal: Arc<AtomicBool>,
}
pub struct EncoderDone<S> {
buf: WriteBuf<S>,
}
pub struct WaitFlush<S>(Option<Encoder<S>>, usize);
pub fn get_inner<S>(e: EncoderDone<S>) -> WriteBuf<S> {
e.buf
}
impl<S> Encoder<S> {
pub fn request_line(&mut self, method: &str, path: &str, version: Version)
{
self.message.request_line(&mut self.buf.out_buf,
method, path, version);
let nstatus = if method.eq_ignore_ascii_case("HEAD") {
RequestState::StartedHead as usize
} else {
RequestState::StartedNormal as usize
};
if self.state.swap(nstatus, Ordering::SeqCst) != 0 {
panic!("Request line in wrong state");
}
}
pub fn add_header<V: AsRef<[u8]>>(&mut self, name: &str, value: V)
-> Result<(), HeaderError>
{
if name.eq_ignore_ascii_case("Connection") && is_close(value.as_ref())
{
self.close_signal.store(true, Ordering::SeqCst);
}
self.message.add_header(&mut self.buf.out_buf, name, value.as_ref())
}
pub fn format_header<D: Display>(&mut self, name: &str, value: D)
-> Result<(), HeaderError>
{
if name.eq_ignore_ascii_case("Connection") {
unimplemented!();
}
self.message.format_header(&mut self.buf.out_buf, name, value)
}
pub fn add_length(&mut self, n: u64)
-> Result<(), HeaderError>
{
self.message.add_length(&mut self.buf.out_buf, n)
}
pub fn add_chunked(&mut self)
-> Result<(), HeaderError>
{
self.message.add_chunked(&mut self.buf.out_buf)
}
pub fn done_headers(&mut self) -> Result<(), HeaderError> {
self.message.done_headers(&mut self.buf.out_buf)
.map(|always_support_body| assert!(always_support_body))
}
pub fn write_body(&mut self, data: &[u8]) {
self.message.write_body(&mut self.buf.out_buf, data)
}
pub fn done(mut self) -> EncoderDone<S> {
self.message.done(&mut self.buf.out_buf);
EncoderDone { buf: self.buf }
}
pub fn flush(&mut self) -> Result<(), io::Error>
where S: AsyncWrite
{
self.buf.flush()
}
pub fn bytes_buffered(&mut self) -> usize {
self.buf.out_buf.len()
}
pub fn wait_flush(self, watermark: usize) -> WaitFlush<S> {
WaitFlush(Some(self), watermark)
}
}
impl<S: AsyncWrite> Future for WaitFlush<S> {
type Item = Encoder<S>;
type Error = io::Error;
fn poll(&mut self) -> Result<Async<Encoder<S>>, io::Error> {
let bytes_left = {
let enc = self.0.as_mut().expect("future is polled twice");
enc.flush()?;
enc.buf.out_buf.len()
};
if bytes_left < self.1 {
Ok(Async::Ready(self.0.take().unwrap()))
} else {
Ok(Async::NotReady)
}
}
}
pub fn new<S>(io: WriteBuf<S>,
state: Arc<AtomicUsize>, close_signal: Arc<AtomicBool>)
-> Encoder<S>
{
Encoder {
message: MessageState::RequestStart,
buf: io,
state: state,
close_signal: close_signal,
}
}
impl<S> io::Write for Encoder<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.write_body(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}