use super::DockerConfigAuth;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum AuthSource {
#[default]
Anonymous,
Basic { username: String, password: String },
DockerConfig,
EnvVar {
username_var: String,
password_var: String,
},
SecretStore { credential_id: String },
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RegistryAuthConfig {
pub registry: String,
pub source: AuthSource,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct AuthConfig {
#[serde(default)]
pub registries: Vec<RegistryAuthConfig>,
#[serde(default)]
pub default: AuthSource,
pub docker_config_path: Option<PathBuf>,
}
impl Default for AuthConfig {
fn default() -> Self {
Self {
registries: Vec::new(),
default: AuthSource::DockerConfig,
docker_config_path: None,
}
}
}
pub struct AuthResolver {
config: AuthConfig,
docker_config: Option<DockerConfigAuth>,
registry_map: HashMap<String, AuthSource>,
}
impl AuthResolver {
#[must_use]
pub fn new(config: AuthConfig) -> Self {
let registry_map: HashMap<String, AuthSource> = config
.registries
.iter()
.map(|r| (r.registry.clone(), r.source.clone()))
.collect();
let needs_docker_config = config.default == AuthSource::DockerConfig
|| registry_map
.values()
.any(|s| matches!(s, AuthSource::DockerConfig));
let docker_config = if needs_docker_config {
Self::load_docker_config(config.docker_config_path.as_ref())
} else {
None
};
Self {
config,
docker_config,
registry_map,
}
}
#[must_use]
pub fn resolve(&self, image: &str) -> oci_client::secrets::RegistryAuth {
let registry = Self::extract_registry(image);
let source = self
.registry_map
.get(®istry)
.unwrap_or(&self.config.default);
self.resolve_source(source, ®istry)
}
#[must_use]
pub fn source_for_registry(&self, registry: &str) -> &AuthSource {
self.registry_map
.get(registry)
.unwrap_or(&self.config.default)
}
pub fn resolve_source(
&self,
source: &AuthSource,
registry: &str,
) -> oci_client::secrets::RegistryAuth {
match source {
AuthSource::Anonymous => oci_client::secrets::RegistryAuth::Anonymous,
AuthSource::Basic { username, password } => {
oci_client::secrets::RegistryAuth::Basic(username.clone(), password.clone())
}
AuthSource::DockerConfig => {
if let Some(ref docker_config) = self.docker_config {
if let Some((username, password)) = docker_config.get_credentials(registry) {
return oci_client::secrets::RegistryAuth::Basic(username, password);
}
}
oci_client::secrets::RegistryAuth::Anonymous
}
AuthSource::EnvVar {
username_var,
password_var,
} => {
let username = std::env::var(username_var).unwrap_or_default();
let password = std::env::var(password_var).unwrap_or_default();
if !username.is_empty() && !password.is_empty() {
oci_client::secrets::RegistryAuth::Basic(username, password)
} else {
oci_client::secrets::RegistryAuth::Anonymous
}
}
AuthSource::SecretStore { .. } => {
tracing::warn!(
"SecretStore auth source requires async resolver; returning Anonymous"
);
oci_client::secrets::RegistryAuth::Anonymous
}
}
}
fn extract_registry(image: &str) -> String {
let image_without_digest = image.split('@').next().unwrap_or(image);
let parts: Vec<&str> = image_without_digest.split('/').collect();
if parts.len() == 1 {
return "docker.io".to_string();
}
let first_part = parts[0];
if first_part.contains('.') || first_part.contains(':') || first_part == "localhost" {
first_part.to_string()
} else {
"docker.io".to_string()
}
}
fn load_docker_config(path: Option<&PathBuf>) -> Option<DockerConfigAuth> {
let config = if let Some(path) = path {
DockerConfigAuth::load_from_path(path).ok()
} else {
DockerConfigAuth::load().ok()
};
if config.is_none() {
tracing::debug!("Failed to load Docker config, using anonymous auth as fallback");
}
config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_registry() {
assert_eq!(AuthResolver::extract_registry("ubuntu"), "docker.io");
assert_eq!(AuthResolver::extract_registry("ubuntu:latest"), "docker.io");
assert_eq!(
AuthResolver::extract_registry("library/ubuntu"),
"docker.io"
);
assert_eq!(
AuthResolver::extract_registry("ghcr.io/owner/repo"),
"ghcr.io"
);
assert_eq!(
AuthResolver::extract_registry("ghcr.io/owner/repo:tag"),
"ghcr.io"
);
assert_eq!(
AuthResolver::extract_registry("localhost:5000/image"),
"localhost:5000"
);
assert_eq!(
AuthResolver::extract_registry("myregistry.com/path/to/image:v1.0"),
"myregistry.com"
);
}
#[test]
fn test_anonymous_auth() {
let config = AuthConfig {
default: AuthSource::Anonymous,
..Default::default()
};
let resolver = AuthResolver::new(config);
let auth = resolver.resolve("ubuntu:latest");
assert!(matches!(auth, oci_client::secrets::RegistryAuth::Anonymous));
}
#[test]
fn test_basic_auth() {
let config = AuthConfig {
default: AuthSource::Basic {
username: "user".to_string(),
password: "pass".to_string(),
},
..Default::default()
};
let resolver = AuthResolver::new(config);
let auth = resolver.resolve("ubuntu:latest");
match auth {
oci_client::secrets::RegistryAuth::Basic(username, password) => {
assert_eq!(username, "user");
assert_eq!(password, "pass");
}
_ => panic!("Expected Basic auth"),
}
}
#[test]
fn test_per_registry_auth() {
let config = AuthConfig {
registries: vec![RegistryAuthConfig {
registry: "ghcr.io".to_string(),
source: AuthSource::Basic {
username: "ghcr_user".to_string(),
password: "ghcr_pass".to_string(),
},
}],
default: AuthSource::Anonymous,
..Default::default()
};
let resolver = AuthResolver::new(config);
let auth = resolver.resolve("ghcr.io/owner/repo:tag");
match auth {
oci_client::secrets::RegistryAuth::Basic(username, password) => {
assert_eq!(username, "ghcr_user");
assert_eq!(password, "ghcr_pass");
}
_ => panic!("Expected Basic auth for ghcr.io"),
}
let auth = resolver.resolve("ubuntu:latest");
assert!(matches!(auth, oci_client::secrets::RegistryAuth::Anonymous));
}
#[test]
fn test_env_var_auth() {
std::env::set_var("TEST_USERNAME", "env_user");
std::env::set_var("TEST_PASSWORD", "env_pass");
let config = AuthConfig {
default: AuthSource::EnvVar {
username_var: "TEST_USERNAME".to_string(),
password_var: "TEST_PASSWORD".to_string(),
},
..Default::default()
};
let resolver = AuthResolver::new(config);
let auth = resolver.resolve("ubuntu:latest");
match auth {
oci_client::secrets::RegistryAuth::Basic(username, password) => {
assert_eq!(username, "env_user");
assert_eq!(password, "env_pass");
}
_ => panic!("Expected Basic auth from env vars"),
}
std::env::remove_var("TEST_USERNAME");
std::env::remove_var("TEST_PASSWORD");
}
#[test]
fn test_env_var_auth_fallback() {
let config = AuthConfig {
default: AuthSource::EnvVar {
username_var: "NONEXISTENT_USER".to_string(),
password_var: "NONEXISTENT_PASS".to_string(),
},
..Default::default()
};
let resolver = AuthResolver::new(config);
let auth = resolver.resolve("ubuntu:latest");
assert!(matches!(auth, oci_client::secrets::RegistryAuth::Anonymous));
}
#[test]
fn test_secret_store_sync_fallback_returns_anonymous() {
let config = AuthConfig {
registries: vec![RegistryAuthConfig {
registry: "private.registry.io".to_string(),
source: AuthSource::SecretStore {
credential_id: "cred-uuid-123".to_string(),
},
}],
default: AuthSource::Anonymous,
..Default::default()
};
let resolver = AuthResolver::new(config);
let auth = resolver.resolve("private.registry.io/image:latest");
assert!(matches!(auth, oci_client::secrets::RegistryAuth::Anonymous));
let auth = resolver.resolve("ubuntu:latest");
assert!(matches!(auth, oci_client::secrets::RegistryAuth::Anonymous));
}
#[test]
fn test_source_for_registry_returns_correct_source() {
let config = AuthConfig {
registries: vec![RegistryAuthConfig {
registry: "ghcr.io".to_string(),
source: AuthSource::Basic {
username: "user".to_string(),
password: "pass".to_string(),
},
}],
default: AuthSource::Anonymous,
..Default::default()
};
let resolver = AuthResolver::new(config);
let source = resolver.source_for_registry("ghcr.io");
assert!(matches!(source, AuthSource::Basic { .. }));
let source = resolver.source_for_registry("docker.io");
assert!(matches!(source, AuthSource::Anonymous));
}
#[test]
fn test_secret_store_serde_roundtrip() {
let source = AuthSource::SecretStore {
credential_id: "abc-123".to_string(),
};
let json = serde_json::to_string(&source).unwrap();
let parsed: AuthSource = serde_json::from_str(&json).unwrap();
assert_eq!(source, parsed);
}
}