use reqwest::header::ToStrError;
use reqwest::Client;
use std::path::{Path, PathBuf};
use std::string::FromUtf8Error;
use std::time::SystemTime;
use thiserror::Error;
use tokio::fs::{File, OpenOptions};
use tokio::io::{AsyncReadExt, AsyncSeekExt};
use tokio::{fs, io};
use url::Url;
use zip::ZipArchive;
use crate::config::Config;
use crate::lua_version::{LuaVersion, LuaVersionUnset};
use crate::progress::{Progress, ProgressBar};
#[derive(Error, Debug)]
pub enum ManifestFromServerError {
#[error(transparent)]
Io(#[from] io::Error),
#[error("failed to pull manifest:\n{0}")]
Request(#[from] reqwest::Error),
#[error("failed to parse manifest:\n{0}")]
FromUtf8(#[from] FromUtf8Error),
#[error("invalidate date received from server:\n{0}")]
InvalidDate(#[from] httpdate::Error),
#[error("non-ASCII characters returned in response header:\n{0}")]
InvalidHeader(#[from] ToStrError),
#[error("error parsing manifest URL:\n{0}")]
Url(#[from] url::ParseError),
#[error("failed to read manifest archive {0}:\n{1}")]
ZipRead(Url, zip::result::ZipError),
#[error("failed to unzip manifest file {0}:\n{1}")]
ZipExtract(Url, zip::result::ZipError),
#[error(transparent)]
LuaVersion(#[from] LuaVersionUnset),
}
pub(super) async fn get_manifest(
url: Url,
manifest_version: String,
target: &Path,
client: &Client,
) -> Result<String, ManifestFromServerError> {
let response = client.get(url.clone()).send().await?;
if response.status().is_client_error() {
let url = fallback_unzipped_url(&url)?;
let manifest_bytes = client
.get(url)
.send()
.await?
.error_for_status()?
.bytes()
.await?;
let manifest = String::from_utf8(manifest_bytes.to_vec())?;
tokio::fs::write(&target, &manifest).await?;
Ok(manifest)
} else {
let manifest_bytes = response.error_for_status()?.bytes().await?;
let mut archive = ZipArchive::new(std::io::Cursor::new(manifest_bytes))
.map_err(|err| ManifestFromServerError::ZipRead(url.clone(), err))?;
let temp = tempfile::tempdir()?;
archive
.extract_unwrapped_root_dir(&temp, zip::read::root_dir_common_filter)
.map_err(|err| ManifestFromServerError::ZipExtract(url.clone(), err))?;
let mut extracted_manifest =
File::open(temp.path().join(format!("manifest-{manifest_version}"))).await?;
let mut target = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(true)
.open(target)
.await?;
io::copy(&mut extracted_manifest, &mut target).await?;
let mut manifest = String::new();
target.seek(io::SeekFrom::Start(0)).await?;
target.read_to_string(&mut manifest).await?;
Ok(manifest)
}
}
pub(crate) async fn manifest_from_cache_or_server(
server_url: &Url,
config: &Config,
bar: &Progress<ProgressBar>,
) -> Result<String, ManifestFromServerError> {
let manifest_version = LuaVersion::from(config)?.version_compatibility_str();
let url = mk_manifest_url(server_url, &manifest_version, config)?;
let cache = mk_manifest_cache(&url, config).await?;
let client = Client::new();
if let Ok(metadata) = fs::metadata(&cache).await {
let last_modified_local: SystemTime = metadata.modified()?;
let response = match client.head(url.clone()).send().await? {
response if response.status().is_client_error() => {
let url = fallback_unzipped_url(&url)?;
client.head(url).send().await?.error_for_status()?
}
response => response.error_for_status()?,
};
if let Some(last_modified_header) = response.headers().get("Last-Modified") {
let server_last_modified = httpdate::parse_http_date(last_modified_header.to_str()?)?;
if server_last_modified > last_modified_local {
bar.map(|bar| {
bar.set_message(format!("📥 Downloading updated manifest from {}", &url))
});
return get_manifest(url, manifest_version.clone(), &cache, &client).await;
}
return Ok(fs::read_to_string(&cache).await?);
}
}
bar.map(|bar| bar.set_message(format!("📥 Downloading manifest from {}", &url)));
get_manifest(url, manifest_version.clone(), &cache, &client).await
}
pub(crate) async fn manifest_from_server_only(
server_url: &Url,
config: &Config,
bar: &Progress<ProgressBar>,
) -> Result<String, ManifestFromServerError> {
let manifest_version = LuaVersion::from(config)?.version_compatibility_str();
let url = mk_manifest_url(server_url, &manifest_version, config)?;
let cache = mk_manifest_cache(&url, config).await?;
let client = Client::new();
bar.map(|bar| bar.set_message(format!("📥 Downloading manifest from {}", &url)));
get_manifest(url, manifest_version.clone(), &cache, &client).await
}
fn mk_manifest_url(
server_url: &Url,
manifest_version: &str,
config: &Config,
) -> Result<Url, ManifestFromServerError> {
let manifest_filename = format!("manifest-{manifest_version}.zip");
let url = match config.namespace() {
Some(namespace) => server_url
.join(&format!("manifests/{namespace}/"))?
.join(&manifest_filename)?,
None => server_url.join(&manifest_filename)?,
};
Ok(url)
}
async fn mk_manifest_cache(url: &Url, config: &Config) -> io::Result<PathBuf> {
let cache = config.cache_dir().join(
url.to_string()
.replace(&[':', '*', '?', '"', '<', '>', '|', '/', '\\'][..], "_")
.trim_end_matches(".zip"),
);
if let Some(cache_parent_dir) = cache.parent() {
fs::create_dir_all(cache_parent_dir).await?;
}
Ok(cache)
}
fn fallback_unzipped_url(url: &Url) -> Result<Url, url::ParseError> {
url.to_string().trim_end_matches(".zip").parse()
}
#[cfg(test)]
mod tests {
use httptest::{matchers::request, responders::status_code, Expectation, Server};
use serial_test::serial;
use crate::{config::ConfigBuilder, progress::MultiProgress};
use super::*;
fn start_test_server(manifest_name: String) -> Server {
let server = Server::run();
let manifest_path = format!("/{manifest_name}");
server.expect(
Expectation::matching(request::path(manifest_path + ".zip"))
.times(1..)
.respond_with(
status_code(200)
.append_header("Last-Modified", "Sat, 20 Jan 2024 13:14:12 GMT")
.body(
std::fs::read(
format!(
"{}/resources/test/manifest-5.1.zip",
env!("CARGO_MANIFEST_DIR")
)
.as_str(),
)
.unwrap(),
),
),
);
server
}
#[tokio::test]
#[serial]
pub async fn get_manifest_luajit() {
let cache_dir = assert_fs::TempDir::new().unwrap().to_path_buf();
let server = start_test_server("manifest-5.1".into());
let mut url_str = server.url_str(""); url_str.pop();
let config = ConfigBuilder::new()
.unwrap()
.cache_dir(Some(cache_dir))
.lua_version(Some(LuaVersion::LuaJIT))
.no_progress(Some(true))
.build()
.unwrap();
let progress = MultiProgress::new(&config);
let bar = progress.map(MultiProgress::new_bar);
manifest_from_cache_or_server(&Url::parse(&url_str).unwrap(), &config, &bar)
.await
.unwrap();
}
#[tokio::test]
#[serial]
pub async fn get_manifest_for_5_1() {
let cache_dir = assert_fs::TempDir::new().unwrap().to_path_buf();
let server = start_test_server("manifest-5.1".into());
let mut url_str = server.url_str(""); url_str.pop();
let config = ConfigBuilder::new()
.unwrap()
.cache_dir(Some(cache_dir))
.lua_version(Some(LuaVersion::Lua51))
.no_progress(Some(true))
.build()
.unwrap();
let progress = MultiProgress::new(&config);
let bar = progress.map(MultiProgress::new_bar);
manifest_from_cache_or_server(&Url::parse(&url_str).unwrap(), &config, &bar)
.await
.unwrap();
}
#[tokio::test]
#[serial]
pub async fn get_cached_manifest() {
let server = start_test_server("manifest-5.1".into());
let mut url_str = server.url_str(""); url_str.pop();
let manifest_content = std::fs::read_to_string(
format!("{}/resources/test/manifest-5.1", env!("CARGO_MANIFEST_DIR")).as_str(),
)
.unwrap();
let cache_dir = assert_fs::TempDir::new().unwrap();
let cache = cache_dir.join("manifest-5.1");
fs::write(&cache, &manifest_content).await.unwrap();
let _metadata = fs::metadata(&cache).await.unwrap();
let config = ConfigBuilder::new()
.unwrap()
.cache_dir(Some(cache_dir.to_path_buf()))
.lua_version(Some(LuaVersion::Lua51))
.no_progress(Some(true))
.build()
.unwrap();
let progress = MultiProgress::new(&config);
let bar = progress.map(MultiProgress::new_bar);
let result = manifest_from_cache_or_server(&Url::parse(&url_str).unwrap(), &config, &bar)
.await
.unwrap();
assert_eq!(result, manifest_content);
}
}