uv_distribution/
distribution_database.rs

1use std::future::Future;
2use std::io;
3use std::path::Path;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use futures::{FutureExt, TryStreamExt};
9use tempfile::TempDir;
10use tokio::io::{AsyncRead, AsyncSeekExt, ReadBuf};
11use tokio::sync::Semaphore;
12use tokio_util::compat::FuturesAsyncReadCompatExt;
13use tracing::{Instrument, info_span, instrument, warn};
14use url::Url;
15
16use uv_cache::{ArchiveId, CacheBucket, CacheEntry, WheelCache};
17use uv_cache_info::{CacheInfo, Timestamp};
18use uv_client::{
19    CacheControl, CachedClientError, Connectivity, DataWithCachePolicy, RegistryClient,
20};
21use uv_distribution_filename::WheelFilename;
22use uv_distribution_types::{
23    BuildInfo, BuildableSource, BuiltDist, Dist, File, HashPolicy, Hashed, IndexUrl, InstalledDist,
24    Name, SourceDist, ToUrlError,
25};
26use uv_extract::hash::Hasher;
27use uv_fs::write_atomic;
28use uv_platform_tags::Tags;
29use uv_pypi_types::{HashDigest, HashDigests, PyProjectToml};
30use uv_redacted::DisplaySafeUrl;
31use uv_types::{BuildContext, BuildStack};
32
33use crate::archive::Archive;
34use crate::metadata::{ArchiveMetadata, Metadata};
35use crate::source::SourceDistributionBuilder;
36use crate::{Error, LocalWheel, Reporter, RequiresDist};
37
38/// A cached high-level interface to convert distributions (a requirement resolved to a location)
39/// to a wheel or wheel metadata.
40///
41/// For wheel metadata, this happens by either fetching the metadata from the remote wheel or by
42/// building the source distribution. For wheel files, either the wheel is downloaded or a source
43/// distribution is downloaded, built and the new wheel gets returned.
44///
45/// All kinds of wheel sources (index, URL, path) and source distribution source (index, URL, path,
46/// Git) are supported.
47///
48/// This struct also has the task of acquiring locks around source dist builds in general and git
49/// operation especially, as well as respecting concurrency limits.
50pub struct DistributionDatabase<'a, Context: BuildContext> {
51    build_context: &'a Context,
52    builder: SourceDistributionBuilder<'a, Context>,
53    client: ManagedClient<'a>,
54    reporter: Option<Arc<dyn Reporter>>,
55}
56
57impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
58    pub fn new(
59        client: &'a RegistryClient,
60        build_context: &'a Context,
61        concurrent_downloads: usize,
62    ) -> Self {
63        Self {
64            build_context,
65            builder: SourceDistributionBuilder::new(build_context),
66            client: ManagedClient::new(client, concurrent_downloads),
67            reporter: None,
68        }
69    }
70
71    /// Set the build stack to use for the [`DistributionDatabase`].
72    #[must_use]
73    pub fn with_build_stack(self, build_stack: &'a BuildStack) -> Self {
74        Self {
75            builder: self.builder.with_build_stack(build_stack),
76            ..self
77        }
78    }
79
80    /// Set the [`Reporter`] to use for the [`DistributionDatabase`].
81    #[must_use]
82    pub fn with_reporter(self, reporter: Arc<dyn Reporter>) -> Self {
83        Self {
84            builder: self.builder.with_reporter(reporter.clone()),
85            reporter: Some(reporter),
86            ..self
87        }
88    }
89
90    /// Handle a specific `reqwest` error, and convert it to [`io::Error`].
91    fn handle_response_errors(&self, err: reqwest::Error) -> io::Error {
92        if err.is_timeout() {
93            io::Error::new(
94                io::ErrorKind::TimedOut,
95                format!(
96                    "Failed to download distribution due to network timeout. Try increasing UV_HTTP_TIMEOUT (current value: {}s).",
97                    self.client.unmanaged.timeout().as_secs()
98                ),
99            )
100        } else {
101            io::Error::other(err)
102        }
103    }
104
105    /// Either fetch the wheel or fetch and build the source distribution
106    ///
107    /// Returns a wheel that's compliant with the given platform tags.
108    ///
109    /// While hashes will be generated in some cases, hash-checking is only enforced for source
110    /// distributions, and should be enforced by the caller for wheels.
111    #[instrument(skip_all, fields(%dist))]
112    pub async fn get_or_build_wheel(
113        &self,
114        dist: &Dist,
115        tags: &Tags,
116        hashes: HashPolicy<'_>,
117    ) -> Result<LocalWheel, Error> {
118        match dist {
119            Dist::Built(built) => self.get_wheel(built, hashes).await,
120            Dist::Source(source) => self.build_wheel(source, tags, hashes).await,
121        }
122    }
123
124    /// Either fetch the only wheel metadata (directly from the index or with range requests) or
125    /// fetch and build the source distribution.
126    ///
127    /// While hashes will be generated in some cases, hash-checking is only enforced for source
128    /// distributions, and should be enforced by the caller for wheels.
129    #[instrument(skip_all, fields(%dist))]
130    pub async fn get_installed_metadata(
131        &self,
132        dist: &InstalledDist,
133    ) -> Result<ArchiveMetadata, Error> {
134        // If the metadata was provided by the user directly, prefer it.
135        if let Some(metadata) = self
136            .build_context
137            .dependency_metadata()
138            .get(dist.name(), Some(dist.version()))
139        {
140            return Ok(ArchiveMetadata::from_metadata23(metadata.clone()));
141        }
142
143        let metadata = dist
144            .read_metadata()
145            .map_err(|err| Error::ReadInstalled(Box::new(dist.clone()), err))?;
146
147        Ok(ArchiveMetadata::from_metadata23(metadata.clone()))
148    }
149
150    /// Either fetch the only wheel metadata (directly from the index or with range requests) or
151    /// fetch and build the source distribution.
152    ///
153    /// While hashes will be generated in some cases, hash-checking is only enforced for source
154    /// distributions, and should be enforced by the caller for wheels.
155    #[instrument(skip_all, fields(%dist))]
156    pub async fn get_or_build_wheel_metadata(
157        &self,
158        dist: &Dist,
159        hashes: HashPolicy<'_>,
160    ) -> Result<ArchiveMetadata, Error> {
161        match dist {
162            Dist::Built(built) => self.get_wheel_metadata(built, hashes).await,
163            Dist::Source(source) => {
164                self.build_wheel_metadata(&BuildableSource::Dist(source), hashes)
165                    .await
166            }
167        }
168    }
169
170    /// Fetch a wheel from the cache or download it from the index.
171    ///
172    /// While hashes will be generated in all cases, hash-checking is _not_ enforced and should
173    /// instead be enforced by the caller.
174    async fn get_wheel(
175        &self,
176        dist: &BuiltDist,
177        hashes: HashPolicy<'_>,
178    ) -> Result<LocalWheel, Error> {
179        match dist {
180            BuiltDist::Registry(wheels) => {
181                let wheel = wheels.best_wheel();
182                let WheelTarget {
183                    url,
184                    extension,
185                    size,
186                } = WheelTarget::try_from(&*wheel.file)?;
187
188                // Create a cache entry for the wheel.
189                let wheel_entry = self.build_context.cache().entry(
190                    CacheBucket::Wheels,
191                    WheelCache::Index(&wheel.index).wheel_dir(wheel.name().as_ref()),
192                    wheel.filename.cache_key(),
193                );
194
195                // If the URL is a file URL, load the wheel directly.
196                if url.scheme() == "file" {
197                    let path = url
198                        .to_file_path()
199                        .map_err(|()| Error::NonFileUrl(url.clone()))?;
200                    return self
201                        .load_wheel(
202                            &path,
203                            &wheel.filename,
204                            WheelExtension::Whl,
205                            wheel_entry,
206                            dist,
207                            hashes,
208                        )
209                        .await;
210                }
211
212                // Download and unzip.
213                match self
214                    .stream_wheel(
215                        url.clone(),
216                        dist.index(),
217                        &wheel.filename,
218                        extension,
219                        size,
220                        &wheel_entry,
221                        dist,
222                        hashes,
223                    )
224                    .await
225                {
226                    Ok(archive) => Ok(LocalWheel {
227                        dist: Dist::Built(dist.clone()),
228                        archive: self
229                            .build_context
230                            .cache()
231                            .archive(&archive.id)
232                            .into_boxed_path(),
233                        hashes: archive.hashes,
234                        filename: wheel.filename.clone(),
235                        cache: CacheInfo::default(),
236                        build: None,
237                    }),
238                    Err(Error::Extract(name, err)) => {
239                        if err.is_http_streaming_unsupported() {
240                            warn!(
241                                "Streaming unsupported for {dist}; downloading wheel to disk ({err})"
242                            );
243                        } else if err.is_http_streaming_failed() {
244                            warn!("Streaming failed for {dist}; downloading wheel to disk ({err})");
245                        } else {
246                            return Err(Error::Extract(name, err));
247                        }
248
249                        // If the request failed because streaming is unsupported, download the
250                        // wheel directly.
251                        let archive = self
252                            .download_wheel(
253                                url,
254                                dist.index(),
255                                &wheel.filename,
256                                extension,
257                                size,
258                                &wheel_entry,
259                                dist,
260                                hashes,
261                            )
262                            .await?;
263
264                        Ok(LocalWheel {
265                            dist: Dist::Built(dist.clone()),
266                            archive: self
267                                .build_context
268                                .cache()
269                                .archive(&archive.id)
270                                .into_boxed_path(),
271                            hashes: archive.hashes,
272                            filename: wheel.filename.clone(),
273                            cache: CacheInfo::default(),
274                            build: None,
275                        })
276                    }
277                    Err(err) => Err(err),
278                }
279            }
280
281            BuiltDist::DirectUrl(wheel) => {
282                // Create a cache entry for the wheel.
283                let wheel_entry = self.build_context.cache().entry(
284                    CacheBucket::Wheels,
285                    WheelCache::Url(&wheel.url).wheel_dir(wheel.name().as_ref()),
286                    wheel.filename.cache_key(),
287                );
288
289                // Download and unzip.
290                match self
291                    .stream_wheel(
292                        wheel.url.raw().clone(),
293                        None,
294                        &wheel.filename,
295                        WheelExtension::Whl,
296                        None,
297                        &wheel_entry,
298                        dist,
299                        hashes,
300                    )
301                    .await
302                {
303                    Ok(archive) => Ok(LocalWheel {
304                        dist: Dist::Built(dist.clone()),
305                        archive: self
306                            .build_context
307                            .cache()
308                            .archive(&archive.id)
309                            .into_boxed_path(),
310                        hashes: archive.hashes,
311                        filename: wheel.filename.clone(),
312                        cache: CacheInfo::default(),
313                        build: None,
314                    }),
315                    Err(Error::Client(err)) if err.is_http_streaming_unsupported() => {
316                        warn!(
317                            "Streaming unsupported for {dist}; downloading wheel to disk ({err})"
318                        );
319
320                        // If the request failed because streaming is unsupported, download the
321                        // wheel directly.
322                        let archive = self
323                            .download_wheel(
324                                wheel.url.raw().clone(),
325                                None,
326                                &wheel.filename,
327                                WheelExtension::Whl,
328                                None,
329                                &wheel_entry,
330                                dist,
331                                hashes,
332                            )
333                            .await?;
334                        Ok(LocalWheel {
335                            dist: Dist::Built(dist.clone()),
336                            archive: self
337                                .build_context
338                                .cache()
339                                .archive(&archive.id)
340                                .into_boxed_path(),
341                            hashes: archive.hashes,
342                            filename: wheel.filename.clone(),
343                            cache: CacheInfo::default(),
344                            build: None,
345                        })
346                    }
347                    Err(err) => Err(err),
348                }
349            }
350
351            BuiltDist::Path(wheel) => {
352                let cache_entry = self.build_context.cache().entry(
353                    CacheBucket::Wheels,
354                    WheelCache::Url(&wheel.url).wheel_dir(wheel.name().as_ref()),
355                    wheel.filename.cache_key(),
356                );
357
358                self.load_wheel(
359                    &wheel.install_path,
360                    &wheel.filename,
361                    WheelExtension::Whl,
362                    cache_entry,
363                    dist,
364                    hashes,
365                )
366                .await
367            }
368        }
369    }
370
371    /// Convert a source distribution into a wheel, fetching it from the cache or building it if
372    /// necessary.
373    ///
374    /// The returned wheel is guaranteed to come from a distribution with a matching hash, and
375    /// no build processes will be executed for distributions with mismatched hashes.
376    async fn build_wheel(
377        &self,
378        dist: &SourceDist,
379        tags: &Tags,
380        hashes: HashPolicy<'_>,
381    ) -> Result<LocalWheel, Error> {
382        let built_wheel = self
383            .builder
384            .download_and_build(&BuildableSource::Dist(dist), tags, hashes, &self.client)
385            .boxed_local()
386            .await?;
387
388        // Acquire the advisory lock.
389        #[cfg(windows)]
390        let _lock = {
391            let lock_entry = CacheEntry::new(
392                built_wheel.target.parent().unwrap(),
393                format!(
394                    "{}.lock",
395                    built_wheel.target.file_name().unwrap().to_str().unwrap()
396                ),
397            );
398            lock_entry.lock().await.map_err(Error::CacheWrite)?
399        };
400
401        // If the wheel was unzipped previously, respect it. Source distributions are
402        // cached under a unique revision ID, so unzipped directories are never stale.
403        match self.build_context.cache().resolve_link(&built_wheel.target) {
404            Ok(archive) => {
405                return Ok(LocalWheel {
406                    dist: Dist::Source(dist.clone()),
407                    archive: archive.into_boxed_path(),
408                    filename: built_wheel.filename,
409                    hashes: built_wheel.hashes,
410                    cache: built_wheel.cache_info,
411                    build: Some(built_wheel.build_info),
412                });
413            }
414            Err(err) if err.kind() == io::ErrorKind::NotFound => {}
415            Err(err) => return Err(Error::CacheRead(err)),
416        }
417
418        // Otherwise, unzip the wheel.
419        let id = self
420            .unzip_wheel(&built_wheel.path, &built_wheel.target)
421            .await?;
422
423        Ok(LocalWheel {
424            dist: Dist::Source(dist.clone()),
425            archive: self.build_context.cache().archive(&id).into_boxed_path(),
426            hashes: built_wheel.hashes,
427            filename: built_wheel.filename,
428            cache: built_wheel.cache_info,
429            build: Some(built_wheel.build_info),
430        })
431    }
432
433    /// Fetch the wheel metadata from the index, or from the cache if possible.
434    ///
435    /// While hashes will be generated in some cases, hash-checking is _not_ enforced and should
436    /// instead be enforced by the caller.
437    async fn get_wheel_metadata(
438        &self,
439        dist: &BuiltDist,
440        hashes: HashPolicy<'_>,
441    ) -> Result<ArchiveMetadata, Error> {
442        // If hash generation is enabled, and the distribution isn't hosted on a registry, get the
443        // entire wheel to ensure that the hashes are included in the response. If the distribution
444        // is hosted on an index, the hashes will be included in the simple metadata response.
445        // For hash _validation_, callers are expected to enforce the policy when retrieving the
446        // wheel.
447        //
448        // Historically, for `uv pip compile --universal`, we also generate hashes for
449        // registry-based distributions when the relevant registry doesn't provide them. This was
450        // motivated by `--find-links`. We continue that behavior (under `HashGeneration::All`) for
451        // backwards compatibility, but it's a little dubious, since we're only hashing _one_
452        // distribution here (as opposed to hashing all distributions for the version), and it may
453        // not even be a compatible distribution!
454        //
455        // TODO(charlie): Request the hashes via a separate method, to reduce the coupling in this API.
456        if hashes.is_generate(dist) {
457            let wheel = self.get_wheel(dist, hashes).await?;
458            // If the metadata was provided by the user directly, prefer it.
459            let metadata = if let Some(metadata) = self
460                .build_context
461                .dependency_metadata()
462                .get(dist.name(), Some(dist.version()))
463            {
464                metadata.clone()
465            } else {
466                wheel.metadata()?
467            };
468            let hashes = wheel.hashes;
469            return Ok(ArchiveMetadata {
470                metadata: Metadata::from_metadata23(metadata),
471                hashes,
472            });
473        }
474
475        // If the metadata was provided by the user directly, prefer it.
476        if let Some(metadata) = self
477            .build_context
478            .dependency_metadata()
479            .get(dist.name(), Some(dist.version()))
480        {
481            return Ok(ArchiveMetadata::from_metadata23(metadata.clone()));
482        }
483
484        let result = self
485            .client
486            .managed(|client| {
487                client
488                    .wheel_metadata(dist, self.build_context.capabilities())
489                    .boxed_local()
490            })
491            .await;
492
493        match result {
494            Ok(metadata) => {
495                // Validate that the metadata is consistent with the distribution.
496                Ok(ArchiveMetadata::from_metadata23(metadata))
497            }
498            Err(err) if err.is_http_streaming_unsupported() => {
499                warn!(
500                    "Streaming unsupported when fetching metadata for {dist}; downloading wheel directly ({err})"
501                );
502
503                // If the request failed due to an error that could be resolved by
504                // downloading the wheel directly, try that.
505                let wheel = self.get_wheel(dist, hashes).await?;
506                let metadata = wheel.metadata()?;
507                let hashes = wheel.hashes;
508                Ok(ArchiveMetadata {
509                    metadata: Metadata::from_metadata23(metadata),
510                    hashes,
511                })
512            }
513            Err(err) => Err(err.into()),
514        }
515    }
516
517    /// Build the wheel metadata for a source distribution, or fetch it from the cache if possible.
518    ///
519    /// The returned metadata is guaranteed to come from a distribution with a matching hash, and
520    /// no build processes will be executed for distributions with mismatched hashes.
521    pub async fn build_wheel_metadata(
522        &self,
523        source: &BuildableSource<'_>,
524        hashes: HashPolicy<'_>,
525    ) -> Result<ArchiveMetadata, Error> {
526        // If the metadata was provided by the user directly, prefer it.
527        if let Some(dist) = source.as_dist() {
528            if let Some(metadata) = self
529                .build_context
530                .dependency_metadata()
531                .get(dist.name(), dist.version())
532            {
533                // If we skipped the build, we should still resolve any Git dependencies to precise
534                // commits.
535                self.builder.resolve_revision(source, &self.client).await?;
536
537                return Ok(ArchiveMetadata::from_metadata23(metadata.clone()));
538            }
539        }
540
541        let metadata = self
542            .builder
543            .download_and_build_metadata(source, hashes, &self.client)
544            .boxed_local()
545            .await?;
546
547        Ok(metadata)
548    }
549
550    /// Return the [`RequiresDist`] from a `pyproject.toml`, if it can be statically extracted.
551    pub async fn requires_dist(
552        &self,
553        path: &Path,
554        pyproject_toml: &PyProjectToml,
555    ) -> Result<Option<RequiresDist>, Error> {
556        self.builder
557            .source_tree_requires_dist(path, pyproject_toml)
558            .await
559    }
560
561    /// Stream a wheel from a URL, unzipping it into the cache as it's downloaded.
562    async fn stream_wheel(
563        &self,
564        url: DisplaySafeUrl,
565        index: Option<&IndexUrl>,
566        filename: &WheelFilename,
567        extension: WheelExtension,
568        size: Option<u64>,
569        wheel_entry: &CacheEntry,
570        dist: &BuiltDist,
571        hashes: HashPolicy<'_>,
572    ) -> Result<Archive, Error> {
573        // Acquire an advisory lock, to guard against concurrent writes.
574        #[cfg(windows)]
575        let _lock = {
576            let lock_entry = wheel_entry.with_file(format!("{}.lock", filename.stem()));
577            lock_entry.lock().await.map_err(Error::CacheWrite)?
578        };
579
580        // Create an entry for the HTTP cache.
581        let http_entry = wheel_entry.with_file(format!("{}.http", filename.cache_key()));
582
583        let download = |response: reqwest::Response| {
584            async {
585                let size = size.or_else(|| content_length(&response));
586
587                let progress = self
588                    .reporter
589                    .as_ref()
590                    .map(|reporter| (reporter, reporter.on_download_start(dist.name(), size)));
591
592                let reader = response
593                    .bytes_stream()
594                    .map_err(|err| self.handle_response_errors(err))
595                    .into_async_read();
596
597                // Create a hasher for each hash algorithm.
598                let algorithms = hashes.algorithms();
599                let mut hashers = algorithms.into_iter().map(Hasher::from).collect::<Vec<_>>();
600                let mut hasher = uv_extract::hash::HashReader::new(reader.compat(), &mut hashers);
601
602                // Download and unzip the wheel to a temporary directory.
603                let temp_dir = tempfile::tempdir_in(self.build_context.cache().root())
604                    .map_err(Error::CacheWrite)?;
605
606                match progress {
607                    Some((reporter, progress)) => {
608                        let mut reader = ProgressReader::new(&mut hasher, progress, &**reporter);
609                        match extension {
610                            WheelExtension::Whl => {
611                                uv_extract::stream::unzip(&mut reader, temp_dir.path())
612                                    .await
613                                    .map_err(|err| Error::Extract(filename.to_string(), err))?;
614                            }
615                            WheelExtension::WhlZst => {
616                                uv_extract::stream::untar_zst(&mut reader, temp_dir.path())
617                                    .await
618                                    .map_err(|err| Error::Extract(filename.to_string(), err))?;
619                            }
620                        }
621                    }
622                    None => match extension {
623                        WheelExtension::Whl => {
624                            uv_extract::stream::unzip(&mut hasher, temp_dir.path())
625                                .await
626                                .map_err(|err| Error::Extract(filename.to_string(), err))?;
627                        }
628                        WheelExtension::WhlZst => {
629                            uv_extract::stream::untar_zst(&mut hasher, temp_dir.path())
630                                .await
631                                .map_err(|err| Error::Extract(filename.to_string(), err))?;
632                        }
633                    },
634                }
635
636                // If necessary, exhaust the reader to compute the hash.
637                if !hashes.is_none() {
638                    hasher.finish().await.map_err(Error::HashExhaustion)?;
639                }
640
641                // Persist the temporary directory to the directory store.
642                let id = self
643                    .build_context
644                    .cache()
645                    .persist(temp_dir.keep(), wheel_entry.path())
646                    .await
647                    .map_err(Error::CacheRead)?;
648
649                if let Some((reporter, progress)) = progress {
650                    reporter.on_download_complete(dist.name(), progress);
651                }
652
653                Ok(Archive::new(
654                    id,
655                    hashers.into_iter().map(HashDigest::from).collect(),
656                    filename.clone(),
657                ))
658            }
659            .instrument(info_span!("wheel", wheel = %dist))
660        };
661
662        // Fetch the archive from the cache, or download it if necessary.
663        let req = self.request(url.clone())?;
664
665        // Determine the cache control policy for the URL.
666        let cache_control = match self.client.unmanaged.connectivity() {
667            Connectivity::Online => {
668                if let Some(header) = index.and_then(|index| {
669                    self.build_context
670                        .locations()
671                        .artifact_cache_control_for(index)
672                }) {
673                    CacheControl::Override(header)
674                } else {
675                    CacheControl::from(
676                        self.build_context
677                            .cache()
678                            .freshness(&http_entry, Some(&filename.name), None)
679                            .map_err(Error::CacheRead)?,
680                    )
681                }
682            }
683            Connectivity::Offline => CacheControl::AllowStale,
684        };
685
686        let archive = self
687            .client
688            .managed(|client| {
689                client.cached_client().get_serde_with_retry(
690                    req,
691                    &http_entry,
692                    cache_control,
693                    download,
694                )
695            })
696            .await
697            .map_err(|err| match err {
698                CachedClientError::Callback { err, .. } => err,
699                CachedClientError::Client { err, .. } => Error::Client(err),
700            })?;
701
702        // If the archive is missing the required hashes, or has since been removed, force a refresh.
703        let archive = Some(archive)
704            .filter(|archive| archive.has_digests(hashes))
705            .filter(|archive| archive.exists(self.build_context.cache()));
706
707        let archive = if let Some(archive) = archive {
708            archive
709        } else {
710            self.client
711                .managed(async |client| {
712                    client
713                        .cached_client()
714                        .skip_cache_with_retry(
715                            self.request(url)?,
716                            &http_entry,
717                            cache_control,
718                            download,
719                        )
720                        .await
721                        .map_err(|err| match err {
722                            CachedClientError::Callback { err, .. } => err,
723                            CachedClientError::Client { err, .. } => Error::Client(err),
724                        })
725                })
726                .await?
727        };
728
729        Ok(archive)
730    }
731
732    /// Download a wheel from a URL, then unzip it into the cache.
733    async fn download_wheel(
734        &self,
735        url: DisplaySafeUrl,
736        index: Option<&IndexUrl>,
737        filename: &WheelFilename,
738        extension: WheelExtension,
739        size: Option<u64>,
740        wheel_entry: &CacheEntry,
741        dist: &BuiltDist,
742        hashes: HashPolicy<'_>,
743    ) -> Result<Archive, Error> {
744        // Acquire an advisory lock, to guard against concurrent writes.
745        #[cfg(windows)]
746        let _lock = {
747            let lock_entry = wheel_entry.with_file(format!("{}.lock", filename.stem()));
748            lock_entry.lock().await.map_err(Error::CacheWrite)?
749        };
750
751        // Create an entry for the HTTP cache.
752        let http_entry = wheel_entry.with_file(format!("{}.http", filename.cache_key()));
753
754        let download = |response: reqwest::Response| {
755            async {
756                let size = size.or_else(|| content_length(&response));
757
758                let progress = self
759                    .reporter
760                    .as_ref()
761                    .map(|reporter| (reporter, reporter.on_download_start(dist.name(), size)));
762
763                let reader = response
764                    .bytes_stream()
765                    .map_err(|err| self.handle_response_errors(err))
766                    .into_async_read();
767
768                // Download the wheel to a temporary file.
769                let temp_file = tempfile::tempfile_in(self.build_context.cache().root())
770                    .map_err(Error::CacheWrite)?;
771                let mut writer = tokio::io::BufWriter::new(fs_err::tokio::File::from_std(
772                    // It's an unnamed file on Linux so that's the best approximation.
773                    fs_err::File::from_parts(temp_file, self.build_context.cache().root()),
774                ));
775
776                match progress {
777                    Some((reporter, progress)) => {
778                        // Wrap the reader in a progress reporter. This will report 100% progress
779                        // after the download is complete, even if we still have to unzip and hash
780                        // part of the file.
781                        let mut reader =
782                            ProgressReader::new(reader.compat(), progress, &**reporter);
783
784                        tokio::io::copy(&mut reader, &mut writer)
785                            .await
786                            .map_err(Error::CacheWrite)?;
787                    }
788                    None => {
789                        tokio::io::copy(&mut reader.compat(), &mut writer)
790                            .await
791                            .map_err(Error::CacheWrite)?;
792                    }
793                }
794
795                // Unzip the wheel to a temporary directory.
796                let temp_dir = tempfile::tempdir_in(self.build_context.cache().root())
797                    .map_err(Error::CacheWrite)?;
798                let mut file = writer.into_inner();
799                file.seek(io::SeekFrom::Start(0))
800                    .await
801                    .map_err(Error::CacheWrite)?;
802
803                // If no hashes are required, parallelize the unzip operation.
804                let hashes = if hashes.is_none() {
805                    let file = file.into_std().await;
806                    tokio::task::spawn_blocking({
807                        let target = temp_dir.path().to_owned();
808                        move || -> Result<(), uv_extract::Error> {
809                            // Unzip the wheel into a temporary directory.
810                            match extension {
811                                WheelExtension::Whl => {
812                                    uv_extract::unzip(file, &target)?;
813                                }
814                                WheelExtension::WhlZst => {
815                                    uv_extract::stream::untar_zst_file(file, &target)?;
816                                }
817                            }
818                            Ok(())
819                        }
820                    })
821                    .await?
822                    .map_err(|err| Error::Extract(filename.to_string(), err))?;
823
824                    HashDigests::empty()
825                } else {
826                    // Create a hasher for each hash algorithm.
827                    let algorithms = hashes.algorithms();
828                    let mut hashers = algorithms.into_iter().map(Hasher::from).collect::<Vec<_>>();
829                    let mut hasher = uv_extract::hash::HashReader::new(file, &mut hashers);
830
831                    match extension {
832                        WheelExtension::Whl => {
833                            uv_extract::stream::unzip(&mut hasher, temp_dir.path())
834                                .await
835                                .map_err(|err| Error::Extract(filename.to_string(), err))?;
836                        }
837                        WheelExtension::WhlZst => {
838                            uv_extract::stream::untar_zst(&mut hasher, temp_dir.path())
839                                .await
840                                .map_err(|err| Error::Extract(filename.to_string(), err))?;
841                        }
842                    }
843
844                    // If necessary, exhaust the reader to compute the hash.
845                    hasher.finish().await.map_err(Error::HashExhaustion)?;
846
847                    hashers.into_iter().map(HashDigest::from).collect()
848                };
849
850                // Persist the temporary directory to the directory store.
851                let id = self
852                    .build_context
853                    .cache()
854                    .persist(temp_dir.keep(), wheel_entry.path())
855                    .await
856                    .map_err(Error::CacheRead)?;
857
858                if let Some((reporter, progress)) = progress {
859                    reporter.on_download_complete(dist.name(), progress);
860                }
861
862                Ok(Archive::new(id, hashes, filename.clone()))
863            }
864            .instrument(info_span!("wheel", wheel = %dist))
865        };
866
867        // Fetch the archive from the cache, or download it if necessary.
868        let req = self.request(url.clone())?;
869
870        // Determine the cache control policy for the URL.
871        let cache_control = match self.client.unmanaged.connectivity() {
872            Connectivity::Online => {
873                if let Some(header) = index.and_then(|index| {
874                    self.build_context
875                        .locations()
876                        .artifact_cache_control_for(index)
877                }) {
878                    CacheControl::Override(header)
879                } else {
880                    CacheControl::from(
881                        self.build_context
882                            .cache()
883                            .freshness(&http_entry, Some(&filename.name), None)
884                            .map_err(Error::CacheRead)?,
885                    )
886                }
887            }
888            Connectivity::Offline => CacheControl::AllowStale,
889        };
890
891        let archive = self
892            .client
893            .managed(|client| {
894                client.cached_client().get_serde_with_retry(
895                    req,
896                    &http_entry,
897                    cache_control,
898                    download,
899                )
900            })
901            .await
902            .map_err(|err| match err {
903                CachedClientError::Callback { err, .. } => err,
904                CachedClientError::Client { err, .. } => Error::Client(err),
905            })?;
906
907        // If the archive is missing the required hashes, or has since been removed, force a refresh.
908        let archive = Some(archive)
909            .filter(|archive| archive.has_digests(hashes))
910            .filter(|archive| archive.exists(self.build_context.cache()));
911
912        let archive = if let Some(archive) = archive {
913            archive
914        } else {
915            self.client
916                .managed(async |client| {
917                    client
918                        .cached_client()
919                        .skip_cache_with_retry(
920                            self.request(url)?,
921                            &http_entry,
922                            cache_control,
923                            download,
924                        )
925                        .await
926                        .map_err(|err| match err {
927                            CachedClientError::Callback { err, .. } => err,
928                            CachedClientError::Client { err, .. } => Error::Client(err),
929                        })
930                })
931                .await?
932        };
933
934        Ok(archive)
935    }
936
937    /// Load a wheel from a local path.
938    async fn load_wheel(
939        &self,
940        path: &Path,
941        filename: &WheelFilename,
942        extension: WheelExtension,
943        wheel_entry: CacheEntry,
944        dist: &BuiltDist,
945        hashes: HashPolicy<'_>,
946    ) -> Result<LocalWheel, Error> {
947        #[cfg(windows)]
948        let _lock = {
949            let lock_entry = wheel_entry.with_file(format!("{}.lock", filename.stem()));
950            lock_entry.lock().await.map_err(Error::CacheWrite)?
951        };
952
953        // Determine the last-modified time of the wheel.
954        let modified = Timestamp::from_path(path).map_err(Error::CacheRead)?;
955
956        // Attempt to read the archive pointer from the cache.
957        let pointer_entry = wheel_entry.with_file(format!("{}.rev", filename.cache_key()));
958        let pointer = LocalArchivePointer::read_from(&pointer_entry)?;
959
960        // Extract the archive from the pointer.
961        let archive = pointer
962            .filter(|pointer| pointer.is_up_to_date(modified))
963            .map(LocalArchivePointer::into_archive)
964            .filter(|archive| archive.has_digests(hashes));
965
966        // If the file is already unzipped, and the cache is up-to-date, return it.
967        if let Some(archive) = archive {
968            Ok(LocalWheel {
969                dist: Dist::Built(dist.clone()),
970                archive: self
971                    .build_context
972                    .cache()
973                    .archive(&archive.id)
974                    .into_boxed_path(),
975                hashes: archive.hashes,
976                filename: filename.clone(),
977                cache: CacheInfo::from_timestamp(modified),
978                build: None,
979            })
980        } else if hashes.is_none() {
981            // Otherwise, unzip the wheel.
982            let archive = Archive::new(
983                self.unzip_wheel(path, wheel_entry.path()).await?,
984                HashDigests::empty(),
985                filename.clone(),
986            );
987
988            // Write the archive pointer to the cache.
989            let pointer = LocalArchivePointer {
990                timestamp: modified,
991                archive: archive.clone(),
992            };
993            pointer.write_to(&pointer_entry).await?;
994
995            Ok(LocalWheel {
996                dist: Dist::Built(dist.clone()),
997                archive: self
998                    .build_context
999                    .cache()
1000                    .archive(&archive.id)
1001                    .into_boxed_path(),
1002                hashes: archive.hashes,
1003                filename: filename.clone(),
1004                cache: CacheInfo::from_timestamp(modified),
1005                build: None,
1006            })
1007        } else {
1008            // If necessary, compute the hashes of the wheel.
1009            let file = fs_err::tokio::File::open(path)
1010                .await
1011                .map_err(Error::CacheRead)?;
1012            let temp_dir = tempfile::tempdir_in(self.build_context.cache().root())
1013                .map_err(Error::CacheWrite)?;
1014
1015            // Create a hasher for each hash algorithm.
1016            let algorithms = hashes.algorithms();
1017            let mut hashers = algorithms.into_iter().map(Hasher::from).collect::<Vec<_>>();
1018            let mut hasher = uv_extract::hash::HashReader::new(file, &mut hashers);
1019
1020            // Unzip the wheel to a temporary directory.
1021            match extension {
1022                WheelExtension::Whl => {
1023                    uv_extract::stream::unzip(&mut hasher, temp_dir.path())
1024                        .await
1025                        .map_err(|err| Error::Extract(filename.to_string(), err))?;
1026                }
1027                WheelExtension::WhlZst => {
1028                    uv_extract::stream::untar_zst(&mut hasher, temp_dir.path())
1029                        .await
1030                        .map_err(|err| Error::Extract(filename.to_string(), err))?;
1031                }
1032            }
1033
1034            // Exhaust the reader to compute the hash.
1035            hasher.finish().await.map_err(Error::HashExhaustion)?;
1036
1037            let hashes = hashers.into_iter().map(HashDigest::from).collect();
1038
1039            // Persist the temporary directory to the directory store.
1040            let id = self
1041                .build_context
1042                .cache()
1043                .persist(temp_dir.keep(), wheel_entry.path())
1044                .await
1045                .map_err(Error::CacheWrite)?;
1046
1047            // Create an archive.
1048            let archive = Archive::new(id, hashes, filename.clone());
1049
1050            // Write the archive pointer to the cache.
1051            let pointer = LocalArchivePointer {
1052                timestamp: modified,
1053                archive: archive.clone(),
1054            };
1055            pointer.write_to(&pointer_entry).await?;
1056
1057            Ok(LocalWheel {
1058                dist: Dist::Built(dist.clone()),
1059                archive: self
1060                    .build_context
1061                    .cache()
1062                    .archive(&archive.id)
1063                    .into_boxed_path(),
1064                hashes: archive.hashes,
1065                filename: filename.clone(),
1066                cache: CacheInfo::from_timestamp(modified),
1067                build: None,
1068            })
1069        }
1070    }
1071
1072    /// Unzip a wheel into the cache, returning the path to the unzipped directory.
1073    async fn unzip_wheel(&self, path: &Path, target: &Path) -> Result<ArchiveId, Error> {
1074        let temp_dir = tokio::task::spawn_blocking({
1075            let path = path.to_owned();
1076            let root = self.build_context.cache().root().to_path_buf();
1077            move || -> Result<TempDir, Error> {
1078                // Unzip the wheel into a temporary directory.
1079                let temp_dir = tempfile::tempdir_in(root).map_err(Error::CacheWrite)?;
1080                let reader = fs_err::File::open(&path).map_err(Error::CacheWrite)?;
1081                uv_extract::unzip(reader, temp_dir.path())
1082                    .map_err(|err| Error::Extract(path.to_string_lossy().into_owned(), err))?;
1083                Ok(temp_dir)
1084            }
1085        })
1086        .await??;
1087
1088        // Persist the temporary directory to the directory store.
1089        let id = self
1090            .build_context
1091            .cache()
1092            .persist(temp_dir.keep(), target)
1093            .await
1094            .map_err(Error::CacheWrite)?;
1095
1096        Ok(id)
1097    }
1098
1099    /// Returns a GET [`reqwest::Request`] for the given URL.
1100    fn request(&self, url: DisplaySafeUrl) -> Result<reqwest::Request, reqwest::Error> {
1101        self.client
1102            .unmanaged
1103            .uncached_client(&url)
1104            .get(Url::from(url))
1105            .header(
1106                // `reqwest` defaults to accepting compressed responses.
1107                // Specify identity encoding to get consistent .whl downloading
1108                // behavior from servers. ref: https://github.com/pypa/pip/pull/1688
1109                "accept-encoding",
1110                reqwest::header::HeaderValue::from_static("identity"),
1111            )
1112            .build()
1113    }
1114
1115    /// Return the [`ManagedClient`] used by this resolver.
1116    pub fn client(&self) -> &ManagedClient<'a> {
1117        &self.client
1118    }
1119}
1120
1121/// A wrapper around `RegistryClient` that manages a concurrency limit.
1122pub struct ManagedClient<'a> {
1123    pub unmanaged: &'a RegistryClient,
1124    control: Semaphore,
1125}
1126
1127impl<'a> ManagedClient<'a> {
1128    /// Create a new `ManagedClient` using the given client and concurrency limit.
1129    fn new(client: &'a RegistryClient, concurrency: usize) -> Self {
1130        ManagedClient {
1131            unmanaged: client,
1132            control: Semaphore::new(concurrency),
1133        }
1134    }
1135
1136    /// Perform a request using the client, respecting the concurrency limit.
1137    ///
1138    /// If the concurrency limit has been reached, this method will wait until a pending
1139    /// operation completes before executing the closure.
1140    pub async fn managed<F, T>(&self, f: impl FnOnce(&'a RegistryClient) -> F) -> T
1141    where
1142        F: Future<Output = T>,
1143    {
1144        let _permit = self.control.acquire().await.unwrap();
1145        f(self.unmanaged).await
1146    }
1147
1148    /// Perform a request using a client that internally manages the concurrency limit.
1149    ///
1150    /// The callback is passed the client and a semaphore. It must acquire the semaphore before
1151    /// any request through the client and drop it after.
1152    ///
1153    /// This method serves as an escape hatch for functions that may want to send multiple requests
1154    /// in parallel.
1155    pub async fn manual<F, T>(&'a self, f: impl FnOnce(&'a RegistryClient, &'a Semaphore) -> F) -> T
1156    where
1157        F: Future<Output = T>,
1158    {
1159        f(self.unmanaged, &self.control).await
1160    }
1161}
1162
1163/// Returns the value of the `Content-Length` header from the [`reqwest::Response`], if present.
1164fn content_length(response: &reqwest::Response) -> Option<u64> {
1165    response
1166        .headers()
1167        .get(reqwest::header::CONTENT_LENGTH)
1168        .and_then(|val| val.to_str().ok())
1169        .and_then(|val| val.parse::<u64>().ok())
1170}
1171
1172/// An asynchronous reader that reports progress as bytes are read.
1173struct ProgressReader<'a, R> {
1174    reader: R,
1175    index: usize,
1176    reporter: &'a dyn Reporter,
1177}
1178
1179impl<'a, R> ProgressReader<'a, R> {
1180    /// Create a new [`ProgressReader`] that wraps another reader.
1181    fn new(reader: R, index: usize, reporter: &'a dyn Reporter) -> Self {
1182        Self {
1183            reader,
1184            index,
1185            reporter,
1186        }
1187    }
1188}
1189
1190impl<R> AsyncRead for ProgressReader<'_, R>
1191where
1192    R: AsyncRead + Unpin,
1193{
1194    fn poll_read(
1195        mut self: Pin<&mut Self>,
1196        cx: &mut Context<'_>,
1197        buf: &mut ReadBuf<'_>,
1198    ) -> Poll<io::Result<()>> {
1199        Pin::new(&mut self.as_mut().reader)
1200            .poll_read(cx, buf)
1201            .map_ok(|()| {
1202                self.reporter
1203                    .on_download_progress(self.index, buf.filled().len() as u64);
1204            })
1205    }
1206}
1207
1208/// A pointer to an archive in the cache, fetched from an HTTP archive.
1209///
1210/// Encoded with `MsgPack`, and represented on disk by a `.http` file.
1211#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1212pub struct HttpArchivePointer {
1213    archive: Archive,
1214}
1215
1216impl HttpArchivePointer {
1217    /// Read an [`HttpArchivePointer`] from the cache.
1218    pub fn read_from(path: impl AsRef<Path>) -> Result<Option<Self>, Error> {
1219        match fs_err::File::open(path.as_ref()) {
1220            Ok(file) => {
1221                let data = DataWithCachePolicy::from_reader(file)?.data;
1222                let archive = rmp_serde::from_slice::<Archive>(&data)?;
1223                Ok(Some(Self { archive }))
1224            }
1225            Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None),
1226            Err(err) => Err(Error::CacheRead(err)),
1227        }
1228    }
1229
1230    /// Return the [`Archive`] from the pointer.
1231    pub fn into_archive(self) -> Archive {
1232        self.archive
1233    }
1234
1235    /// Return the [`CacheInfo`] from the pointer.
1236    pub fn to_cache_info(&self) -> CacheInfo {
1237        CacheInfo::default()
1238    }
1239
1240    /// Return the [`BuildInfo`] from the pointer.
1241    pub fn to_build_info(&self) -> Option<BuildInfo> {
1242        None
1243    }
1244}
1245
1246/// A pointer to an archive in the cache, fetched from a local path.
1247///
1248/// Encoded with `MsgPack`, and represented on disk by a `.rev` file.
1249#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1250pub struct LocalArchivePointer {
1251    timestamp: Timestamp,
1252    archive: Archive,
1253}
1254
1255impl LocalArchivePointer {
1256    /// Read an [`LocalArchivePointer`] from the cache.
1257    pub fn read_from(path: impl AsRef<Path>) -> Result<Option<Self>, Error> {
1258        match fs_err::read(path) {
1259            Ok(cached) => Ok(Some(rmp_serde::from_slice::<Self>(&cached)?)),
1260            Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None),
1261            Err(err) => Err(Error::CacheRead(err)),
1262        }
1263    }
1264
1265    /// Write an [`LocalArchivePointer`] to the cache.
1266    pub async fn write_to(&self, entry: &CacheEntry) -> Result<(), Error> {
1267        write_atomic(entry.path(), rmp_serde::to_vec(&self)?)
1268            .await
1269            .map_err(Error::CacheWrite)
1270    }
1271
1272    /// Returns `true` if the archive is up-to-date with the given modified timestamp.
1273    pub fn is_up_to_date(&self, modified: Timestamp) -> bool {
1274        self.timestamp == modified
1275    }
1276
1277    /// Return the [`Archive`] from the pointer.
1278    pub fn into_archive(self) -> Archive {
1279        self.archive
1280    }
1281
1282    /// Return the [`CacheInfo`] from the pointer.
1283    pub fn to_cache_info(&self) -> CacheInfo {
1284        CacheInfo::from_timestamp(self.timestamp)
1285    }
1286
1287    /// Return the [`BuildInfo`] from the pointer.
1288    pub fn to_build_info(&self) -> Option<BuildInfo> {
1289        None
1290    }
1291}
1292
1293#[derive(Debug, Clone)]
1294struct WheelTarget {
1295    /// The URL from which the wheel can be downloaded.
1296    url: DisplaySafeUrl,
1297    /// The expected extension of the wheel file.
1298    extension: WheelExtension,
1299    /// The expected size of the wheel file, if known.
1300    size: Option<u64>,
1301}
1302
1303impl TryFrom<&File> for WheelTarget {
1304    type Error = ToUrlError;
1305
1306    /// Determine the [`WheelTarget`] from a [`File`].
1307    fn try_from(file: &File) -> Result<Self, Self::Error> {
1308        let url = file.url.to_url()?;
1309        if let Some(zstd) = file.zstd.as_ref() {
1310            Ok(Self {
1311                url: add_tar_zst_extension(url),
1312                extension: WheelExtension::WhlZst,
1313                size: zstd.size,
1314            })
1315        } else {
1316            Ok(Self {
1317                url,
1318                extension: WheelExtension::Whl,
1319                size: file.size,
1320            })
1321        }
1322    }
1323}
1324
1325#[derive(Debug, Copy, Clone, PartialEq, Eq)]
1326enum WheelExtension {
1327    /// A `.whl` file.
1328    Whl,
1329    /// A `.whl.tar.zst` file.
1330    WhlZst,
1331}
1332
1333/// Add `.tar.zst` to the end of the URL path, if it doesn't already exist.
1334#[must_use]
1335fn add_tar_zst_extension(mut url: DisplaySafeUrl) -> DisplaySafeUrl {
1336    let mut path = url.path().to_string();
1337
1338    if !path.ends_with(".tar.zst") {
1339        path.push_str(".tar.zst");
1340    }
1341
1342    url.set_path(&path);
1343    url
1344}
1345
1346#[cfg(test)]
1347mod tests {
1348    use super::*;
1349
1350    #[test]
1351    fn test_add_tar_zst_extension() {
1352        let url =
1353            DisplaySafeUrl::parse("https://files.pythonhosted.org/flask-3.1.0-py3-none-any.whl")
1354                .unwrap();
1355        assert_eq!(
1356            add_tar_zst_extension(url).as_str(),
1357            "https://files.pythonhosted.org/flask-3.1.0-py3-none-any.whl.tar.zst"
1358        );
1359
1360        let url = DisplaySafeUrl::parse(
1361            "https://files.pythonhosted.org/flask-3.1.0-py3-none-any.whl.tar.zst",
1362        )
1363        .unwrap();
1364        assert_eq!(
1365            add_tar_zst_extension(url).as_str(),
1366            "https://files.pythonhosted.org/flask-3.1.0-py3-none-any.whl.tar.zst"
1367        );
1368
1369        let url = DisplaySafeUrl::parse(
1370            "https://files.pythonhosted.org/flask-3.1.0%2Bcu124-py3-none-any.whl",
1371        )
1372        .unwrap();
1373        assert_eq!(
1374            add_tar_zst_extension(url).as_str(),
1375            "https://files.pythonhosted.org/flask-3.1.0%2Bcu124-py3-none-any.whl.tar.zst"
1376        );
1377    }
1378}