use std::env;
use std::env::consts;
use std::io::Cursor;
use std::path::{Path, PathBuf};
use bon::bon;
use bytes::Bytes;
use current_platform::CURRENT_PLATFORM;
use hex_simd::AsciiCase;
use http::{HeaderMap, HeaderValue};
use osc94::Progress;
use semver::{Version, VersionReq};
use serde::Deserialize;
use serde::de::DeserializeOwned;
use url::Url;
use crate::{Error, HTTPClient};
pub struct SelfUpdate {
owner: String,
repo: String,
auth_token: Option<String>,
bin_name: String,
current_version: Version,
show_release_body: bool,
client: HTTPClient,
client_rss: HTTPClient,
}
#[derive(Deserialize)]
pub struct LatestReleases {
pub tag_name: String,
pub body: String,
pub assets: Vec<ReleaseAsset>,
}
#[derive(Deserialize)]
pub struct ReleaseAsset {
pub name: String,
pub url: Url,
pub digest: String,
}
#[bon]
impl SelfUpdate {
const HOST: &'static str = "https://api.github.com";
const API_VERSION: &'static str = "2022-11-28";
#[builder]
pub async fn new(
owner: &str,
repo: &str,
auth_token: Option<&str>,
bin_name: &str,
current_version: Version,
show_release_body: bool,
proxy: Option<Url>,
no_proxy: Option<bool>,
cert_path: Option<PathBuf>,
) -> Result<Self, Error> {
let mut headers = HeaderMap::new();
headers.insert(
"X-GitHub-Api-Version",
HeaderValue::from_static(SelfUpdate::API_VERSION),
);
let client = HTTPClient::builder()
.app_name("self-update")
.accept(HeaderValue::from_static("application/vnd.github+json"))
.headers(headers)
.maybe_proxy(proxy.clone())
.maybe_no_proxy(no_proxy)
.maybe_cert_path(cert_path.clone())
.retry_url(Url::parse(SelfUpdate::HOST)?)
.build()
.await?;
let client_rss = HTTPClient::builder()
.app_name("self-update-rss")
.accept(HeaderValue::from_static("application/octet-stream"))
.maybe_proxy(proxy)
.maybe_no_proxy(no_proxy)
.maybe_cert_path(cert_path.clone())
.build()
.await?;
Ok(Self {
owner: owner.to_string(),
repo: repo.to_string(),
auth_token: auth_token.map(str::to_owned),
bin_name: bin_name.to_string(),
current_version,
show_release_body,
client,
client_rss,
})
}
pub async fn update(self) -> Result<(), Error> {
println!("Checking target-arch... {CURRENT_PLATFORM}");
println!("Checking current version... {}", self.current_version);
let latest_releases = self.get_latest_releases().await?;
let latest_version = Version::parse(latest_releases.tag_name.trim_start_matches("v"))?;
println!("Checking latest released version... {latest_version}");
if !self.need_update(&latest_version)? {
return Ok(());
}
println!(
"New release found! {} --> {latest_version}",
self.current_version
);
if self.show_release_body {
println!("Release notes...\n\n{}", latest_releases.body);
}
let Some(target_asset) = self.get_target_asset(&latest_releases)? else {
return Err(Error::NovelApi(String::from(
"no assets available for download",
)));
};
println!(
r#"
{} release status:
* Current exe: "{}"
* New exe release: "{}"
* New exe download url: "{}"
"#,
self.bin_name,
env::current_exe()?.display(),
target_asset.name,
target_asset.url
);
println!(
"The new release will be downloaded/extracted and the existing binary will be replaced"
);
if !crate::confirm("Do you want to continue?")? {
return Ok(());
}
let bytes = if crate::support_osc94() {
let mut progress = Progress::default();
progress.indeterminate().flush()?;
let bytes = self.download_with_verify(target_asset).await?;
progress.hidden();
bytes
} else {
self.download_with_verify(target_asset).await?
};
let temp_dir = tempfile::tempdir()?;
crate::unzip(Cursor::new(bytes), &temp_dir)?;
let new_release_exe =
temp_dir
.path()
.join(format!("{}{}", self.bin_name, consts::EXE_SUFFIX));
self_replace::self_replace(new_release_exe)?;
Ok(())
}
async fn get_latest_releases(&self) -> Result<LatestReleases, Error> {
let url = format!("/repos/{}/{}/releases/latest", self.owner, self.repo);
let latest_releases: LatestReleases = self.get(url).await?;
Ok(latest_releases)
}
fn need_update(&self, latest_version: &Version) -> Result<bool, Error> {
let req = VersionReq::parse(&format!(">{}", self.current_version))?;
Ok(req.matches(latest_version))
}
fn get_target_asset<'a>(
&self,
latest_releases: &'a LatestReleases,
) -> Result<Option<&'a ReleaseAsset>, Error> {
for asset in &latest_releases.assets {
if asset.name.contains(CURRENT_PLATFORM)
&& Path::new(&asset.name)
.extension()
.is_some_and(|ext| ext == "zip")
{
return Ok(Some(asset));
}
}
Ok(None)
}
async fn download_with_verify(&self, asset: &ReleaseAsset) -> Result<Bytes, Error> {
let bytes = self.get_rss(asset.url.to_string()).await?;
let hash = crate::sha256_hex(&bytes, AsciiCase::Lower);
let asset_hash = asset.digest.trim_start_matches("sha256:");
if hash != asset_hash {
return Err(Error::NovelApi(format!(
"incorrect hash value: `{hash}` vs `{asset_hash}`"
)));
}
Ok(bytes)
}
async fn get<T, R>(&self, url: T) -> Result<R, Error>
where
T: AsRef<str>,
R: DeserializeOwned,
{
let mut builder = self.client.get(SelfUpdate::HOST.to_string() + url.as_ref());
if let Some(auth_token) = &self.auth_token {
builder = builder.bearer_auth(auth_token);
}
let response = builder.send().await?.error_for_status()?;
Ok(sonic_rs::from_slice(&response.bytes().await?)?)
}
async fn get_rss<T>(&self, url: T) -> Result<Bytes, Error>
where
T: AsRef<str>,
{
let mut builder = self.client_rss.get(url.as_ref());
if let Some(auth_token) = &self.auth_token {
builder = builder.bearer_auth(auth_token);
}
Ok(builder.send().await?.error_for_status()?.bytes().await?)
}
}