use std::default::Default;
use std::io::{self, copy, Read};
use std::iter::Extend;
use url::UrlParser;
use url::ParseError as UrlError;
use header::{Headers, Header, HeaderFormat};
use header::{ContentLength, Location};
use method::Method;
use net::{NetworkConnector, NetworkStream, ContextVerifier};
use status::StatusClass::Redirection;
use {Url};
use Error;
pub use self::pool::Pool;
pub use self::request::Request;
pub use self::response::Response;
pub mod pool;
pub mod request;
pub mod response;
pub struct Client {
connector: Connector,
redirect_policy: RedirectPolicy,
}
impl Client {
pub fn new() -> Client {
Client::with_pool_config(Default::default())
}
pub fn with_pool_config(config: pool::Config) -> Client {
Client::with_connector(Pool::new(config))
}
pub fn with_connector<C, S>(connector: C) -> Client
where C: NetworkConnector<Stream=S> + Send + 'static, S: NetworkStream + Send {
Client {
connector: with_connector(connector),
redirect_policy: Default::default()
}
}
pub fn set_ssl_verifier(&mut self, verifier: ContextVerifier) {
self.connector.set_ssl_verifier(verifier);
}
pub fn set_redirect_policy(&mut self, policy: RedirectPolicy) {
self.redirect_policy = policy;
}
pub fn get<U: IntoUrl>(&mut self, url: U) -> RequestBuilder<U> {
self.request(Method::Get, url)
}
pub fn head<U: IntoUrl>(&mut self, url: U) -> RequestBuilder<U> {
self.request(Method::Head, url)
}
pub fn post<U: IntoUrl>(&mut self, url: U) -> RequestBuilder<U> {
self.request(Method::Post, url)
}
pub fn put<U: IntoUrl>(&mut self, url: U) -> RequestBuilder<U> {
self.request(Method::Put, url)
}
pub fn delete<U: IntoUrl>(&mut self, url: U) -> RequestBuilder<U> {
self.request(Method::Delete, url)
}
pub fn request<U: IntoUrl>(&mut self, method: Method, url: U) -> RequestBuilder<U> {
RequestBuilder {
client: self,
method: method,
url: url,
body: None,
headers: None,
}
}
}
fn with_connector<C: NetworkConnector<Stream=S> + Send + 'static, S: NetworkStream + Send>(c: C) -> Connector {
Connector(Box::new(ConnAdapter(c)))
}
impl Default for Client {
fn default() -> Client { Client::new() }
}
struct ConnAdapter<C: NetworkConnector + Send>(C);
impl<C: NetworkConnector<Stream=S> + Send, S: NetworkStream + Send> NetworkConnector for ConnAdapter<C> {
type Stream = Box<NetworkStream + Send>;
#[inline]
fn connect(&self, host: &str, port: u16, scheme: &str)
-> ::Result<Box<NetworkStream + Send>> {
Ok(try!(self.0.connect(host, port, scheme)).into())
}
#[inline]
fn set_ssl_verifier(&mut self, verifier: ContextVerifier) {
self.0.set_ssl_verifier(verifier);
}
}
struct Connector(Box<NetworkConnector<Stream=Box<NetworkStream + Send>> + Send>);
impl NetworkConnector for Connector {
type Stream = Box<NetworkStream + Send>;
#[inline]
fn connect(&self, host: &str, port: u16, scheme: &str)
-> ::Result<Box<NetworkStream + Send>> {
Ok(try!(self.0.connect(host, port, scheme)).into())
}
#[inline]
fn set_ssl_verifier(&mut self, verifier: ContextVerifier) {
self.0.set_ssl_verifier(verifier);
}
}
pub struct RequestBuilder<'a, U: IntoUrl> {
client: &'a Client,
url: U,
headers: Option<Headers>,
method: Method,
body: Option<Body<'a>>,
}
impl<'a, U: IntoUrl> RequestBuilder<'a, U> {
pub fn body<B: Into<Body<'a>>>(mut self, body: B) -> RequestBuilder<'a, U> {
self.body = Some(body.into());
self
}
pub fn headers(mut self, headers: Headers) -> RequestBuilder<'a, U> {
self.headers = Some(headers);
self
}
pub fn header<H: Header + HeaderFormat>(mut self, header: H) -> RequestBuilder<'a, U> {
{
let mut headers = match self.headers {
Some(ref mut h) => h,
None => {
self.headers = Some(Headers::new());
self.headers.as_mut().unwrap()
}
};
headers.set(header);
}
self
}
pub fn send(self) -> ::Result<Response> {
let RequestBuilder { client, method, url, headers, body } = self;
let mut url = try!(url.into_url());
trace!("send {:?} {:?}", method, url);
let can_have_body = match &method {
&Method::Get | &Method::Head => false,
_ => true
};
let mut body = if can_have_body {
body
} else {
None
};
loop {
let mut req = try!(Request::with_connector(method.clone(), url.clone(), &client.connector));
headers.as_ref().map(|headers| req.headers_mut().extend(headers.iter()));
match (can_have_body, body.as_ref()) {
(true, Some(body)) => match body.size() {
Some(size) => req.headers_mut().set(ContentLength(size)),
None => (), },
(true, None) => req.headers_mut().set(ContentLength(0)),
_ => () }
let mut streaming = try!(req.start());
body.take().map(|mut rdr| copy(&mut rdr, &mut streaming));
let res = try!(streaming.send());
if res.status.class() != Redirection {
return Ok(res)
}
debug!("redirect code {:?} for {}", res.status, url);
let loc = {
let loc = match res.headers.get::<Location>() {
Some(&Location(ref loc)) => {
Some(UrlParser::new().base_url(&url).parse(&loc[..]))
}
None => {
debug!("no Location header");
None
}
};
match loc {
Some(r) => r,
None => return Ok(res)
}
};
url = match loc {
Ok(u) => u,
Err(e) => {
debug!("Location header had invalid URI: {:?}", e);
return Ok(res);
}
};
match client.redirect_policy {
RedirectPolicy::FollowAll => (), RedirectPolicy::FollowIf(cond) if cond(&url) => (), _ => return Ok(res),
}
}
}
}
pub enum Body<'a> {
ChunkedBody(&'a mut (Read + 'a)),
SizedBody(&'a mut (Read + 'a), u64),
BufBody(&'a [u8] , usize),
}
impl<'a> Body<'a> {
fn size(&self) -> Option<u64> {
match *self {
Body::SizedBody(_, len) => Some(len),
Body::BufBody(_, len) => Some(len as u64),
_ => None
}
}
}
impl<'a> Read for Body<'a> {
#[inline]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match *self {
Body::ChunkedBody(ref mut r) => r.read(buf),
Body::SizedBody(ref mut r, _) => r.read(buf),
Body::BufBody(ref mut r, _) => Read::read(r, buf),
}
}
}
impl<'a> Into<Body<'a>> for &'a [u8] {
#[inline]
fn into(self) -> Body<'a> {
Body::BufBody(self, self.len())
}
}
impl<'a> Into<Body<'a>> for &'a str {
#[inline]
fn into(self) -> Body<'a> {
self.as_bytes().into()
}
}
impl<'a> Into<Body<'a>> for &'a String {
#[inline]
fn into(self) -> Body<'a> {
self.as_bytes().into()
}
}
impl<'a, R: Read> From<&'a mut R> for Body<'a> {
#[inline]
fn from(r: &'a mut R) -> Body<'a> {
Body::ChunkedBody(r)
}
}
pub trait IntoUrl {
fn into_url(self) -> Result<Url, UrlError>;
}
impl IntoUrl for Url {
fn into_url(self) -> Result<Url, UrlError> {
Ok(self)
}
}
impl<'a> IntoUrl for &'a str {
fn into_url(self) -> Result<Url, UrlError> {
Url::parse(self)
}
}
impl<'a> IntoUrl for &'a String {
fn into_url(self) -> Result<Url, UrlError> {
Url::parse(self)
}
}
#[derive(Copy)]
pub enum RedirectPolicy {
FollowNone,
FollowAll,
FollowIf(fn(&Url) -> bool),
}
impl Clone for RedirectPolicy {
fn clone(&self) -> RedirectPolicy {
*self
}
}
impl Default for RedirectPolicy {
fn default() -> RedirectPolicy {
RedirectPolicy::FollowAll
}
}
fn get_host_and_port(url: &Url) -> ::Result<(String, u16)> {
let host = match url.serialize_host() {
Some(host) => host,
None => return Err(Error::Uri(UrlError::EmptyHost))
};
trace!("host={:?}", host);
let port = match url.port_or_default() {
Some(port) => port,
None => return Err(Error::Uri(UrlError::InvalidPort))
};
trace!("port={:?}", port);
Ok((host, port))
}
#[cfg(test)]
mod tests {
use header::Server;
use super::{Client, RedirectPolicy};
use url::Url;
use mock::ChannelMockConnector;
use std::sync::mpsc::{self, TryRecvError};
mock_connector!(MockRedirectPolicy {
"http://127.0.0.1" => "HTTP/1.1 301 Redirect\r\n\
Location: http://127.0.0.2\r\n\
Server: mock1\r\n\
\r\n\
"
"http://127.0.0.2" => "HTTP/1.1 302 Found\r\n\
Location: https://127.0.0.3\r\n\
Server: mock2\r\n\
\r\n\
"
"https://127.0.0.3" => "HTTP/1.1 200 OK\r\n\
Server: mock3\r\n\
\r\n\
"
});
#[test]
fn test_redirect_followall() {
let mut client = Client::with_connector(MockRedirectPolicy);
client.set_redirect_policy(RedirectPolicy::FollowAll);
let res = client.get("http://127.0.0.1").send().unwrap();
assert_eq!(res.headers.get(), Some(&Server("mock3".to_owned())));
}
#[test]
fn test_redirect_dontfollow() {
let mut client = Client::with_connector(MockRedirectPolicy);
client.set_redirect_policy(RedirectPolicy::FollowNone);
let res = client.get("http://127.0.0.1").send().unwrap();
assert_eq!(res.headers.get(), Some(&Server("mock1".to_owned())));
}
#[test]
fn test_redirect_followif() {
fn follow_if(url: &Url) -> bool {
!url.serialize().contains("127.0.0.3")
}
let mut client = Client::with_connector(MockRedirectPolicy);
client.set_redirect_policy(RedirectPolicy::FollowIf(follow_if));
let res = client.get("http://127.0.0.1").send().unwrap();
assert_eq!(res.headers.get(), Some(&Server("mock2".to_owned())));
}
#[test]
fn test_client_set_ssl_verifer() {
let (tx, rx) = mpsc::channel();
let mut client = Client::with_connector(ChannelMockConnector::new(tx));
client.set_ssl_verifier(Box::new(|_| {}));
match rx.try_recv() {
Ok(meth) => {
assert_eq!(meth, "set_ssl_verifier");
},
_ => panic!("Expected a call to `set_ssl_verifier`"),
};
match rx.try_recv() {
Err(TryRecvError::Empty) => {},
Err(TryRecvError::Disconnected) => {
panic!("Expected the connector to still be alive.");
},
Ok(_) => panic!("Did not expect any more method calls."),
};
}
}