wac_resolver/
registry.rs

1use super::Error;
2use anyhow::Result;
3use futures::{stream::FuturesUnordered, StreamExt};
4use indexmap::IndexMap;
5use miette::SourceSpan;
6use semver::{Version, VersionReq};
7use std::{fs, path::Path, sync::Arc};
8use wac_types::BorrowedPackageKey;
9use warg_client::{Client, ClientError, Config, FileSystemClient};
10use warg_protocol::registry::PackageName;
11
12/// Implemented by progress bars.
13///
14/// This is used to abstract a UI for the registry resolver.
15pub trait ProgressBar {
16    /// Initializes the progress bar with the given count.
17    fn init(&self, count: usize);
18
19    /// Prints a message and then redraws the progress bar.
20    fn println(&self, status: &str, msg: &str);
21
22    /// Increments the progress bar by the given amount.
23    fn inc(&self, delta: usize);
24
25    // Finishes the progress bar.
26    fn finish(&self);
27}
28
29/// Used to resolve packages from a Warg registry.
30///
31/// Note that the registry will be locked for the lifetime of
32/// the resolver.
33pub struct RegistryPackageResolver {
34    client: Arc<FileSystemClient>,
35    bar: Option<Box<dyn ProgressBar>>,
36}
37
38impl RegistryPackageResolver {
39    /// Creates a new registry package resolver using the default
40    /// client configuration file.
41    ///
42    /// If `url` is `None`, the default URL will be used.
43    pub async fn new(url: Option<&str>, bar: Option<Box<dyn ProgressBar>>) -> Result<Self> {
44        Ok(Self {
45            client: Arc::new(Client::new_with_default_config(url).await?),
46            bar,
47        })
48    }
49
50    /// Creates a new registry package resolver with the given configuration.
51    ///
52    /// If `url` is `None`, the default URL will be used.
53    pub async fn new_with_config(
54        url: Option<&str>,
55        config: &Config,
56        bar: Option<Box<dyn ProgressBar>>,
57    ) -> Result<Self> {
58        Ok(Self {
59            client: Arc::new(Client::new_with_config(url, config, None).await?),
60            bar,
61        })
62    }
63
64    /// Resolves the provided package keys to packages.
65    ///
66    /// If the package isn't found, an error is returned.
67    pub async fn resolve<'a>(
68        &self,
69        keys: &IndexMap<BorrowedPackageKey<'a>, SourceSpan>,
70    ) -> Result<IndexMap<BorrowedPackageKey<'a>, Vec<u8>>, Error> {
71        // parses into `PackageName` and maps back to `SourceSpan`
72        let package_names_with_source_span = keys
73            .iter()
74            .map(|(key, span)| {
75                Ok((
76                    PackageName::new(key.name.to_string()).map_err(|_| {
77                        Error::InvalidPackageName {
78                            name: key.name.to_string(),
79                            span: *span,
80                        }
81                    })?,
82                    (key.version.cloned(), *span),
83                ))
84            })
85            .collect::<Result<IndexMap<PackageName, (Option<Version>, SourceSpan)>, Error>>()?;
86
87        // fetch required package logs and return error if any not found
88        if let Some(bar) = self.bar.as_ref() {
89            bar.println("Updating", "package logs from the registry");
90        }
91
92        match self
93            .client
94            .fetch_packages(package_names_with_source_span.keys())
95            .await
96        {
97            Ok(_) => {}
98            Err(ClientError::PackageDoesNotExist { name, .. }) => {
99                return Err(Error::PackageDoesNotExist {
100                    name: name.to_string(),
101                    span: package_names_with_source_span.get(&name).unwrap().1,
102                });
103            }
104            Err(err) => {
105                return Err(Error::RegistryUpdateFailure { source: err.into() });
106            }
107        }
108
109        if let Some(bar) = self.bar.as_ref() {
110            // download package content if not in cache
111            bar.init(keys.len());
112            bar.println("Downloading", "package content from the registry");
113        }
114
115        let mut tasks = FuturesUnordered::new();
116        for (index, (package_name, (version, span))) in
117            package_names_with_source_span.into_iter().enumerate()
118        {
119            let client = self.client.clone();
120            tasks.push(tokio::spawn(async move {
121                Ok((
122                    index,
123                    if let Some(version) = version {
124                        client
125                            .download_exact(&package_name, &version)
126                            .await
127                            .map_err(|err| match err {
128                                ClientError::PackageVersionDoesNotExist { name, version } => {
129                                    Error::PackageVersionDoesNotExist {
130                                        name: name.to_string(),
131                                        version,
132                                        span,
133                                    }
134                                }
135                                err => Error::RegistryDownloadFailure { source: err.into() },
136                            })?
137                    } else {
138                        client
139                            .download(&package_name, &VersionReq::STAR)
140                            .await
141                            .map_err(|err| Error::RegistryDownloadFailure { source: err.into() })?
142                            .ok_or_else(|| Error::PackageNoReleases {
143                                name: package_name.to_string(),
144                                span,
145                            })?
146                    },
147                ))
148            }));
149        }
150
151        let mut packages = IndexMap::with_capacity(keys.len());
152        let count = tasks.len();
153        let mut finished = 0;
154
155        while let Some(res) = tasks.next().await {
156            let (index, download) = res.unwrap()?;
157
158            finished += 1;
159
160            let (key, _) = keys.get_index(index).unwrap();
161
162            if let Some(bar) = self.bar.as_ref() {
163                bar.inc(1);
164                let BorrowedPackageKey { name, .. } = key;
165                bar.println(
166                    "Downloaded",
167                    &format!("package `{name}` {version}", version = download.version),
168                )
169            }
170
171            packages.insert(*key, Self::read_contents(&download.path)?);
172        }
173
174        assert_eq!(finished, count);
175
176        if let Some(bar) = self.bar.as_ref() {
177            bar.finish();
178        }
179
180        Ok(packages)
181    }
182
183    fn read_contents(path: &Path) -> Result<Vec<u8>, Error> {
184        fs::read(path).map_err(|e| Error::RegistryContentFailure {
185            path: path.to_path_buf(),
186            source: e.into(),
187        })
188    }
189}