use crate::{Error, Source};
use log::info;
use sha2::{Digest, Sha256};
use std::io::{self, Cursor, Read};
use std::path::{Path, PathBuf};
use std::time::Duration;
use std::{fs, thread};
use tar::Archive;
use ureq::Agent;
const USER_AGENT: &str = "https://github.com/rust-osdev/ovmf-prebuilt";
const MAX_DOWNLOAD_SIZE_IN_BYTES: usize = 10 * 1024 * 1024;
const MAX_DOWNLOAD_RETRIES: usize = 4;
pub(crate) fn update_cache(source: Source, prebuilt_dir: &Path) -> Result<(), Error> {
let hash_path = prebuilt_dir.join("sha256");
if let Ok(current_hash) = fs::read_to_string(&hash_path) {
if current_hash == source.sha256 {
return Ok(());
}
}
let base_url = "https://github.com/rust-osdev/ovmf-prebuilt/releases/download";
let url = format!(
"{base_url}/{release}/{release}-bin.tar.xz",
release = source.tag
);
let data = retry(MAX_DOWNLOAD_RETRIES, || download_url(&url))?;
let actual_hash = format!("{:x}", Sha256::digest(&data));
if actual_hash != source.sha256 {
return Err(Error::HashMismatch {
actual: actual_hash,
expected: source.sha256.to_owned(),
});
}
let decompressed = decompress(&data)?;
let _ = fs::remove_dir_all(prebuilt_dir);
extract(&decompressed, prebuilt_dir).map_err(Error::Extract)?;
fs::write(&hash_path, actual_hash).map_err(Error::HashWrite)?;
Ok(())
}
fn download_url(url: &str) -> Result<Vec<u8>, Error> {
let config = Agent::config_builder().user_agent(USER_AGENT).build();
let agent = Agent::new_with_config(config);
info!("downloading {url}");
let resp = agent
.get(url)
.call()
.map_err(|err| Error::Request(Box::new(err)))?;
let mut data = Vec::with_capacity(MAX_DOWNLOAD_SIZE_IN_BYTES);
resp.into_body()
.into_reader()
.take(MAX_DOWNLOAD_SIZE_IN_BYTES.try_into().unwrap())
.read_to_end(&mut data)
.map_err(Error::Download)?;
info!("received {} bytes", data.len());
Ok(data)
}
fn retry<F>(max_retries: usize, mut f: F) -> Result<Vec<u8>, Error>
where
F: FnMut() -> Result<Vec<u8>, Error>,
{
let mut delay = Duration::from_secs(1);
let max_attempts = 1 + max_retries;
for attempt in 1..=max_attempts {
match f() {
Ok(r) => return Ok(r),
Err(err) => {
if attempt == max_attempts {
return Err(err);
}
info!("sleeping for {delay:?} before retrying...");
thread::sleep(delay);
delay *= 2;
}
}
}
unreachable!();
}
fn decompress(data: &[u8]) -> Result<Vec<u8>, Error> {
info!("decompressing tarball");
let mut decompressed = Vec::new();
let mut compressed = Cursor::new(data);
lzma_rs::xz_decompress(&mut compressed, &mut decompressed).map_err(Error::Decompress)?;
Ok(decompressed)
}
fn extract(tarball_data: &[u8], prebuilt_dir: &Path) -> Result<(), io::Error> {
let cursor = Cursor::new(tarball_data);
let mut archive = Archive::new(cursor);
for entry in archive.entries()? {
let mut entry = entry?;
if entry.size() == 0 {
continue;
}
let path = entry.path()?;
let path: PathBuf = path.components().skip(1).collect();
let dir = path.parent().unwrap();
let dst_dir = prebuilt_dir.join(dir);
let dst_path = prebuilt_dir.join(path);
info!("unpacking to {}", dst_path.display());
fs::create_dir_all(dst_dir)?;
entry.unpack(dst_path)?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_retry_immediate_success() {
let mut attempts = 0;
retry(4, || {
attempts += 1;
Ok(vec![])
})
.unwrap();
assert_eq!(attempts, 1);
}
#[test]
fn test_retry_one_retry() {
let mut attempts = 0;
retry(4, || {
attempts += 1;
if attempts < 2 {
Err(Error::Download(io::ErrorKind::Interrupted.into()))
} else {
Ok(vec![])
}
})
.unwrap();
assert_eq!(attempts, 2);
}
#[test]
fn test_retry_failure() {
let mut attempts = 0;
assert!(retry(2, || {
attempts += 1;
Err(Error::Download(io::ErrorKind::Interrupted.into()))
})
.is_err());
assert_eq!(attempts, 3);
}
}