1use crate::registry::{NpmRegistry, RegistryError};
4use std::collections::{HashMap, HashSet};
5
6#[derive(Debug, Clone)]
8pub struct ResolvedPackage {
9 pub name: String,
10 pub version: String,
11 pub tarball_url: String,
12 pub integrity: Option<String>,
13 pub dependencies: HashMap<String, String>,
14}
15
16pub struct Resolver {
18 registry: NpmRegistry,
19 resolved: HashMap<String, ResolvedPackage>,
20 in_progress: HashSet<String>,
21}
22
23impl Resolver {
24 pub fn new(registry: NpmRegistry) -> Self {
25 Self {
26 registry,
27 resolved: HashMap::new(),
28 in_progress: HashSet::new(),
29 }
30 }
31
32 pub async fn resolve(
34 &mut self,
35 dependencies: &HashMap<String, String>,
36 ) -> Result<Vec<ResolvedPackage>, ResolverError> {
37 for (name, version_req) in dependencies {
38 self.resolve_package(name, version_req).await?;
39 }
40
41 Ok(self.resolved.values().cloned().collect())
42 }
43
44 pub async fn resolve_package(
46 &mut self,
47 name: &str,
48 version_req: &str,
49 ) -> Result<(), ResolverError> {
50 if self.resolved.contains_key(name) {
52 return Ok(());
53 }
54
55 if self.in_progress.contains(name) {
57 return Err(ResolverError::CircularDependency(name.to_string()));
58 }
59
60 self.in_progress.insert(name.to_string());
61
62 let version = self
64 .registry
65 .resolve_version(name, version_req)
66 .await
67 .map_err(ResolverError::Registry)?;
68
69 let metadata = self
71 .registry
72 .get_package(name)
73 .await
74 .map_err(ResolverError::Registry)?;
75
76 let version_info =
77 metadata
78 .versions
79 .get(&version)
80 .ok_or_else(|| ResolverError::VersionNotFound {
81 name: name.to_string(),
82 version: version.clone(),
83 })?;
84
85 let deps = version_info.dependencies.clone().unwrap_or_default();
87 let tarball_url = version_info.dist.tarball.clone();
88 let integrity = version_info.dist.integrity.clone();
89
90 for (dep_name, dep_version) in &deps {
92 Box::pin(self.resolve_package(dep_name, dep_version)).await?;
94 }
95
96 self.resolved.insert(
98 name.to_string(),
99 ResolvedPackage {
100 name: name.to_string(),
101 version,
102 tarball_url,
103 integrity,
104 dependencies: deps,
105 },
106 );
107
108 self.in_progress.remove(name);
109 Ok(())
110 }
111
112 pub fn get_resolved(&self) -> &HashMap<String, ResolvedPackage> {
114 &self.resolved
115 }
116
117 pub fn get_package(&self, name: &str) -> Option<&ResolvedPackage> {
119 self.resolved.get(name)
120 }
121
122 pub fn clear(&mut self) {
124 self.resolved.clear();
125 self.in_progress.clear();
126 }
127
128 pub fn into_registry(self) -> NpmRegistry {
130 self.registry
131 }
132}
133
134#[derive(Debug, thiserror::Error)]
135pub enum ResolverError {
136 #[error("Registry error: {0}")]
137 Registry(#[from] RegistryError),
138
139 #[error("Circular dependency: {0}")]
140 CircularDependency(String),
141
142 #[error("Version not found: {name}@{version}")]
143 VersionNotFound { name: String, version: String },
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149
150 #[test]
151 fn test_resolver_new() {
152 let registry = NpmRegistry::new();
153 let resolver = Resolver::new(registry);
154 assert!(resolver.resolved.is_empty());
155 assert!(resolver.in_progress.is_empty());
156 }
157
158 #[tokio::test]
159 #[ignore] async fn test_resolve_simple() {
161 let registry = NpmRegistry::new();
162 let mut resolver = Resolver::new(registry);
163
164 let mut deps = HashMap::new();
165 deps.insert("is-odd".to_string(), "^3.0.0".to_string());
166
167 let result = resolver.resolve(&deps).await;
168 if let Ok(packages) = result {
169 assert!(!packages.is_empty());
170 let names: Vec<_> = packages.iter().map(|p| p.name.as_str()).collect();
172 assert!(names.contains(&"is-odd"));
173 }
174 }
175
176 #[tokio::test]
177 #[ignore] async fn test_resolve_with_transitive() {
179 let registry = NpmRegistry::new();
180 let mut resolver = Resolver::new(registry);
181
182 let mut deps = HashMap::new();
184 deps.insert("chalk".to_string(), "^4.0.0".to_string());
185
186 let result = resolver.resolve(&deps).await;
187 if let Ok(packages) = result {
188 assert!(packages.len() > 1);
190 let names: Vec<_> = packages.iter().map(|p| p.name.as_str()).collect();
191 assert!(names.contains(&"chalk"));
192 }
193 }
194}