1use super::Error;
2use anyhow::Result;
3use futures::{stream::FuturesUnordered, StreamExt};
4use indexmap::IndexMap;
5use miette::SourceSpan;
6use semver::{Version, VersionReq};
7use std::{fs, path::Path, sync::Arc};
8use wac_types::BorrowedPackageKey;
9use warg_client::{Client, ClientError, Config, FileSystemClient};
10use warg_protocol::registry::PackageName;
11
12pub trait ProgressBar {
16 fn init(&self, count: usize);
18
19 fn println(&self, status: &str, msg: &str);
21
22 fn inc(&self, delta: usize);
24
25 fn finish(&self);
27}
28
29pub struct RegistryPackageResolver {
34 client: Arc<FileSystemClient>,
35 bar: Option<Box<dyn ProgressBar>>,
36}
37
38impl RegistryPackageResolver {
39 pub async fn new(url: Option<&str>, bar: Option<Box<dyn ProgressBar>>) -> Result<Self> {
44 Ok(Self {
45 client: Arc::new(Client::new_with_default_config(url).await?),
46 bar,
47 })
48 }
49
50 pub async fn new_with_config(
54 url: Option<&str>,
55 config: &Config,
56 bar: Option<Box<dyn ProgressBar>>,
57 ) -> Result<Self> {
58 Ok(Self {
59 client: Arc::new(Client::new_with_config(url, config, None).await?),
60 bar,
61 })
62 }
63
64 pub async fn resolve<'a>(
68 &self,
69 keys: &IndexMap<BorrowedPackageKey<'a>, SourceSpan>,
70 ) -> Result<IndexMap<BorrowedPackageKey<'a>, Vec<u8>>, Error> {
71 let package_names_with_source_span = keys
73 .iter()
74 .map(|(key, span)| {
75 Ok((
76 PackageName::new(key.name.to_string()).map_err(|_| {
77 Error::InvalidPackageName {
78 name: key.name.to_string(),
79 span: *span,
80 }
81 })?,
82 (key.version.cloned(), *span),
83 ))
84 })
85 .collect::<Result<IndexMap<PackageName, (Option<Version>, SourceSpan)>, Error>>()?;
86
87 if let Some(bar) = self.bar.as_ref() {
89 bar.println("Updating", "package logs from the registry");
90 }
91
92 match self
93 .client
94 .fetch_packages(package_names_with_source_span.keys())
95 .await
96 {
97 Ok(_) => {}
98 Err(ClientError::PackageDoesNotExist { name, .. }) => {
99 return Err(Error::PackageDoesNotExist {
100 name: name.to_string(),
101 span: package_names_with_source_span.get(&name).unwrap().1,
102 });
103 }
104 Err(err) => {
105 return Err(Error::RegistryUpdateFailure { source: err.into() });
106 }
107 }
108
109 if let Some(bar) = self.bar.as_ref() {
110 bar.init(keys.len());
112 bar.println("Downloading", "package content from the registry");
113 }
114
115 let mut tasks = FuturesUnordered::new();
116 for (index, (package_name, (version, span))) in
117 package_names_with_source_span.into_iter().enumerate()
118 {
119 let client = self.client.clone();
120 tasks.push(tokio::spawn(async move {
121 Ok((
122 index,
123 if let Some(version) = version {
124 client
125 .download_exact(&package_name, &version)
126 .await
127 .map_err(|err| match err {
128 ClientError::PackageVersionDoesNotExist { name, version } => {
129 Error::PackageVersionDoesNotExist {
130 name: name.to_string(),
131 version,
132 span,
133 }
134 }
135 err => Error::RegistryDownloadFailure { source: err.into() },
136 })?
137 } else {
138 client
139 .download(&package_name, &VersionReq::STAR)
140 .await
141 .map_err(|err| Error::RegistryDownloadFailure { source: err.into() })?
142 .ok_or_else(|| Error::PackageNoReleases {
143 name: package_name.to_string(),
144 span,
145 })?
146 },
147 ))
148 }));
149 }
150
151 let mut packages = IndexMap::with_capacity(keys.len());
152 let count = tasks.len();
153 let mut finished = 0;
154
155 while let Some(res) = tasks.next().await {
156 let (index, download) = res.unwrap()?;
157
158 finished += 1;
159
160 let (key, _) = keys.get_index(index).unwrap();
161
162 if let Some(bar) = self.bar.as_ref() {
163 bar.inc(1);
164 let BorrowedPackageKey { name, .. } = key;
165 bar.println(
166 "Downloaded",
167 &format!("package `{name}` {version}", version = download.version),
168 )
169 }
170
171 packages.insert(*key, Self::read_contents(&download.path)?);
172 }
173
174 assert_eq!(finished, count);
175
176 if let Some(bar) = self.bar.as_ref() {
177 bar.finish();
178 }
179
180 Ok(packages)
181 }
182
183 fn read_contents(path: &Path) -> Result<Vec<u8>, Error> {
184 fs::read(path).map_err(|e| Error::RegistryContentFailure {
185 path: path.to_path_buf(),
186 source: e.into(),
187 })
188 }
189}