use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AuthCredentialType {
ApiKey,
Http,
OAuth2,
OpenIdConnect,
ServiceAccount,
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct HttpAuth {
pub scheme: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub token: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub username: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub password: Option<String>,
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct OAuth2Auth {
pub client_id: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub client_secret: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub auth_uri: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub token_uri: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub redirect_uri: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub state: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub code_verifier: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub auth_code: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub access_token: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub refresh_token: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub expires_at: Option<i64>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub scopes: Vec<String>,
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct ServiceAccountAuth {
#[serde(rename = "type", default)]
pub account_type: String,
pub project_id: String,
pub private_key_id: String,
pub private_key: String,
pub client_email: String,
#[serde(default)]
pub client_id: String,
#[serde(default)]
pub auth_uri: String,
pub token_uri: String,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub scopes: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub target_audience: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub access_token: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub expires_at: Option<i64>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct AuthCredential {
pub auth_type: AuthCredentialType,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub api_key: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub http: Option<HttpAuth>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub oauth2: Option<OAuth2Auth>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub service_account: Option<ServiceAccountAuth>,
}
impl AuthCredential {
#[must_use]
pub fn api_key(value: impl Into<String>) -> Self {
Self {
auth_type: AuthCredentialType::ApiKey,
api_key: Some(value.into()),
http: None,
oauth2: None,
service_account: None,
}
}
#[must_use]
pub fn bearer(token: impl Into<String>) -> Self {
Self {
auth_type: AuthCredentialType::Http,
api_key: None,
http: Some(HttpAuth {
scheme: "bearer".into(),
token: Some(token.into()),
username: None,
password: None,
}),
oauth2: None,
service_account: None,
}
}
#[must_use]
pub fn basic(username: impl Into<String>, password: impl Into<String>) -> Self {
Self {
auth_type: AuthCredentialType::Http,
api_key: None,
http: Some(HttpAuth {
scheme: "basic".into(),
token: None,
username: Some(username.into()),
password: Some(password.into()),
}),
oauth2: None,
service_account: None,
}
}
#[must_use]
pub fn oauth2(oauth2: OAuth2Auth) -> Self {
Self {
auth_type: AuthCredentialType::OAuth2,
api_key: None,
http: None,
oauth2: Some(oauth2),
service_account: None,
}
}
#[must_use]
pub fn service_account(sa: ServiceAccountAuth) -> Self {
Self {
auth_type: AuthCredentialType::ServiceAccount,
api_key: None,
http: None,
oauth2: None,
service_account: Some(sa),
}
}
#[must_use]
pub fn is_ready(&self) -> bool {
match self.auth_type {
AuthCredentialType::ApiKey => self.api_key.is_some(),
AuthCredentialType::Http => self
.http
.as_ref()
.is_some_and(|h| h.token.is_some() || h.username.is_some()),
AuthCredentialType::OAuth2 | AuthCredentialType::OpenIdConnect => self
.oauth2
.as_ref()
.and_then(|o| o.access_token.as_ref())
.is_some(),
AuthCredentialType::ServiceAccount => self
.service_account
.as_ref()
.and_then(|s| s.access_token.as_ref())
.is_some(),
}
}
#[must_use]
pub fn is_expired(&self, now_unix: i64) -> bool {
let exp = match self.auth_type {
AuthCredentialType::OAuth2 | AuthCredentialType::OpenIdConnect => {
self.oauth2.as_ref().and_then(|o| o.expires_at)
}
AuthCredentialType::ServiceAccount => {
self.service_account.as_ref().and_then(|s| s.expires_at)
}
_ => None,
};
matches!(exp, Some(e) if e <= now_unix + 60)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn api_key_round_trip() {
let c = AuthCredential::api_key("sk-xyz");
let s = serde_json::to_string(&c).unwrap();
let back: AuthCredential = serde_json::from_str(&s).unwrap();
assert_eq!(c, back);
assert!(c.is_ready());
}
#[test]
fn bearer_round_trip() {
let c = AuthCredential::bearer("token-abc");
let s = serde_json::to_string(&c).unwrap();
assert!(s.contains("\"bearer\""));
let back: AuthCredential = serde_json::from_str(&s).unwrap();
assert_eq!(c, back);
assert!(c.is_ready());
}
#[test]
fn oauth2_unready_until_access_token() {
let mut c = AuthCredential::oauth2(OAuth2Auth {
client_id: "id".into(),
..OAuth2Auth::default()
});
assert!(!c.is_ready());
c.oauth2.as_mut().unwrap().access_token = Some("at".into());
assert!(c.is_ready());
}
#[test]
fn oauth2_expiry_leeway() {
let mut c = AuthCredential::oauth2(OAuth2Auth {
client_id: "id".into(),
access_token: Some("at".into()),
expires_at: Some(1000),
..OAuth2Auth::default()
});
assert!(!c.is_expired(0));
assert!(c.is_expired(1000)); assert!(c.is_expired(2000));
c.oauth2.as_mut().unwrap().expires_at = None;
assert!(!c.is_expired(9_999_999));
}
}