use super::config::OAuthConfig;
use super::error::OAuthResult;
use super::flow::TokenResponse;
use crate::models::auth::ProviderAuth;
use async_trait::async_trait;
use reqwest::header::HeaderMap;
#[derive(Debug, Clone)]
pub struct AuthMethod {
pub id: String,
pub label: String,
pub description: Option<String>,
pub method_type: AuthMethodType,
}
impl AuthMethod {
pub fn oauth(
id: impl Into<String>,
label: impl Into<String>,
description: Option<String>,
) -> Self {
Self {
id: id.into(),
label: label.into(),
description,
method_type: AuthMethodType::OAuth,
}
}
pub fn api_key(
id: impl Into<String>,
label: impl Into<String>,
description: Option<String>,
) -> Self {
Self {
id: id.into(),
label: label.into(),
description,
method_type: AuthMethodType::ApiKey,
}
}
pub fn display(&self) -> String {
match &self.description {
Some(desc) => format!("{} - {}", self.label, desc),
None => self.label.clone(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AuthMethodType {
OAuth,
ApiKey,
}
#[async_trait]
pub trait OAuthProvider: Send + Sync {
fn id(&self) -> &'static str;
fn name(&self) -> &'static str;
fn auth_methods(&self) -> Vec<AuthMethod>;
fn oauth_config(&self, method_id: &str) -> Option<OAuthConfig>;
async fn post_authorize(
&self,
method_id: &str,
tokens: &TokenResponse,
) -> OAuthResult<ProviderAuth>;
fn apply_auth_headers(&self, auth: &ProviderAuth, headers: &mut HeaderMap) -> OAuthResult<()>;
fn api_key_env_var(&self) -> Option<&'static str> {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_auth_method_oauth() {
let method = AuthMethod::oauth(
"claude-max",
"Claude Pro/Max",
Some("Use your subscription".to_string()),
);
assert_eq!(method.id, "claude-max");
assert_eq!(method.label, "Claude Pro/Max");
assert_eq!(
method.description,
Some("Use your subscription".to_string())
);
assert_eq!(method.method_type, AuthMethodType::OAuth);
}
#[test]
fn test_auth_method_api_key() {
let method = AuthMethod::api_key("api-key", "Manual API Key", None);
assert_eq!(method.id, "api-key");
assert_eq!(method.label, "Manual API Key");
assert_eq!(method.description, None);
assert_eq!(method.method_type, AuthMethodType::ApiKey);
}
#[test]
fn test_auth_method_display() {
let with_desc =
AuthMethod::oauth("test", "Test Method", Some("Description here".to_string()));
assert_eq!(with_desc.display(), "Test Method - Description here");
let without_desc = AuthMethod::oauth("test", "Test Method", None);
assert_eq!(without_desc.display(), "Test Method");
}
}