use anyhow::Result;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use secrecy::{ExposeSecret, SecretString};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AuthType {
OAuth2DeviceFlow,
OAuth2AuthorizationCode,
OAuth2ClientCredentials,
ApiKey,
AwsIam,
AzureAd,
MutualTls,
Custom,
}
impl fmt::Display for AuthType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
AuthType::OAuth2DeviceFlow => write!(f, "OAuth2 Device Flow"),
AuthType::OAuth2AuthorizationCode => write!(f, "OAuth2 Authorization Code"),
AuthType::OAuth2ClientCredentials => write!(f, "OAuth2 Client Credentials"),
AuthType::ApiKey => write!(f, "API Key"),
AuthType::AwsIam => write!(f, "AWS IAM"),
AuthType::AzureAd => write!(f, "Azure AD"),
AuthType::MutualTls => write!(f, "Mutual TLS"),
AuthType::Custom => write!(f, "Custom"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderConfig {
pub id: String,
pub display_name: String,
pub auth_type: AuthType,
#[serde(default)]
pub oauth2: Option<OAuth2Config>,
#[serde(default)]
pub api_key: Option<ApiKeyConfig>,
#[serde(default)]
pub aws: Option<AwsConfig>,
#[serde(default)]
pub custom: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuth2Config {
pub device_authorization_endpoint: Option<String>,
pub authorization_endpoint: Option<String>,
pub token_endpoint: String,
pub revocation_endpoint: Option<String>,
pub client_id: String,
#[serde(default)]
pub client_secret: Option<String>,
#[serde(default)]
pub scopes: Vec<String>,
pub audience: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApiKeyConfig {
pub header_name: String,
#[serde(default)]
pub header_prefix: Option<String>,
pub env_var_name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AwsConfig {
pub region: Option<String>,
pub profile: Option<String>,
pub credential_source: Option<String>,
}
#[derive(Debug, Clone)]
pub struct AuthResult {
pub credentials: Credentials,
pub expires_at: Option<DateTime<Utc>>,
pub refresh_token: Option<SecretString>,
pub scopes: Vec<String>,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Credentials {
pub provider_id: String,
pub credential_type: CredentialType,
pub expires_at: Option<DateTime<Utc>>,
#[serde(default)]
pub scopes: Vec<String>,
#[serde(default)]
pub data: HashMap<String, String>,
#[serde(default)]
pub metadata: HashMap<String, String>,
}
impl Credentials {
pub fn is_expired(&self) -> bool {
match self.expires_at {
Some(expires) => Utc::now() >= expires,
None => false, }
}
pub fn expires_within(&self, duration: chrono::Duration) -> bool {
match self.expires_at {
Some(expires) => Utc::now() + duration >= expires,
None => false,
}
}
pub fn needs_refresh(&self) -> bool {
self.expires_within(chrono::Duration::minutes(5))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CredentialType {
OAuth2AccessToken,
OAuth2RefreshToken,
ApiKey,
AwsAccessKeyId,
AwsSecretAccessKey,
AwsSessionToken,
Certificate,
PrivateKey,
Secret,
}
#[derive(Debug, Clone)]
pub struct AuthStatus {
pub provider_id: String,
pub display_name: String,
pub authenticated: bool,
pub skill: Option<String>,
pub instance: Option<String>,
pub expires_at: Option<DateTime<Utc>>,
pub scopes: Vec<String>,
pub message: String,
}
#[async_trait]
pub trait AuthProvider: Send + Sync {
fn id(&self) -> &str;
fn display_name(&self) -> &str;
fn auth_type(&self) -> AuthType;
fn config(&self) -> &ProviderConfig;
async fn authenticate(&self, scopes: Option<Vec<String>>) -> Result<AuthResult>;
async fn refresh(&self, credentials: &Credentials, refresh_token: &SecretString) -> Result<AuthResult>;
async fn validate(&self, credentials: &Credentials) -> Result<bool>;
async fn revoke(&self, credentials: &Credentials) -> Result<()>;
fn to_skill_config(&self, credentials: &Credentials) -> HashMap<String, String>;
fn secret_keys(&self) -> Vec<&str>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeviceAuthorizationResponse {
pub device_code: String,
pub user_code: String,
pub verification_uri: String,
pub verification_uri_complete: Option<String>,
pub expires_in: u64,
pub interval: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenResponse {
pub access_token: String,
pub token_type: String,
pub expires_in: Option<u64>,
pub refresh_token: Option<String>,
pub scope: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuth2Error {
pub error: String,
pub error_description: Option<String>,
pub error_uri: Option<String>,
}
impl fmt::Display for OAuth2Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.error_description {
Some(desc) => write!(f, "{}: {}", self.error, desc),
None => write!(f, "{}", self.error),
}
}
}
impl std::error::Error for OAuth2Error {}