1use std::{
3 collections::{hash_map, HashMap},
4 fmt::Debug,
5 path::{Path, PathBuf},
6 str::FromStr,
7 sync::Arc,
8};
9
10use anyhow::{bail, Context, Result};
11use futures::TryStreamExt;
12use indexmap::IndexMap;
13use semver::{Comparator, Op, Version, VersionReq};
14use serde::{
15 de::{self, value::MapAccessDeserializer},
16 Deserialize, Serialize,
17};
18
19use tokio::io::AsyncReadExt;
20use wasm_pkg_client::{
21 caching::{CachingClient, FileCache},
22 Client, Config, ContentDigest, Error as WasmPkgError, PackageRef, Release, VersionInfo,
23};
24use wit_component::DecodedWasm;
25use wit_parser::{PackageId, PackageName, Resolve, UnresolvedPackageGroup, WorldId};
26
27use crate::lock::{LockFileResolver, LockedPackageVersion};
28
29pub const DEFAULT_REGISTRY_NAME: &str = "default";
31
32#[derive(Debug, Clone)]
34pub enum Dependency {
35 Package(RegistryPackage),
37
38 Local(PathBuf),
40}
41
42impl Serialize for Dependency {
43 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
44 where
45 S: serde::Serializer,
46 {
47 match self {
48 Self::Package(package) => {
49 if package.name.is_none() && package.registry.is_none() {
50 let version = package.version.to_string();
51 version.trim_start_matches('^').serialize(serializer)
52 } else {
53 #[derive(Serialize)]
54 struct Entry<'a> {
55 package: Option<&'a PackageRef>,
56 version: &'a str,
57 registry: Option<&'a str>,
58 }
59
60 Entry {
61 package: package.name.as_ref(),
62 version: package.version.to_string().trim_start_matches('^'),
63 registry: package.registry.as_deref(),
64 }
65 .serialize(serializer)
66 }
67 }
68 Self::Local(path) => {
69 #[derive(Serialize)]
70 struct Entry<'a> {
71 path: &'a PathBuf,
72 }
73
74 Entry { path }.serialize(serializer)
75 }
76 }
77 }
78}
79
80impl<'de> Deserialize<'de> for Dependency {
81 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
82 where
83 D: serde::Deserializer<'de>,
84 {
85 struct Visitor;
86
87 impl<'de> de::Visitor<'de> for Visitor {
88 type Value = Dependency;
89
90 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
91 write!(formatter, "a string or a table")
92 }
93
94 fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
95 where
96 E: de::Error,
97 {
98 Ok(Self::Value::Package(s.parse().map_err(de::Error::custom)?))
99 }
100
101 fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
102 where
103 A: de::MapAccess<'de>,
104 {
105 #[derive(Default, Deserialize)]
106 #[serde(default, deny_unknown_fields)]
107 struct Entry {
108 path: Option<PathBuf>,
109 package: Option<PackageRef>,
110 version: Option<VersionReq>,
111 registry: Option<String>,
112 }
113
114 let entry = Entry::deserialize(MapAccessDeserializer::new(map))?;
115
116 match (entry.path, entry.package, entry.version, entry.registry) {
117 (Some(path), None, None, None) => Ok(Self::Value::Local(path)),
118 (None, name, Some(version), registry) => {
119 Ok(Self::Value::Package(RegistryPackage {
120 name,
121 version,
122 registry,
123 }))
124 }
125 (Some(_), None, Some(_), _) => Err(de::Error::custom(
126 "cannot specify both `path` and `version` fields in a dependency entry",
127 )),
128 (Some(_), None, None, Some(_)) => Err(de::Error::custom(
129 "cannot specify both `path` and `registry` fields in a dependency entry",
130 )),
131 (Some(_), Some(_), _, _) => Err(de::Error::custom(
132 "cannot specify both `path` and `package` fields in a dependency entry",
133 )),
134 (None, None, _, _) => Err(de::Error::missing_field("package")),
135 (None, Some(_), None, _) => Err(de::Error::missing_field("version")),
136 }
137 }
138 }
139
140 deserializer.deserialize_any(Visitor)
141 }
142}
143
144impl FromStr for Dependency {
145 type Err = anyhow::Error;
146
147 fn from_str(s: &str) -> Result<Self> {
148 Ok(Self::Package(s.parse()?))
149 }
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
154#[serde(deny_unknown_fields)]
155pub struct RegistryPackage {
156 pub name: Option<PackageRef>,
160
161 pub version: VersionReq,
163
164 pub registry: Option<String>,
168}
169
170impl FromStr for RegistryPackage {
171 type Err = anyhow::Error;
172
173 fn from_str(s: &str) -> Result<Self> {
174 Ok(Self {
175 name: None,
176 version: s
177 .parse()
178 .with_context(|| format!("'{s}' is an invalid registry package version"))?,
179 registry: None,
180 })
181 }
182}
183
184#[derive(Clone)]
186pub struct RegistryResolution {
187 pub name: PackageRef,
191 pub package: PackageRef,
193 pub registry: Option<String>,
197 pub requirement: VersionReq,
199 pub version: Version,
201 pub digest: ContentDigest,
203 client: Arc<CachingClient<FileCache>>,
205}
206
207impl Debug for RegistryResolution {
208 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
209 f.debug_struct("RegistryResolution")
210 .field("name", &self.name)
211 .field("package", &self.package)
212 .field("registry", &self.registry)
213 .field("requirement", &self.requirement)
214 .field("version", &self.version)
215 .field("digest", &self.digest)
216 .finish()
217 }
218}
219
220#[derive(Clone, Debug)]
222pub struct LocalResolution {
223 pub name: PackageRef,
225 pub path: PathBuf,
227}
228
229#[derive(Debug, Clone)]
231#[allow(clippy::large_enum_variant)]
232pub enum DependencyResolution {
233 Registry(RegistryResolution),
235 Local(LocalResolution),
237}
238
239impl DependencyResolution {
240 pub fn name(&self) -> &PackageRef {
242 match self {
243 Self::Registry(res) => &res.name,
244 Self::Local(res) => &res.name,
245 }
246 }
247
248 pub fn version(&self) -> Option<&Version> {
252 match self {
253 Self::Registry(res) => Some(&res.version),
254 Self::Local(_) => None,
255 }
256 }
257
258 pub fn key(&self) -> Option<(&PackageRef, Option<&str>)> {
262 match self {
263 DependencyResolution::Registry(pkg) => Some((&pkg.package, pkg.registry.as_deref())),
264 DependencyResolution::Local(_) => None,
265 }
266 }
267
268 pub async fn decode(&self) -> Result<DecodedDependency> {
270 let bytes = match self {
272 DependencyResolution::Local(LocalResolution { path, .. })
273 if tokio::fs::metadata(path).await?.is_dir() =>
274 {
275 return Ok(DecodedDependency::Wit {
276 resolution: self,
277 package: UnresolvedPackageGroup::parse_dir(path).with_context(|| {
278 format!("failed to parse dependency `{path}`", path = path.display())
279 })?,
280 });
281 }
282 DependencyResolution::Local(LocalResolution { path, .. }) => {
283 tokio::fs::read(path).await.with_context(|| {
284 format!(
285 "failed to read content of dependency `{name}` at path `{path}`",
286 name = self.name(),
287 path = path.display()
288 )
289 })?
290 }
291 DependencyResolution::Registry(res) => {
292 let stream = res
293 .client
294 .get_content(
295 &res.package,
296 &Release {
297 version: res.version.clone(),
298 content_digest: res.digest.clone(),
299 },
300 )
301 .await?;
302
303 let mut buf = Vec::new();
304 tokio_util::io::StreamReader::new(
305 stream.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)),
306 )
307 .read_to_end(&mut buf)
308 .await?;
309 buf
310 }
311 };
312
313 if &bytes[0..4] != b"\0asm" {
314 return Ok(DecodedDependency::Wit {
315 resolution: self,
316 package: UnresolvedPackageGroup::parse(
317 self.name().to_string(),
319 std::str::from_utf8(&bytes).with_context(|| {
320 format!(
321 "dependency `{name}` is not UTF-8 encoded",
322 name = self.name()
323 )
324 })?,
325 )?,
326 });
327 }
328
329 Ok(DecodedDependency::Wasm {
330 resolution: self,
331 decoded: wit_component::decode(&bytes).with_context(|| {
332 format!(
333 "failed to decode content of dependency `{name}`",
334 name = self.name(),
335 )
336 })?,
337 })
338 }
339}
340
341pub enum DecodedDependency<'a> {
343 Wit {
345 resolution: &'a DependencyResolution,
347 package: UnresolvedPackageGroup,
349 },
350 Wasm {
352 resolution: &'a DependencyResolution,
354 decoded: DecodedWasm,
356 },
357}
358
359impl<'a> DecodedDependency<'a> {
360 pub fn resolve(self) -> Result<(Resolve, PackageId, Vec<PathBuf>)> {
365 match self {
366 Self::Wit { package, .. } => {
367 let mut resolve = Resolve::new();
368 let source_files = package
369 .source_map
370 .source_files()
371 .map(Path::to_path_buf)
372 .collect();
373 let pkg = resolve.push_group(package)?;
374 Ok((resolve, pkg, source_files))
375 }
376 Self::Wasm { decoded, .. } => match decoded {
377 DecodedWasm::WitPackage(resolve, pkg) => Ok((resolve, pkg, Vec::new())),
378 DecodedWasm::Component(resolve, world) => {
379 let pkg = resolve.worlds[world].package.unwrap();
380 Ok((resolve, pkg, Vec::new()))
381 }
382 },
383 }
384 }
385
386 pub fn package_name(&self) -> &PackageName {
388 match self {
389 Self::Wit { package, .. } => &package.main.name,
390 Self::Wasm { decoded, .. } => &decoded.resolve().packages[decoded.package()].name,
391 }
392 }
393
394 pub fn into_component_world(self) -> Result<(Resolve, WorldId)> {
398 match self {
399 Self::Wasm {
400 decoded: DecodedWasm::Component(resolve, world),
401 ..
402 } => Ok((resolve, world)),
403 _ => bail!("dependency is not a WebAssembly component"),
404 }
405 }
406}
407
408pub struct DependencyResolver<'a> {
410 client: Arc<CachingClient<FileCache>>,
411 lock_file: Option<LockFileResolver<'a>>,
412 registries: IndexMap<&'a str, Registry<'a>>,
413 resolutions: HashMap<PackageRef, DependencyResolution>,
414}
415
416impl<'a> DependencyResolver<'a> {
417 pub fn new(
421 config: Option<Config>,
422 lock_file: Option<LockFileResolver<'a>>,
423 cache: FileCache,
424 ) -> anyhow::Result<Self> {
425 if config.is_none() && lock_file.is_none() {
426 anyhow::bail!("lock file must be provided when offline mode is enabled");
427 }
428 let client = CachingClient::new(config.map(Client::new), cache);
429 Ok(DependencyResolver {
430 client: Arc::new(client),
431 lock_file,
432 registries: Default::default(),
433 resolutions: Default::default(),
434 })
435 }
436
437 pub fn new_with_client(
441 client: Arc<CachingClient<FileCache>>,
442 lock_file: Option<LockFileResolver<'a>>,
443 ) -> anyhow::Result<Self> {
444 if client.is_readonly() && lock_file.is_none() {
445 anyhow::bail!("lock file must be provided when offline mode is enabled");
446 }
447 Ok(DependencyResolver {
448 client,
449 lock_file,
450 registries: Default::default(),
451 resolutions: Default::default(),
452 })
453 }
454
455 pub async fn add_dependency(
457 &mut self,
458 name: &'a PackageRef,
459 dependency: &'a Dependency,
460 ) -> Result<()> {
461 match dependency {
462 Dependency::Package(package) => {
463 let registry_name = package.registry.as_deref().unwrap_or(DEFAULT_REGISTRY_NAME);
465 let package_name = package.name.clone().unwrap_or_else(|| name.clone());
466
467 let locked = match self.lock_file.as_ref().and_then(|resolver| {
469 resolver
470 .resolve(registry_name, &package_name, &package.version)
471 .transpose()
472 }) {
473 Some(Ok(locked)) => Some(locked),
474 Some(Err(e)) => return Err(e),
475 _ => None,
476 };
477
478 let registry = match self.registries.entry(registry_name) {
479 indexmap::map::Entry::Occupied(e) => e.into_mut(),
480 indexmap::map::Entry::Vacant(e) => e.insert(Registry {
481 client: self.client.clone(),
482 packages: HashMap::new(),
483 dependencies: Vec::new(),
484 }),
485 };
486
487 registry
488 .add_dependency(name, package_name, &package.version, locked)
489 .await?;
490 }
491 Dependency::Local(p) => {
492 let res = DependencyResolution::Local(LocalResolution {
494 name: name.clone(),
495 path: p.clone(),
496 });
497
498 let prev = self.resolutions.insert(name.clone(), res);
499 assert!(prev.is_none());
500 }
501 }
502
503 Ok(())
504 }
505
506 pub async fn resolve(mut self) -> Result<DependencyResolutionMap> {
512 for (name, registry) in self.registries.iter_mut() {
514 registry.resolve(name).await?;
515 }
516
517 for resolution in self
518 .registries
519 .into_values()
520 .flat_map(|r| r.dependencies.into_iter())
521 .map(|d| {
522 DependencyResolution::Registry(
523 d.resolution.expect("dependency should have been resolved"),
524 )
525 })
526 {
527 let prev = self
528 .resolutions
529 .insert(resolution.name().clone(), resolution);
530 assert!(prev.is_none());
531 }
532
533 Ok(self.resolutions)
534 }
535}
536
537struct Registry<'a> {
538 client: Arc<CachingClient<FileCache>>,
539 packages: HashMap<PackageRef, Vec<VersionInfo>>,
540 dependencies: Vec<RegistryDependency<'a>>,
541}
542
543impl<'a> Registry<'a> {
544 async fn add_dependency(
545 &mut self,
546 name: &'a PackageRef,
547 package: PackageRef,
548 version: &'a VersionReq,
549 locked: Option<&LockedPackageVersion>,
550 ) -> Result<()> {
551 let dep = RegistryDependency {
552 name,
553 package: package.clone(),
554 version,
555 locked: locked.map(|l| (l.version.clone(), l.digest.clone())),
556 resolution: None,
557 };
558
559 self.dependencies.push(dep);
560
561 Ok(())
562 }
563
564 async fn resolve(&mut self, registry: &'a str) -> Result<()> {
565 for dependency in self.dependencies.iter_mut() {
566 let client = self.client.clone();
569
570 let (selected_version, digest) = if client.is_readonly() {
571 dependency
572 .locked
573 .as_ref()
574 .map(|(ver, digest)| (ver, Some(digest)))
575 .ok_or_else(|| {
576 anyhow::anyhow!("Couldn't find locked dependency while in offline mode")
577 })?
578 } else {
579 let versions =
580 load_package(&mut self.packages, &self.client, dependency.package.clone())
581 .await?
582 .with_context(|| {
583 format!(
584 "package `{name}` was not found in component registry `{registry}`",
585 name = dependency.package
586 )
587 })?;
588
589 match &dependency.locked {
590 Some((version, digest)) => {
591 let exact_req = VersionReq {
593 comparators: vec![Comparator {
594 op: Op::Exact,
595 major: version.major,
596 minor: Some(version.minor),
597 patch: Some(version.patch),
598 pre: version.pre.clone(),
599 }],
600 };
601
602 find_latest_release(versions, &exact_req).map(|v| (&v.version, Some(digest))).or_else(|| find_latest_release(versions, dependency.version).map(|v| (&v.version, None)))
607 }
608 None => find_latest_release(versions, dependency.version).map(|v| (&v.version, None)),
609 }.with_context(|| format!("component registry package `{name}` has no release matching version requirement `{version}`", name = dependency.package, version = dependency.version))?
610 };
611
612 let release = client
615 .get_release(&dependency.package, selected_version)
616 .await?;
617 if let Some(digest) = digest {
618 if &release.content_digest != digest {
619 bail!(
620 "component registry package `{name}` (v`{version}`) has digest `{content}` but the lock file specifies digest `{digest}`",
621 name = dependency.package,
622 version = release.version,
623 content = release.content_digest,
624 );
625 }
626 }
627
628 dependency.resolution = Some(RegistryResolution {
629 name: dependency.name.clone(),
630 package: dependency.package.clone(),
631 registry: if registry == DEFAULT_REGISTRY_NAME {
632 None
633 } else {
634 Some(registry.to_string())
635 },
636 requirement: dependency.version.clone(),
637 version: release.version.clone(),
638 digest: release.content_digest.clone(),
639 client: self.client.clone(),
640 });
641 }
642
643 Ok(())
644 }
645}
646
647async fn load_package<'b>(
648 packages: &'b mut HashMap<PackageRef, Vec<VersionInfo>>,
649 client: &CachingClient<FileCache>,
650 package: PackageRef,
651) -> Result<Option<&'b Vec<VersionInfo>>> {
652 match packages.entry(package) {
653 hash_map::Entry::Occupied(e) => Ok(Some(e.into_mut())),
654 hash_map::Entry::Vacant(e) => match client.list_all_versions(e.key()).await {
655 Ok(p) => Ok(Some(e.insert(p))),
656 Err(WasmPkgError::PackageNotFound) => Ok(None),
657 Err(err) => Err(err.into()),
658 },
659 }
660}
661
662struct RegistryDependency<'a> {
663 name: &'a PackageRef,
665 package: PackageRef,
667 version: &'a VersionReq,
668 locked: Option<(Version, ContentDigest)>,
669 resolution: Option<RegistryResolution>,
670}
671
672pub type DependencyResolutionMap = HashMap<PackageRef, DependencyResolution>;
676
677fn find_latest_release<'a>(
678 versions: &'a [VersionInfo],
679 req: &VersionReq,
680) -> Option<&'a VersionInfo> {
681 versions
682 .iter()
683 .filter(|info| !info.yanked && req.matches(&info.version))
684 .max_by(|a, b| a.version.cmp(&b.version))
685}