use std::borrow::Cow;
use either::Either;
use rustc_hash::{FxBuildHasher, FxHashMap, FxHashSet};
use serde::de::IntoDeserializer;
use uv_distribution_types::{Requirement, RequirementSource};
use uv_normalize::PackageName;
use uv_pep440::Version;
use uv_pep508::MarkerTree;
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, serde::Serialize, serde::Deserialize)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[serde(
rename_all = "kebab-case",
deny_unknown_fields,
bound(
serialize = "T: serde::Serialize",
deserialize = "T: serde::Deserialize<'de>"
)
)]
pub struct PackageOverride<T> {
pub package: PackageOverrideTarget,
pub dependencies: Box<[T]>,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, serde::Serialize, serde::Deserialize)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[serde(rename_all = "kebab-case", deny_unknown_fields)]
pub struct PackageOverrideTarget {
pub name: PackageName,
#[cfg_attr(
feature = "schemars",
schemars(
with = "Option<String>",
description = "PEP 440-style package version, e.g., `1.2.3`"
)
)]
pub version: Option<Version>,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, serde::Serialize)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema), schemars(untagged))]
#[serde(untagged, bound(serialize = "T: serde::Serialize"))]
pub enum Override<T> {
Package(PackageOverride<T>),
Requirement(T),
}
impl<'de, T> serde::Deserialize<'de> for Override<T>
where
T: serde::Deserialize<'de>,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(serde::Deserialize)]
#[serde(untagged)]
enum MapOverride<T> {
Package(PackageOverride<T>),
Requirement(T),
}
serde_untagged::UntaggedEnumVisitor::new()
.string(|string| T::deserialize(string.into_deserializer()).map(Self::Requirement))
.map(|map| {
map.deserialize::<MapOverride<T>>()
.map(|entry| match entry {
MapOverride::Package(package) => Self::Package(package),
MapOverride::Requirement(requirement) => Self::Requirement(requirement),
})
})
.deserialize(deserializer)
}
}
#[derive(Debug, Default, Clone)]
pub struct Overrides {
global: FxHashMap<PackageName, Vec<Requirement>>,
scoped: FxHashMap<PackageName, Vec<ScopedOverrides>>,
}
#[derive(Debug, Clone)]
struct ScopedOverrides {
version: Option<Version>,
overrides: FxHashMap<PackageName, Vec<Requirement>>,
}
#[derive(Debug, thiserror::Error)]
pub enum ScopedOverrideSourceError {
#[error(
"Scoped override for `{package}` cannot use a URL or path source for `{dependency}`; scoped overrides currently support version specifiers only"
)]
Url {
package: PackageName,
dependency: PackageName,
},
#[error(
"Scoped override for `{package}` cannot use an explicit index for `{dependency}`; scoped overrides currently support version specifiers only"
)]
Index {
package: PackageName,
dependency: PackageName,
},
}
impl Overrides {
pub fn from_requirements(requirements: Vec<Requirement>) -> Self {
let mut global: FxHashMap<PackageName, Vec<Requirement>> =
FxHashMap::with_capacity_and_hasher(requirements.len(), FxBuildHasher);
for requirement in requirements {
global
.entry(requirement.name.clone())
.or_default()
.push(requirement);
}
Self {
global,
scoped: FxHashMap::default(),
}
}
pub fn from_entries(
entries: Vec<Override<Requirement>>,
) -> Result<Self, ScopedOverrideSourceError> {
let mut global: FxHashMap<PackageName, Vec<Requirement>> =
FxHashMap::with_capacity_and_hasher(entries.len(), FxBuildHasher);
let mut scoped: FxHashMap<PackageName, Vec<ScopedOverrides>> = FxHashMap::default();
for entry in entries {
match entry {
Override::Requirement(requirement) => {
global
.entry(requirement.name.clone())
.or_default()
.push(requirement);
}
Override::Package(package) => {
for requirement in &package.dependencies {
match &requirement.source {
RequirementSource::Registry { index: Some(_), .. } => {
return Err(ScopedOverrideSourceError::Index {
package: package.package.name.clone(),
dependency: requirement.name.clone(),
});
}
RequirementSource::Registry { index: None, .. } => {}
RequirementSource::Url { .. }
| RequirementSource::GitDirectory { .. }
| RequirementSource::GitPath { .. }
| RequirementSource::Path { .. }
| RequirementSource::Directory { .. } => {
return Err(ScopedOverrideSourceError::Url {
package: package.package.name.clone(),
dependency: requirement.name.clone(),
});
}
}
}
let packages = scoped.entry(package.package.name.clone()).or_default();
let position = packages
.iter()
.position(|overrides| overrides.version == package.package.version)
.unwrap_or_else(|| {
let position = packages.len();
packages.push(ScopedOverrides {
version: package.package.version,
overrides: FxHashMap::default(),
});
position
});
let overrides = &mut packages[position].overrides;
for requirement in package.dependencies {
overrides
.entry(requirement.name.clone())
.or_default()
.push(requirement);
}
}
}
}
Ok(Self { global, scoped })
}
pub fn global_requirements(&self) -> impl Iterator<Item = &Requirement> {
self.global
.values()
.flat_map(|requirements| requirements.iter())
}
pub fn scoped_requirements(
&self,
) -> impl Iterator<Item = (&PackageName, Option<&Version>, &Requirement)> {
self.scoped.iter().flat_map(|(package, entries)| {
entries.iter().flat_map(move |entry| {
entry
.overrides
.values()
.flatten()
.map(move |requirement| (package, entry.version.as_ref(), requirement))
})
})
}
pub fn scoped_requirements_for(
&self,
package: &PackageName,
version: &Version,
) -> impl Iterator<Item = &Requirement> {
self.scoped_for(package, version)
.into_iter()
.flat_map(|scoped| scoped.overrides.values().flatten())
}
pub(crate) fn has_exact_scope(&self, package: &PackageName, version: &Version) -> bool {
self.scoped.get(package).is_some_and(|entries| {
entries
.iter()
.any(|entry| entry.version.as_ref() == Some(version))
})
}
fn get(&self, name: &PackageName) -> Option<&Vec<Requirement>> {
self.global.get(name)
}
fn scoped_for(&self, package: &PackageName, version: &Version) -> Option<&ScopedOverrides> {
self.scoped.get(package).and_then(|entries| {
entries
.iter()
.find(|entry| entry.version.as_ref() == Some(version))
.or_else(|| entries.iter().find(|entry| entry.version.is_none()))
})
}
pub fn apply<'a, I>(
&'a self,
requirements: I,
) -> impl Iterator<Item = Cow<'a, Requirement>> + use<'a, I>
where
I: IntoIterator<Item = &'a Requirement>,
{
self.apply_inner(requirements, None)
}
pub fn apply_for<'a, I>(
&'a self,
package: &PackageName,
version: &Version,
requirements: I,
) -> impl Iterator<Item = Cow<'a, Requirement>> + use<'a, I>
where
I: IntoIterator<Item = &'a Requirement>,
{
self.apply_inner(requirements, Some((package, version)))
}
pub fn apply_for_package<'a, I>(
&'a self,
package: Option<(&PackageName, &Version)>,
requirements: I,
) -> impl Iterator<Item = Cow<'a, Requirement>> + use<'a, I>
where
I: IntoIterator<Item = &'a Requirement>,
{
self.apply_inner(requirements, package)
}
fn apply_inner<'a, I>(
&'a self,
requirements: I,
package: Option<(&PackageName, &Version)>,
) -> impl Iterator<Item = Cow<'a, Requirement>> + use<'a, I>
where
I: IntoIterator<Item = &'a Requirement>,
{
let scoped = package.and_then(|(package, version)| self.scoped_for(package, version));
if let Some(scoped) = scoped {
let requirements = requirements.into_iter().collect::<Vec<_>>();
let names = requirements
.iter()
.map(|requirement| requirement.name.clone())
.collect::<FxHashSet<_>>();
let mut additions = scoped
.overrides
.iter()
.filter(|(name, _)| !names.contains(*name))
.flat_map(|(_, requirements)| requirements)
.collect::<Vec<_>>();
additions.sort_unstable();
return Either::Left(
requirements
.into_iter()
.flat_map(move |requirement| self.apply_requirement(requirement, Some(scoped)))
.chain(additions.into_iter().map(Cow::Borrowed)),
);
}
if self.global.is_empty() {
return Either::Right(Either::Left(requirements.into_iter().map(Cow::Borrowed)));
}
Either::Right(Either::Right(requirements.into_iter().flat_map(
move |requirement| self.apply_requirement(requirement, None),
)))
}
fn apply_requirement<'a>(
&'a self,
requirement: &'a Requirement,
scoped: Option<&'a ScopedOverrides>,
) -> impl Iterator<Item = Cow<'a, Requirement>> {
let overrides = scoped
.and_then(|scoped| scoped.overrides.get(&requirement.name))
.or_else(|| self.get(&requirement.name));
let Some(overrides) = overrides else {
return Either::Left(std::iter::once(Cow::Borrowed(requirement)));
};
let Some(extra_expression) = requirement.marker.top_level_extra() else {
return Either::Right(Either::Right(overrides.iter().map(Cow::Borrowed)));
};
Either::Right(Either::Left(overrides.iter().map(
move |override_requirement| {
let mut joint_marker = MarkerTree::expression(extra_expression.clone());
joint_marker.and(override_requirement.marker);
Cow::Owned(Requirement {
marker: joint_marker,
..override_requirement.clone()
})
},
)))
}
}