uv_auth/
pyx.rs

1use std::io;
2use std::path::{Path, PathBuf};
3use std::time::Duration;
4
5use base64::Engine;
6use base64::prelude::BASE64_URL_SAFE_NO_PAD;
7use etcetera::BaseStrategy;
8use reqwest_middleware::ClientWithMiddleware;
9use tracing::debug;
10use url::Url;
11use uv_fs::{LockedFile, LockedFileMode};
12
13use uv_cache_key::CanonicalUrl;
14use uv_redacted::{DisplaySafeUrl, DisplaySafeUrlError};
15use uv_small_str::SmallString;
16use uv_state::{StateBucket, StateStore};
17use uv_static::EnvVars;
18
19use crate::credentials::Token;
20use crate::{AccessToken, Credentials, Realm};
21
22/// Retrieve the pyx API key from the environment variable, or return `None`.
23fn read_pyx_api_key() -> Option<String> {
24    std::env::var(EnvVars::PYX_API_KEY)
25        .ok()
26        .or_else(|| std::env::var(EnvVars::UV_API_KEY).ok())
27}
28
29/// Retrieve the pyx authentication token (JWT) from the environment variable, or return `None`.
30fn read_pyx_auth_token() -> Option<AccessToken> {
31    std::env::var(EnvVars::PYX_AUTH_TOKEN)
32        .ok()
33        .or_else(|| std::env::var(EnvVars::UV_AUTH_TOKEN).ok())
34        .map(AccessToken::from)
35}
36
37/// An access token with an accompanying refresh token.
38///
39/// Refresh tokens are single-use tokens that can be exchanged for a renewed access token
40/// and a new refresh token.
41#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
42pub struct PyxOAuthTokens {
43    pub access_token: AccessToken,
44    pub refresh_token: String,
45}
46
47/// An access token with an accompanying API key.
48#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
49pub struct PyxApiKeyTokens {
50    pub access_token: AccessToken,
51    pub api_key: String,
52}
53
54#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
55pub enum PyxTokens {
56    /// An access token with an accompanying refresh token.
57    ///
58    /// Refresh tokens are single-use tokens that can be exchanged for a renewed access token
59    /// and a new refresh token.
60    OAuth(PyxOAuthTokens),
61    /// An access token with an accompanying API key.
62    ///
63    /// API keys are long-lived tokens that can be exchanged for an access token.
64    ApiKey(PyxApiKeyTokens),
65}
66
67impl From<PyxTokens> for AccessToken {
68    fn from(tokens: PyxTokens) -> Self {
69        match tokens {
70            PyxTokens::OAuth(PyxOAuthTokens { access_token, .. }) => access_token,
71            PyxTokens::ApiKey(PyxApiKeyTokens { access_token, .. }) => access_token,
72        }
73    }
74}
75
76impl From<PyxTokens> for Credentials {
77    fn from(tokens: PyxTokens) -> Self {
78        let access_token = match tokens {
79            PyxTokens::OAuth(PyxOAuthTokens { access_token, .. }) => access_token,
80            PyxTokens::ApiKey(PyxApiKeyTokens { access_token, .. }) => access_token,
81        };
82        Self::from(access_token)
83    }
84}
85
86impl From<AccessToken> for Credentials {
87    fn from(access_token: AccessToken) -> Self {
88        Self::Bearer {
89            token: Token::new(access_token.into_bytes()),
90        }
91    }
92}
93
94/// Reason why a token is considered expired and needs refresh.
95#[derive(Debug, Clone)]
96enum ExpiredTokenReason {
97    /// The token has no expiration claim.
98    MissingExpiration,
99    /// Zero tolerance was requested, forcing a refresh.
100    ForcedRefresh,
101    /// The token's expiration time has passed.
102    Expired(jiff::Timestamp),
103    /// The token will expire within the tolerance window.
104    ExpiringSoon(jiff::Timestamp),
105}
106
107impl std::fmt::Display for ExpiredTokenReason {
108    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109        match self {
110            Self::MissingExpiration => write!(f, "missing expiration"),
111            Self::ForcedRefresh => write!(f, "forced refresh"),
112            Self::Expired(exp) => write!(f, "token expired (`{exp}`)"),
113            Self::ExpiringSoon(exp) => write!(f, "token will expire within tolerance (`{exp}`)"),
114        }
115    }
116}
117
118impl PyxTokens {
119    /// Returns the access token.
120    fn access_token(&self) -> &AccessToken {
121        match self {
122            Self::OAuth(PyxOAuthTokens { access_token, .. }) => access_token,
123            Self::ApiKey(PyxApiKeyTokens { access_token, .. }) => access_token,
124        }
125    }
126
127    /// Check if the token is fresh (not expired and not expiring within tolerance).
128    ///
129    /// Returns `Ok(expiration)` if fresh, or `Err(reason)` if refresh is needed.
130    fn check_fresh(&self, tolerance_secs: u64) -> Result<jiff::Timestamp, ExpiredTokenReason> {
131        let Ok(jwt) = PyxJwt::decode(self.access_token()) else {
132            return Err(ExpiredTokenReason::MissingExpiration);
133        };
134        match jwt.exp {
135            None => Err(ExpiredTokenReason::MissingExpiration),
136            Some(_) if tolerance_secs == 0 => Err(ExpiredTokenReason::ForcedRefresh),
137            Some(exp) => {
138                let Ok(exp) = jiff::Timestamp::from_second(exp) else {
139                    return Err(ExpiredTokenReason::MissingExpiration);
140                };
141                let now = jiff::Timestamp::now();
142                if exp < now {
143                    Err(ExpiredTokenReason::Expired(exp))
144                } else if exp < now + Duration::from_secs(tolerance_secs) {
145                    Err(ExpiredTokenReason::ExpiringSoon(exp))
146                } else {
147                    Ok(exp)
148                }
149            }
150        }
151    }
152}
153
154/// The default tolerance for the access token expiration.
155pub const DEFAULT_TOLERANCE_SECS: u64 = 60 * 5;
156
157#[derive(Debug, Clone)]
158struct PyxDirectories {
159    /// The root directory for the token store (e.g., `/Users/ferris/.local/share/pyx/credentials`).
160    root: PathBuf,
161    /// The subdirectory for the token store (e.g., `/Users/ferris/.local/share/uv/credentials/3859a629b26fda96`).
162    subdirectory: PathBuf,
163}
164
165impl PyxDirectories {
166    /// Detect the [`PyxDirectories`] for a given API URL.
167    fn from_api(api: &DisplaySafeUrl) -> Result<Self, io::Error> {
168        // Store credentials in a subdirectory based on the API URL.
169        let digest = uv_cache_key::cache_digest(&CanonicalUrl::new(api));
170
171        // If the user explicitly set `PYX_CREDENTIALS_DIR`, use that.
172        if let Some(root) = std::env::var_os(EnvVars::PYX_CREDENTIALS_DIR) {
173            let root = std::path::absolute(root)?;
174            let subdirectory = root.join(&digest);
175            return Ok(Self { root, subdirectory });
176        }
177
178        // If the user has pyx credentials in their uv credentials directory, read them for
179        // backwards compatibility.
180        let root = if let Some(tool_dir) = std::env::var_os(EnvVars::UV_CREDENTIALS_DIR) {
181            std::path::absolute(tool_dir)?
182        } else {
183            StateStore::from_settings(None)?.bucket(StateBucket::Credentials)
184        };
185        let subdirectory = root.join(&digest);
186        if subdirectory.exists() {
187            return Ok(Self { root, subdirectory });
188        }
189
190        // Otherwise, use (e.g.) `~/.local/share/pyx`.
191        let Ok(xdg) = etcetera::base_strategy::choose_base_strategy() else {
192            return Err(io::Error::new(
193                io::ErrorKind::NotFound,
194                "Could not determine user data directory",
195            ));
196        };
197
198        let root = xdg.data_dir().join("pyx").join("credentials");
199        let subdirectory = root.join(&digest);
200        Ok(Self { root, subdirectory })
201    }
202}
203
204#[derive(Debug, Clone)]
205pub struct PyxTokenStore {
206    /// The root directory for the token store (e.g., `/Users/ferris/.local/share/pyx/credentials`).
207    root: PathBuf,
208    /// The subdirectory for the token store (e.g., `/Users/ferris/.local/share/uv/credentials/3859a629b26fda96`).
209    subdirectory: PathBuf,
210    /// The API URL for the token store (e.g., `https://api.pyx.dev`).
211    api: DisplaySafeUrl,
212    /// The CDN domain for the token store (e.g., `astralhosted.com`).
213    cdn: SmallString,
214}
215
216impl PyxTokenStore {
217    /// Create a new [`PyxTokenStore`] from settings.
218    pub fn from_settings() -> Result<Self, TokenStoreError> {
219        // Read the API URL and CDN domain from the environment variables, or fallback to the
220        // defaults.
221        let api = if let Ok(api_url) = std::env::var(EnvVars::PYX_API_URL) {
222            DisplaySafeUrl::parse(&api_url)
223        } else {
224            DisplaySafeUrl::parse("https://api.pyx.dev")
225        }?;
226        let cdn = std::env::var(EnvVars::PYX_CDN_DOMAIN)
227            .ok()
228            .map(SmallString::from)
229            .unwrap_or_else(|| SmallString::from(arcstr::literal!("astralhosted.com")));
230
231        // Determine the root directory for the token store.
232        let PyxDirectories { root, subdirectory } = PyxDirectories::from_api(&api)?;
233
234        Ok(Self {
235            root,
236            subdirectory,
237            api,
238            cdn,
239        })
240    }
241
242    /// Return the root directory for the token store.
243    pub fn root(&self) -> &Path {
244        &self.root
245    }
246
247    /// Return the API URL for the token store.
248    pub fn api(&self) -> &DisplaySafeUrl {
249        &self.api
250    }
251
252    /// Get or initialize an [`AccessToken`] from the store.
253    ///
254    /// If an access token is set in the environment, it will be returned as-is.
255    ///
256    /// If an access token is present on-disk, it will be returned (and refreshed, if necessary).
257    ///
258    /// If no access token is found, but an API key is present, the API key will be used to
259    /// bootstrap an access token.
260    pub async fn access_token(
261        &self,
262        client: &ClientWithMiddleware,
263        tolerance_secs: u64,
264    ) -> Result<Option<AccessToken>, TokenStoreError> {
265        // If the access token is already set in the environment, return it.
266        if let Some(access_token) = read_pyx_auth_token() {
267            return Ok(Some(access_token));
268        }
269
270        // Initialize the tokens from the store.
271        let tokens = self.init(client, tolerance_secs).await?;
272
273        // Extract the access token from the OAuth tokens or API key.
274        Ok(tokens.map(AccessToken::from))
275    }
276
277    /// Initialize the [`PyxTokens`] from the store.
278    ///
279    /// If an access token is already present, it will be returned (and refreshed, if necessary).
280    ///
281    /// If no access token is found, but an API key is present, the API key will be used to
282    /// bootstrap an access token.
283    pub async fn init(
284        &self,
285        client: &ClientWithMiddleware,
286        tolerance_secs: u64,
287    ) -> Result<Option<PyxTokens>, TokenStoreError> {
288        match self.read().await? {
289            Some(tokens) => {
290                // Refresh the tokens if they are expired.
291                let tokens = self.refresh(tokens, client, tolerance_secs).await?;
292                Ok(Some(tokens))
293            }
294            None => {
295                // If no tokens are present, bootstrap them from an API key.
296                self.bootstrap(client).await
297            }
298        }
299    }
300
301    /// Write the tokens to the store.
302    pub async fn write(&self, tokens: &PyxTokens) -> Result<(), TokenStoreError> {
303        fs_err::tokio::create_dir_all(&self.subdirectory).await?;
304        match tokens {
305            PyxTokens::OAuth(tokens) => {
306                // Write OAuth tokens to a generic `tokens.json` file.
307                fs_err::tokio::write(
308                    self.subdirectory.join("tokens.json"),
309                    serde_json::to_vec(tokens)?,
310                )
311                .await?;
312            }
313            PyxTokens::ApiKey(tokens) => {
314                // Write API key tokens to a file based on the API key.
315                let digest = uv_cache_key::cache_digest(&tokens.api_key);
316                fs_err::tokio::write(
317                    self.subdirectory.join(format!("{digest}.json")),
318                    &tokens.access_token,
319                )
320                .await?;
321            }
322        }
323        Ok(())
324    }
325
326    /// Returns `true` if the user appears to have an authentication token set.
327    pub fn has_auth_token(&self) -> bool {
328        read_pyx_auth_token().is_some()
329    }
330
331    /// Returns `true` if the user appears to have an API key set.
332    pub fn has_api_key(&self) -> bool {
333        read_pyx_api_key().is_some()
334    }
335
336    /// Returns `true` if the user appears to have OAuth tokens stored on disk.
337    pub fn has_oauth_tokens(&self) -> bool {
338        self.subdirectory.join("tokens.json").is_file()
339    }
340
341    /// Returns `true` if the user appears to have credentials (which may be invalid).
342    pub fn has_credentials(&self) -> bool {
343        self.has_auth_token() || self.has_api_key() || self.has_oauth_tokens()
344    }
345
346    /// Read the tokens from the store.
347    pub async fn read(&self) -> Result<Option<PyxTokens>, TokenStoreError> {
348        if let Some(api_key) = read_pyx_api_key() {
349            // Read the API key tokens from a file based on the API key.
350            let digest = uv_cache_key::cache_digest(&api_key);
351            match fs_err::tokio::read(self.subdirectory.join(format!("{digest}.json"))).await {
352                Ok(data) => {
353                    let access_token =
354                        AccessToken::from(String::from_utf8(data).expect("Invalid UTF-8"));
355                    Ok(Some(PyxTokens::ApiKey(PyxApiKeyTokens {
356                        access_token,
357                        api_key,
358                    })))
359                }
360                Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None),
361                Err(err) => Err(err.into()),
362            }
363        } else {
364            match fs_err::tokio::read(self.subdirectory.join("tokens.json")).await {
365                Ok(data) => {
366                    let tokens: PyxOAuthTokens = serde_json::from_slice(&data)?;
367                    Ok(Some(PyxTokens::OAuth(tokens)))
368                }
369                Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None),
370                Err(err) => Err(err.into()),
371            }
372        }
373    }
374
375    /// Remove the tokens from the store.
376    pub async fn delete(&self) -> Result<(), io::Error> {
377        fs_err::tokio::remove_dir_all(&self.subdirectory).await?;
378        Ok(())
379    }
380
381    /// Return the path to the refresh lock file for a given token type.
382    ///
383    /// For OAuth tokens, uses a fixed "tokens.lock" file.
384    /// For API key tokens, uses a file based on the API key digest.
385    fn lock_path(&self, tokens: &PyxTokens) -> PathBuf {
386        match tokens {
387            PyxTokens::OAuth(_) => self.subdirectory.join("tokens.lock"),
388            PyxTokens::ApiKey(PyxApiKeyTokens { api_key, .. }) => {
389                let digest = uv_cache_key::cache_digest(api_key);
390                self.subdirectory.join(format!("{digest}.lock"))
391            }
392        }
393    }
394
395    /// Bootstrap the tokens from the store.
396    async fn bootstrap(
397        &self,
398        client: &ClientWithMiddleware,
399    ) -> Result<Option<PyxTokens>, TokenStoreError> {
400        #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
401        struct Payload {
402            access_token: AccessToken,
403        }
404
405        // Retrieve the API key from the environment variable, if set.
406        let Some(api_key) = read_pyx_api_key() else {
407            return Ok(None);
408        };
409
410        debug!("Bootstrapping access token from an API key");
411
412        // Parse the API URL.
413        let mut url = self.api.clone();
414        url.set_path("auth/cli/access-token");
415
416        let mut request = reqwest::Request::new(reqwest::Method::POST, Url::from(url));
417        request.headers_mut().insert(
418            "Authorization",
419            reqwest::header::HeaderValue::from_str(&format!("Bearer {api_key}"))?,
420        );
421
422        let response = client.execute(request).await?;
423        let Payload { access_token } = response.error_for_status()?.json::<Payload>().await?;
424        let tokens = PyxTokens::ApiKey(PyxApiKeyTokens {
425            access_token,
426            api_key,
427        });
428
429        // Write the tokens to disk.
430        self.write(&tokens).await?;
431
432        Ok(Some(tokens))
433    }
434
435    /// Refresh the tokens in the store, if they are expired.
436    ///
437    /// In theory, we should _also_ refresh if we hit a 401; but for now, we only refresh ahead of
438    /// time.
439    async fn refresh(
440        &self,
441        tokens: PyxTokens,
442        client: &ClientWithMiddleware,
443        tolerance_secs: u64,
444    ) -> Result<PyxTokens, TokenStoreError> {
445        let reason = match tokens.check_fresh(tolerance_secs) {
446            Ok(exp) => {
447                debug!("Access token is up-to-date (`{exp}`)");
448                return Ok(tokens);
449            }
450            Err(reason) => reason,
451        };
452        debug!("Refreshing token due to {reason}");
453
454        // Ensure the subdirectory exists before acquiring the lock
455        fs_err::tokio::create_dir_all(&self.subdirectory).await?;
456
457        // Get the lock path for this specific token
458        let lock_path = self.lock_path(&tokens);
459
460        // Acquire a lock to prevent concurrent refresh attempts for this token
461        let _lock = LockedFile::acquire(&lock_path, LockedFileMode::Exclusive, "pyx refresh")
462            .await
463            .map_err(|err| TokenStoreError::Io(io::Error::other(err.to_string())))?;
464
465        // Check if another process has already refreshed the tokens
466        if let Some(tokens) = self.read().await? {
467            match tokens.check_fresh(tolerance_secs) {
468                Ok(exp) => {
469                    debug!("Using recently refreshed token (`{exp}`)");
470                    return Ok(tokens);
471                }
472                Err(reason) => {
473                    debug!("Token on disk still needs refresh due to {reason}");
474                }
475            }
476        }
477
478        // Refresh the tokens
479        let tokens = match tokens {
480            PyxTokens::OAuth(PyxOAuthTokens { refresh_token, .. }) => {
481                // Parse the API URL.
482                let mut url = self.api.clone();
483                url.set_path("auth/cli/refresh");
484
485                let mut request = reqwest::Request::new(reqwest::Method::POST, Url::from(url));
486                let body = serde_json::json!({
487                    "refresh_token": refresh_token
488                });
489                *request.body_mut() = Some(body.to_string().into());
490
491                let response = client.execute(request).await?;
492                let tokens = response
493                    .error_for_status()?
494                    .json::<PyxOAuthTokens>()
495                    .await?;
496                PyxTokens::OAuth(tokens)
497            }
498            PyxTokens::ApiKey(PyxApiKeyTokens { api_key, .. }) => {
499                #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
500                struct Payload {
501                    access_token: AccessToken,
502                }
503
504                // Parse the API URL.
505                let mut url = self.api.clone();
506                url.set_path("auth/cli/access-token");
507
508                let mut request = reqwest::Request::new(reqwest::Method::POST, Url::from(url));
509                request.headers_mut().insert(
510                    "Authorization",
511                    reqwest::header::HeaderValue::from_str(&format!("Bearer {api_key}"))?,
512                );
513
514                let response = client.execute(request).await?;
515                let Payload { access_token } =
516                    response.error_for_status()?.json::<Payload>().await?;
517                PyxTokens::ApiKey(PyxApiKeyTokens {
518                    access_token,
519                    api_key,
520                })
521            }
522        };
523
524        // Write the new tokens to disk
525        self.write(&tokens).await?;
526
527        Ok(tokens)
528    }
529
530    /// Returns `true` if the given URL is "known" to this token store (i.e., should be
531    /// authenticated using the store's tokens).
532    pub fn is_known_url(&self, url: &Url) -> bool {
533        is_known_url(url, &self.api, &self.cdn)
534    }
535
536    /// Returns `true` if the URL is on a "known" domain (i.e., the same domain as the API or CDN).
537    ///
538    /// Like [`is_known_url`](Self::is_known_url), but also returns `true` if the API is on the
539    /// subdomain of the URL (e.g., if the API is `api.pyx.dev` and the URL is `pyx.dev`).
540    pub fn is_known_domain(&self, url: &Url) -> bool {
541        is_known_domain(url, &self.api, &self.cdn)
542    }
543}
544
545#[derive(thiserror::Error, Debug)]
546pub enum TokenStoreError {
547    #[error(transparent)]
548    Url(#[from] DisplaySafeUrlError),
549    #[error(transparent)]
550    Io(#[from] io::Error),
551    #[error(transparent)]
552    Serialization(#[from] serde_json::Error),
553    #[error(transparent)]
554    Reqwest(#[from] reqwest::Error),
555    #[error(transparent)]
556    ReqwestMiddleware(#[from] reqwest_middleware::Error),
557    #[error(transparent)]
558    InvalidHeaderValue(#[from] reqwest::header::InvalidHeaderValue),
559    #[error(transparent)]
560    Jiff(#[from] jiff::Error),
561    #[error(transparent)]
562    Jwt(#[from] JwtError),
563}
564
565impl TokenStoreError {
566    /// Returns `true` if the error is a 401 (Unauthorized) error.
567    pub fn is_unauthorized(&self) -> bool {
568        match self {
569            Self::Reqwest(err) => err.status() == Some(reqwest::StatusCode::UNAUTHORIZED),
570            Self::ReqwestMiddleware(err) => err.status() == Some(reqwest::StatusCode::UNAUTHORIZED),
571            _ => false,
572        }
573    }
574}
575
576/// The payload of the JWT.
577#[derive(Debug, serde::Deserialize)]
578pub struct PyxJwt {
579    /// The expiration time of the JWT, as a Unix timestamp.
580    pub exp: Option<i64>,
581    /// The issuer of the JWT.
582    pub iss: Option<String>,
583    /// The name of the organization, if any.
584    #[serde(rename = "urn:pyx:org_name")]
585    pub name: Option<String>,
586}
587
588impl PyxJwt {
589    /// Decode the JWT from the access token.
590    pub fn decode(access_token: &AccessToken) -> Result<Self, JwtError> {
591        let mut token_segments = access_token.as_str().splitn(3, '.');
592
593        let _header = token_segments.next().ok_or(JwtError::MissingHeader)?;
594        let payload = token_segments.next().ok_or(JwtError::MissingPayload)?;
595        let _signature = token_segments.next().ok_or(JwtError::MissingSignature)?;
596        if token_segments.next().is_some() {
597            return Err(JwtError::TooManySegments);
598        }
599
600        let decoded = BASE64_URL_SAFE_NO_PAD.decode(payload)?;
601
602        let jwt = serde_json::from_slice::<Self>(&decoded)?;
603        Ok(jwt)
604    }
605}
606
607#[derive(thiserror::Error, Debug)]
608pub enum JwtError {
609    #[error("JWT is missing a header")]
610    MissingHeader,
611    #[error("JWT is missing a payload")]
612    MissingPayload,
613    #[error("JWT is missing a signature")]
614    MissingSignature,
615    #[error("JWT has too many segments")]
616    TooManySegments,
617    #[error(transparent)]
618    Base64(#[from] base64::DecodeError),
619    #[error(transparent)]
620    Serde(#[from] serde_json::Error),
621}
622
623fn is_known_url(url: &Url, api: &DisplaySafeUrl, cdn: &str) -> bool {
624    // Determine whether the URL matches the API realm.
625    if Realm::from(url) == Realm::from(&**api) {
626        return true;
627    }
628
629    // Determine whether the URL matches the CDN domain (or a subdomain of it).
630    //
631    // For example, if URL is on `files.astralhosted.com` and the CDN domain is
632    // `astralhosted.com`, consider it known.
633    if matches!(url.scheme(), "https") && matches_domain(url, cdn) {
634        return true;
635    }
636
637    false
638}
639
640fn is_known_domain(url: &Url, api: &DisplaySafeUrl, cdn: &str) -> bool {
641    // Determine whether the URL matches the API domain.
642    if let Some(domain) = url.domain() {
643        if matches_domain(api, domain) {
644            return true;
645        }
646    }
647    is_known_url(url, api, cdn)
648}
649
650/// Returns `true` if the target URL is on the given domain.
651fn matches_domain(url: &Url, domain: &str) -> bool {
652    url.domain().is_some_and(|subdomain| {
653        subdomain == domain
654            || subdomain
655                .strip_suffix(domain)
656                .is_some_and(|prefix| prefix.ends_with('.'))
657    })
658}
659
660#[cfg(test)]
661mod tests {
662    use super::*;
663
664    #[test]
665    fn test_is_known_url() {
666        let api_url = DisplaySafeUrl::parse("https://api.pyx.dev").unwrap();
667        let cdn_domain = "astralhosted.com";
668
669        // Same realm as API.
670        assert!(is_known_url(
671            &Url::parse("https://api.pyx.dev/simple/").unwrap(),
672            &api_url,
673            cdn_domain
674        ));
675
676        // Different path on same API domain
677        assert!(is_known_url(
678            &Url::parse("https://api.pyx.dev/v1/").unwrap(),
679            &api_url,
680            cdn_domain
681        ));
682
683        // CDN domain.
684        assert!(is_known_url(
685            &Url::parse("https://astralhosted.com/packages/").unwrap(),
686            &api_url,
687            cdn_domain
688        ));
689
690        // CDN subdomain.
691        assert!(is_known_url(
692            &Url::parse("https://files.astralhosted.com/packages/").unwrap(),
693            &api_url,
694            cdn_domain
695        ));
696
697        // CDN on HTTP.
698        assert!(!is_known_url(
699            &Url::parse("http://astralhosted.com/packages/").unwrap(),
700            &api_url,
701            cdn_domain
702        ));
703
704        // Unknown domain.
705        assert!(!is_known_url(
706            &Url::parse("https://pypi.org/simple/").unwrap(),
707            &api_url,
708            cdn_domain
709        ));
710
711        // Similar but not matching domain.
712        assert!(!is_known_url(
713            &Url::parse("https://badastralhosted.com/packages/").unwrap(),
714            &api_url,
715            cdn_domain
716        ));
717    }
718
719    #[test]
720    fn test_is_known_domain() {
721        let api_url = DisplaySafeUrl::parse("https://api.pyx.dev").unwrap();
722        let cdn_domain = "astralhosted.com";
723
724        // Same realm as API.
725        assert!(is_known_domain(
726            &Url::parse("https://api.pyx.dev/simple/").unwrap(),
727            &api_url,
728            cdn_domain
729        ));
730
731        // API super-domain.
732        assert!(is_known_domain(
733            &Url::parse("https://pyx.dev").unwrap(),
734            &api_url,
735            cdn_domain
736        ));
737
738        // API subdomain.
739        assert!(!is_known_domain(
740            &Url::parse("https://foo.api.pyx.dev").unwrap(),
741            &api_url,
742            cdn_domain
743        ));
744
745        // Different subdomain.
746        assert!(!is_known_domain(
747            &Url::parse("https://beta.pyx.dev/").unwrap(),
748            &api_url,
749            cdn_domain
750        ));
751
752        // CDN domain.
753        assert!(is_known_domain(
754            &Url::parse("https://astralhosted.com/packages/").unwrap(),
755            &api_url,
756            cdn_domain
757        ));
758
759        // CDN subdomain.
760        assert!(is_known_domain(
761            &Url::parse("https://files.astralhosted.com/packages/").unwrap(),
762            &api_url,
763            cdn_domain
764        ));
765
766        // Unknown domain.
767        assert!(!is_known_domain(
768            &Url::parse("https://pypi.org/simple/").unwrap(),
769            &api_url,
770            cdn_domain
771        ));
772
773        // Different TLD.
774        assert!(!is_known_domain(
775            &Url::parse("https://pyx.com/").unwrap(),
776            &api_url,
777            cdn_domain
778        ));
779    }
780
781    #[test]
782    fn test_matches_domain() {
783        assert!(matches_domain(
784            &Url::parse("https://example.com").unwrap(),
785            "example.com"
786        ));
787        assert!(matches_domain(
788            &Url::parse("https://foo.example.com").unwrap(),
789            "example.com"
790        ));
791        assert!(matches_domain(
792            &Url::parse("https://bar.foo.example.com").unwrap(),
793            "example.com"
794        ));
795
796        assert!(!matches_domain(
797            &Url::parse("https://example.com").unwrap(),
798            "other.com"
799        ));
800        assert!(!matches_domain(
801            &Url::parse("https://example.org").unwrap(),
802            "example.com"
803        ));
804        assert!(!matches_domain(
805            &Url::parse("https://badexample.com").unwrap(),
806            "example.com"
807        ));
808    }
809}