use std::collections::{BTreeMap, HashMap};
use std::sync::Arc;
use async_trait::async_trait;
use semver::Version;
use tokio::sync::{OnceCell, RwLock};
use crate::error::{ResolutionError, ResolutionResult};
use crate::version::VersionSet;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct VersionMetadata {
pub name: String,
pub version: Version,
pub dependencies: BTreeMap<String, VersionSet>,
pub peer_dependencies: BTreeMap<String, VersionSet>,
pub optional_dependencies: BTreeMap<String, VersionSet>,
}
impl VersionMetadata {
pub fn new<N: Into<String>>(name: N, version: Version) -> Self {
Self {
name: name.into(),
version,
dependencies: BTreeMap::new(),
peer_dependencies: BTreeMap::new(),
optional_dependencies: BTreeMap::new(),
}
}
pub fn with_dependency<N: Into<String>>(mut self, name: N, version_set: VersionSet) -> Self {
self.dependencies.insert(name.into(), version_set);
self
}
}
#[async_trait]
pub trait PackageRegistry: Send + Sync {
async fn get_package_versions(&self, name: &str) -> ResolutionResult<Vec<Version>>;
async fn get_satisfying_versions(
&self,
name: &str,
version_set: &VersionSet,
) -> ResolutionResult<Vec<Version>> {
let all = self.get_package_versions(name).await?;
Ok(all
.into_iter()
.filter(|v| version_set.satisfies(v))
.collect())
}
async fn get_version_metadata(
&self,
name: &str,
version: &Version,
) -> ResolutionResult<VersionMetadata>;
async fn package_exists(&self, name: &str) -> ResolutionResult<bool> {
match self.get_package_versions(name).await {
Ok(v) => Ok(!v.is_empty()),
Err(ResolutionError::PackageNotFound { .. }) => Ok(false),
Err(e) => Err(e),
}
}
}
type VersionsCell = Arc<OnceCell<Vec<Version>>>;
type MetadataCell = Arc<OnceCell<VersionMetadata>>;
type VersionsMap = Arc<RwLock<HashMap<String, VersionsCell>>>;
type MetadataMap = Arc<RwLock<HashMap<(String, Version), MetadataCell>>>;
type ConstraintMap = Arc<RwLock<HashMap<(String, VersionSet), Vec<Version>>>>;
#[derive(Debug)]
pub struct CachedRegistry<R> {
inner: R,
versions_cache: VersionsMap,
metadata_cache: MetadataMap,
constraint_cache: ConstraintMap,
}
impl<R> CachedRegistry<R> {
pub fn new(inner: R) -> Self {
Self {
inner,
versions_cache: Arc::new(RwLock::new(HashMap::new())),
metadata_cache: Arc::new(RwLock::new(HashMap::new())),
constraint_cache: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn inner(&self) -> &R {
&self.inner
}
pub async fn clear_cache(&self) {
self.versions_cache.write().await.clear();
self.metadata_cache.write().await.clear();
self.constraint_cache.write().await.clear();
}
pub async fn cache_stats(&self) -> CacheStats {
CacheStats {
versions_entries: self.versions_cache.read().await.len(),
metadata_entries: self.metadata_cache.read().await.len(),
constraint_entries: self.constraint_cache.read().await.len(),
}
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub versions_entries: usize,
pub metadata_entries: usize,
pub constraint_entries: usize,
}
#[async_trait]
impl<R: PackageRegistry> PackageRegistry for CachedRegistry<R> {
async fn get_package_versions(&self, name: &str) -> ResolutionResult<Vec<Version>> {
let cell = {
let mut cache = self.versions_cache.write().await;
cache
.entry(name.to_string())
.or_insert_with(|| Arc::new(OnceCell::new()))
.clone()
};
let versions = cell
.get_or_try_init(|| async { self.inner.get_package_versions(name).await })
.await?;
Ok(versions.clone())
}
async fn get_satisfying_versions(
&self,
name: &str,
version_set: &VersionSet,
) -> ResolutionResult<Vec<Version>> {
let key = (name.to_string(), version_set.clone());
if let Some(versions) = self.constraint_cache.read().await.get(&key).cloned() {
return Ok(versions);
}
let all = self.get_package_versions(name).await?;
let satisfying: Vec<Version> = all
.into_iter()
.filter(|v| version_set.satisfies(v))
.collect();
self.constraint_cache
.write()
.await
.insert(key, satisfying.clone());
Ok(satisfying)
}
async fn get_version_metadata(
&self,
name: &str,
version: &Version,
) -> ResolutionResult<VersionMetadata> {
let key = (name.to_string(), version.clone());
let cell = {
let mut cache = self.metadata_cache.write().await;
cache
.entry(key)
.or_insert_with(|| Arc::new(OnceCell::new()))
.clone()
};
let metadata = cell
.get_or_try_init(|| async { self.inner.get_version_metadata(name, version).await })
.await?;
Ok(metadata.clone())
}
async fn package_exists(&self, name: &str) -> ResolutionResult<bool> {
if let Some(cell) = self.versions_cache.read().await.get(name).cloned() {
if let Some(versions) = cell.get() {
return Ok(!versions.is_empty());
}
}
self.inner.package_exists(name).await
}
}
#[derive(Debug, Default, Clone)]
pub struct MockRegistry {
packages: BTreeMap<String, BTreeMap<Version, VersionMetadata>>,
}
impl MockRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn with_versions<N: Into<String>>(mut self, name: N, versions: &[&str]) -> Self {
let name = name.into();
let entry = self.packages.entry(name.clone()).or_default();
for v in versions {
if let Ok(parsed) = Version::parse(v) {
entry
.entry(parsed.clone())
.or_insert_with(|| VersionMetadata::new(name.clone(), parsed));
}
}
self
}
pub fn with_dependency<N: Into<String>, D: Into<String>>(
mut self,
name: N,
version: &str,
dep_name: D,
dep_set: VersionSet,
) -> Self {
let name = name.into();
let version =
Version::parse(version).expect("valid version in MockRegistry::with_dependency");
let entry = self
.packages
.entry(name.clone())
.or_default()
.entry(version.clone())
.or_insert_with(|| VersionMetadata::new(name, version));
entry.dependencies.insert(dep_name.into(), dep_set);
self
}
}
#[async_trait]
impl PackageRegistry for MockRegistry {
async fn get_package_versions(&self, name: &str) -> ResolutionResult<Vec<Version>> {
match self.packages.get(name) {
Some(versions) => {
let mut out: Vec<Version> = versions.keys().cloned().collect();
out.sort_by(|a, b| b.cmp(a));
Ok(out)
}
None => Err(ResolutionError::PackageNotFound {
package: name.to_string(),
version: "*".to_string(),
}),
}
}
async fn get_version_metadata(
&self,
name: &str,
version: &Version,
) -> ResolutionResult<VersionMetadata> {
self.packages
.get(name)
.and_then(|v| v.get(version))
.cloned()
.ok_or_else(|| ResolutionError::PackageNotFound {
package: name.to_string(),
version: version.to_string(),
})
}
async fn package_exists(&self, name: &str) -> ResolutionResult<bool> {
Ok(self.packages.contains_key(name))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn vs(s: &str) -> VersionSet {
s.parse().unwrap()
}
#[tokio::test]
async fn mock_registry_basic_lookup() {
let registry = MockRegistry::new().with_versions("pkg", &["1.0.0", "1.1.0", "2.0.0"]);
assert!(registry.package_exists("pkg").await.unwrap());
assert!(!registry.package_exists("nope").await.unwrap());
let versions = registry.get_package_versions("pkg").await.unwrap();
assert_eq!(
versions,
vec![
Version::new(2, 0, 0),
Version::new(1, 1, 0),
Version::new(1, 0, 0),
]
);
let satisfying = registry
.get_satisfying_versions("pkg", &vs("^1.0.0"))
.await
.unwrap();
assert_eq!(satisfying.len(), 2);
assert!(satisfying.contains(&Version::new(1, 0, 0)));
assert!(satisfying.contains(&Version::new(1, 1, 0)));
}
#[tokio::test]
async fn mock_registry_dependencies() {
let registry = MockRegistry::new()
.with_versions("left-pad", &["1.0.0"])
.with_dependency("my-app", "1.0.0", "left-pad", vs("^1.0.0"));
let metadata = registry
.get_version_metadata("my-app", &Version::new(1, 0, 0))
.await
.unwrap();
assert_eq!(metadata.name, "my-app");
assert_eq!(metadata.version, Version::new(1, 0, 0));
assert_eq!(metadata.dependencies.len(), 1);
assert_eq!(metadata.dependencies.get("left-pad"), Some(&vs("^1.0.0")));
}
#[tokio::test]
async fn cached_registry_single_flights_versions() {
let inner = MockRegistry::new().with_versions("pkg", &["1.0.0", "2.0.0"]);
let registry = CachedRegistry::new(inner);
let first = registry.get_package_versions("pkg").await.unwrap();
let second = registry.get_package_versions("pkg").await.unwrap();
assert_eq!(first, second);
let stats = registry.cache_stats().await;
assert_eq!(stats.versions_entries, 1);
assert_eq!(stats.constraint_entries, 0);
}
#[tokio::test]
async fn cached_registry_caches_constraint_filter() {
let inner = MockRegistry::new().with_versions("pkg", &["1.0.0", "2.0.0"]);
let registry = CachedRegistry::new(inner);
let one = registry
.get_satisfying_versions("pkg", &vs("^1.0.0"))
.await
.unwrap();
let two = registry
.get_satisfying_versions("pkg", &vs("^1.0.0"))
.await
.unwrap();
assert_eq!(one, two);
assert_eq!(one, vec![Version::new(1, 0, 0)]);
let stats = registry.cache_stats().await;
assert_eq!(stats.versions_entries, 1);
assert_eq!(stats.constraint_entries, 1);
}
#[tokio::test]
async fn cached_registry_caches_per_version_metadata() {
let inner = MockRegistry::new()
.with_versions("pkg", &["1.0.0"])
.with_dependency("pkg", "1.0.0", "dep", vs("^1.0.0"));
let registry = CachedRegistry::new(inner);
let a = registry
.get_version_metadata("pkg", &Version::new(1, 0, 0))
.await
.unwrap();
let b = registry
.get_version_metadata("pkg", &Version::new(1, 0, 0))
.await
.unwrap();
assert_eq!(a, b);
let stats = registry.cache_stats().await;
assert_eq!(stats.metadata_entries, 1);
}
#[tokio::test]
async fn cached_registry_propagates_not_found() {
let inner = MockRegistry::new();
let registry = CachedRegistry::new(inner);
let err = registry.get_package_versions("nope").await.unwrap_err();
assert!(matches!(err, ResolutionError::PackageNotFound { .. }));
assert!(!registry.package_exists("nope").await.unwrap());
}
}