1use std::time::{Duration, Instant};
6use std::{fmt, io, net, thread};
7use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWrite, BufReader, BufWriter};
8use tokio::net::{TcpStream, ToSocketAddrs};
9
10use base64;
11use serde;
12use serde_json;
13
14use super::{Request, Response};
15use crate::client::Transport;
16
17pub const DEFAULT_PORT: u16 = 8332;
20
21#[derive(Clone, Debug)]
24pub struct SimpleHttpTransport {
25 addr: net::SocketAddr,
26 path: String,
27 timeout: Duration,
28 basic_auth: Option<String>,
30}
31
32impl Default for SimpleHttpTransport {
33 fn default() -> Self {
34 SimpleHttpTransport {
35 addr: net::SocketAddr::new(
36 net::IpAddr::V4(net::Ipv4Addr::new(127, 0, 0, 1)),
37 DEFAULT_PORT,
38 ),
39 path: "/".to_owned(),
40 timeout: Duration::from_secs(15),
41 basic_auth: None,
42 }
43 }
44}
45
46use tokio::io::{AsyncReadExt, AsyncWriteExt};
47impl SimpleHttpTransport {
48 pub fn new() -> Self {
50 SimpleHttpTransport::default()
51 }
52
53 pub fn builder() -> Builder {
55 Builder::new()
56 }
57
58 async fn request<R>(&self, req: impl serde::Serialize) -> Result<R, Error>
59 where
60 R: for<'a> serde::de::Deserialize<'a>,
61 {
62 let request_deadline = Instant::now() + self.timeout;
64 let sock = tokio::time::timeout(self.timeout, TcpStream::connect(self.addr)).await??;
65 let (read, write) = sock.into_split();
66 let mut writer = BufWriter::new(write);
67
68 let body = serde_json::to_vec(&req)?;
70
71 writer.write_all(b"POST ").await?;
73 writer.write_all(self.path.as_bytes()).await?;
74 writer.write_all(b" HTTP/1.1\r\n").await?;
75 writer.write_all(b"Content-Type: application/json-rpc\r\n").await?;
77 writer.write_all(b"Content-Length: ").await?;
78 writer.write_all(body.len().to_string().as_bytes()).await?;
79 writer.write_all(b"\r\n").await?;
80 if let Some(ref auth) = self.basic_auth {
81 writer.write_all(b"Authorization: ").await?;
82 writer.write_all(auth.as_ref()).await?;
83 writer.write_all(b"\r\n").await?;
84 }
85 writer.write_all(b"\r\n").await?;
87 writer.write_all(&body).await?;
88 writer.flush().await?;
89
90 let mut reader = BufReader::new(read);
92
93 let http_response = get_line(&mut reader, request_deadline).await?;
95 if http_response.len() < 12 || !http_response.starts_with("HTTP/1.1 ") {
96 return Err(Error::HttpParseError);
97 }
98 let response_code = match http_response[9..12].parse::<u16>() {
99 Ok(n) => n,
100 Err(_) => return Err(Error::HttpParseError),
101 };
102
103 while get_line(&mut reader, request_deadline).await? != "\r\n" {}
105
106 if response_code == 401 {
107 return Err(Error::HttpErrorCode(response_code));
109 }
110
111 let resp_body = get_line(&mut reader, request_deadline).await?;
114 match serde_json::from_str(&resp_body) {
115 Ok(s) => Ok(s),
116 Err(e) => {
117 if response_code != 200 {
118 Err(Error::HttpErrorCode(response_code))
119 } else {
120 Err(e.into())
122 }
123 }
124 }
125 }
126}
127
128#[derive(Debug)]
130pub enum Error {
131 InvalidUrl {
133 url: String,
135 reason: &'static str,
137 },
138 SocketError(io::Error),
140 HttpParseError,
142 HttpErrorCode(u16),
144 Timeout,
146 Json(serde_json::Error),
148}
149
150impl Error {
151 fn url<U: Into<String>>(url: U, reason: &'static str) -> Error {
153 Error::InvalidUrl {
154 url: url.into(),
155 reason: reason,
156 }
157 }
158}
159
160impl std::error::Error for Error {}
161
162impl fmt::Display for Error {
163 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
164 match *self {
165 Error::InvalidUrl {
166 ref url,
167 ref reason,
168 } => write!(f, "invalid URL '{}': {}", url, reason),
169 Error::SocketError(ref e) => write!(f, "Couldn't connect to host: {}", e),
170 Error::HttpParseError => f.write_str("Couldn't parse response header."),
171 Error::HttpErrorCode(c) => write!(f, "unexpected HTTP code: {}", c),
172 Error::Timeout => f.write_str("Didn't receive response data in time, timed out."),
173 Error::Json(ref e) => write!(f, "JSON error: {}", e),
174 }
175 }
176}
177use tokio::time::error::Elapsed;
178impl From<Elapsed> for Error {
179 fn from(e: Elapsed) -> Error {
180 Error::Timeout
181 }
182}
183
184impl From<io::Error> for Error {
185 fn from(e: io::Error) -> Self {
186 Error::SocketError(e)
187 }
188}
189
190impl From<serde_json::Error> for Error {
191 fn from(e: serde_json::Error) -> Self {
192 Error::Json(e)
193 }
194}
195
196impl From<Error> for crate::Error {
197 fn from(e: Error) -> Self {
198 match e {
199 Error::Json(e) => crate::Error::Json(e),
200 e => crate::Error::Transport(Box::new(e)),
201 }
202 }
203}
204
205async fn get_line<R: AsyncBufRead + Unpin>(
208 reader: &mut R,
209 deadline: Instant,
210) -> Result<String, Error> {
211 let mut line = String::new();
212 while deadline > Instant::now() {
213 match reader.read_line(&mut line).await {
214 Ok(0) => thread::sleep(Duration::from_millis(5)),
216 Ok(_) => return Ok(line),
218 Err(e) => return Err(Error::SocketError(e)),
220 }
221 }
222 Err(Error::Timeout)
223}
224
225use async_trait::async_trait;
226#[async_trait]
227impl Transport for SimpleHttpTransport {
228 async fn send_request(&self, req: Request<'_>) -> Result<Response, crate::Error> {
229 Ok(self.request(req).await?)
230 }
231
232 async fn send_batch(&self, reqs: &[Request<'_>]) -> Result<Vec<Response>, crate::Error> {
233 Ok(self.request(reqs).await?)
234 }
235
236 fn fmt_target(&self, f: &mut fmt::Formatter) -> fmt::Result {
237 write!(f, "http://{}:{}{}", self.addr.ip(), self.addr.port(), self.path)
238 }
239}
240
241#[derive(Clone, Debug)]
243pub struct Builder {
244 tp: SimpleHttpTransport,
245}
246
247impl Builder {
248 pub fn new() -> Builder {
250 Builder {
251 tp: SimpleHttpTransport::new(),
252 }
253 }
254
255 pub fn timeout(mut self, timeout: Duration) -> Self {
257 self.tp.timeout = timeout;
258 self
259 }
260
261 pub async fn url(mut self, url: &str) -> Result<Self, Error> {
263 let mut fallback_port = DEFAULT_PORT;
269
270 let after_scheme = {
273 let mut split = url.splitn(2, "://");
274 let s = split.next().unwrap();
275 match split.next() {
276 None => s, Some(after) => {
278 if s == "http" {
280 fallback_port = 80;
281 } else if s == "https" {
282 fallback_port = 443;
283 } else {
284 return Err(Error::url(url, "scheme schould be http or https"));
285 }
286 after
287 }
288 }
289 };
290 let (before_path, path) = {
292 if let Some(slash) = after_scheme.find("/") {
293 (&after_scheme[0..slash], &after_scheme[slash..])
294 } else {
295 (after_scheme, "/")
296 }
297 };
298 let after_auth = {
300 let mut split = before_path.splitn(2, "@");
301 let s = split.next().unwrap();
302 split.next().unwrap_or(s)
303 };
304 let mut split = after_auth.split(":");
306 let hostname = split.next().unwrap();
307 let port: u16 = match split.next() {
308 Some(port_str) => match port_str.parse() {
309 Ok(port) => port,
310 Err(_) => return Err(Error::url(url, "invalid port")),
311 },
312 None => fallback_port,
313 };
314 if split.next().is_some() {
316 return Err(Error::url(url, "unexpected extra colon"));
317 }
318
319 self.tp.addr = match tokio::net::lookup_host((hostname, port)).await?.next() {
320 Some(a) => a,
321 None => {
322 return Err(Error::url(url, "invalid hostname: error extracting socket address"))
323 }
324 };
325 self.tp.path = path.to_owned();
326 Ok(self)
327 }
328
329 pub fn auth<S: AsRef<str>>(mut self, user: S, pass: Option<S>) -> Self {
331 let mut auth = user.as_ref().to_owned();
332 auth.push(':');
333 if let Some(ref pass) = pass {
334 auth.push_str(&pass.as_ref()[..]);
335 }
336 self.tp.basic_auth = Some(format!("Basic {}", &base64::encode(auth.as_bytes())));
337 self
338 }
339
340 pub fn cookie_auth<S: AsRef<str>>(mut self, cookie: S) -> Self {
342 self.tp.basic_auth = Some(format!("Basic {}", &base64::encode(cookie.as_ref().as_bytes())));
343 self
344 }
345
346 pub fn build(self) -> SimpleHttpTransport {
348 self.tp
349 }
350}
351
352use crate::client::Client;
353impl Client {
354 pub async fn simple_http(
356 url: &str,
357 user: Option<String>,
358 pass: Option<String>,
359 ) -> Result<Client, Error> {
360 let mut builder = Builder::new().url(&url).await?;
361 if let Some(user) = user {
362 builder = builder.auth(user, pass);
363 }
364 Ok(Client::with_transport(builder.build()))
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use std::net;
371
372 use super::*;
373 use Client;
374
375 #[tokio::test]
376 async fn test_urls() {
377 let addr: net::SocketAddr =
378 tokio::net::lookup_host(("localhost", 22)).await.unwrap().next().unwrap();
379 let urls = [
380 "localhost:22",
381 "http://localhost:22/",
382 "https://localhost:22/walletname/stuff?it=working",
383 "http://me:weak@localhost:22/wallet",
384 ];
385 for u in &urls {
386 let tp = Builder::new().url(*u).await.unwrap().build();
387 assert_eq!(tp.addr, addr);
388 }
389
390 let addr: net::SocketAddr =
392 tokio::net::lookup_host(("localhost", 80)).await.unwrap().next().unwrap();
393 let tp = Builder::new().url("http://localhost/").await.unwrap().build();
394 assert_eq!(tp.addr, addr);
395 let addr: net::SocketAddr =
396 tokio::net::lookup_host(("localhost", 443)).await.unwrap().next().unwrap();
397 let tp = Builder::new().url("https://localhost/").await.unwrap().build();
398 assert_eq!(tp.addr, addr);
399 let addr: net::SocketAddr = tokio::net::lookup_host(("localhost", super::DEFAULT_PORT))
400 .await
401 .unwrap()
402 .next()
403 .unwrap();
404 let tp = Builder::new().url("localhost").await.unwrap().build();
405 assert_eq!(tp.addr, addr);
406
407 let valid_urls = [
408 "localhost",
409 "127.0.0.1:8080",
410 "http://127.0.0.1:8080/",
411 "http://127.0.0.1:8080/rpc/test",
412 "https://127.0.0.1/rpc/test",
413 ];
414 for u in &valid_urls {
415 Builder::new().url(*u).await.expect(&format!("error for: {}", u));
416 }
417
418 let invalid_urls = [
419 "127.0.0.1.0:8080",
420 "httpx://127.0.0.1:8080/",
421 "ftp://127.0.0.1:8080/rpc/test",
422 "http://127.0.0./rpc/test",
423 ];
425 for u in &invalid_urls {
426 if let Ok(b) = Builder::new().url(*u).await {
427 let tp = b.build();
428 panic!("expected error for url {}, got {:?}", u, tp);
429 }
430 }
431 }
432
433 #[tokio::test]
434 async fn construct() {
435 let tp = Builder::new()
436 .timeout(Duration::from_millis(100))
437 .url("localhost:22")
438 .await
439 .unwrap()
440 .auth("user", None)
441 .build();
442 let _ = Client::with_transport(tp);
443
444 let _ = Client::simple_http("localhost:22", None, None).await.unwrap();
445 }
446}