#[cfg(feature = "oauth")]
pub mod oauth;
use clap::Parser;
use rattler_networking::{
Authentication, AuthenticationStorage, authentication_storage::AuthenticationStorageError,
};
use reqwest::{Client, header::CONTENT_TYPE};
use serde_json::json;
use thiserror;
use url::Url;
pub const DEFAULT_USER_AGENT: &str = concat!("rattler/", env!("CARGO_PKG_VERSION"));
#[derive(Parser, Debug)]
struct LoginArgs {
host: String,
#[clap(long, help_heading = "Token / Basic Authentication")]
token: Option<String>,
#[clap(long, help_heading = "Token / Basic Authentication")]
username: Option<String>,
#[clap(long, help_heading = "Token / Basic Authentication")]
password: Option<String>,
#[clap(long, help_heading = "Token / Basic Authentication")]
conda_token: Option<String>,
#[clap(long, requires_all = ["s3_secret_access_key"], conflicts_with_all = ["token", "username", "password", "conda_token"], help_heading = "S3 Authentication")]
s3_access_key_id: Option<String>,
#[clap(long, requires_all = ["s3_access_key_id"], help_heading = "S3 Authentication")]
s3_secret_access_key: Option<String>,
#[clap(long, requires_all = ["s3_access_key_id"], help_heading = "S3 Authentication")]
s3_session_token: Option<String>,
#[cfg(feature = "oauth")]
#[clap(long, conflicts_with_all = ["token", "username", "password", "conda_token", "s3_access_key_id"], help_heading = "OAuth/OIDC Authentication")]
oauth: bool,
#[cfg(feature = "oauth")]
#[clap(long, requires = "oauth", help_heading = "OAuth/OIDC Authentication")]
oauth_issuer_url: Option<String>,
#[cfg(feature = "oauth")]
#[clap(long, requires = "oauth", help_heading = "OAuth/OIDC Authentication")]
oauth_client_id: Option<String>,
#[cfg(feature = "oauth")]
#[clap(long, requires = "oauth", help_heading = "OAuth/OIDC Authentication")]
oauth_client_secret: Option<String>,
#[cfg(feature = "oauth")]
#[clap(long, requires = "oauth", value_parser = ["auto", "auth-code", "device-code"], help_heading = "OAuth/OIDC Authentication")]
oauth_flow: Option<String>,
#[cfg(feature = "oauth")]
#[clap(
long = "oauth-scope",
requires = "oauth",
help_heading = "OAuth/OIDC Authentication"
)]
oauth_scopes: Vec<String>,
#[cfg(feature = "oauth")]
#[clap(long, requires = "oauth", help_heading = "OAuth/OIDC Authentication")]
oauth_redirect_uri: Option<String>,
#[clap(long)]
user_agent: Option<String>,
}
#[derive(Parser, Debug)]
struct LogoutArgs {
host: String,
}
#[derive(Parser, Debug)]
#[allow(clippy::large_enum_variant)]
enum Subcommand {
Login(LoginArgs),
Logout(LogoutArgs),
}
#[derive(Parser, Debug)]
pub struct Args {
#[clap(subcommand)]
subcommand: Subcommand,
}
#[derive(thiserror::Error, Debug)]
pub enum AuthenticationCLIError {
#[error("Failed to parse the URL")]
ParseUrlError(#[from] url::ParseError),
#[error("Password must be provided when using basic authentication")]
MissingPassword,
#[error("No authentication method provided")]
NoAuthenticationMethod,
#[error("Authentication with prefix.dev requires a token. Use `--token` to provide one")]
PrefixDevBadMethod,
#[error(
"Authentication with anaconda.org requires a conda token. Use `--conda-token` to provide one"
)]
AnacondaOrgBadMethod,
#[error(
"Authentication with S3 requires a S3 access key ID and a secret access key. Use `--s3-access-key-id` and `--s3-secret-access-key` to provide them"
)]
S3BadMethod,
#[error("Failed to interact with the authentication storage system")]
AnyhowError(#[from] anyhow::Error),
#[error("Failed to interact with the authentication storage system")]
AuthenticationStorageError(#[from] AuthenticationStorageError),
#[error("General http request error")]
ReqwestError(#[from] reqwest::Error),
#[error("Failed to parse JSON: {0}")]
JsonParseError(String),
#[error("Unauthorized or invalid token")]
UnauthorizedToken,
#[cfg(feature = "oauth")]
#[error(transparent)]
OAuthError(#[from] oauth::OAuthError),
}
fn normalize_login_host(host: &str) -> String {
let host = host.trim_start_matches("*.");
if let Some(h) = url::Url::parse(host)
.ok()
.and_then(|u| u.host_str().map(str::to_string))
{
return h;
}
url::Url::parse(&format!("https://{host}"))
.ok()
.and_then(|u| u.host_str().map(str::to_string))
.unwrap_or_else(|| host.trim_end_matches('/').to_string())
}
#[cfg(feature = "oauth")]
const PREFIX_DEV_OAUTH_SCOPES: &[&str] = &[
"openid",
"profile",
"offline_access",
"channel:read",
"channel:upload",
];
#[cfg(feature = "oauth")]
struct DefaultOAuthConfig {
issuer_url: String,
client_id: String,
scopes: Vec<String>,
redirect_uri: Option<String>,
}
#[cfg(feature = "oauth")]
fn default_oauth_config_for_host(host: &str) -> Option<DefaultOAuthConfig> {
let normalized = normalize_login_host(host);
if !(normalized == "prefix.dev" || normalized.ends_with(".prefix.dev")) {
return None;
}
Some(DefaultOAuthConfig {
issuer_url: ensure_url_scheme(host),
client_id: "rattler".to_string(),
scopes: PREFIX_DEV_OAUTH_SCOPES
.iter()
.map(|&s| s.to_string())
.collect(),
redirect_uri: None,
})
}
#[cfg(feature = "oauth")]
fn default_oauth_for_login(args: &LoginArgs) -> Option<DefaultOAuthConfig> {
let no_explicit_method = args.token.is_none()
&& args.username.is_none()
&& args.password.is_none()
&& args.conda_token.is_none()
&& args.s3_access_key_id.is_none();
if !no_explicit_method {
return None;
}
default_oauth_config_for_host(&args.host)
}
fn get_url(url: &str) -> Result<String, AuthenticationCLIError> {
let host = if url.contains("://") {
url::Url::parse(url)?.host_str().unwrap().to_string()
} else {
url.to_string()
};
let host = if host.matches('.').count() == 1 {
format!("*.{host}")
} else {
host
};
Ok(host)
}
fn ensure_url_scheme(host: &str) -> String {
if host.contains("://") {
host.to_string()
} else {
format!("https://{host}")
}
}
#[derive(Debug, PartialEq)]
pub enum ValidationResult {
Valid(String, Url),
Invalid,
}
async fn login(
args: LoginArgs,
storage: AuthenticationStorage,
) -> Result<(), AuthenticationCLIError> {
#[cfg(feature = "oauth")]
{
let auto_default = default_oauth_for_login(&args);
if args.oauth || auto_default.is_some() {
if !args.oauth {
eprintln!(
"No credentials provided; using OAuth browser login for {}.",
args.host
);
}
let host_default = auto_default.or_else(|| default_oauth_config_for_host(&args.host));
let issuer_url = args
.oauth_issuer_url
.or_else(|| host_default.as_ref().map(|c| c.issuer_url.clone()))
.unwrap_or_else(|| ensure_url_scheme(&args.host));
let client_id = args
.oauth_client_id
.or_else(|| host_default.as_ref().map(|c| c.client_id.clone()))
.unwrap_or_else(|| "rattler".to_string());
let flow = match args.oauth_flow.as_deref() {
Some("auth-code") => oauth::OAuthFlow::AuthCode,
Some("device-code") => oauth::OAuthFlow::DeviceCode,
_ => oauth::OAuthFlow::Auto,
};
let redirect_uri = args
.oauth_redirect_uri
.or_else(|| host_default.as_ref().and_then(|c| c.redirect_uri.clone()));
let scopes: std::collections::HashSet<String> = if !args.oauth_scopes.is_empty() {
args.oauth_scopes.into_iter().collect()
} else if let Some(default) = host_default {
default.scopes.into_iter().collect()
} else {
oauth::DEFAULT_OAUTH_SCOPES
.iter()
.map(|&s| s.to_string())
.collect()
};
let config = oauth::OAuthConfig {
issuer_url,
client_id,
client_secret: args.oauth_client_secret,
flow,
scopes,
redirect_uri,
user_agent: args.user_agent,
};
let auth = oauth::perform_oauth_login(config).await?;
let host = normalize_login_host(&args.host);
storage.store(&host, &auth)?;
eprintln!("Credentials stored for {host}.");
return Ok(());
}
}
let auth = if let Some(conda_token) = args.conda_token {
Authentication::CondaToken(conda_token)
} else if let Some(username) = args.username {
if let Some(password) = args.password {
Authentication::BasicHTTP { username, password }
} else {
return Err(AuthenticationCLIError::MissingPassword);
}
} else if let Some(token) = args.token {
Authentication::BearerToken(token)
} else if let (Some(access_key_id), Some(secret_access_key)) =
(args.s3_access_key_id, args.s3_secret_access_key)
{
let session_token = args.s3_session_token;
Authentication::S3Credentials {
access_key_id,
secret_access_key,
session_token,
}
} else {
return Err(AuthenticationCLIError::NoAuthenticationMethod);
};
if args.host.contains("prefix.dev") && !matches!(auth, Authentication::BearerToken(_)) {
return Err(AuthenticationCLIError::PrefixDevBadMethod);
}
if args.host.contains("anaconda.org") && !matches!(auth, Authentication::CondaToken(_)) {
return Err(AuthenticationCLIError::AnacondaOrgBadMethod);
}
if args.host.contains("s3://") && !matches!(auth, Authentication::S3Credentials { .. })
|| matches!(auth, Authentication::S3Credentials { .. }) && !args.host.contains("s3://")
{
return Err(AuthenticationCLIError::S3BadMethod);
}
let host = get_url(&args.host)?;
eprintln!("Authenticating with {host} using {} method", auth.method());
if args.host.contains("prefix.dev") {
let token = match &auth {
Authentication::BearerToken(t) => t,
_ => return Err(AuthenticationCLIError::PrefixDevBadMethod),
};
match validate_prefix_dev_token(token, &args.host, args.user_agent.as_deref()).await? {
ValidationResult::Valid(username, url) => {
println!(
"✅ Token is valid. Logged into {url} as \"{username}\". Storing credentials..."
);
storage.store(&host, &auth)?;
}
ValidationResult::Invalid => {
return Err(AuthenticationCLIError::UnauthorizedToken);
}
}
} else {
storage.store(&host, &auth)?;
}
Ok(())
}
async fn validate_prefix_dev_token(
token: &str,
host: &str,
user_agent: Option<&str>,
) -> Result<ValidationResult, AuthenticationCLIError> {
let prefix_url = if let Ok(env_var) = std::env::var("PREFIX_DEV_API_URL") {
Url::parse(&env_var).expect("PREFIX_DEV_API_URL must be a valid URL")
} else {
let host = host.replace("*.", "");
let host_url = Url::parse(&ensure_url_scheme(&host))?;
let host_url = host_url.host_str().unwrap_or("prefix.dev");
let host_url = host_url.strip_prefix("repo.").unwrap_or(host_url);
Url::parse(&format!("https://{host_url}")).expect("constructed url must be valid")
};
let body = json!({
"query": "query { viewer { login } }"
});
let client = Client::builder()
.user_agent(user_agent.unwrap_or(DEFAULT_USER_AGENT))
.build()?;
let response = client
.post(prefix_url.join("api/graphql").expect("must be valid"))
.bearer_auth(token)
.header(CONTENT_TYPE, "application/json")
.json(&body)
.send()
.await?
.error_for_status()?;
let text = response.text().await?;
let json: serde_json::Value = serde_json::from_str(&text)
.map_err(|e| AuthenticationCLIError::JsonParseError(e.to_string()))?;
match &json["data"]["viewer"] {
serde_json::Value::Null => Ok(ValidationResult::Invalid),
viewer_data => {
if let Some(username) = viewer_data["login"].as_str() {
Ok(ValidationResult::Valid(username.to_string(), prefix_url))
} else {
Ok(ValidationResult::Invalid)
}
}
}
}
async fn logout(
args: LogoutArgs,
storage: AuthenticationStorage,
) -> Result<(), AuthenticationCLIError> {
let host = get_url(&args.host)?;
#[cfg(feature = "oauth")]
if let Ok(Some(Authentication::OAuth {
ref access_token,
ref refresh_token,
revocation_endpoint: Some(ref revocation_endpoint),
ref client_id,
..
})) = storage.get(&host)
{
eprintln!("Revoking OAuth tokens...");
oauth::revoke_tokens(
revocation_endpoint,
access_token,
refresh_token.as_deref(),
client_id,
None,
)
.await;
}
println!("Removing authentication for {host}");
storage.delete(&host)?;
Ok(())
}
pub async fn execute(args: Args) -> Result<(), AuthenticationCLIError> {
let storage = AuthenticationStorage::from_env_and_defaults()?;
match args.subcommand {
Subcommand::Login(args) => login(args, storage).await,
Subcommand::Logout(args) => logout(args, storage).await,
}
}
#[cfg(test)]
mod tests {
use mockito::Server;
use rattler_networking::{
AuthenticationStorage, authentication_storage::backends::memory::MemoryStorage,
};
use serde_json::json;
use temp_env::async_with_vars;
use tempfile::TempDir;
use super::*;
fn create_test_storage() -> (AuthenticationStorage, TempDir) {
let temp_dir = TempDir::new().unwrap();
let mut storage = AuthenticationStorage::empty();
storage.add_backend(std::sync::Arc::new(MemoryStorage::new()));
(storage, temp_dir)
}
fn create_login_args(host: &str) -> LoginArgs {
LoginArgs {
host: host.to_string(),
token: None,
username: None,
password: None,
conda_token: None,
s3_access_key_id: None,
s3_secret_access_key: None,
s3_session_token: None,
#[cfg(feature = "oauth")]
oauth: false,
#[cfg(feature = "oauth")]
oauth_issuer_url: None,
#[cfg(feature = "oauth")]
oauth_client_id: None,
#[cfg(feature = "oauth")]
oauth_client_secret: None,
#[cfg(feature = "oauth")]
oauth_flow: None,
#[cfg(feature = "oauth")]
oauth_scopes: vec![],
#[cfg(feature = "oauth")]
oauth_redirect_uri: None,
user_agent: None,
}
}
#[tokio::test]
async fn test_login_with_token_success() {
let (storage, _temp_dir) = create_test_storage();
let mut server = Server::new_async().await;
let mock = server
.mock("POST", "/api/graphql")
.with_status(200)
.with_header("content-type", "application/json")
.with_header("authorization", "Bearer valid_token")
.with_body(
json!({
"data": {
"viewer": {
"login": "testuser"
}
}
})
.to_string(),
)
.expect(1)
.create();
let mut args = create_login_args("prefix.dev");
args.token = Some("valid_token".to_string());
let result = async_with_vars(
[("PREFIX_DEV_API_URL", Some(server.url().as_str()))],
async { login(args, storage).await },
)
.await;
assert!(result.is_ok());
mock.assert();
}
#[tokio::test]
async fn test_login_with_invalid_token() {
let (storage, _temp_dir) = create_test_storage();
let mut server = Server::new_async().await;
let mock = server
.mock("POST", "/api/graphql")
.with_status(200)
.with_header("content-type", "application/json")
.with_header("authorization", "Bearer invalid_token")
.with_body(
json!({
"data": {
"viewer": null
}
})
.to_string(),
)
.expect(1)
.create();
let mut args = create_login_args("prefix.dev");
args.token = Some("invalid_token".to_string());
let result = async_with_vars(
[("PREFIX_DEV_API_URL", Some(server.url().as_str()))],
async { login(args, storage).await },
)
.await;
assert!(matches!(
result,
Err(AuthenticationCLIError::UnauthorizedToken)
));
mock.assert();
}
#[tokio::test]
async fn test_login_missing_password_for_basic_auth() {
let (storage, _temp_dir) = create_test_storage();
let mut args = create_login_args("example.com");
args.username = Some("testuser".to_string());
let result = login(args, storage).await;
assert!(matches!(
result,
Err(AuthenticationCLIError::MissingPassword)
));
}
#[tokio::test]
async fn test_login_basic_auth_success() {
let (storage, _temp_dir) = create_test_storage();
let mut args = create_login_args("example.com");
args.username = Some("testuser".to_string());
args.password = Some("testpass".to_string());
let result = login(args, storage).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_login_conda_token_success() {
let (storage, _temp_dir) = create_test_storage();
let mut args = create_login_args("anaconda.org");
args.conda_token = Some("conda_token_123".to_string());
let result = login(args, storage).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_login_s3_credentials_success() {
let (storage, _temp_dir) = create_test_storage();
let mut args = create_login_args("s3://my-bucket");
args.s3_access_key_id = Some("access_key".to_string());
args.s3_secret_access_key = Some("secret_key".to_string());
args.s3_session_token = Some("session_token".to_string());
let result = login(args, storage).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_login_no_authentication_method() {
let (storage, _temp_dir) = create_test_storage();
let args = create_login_args("example.com");
let result = login(args, storage).await;
assert!(matches!(
result,
Err(AuthenticationCLIError::NoAuthenticationMethod)
));
}
#[tokio::test]
async fn test_login_prefix_dev_requires_token() {
let (storage, _temp_dir) = create_test_storage();
let mut args = create_login_args("prefix.dev");
args.username = Some("testuser".to_string());
args.password = Some("testpass".to_string());
let result = login(args, storage).await;
assert!(matches!(
result,
Err(AuthenticationCLIError::PrefixDevBadMethod)
));
}
#[tokio::test]
async fn test_login_anaconda_org_requires_conda_token() {
let (storage, _temp_dir) = create_test_storage();
let mut args = create_login_args("anaconda.org");
args.token = Some("bearer_token".to_string());
let result = login(args, storage).await;
assert!(matches!(
result,
Err(AuthenticationCLIError::AnacondaOrgBadMethod)
));
}
#[tokio::test]
async fn test_login_s3_requires_proper_credentials() {
let (storage, _temp_dir) = create_test_storage();
let mut args = create_login_args("s3://my-bucket");
args.token = Some("bearer_token".to_string());
let result = login(args, storage).await;
assert!(matches!(result, Err(AuthenticationCLIError::S3BadMethod)));
}
#[tokio::test]
async fn test_login_s3_credentials_with_non_s3_host() {
let (storage, _temp_dir) = create_test_storage();
let mut args = create_login_args("example.com");
args.s3_access_key_id = Some("access_key".to_string());
args.s3_secret_access_key = Some("secret_key".to_string());
let result = login(args, storage).await;
assert!(matches!(result, Err(AuthenticationCLIError::S3BadMethod)));
}
#[test]
fn ensure_url_scheme_prepends_https_for_bare_host() {
assert_eq!(ensure_url_scheme("prefix.dev"), "https://prefix.dev");
}
#[test]
fn ensure_url_scheme_keeps_existing_https_scheme() {
assert_eq!(
ensure_url_scheme("https://prefix.dev"),
"https://prefix.dev"
);
}
#[test]
fn ensure_url_scheme_keeps_existing_http_scheme() {
assert_eq!(
ensure_url_scheme("http://localhost:4444"),
"http://localhost:4444"
);
}
#[cfg(feature = "oauth")]
#[test]
fn test_default_oauth_config_for_host() {
let has_default = |h: &str| default_oauth_config_for_host(h).is_some();
assert!(has_default("prefix.dev"));
assert!(has_default("repo.prefix.dev"));
assert!(has_default("https://prefix.dev"));
assert!(has_default("*.prefix.dev"));
assert!(has_default("prefix.dev/"));
assert!(has_default("https://prefix.dev/"));
assert!(has_default("https://repo.prefix.dev/"));
assert!(!has_default("localhost"));
assert!(!has_default("localhost:8080"));
assert!(!has_default("127.0.0.1"));
assert!(!has_default("example.com"));
assert!(!has_default("evil-prefix.dev.attacker.com"));
assert!(!has_default("notprefix.dev"));
let prefix = default_oauth_config_for_host("prefix.dev").unwrap();
assert_eq!(prefix.issuer_url, "https://prefix.dev");
assert_eq!(prefix.client_id, "rattler");
assert!(prefix.scopes.iter().any(|s| s == "channel:upload"));
}
#[cfg(feature = "oauth")]
#[test]
fn test_default_oauth_for_login() {
assert!(default_oauth_for_login(&create_login_args("prefix.dev")).is_some());
let mut args = create_login_args("prefix.dev");
args.token = Some("t".into());
assert!(default_oauth_for_login(&args).is_none());
assert!(default_oauth_for_login(&create_login_args("example.com")).is_none());
}
}