use std::time::{Duration, Instant};
use std::{fmt, io, net, thread};
use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWrite, BufReader, BufWriter};
use tokio::net::{TcpStream, ToSocketAddrs};
use base64;
use serde;
use serde_json;
use super::{Request, Response};
use crate::client::Transport;
pub const DEFAULT_PORT: u16 = 8332;
#[derive(Clone, Debug)]
pub struct SimpleHttpTransport {
addr: net::SocketAddr,
path: String,
timeout: Duration,
basic_auth: Option<String>,
}
impl Default for SimpleHttpTransport {
fn default() -> Self {
SimpleHttpTransport {
addr: net::SocketAddr::new(
net::IpAddr::V4(net::Ipv4Addr::new(127, 0, 0, 1)),
DEFAULT_PORT,
),
path: "/".to_owned(),
timeout: Duration::from_secs(15),
basic_auth: None,
}
}
}
use tokio::io::{AsyncReadExt, AsyncWriteExt};
impl SimpleHttpTransport {
pub fn new() -> Self {
SimpleHttpTransport::default()
}
pub fn builder() -> Builder {
Builder::new()
}
async fn request<R>(&self, req: impl serde::Serialize) -> Result<R, Error>
where
R: for<'a> serde::de::Deserialize<'a>,
{
let request_deadline = Instant::now() + self.timeout;
let sock = tokio::time::timeout(self.timeout, TcpStream::connect(self.addr)).await??;
let (read, write) = sock.into_split();
let mut writer = BufWriter::new(write);
let body = serde_json::to_vec(&req)?;
writer.write_all(b"POST ").await?;
writer.write_all(self.path.as_bytes()).await?;
writer.write_all(b" HTTP/1.1\r\n").await?;
writer.write_all(b"Content-Type: application/json-rpc\r\n").await?;
writer.write_all(b"Content-Length: ").await?;
writer.write_all(body.len().to_string().as_bytes()).await?;
writer.write_all(b"\r\n").await?;
if let Some(ref auth) = self.basic_auth {
writer.write_all(b"Authorization: ").await?;
writer.write_all(auth.as_ref()).await?;
writer.write_all(b"\r\n").await?;
}
writer.write_all(b"\r\n").await?;
writer.write_all(&body).await?;
writer.flush().await?;
let mut reader = BufReader::new(read);
let http_response = get_line(&mut reader, request_deadline).await?;
if http_response.len() < 12 || !http_response.starts_with("HTTP/1.1 ") {
return Err(Error::HttpParseError);
}
let response_code = match http_response[9..12].parse::<u16>() {
Ok(n) => n,
Err(_) => return Err(Error::HttpParseError),
};
while get_line(&mut reader, request_deadline).await? != "\r\n" {}
let resp_body = get_line(&mut reader, request_deadline).await?;
match serde_json::from_str(&resp_body) {
Ok(s) => Ok(s),
Err(e) => {
if response_code != 200 {
Err(Error::HttpErrorCode(response_code))
} else {
Err(e.into())
}
}
}
}
}
#[derive(Debug)]
pub enum Error {
InvalidUrl {
url: String,
reason: &'static str,
},
SocketError(io::Error),
HttpParseError,
HttpErrorCode(u16),
Timeout,
Json(serde_json::Error),
}
impl Error {
fn url<U: Into<String>>(url: U, reason: &'static str) -> Error {
Error::InvalidUrl {
url: url.into(),
reason: reason,
}
}
}
impl std::error::Error for Error {}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
match *self {
Error::InvalidUrl {
ref url,
ref reason,
} => write!(f, "invalid URL '{}': {}", url, reason),
Error::SocketError(ref e) => write!(f, "Couldn't connect to host: {}", e),
Error::HttpParseError => f.write_str("Couldn't parse response header."),
Error::HttpErrorCode(c) => write!(f, "unexpected HTTP code: {}", c),
Error::Timeout => f.write_str("Didn't receive response data in time, timed out."),
Error::Json(ref e) => write!(f, "JSON error: {}", e),
}
}
}
use tokio::time::error::Elapsed;
impl From<Elapsed> for Error {
fn from(e: Elapsed) -> Error {
Error::Timeout
}
}
impl From<io::Error> for Error {
fn from(e: io::Error) -> Self {
Error::SocketError(e)
}
}
impl From<serde_json::Error> for Error {
fn from(e: serde_json::Error) -> Self {
Error::Json(e)
}
}
impl From<Error> for crate::Error {
fn from(e: Error) -> Self {
match e {
Error::Json(e) => crate::Error::Json(e),
e => crate::Error::Transport(Box::new(e)),
}
}
}
async fn get_line<R: AsyncBufRead + Unpin>(
reader: &mut R,
deadline: Instant,
) -> Result<String, Error> {
let mut line = String::new();
while deadline > Instant::now() {
match reader.read_line(&mut line).await {
Ok(0) => thread::sleep(Duration::from_millis(5)),
Ok(_) => return Ok(line),
Err(e) => return Err(Error::SocketError(e)),
}
}
Err(Error::Timeout)
}
use async_trait::async_trait;
#[async_trait]
impl Transport for SimpleHttpTransport {
async fn send_request(&self, req: Request<'_>) -> Result<Response, crate::Error> {
Ok(self.request(req).await?)
}
async fn send_batch(&self, reqs: &[Request<'_>]) -> Result<Vec<Response>, crate::Error> {
Ok(self.request(reqs).await?)
}
fn fmt_target(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "http://{}:{}{}", self.addr.ip(), self.addr.port(), self.path)
}
}
#[derive(Clone, Debug)]
pub struct Builder {
tp: SimpleHttpTransport,
}
impl Builder {
pub fn new() -> Builder {
Builder {
tp: SimpleHttpTransport::new(),
}
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.tp.timeout = timeout;
self
}
pub async fn url(mut self, url: &str) -> Result<Self, Error> {
let mut fallback_port = DEFAULT_PORT;
let after_scheme = {
let mut split = url.splitn(2, "://");
let s = split.next().unwrap();
match split.next() {
None => s,
Some(after) => {
if s == "http" {
fallback_port = 80;
} else if s == "https" {
fallback_port = 443;
} else {
return Err(Error::url(url, "scheme schould be http or https"));
}
after
}
}
};
let (before_path, path) = {
if let Some(slash) = after_scheme.find("/") {
(&after_scheme[0..slash], &after_scheme[slash..])
} else {
(after_scheme, "/")
}
};
let after_auth = {
let mut split = before_path.splitn(2, "@");
let s = split.next().unwrap();
split.next().unwrap_or(s)
};
let mut split = after_auth.split(":");
let hostname = split.next().unwrap();
let port: u16 = match split.next() {
Some(port_str) => match port_str.parse() {
Ok(port) => port,
Err(_) => return Err(Error::url(url, "invalid port")),
},
None => fallback_port,
};
if split.next().is_some() {
return Err(Error::url(url, "unexpected extra colon"));
}
self.tp.addr = match tokio::net::lookup_host((hostname, port)).await?.next() {
Some(a) => a,
None => {
return Err(Error::url(url, "invalid hostname: error extracting socket address"))
}
};
self.tp.path = path.to_owned();
Ok(self)
}
pub fn auth<S: AsRef<str>>(mut self, user: S, pass: Option<S>) -> Self {
let mut auth = user.as_ref().to_owned();
auth.push(':');
if let Some(ref pass) = pass {
auth.push_str(&pass.as_ref()[..]);
}
self.tp.basic_auth = Some(format!("Basic {}", &base64::encode(auth.as_bytes())));
self
}
pub fn cookie_auth<S: AsRef<str>>(mut self, cookie: S) -> Self {
self.tp.basic_auth = Some(format!("Basic {}", &base64::encode(cookie.as_ref().as_bytes())));
self
}
pub fn build(self) -> SimpleHttpTransport {
self.tp
}
}
use crate::client::Client;
impl Client {
pub async fn simple_http(
url: &str,
user: Option<String>,
pass: Option<String>,
) -> Result<Client, Error> {
let mut builder = Builder::new().url(&url).await?;
if let Some(user) = user {
builder = builder.auth(user, pass);
}
Ok(Client::with_transport(builder.build()))
}
}
#[cfg(test)]
mod tests {
use std::net;
use super::*;
use Client;
#[tokio::test]
async fn test_urls() {
let addr: net::SocketAddr =
tokio::net::lookup_host(("localhost", 22)).await.unwrap().next().unwrap();
let urls = [
"localhost:22",
"http://localhost:22/",
"https://localhost:22/walletname/stuff?it=working",
"http://me:weak@localhost:22/wallet",
];
for u in &urls {
let tp = Builder::new().url(*u).await.unwrap().build();
assert_eq!(tp.addr, addr);
}
let addr: net::SocketAddr =
tokio::net::lookup_host(("localhost", 80)).await.unwrap().next().unwrap();
let tp = Builder::new().url("http://localhost/").await.unwrap().build();
assert_eq!(tp.addr, addr);
let addr: net::SocketAddr =
tokio::net::lookup_host(("localhost", 443)).await.unwrap().next().unwrap();
let tp = Builder::new().url("https://localhost/").await.unwrap().build();
assert_eq!(tp.addr, addr);
let addr: net::SocketAddr = tokio::net::lookup_host(("localhost", super::DEFAULT_PORT))
.await
.unwrap()
.next()
.unwrap();
let tp = Builder::new().url("localhost").await.unwrap().build();
assert_eq!(tp.addr, addr);
let valid_urls = [
"localhost",
"127.0.0.1:8080",
"http://127.0.0.1:8080/",
"http://127.0.0.1:8080/rpc/test",
"https://127.0.0.1/rpc/test",
];
for u in &valid_urls {
Builder::new().url(*u).await.expect(&format!("error for: {}", u));
}
let invalid_urls = [
"127.0.0.1.0:8080",
"httpx://127.0.0.1:8080/",
"ftp://127.0.0.1:8080/rpc/test",
"http://127.0.0./rpc/test",
];
for u in &invalid_urls {
if let Ok(b) = Builder::new().url(*u).await {
let tp = b.build();
panic!("expected error for url {}, got {:?}", u, tp);
}
}
}
#[tokio::test]
async fn construct() {
let tp = Builder::new()
.timeout(Duration::from_millis(100))
.url("localhost:22")
.await
.unwrap()
.auth("user", None)
.build();
let _ = Client::with_transport(tp);
let _ = Client::simple_http("localhost:22", None, None).await.unwrap();
}
}