use std::{sync::Arc, time::Duration};
use regex::Regex;
use soar_config::config::get_config;
use soar_core::{database::models::Package, package::query::PackageQuery, SoarResult};
use soar_db::repository::metadata::MetadataRepository;
use soar_dl::{
download::Download,
error::DownloadError,
filter::Filter,
github::Github,
gitlab::GitLab,
oci::OciDownload,
platform::PlatformUrl,
traits::{Asset, Platform as _, Release as _},
types::{OverwriteMode, Progress},
};
use soar_utils::bytes::format_bytes;
use tokio::time::sleep;
use tracing::{error, info};
use crate::utils::{interactive_ask, select_package_interactively};
pub struct DownloadContext {
pub regexes: Vec<Regex>,
pub globs: Vec<String>,
pub match_keywords: Vec<String>,
pub exclude_keywords: Vec<String>,
pub output: Option<String>,
pub yes: bool,
pub progress_callback: Arc<dyn Fn(Progress) + Send + Sync>,
pub exact_case: bool,
pub extract: bool,
pub extract_dir: Option<String>,
pub skip_existing: bool,
pub force_overwrite: bool,
}
impl DownloadContext {
fn get_overwrite_mode(&self) -> OverwriteMode {
if self.force_overwrite || self.yes {
OverwriteMode::Force
} else if self.skip_existing {
OverwriteMode::Skip
} else {
OverwriteMode::Prompt
}
}
fn create_filter(&self) -> Filter {
Filter {
regexes: self.regexes.clone(),
globs: self.globs.clone(),
include: self.match_keywords.clone(),
exclude: self.exclude_keywords.clone(),
case_sensitive: self.exact_case,
}
}
}
pub async fn download(
ctx: DownloadContext,
links: Vec<String>,
github: Vec<String>,
gitlab: Vec<String>,
ghcr: Vec<String>,
) -> SoarResult<()> {
handle_direct_downloads(&ctx, links, ctx.output.clone()).await?;
if !github.is_empty() {
handle_github_downloads(&ctx, github).await?;
}
if !gitlab.is_empty() {
handle_gitlab_downloads(&ctx, gitlab).await?;
}
if !ghcr.is_empty() {
handle_oci_downloads(&ctx, ghcr).await?;
}
Ok(())
}
pub async fn handle_direct_downloads(
ctx: &DownloadContext,
links: Vec<String>,
output: Option<String>,
) -> SoarResult<()> {
for link in &links {
match PlatformUrl::parse(link) {
Some(PlatformUrl::Direct {
url,
}) => {
info!("Downloading using direct link: {}", url);
let mut dl = Download::new(url)
.overwrite(ctx.get_overwrite_mode())
.extract(ctx.extract);
if let Some(ref out) = output {
dl = dl.output(out);
}
if let Some(extract_dir) = ctx.extract_dir.clone() {
dl = dl.extract_to(extract_dir);
}
let cb = ctx.progress_callback.clone();
dl = dl.progress(move |p| {
cb(p);
});
if let Err(err) = dl.execute() {
error!("{}", err);
}
}
Some(PlatformUrl::Github {
project,
tag,
}) => {
info!("Detected GitHub URL, processing as GitHub release");
if let Err(err) = handle_github_release(ctx, &project, tag.as_deref()) {
error!("{}", err);
}
}
Some(PlatformUrl::Gitlab {
project,
tag,
}) => {
info!("Detected GitLab URL, processing as GitLab release");
if let Err(err) = handle_gitlab_release(ctx, &project, tag.as_deref()) {
error!("{}", err);
}
}
Some(PlatformUrl::Oci {
reference,
}) => {
if let Err(err) = handle_oci_download(ctx, &reference).await {
error!("{}", err);
};
}
None => {
let (ctx_inner, _) = crate::create_context();
let metadata_mgr = ctx_inner.metadata_manager().await?;
let query = PackageQuery::try_from(link.as_str())?;
let packages: Vec<Package> = if let Some(ref repo_name) = query.repo_name {
metadata_mgr
.query_repo(repo_name, |conn| {
MetadataRepository::find_filtered(
conn,
query.name.as_deref(),
query.pkg_id.as_deref(),
None,
None,
None,
)
})?
.unwrap_or_default()
.into_iter()
.map(|p| {
let mut pkg: Package = p.into();
pkg.repo_name = repo_name.clone();
pkg
})
.collect()
} else {
metadata_mgr.query_all_flat(|repo_name, conn| {
let pkgs = MetadataRepository::find_filtered(
conn,
query.name.as_deref(),
query.pkg_id.as_deref(),
None,
None,
None,
)?;
Ok(pkgs
.into_iter()
.map(|p| {
let mut pkg: Package = p.into();
pkg.repo_name = repo_name.to_string();
pkg
})
.collect())
})?
};
let packages: Vec<Package> = if let Some(ref version) = query.version {
packages
.into_iter()
.filter(|p| p.has_version(version))
.collect()
} else {
packages
};
if packages.is_empty() {
error!("Invalid download resource '{}'", link);
break;
}
let package = if packages.len() == 1 || ctx.yes {
packages.first().unwrap().clone()
} else {
select_package_interactively(packages, link)?.unwrap()
};
let package = package.resolve(query.version.as_deref());
info!(
"Downloading package: {}#{}",
package.pkg_name, package.pkg_id
);
if let Some(ref url) = package.ghcr_blob {
let mut dl = OciDownload::new(url.as_str()).overwrite(ctx.get_overwrite_mode());
if let Some(ref out) = output {
dl = dl.output(out);
}
let cb = ctx.progress_callback.clone();
dl = dl.progress(move |p| {
cb(p);
});
if let Err(err) = dl.execute() {
error!("{}", err);
}
} else {
let mut dl =
Download::new(&package.download_url).overwrite(ctx.get_overwrite_mode());
if let Some(ref out) = output {
dl = dl.output(out);
}
let cb = ctx.progress_callback.clone();
dl = dl.progress(move |p| {
cb(p);
});
dl.execute()?;
}
}
};
}
Ok(())
}
async fn handle_oci_download(ctx: &DownloadContext, reference: &str) -> SoarResult<()> {
info!("Downloading using OCI reference: {}", reference);
let mut dl = OciDownload::new(reference)
.filter(ctx.create_filter())
.parallel(get_config().ghcr_concurrency.unwrap_or(8))
.overwrite(ctx.get_overwrite_mode());
if let Some(ref output) = ctx.output {
dl = dl.output(output);
}
let cb = ctx.progress_callback.clone();
dl = dl.progress(move |p| {
cb(p);
});
let mut retries = 0;
let max_retries = 5;
loop {
match dl.clone().execute() {
Ok(_) => {
info!("Download completed successfully");
break;
}
Err(err)
if matches!(
err,
DownloadError::HttpError {
status: 429,
..
} | DownloadError::Network(_)
) && retries < max_retries =>
{
retries += 1;
info!("Retrying... ({}/{})", retries, max_retries);
ctx.progress_callback.clone()(Progress::Recovered);
sleep(Duration::from_secs(5)).await;
}
Err(err) => {
ctx.progress_callback.clone()(Progress::Error);
error!("Download failed: {}", err);
return Err(err.into());
}
}
}
Ok(())
}
pub async fn handle_oci_downloads(
ctx: &DownloadContext,
references: Vec<String>,
) -> SoarResult<()> {
for reference in &references {
handle_oci_download(ctx, reference).await?;
}
Ok(())
}
fn handle_github_release(
ctx: &DownloadContext,
project: &str,
tag: Option<&str>,
) -> SoarResult<()> {
let releases = Github::fetch_releases(project, tag)?;
let release = if let Some(tag) = tag {
releases.iter().find(|r| r.tag() == tag)
} else {
releases
.iter()
.find(|r| !r.is_prerelease())
.or_else(|| releases.first())
};
let release = release.ok_or_else(|| DownloadError::InvalidResponse)?;
info!("Found release: {}", release.tag());
let filter = ctx.create_filter();
let assets: Vec<_> = release
.assets()
.iter()
.filter(|a| filter.matches(a.name()))
.collect();
if assets.is_empty() {
let available = release
.assets()
.iter()
.map(|a| a.name().to_string())
.collect::<Vec<String>>();
Err(DownloadError::NoMatch {
available,
})?
}
let selected_asset = if assets.len() == 1 || ctx.yes {
assets[0]
} else {
&select_asset_interactively(assets)?
};
info!("Downloading asset: {}", selected_asset.name());
let mut dl = Download::new(selected_asset.url())
.overwrite(ctx.get_overwrite_mode())
.extract(ctx.extract);
if let Some(ref out) = ctx.output {
dl = dl.output(out);
}
if let Some(ref extract_dir) = ctx.extract_dir {
dl = dl.extract_to(extract_dir);
}
let cb = ctx.progress_callback.clone();
dl = dl.progress(move |p| {
cb(p);
});
dl.execute()?;
Ok(())
}
fn handle_gitlab_release(
ctx: &DownloadContext,
project: &str,
tag: Option<&str>,
) -> SoarResult<()> {
let releases = GitLab::fetch_releases(project, tag)?;
let release = if let Some(tag) = tag {
releases.iter().find(|r| r.tag() == tag)
} else {
releases
.iter()
.find(|r| !r.is_prerelease())
.or_else(|| releases.first())
};
let release = release.ok_or_else(|| DownloadError::InvalidResponse)?;
info!("Found release: {}", release.tag());
let filter = ctx.create_filter();
let assets: Vec<_> = release
.assets()
.iter()
.filter(|a| filter.matches(a.name()))
.collect();
if assets.is_empty() {
let available = release
.assets()
.iter()
.map(|a| a.name().to_string())
.collect::<Vec<String>>();
Err(DownloadError::NoMatch {
available,
})?
}
let selected_asset = if assets.len() == 1 || ctx.yes {
assets[0]
} else {
&select_asset_interactively(assets)?
};
info!("Downloading asset: {}", selected_asset.name());
let mut dl = Download::new(selected_asset.url())
.overwrite(ctx.get_overwrite_mode())
.extract(ctx.extract);
if let Some(ref out) = ctx.output {
dl = dl.output(out);
}
if let Some(ref extract_dir) = ctx.extract_dir {
dl = dl.extract_to(extract_dir);
}
let cb = ctx.progress_callback.clone();
dl = dl.progress(move |p| {
cb(p);
});
dl.execute()?;
Ok(())
}
pub fn create_regex_patterns(regex_patterns: Option<Vec<String>>) -> SoarResult<Vec<Regex>> {
match regex_patterns {
Some(patterns) => {
patterns
.iter()
.map(|pattern| Regex::new(pattern).map_err(|err| err.into()))
.collect()
}
None => Ok(Vec::new()),
}
}
pub async fn handle_github_downloads(
ctx: &DownloadContext,
projects: Vec<String>,
) -> SoarResult<()> {
for project in &projects {
info!("Fetching releases from GitHub: {}", project);
let (project, tag) = match project.trim().split_once('@') {
Some((proj, tag)) if !tag.trim().is_empty() => (proj, Some(tag.trim())),
_ => (project.trim_end_matches('@'), None),
};
if let Err(err) = handle_github_release(ctx, project, tag) {
error!("{}", err);
}
}
Ok(())
}
pub async fn handle_gitlab_downloads(
ctx: &DownloadContext,
projects: Vec<String>,
) -> SoarResult<()> {
for project in &projects {
info!("Fetching releases from GitLab: {}", project);
let (project, tag) = match project.trim().split_once('@') {
Some((proj, tag)) if !tag.trim().is_empty() => (proj, Some(tag.trim())),
_ => (project.trim_end_matches('@'), None),
};
if let Err(err) = handle_gitlab_release(ctx, project, tag) {
error!("{}", err);
}
}
Ok(())
}
fn select_asset_interactively<A>(assets: Vec<&A>) -> SoarResult<A>
where
A: Asset + Clone,
{
info!("\nAvailable assets:");
for (i, asset) in assets.iter().enumerate() {
let size = asset
.size()
.map(|s| format!(" ({})", format_bytes(s, 2)))
.unwrap_or_default();
info!(" {}. {}{}", i + 1, asset.name(), size);
}
loop {
let max = assets.len();
let response = interactive_ask(&format!("Select an asset (1-{}): ", max))?;
match response.trim().parse::<usize>() {
Ok(n) if n > 0 && n <= max => return Ok(assets[n - 1].clone()),
_ => error!("Invalid selection, please try again."),
}
}
}