#[cfg(test)]
mod tests;
use std::{io::Read, net::SocketAddr, str::FromStr, sync::RwLock};
use crate::{digest::Digest, EventHandler, Reference};
const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
#[derive(thiserror::Error, Debug)]
pub enum HttpError {
#[error("{0}")]
Client(#[from] Box<ureq::Error>),
#[error("Missing authentication tokens.")]
MissingTokens,
#[error("Invalid JSON: {0}")]
Json(#[from] serde_json::Error),
}
impl From<ureq::Error> for HttpError {
fn from(value: ureq::Error) -> Self {
HttpError::Client(Box::new(value))
}
}
pub(super) struct Client<'a, E> {
event_handler: &'a E,
auth_token: RwLock<Option<String>>,
host: String,
}
impl<'a, E> Client<'a, E>
where
E: EventHandler,
{
pub fn new(reference: &Reference, event_handler: &'a E) -> Self {
let host = format!(
"{}{}/v2/{}",
guess_scheme(reference.registry),
reference.registry,
reference.repository
);
Client {
event_handler,
auth_token: Default::default(),
host,
}
}
pub fn get(&self, path: &str, accept: Option<&str>) -> Result<ureq::Response, HttpError> {
let url = format!("{}/{}", self.host, path);
let mut request = ureq::get(&url);
if let Some(accept) = accept {
request = request.set("Accept", accept);
}
self.send(request)
}
pub fn download_blob(&self, blob: &Digest) -> Result<impl Read, HttpError> {
let response = self.get(&format!("blobs/{}", blob.source()), None)?;
Ok(blob.wrap_reader(response.into_reader()))
}
fn send(&self, request: ureq::Request) -> Result<ureq::Response, HttpError> {
let request = request.set("User-Agent", USER_AGENT);
self.event_handler.registry_request(request.url());
let auth_token = self.auth_token.read().unwrap();
if let Some(auth) = auth_token.as_deref() {
return Ok(request.set("Authorization", auth).call()?);
}
drop(auth_token);
let mut auth_token = self.auth_token.write().unwrap();
let response = match request.clone().call() {
Ok(r) => return Ok(r),
Err(ureq::Error::Status(401, r)) => r,
Err(e) => return Err(e.into()),
};
let Some(auth_request) = response
.header("www-authenticate")
.and_then(build_auth_request)
else {
return Err(ureq::Error::Status(401, response).into());
};
self.event_handler.registry_auth(auth_request.url());
#[derive(serde::Deserialize, Debug)]
struct Tokens {
token: Option<String>,
access_token: Option<String>,
}
let mut token = match serde_json::from_reader(auth_request.call()?.into_reader())? {
Tokens { token: Some(t), .. } => t,
Tokens {
access_token: Some(t),
..
} => t,
_ => return Err(HttpError::MissingTokens),
};
token.insert_str(0, "Bearer ");
*auth_token = Some(token);
drop(auth_token);
self.send(request)
}
}
fn guess_scheme(registry: &str) -> &'static str {
const HTTP: &str = "http://";
const HTTPS: &str = "https://";
if registry.ends_with(":80") {
return HTTP;
}
if let Ok(address) = SocketAddr::from_str(registry) {
let loopback = match address {
SocketAddr::V4(v4) => v4.ip().is_loopback(),
SocketAddr::V6(v6) => v6.ip().is_loopback(),
};
return if loopback { HTTP } else { HTTPS };
}
HTTPS
}
fn build_auth_request(auth_spec: &str) -> Option<ureq::Request> {
let mut request = None;
let mut pending_params = vec![];
let mut tail = auth_spec;
tail = tail.strip_prefix("Bearer ")?;
loop {
let (key, value) = tail.split_once('=')?;
let key = key.trim_ascii();
let (value, after) = value.strip_prefix('"')?.split_once('"')?;
if key == "realm" {
request = Some(
pending_params
.drain(..)
.fold(ureq::get(value), |r, (k, v)| r.query(k, v)),
);
} else {
match request.take() {
Some(r) => request = Some(r.query(key, value)),
None => pending_params.push((key, value)),
}
}
tail = match after.trim_ascii_start() {
"" => return request,
t => t.strip_prefix(',')?,
};
}
}