use std::{
borrow::Cow,
collections::HashMap,
env,
future::IntoFuture,
path::PathBuf,
str::FromStr,
sync::Arc,
time::{Duration, Instant},
};
use clap::ValueEnum;
use indicatif::{ProgressBar, ProgressStyle};
use itertools::Itertools;
use miette::{Context, IntoDiagnostic};
use rattler::{
default_cache_dir,
install::{IndicatifReporter, Installer, Transaction, TransactionOperation},
package_cache::PackageCache,
};
use rattler_conda_types::{
Channel, ChannelConfig, GenericVirtualPackage, MatchSpec, Matches, PackageName,
ParseMatchSpecOptions, Platform, PrefixRecord, RepoDataRecord, Version,
};
use rattler_networking::AuthenticationMiddleware;
#[cfg(feature = "s3")]
use rattler_networking::AuthenticationStorage;
use rattler_repodata_gateway::{Gateway, RepoData, SourceConfig};
use rattler_solve::{
libsolv_c::{self},
resolvo, SolverImpl, SolverTask,
};
use reqwest::Client;
use crate::{exclude_newer::ExcludeNewer, global_multi_progress};
#[derive(Debug, clap::Parser)]
pub struct Opt {
#[clap(short, long = "channel")]
channels: Option<Vec<String>>,
#[clap(required = true)]
specs: Vec<String>,
#[clap(long)]
dry_run: bool,
#[clap(long)]
platform: Option<String>,
#[clap(long)]
virtual_package: Option<Vec<String>>,
#[clap(long)]
solver: Option<Solver>,
#[clap(long)]
timeout: Option<u64>,
#[clap(short = 'p', long = "prefix", visible_alias = "target-prefix")]
target_prefix: Option<PathBuf>,
#[clap(long)]
strategy: Option<SolveStrategy>,
#[clap(long, group = "deps_mode")]
only_deps: bool,
#[clap(long, group = "deps_mode")]
no_deps: bool,
#[clap(long)]
exclude_newer: Option<ExcludeNewer>,
}
#[derive(Debug, Clone, Copy, ValueEnum)]
pub enum SolveStrategy {
Highest,
Lowest,
LowestDirect,
}
#[derive(Default, Debug, Clone, Copy, ValueEnum)]
pub enum Solver {
#[default]
Resolvo,
#[value(name = "libsolv")]
LibSolv,
}
impl From<SolveStrategy> for rattler_solve::SolveStrategy {
fn from(value: SolveStrategy) -> Self {
match value {
SolveStrategy::Highest => rattler_solve::SolveStrategy::Highest,
SolveStrategy::Lowest => rattler_solve::SolveStrategy::LowestVersion,
SolveStrategy::LowestDirect => rattler_solve::SolveStrategy::LowestVersionDirect,
}
}
}
pub async fn create(opt: Opt) -> miette::Result<()> {
let channel_config =
ChannelConfig::default_with_root_dir(env::current_dir().into_diagnostic()?);
let current_dir = env::current_dir().into_diagnostic()?;
let target_prefix = opt
.target_prefix
.unwrap_or_else(|| current_dir.join(".prefix"));
let target_prefix = std::path::absolute(target_prefix).into_diagnostic()?;
println!("Target prefix: {}", target_prefix.display());
let install_platform = if let Some(platform) = opt.platform {
Platform::from_str(&platform).into_diagnostic()?
} else {
Platform::current()
};
println!("Installing for platform: {install_platform:?}");
let match_spec_options = ParseMatchSpecOptions::strict()
.with_experimental_extras(true)
.with_experimental_conditionals(true);
let specs = opt
.specs
.iter()
.map(|spec| MatchSpec::from_str(spec, match_spec_options))
.collect::<Result<Vec<_>, _>>()
.into_diagnostic()?;
let cache_dir = default_cache_dir()
.map_err(|e| miette::miette!("could not determine default cache directory: {}", e))?;
rattler_cache::ensure_cache_dir(&cache_dir)
.map_err(|e| miette::miette!("could not create cache directory: {}", e))?;
let channels = opt
.channels
.unwrap_or_else(|| vec![String::from("conda-forge")])
.into_iter()
.map(|channel_str| Channel::from_str(channel_str, &channel_config))
.collect::<Result<Vec<_>, _>>()
.into_diagnostic()?;
let installed_packages =
PrefixRecord::collect_from_prefix::<PrefixRecord>(&target_prefix).into_diagnostic()?;
let download_client = Client::builder()
.no_gzip()
.build()
.expect("failed to create client");
let download_client = reqwest_middleware::ClientBuilder::new(download_client.clone())
.with_arc(Arc::new(
AuthenticationMiddleware::from_env_and_defaults().into_diagnostic()?,
))
.with(rattler_networking::OciMiddleware::new(download_client));
#[cfg(feature = "s3")]
let download_client = download_client.with(rattler_networking::S3Middleware::new(
HashMap::new(),
AuthenticationStorage::from_env_and_defaults().into_diagnostic()?,
));
#[cfg(feature = "gcs")]
let download_client = download_client.with(rattler_networking::GCSMiddleware::default());
let download_client = download_client.build();
let gateway = Gateway::builder()
.with_cache_dir(cache_dir.join(rattler_cache::REPODATA_CACHE_DIR))
.with_package_cache(PackageCache::new(
cache_dir.join(rattler_cache::PACKAGE_CACHE_DIR),
))
.with_client(download_client.clone())
.with_channel_config(rattler_repodata_gateway::ChannelConfig {
default: SourceConfig {
sharded_enabled: true,
..SourceConfig::default()
},
per_channel: HashMap::new(),
})
.finish();
let start_load_repo_data = Instant::now();
let repo_data = wrap_in_async_progress(
"loading repodata",
gateway
.query(
channels,
[install_platform, Platform::NoArch],
specs.clone(),
)
.recursive(true),
)
.await
.into_diagnostic()
.context("failed to load repodata")?;
let total_records: usize = repo_data.iter().map(RepoData::len).sum();
println!(
"Loaded {} records in {:?}",
total_records,
start_load_repo_data.elapsed()
);
let virtual_packages = wrap_in_progress("determining virtual packages", move || {
if let Some(virtual_packages) = opt.virtual_package {
Ok(virtual_packages
.iter()
.map(|virt_pkg| {
let elems = virt_pkg.split('=').collect::<Vec<&str>>();
Ok(GenericVirtualPackage {
name: elems[0].try_into().into_diagnostic()?,
version: elems
.get(1)
.map_or(Version::from_str("0"), |s| Version::from_str(s))
.expect("Could not parse virtual package version"),
build_string: (*elems.get(2).unwrap_or(&"")).to_string(),
})
})
.collect::<miette::Result<Vec<_>>>()?)
} else {
rattler_virtual_packages::VirtualPackage::detect(
&rattler_virtual_packages::VirtualPackageOverrides::default(),
)
.map(|vpkgs| {
vpkgs
.iter()
.map(|vpkg| GenericVirtualPackage::from(vpkg.clone()))
.collect::<Vec<_>>()
})
.into_diagnostic()
}
})?;
println!(
"Virtual packages:\n{}\n",
virtual_packages
.iter()
.format_with("\n", |i, f| f(&format_args!(" - {i}",)))
);
let locked_packages = installed_packages
.iter()
.map(|record| record.repodata_record.clone())
.collect();
let solver_task = SolverTask {
locked_packages,
virtual_packages,
specs: specs.clone(),
timeout: opt.timeout.map(Duration::from_millis),
strategy: opt.strategy.map_or_else(Default::default, Into::into),
exclude_newer: opt.exclude_newer.map(Into::into),
..SolverTask::from_iter(&repo_data)
};
let solver_result = wrap_in_progress("solving", move || match opt.solver.unwrap_or_default() {
Solver::Resolvo => resolvo::Solver.solve(solver_task),
Solver::LibSolv => libsolv_c::Solver.solve(solver_task),
})
.into_diagnostic()?;
let mut required_packages: Vec<RepoDataRecord> = solver_result.records;
if opt.no_deps {
required_packages.retain(|r| specs.iter().any(|s| s.matches(&r.package_record)));
} else if opt.only_deps {
required_packages.retain(|r| !specs.iter().any(|s| s.matches(&r.package_record)));
};
if opt.dry_run {
let transaction = Transaction::from_current_and_desired(
installed_packages,
required_packages,
None,
None, install_platform,
)
.into_diagnostic()?;
if transaction.operations.is_empty() {
println!("No operations necessary");
} else {
print_transaction(&transaction, solver_result.extras);
}
return Ok(());
}
let install_start = Instant::now();
let result = Installer::new()
.with_download_client(download_client)
.with_target_platform(install_platform)
.with_installed_packages(installed_packages)
.with_execute_link_scripts(true)
.with_requested_specs(specs)
.with_reporter(
IndicatifReporter::builder()
.with_multi_progress(global_multi_progress())
.finish(),
)
.install(&target_prefix, required_packages)
.await
.into_diagnostic()?;
if result.transaction.operations.is_empty() {
println!(
"{} Already up to date",
console::style(console::Emoji("✔", "")).green(),
);
} else {
println!(
"{} Successfully updated the environment in {:?}",
console::style(console::Emoji("✔", "")).green(),
install_start.elapsed()
);
let transaction = result
.transaction
.into_prefix_record(target_prefix)
.unwrap();
print_transaction(&transaction, solver_result.extras);
}
Ok(())
}
fn print_transaction(
transaction: &Transaction<PrefixRecord, RepoDataRecord>,
features: HashMap<PackageName, Vec<String>>,
) {
let format_record = |r: &RepoDataRecord| {
let direct_url_print = if let Some(channel) = &r.channel {
channel.clone()
} else {
String::new()
};
if let Some(features) = features.get(&r.package_record.name) {
format!(
"{}[{}] {} {} {}",
r.package_record.name.as_normalized(),
features.join(", "),
r.package_record.version,
r.package_record.build,
direct_url_print,
)
} else {
format!(
"{} {} {} {}",
r.package_record.name.as_normalized(),
r.package_record.version,
r.package_record.build,
direct_url_print,
)
}
};
for operation in &transaction.operations {
match operation {
TransactionOperation::Install(r) => {
println!("{} {}", console::style("+").green(), format_record(r));
}
TransactionOperation::Change { old, new } => {
println!(
"{} {} -> {}",
console::style("~").yellow(),
format_record(&old.repodata_record),
format_record(new)
);
}
TransactionOperation::Reinstall { old, .. } => {
println!(
"{} {}",
console::style("~").yellow(),
format_record(&old.repodata_record)
);
}
TransactionOperation::Remove(r) => {
println!(
"{} {}",
console::style("-").red(),
format_record(&r.repodata_record)
);
}
}
}
}
fn wrap_in_progress<T, F: FnOnce() -> T>(msg: impl Into<Cow<'static, str>>, func: F) -> T {
let pb = ProgressBar::new_spinner();
pb.enable_steady_tick(Duration::from_millis(100));
pb.set_style(long_running_progress_style());
pb.set_message(msg);
let result = func();
pb.finish_and_clear();
result
}
async fn wrap_in_async_progress<T, F: IntoFuture<Output = T>>(
msg: impl Into<Cow<'static, str>>,
fut: F,
) -> T {
let pb = ProgressBar::new_spinner();
pb.enable_steady_tick(Duration::from_millis(100));
pb.set_style(long_running_progress_style());
pb.set_message(msg);
let result = fut.into_future().await;
pb.finish_and_clear();
result
}
fn long_running_progress_style() -> indicatif::ProgressStyle {
ProgressStyle::with_template("{spinner:.green} {msg}").unwrap()
}