use std::convert::Infallible;
use std::error::Error as StdError;
use std::future::{self, Future};
use std::sync::Arc;
use hyper::header::HeaderValue;
use hyper::server::accept::Accept;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Method, Request, Response, Server, StatusCode};
use tokio::io::{AsyncRead, AsyncWrite};
use crate::headers::{HEADER_FILE_LENGTH, HEADER_RANGE, HEADER_SECRET};
use crate::provider::Provider;
use crate::utils::{body_stream, get_hash, stream_body};
use crate::{make_resp, BlobRange, Error};
pub struct Config {
pub provider: Box<dyn Provider>,
pub secret: String,
}
pub async fn listen<I>(config: Config, incoming: I) -> hyper::Result<()>
where
I: Accept,
I::Error: Into<Box<dyn StdError + Send + Sync>>,
I::Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
listen_with_shutdown(config, incoming, future::pending()).await
}
pub async fn listen_with_shutdown<I>(
config: Config,
incoming: I,
shutdown: impl Future<Output = ()>,
) -> hyper::Result<()>
where
I: Accept,
I::Error: Into<Box<dyn StdError + Send + Sync>>,
I::Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let config = Arc::new(config);
let make_svc = make_service_fn(move |_conn| {
let config = Arc::clone(&config);
async {
Ok::<_, Infallible>(service_fn(move |req| {
let config = Arc::clone(&config);
async {
let resp = handle(config, req).await;
Ok::<_, Infallible>(resp.unwrap_or_else(|err_resp| err_resp))
}
}))
}
});
Server::builder(incoming)
.serve(make_svc)
.with_graceful_shutdown(shutdown)
.await
}
async fn handle(config: Arc<Config>, req: Request<Body>) -> Result<Response<Body>, Response<Body>> {
let secret = req.headers().get(HEADER_SECRET);
let secret = secret.and_then(|s| s.to_str().ok());
match (req.method(), req.uri().path()) {
(&Method::GET, "/") => Ok(make_resp(StatusCode::OK, "blobnet ok")),
_ if secret != Some(&config.secret) => {
Err(make_resp(StatusCode::UNAUTHORIZED, "unauthorized"))
}
(&Method::HEAD, path) => {
let hash = get_hash(path)?;
let len = config.provider.head(hash).await?;
Response::builder()
.header(HEADER_FILE_LENGTH, len.to_string())
.body(Body::empty())
.map_err(|e| Error::Internal(e.into()).into())
}
(&Method::GET, path) => {
let range = req.headers().get(HEADER_RANGE).and_then(parse_range_header);
let hash = get_hash(path)?;
let reader = config.provider.get(hash, range).await?;
Ok(Response::new(stream_body(reader)))
}
(&Method::PUT, "/") => {
let body = req.into_body();
let hash = config.provider.put(body_stream(body)).await?;
Ok(Response::new(Body::from(hash)))
}
_ => Err(make_resp(StatusCode::NOT_FOUND, "invalid request path")),
}
}
fn parse_range_header(s: &HeaderValue) -> BlobRange {
let s = s.to_str().ok()?;
let (start, end) = s.split_once('-')?;
Some((start.parse().ok()?, end.parse().ok()?))
}