use std::io;
use std::path::{Path, PathBuf};
use std::time::Duration;
use base64::Engine;
use base64::prelude::BASE64_URL_SAFE_NO_PAD;
use etcetera::BaseStrategy;
use reqwest_middleware::ClientWithMiddleware;
use tracing::debug;
use url::Url;
use uv_fs::{LockedFile, LockedFileMode};
use uv_cache_key::CanonicalUrl;
use uv_redacted::{DisplaySafeUrl, DisplaySafeUrlError};
use uv_small_str::SmallString;
use uv_state::{StateBucket, StateStore};
use uv_static::EnvVars;
use crate::credentials::Token;
use crate::{AccessToken, Credentials, Realm};
const PYX_DEFAULT_API_URL: &str = "https://api.pyx.dev";
const PYX_DEFAULT_CDN_DOMAIN: &str = "astralhosted.com";
fn read_pyx_api_key() -> Option<String> {
std::env::var(EnvVars::PYX_API_KEY)
.ok()
.or_else(|| std::env::var(EnvVars::UV_API_KEY).ok())
}
fn read_pyx_auth_token() -> Option<AccessToken> {
std::env::var(EnvVars::PYX_AUTH_TOKEN)
.ok()
.or_else(|| std::env::var(EnvVars::UV_AUTH_TOKEN).ok())
.map(AccessToken::from)
}
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
pub struct PyxOAuthTokens {
pub access_token: AccessToken,
pub refresh_token: String,
}
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
pub struct PyxApiKeyTokens {
pub access_token: AccessToken,
pub api_key: String,
}
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
pub enum PyxTokens {
OAuth(PyxOAuthTokens),
ApiKey(PyxApiKeyTokens),
}
impl From<PyxTokens> for AccessToken {
fn from(tokens: PyxTokens) -> Self {
match tokens {
PyxTokens::OAuth(PyxOAuthTokens { access_token, .. }) => access_token,
PyxTokens::ApiKey(PyxApiKeyTokens { access_token, .. }) => access_token,
}
}
}
impl From<PyxTokens> for Credentials {
fn from(tokens: PyxTokens) -> Self {
let access_token = match tokens {
PyxTokens::OAuth(PyxOAuthTokens { access_token, .. }) => access_token,
PyxTokens::ApiKey(PyxApiKeyTokens { access_token, .. }) => access_token,
};
Self::from(access_token)
}
}
impl From<AccessToken> for Credentials {
fn from(access_token: AccessToken) -> Self {
Self::Bearer {
token: Token::new(access_token.into_bytes()),
}
}
}
#[derive(Debug, Clone)]
enum ExpiredTokenReason {
MissingExpiration,
ForcedRefresh,
Expired(jiff::Timestamp),
ExpiringSoon(jiff::Timestamp),
}
impl std::fmt::Display for ExpiredTokenReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::MissingExpiration => write!(f, "missing expiration"),
Self::ForcedRefresh => write!(f, "forced refresh"),
Self::Expired(exp) => write!(f, "token expired (`{exp}`)"),
Self::ExpiringSoon(exp) => write!(f, "token will expire within tolerance (`{exp}`)"),
}
}
}
impl PyxTokens {
fn access_token(&self) -> &AccessToken {
match self {
Self::OAuth(PyxOAuthTokens { access_token, .. }) => access_token,
Self::ApiKey(PyxApiKeyTokens { access_token, .. }) => access_token,
}
}
fn check_fresh(&self, tolerance_secs: u64) -> Result<jiff::Timestamp, ExpiredTokenReason> {
let Ok(jwt) = PyxJwt::decode(self.access_token()) else {
return Err(ExpiredTokenReason::MissingExpiration);
};
match jwt.exp {
None => Err(ExpiredTokenReason::MissingExpiration),
Some(_) if tolerance_secs == 0 => Err(ExpiredTokenReason::ForcedRefresh),
Some(exp) => {
let Ok(exp) = jiff::Timestamp::from_second(exp) else {
return Err(ExpiredTokenReason::MissingExpiration);
};
let now = jiff::Timestamp::now();
if exp < now {
Err(ExpiredTokenReason::Expired(exp))
} else if exp < now + Duration::from_secs(tolerance_secs) {
Err(ExpiredTokenReason::ExpiringSoon(exp))
} else {
Ok(exp)
}
}
}
}
}
pub const DEFAULT_TOLERANCE_SECS: u64 = 60 * 5;
#[derive(Debug, Clone)]
struct PyxDirectories {
root: PathBuf,
subdirectory: PathBuf,
}
impl PyxDirectories {
fn from_api(api: &DisplaySafeUrl) -> Result<Self, io::Error> {
let digest = uv_cache_key::cache_digest(&CanonicalUrl::new(api));
if let Some(root) = std::env::var_os(EnvVars::PYX_CREDENTIALS_DIR) {
let root = std::path::absolute(root)?;
let subdirectory = root.join(&digest);
return Ok(Self { root, subdirectory });
}
let root = if let Some(tool_dir) = std::env::var_os(EnvVars::UV_CREDENTIALS_DIR) {
std::path::absolute(tool_dir)?
} else {
StateStore::from_settings(None)?.bucket(StateBucket::Credentials)
};
let subdirectory = root.join(&digest);
if subdirectory.exists() {
return Ok(Self { root, subdirectory });
}
let Ok(xdg) = etcetera::base_strategy::choose_base_strategy() else {
return Err(io::Error::new(
io::ErrorKind::NotFound,
"Could not determine user data directory",
));
};
let root = xdg.data_dir().join("pyx").join("credentials");
let subdirectory = root.join(&digest);
Ok(Self { root, subdirectory })
}
}
#[derive(Debug, Clone)]
pub struct PyxTokenStore {
root: PathBuf,
subdirectory: PathBuf,
api: DisplaySafeUrl,
cdn: SmallString,
}
impl PyxTokenStore {
pub fn from_settings() -> Result<Self, TokenStoreError> {
let api = if let Ok(api_url) = std::env::var(EnvVars::PYX_API_URL) {
DisplaySafeUrl::parse(&api_url)
} else {
DisplaySafeUrl::parse(PYX_DEFAULT_API_URL)
}?;
let cdn = std::env::var(EnvVars::PYX_CDN_DOMAIN)
.ok()
.map(SmallString::from)
.unwrap_or_else(|| SmallString::from(arcstr::literal!(PYX_DEFAULT_CDN_DOMAIN)));
let PyxDirectories { root, subdirectory } = PyxDirectories::from_api(&api)?;
Ok(Self {
root,
subdirectory,
api,
cdn,
})
}
pub fn root(&self) -> &Path {
&self.root
}
pub fn api(&self) -> &DisplaySafeUrl {
&self.api
}
pub async fn access_token(
&self,
client: &ClientWithMiddleware,
tolerance_secs: u64,
) -> Result<Option<AccessToken>, TokenStoreError> {
if let Some(access_token) = read_pyx_auth_token() {
return Ok(Some(access_token));
}
let tokens = self.init(client, tolerance_secs).await?;
Ok(tokens.map(AccessToken::from))
}
pub async fn init(
&self,
client: &ClientWithMiddleware,
tolerance_secs: u64,
) -> Result<Option<PyxTokens>, TokenStoreError> {
match self.read().await? {
Some(tokens) => {
let tokens = self.refresh(tokens, client, tolerance_secs).await?;
Ok(Some(tokens))
}
None => {
self.bootstrap(client).await
}
}
}
pub async fn write(&self, tokens: &PyxTokens) -> Result<(), TokenStoreError> {
fs_err::tokio::create_dir_all(&self.subdirectory).await?;
match tokens {
PyxTokens::OAuth(tokens) => {
fs_err::tokio::write(
self.subdirectory.join("tokens.json"),
serde_json::to_vec(tokens)?,
)
.await?;
}
PyxTokens::ApiKey(tokens) => {
let digest = uv_cache_key::cache_digest(&tokens.api_key);
fs_err::tokio::write(
self.subdirectory.join(format!("{digest}.json")),
&tokens.access_token,
)
.await?;
}
}
Ok(())
}
pub fn has_auth_token(&self) -> bool {
read_pyx_auth_token().is_some()
}
pub fn has_api_key(&self) -> bool {
read_pyx_api_key().is_some()
}
pub fn has_oauth_tokens(&self) -> bool {
self.subdirectory.join("tokens.json").is_file()
}
pub fn has_credentials(&self) -> bool {
self.has_auth_token() || self.has_api_key() || self.has_oauth_tokens()
}
pub async fn read(&self) -> Result<Option<PyxTokens>, TokenStoreError> {
if let Some(api_key) = read_pyx_api_key() {
let digest = uv_cache_key::cache_digest(&api_key);
match fs_err::tokio::read(self.subdirectory.join(format!("{digest}.json"))).await {
Ok(data) => {
let access_token =
AccessToken::from(String::from_utf8(data).expect("Invalid UTF-8"));
Ok(Some(PyxTokens::ApiKey(PyxApiKeyTokens {
access_token,
api_key,
})))
}
Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None),
Err(err) => Err(err.into()),
}
} else {
match fs_err::tokio::read(self.subdirectory.join("tokens.json")).await {
Ok(data) => {
let tokens: PyxOAuthTokens = serde_json::from_slice(&data)?;
Ok(Some(PyxTokens::OAuth(tokens)))
}
Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(None),
Err(err) => Err(err.into()),
}
}
}
pub async fn delete(&self) -> Result<(), io::Error> {
fs_err::tokio::remove_dir_all(&self.subdirectory).await?;
Ok(())
}
fn lock_path(&self, tokens: &PyxTokens) -> PathBuf {
match tokens {
PyxTokens::OAuth(_) => self.subdirectory.join("tokens.lock"),
PyxTokens::ApiKey(PyxApiKeyTokens { api_key, .. }) => {
let digest = uv_cache_key::cache_digest(api_key);
self.subdirectory.join(format!("{digest}.lock"))
}
}
}
async fn bootstrap(
&self,
client: &ClientWithMiddleware,
) -> Result<Option<PyxTokens>, TokenStoreError> {
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
struct Payload {
access_token: AccessToken,
}
let Some(api_key) = read_pyx_api_key() else {
return Ok(None);
};
debug!("Bootstrapping access token from an API key");
let mut url = self.api.clone();
url.set_path("auth/cli/access-token");
let mut request = reqwest::Request::new(reqwest::Method::POST, Url::from(url));
request.headers_mut().insert(
"Authorization",
reqwest::header::HeaderValue::from_str(&format!("Bearer {api_key}"))?,
);
let response = client.execute(request).await?;
let Payload { access_token } = response.error_for_status()?.json::<Payload>().await?;
let tokens = PyxTokens::ApiKey(PyxApiKeyTokens {
access_token,
api_key,
});
self.write(&tokens).await?;
Ok(Some(tokens))
}
async fn refresh(
&self,
tokens: PyxTokens,
client: &ClientWithMiddleware,
tolerance_secs: u64,
) -> Result<PyxTokens, TokenStoreError> {
let reason = match tokens.check_fresh(tolerance_secs) {
Ok(exp) => {
debug!("Access token is up-to-date (`{exp}`)");
return Ok(tokens);
}
Err(reason) => reason,
};
debug!("Refreshing token due to {reason}");
fs_err::tokio::create_dir_all(&self.subdirectory).await?;
let lock_path = self.lock_path(&tokens);
let _lock = LockedFile::acquire(&lock_path, LockedFileMode::Exclusive, "pyx refresh")
.await
.map_err(|err| TokenStoreError::Io(io::Error::other(err.to_string())))?;
if let Some(tokens) = self.read().await? {
match tokens.check_fresh(tolerance_secs) {
Ok(exp) => {
debug!("Using recently refreshed token (`{exp}`)");
return Ok(tokens);
}
Err(reason) => {
debug!("Token on disk still needs refresh due to {reason}");
}
}
}
let tokens = match tokens {
PyxTokens::OAuth(PyxOAuthTokens { refresh_token, .. }) => {
let mut url = self.api.clone();
url.set_path("auth/cli/refresh");
let mut request = reqwest::Request::new(reqwest::Method::POST, Url::from(url));
let body = serde_json::json!({
"refresh_token": refresh_token
});
*request.body_mut() = Some(body.to_string().into());
let response = client.execute(request).await?;
let tokens = response
.error_for_status()?
.json::<PyxOAuthTokens>()
.await?;
PyxTokens::OAuth(tokens)
}
PyxTokens::ApiKey(PyxApiKeyTokens { api_key, .. }) => {
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
struct Payload {
access_token: AccessToken,
}
let mut url = self.api.clone();
url.set_path("auth/cli/access-token");
let mut request = reqwest::Request::new(reqwest::Method::POST, Url::from(url));
request.headers_mut().insert(
"Authorization",
reqwest::header::HeaderValue::from_str(&format!("Bearer {api_key}"))?,
);
let response = client.execute(request).await?;
let Payload { access_token } =
response.error_for_status()?.json::<Payload>().await?;
PyxTokens::ApiKey(PyxApiKeyTokens {
access_token,
api_key,
})
}
};
self.write(&tokens).await?;
Ok(tokens)
}
pub fn is_known_url(&self, url: &Url) -> bool {
is_known_url(url, &self.api, &self.cdn)
}
pub fn is_known_domain(&self, url: &Url) -> bool {
is_known_domain(url, &self.api, &self.cdn)
}
}
#[derive(thiserror::Error, Debug)]
pub enum TokenStoreError {
#[error(transparent)]
Url(#[from] DisplaySafeUrlError),
#[error(transparent)]
Io(#[from] io::Error),
#[error(transparent)]
Serialization(#[from] serde_json::Error),
#[error(transparent)]
Reqwest(#[from] reqwest::Error),
#[error(transparent)]
ReqwestMiddleware(#[from] reqwest_middleware::Error),
#[error(transparent)]
InvalidHeaderValue(#[from] reqwest::header::InvalidHeaderValue),
#[error(transparent)]
Jiff(#[from] jiff::Error),
#[error(transparent)]
Jwt(#[from] JwtError),
}
impl TokenStoreError {
pub fn is_unauthorized(&self) -> bool {
match self {
Self::Reqwest(err) => err.status() == Some(reqwest::StatusCode::UNAUTHORIZED),
Self::ReqwestMiddleware(err) => err.status() == Some(reqwest::StatusCode::UNAUTHORIZED),
_ => false,
}
}
}
#[derive(Debug, serde::Deserialize)]
pub struct PyxJwt {
pub exp: Option<i64>,
pub iss: Option<String>,
#[serde(rename = "urn:pyx:org_name")]
pub name: Option<String>,
}
impl PyxJwt {
pub fn decode(access_token: &AccessToken) -> Result<Self, JwtError> {
let mut token_segments = access_token.as_str().splitn(3, '.');
let _header = token_segments.next().ok_or(JwtError::MissingHeader)?;
let payload = token_segments.next().ok_or(JwtError::MissingPayload)?;
let _signature = token_segments.next().ok_or(JwtError::MissingSignature)?;
if token_segments.next().is_some() {
return Err(JwtError::TooManySegments);
}
let decoded = BASE64_URL_SAFE_NO_PAD.decode(payload)?;
let jwt = serde_json::from_slice::<Self>(&decoded)?;
Ok(jwt)
}
}
#[derive(thiserror::Error, Debug)]
pub enum JwtError {
#[error("JWT is missing a header")]
MissingHeader,
#[error("JWT is missing a payload")]
MissingPayload,
#[error("JWT is missing a signature")]
MissingSignature,
#[error("JWT has too many segments")]
TooManySegments,
#[error(transparent)]
Base64(#[from] base64::DecodeError),
#[error(transparent)]
Serde(#[from] serde_json::Error),
}
fn is_known_url(url: &Url, api: &DisplaySafeUrl, cdn: &str) -> bool {
if Realm::from(url) == Realm::from(&**api) {
return true;
}
if matches!(url.scheme(), "https") && matches_domain(url, cdn) {
return true;
}
false
}
fn is_known_domain(url: &Url, api: &DisplaySafeUrl, cdn: &str) -> bool {
if let Some(domain) = url.domain() {
if matches_domain(api, domain) {
return true;
}
}
is_known_url(url, api, cdn)
}
pub fn is_default_pyx_domain(url: &Url) -> bool {
let api = DisplaySafeUrl::parse(PYX_DEFAULT_API_URL).expect("default API URL should be valid");
is_known_domain(url, &api, PYX_DEFAULT_CDN_DOMAIN)
}
fn matches_domain(url: &Url, domain: &str) -> bool {
url.domain().is_some_and(|subdomain| {
subdomain == domain
|| subdomain
.strip_suffix(domain)
.is_some_and(|prefix| prefix.ends_with('.'))
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_known_url() {
let api_url = DisplaySafeUrl::parse("https://api.pyx.dev").unwrap();
let cdn_domain = "astralhosted.com";
assert!(is_known_url(
&Url::parse("https://api.pyx.dev/simple/").unwrap(),
&api_url,
cdn_domain
));
assert!(is_known_url(
&Url::parse("https://api.pyx.dev/v1/").unwrap(),
&api_url,
cdn_domain
));
assert!(is_known_url(
&Url::parse("https://astralhosted.com/packages/").unwrap(),
&api_url,
cdn_domain
));
assert!(is_known_url(
&Url::parse("https://files.astralhosted.com/packages/").unwrap(),
&api_url,
cdn_domain
));
assert!(!is_known_url(
&Url::parse("http://astralhosted.com/packages/").unwrap(),
&api_url,
cdn_domain
));
assert!(!is_known_url(
&Url::parse("https://pypi.org/simple/").unwrap(),
&api_url,
cdn_domain
));
assert!(!is_known_url(
&Url::parse("https://badastralhosted.com/packages/").unwrap(),
&api_url,
cdn_domain
));
}
#[test]
fn test_is_known_domain() {
let api_url = DisplaySafeUrl::parse("https://api.pyx.dev").unwrap();
let cdn_domain = "astralhosted.com";
assert!(is_known_domain(
&Url::parse("https://api.pyx.dev/simple/").unwrap(),
&api_url,
cdn_domain
));
assert!(is_known_domain(
&Url::parse("https://pyx.dev").unwrap(),
&api_url,
cdn_domain
));
assert!(!is_known_domain(
&Url::parse("https://foo.api.pyx.dev").unwrap(),
&api_url,
cdn_domain
));
assert!(!is_known_domain(
&Url::parse("https://beta.pyx.dev/").unwrap(),
&api_url,
cdn_domain
));
assert!(is_known_domain(
&Url::parse("https://astralhosted.com/packages/").unwrap(),
&api_url,
cdn_domain
));
assert!(is_known_domain(
&Url::parse("https://files.astralhosted.com/packages/").unwrap(),
&api_url,
cdn_domain
));
assert!(!is_known_domain(
&Url::parse("https://pypi.org/simple/").unwrap(),
&api_url,
cdn_domain
));
assert!(!is_known_domain(
&Url::parse("https://pyx.com/").unwrap(),
&api_url,
cdn_domain
));
}
#[test]
fn test_is_default_pyx_domain() {
assert!(is_default_pyx_domain(
&Url::parse("https://pyx.dev").unwrap()
));
assert!(is_default_pyx_domain(
&Url::parse("https://api.pyx.dev").unwrap()
));
assert!(is_default_pyx_domain(
&Url::parse("https://astralhosted.com").unwrap()
));
assert!(is_default_pyx_domain(
&Url::parse("https://files.astralhosted.com").unwrap()
));
assert!(!is_default_pyx_domain(
&Url::parse("http://localhost:8000").unwrap()
));
assert!(!is_default_pyx_domain(
&Url::parse("https://pypi.org").unwrap()
));
assert!(!is_default_pyx_domain(
&Url::parse("https://pyx.com").unwrap()
));
}
#[test]
fn test_matches_domain() {
assert!(matches_domain(
&Url::parse("https://example.com").unwrap(),
"example.com"
));
assert!(matches_domain(
&Url::parse("https://foo.example.com").unwrap(),
"example.com"
));
assert!(matches_domain(
&Url::parse("https://bar.foo.example.com").unwrap(),
"example.com"
));
assert!(!matches_domain(
&Url::parse("https://example.com").unwrap(),
"other.com"
));
assert!(!matches_domain(
&Url::parse("https://example.org").unwrap(),
"example.com"
));
assert!(!matches_domain(
&Url::parse("https://badexample.com").unwrap(),
"example.com"
));
}
#[test]
fn test_is_default_pyx_domain_staging() {
assert!(!is_default_pyx_domain(
&Url::parse("https://astral-sh-staging-api.pyx.dev").unwrap()
));
assert!(!is_default_pyx_domain(
&Url::parse("https://beta.pyx.dev").unwrap()
));
}
}