use super::config::OAuthConfig;
use super::device_flow::{DeviceCodeResponse, DeviceFlow, DeviceTokenResponse};
use super::error::{OAuthError, OAuthResult};
use super::flow::TokenResponse;
use crate::models::auth::ProviderAuth;
use async_trait::async_trait;
use reqwest::header::HeaderMap;
#[derive(Debug, Clone)]
pub struct AuthMethod {
pub id: String,
pub label: String,
pub description: Option<String>,
pub method_type: AuthMethodType,
}
impl AuthMethod {
pub fn oauth(
id: impl Into<String>,
label: impl Into<String>,
description: Option<String>,
) -> Self {
Self {
id: id.into(),
label: label.into(),
description,
method_type: AuthMethodType::OAuth,
}
}
pub fn api_key(
id: impl Into<String>,
label: impl Into<String>,
description: Option<String>,
) -> Self {
Self {
id: id.into(),
label: label.into(),
description,
method_type: AuthMethodType::ApiKey,
}
}
pub fn display(&self) -> String {
match &self.description {
Some(desc) => format!("{} - {}", self.label, desc),
None => self.label.clone(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AuthMethodType {
OAuth,
ApiKey,
DeviceFlow,
}
#[async_trait]
pub trait OAuthProvider: Send + Sync {
fn id(&self) -> &'static str;
fn name(&self) -> &'static str;
fn auth_methods(&self) -> Vec<AuthMethod>;
fn oauth_config(&self, method_id: &str) -> Option<OAuthConfig>;
async fn post_authorize(
&self,
method_id: &str,
tokens: &TokenResponse,
) -> OAuthResult<ProviderAuth>;
fn apply_auth_headers(&self, auth: &ProviderAuth, headers: &mut HeaderMap) -> OAuthResult<()>;
fn api_key_env_var(&self) -> Option<&'static str> {
None
}
fn device_flow(&self, method_id: &str) -> OAuthResult<DeviceFlow> {
Err(OAuthError::unknown_method(format!(
"Provider '{}' does not support the Device Authorization Grant for method '{}'",
self.id(),
method_id,
)))
}
async fn request_device_code(
&self,
method_id: &str,
) -> OAuthResult<(DeviceFlow, DeviceCodeResponse)> {
let flow = self.device_flow(method_id)?;
let code = flow.request_device_code().await?;
Ok((flow, code))
}
async fn wait_for_token(
&self,
flow: &DeviceFlow,
device_code: &DeviceCodeResponse,
) -> OAuthResult<DeviceTokenResponse> {
flow.poll_for_token(device_code).await
}
async fn post_device_authorize(
&self,
method_id: &str,
token: &DeviceTokenResponse,
) -> OAuthResult<ProviderAuth> {
let _ = (method_id, token);
Err(OAuthError::unknown_method(format!(
"Provider '{}' does not support post_device_authorize for method '{}'",
self.id(),
method_id,
)))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_auth_method_oauth() {
let method = AuthMethod::oauth(
"claude-max",
"Claude Pro/Max",
Some("Use your subscription".to_string()),
);
assert_eq!(method.id, "claude-max");
assert_eq!(method.label, "Claude Pro/Max");
assert_eq!(
method.description,
Some("Use your subscription".to_string())
);
assert_eq!(method.method_type, AuthMethodType::OAuth);
}
#[test]
fn test_auth_method_api_key() {
let method = AuthMethod::api_key("api-key", "Manual API Key", None);
assert_eq!(method.id, "api-key");
assert_eq!(method.label, "Manual API Key");
assert_eq!(method.description, None);
assert_eq!(method.method_type, AuthMethodType::ApiKey);
}
#[test]
fn test_auth_method_display() {
let with_desc =
AuthMethod::oauth("test", "Test Method", Some("Description here".to_string()));
assert_eq!(with_desc.display(), "Test Method - Description here");
let without_desc = AuthMethod::oauth("test", "Test Method", None);
assert_eq!(without_desc.display(), "Test Method");
}
}