use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ProviderAuth {
Api {
key: String,
},
#[serde(rename = "oauth")]
OAuth {
access: String,
refresh: String,
expires: i64,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
}
impl ProviderAuth {
pub fn api_key(key: impl Into<String>) -> Self {
Self::Api { key: key.into() }
}
pub fn oauth(access: impl Into<String>, refresh: impl Into<String>, expires: i64) -> Self {
Self::OAuth {
access: access.into(),
refresh: refresh.into(),
expires,
name: None,
}
}
pub fn oauth_with_name(
access: impl Into<String>,
refresh: impl Into<String>,
expires: i64,
name: impl Into<String>,
) -> Self {
Self::OAuth {
access: access.into(),
refresh: refresh.into(),
expires,
name: Some(name.into()),
}
}
pub fn needs_refresh(&self) -> bool {
match self {
Self::OAuth { expires, .. } => {
let now_ms = chrono::Utc::now().timestamp_millis();
let buffer_ms = 5 * 60 * 1000; *expires < (now_ms + buffer_ms)
}
Self::Api { .. } => false,
}
}
pub fn is_expired(&self) -> bool {
match self {
Self::OAuth { expires, .. } => *expires < chrono::Utc::now().timestamp_millis(),
Self::Api { .. } => false,
}
}
pub fn api_key_value(&self) -> Option<&str> {
match self {
Self::Api { key } => Some(key),
Self::OAuth { .. } => None,
}
}
pub fn access_token(&self) -> Option<&str> {
match self {
Self::OAuth { access, .. } => Some(access),
Self::Api { .. } => None,
}
}
pub fn refresh_token(&self) -> Option<&str> {
match self {
Self::OAuth { refresh, .. } => Some(refresh),
Self::Api { .. } => None,
}
}
pub fn is_oauth(&self) -> bool {
matches!(self, Self::OAuth { .. })
}
pub fn is_api_key(&self) -> bool {
matches!(self, Self::Api { .. })
}
pub fn auth_type_display(&self) -> &'static str {
match self {
Self::Api { .. } => "api_key",
Self::OAuth { .. } => "oauth",
}
}
pub fn subscription_name(&self) -> Option<&str> {
match self {
Self::OAuth { name, .. } => name.as_deref(),
Self::Api { .. } => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_api_key_creation() {
let auth = ProviderAuth::api_key("sk-test-key");
assert!(auth.is_api_key());
assert!(!auth.is_oauth());
assert_eq!(auth.api_key_value(), Some("sk-test-key"));
assert_eq!(auth.access_token(), None);
}
#[test]
fn test_oauth_creation() {
let expires = chrono::Utc::now().timestamp_millis() + 3600000; let auth = ProviderAuth::oauth("access-token", "refresh-token", expires);
assert!(auth.is_oauth());
assert!(!auth.is_api_key());
assert_eq!(auth.access_token(), Some("access-token"));
assert_eq!(auth.refresh_token(), Some("refresh-token"));
assert_eq!(auth.api_key_value(), None);
}
#[test]
fn test_oauth_needs_refresh() {
let expires = chrono::Utc::now().timestamp_millis() + 2 * 60 * 1000;
let auth = ProviderAuth::oauth("access", "refresh", expires);
assert!(auth.needs_refresh());
let expires = chrono::Utc::now().timestamp_millis() + 10 * 60 * 1000;
let auth = ProviderAuth::oauth("access", "refresh", expires);
assert!(!auth.needs_refresh());
}
#[test]
fn test_oauth_is_expired() {
let expires = chrono::Utc::now().timestamp_millis() - 1000;
let auth = ProviderAuth::oauth("access", "refresh", expires);
assert!(auth.is_expired());
let expires = chrono::Utc::now().timestamp_millis() + 3600000;
let auth = ProviderAuth::oauth("access", "refresh", expires);
assert!(!auth.is_expired());
}
#[test]
fn test_api_key_never_needs_refresh() {
let auth = ProviderAuth::api_key("sk-test");
assert!(!auth.needs_refresh());
assert!(!auth.is_expired());
}
#[test]
fn test_serde_api_key() {
let auth = ProviderAuth::api_key("sk-test-key");
let json = serde_json::to_string(&auth).unwrap();
assert!(json.contains("\"type\":\"api\""));
assert!(json.contains("\"key\":\"sk-test-key\""));
let parsed: ProviderAuth = serde_json::from_str(&json).unwrap();
assert_eq!(auth, parsed);
}
#[test]
fn test_serde_oauth() {
let auth = ProviderAuth::oauth("access-token", "refresh-token", 1735600000000);
let json = serde_json::to_string(&auth).unwrap();
assert!(json.contains("\"type\":\"oauth\""), "JSON was: {}", json);
assert!(json.contains("\"access\":\"access-token\""));
assert!(json.contains("\"refresh\":\"refresh-token\""));
assert!(json.contains("\"expires\":1735600000000"));
let parsed: ProviderAuth = serde_json::from_str(&json).unwrap();
assert_eq!(auth, parsed);
}
#[test]
fn test_auth_type_display() {
let api = ProviderAuth::api_key("key");
assert_eq!(api.auth_type_display(), "api_key");
let oauth = ProviderAuth::oauth("access", "refresh", 0);
assert_eq!(oauth.auth_type_display(), "oauth");
}
}