use std::fs::File;
use std::io::Write;
use std::path::Path;
use std::time::Duration;
use eyre::{bail, Report, Result};
use reqwest::header::HeaderMap;
use reqwest::{ClientBuilder, IntoUrl, RequestBuilder, Response};
use std::sync::LazyLock as Lazy;
use url::Url;
use crate::cli::version;
use crate::config::SETTINGS;
use crate::file::display_path;
use crate::tokio::RUNTIME;
use crate::ui::progress_report::SingleReport;
use crate::{env, file};
#[cfg(not(test))]
pub static HTTP_VERSION_CHECK: Lazy<Client> =
Lazy::new(|| Client::new(Duration::from_secs(3)).unwrap());
pub static HTTP: Lazy<Client> = Lazy::new(|| {
let duration = humantime::parse_duration(&SETTINGS.http_timeout)
.unwrap_or_else(|_| Duration::from_secs(SETTINGS.http_timeout.parse().unwrap()));
Client::new(duration).unwrap()
});
pub static HTTP_FETCH: Lazy<Client> = Lazy::new(|| {
Client::new(humantime::parse_duration(&SETTINGS.fetch_remote_versions_timeout).unwrap())
.unwrap()
});
#[derive(Debug)]
pub struct Client {
reqwest: reqwest::Client,
}
impl Client {
fn new(timeout: Duration) -> Result<Self> {
Ok(Self {
reqwest: Self::_new()
.read_timeout(timeout)
.connect_timeout(timeout)
.build()?,
})
}
fn _new() -> ClientBuilder {
let v = &*version::VERSION;
let shell = env::MISE_SHELL.map(|s| s.to_string()).unwrap_or_default();
ClientBuilder::new()
.user_agent(format!("mise/{v} {shell}").trim())
.gzip(true)
.zstd(true)
}
pub fn get_bytes<U: IntoUrl>(&self, url: U) -> Result<impl AsRef<[u8]>> {
let url = url.into_url().unwrap();
RUNTIME.block_on(async {
let resp = self.get_async(url.clone()).await?;
Ok(resp.bytes().await?)
})
}
pub async fn get_async<U: IntoUrl>(&self, url: U) -> Result<Response> {
let get = |url: Url| async move {
debug!("GET {}", &url);
let mut req = self.reqwest.get(url.clone());
req = with_github_auth(&url, req);
let resp = req.send().await?;
debug!("GET {url} {}", resp.status());
display_github_rate_limit(&resp);
resp.error_for_status_ref()?;
Ok(resp)
};
let mut url = url.into_url().unwrap();
let resp = match get(url.clone()).await {
Ok(resp) => resp,
Err(_) if url.scheme() == "http" => {
url.set_scheme("https").unwrap();
get(url).await?
}
Err(err) => return Err(err),
};
resp.error_for_status_ref()?;
Ok(resp)
}
pub fn head<U: IntoUrl>(&self, url: U) -> Result<Response> {
let url = url.into_url().unwrap();
RUNTIME.block_on(self.head_async(url))
}
pub async fn head_async<U: IntoUrl>(&self, url: U) -> Result<Response> {
let head = |url: Url| async move {
debug!("HEAD {}", &url);
let mut req = self.reqwest.head(url.clone());
req = with_github_auth(&url, req);
let resp = req.send().await?;
debug!("HEAD {url} {}", resp.status());
display_github_rate_limit(&resp);
resp.error_for_status_ref()?;
Ok(resp)
};
let mut url = url.into_url().unwrap();
let resp = match head(url.clone()).await {
Ok(resp) => resp,
Err(_) if url.scheme() == "http" => {
url.set_scheme("https").unwrap();
head(url).await?
}
Err(err) => return Err(err),
};
resp.error_for_status_ref()?;
Ok(resp)
}
pub fn get_text<U: IntoUrl>(&self, url: U) -> Result<String> {
let mut url = url.into_url().unwrap();
let text = RUNTIME.block_on(async {
let resp = self.get_async(url.clone()).await?;
Ok::<String, eyre::Error>(resp.text().await?)
})?;
if text.starts_with("<!DOCTYPE html>") {
if url.scheme() == "http" {
url.set_scheme("https").unwrap();
return self.get_text(url);
}
bail!("Got HTML instead of text from {}", url);
}
Ok(text)
}
pub fn json_headers<T, U: IntoUrl>(&self, url: U) -> Result<(T, HeaderMap)>
where
T: serde::de::DeserializeOwned,
{
let url = url.into_url().unwrap();
let (json, headers) = RUNTIME.block_on(async {
let resp = self.get_async(url).await?;
let headers = resp.headers().clone();
Ok::<(T, HeaderMap), eyre::Error>((resp.json().await?, headers))
})?;
Ok((json, headers))
}
pub fn json<T, U: IntoUrl>(&self, url: U) -> Result<T>
where
T: serde::de::DeserializeOwned,
{
self.json_headers(url).map(|(json, _)| json)
}
pub fn download_file<U: IntoUrl>(
&self,
url: U,
path: &Path,
pr: Option<&Box<dyn SingleReport>>,
) -> Result<()> {
let url = url.into_url()?;
debug!("GET Downloading {} to {}", &url, display_path(path));
RUNTIME.block_on(async {
let mut resp = self.get_async(url).await?;
if let Some(length) = resp.content_length() {
if let Some(pr) = pr {
pr.set_length(length);
}
}
file::create_dir_all(path.parent().unwrap())?;
let mut file = File::create(path)?;
while let Some(chunk) = resp.chunk().await? {
file.write_all(&chunk)?;
if let Some(pr) = pr {
pr.inc(chunk.len() as u64);
}
}
Ok::<(), eyre::Error>(())
})?;
Ok(())
}
}
pub fn error_code(e: &Report) -> Option<u16> {
if e.to_string().contains("404") {
return Some(404);
}
if let Some(err) = e.downcast_ref::<reqwest::Error>() {
err.status().map(|s| s.as_u16())
} else {
None
}
}
fn with_github_auth(url: &Url, mut req: RequestBuilder) -> RequestBuilder {
if url.host_str() == Some("api.github.com") {
if let Some(token) = &*env::GITHUB_TOKEN {
req = req.header("authorization", format!("token {}", token));
req = req.header("x-github-api-version", "2022-11-28");
}
}
req
}
fn display_github_rate_limit(resp: &Response) {
let status = resp.status().as_u16();
if status == 403 || status == 429 {
if resp
.headers()
.get("x-ratelimit-remaining")
.is_none_or(|r| r != "0")
{
return;
}
if let Some(reset) = resp.headers().get("x-ratelimit-reset") {
let reset = reset.to_str().map(|r| r.to_string()).unwrap_or_default();
if let Some(reset) = chrono::DateTime::from_timestamp(reset.parse().unwrap(), 0) {
warn!(
"GitHub rate limit exceeded. Resets at {}",
reset.naive_local().to_string()
);
}
}
}
}