use base64::Engine;
use base64::prelude::BASE64_URL_SAFE_NO_PAD;
use reqwest::StatusCode;
use reqwest_middleware::ClientWithMiddleware;
use serde::{Deserialize, Serialize};
use std::env;
use std::fmt::Display;
use thiserror::Error;
use tracing::{debug, trace};
use url::Url;
use uv_redacted::{DisplaySafeUrl, DisplaySafeUrlError};
use uv_static::EnvVars;
#[derive(Debug, Error)]
pub enum TrustedPublishingError {
#[error(transparent)]
Url(#[from] DisplaySafeUrlError),
#[error("Failed to obtain OIDC token: is the `id-token: write` permission missing?")]
GitHubPermissions(#[source] ambient_id::Error),
#[error("Failed to discover OIDC token")]
Discovery(#[source] ambient_id::Error),
#[error("No OIDC token discovered: are you in a supported trusted publishing environment?")]
NoToken,
#[error("Failed to fetch: `{0}`")]
Reqwest(DisplaySafeUrl, #[source] reqwest::Error),
#[error("Failed to fetch: `{0}`")]
ReqwestMiddleware(DisplaySafeUrl, #[source] reqwest_middleware::Error),
#[error(transparent)]
SerdeJson(#[from] serde_json::error::Error),
#[error(
"PyPI returned error code {0}, is trusted publishing correctly configured?\nResponse: {1}\nToken claims, which must match the PyPI configuration: {2:#?}"
)]
Pypi(StatusCode, String, OidcTokenClaims),
#[error("PyPI returned error code {0}, and the OIDC has an unexpected format.\nResponse: {1}")]
InvalidOidcToken(StatusCode, String),
}
#[derive(Deserialize)]
#[serde(transparent)]
pub struct TrustedPublishingToken(String);
impl Display for TrustedPublishingToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Deserialize)]
struct Audience {
audience: String,
}
#[derive(Serialize)]
struct MintTokenRequest {
token: String,
}
#[derive(Deserialize)]
struct PublishToken {
token: TrustedPublishingToken,
}
#[derive(Deserialize, Debug)]
#[allow(dead_code)]
pub struct OidcTokenClaims {
sub: String,
repository: String,
repository_owner: String,
repository_owner_id: String,
job_workflow_ref: String,
r#ref: String,
}
pub(crate) async fn get_token(
registry: &DisplaySafeUrl,
client: &ClientWithMiddleware,
) -> Result<Option<TrustedPublishingToken>, TrustedPublishingError> {
let audience = get_audience(registry, client).await?;
let oidc_token = get_oidc_token(&audience, client).await?;
if let Some(oidc_token) = oidc_token {
let publish_token = get_publish_token(registry, oidc_token, client).await?;
#[allow(clippy::print_stdout)]
if env::var(EnvVars::GITHUB_ACTIONS) == Ok("true".to_string()) {
println!("::add-mask::{publish_token}");
}
Ok(Some(publish_token))
} else {
Ok(None)
}
}
async fn get_audience(
registry: &DisplaySafeUrl,
client: &ClientWithMiddleware,
) -> Result<String, TrustedPublishingError> {
let scheme: &str = if cfg!(feature = "test") {
registry.scheme()
} else {
"https"
};
let audience_url = DisplaySafeUrl::parse(&format!(
"{}://{}/_/oidc/audience",
scheme,
registry.authority()
))?;
debug!("Querying the trusted publishing audience from {audience_url}");
let response = client
.get(Url::from(audience_url.clone()))
.send()
.await
.map_err(|err| TrustedPublishingError::ReqwestMiddleware(audience_url.clone(), err))?;
let audience = response
.error_for_status()
.map_err(|err| TrustedPublishingError::Reqwest(audience_url.clone(), err))?
.json::<Audience>()
.await
.map_err(|err| TrustedPublishingError::Reqwest(audience_url.clone(), err))?;
trace!("The audience is `{}`", &audience.audience);
Ok(audience.audience)
}
async fn get_oidc_token(
audience: &str,
client: &ClientWithMiddleware,
) -> Result<Option<ambient_id::IdToken>, TrustedPublishingError> {
let detector = ambient_id::Detector::new_with_client(client.clone());
match detector.detect(audience).await {
Ok(token) => Ok(token),
Err(
err @ ambient_id::Error::GitHubActions(
ambient_id::GitHubError::InsufficientPermissions(_),
),
) => Err(TrustedPublishingError::GitHubPermissions(err)),
Err(err) => Err(TrustedPublishingError::Discovery(err)),
}
}
fn decode_oidc_token(oidc_token: &str) -> Option<OidcTokenClaims> {
let token_segments = oidc_token.splitn(3, '.').collect::<Vec<&str>>();
let [_header, payload, _signature] = *token_segments.into_boxed_slice() else {
return None;
};
let decoded = BASE64_URL_SAFE_NO_PAD.decode(payload).ok()?;
serde_json::from_slice(&decoded).ok()
}
async fn get_publish_token(
registry: &DisplaySafeUrl,
oidc_token: ambient_id::IdToken,
client: &ClientWithMiddleware,
) -> Result<TrustedPublishingToken, TrustedPublishingError> {
let scheme: &str = if cfg!(feature = "test") {
registry.scheme()
} else {
"https"
};
let mint_token_url = DisplaySafeUrl::parse(&format!(
"{}://{}/_/oidc/mint-token",
scheme,
registry.authority()
))?;
debug!("Querying the trusted publishing upload token from {mint_token_url}");
let mint_token_payload = MintTokenRequest {
token: oidc_token.reveal().to_string(),
};
let response = client
.post(Url::from(mint_token_url.clone()))
.body(serde_json::to_vec(&mint_token_payload)?)
.send()
.await
.map_err(|err| TrustedPublishingError::ReqwestMiddleware(mint_token_url.clone(), err))?;
let status = response.status();
let body = response
.bytes()
.await
.map_err(|err| TrustedPublishingError::Reqwest(mint_token_url.clone(), err))?;
if status.is_success() {
let publish_token: PublishToken = serde_json::from_slice(&body)?;
Ok(publish_token.token)
} else {
match decode_oidc_token(oidc_token.reveal()) {
Some(claims) => {
Err(TrustedPublishingError::Pypi(
status,
String::from_utf8_lossy(&body).to_string(),
claims,
))
}
None => {
Err(TrustedPublishingError::InvalidOidcToken(
status,
String::from_utf8_lossy(&body).to_string(),
))
}
}
}
}