Skip to main content

rab/provider/oauth/
mod.rs

1//! OAuth provider trait and registry — matching pi's OAuthProviderInterface.
2//!
3//! Each OAuth provider implements login (device code or callback-server flow),
4//! token refresh, and API key derivation.
5
6use std::collections::HashMap;
7
8use async_trait::async_trait;
9use std::sync::{Arc, Mutex, OnceLock};
10
11pub mod device_code;
12pub mod github_copilot;
13
14/// Credentials returned from a successful OAuth login or refresh.
15#[derive(Debug, Clone)]
16pub struct OAuthCredentials {
17    pub access: String,
18    pub refresh: String,
19    pub expires: i64, // epoch ms
20    pub enterprise_url: Option<String>,
21    /// Provider-specific extra data (e.g. available model IDs for Copilot).
22    pub extra: HashMap<String, String>,
23}
24
25/// Info passed to `on_device_code` callback.
26#[derive(Debug, Clone)]
27pub struct DeviceCodeInfo {
28    pub user_code: String,
29    pub verification_uri: String,
30    pub interval_seconds: Option<u32>,
31    pub expires_in_seconds: Option<u32>,
32}
33
34/// A prompt shown to the user during login.
35#[derive(Debug, Clone)]
36pub enum OAuthPrompt {
37    Text {
38        message: String,
39        placeholder: Option<String>,
40        allow_empty: bool,
41    },
42}
43
44/// Callbacks the login flow uses to interact with the user.
45pub struct OAuthLoginCallbacks<'a> {
46    pub on_device_code: Box<dyn FnMut(DeviceCodeInfo) + Send + 'a>,
47    pub on_prompt: Box<dyn FnMut(OAuthPrompt) -> Result<String, String> + Send + 'a>,
48    pub on_progress: Box<dyn FnMut(String) + Send + 'a>,
49    pub signal: Option<tokio_util::sync::CancellationToken>,
50}
51
52/// An OAuth provider (matching pi's OAuthProviderInterface).
53#[async_trait]
54pub trait OAuthProvider: Send + Sync {
55    /// The provider ID (matches the registry/provider id).
56    fn id(&self) -> &str;
57
58    /// Human-readable name.
59    fn name(&self) -> &str;
60
61    /// Run the login flow (device code for Copilot, callback-server for Anthropic).
62    async fn login(
63        &self,
64        callbacks: &mut OAuthLoginCallbacks<'_>,
65    ) -> Result<OAuthCredentials, String>;
66
67    /// Refresh an expired token.
68    async fn refresh_token(
69        &self,
70        credentials: &OAuthCredentials,
71    ) -> Result<OAuthCredentials, String>;
72
73    /// Derive the API key (access token) for API requests.
74    fn get_api_key<'a>(&self, credentials: &'a OAuthCredentials) -> &'a str;
75}
76
77// ── Registry ───────────────────────────────────────────────────────
78
79static BUILT_IN_PROVIDERS: &[&str] = &["github-copilot"];
80
81static REGISTRY: OnceLock<Mutex<HashMap<String, Arc<dyn OAuthProvider>>>> = OnceLock::new();
82
83fn registry() -> &'static Mutex<HashMap<String, Arc<dyn OAuthProvider>>> {
84    REGISTRY.get_or_init(|| Mutex::new(HashMap::new()))
85}
86
87/// Register an OAuth provider.
88pub fn register(provider: Arc<dyn OAuthProvider>) {
89    registry()
90        .lock()
91        .unwrap()
92        .insert(provider.id().to_string(), provider);
93}
94
95/// Get an OAuth provider by ID.
96pub fn get(id: &str) -> Option<Arc<dyn OAuthProvider>> {
97    registry().lock().unwrap().get(id).cloned()
98}
99
100/// List all registered OAuth provider IDs.
101pub fn list_ids() -> Vec<String> {
102    registry().lock().unwrap().keys().cloned().collect()
103}
104
105/// Check if a provider ID corresponds to a built-in OAuth provider.
106pub fn is_built_in(id: &str) -> bool {
107    BUILT_IN_PROVIDERS.contains(&id)
108}
109
110/// Register all built-in OAuth providers (called once at startup).
111pub fn register_builtins() {
112    static INIT: std::sync::Once = std::sync::Once::new();
113    INIT.call_once(|| {
114        let gh = crate::provider::oauth::github_copilot::GitHubCopilotOAuth;
115        register(Arc::new(gh));
116    });
117}