Skip to main content

uv_distribution/
distribution_database.rs

1use std::future::Future;
2use std::io;
3use std::path::Path;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use futures::{FutureExt, TryStreamExt};
9use tokio::io::{AsyncRead, AsyncSeekExt, ReadBuf};
10use tokio::sync::Semaphore;
11use tokio_util::compat::FuturesAsyncReadCompatExt;
12use tracing::{Instrument, info_span, instrument, warn};
13use url::Url;
14
15use uv_cache::{ArchiveId, CacheBucket, CacheEntry, WheelCache};
16use uv_cache_info::{CacheInfo, Timestamp};
17use uv_client::{
18    CacheControl, CachedClientError, Connectivity, DataWithCachePolicy, RegistryClient,
19};
20use uv_distribution_filename::{SourceDistExtension, WheelFilename};
21use uv_distribution_types::{
22    BuildInfo, BuildableSource, BuiltDist, Dist, DistRef, File, HashPolicy, Hashed, IndexUrl,
23    InstalledDist, Name, SourceDist, ToUrlError,
24};
25use uv_extract::hash::Hasher;
26use uv_fs::write_atomic;
27use uv_install_wheel::validate_and_heal_record;
28use uv_platform_tags::Tags;
29use uv_pypi_types::{HashDigest, HashDigests, PyProjectToml};
30use uv_redacted::DisplaySafeUrl;
31use uv_types::{BuildContext, BuildStack};
32use uv_warnings::warn_user_once;
33
34use crate::archive::Archive;
35use uv_python::PythonVariant;
36
37use crate::error::PythonVersion;
38use crate::metadata::{ArchiveMetadata, Metadata};
39use crate::source::SourceDistributionBuilder;
40use crate::{Error, LocalWheel, Reporter, RequiresDist};
41
42/// A cached high-level interface to convert distributions (a requirement resolved to a location)
43/// to a wheel or wheel metadata.
44///
45/// For wheel metadata, this happens by either fetching the metadata from the remote wheel or by
46/// building the source distribution. For wheel files, either the wheel is downloaded or a source
47/// distribution is downloaded, built and the new wheel gets returned.
48///
49/// All kinds of wheel sources (index, URL, path) and source distribution source (index, URL, path,
50/// Git) are supported.
51///
52/// This struct also has the task of acquiring locks around source dist builds in general and git
53/// operation especially, as well as respecting concurrency limits.
54pub struct DistributionDatabase<'a, Context: BuildContext> {
55    build_context: &'a Context,
56    builder: SourceDistributionBuilder<'a, Context>,
57    client: ManagedClient<'a>,
58    reporter: Option<Arc<dyn Reporter>>,
59}
60
61impl<'a, Context: BuildContext> DistributionDatabase<'a, Context> {
62    pub fn new(
63        client: &'a RegistryClient,
64        build_context: &'a Context,
65        downloads_semaphore: Arc<Semaphore>,
66    ) -> Self {
67        Self {
68            build_context,
69            builder: SourceDistributionBuilder::new(build_context),
70            client: ManagedClient::new(client, downloads_semaphore),
71            reporter: None,
72        }
73    }
74
75    /// Set the build stack to use for the [`DistributionDatabase`].
76    #[must_use]
77    pub fn with_build_stack(self, build_stack: &'a BuildStack) -> Self {
78        Self {
79            builder: self.builder.with_build_stack(build_stack),
80            ..self
81        }
82    }
83
84    /// Set the [`Reporter`] to use for the [`DistributionDatabase`].
85    #[must_use]
86    pub fn with_reporter(self, reporter: Arc<dyn Reporter>) -> Self {
87        Self {
88            builder: self.builder.with_reporter(reporter.clone()),
89            reporter: Some(reporter),
90            ..self
91        }
92    }
93
94    /// Handle a specific `reqwest` error, and convert it to [`io::Error`].
95    fn handle_response_errors(&self, err: reqwest::Error) -> io::Error {
96        if err.is_timeout() {
97            // Assumption: The connect timeout with the 10s default is not the culprit.
98            io::Error::new(
99                io::ErrorKind::TimedOut,
100                format!(
101                    "Failed to download distribution due to network timeout. Try increasing UV_HTTP_TIMEOUT (current value: {}s).",
102                    self.client.unmanaged.read_timeout().as_secs()
103                ),
104            )
105        } else {
106            io::Error::other(err)
107        }
108    }
109
110    /// Either fetch the wheel or fetch and build the source distribution
111    ///
112    /// Returns a wheel that's compliant with the given platform tags.
113    ///
114    /// While hashes will be generated in some cases, hash-checking is only enforced for source
115    /// distributions, and should be enforced by the caller for wheels.
116    #[instrument(skip_all, fields(%dist))]
117    pub async fn get_or_build_wheel(
118        &self,
119        dist: &Dist,
120        tags: &Tags,
121        hashes: HashPolicy<'_>,
122    ) -> Result<LocalWheel, Error> {
123        match dist {
124            Dist::Built(built) => self.get_wheel(built, hashes).await,
125            Dist::Source(source) => self.build_wheel(source, tags, hashes).await,
126        }
127    }
128
129    /// Either fetch the only wheel metadata (directly from the index or with range requests) or
130    /// fetch and build the source distribution.
131    ///
132    /// While hashes will be generated in some cases, hash-checking is only enforced for source
133    /// distributions, and should be enforced by the caller for wheels.
134    #[instrument(skip_all, fields(%dist))]
135    pub async fn get_installed_metadata(
136        &self,
137        dist: &InstalledDist,
138    ) -> Result<ArchiveMetadata, Error> {
139        // If the metadata was provided by the user directly, prefer it.
140        if let Some(metadata) = self
141            .build_context
142            .dependency_metadata()
143            .get(dist.name(), Some(dist.version()))
144        {
145            return Ok(ArchiveMetadata::from_metadata23(metadata.clone()));
146        }
147
148        let metadata = dist
149            .read_metadata()
150            .map_err(|err| Error::ReadInstalled(Box::new(dist.clone()), err))?;
151
152        Ok(ArchiveMetadata::from_metadata23(metadata.clone()))
153    }
154
155    /// Either fetch the only wheel metadata (directly from the index or with range requests) or
156    /// fetch and build the source distribution.
157    ///
158    /// While hashes will be generated in some cases, hash-checking is only enforced for source
159    /// distributions, and should be enforced by the caller for wheels.
160    #[instrument(skip_all, fields(%dist))]
161    pub async fn get_or_build_wheel_metadata(
162        &self,
163        dist: &Dist,
164        hashes: HashPolicy<'_>,
165    ) -> Result<ArchiveMetadata, Error> {
166        match dist {
167            Dist::Built(built) => self.get_wheel_metadata(built, hashes).await,
168            Dist::Source(source) => {
169                self.build_wheel_metadata(&BuildableSource::Dist(source), hashes)
170                    .await
171            }
172        }
173    }
174
175    /// Fetch a wheel from the cache or download it from the index.
176    ///
177    /// While hashes will be generated in all cases, hash-checking is _not_ enforced and should
178    /// instead be enforced by the caller.
179    async fn get_wheel(
180        &self,
181        dist: &BuiltDist,
182        hashes: HashPolicy<'_>,
183    ) -> Result<LocalWheel, Error> {
184        match dist {
185            BuiltDist::Registry(wheels) => {
186                let wheel = wheels.best_wheel();
187                let WheelTarget {
188                    url,
189                    extension,
190                    size,
191                } = WheelTarget::try_from(&*wheel.file)?;
192
193                // Create a cache entry for the wheel.
194                let wheel_entry = self.build_context.cache().entry(
195                    CacheBucket::Wheels,
196                    WheelCache::Index(&wheel.index).wheel_dir(wheel.name().as_ref()),
197                    wheel.filename.cache_key(),
198                );
199
200                // If the URL is a file URL, load the wheel directly.
201                if url.scheme() == "file" {
202                    let path = url
203                        .to_file_path()
204                        .map_err(|()| Error::NonFileUrl(url.clone()))?;
205                    return self
206                        .load_wheel(
207                            &path,
208                            &wheel.filename,
209                            WheelExtension::Whl,
210                            wheel_entry,
211                            dist,
212                            hashes,
213                        )
214                        .await;
215                }
216
217                // Download and unzip.
218                match self
219                    .stream_wheel(
220                        url.clone(),
221                        dist.index(),
222                        &wheel.filename,
223                        extension,
224                        size,
225                        &wheel_entry,
226                        dist,
227                        hashes,
228                    )
229                    .await
230                {
231                    Ok(archive) => Ok(LocalWheel {
232                        dist: Dist::Built(dist.clone()),
233                        archive: self
234                            .build_context
235                            .cache()
236                            .archive(&archive.id)
237                            .into_boxed_path(),
238                        hashes: archive.hashes,
239                        filename: wheel.filename.clone(),
240                        cache: CacheInfo::default(),
241                        build: None,
242                    }),
243                    Err(Error::Extract(name, err)) => {
244                        if err.is_http_streaming_unsupported() {
245                            warn!(
246                                "Streaming unsupported for {dist}; downloading wheel to disk ({err})"
247                            );
248                        } else if err.is_http_streaming_failed() {
249                            warn!("Streaming failed for {dist}; downloading wheel to disk ({err})");
250                        } else {
251                            return Err(Error::Extract(name, err));
252                        }
253
254                        // If the request failed because streaming was unsupported or failed,
255                        // download the wheel directly.
256                        let archive = self
257                            .download_wheel(
258                                url,
259                                dist.index(),
260                                &wheel.filename,
261                                extension,
262                                size,
263                                &wheel_entry,
264                                dist,
265                                hashes,
266                            )
267                            .await?;
268
269                        Ok(LocalWheel {
270                            dist: Dist::Built(dist.clone()),
271                            archive: self
272                                .build_context
273                                .cache()
274                                .archive(&archive.id)
275                                .into_boxed_path(),
276                            hashes: archive.hashes,
277                            filename: wheel.filename.clone(),
278                            cache: CacheInfo::default(),
279                            build: None,
280                        })
281                    }
282                    Err(err) => Err(err),
283                }
284            }
285
286            BuiltDist::DirectUrl(wheel) => {
287                // Create a cache entry for the wheel.
288                let wheel_entry = self.build_context.cache().entry(
289                    CacheBucket::Wheels,
290                    WheelCache::Url(&wheel.url).wheel_dir(wheel.name().as_ref()),
291                    wheel.filename.cache_key(),
292                );
293
294                // Download and unzip.
295                match self
296                    .stream_wheel(
297                        wheel.url.raw().clone(),
298                        None,
299                        &wheel.filename,
300                        WheelExtension::Whl,
301                        None,
302                        &wheel_entry,
303                        dist,
304                        hashes,
305                    )
306                    .await
307                {
308                    Ok(archive) => Ok(LocalWheel {
309                        dist: Dist::Built(dist.clone()),
310                        archive: self
311                            .build_context
312                            .cache()
313                            .archive(&archive.id)
314                            .into_boxed_path(),
315                        hashes: archive.hashes,
316                        filename: wheel.filename.clone(),
317                        cache: CacheInfo::default(),
318                        build: None,
319                    }),
320                    Err(Error::Extract(name, err)) => {
321                        if err.is_http_streaming_unsupported() {
322                            warn!(
323                                "Streaming unsupported for {dist}; downloading wheel to disk ({err})"
324                            );
325                        } else if err.is_http_streaming_failed() {
326                            warn!("Streaming failed for {dist}; downloading wheel to disk ({err})");
327                        } else {
328                            return Err(Error::Extract(name, err));
329                        }
330
331                        // If the request failed because streaming was unsupported or failed,
332                        // download the wheel directly.
333                        let archive = self
334                            .download_wheel(
335                                wheel.url.raw().clone(),
336                                None,
337                                &wheel.filename,
338                                WheelExtension::Whl,
339                                None,
340                                &wheel_entry,
341                                dist,
342                                hashes,
343                            )
344                            .await?;
345                        Ok(LocalWheel {
346                            dist: Dist::Built(dist.clone()),
347                            archive: self
348                                .build_context
349                                .cache()
350                                .archive(&archive.id)
351                                .into_boxed_path(),
352                            hashes: archive.hashes,
353                            filename: wheel.filename.clone(),
354                            cache: CacheInfo::default(),
355                            build: None,
356                        })
357                    }
358                    Err(err) => Err(err),
359                }
360            }
361
362            BuiltDist::Path(wheel) => {
363                let cache_entry = self.build_context.cache().entry(
364                    CacheBucket::Wheels,
365                    WheelCache::Url(&wheel.url).wheel_dir(wheel.name().as_ref()),
366                    wheel.filename.cache_key(),
367                );
368
369                self.load_wheel(
370                    &wheel.install_path,
371                    &wheel.filename,
372                    WheelExtension::Whl,
373                    cache_entry,
374                    dist,
375                    hashes,
376                )
377                .await
378            }
379        }
380    }
381
382    /// Convert a source distribution into a wheel, fetching it from the cache or building it if
383    /// necessary.
384    ///
385    /// The returned wheel is guaranteed to come from a distribution with a matching hash, and
386    /// no build processes will be executed for distributions with mismatched hashes.
387    async fn build_wheel(
388        &self,
389        dist: &SourceDist,
390        tags: &Tags,
391        hashes: HashPolicy<'_>,
392    ) -> Result<LocalWheel, Error> {
393        // Warn if the source distribution isn't PEP 625 compliant.
394        // We do this here instead of in `SourceDistExtension::from_path` to minimize log volume:
395        // a non-compliant distribution isn't a huge problem if it's not actually being
396        // materialized into a wheel. Observe that we also allow no extension, since we expect that
397        // for directory and Git installs.
398        // NOTE: Observe that we also allow `.zip` sdists here, which are not PEP 625 compliant.
399        // This is because they were allowed on PyPI until relatively recently (2020).
400        if let Some(extension) = dist.extension()
401            && !matches!(
402                extension,
403                SourceDistExtension::TarGz | SourceDistExtension::Zip
404            )
405        {
406            if matches!(dist, SourceDist::Registry(_)) {
407                // Observe that we display a slightly different warning when the sdist comes
408                // from a registry, since that suggests that the user has inadvertently
409                // (rather than explicitly) depended on a non-compliant sdist.
410                warn_user_once!(
411                    "{dist} uses a legacy source distribution format ('.{extension}') that is not compliant with PEP 625. A future version of uv will reject this source distribution. Consider upgrading to a newer version of {package}",
412                    package = dist.name(),
413                );
414            } else {
415                warn_user_once!(
416                    "{dist} is not a standards-compliant source distribution: expected '.tar.gz' but found '.{extension}'. A future version of uv will reject source distributions that do not meet the requirements specified in PEP 625",
417                );
418            }
419        }
420
421        let built_wheel = self
422            .builder
423            .download_and_build(&BuildableSource::Dist(dist), tags, hashes, &self.client)
424            .boxed_local()
425            .await?;
426
427        // Check that the wheel is compatible with its install target.
428        //
429        // When building a build dependency for a cross-install, the build dependency needs
430        // to install and run on the host instead of the target. In this case the `tags` are already
431        // for the host instead of the target, so this check passes.
432        if !built_wheel.filename.is_compatible(tags) {
433            return if tags.is_cross() {
434                Err(Error::BuiltWheelIncompatibleTargetPlatform {
435                    filename: built_wheel.filename,
436                    python_platform: tags.python_platform().clone(),
437                    python_version: PythonVersion {
438                        version: tags.python_version(),
439                        variant: if tags.is_freethreaded() {
440                            PythonVariant::Freethreaded
441                        } else {
442                            PythonVariant::Default
443                        },
444                    },
445                })
446            } else {
447                Err(Error::BuiltWheelIncompatibleHostPlatform {
448                    filename: built_wheel.filename,
449                    python_platform: tags.python_platform().clone(),
450                    python_version: PythonVersion {
451                        version: tags.python_version(),
452                        variant: if tags.is_freethreaded() {
453                            PythonVariant::Freethreaded
454                        } else {
455                            PythonVariant::Default
456                        },
457                    },
458                })
459            };
460        }
461
462        // Acquire the advisory lock.
463        #[cfg(windows)]
464        let _lock = {
465            let lock_entry = CacheEntry::new(
466                built_wheel.target.parent().unwrap(),
467                format!(
468                    "{}.lock",
469                    built_wheel.target.file_name().unwrap().to_str().unwrap()
470                ),
471            );
472            lock_entry.lock().await.map_err(Error::CacheLock)?
473        };
474
475        // If the wheel was unzipped previously, respect it. Source distributions are
476        // cached under a unique revision ID, so unzipped directories are never stale.
477        match self.build_context.cache().resolve_link(&built_wheel.target) {
478            Ok(archive) => {
479                return Ok(LocalWheel {
480                    dist: Dist::Source(dist.clone()),
481                    archive: archive.into_boxed_path(),
482                    filename: built_wheel.filename,
483                    hashes: built_wheel.hashes,
484                    cache: built_wheel.cache_info,
485                    build: Some(built_wheel.build_info),
486                });
487            }
488            Err(err) if err.kind() == io::ErrorKind::NotFound => {}
489            Err(err) => return Err(Error::CacheRead(err)),
490        }
491
492        // Otherwise, unzip the wheel.
493        let id = self
494            .unzip_wheel(
495                &built_wheel.path,
496                &built_wheel.target,
497                DistRef::Source(dist),
498            )
499            .await?;
500
501        Ok(LocalWheel {
502            dist: Dist::Source(dist.clone()),
503            archive: self.build_context.cache().archive(&id).into_boxed_path(),
504            hashes: built_wheel.hashes,
505            filename: built_wheel.filename,
506            cache: built_wheel.cache_info,
507            build: Some(built_wheel.build_info),
508        })
509    }
510
511    /// Fetch the wheel metadata from the index, or from the cache if possible.
512    ///
513    /// While hashes will be generated in some cases, hash-checking is _not_ enforced and should
514    /// instead be enforced by the caller.
515    async fn get_wheel_metadata(
516        &self,
517        dist: &BuiltDist,
518        hashes: HashPolicy<'_>,
519    ) -> Result<ArchiveMetadata, Error> {
520        // If hash generation is enabled, and the distribution isn't hosted on a registry, get the
521        // entire wheel to ensure that the hashes are included in the response. If the distribution
522        // is hosted on an index, the hashes will be included in the simple metadata response.
523        // For hash _validation_, callers are expected to enforce the policy when retrieving the
524        // wheel.
525        //
526        // Historically, for `uv pip compile --universal`, we also generate hashes for
527        // registry-based distributions when the relevant registry doesn't provide them. This was
528        // motivated by `--find-links`. We continue that behavior (under `HashGeneration::All`) for
529        // backwards compatibility, but it's a little dubious, since we're only hashing _one_
530        // distribution here (as opposed to hashing all distributions for the version), and it may
531        // not even be a compatible distribution!
532        //
533        // TODO(charlie): Request the hashes via a separate method, to reduce the coupling in this API.
534        if hashes.is_generate(dist) {
535            let wheel = self.get_wheel(dist, hashes).await?;
536            // If the metadata was provided by the user directly, prefer it.
537            let metadata = if let Some(metadata) = self
538                .build_context
539                .dependency_metadata()
540                .get(dist.name(), Some(dist.version()))
541            {
542                metadata.clone()
543            } else {
544                wheel.metadata()?
545            };
546            let hashes = wheel.hashes;
547            return Ok(ArchiveMetadata {
548                metadata: Metadata::from_metadata23(metadata),
549                hashes,
550            });
551        }
552
553        // If the metadata was provided by the user directly, prefer it.
554        if let Some(metadata) = self
555            .build_context
556            .dependency_metadata()
557            .get(dist.name(), Some(dist.version()))
558        {
559            return Ok(ArchiveMetadata::from_metadata23(metadata.clone()));
560        }
561
562        let result = self
563            .client
564            .managed(|client| {
565                client
566                    .wheel_metadata(dist, self.build_context.capabilities())
567                    .boxed_local()
568            })
569            .await;
570
571        match result {
572            Ok(metadata) => {
573                // Validate that the metadata is consistent with the distribution.
574                Ok(ArchiveMetadata::from_metadata23(metadata))
575            }
576            Err(err) if err.is_http_streaming_unsupported() => {
577                warn!(
578                    "Streaming unsupported when fetching metadata for {dist}; downloading wheel directly ({err})"
579                );
580
581                // If the request failed due to an error that could be resolved by
582                // downloading the wheel directly, try that.
583                let wheel = self.get_wheel(dist, hashes).await?;
584                let metadata = wheel.metadata()?;
585                let hashes = wheel.hashes;
586                Ok(ArchiveMetadata {
587                    metadata: Metadata::from_metadata23(metadata),
588                    hashes,
589                })
590            }
591            Err(err) => Err(err.into()),
592        }
593    }
594
595    /// Build the wheel metadata for a source distribution, or fetch it from the cache if possible.
596    ///
597    /// The returned metadata is guaranteed to come from a distribution with a matching hash, and
598    /// no build processes will be executed for distributions with mismatched hashes.
599    pub async fn build_wheel_metadata(
600        &self,
601        source: &BuildableSource<'_>,
602        hashes: HashPolicy<'_>,
603    ) -> Result<ArchiveMetadata, Error> {
604        // If the metadata was provided by the user directly, prefer it.
605        if let Some(dist) = source.as_dist() {
606            if let Some(metadata) = self
607                .build_context
608                .dependency_metadata()
609                .get(dist.name(), dist.version())
610            {
611                // If we skipped the build, we should still resolve any Git dependencies to precise
612                // commits.
613                self.builder.resolve_revision(source, &self.client).await?;
614
615                return Ok(ArchiveMetadata::from_metadata23(metadata.clone()));
616            }
617        }
618
619        let metadata = self
620            .builder
621            .download_and_build_metadata(source, hashes, &self.client)
622            .boxed_local()
623            .await?;
624
625        Ok(metadata)
626    }
627
628    /// Return the [`RequiresDist`] from a `pyproject.toml`, if it can be statically extracted.
629    pub async fn requires_dist(
630        &self,
631        path: &Path,
632        pyproject_toml: &PyProjectToml,
633    ) -> Result<Option<RequiresDist>, Error> {
634        self.builder
635            .source_tree_requires_dist(
636                path,
637                pyproject_toml,
638                self.client.unmanaged.credentials_cache(),
639            )
640            .await
641    }
642
643    /// Stream a wheel from a URL, unzipping it into the cache as it's downloaded.
644    async fn stream_wheel(
645        &self,
646        url: DisplaySafeUrl,
647        index: Option<&IndexUrl>,
648        filename: &WheelFilename,
649        extension: WheelExtension,
650        size: Option<u64>,
651        wheel_entry: &CacheEntry,
652        dist: &BuiltDist,
653        hashes: HashPolicy<'_>,
654    ) -> Result<Archive, Error> {
655        // Acquire an advisory lock, to guard against concurrent writes.
656        #[cfg(windows)]
657        let _lock = {
658            let lock_entry = wheel_entry.with_file(format!("{}.lock", filename.stem()));
659            lock_entry.lock().await.map_err(Error::CacheLock)?
660        };
661
662        // Create an entry for the HTTP cache.
663        let http_entry = wheel_entry.with_file(format!("{}.http", filename.cache_key()));
664
665        let query_url = &url.clone();
666
667        let download = |response: reqwest::Response| {
668            async {
669                let size = size.or_else(|| content_length(&response));
670
671                let progress = self
672                    .reporter
673                    .as_ref()
674                    .map(|reporter| (reporter, reporter.on_download_start(dist.name(), size)));
675
676                let reader = response
677                    .bytes_stream()
678                    .map_err(|err| self.handle_response_errors(err))
679                    .into_async_read();
680
681                // Create a hasher for each hash algorithm.
682                let algorithms = hashes.algorithms();
683                let mut hashers = algorithms.into_iter().map(Hasher::from).collect::<Vec<_>>();
684                let mut hasher = uv_extract::hash::HashReader::new(reader.compat(), &mut hashers);
685
686                // Download and unzip the wheel to a temporary directory.
687                let temp_dir = tempfile::tempdir_in(self.build_context.cache().root())
688                    .map_err(Error::CacheWrite)?;
689
690                let files = match progress {
691                    Some((reporter, progress)) => {
692                        let mut reader = ProgressReader::new(&mut hasher, progress, &**reporter);
693                        match extension {
694                            WheelExtension::Whl => {
695                                uv_extract::stream::unzip(query_url, &mut reader, temp_dir.path())
696                                    .await
697                                    .map_err(|err| Error::Extract(filename.to_string(), err))?
698                            }
699                            WheelExtension::WhlZst => {
700                                uv_extract::stream::untar_zst(&mut reader, temp_dir.path())
701                                    .await
702                                    .map_err(|err| Error::Extract(filename.to_string(), err))?
703                            }
704                        }
705                    }
706                    None => match extension {
707                        WheelExtension::Whl => {
708                            uv_extract::stream::unzip(query_url, &mut hasher, temp_dir.path())
709                                .await
710                                .map_err(|err| Error::Extract(filename.to_string(), err))?
711                        }
712                        WheelExtension::WhlZst => {
713                            uv_extract::stream::untar_zst(&mut hasher, temp_dir.path())
714                                .await
715                                .map_err(|err| Error::Extract(filename.to_string(), err))?
716                        }
717                    },
718                };
719                // If necessary, exhaust the reader to compute the hash.
720                if !hashes.is_none() {
721                    hasher.finish().await.map_err(Error::HashExhaustion)?;
722                }
723
724                // Before we make the wheel accessible by persisting it, ensure that the RECORD is
725                // valid.
726                validate_and_heal_record(temp_dir.path(), files.iter(), dist)
727                    .map_err(Error::InstallWheelError)?;
728
729                // Persist the temporary directory to the directory store.
730                let id = self
731                    .build_context
732                    .cache()
733                    .persist(temp_dir.keep(), wheel_entry.path())
734                    .await
735                    .map_err(Error::CacheRead)?;
736
737                if let Some((reporter, progress)) = progress {
738                    reporter.on_download_complete(dist.name(), progress);
739                }
740
741                Ok(Archive::new(
742                    id,
743                    hashers.into_iter().map(HashDigest::from).collect(),
744                    filename.clone(),
745                ))
746            }
747            .instrument(info_span!("wheel", wheel = %dist))
748        };
749
750        // Fetch the archive from the cache, or download it if necessary.
751        let req = self.request(url.clone())?;
752
753        // Determine the cache control policy for the URL.
754        let cache_control = match self.client.unmanaged.connectivity() {
755            Connectivity::Online => {
756                if let Some(header) = index.and_then(|index| {
757                    self.build_context
758                        .locations()
759                        .artifact_cache_control_for(index)
760                }) {
761                    CacheControl::Override(header)
762                } else {
763                    CacheControl::from(
764                        self.build_context
765                            .cache()
766                            .freshness(&http_entry, Some(&filename.name), None)
767                            .map_err(Error::CacheRead)?,
768                    )
769                }
770            }
771            Connectivity::Offline => CacheControl::AllowStale,
772        };
773
774        let archive = self
775            .client
776            .managed(|client| {
777                client.cached_client().get_serde_with_retry(
778                    req,
779                    &http_entry,
780                    cache_control.clone(),
781                    download,
782                )
783            })
784            .await
785            .map_err(|err| match err {
786                CachedClientError::Callback { err, .. } => err,
787                CachedClientError::Client(err) => Error::Client(err),
788            })?;
789
790        // If the archive is missing the required hashes, or has since been removed, force a refresh.
791        let archive = Some(archive)
792            .filter(|archive| archive.has_digests(hashes))
793            .filter(|archive| archive.exists(self.build_context.cache()));
794
795        let archive = if let Some(archive) = archive {
796            archive
797        } else {
798            self.client
799                .managed(async |client| {
800                    client
801                        .cached_client()
802                        .skip_cache_with_retry(
803                            self.request(url)?,
804                            &http_entry,
805                            cache_control,
806                            download,
807                        )
808                        .await
809                        .map_err(|err| match err {
810                            CachedClientError::Callback { err, .. } => err,
811                            CachedClientError::Client(err) => Error::Client(err),
812                        })
813                })
814                .await?
815        };
816
817        Ok(archive)
818    }
819
820    /// Download a wheel from a URL, then unzip it into the cache.
821    async fn download_wheel(
822        &self,
823        url: DisplaySafeUrl,
824        index: Option<&IndexUrl>,
825        filename: &WheelFilename,
826        extension: WheelExtension,
827        size: Option<u64>,
828        wheel_entry: &CacheEntry,
829        dist: &BuiltDist,
830        hashes: HashPolicy<'_>,
831    ) -> Result<Archive, Error> {
832        // Acquire an advisory lock, to guard against concurrent writes.
833        #[cfg(windows)]
834        let _lock = {
835            let lock_entry = wheel_entry.with_file(format!("{}.lock", filename.stem()));
836            lock_entry.lock().await.map_err(Error::CacheLock)?
837        };
838
839        // Create an entry for the HTTP cache.
840        let http_entry = wheel_entry.with_file(format!("{}.http", filename.cache_key()));
841
842        let query_url = &url.clone();
843
844        let download = |response: reqwest::Response| {
845            async {
846                let size = size.or_else(|| content_length(&response));
847
848                let progress = self
849                    .reporter
850                    .as_ref()
851                    .map(|reporter| (reporter, reporter.on_download_start(dist.name(), size)));
852
853                let reader = response
854                    .bytes_stream()
855                    .map_err(|err| self.handle_response_errors(err))
856                    .into_async_read();
857
858                // Download the wheel to a temporary file.
859                let temp_file = tempfile::tempfile_in(self.build_context.cache().root())
860                    .map_err(Error::CacheWrite)?;
861                let mut writer = tokio::io::BufWriter::new(fs_err::tokio::File::from_std(
862                    // It's an unnamed file on Linux so that's the best approximation.
863                    fs_err::File::from_parts(temp_file, self.build_context.cache().root()),
864                ));
865
866                match progress {
867                    Some((reporter, progress)) => {
868                        // Wrap the reader in a progress reporter. This will report 100% progress
869                        // after the download is complete, even if we still have to unzip and hash
870                        // part of the file.
871                        let mut reader =
872                            ProgressReader::new(reader.compat(), progress, &**reporter);
873
874                        tokio::io::copy(&mut reader, &mut writer)
875                            .await
876                            .map_err(Error::CacheWrite)?;
877                    }
878                    None => {
879                        tokio::io::copy(&mut reader.compat(), &mut writer)
880                            .await
881                            .map_err(Error::CacheWrite)?;
882                    }
883                }
884
885                // Unzip the wheel to a temporary directory.
886                let temp_dir = tempfile::tempdir_in(self.build_context.cache().root())
887                    .map_err(Error::CacheWrite)?;
888                let mut file = writer.into_inner();
889                file.seek(io::SeekFrom::Start(0))
890                    .await
891                    .map_err(Error::CacheWrite)?;
892
893                // If no hashes are required, extract the wheel without hashing.
894                let (files, hashes) = if hashes.is_none() {
895                    let target = temp_dir.path().to_owned();
896                    let files = match extension {
897                        WheelExtension::Whl => {
898                            let file = file.into_std().await;
899                            tokio::task::spawn_blocking(move || uv_extract::unzip(file, &target))
900                                .await?
901                        }
902                        WheelExtension::WhlZst => {
903                            uv_extract::stream::untar_zst(file, &target).await
904                        }
905                    }
906                    .map_err(|err| Error::Extract(filename.to_string(), err))?;
907
908                    (files, HashDigests::empty())
909                } else {
910                    // Create a hasher for each hash algorithm.
911                    let algorithms = hashes.algorithms();
912                    let mut hashers = algorithms.into_iter().map(Hasher::from).collect::<Vec<_>>();
913                    let mut hasher = uv_extract::hash::HashReader::new(file, &mut hashers);
914
915                    let files = match extension {
916                        WheelExtension::Whl => {
917                            uv_extract::stream::unzip(query_url, &mut hasher, temp_dir.path())
918                                .await
919                                .map_err(|err| Error::Extract(filename.to_string(), err))?
920                        }
921                        WheelExtension::WhlZst => {
922                            uv_extract::stream::untar_zst(&mut hasher, temp_dir.path())
923                                .await
924                                .map_err(|err| Error::Extract(filename.to_string(), err))?
925                        }
926                    };
927
928                    // If necessary, exhaust the reader to compute the hash.
929                    hasher.finish().await.map_err(Error::HashExhaustion)?;
930                    let hashes = hashers.into_iter().map(HashDigest::from).collect();
931
932                    (files, hashes)
933                };
934
935                // Before we make the wheel accessible by persisting it, ensure that the RECORD is
936                // valid.
937                validate_and_heal_record(temp_dir.path(), files.iter(), dist)
938                    .map_err(Error::InstallWheelError)?;
939
940                // Persist the temporary directory to the directory store.
941                let id = self
942                    .build_context
943                    .cache()
944                    .persist(temp_dir.keep(), wheel_entry.path())
945                    .await
946                    .map_err(Error::CacheRead)?;
947
948                if let Some((reporter, progress)) = progress {
949                    reporter.on_download_complete(dist.name(), progress);
950                }
951
952                Ok(Archive::new(id, hashes, filename.clone()))
953            }
954            .instrument(info_span!("wheel", wheel = %dist))
955        };
956
957        // Fetch the archive from the cache, or download it if necessary.
958        let req = self.request(url.clone())?;
959
960        // Determine the cache control policy for the URL.
961        let cache_control = match self.client.unmanaged.connectivity() {
962            Connectivity::Online => {
963                if let Some(header) = index.and_then(|index| {
964                    self.build_context
965                        .locations()
966                        .artifact_cache_control_for(index)
967                }) {
968                    CacheControl::Override(header)
969                } else {
970                    CacheControl::from(
971                        self.build_context
972                            .cache()
973                            .freshness(&http_entry, Some(&filename.name), None)
974                            .map_err(Error::CacheRead)?,
975                    )
976                }
977            }
978            Connectivity::Offline => CacheControl::AllowStale,
979        };
980
981        let archive = self
982            .client
983            .managed(|client| {
984                client.cached_client().get_serde_with_retry(
985                    req,
986                    &http_entry,
987                    cache_control.clone(),
988                    download,
989                )
990            })
991            .await
992            .map_err(|err| match err {
993                CachedClientError::Callback { err, .. } => err,
994                CachedClientError::Client(err) => Error::Client(err),
995            })?;
996
997        // If the archive is missing the required hashes, or has since been removed, force a refresh.
998        let archive = Some(archive)
999            .filter(|archive| archive.has_digests(hashes))
1000            .filter(|archive| archive.exists(self.build_context.cache()));
1001
1002        let archive = if let Some(archive) = archive {
1003            archive
1004        } else {
1005            self.client
1006                .managed(async |client| {
1007                    client
1008                        .cached_client()
1009                        .skip_cache_with_retry(
1010                            self.request(url)?,
1011                            &http_entry,
1012                            cache_control,
1013                            download,
1014                        )
1015                        .await
1016                        .map_err(|err| match err {
1017                            CachedClientError::Callback { err, .. } => err,
1018                            CachedClientError::Client(err) => Error::Client(err),
1019                        })
1020                })
1021                .await?
1022        };
1023
1024        Ok(archive)
1025    }
1026
1027    /// Load a wheel from a local path.
1028    async fn load_wheel(
1029        &self,
1030        path: &Path,
1031        filename: &WheelFilename,
1032        extension: WheelExtension,
1033        wheel_entry: CacheEntry,
1034        dist: &BuiltDist,
1035        hashes: HashPolicy<'_>,
1036    ) -> Result<LocalWheel, Error> {
1037        #[cfg(windows)]
1038        let _lock = {
1039            let lock_entry = wheel_entry.with_file(format!("{}.lock", filename.stem()));
1040            lock_entry.lock().await.map_err(Error::CacheLock)?
1041        };
1042
1043        // Determine the last-modified time of the wheel.
1044        let modified = Timestamp::from_path(path).map_err(Error::CacheRead)?;
1045
1046        // Attempt to read the archive pointer from the cache.
1047        let pointer_entry = wheel_entry.with_file(format!("{}.rev", filename.cache_key()));
1048        let pointer = LocalArchivePointer::read_from(&pointer_entry)?;
1049
1050        // Extract the archive from the pointer.
1051        let archive = pointer
1052            .filter(|pointer| pointer.is_up_to_date(modified))
1053            .map(LocalArchivePointer::into_archive)
1054            .filter(|archive| archive.has_digests(hashes));
1055
1056        // If the file is already unzipped, and the cache is up-to-date, return it.
1057        if let Some(archive) = archive {
1058            Ok(LocalWheel {
1059                dist: Dist::Built(dist.clone()),
1060                archive: self
1061                    .build_context
1062                    .cache()
1063                    .archive(&archive.id)
1064                    .into_boxed_path(),
1065                hashes: archive.hashes,
1066                filename: filename.clone(),
1067                cache: CacheInfo::from_timestamp(modified),
1068                build: None,
1069            })
1070        } else if hashes.is_none() {
1071            // Otherwise, unzip the wheel.
1072            let archive = Archive::new(
1073                self.unzip_wheel(path, wheel_entry.path(), DistRef::Built(dist))
1074                    .await?,
1075                HashDigests::empty(),
1076                filename.clone(),
1077            );
1078
1079            // Write the archive pointer to the cache.
1080            let pointer = LocalArchivePointer {
1081                timestamp: modified,
1082                archive: archive.clone(),
1083            };
1084            pointer.write_to(&pointer_entry).await?;
1085
1086            Ok(LocalWheel {
1087                dist: Dist::Built(dist.clone()),
1088                archive: self
1089                    .build_context
1090                    .cache()
1091                    .archive(&archive.id)
1092                    .into_boxed_path(),
1093                hashes: archive.hashes,
1094                filename: filename.clone(),
1095                cache: CacheInfo::from_timestamp(modified),
1096                build: None,
1097            })
1098        } else {
1099            // If necessary, compute the hashes of the wheel.
1100            let file = fs_err::tokio::File::open(path)
1101                .await
1102                .map_err(Error::CacheRead)?;
1103            let temp_dir = tempfile::tempdir_in(self.build_context.cache().root())
1104                .map_err(Error::CacheWrite)?;
1105
1106            // Create a hasher for each hash algorithm.
1107            let algorithms = hashes.algorithms();
1108            let mut hashers = algorithms.into_iter().map(Hasher::from).collect::<Vec<_>>();
1109            let mut hasher = uv_extract::hash::HashReader::new(file, &mut hashers);
1110
1111            // Unzip the wheel to a temporary directory.
1112            let files = match extension {
1113                WheelExtension::Whl => {
1114                    uv_extract::stream::unzip(path.display(), &mut hasher, temp_dir.path())
1115                        .await
1116                        .map_err(|err| Error::Extract(filename.to_string(), err))?
1117                }
1118                WheelExtension::WhlZst => {
1119                    uv_extract::stream::untar_zst(&mut hasher, temp_dir.path())
1120                        .await
1121                        .map_err(|err| Error::Extract(filename.to_string(), err))?
1122                }
1123            };
1124
1125            // Exhaust the reader to compute the hash.
1126            hasher.finish().await.map_err(Error::HashExhaustion)?;
1127
1128            let hashes = hashers.into_iter().map(HashDigest::from).collect();
1129
1130            // Before we make the wheel accessible by persisting it, ensure that the RECORD is
1131            // valid.
1132            validate_and_heal_record(temp_dir.path(), files.iter(), dist)
1133                .map_err(Error::InstallWheelError)?;
1134
1135            // Persist the temporary directory to the directory store.
1136            let id = self
1137                .build_context
1138                .cache()
1139                .persist(temp_dir.keep(), wheel_entry.path())
1140                .await
1141                .map_err(Error::CacheWrite)?;
1142
1143            // Create an archive.
1144            let archive = Archive::new(id, hashes, filename.clone());
1145
1146            // Write the archive pointer to the cache.
1147            let pointer = LocalArchivePointer {
1148                timestamp: modified,
1149                archive: archive.clone(),
1150            };
1151            pointer.write_to(&pointer_entry).await?;
1152
1153            Ok(LocalWheel {
1154                dist: Dist::Built(dist.clone()),
1155                archive: self
1156                    .build_context
1157                    .cache()
1158                    .archive(&archive.id)
1159                    .into_boxed_path(),
1160                hashes: archive.hashes,
1161                filename: filename.clone(),
1162                cache: CacheInfo::from_timestamp(modified),
1163                build: None,
1164            })
1165        }
1166    }
1167
1168    /// Unzip a wheel into the cache, returning the path to the unzipped directory.
1169    async fn unzip_wheel(
1170        &self,
1171        path: &Path,
1172        target: &Path,
1173        dist: DistRef<'_>,
1174    ) -> Result<ArchiveId, Error> {
1175        let (temp_dir, files) = tokio::task::spawn_blocking({
1176            let path = path.to_owned();
1177            let root = self.build_context.cache().root().to_path_buf();
1178            move || -> Result<_, Error> {
1179                // Unzip the wheel into a temporary directory.
1180                let temp_dir = tempfile::tempdir_in(root).map_err(Error::CacheWrite)?;
1181                let reader = fs_err::File::open(&path).map_err(Error::CacheWrite)?;
1182                let files = uv_extract::unzip(reader, temp_dir.path())
1183                    .map_err(|err| Error::Extract(path.to_string_lossy().into_owned(), err))?;
1184                Ok((temp_dir, files))
1185            }
1186        })
1187        .await??;
1188
1189        // Before we make the wheel accessible by persisting it, ensure that the RECORD is valid.
1190        validate_and_heal_record(temp_dir.path(), files.iter(), dist)
1191            .map_err(Error::InstallWheelError)?;
1192
1193        // Persist the temporary directory to the directory store.
1194        let id = self
1195            .build_context
1196            .cache()
1197            .persist(temp_dir.keep(), target)
1198            .await
1199            .map_err(Error::CacheWrite)?;
1200
1201        Ok(id)
1202    }
1203
1204    /// Returns a GET [`reqwest::Request`] for the given URL.
1205    fn request(&self, url: DisplaySafeUrl) -> Result<reqwest::Request, reqwest::Error> {
1206        self.client
1207            .unmanaged
1208            .uncached_client(&url)
1209            .get(Url::from(url))
1210            .header(
1211                // `reqwest` defaults to accepting compressed responses.
1212                // Specify identity encoding to get consistent .whl downloading
1213                // behavior from servers. ref: https://github.com/pypa/pip/pull/1688
1214                "accept-encoding",
1215                reqwest::header::HeaderValue::from_static("identity"),
1216            )
1217            .build()
1218    }
1219
1220    /// Return the [`ManagedClient`] used by this resolver.
1221    pub fn client(&self) -> &ManagedClient<'a> {
1222        &self.client
1223    }
1224}
1225
1226/// A wrapper around `RegistryClient` that manages a concurrency limit.
1227pub struct ManagedClient<'a> {
1228    pub unmanaged: &'a RegistryClient,
1229    control: Arc<Semaphore>,
1230}
1231
1232impl<'a> ManagedClient<'a> {
1233    /// Create a new `ManagedClient` using the given client and concurrency semaphore.
1234    fn new(client: &'a RegistryClient, control: Arc<Semaphore>) -> Self {
1235        ManagedClient {
1236            unmanaged: client,
1237            control,
1238        }
1239    }
1240
1241    /// Perform a request using the client, respecting the concurrency limit.
1242    ///
1243    /// If the concurrency limit has been reached, this method will wait until a pending
1244    /// operation completes before executing the closure.
1245    pub async fn managed<F, T>(&self, f: impl FnOnce(&'a RegistryClient) -> F) -> T
1246    where
1247        F: Future<Output = T>,
1248    {
1249        let _permit = self.control.acquire().await.unwrap();
1250        f(self.unmanaged).await
1251    }
1252
1253    /// Perform a request using a client that internally manages the concurrency limit.
1254    ///
1255    /// The callback is passed the client and a semaphore. It must acquire the semaphore before
1256    /// any request through the client and drop it after.
1257    ///
1258    /// This method serves as an escape hatch for functions that may want to send multiple requests
1259    /// in parallel.
1260    pub async fn manual<F, T>(&'a self, f: impl FnOnce(&'a RegistryClient, &'a Semaphore) -> F) -> T
1261    where
1262        F: Future<Output = T>,
1263    {
1264        f(self.unmanaged, &self.control).await
1265    }
1266}
1267
1268/// Returns the value of the `Content-Length` header from the [`reqwest::Response`], if present.
1269fn content_length(response: &reqwest::Response) -> Option<u64> {
1270    response
1271        .headers()
1272        .get(reqwest::header::CONTENT_LENGTH)
1273        .and_then(|val| val.to_str().ok())
1274        .and_then(|val| val.parse::<u64>().ok())
1275}
1276
1277/// An asynchronous reader that reports progress as bytes are read.
1278struct ProgressReader<'a, R> {
1279    reader: R,
1280    index: usize,
1281    reporter: &'a dyn Reporter,
1282}
1283
1284impl<'a, R> ProgressReader<'a, R> {
1285    /// Create a new [`ProgressReader`] that wraps another reader.
1286    fn new(reader: R, index: usize, reporter: &'a dyn Reporter) -> Self {
1287        Self {
1288            reader,
1289            index,
1290            reporter,
1291        }
1292    }
1293}
1294
1295impl<R> AsyncRead for ProgressReader<'_, R>
1296where
1297    R: AsyncRead + Unpin,
1298{
1299    fn poll_read(
1300        mut self: Pin<&mut Self>,
1301        cx: &mut Context<'_>,
1302        buf: &mut ReadBuf<'_>,
1303    ) -> Poll<io::Result<()>> {
1304        Pin::new(&mut self.as_mut().reader)
1305            .poll_read(cx, buf)
1306            .map_ok(|()| {
1307                self.reporter
1308                    .on_download_progress(self.index, buf.filled().len() as u64);
1309            })
1310    }
1311}
1312
1313/// A pointer to an archive in the cache, fetched from an HTTP archive.
1314///
1315/// Encoded with `MsgPack`, and represented on disk by a `.http` file.
1316#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1317pub struct HttpArchivePointer {
1318    archive: Archive,
1319}
1320
1321impl HttpArchivePointer {
1322    /// Read an [`HttpArchivePointer`] from the cache.
1323    pub fn read_from(path: impl AsRef<Path>) -> Result<Option<Self>, Error> {
1324        match fs_err::File::open(path.as_ref()) {
1325            Ok(file) => {
1326                let data = DataWithCachePolicy::from_reader(file)?.data;
1327                let archive = rmp_serde::from_slice::<Archive>(&data)?;
1328                Ok(Some(Self { archive }))
1329            }
1330            Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None),
1331            Err(err) => Err(Error::CacheRead(err)),
1332        }
1333    }
1334
1335    /// Return the [`Archive`] from the pointer.
1336    pub fn into_archive(self) -> Archive {
1337        self.archive
1338    }
1339
1340    /// Return the [`CacheInfo`] from the pointer.
1341    pub fn to_cache_info(&self) -> CacheInfo {
1342        CacheInfo::default()
1343    }
1344
1345    /// Return the [`BuildInfo`] from the pointer.
1346    pub fn to_build_info(&self) -> Option<BuildInfo> {
1347        None
1348    }
1349}
1350
1351/// A pointer to an archive in the cache, fetched from a local path.
1352///
1353/// Encoded with `MsgPack`, and represented on disk by a `.rev` file.
1354#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1355pub struct LocalArchivePointer {
1356    timestamp: Timestamp,
1357    archive: Archive,
1358}
1359
1360impl LocalArchivePointer {
1361    /// Read an [`LocalArchivePointer`] from the cache.
1362    pub fn read_from(path: impl AsRef<Path>) -> Result<Option<Self>, Error> {
1363        match fs_err::read(path) {
1364            Ok(cached) => Ok(Some(rmp_serde::from_slice::<Self>(&cached)?)),
1365            Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None),
1366            Err(err) => Err(Error::CacheRead(err)),
1367        }
1368    }
1369
1370    /// Write an [`LocalArchivePointer`] to the cache.
1371    pub async fn write_to(&self, entry: &CacheEntry) -> Result<(), Error> {
1372        write_atomic(entry.path(), rmp_serde::to_vec(&self)?)
1373            .await
1374            .map_err(Error::CacheWrite)
1375    }
1376
1377    /// Returns `true` if the archive is up-to-date with the given modified timestamp.
1378    pub fn is_up_to_date(&self, modified: Timestamp) -> bool {
1379        self.timestamp == modified
1380    }
1381
1382    /// Return the [`Archive`] from the pointer.
1383    pub fn into_archive(self) -> Archive {
1384        self.archive
1385    }
1386
1387    /// Return the [`CacheInfo`] from the pointer.
1388    pub fn to_cache_info(&self) -> CacheInfo {
1389        CacheInfo::from_timestamp(self.timestamp)
1390    }
1391
1392    /// Return the [`BuildInfo`] from the pointer.
1393    pub fn to_build_info(&self) -> Option<BuildInfo> {
1394        None
1395    }
1396}
1397
1398#[derive(Debug, Clone)]
1399struct WheelTarget {
1400    /// The URL from which the wheel can be downloaded.
1401    url: DisplaySafeUrl,
1402    /// The expected extension of the wheel file.
1403    extension: WheelExtension,
1404    /// The expected size of the wheel file, if known.
1405    size: Option<u64>,
1406}
1407
1408impl TryFrom<&File> for WheelTarget {
1409    type Error = ToUrlError;
1410
1411    /// Determine the [`WheelTarget`] from a [`File`].
1412    fn try_from(file: &File) -> Result<Self, Self::Error> {
1413        let url = file.url.to_url()?;
1414        if let Some(zstd) = file.zstd.as_ref() {
1415            Ok(Self {
1416                url: add_tar_zst_extension(url),
1417                extension: WheelExtension::WhlZst,
1418                size: zstd.size,
1419            })
1420        } else {
1421            Ok(Self {
1422                url,
1423                extension: WheelExtension::Whl,
1424                size: file.size,
1425            })
1426        }
1427    }
1428}
1429
1430#[derive(Debug, Copy, Clone, PartialEq, Eq)]
1431enum WheelExtension {
1432    /// A `.whl` file.
1433    Whl,
1434    /// A `.whl.tar.zst` file.
1435    WhlZst,
1436}
1437
1438/// Add `.tar.zst` to the end of the URL path, if it doesn't already exist.
1439#[must_use]
1440fn add_tar_zst_extension(mut url: DisplaySafeUrl) -> DisplaySafeUrl {
1441    let mut path = url.path().to_string();
1442
1443    if !path.ends_with(".tar.zst") {
1444        path.push_str(".tar.zst");
1445    }
1446
1447    url.set_path(&path);
1448    url
1449}
1450
1451#[cfg(test)]
1452mod tests {
1453    use super::*;
1454
1455    #[test]
1456    fn test_add_tar_zst_extension() {
1457        let url =
1458            DisplaySafeUrl::parse("https://files.pythonhosted.org/flask-3.1.0-py3-none-any.whl")
1459                .unwrap();
1460        assert_eq!(
1461            add_tar_zst_extension(url).as_str(),
1462            "https://files.pythonhosted.org/flask-3.1.0-py3-none-any.whl.tar.zst"
1463        );
1464
1465        let url = DisplaySafeUrl::parse(
1466            "https://files.pythonhosted.org/flask-3.1.0-py3-none-any.whl.tar.zst",
1467        )
1468        .unwrap();
1469        assert_eq!(
1470            add_tar_zst_extension(url).as_str(),
1471            "https://files.pythonhosted.org/flask-3.1.0-py3-none-any.whl.tar.zst"
1472        );
1473
1474        let url = DisplaySafeUrl::parse(
1475            "https://files.pythonhosted.org/flask-3.1.0%2Bcu124-py3-none-any.whl",
1476        )
1477        .unwrap();
1478        assert_eq!(
1479            add_tar_zst_extension(url).as_str(),
1480            "https://files.pythonhosted.org/flask-3.1.0%2Bcu124-py3-none-any.whl.tar.zst"
1481        );
1482    }
1483}