use crate::{
consts,
environment::{get_up_to_date_prefix, verify_prefix_location_unchanged, LockFileUsage},
project::{manifest::PyPiRequirement, DependencyType, Project, SpecType},
FeatureName,
};
use clap::Parser;
use itertools::{Either, Itertools};
use crate::project::grouped_environment::GroupedEnvironment;
use indexmap::IndexMap;
use miette::{IntoDiagnostic, WrapErr};
use rattler_conda_types::{
version_spec::{LogicalOperator, RangeOperator},
Channel, MatchSpec, NamelessMatchSpec, PackageName, Platform, Version, VersionBumpType,
VersionSpec,
};
use rattler_repodata_gateway::sparse::SparseRepoData;
use rattler_solve::{resolvo, SolverImpl};
use std::{
collections::{HashMap, HashSet},
path::PathBuf,
str::FromStr,
};
#[derive(Parser, Debug, Default)]
#[clap(arg_required_else_help = true)]
pub struct Args {
#[arg(required = true)]
pub specs: Vec<String>,
#[arg(long)]
pub manifest_path: Option<PathBuf>,
#[arg(long, conflicts_with = "build")]
pub host: bool,
#[arg(long, conflicts_with = "host")]
pub build: bool,
#[arg(long, conflicts_with_all = ["host", "build"])]
pub pypi: bool,
#[clap(long, conflicts_with = "no_install")]
pub no_lockfile_update: bool,
#[arg(long)]
pub no_install: bool,
#[arg(long, short)]
pub platform: Vec<Platform>,
#[arg(long, short)]
pub feature: Option<String>,
}
impl DependencyType {
pub fn from_args(args: &Args) -> Self {
if args.pypi {
Self::PypiDependency
} else if args.host {
DependencyType::CondaDependency(SpecType::Host)
} else if args.build {
DependencyType::CondaDependency(SpecType::Build)
} else {
DependencyType::CondaDependency(SpecType::Run)
}
}
}
pub async fn execute(args: Args) -> miette::Result<()> {
let mut project = Project::load_or_else_discover(args.manifest_path.as_deref())?;
let dependency_type = DependencyType::from_args(&args);
let spec_platforms = &args.platform;
verify_prefix_location_unchanged(
project
.default_environment()
.dir()
.join(consts::PREFIX_FILE_NAME)
.as_path(),
)?;
let platforms_to_add = spec_platforms
.iter()
.filter(|p| !project.platforms().contains(p))
.cloned()
.collect::<Vec<Platform>>();
project
.manifest
.add_platforms(platforms_to_add.iter(), &FeatureName::Default)?;
let feature_name = args
.feature
.map_or(FeatureName::Default, FeatureName::Named);
match dependency_type {
DependencyType::CondaDependency(spec_type) => {
let specs = args
.specs
.clone()
.into_iter()
.map(|s| MatchSpec::from_str(&s))
.collect::<Result<Vec<_>, _>>()
.into_diagnostic()?;
add_conda_specs_to_project(
&mut project,
&feature_name,
specs,
spec_type,
args.no_install,
args.no_lockfile_update,
spec_platforms,
)
.await
}
DependencyType::PypiDependency => {
let pep508_requirements = args
.specs
.clone()
.into_iter()
.map(|input| pep508_rs::Requirement::from_str(input.as_ref()).into_diagnostic())
.collect::<miette::Result<Vec<_>>>()?;
let specs = pep508_requirements
.into_iter()
.map(|req| {
let name = rip::types::PackageName::from_str(req.name.as_str())?;
let requirement = PyPiRequirement::from(req);
Ok((name, requirement))
})
.collect::<Result<Vec<_>, rip::types::ParsePackageNameError>>()
.into_diagnostic()?;
add_pypi_specs_to_project(
&mut project,
specs,
spec_platforms,
args.no_lockfile_update,
args.no_install,
)
.await
}
}?;
for package in args.specs {
eprintln!(
"{}Added {}",
console::style(console::Emoji("✔ ", "")).green(),
console::style(package).bold(),
);
}
if !matches!(
dependency_type,
DependencyType::CondaDependency(SpecType::Run)
) {
eprintln!(
"Added these as {}.",
console::style(dependency_type.name()).bold()
);
}
if !args.platform.is_empty() {
eprintln!(
"Added these only for platform(s): {}",
console::style(args.platform.iter().join(", ")).bold()
)
}
Ok(())
}
pub async fn add_pypi_specs_to_project(
project: &mut Project,
specs: Vec<(rip::types::PackageName, PyPiRequirement)>,
specs_platforms: &Vec<Platform>,
no_update_lockfile: bool,
no_install: bool,
) -> miette::Result<()> {
for (name, spec) in &specs {
if specs_platforms.is_empty() {
project.manifest.add_pypi_dependency(name, spec, None)?;
} else {
for platform in specs_platforms.iter() {
project
.manifest
.add_pypi_dependency(name, spec, Some(*platform))?;
}
}
}
let lock_file_usage = if no_update_lockfile {
LockFileUsage::Frozen
} else {
LockFileUsage::Update
};
get_up_to_date_prefix(
&project.default_environment(),
lock_file_usage,
no_install,
IndexMap::default(),
)
.await?;
project.save()?;
Ok(())
}
pub async fn add_conda_specs_to_project(
project: &mut Project,
feature_name: &FeatureName,
specs: Vec<MatchSpec>,
spec_type: SpecType,
no_install: bool,
no_update_lockfile: bool,
specs_platforms: &Vec<Platform>,
) -> miette::Result<()> {
let new_specs = specs
.into_iter()
.map(|spec| match &spec.name {
Some(name) => Ok((name.clone(), spec.into())),
None => Err(miette::miette!("missing package name for spec '{spec}'")),
})
.collect::<miette::Result<HashMap<PackageName, NamelessMatchSpec>>>()?;
let sparse_repo_data = project.fetch_sparse_repodata().await?;
let mut package_versions = HashMap::<PackageName, HashSet<Version>>::new();
let grouped_environments: Vec<GroupedEnvironment> = project
.grouped_environments()
.iter()
.filter(|env| {
env.features()
.map(|feat| &feat.name)
.contains(&feature_name)
})
.cloned()
.collect();
for grouped_environment in grouped_environments {
let platforms = if specs_platforms.is_empty() {
Either::Left(grouped_environment.platforms().into_iter())
} else {
Either::Right(specs_platforms.iter().copied())
};
for platform in platforms {
let solved_versions = match determine_best_version(
&grouped_environment,
&new_specs,
spec_type,
&sparse_repo_data,
platform,
) {
Ok(versions) => versions,
Err(err) => {
return Err(err).wrap_err_with(|| miette::miette!(
"could not determine any available versions for {} on {platform}. Either the package could not be found or version constraints on other dependencies result in a conflict.",
new_specs.keys().map(|s| s.as_source()).join(", ")
));
}
};
for (name, version) in solved_versions {
package_versions.entry(name).or_default().insert(version);
}
}
}
for (name, spec) in new_specs {
let updated_spec = if spec.version.is_none() {
let mut updated_spec = spec.clone();
if let Some(versions_seen) = package_versions.get(&name).cloned() {
updated_spec.version = determine_version_constraint(&versions_seen);
} else {
updated_spec.version = determine_version_constraint(&determine_latest_versions(
project,
specs_platforms,
&sparse_repo_data,
&name,
)?);
}
updated_spec
} else {
spec
};
let spec = MatchSpec::from_nameless(updated_spec, Some(name));
if specs_platforms.is_empty() {
project
.manifest
.add_dependency(&spec, spec_type, None, feature_name)?;
} else {
for platform in specs_platforms.iter() {
project
.manifest
.add_dependency(&spec, spec_type, Some(*platform), feature_name)?;
}
}
}
let lock_file_usage = if no_update_lockfile {
LockFileUsage::Frozen
} else {
LockFileUsage::Update
};
get_up_to_date_prefix(
&project.default_environment(),
lock_file_usage,
no_install,
sparse_repo_data,
)
.await?;
project.save()?;
Ok(())
}
fn determine_latest_versions(
project: &Project,
platforms: &Vec<Platform>,
sparse_repo_data: &IndexMap<(Channel, Platform), SparseRepoData>,
name: &PackageName,
) -> miette::Result<Vec<Version>> {
let mut found_records = Vec::new();
let platforms = if platforms.is_empty() {
let mut temp = project.platforms().into_iter().collect_vec();
temp.push(Platform::NoArch);
temp
} else {
let mut temp = platforms.clone();
temp.push(Platform::NoArch);
temp
};
for channel in project.channels() {
for platform in &platforms {
let sparse_repo_data = sparse_repo_data.get(&(channel.clone(), *platform));
if let Some(sparse_repo_data) = sparse_repo_data {
let records = sparse_repo_data.load_records(name).into_diagnostic()?;
if let Some(max_record) = records
.into_iter()
.max_by_key(|record| record.package_record.version.version().clone())
{
found_records.push(max_record);
}
};
}
}
Ok(found_records
.iter()
.map(|record| record.package_record.version.version().clone())
.collect_vec())
}
pub fn determine_best_version(
environment: &GroupedEnvironment,
new_specs: &HashMap<PackageName, NamelessMatchSpec>,
new_specs_type: SpecType,
sparse_repo_data: &IndexMap<(Channel, Platform), SparseRepoData>,
platform: Platform,
) -> miette::Result<HashMap<PackageName, Version>> {
let dependencies = SpecType::all()
.map(|spec_type| {
let mut deps = environment.dependencies(Some(spec_type), Some(platform));
if spec_type == new_specs_type {
for (new_name, new_spec) in new_specs.iter() {
deps.remove(new_name); deps.insert(new_name.clone(), new_spec.clone()); }
}
deps
})
.reduce(|acc, deps| acc.overwrite(&deps))
.unwrap_or_default();
let package_names = dependencies.names().cloned().collect_vec();
let platform_sparse_repo_data = environment
.channels()
.into_iter()
.cloned()
.cartesian_product(vec![platform, Platform::NoArch])
.filter_map(|target| sparse_repo_data.get(&target));
let available_packages = SparseRepoData::load_records_recursive(
platform_sparse_repo_data,
package_names.iter().cloned(),
None,
)
.into_diagnostic()?;
let task = rattler_solve::SolverTask {
specs: dependencies
.iter_specs()
.map(|(name, spec)| MatchSpec::from_nameless(spec.clone(), Some(name.clone())))
.collect(),
available_packages: &available_packages,
virtual_packages: environment.virtual_packages(platform),
locked_packages: vec![],
pinned_packages: vec![],
timeout: None,
};
let records = resolvo::Solver.solve(task).into_diagnostic()?;
Ok(records
.into_iter()
.filter(|record| new_specs.contains_key(&record.package_record.name))
.map(|record| {
(
record.package_record.name,
record.package_record.version.into(),
)
})
.collect())
}
fn determine_version_constraint<'a>(
versions: impl IntoIterator<Item = &'a Version>,
) -> Option<VersionSpec> {
let (min_version, max_version) = versions.into_iter().minmax().into_option()?;
let lower_bound = min_version.clone();
let upper_bound = max_version
.pop_segments(1)
.unwrap_or_else(|| max_version.clone())
.bump(VersionBumpType::Last)
.ok()?;
Some(VersionSpec::Group(
LogicalOperator::And,
vec![
VersionSpec::Range(RangeOperator::GreaterEquals, lower_bound),
VersionSpec::Range(RangeOperator::Less, upper_bound),
],
))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_determine_version_constraint() {
insta::assert_snapshot!(determine_version_constraint(&["1.2.0".parse().unwrap()])
.unwrap()
.to_string(), @">=1.2.0,<1.3");
insta::assert_snapshot!(determine_version_constraint(&["1.2.0".parse().unwrap(), "1.3.0".parse().unwrap()])
.unwrap()
.to_string(), @">=1.2.0,<1.4");
}
}