rab/provider/oauth/
mod.rs1use std::collections::HashMap;
7
8use async_trait::async_trait;
9use std::sync::{Arc, Mutex, OnceLock};
10
11pub mod device_code;
12pub mod github_copilot;
13
14#[derive(Debug, Clone)]
16pub struct OAuthCredentials {
17 pub access: String,
18 pub refresh: String,
19 pub expires: i64, pub enterprise_url: Option<String>,
21 pub extra: HashMap<String, String>,
23}
24
25#[derive(Debug, Clone)]
27pub struct DeviceCodeInfo {
28 pub user_code: String,
29 pub verification_uri: String,
30 pub interval_seconds: Option<u32>,
31 pub expires_in_seconds: Option<u32>,
32}
33
34#[derive(Debug, Clone)]
36pub enum OAuthPrompt {
37 Text {
38 message: String,
39 placeholder: Option<String>,
40 allow_empty: bool,
41 },
42}
43
44pub struct OAuthLoginCallbacks<'a> {
46 pub on_device_code: Box<dyn FnMut(DeviceCodeInfo) + Send + 'a>,
47 pub on_prompt: Box<dyn FnMut(OAuthPrompt) -> Result<String, String> + Send + 'a>,
48 pub on_progress: Box<dyn FnMut(String) + Send + 'a>,
49 pub signal: Option<tokio_util::sync::CancellationToken>,
50}
51
52#[async_trait]
54pub trait OAuthProvider: Send + Sync {
55 fn id(&self) -> &str;
57
58 fn name(&self) -> &str;
60
61 async fn login(
63 &self,
64 callbacks: &mut OAuthLoginCallbacks<'_>,
65 ) -> Result<OAuthCredentials, String>;
66
67 async fn refresh_token(
69 &self,
70 credentials: &OAuthCredentials,
71 ) -> Result<OAuthCredentials, String>;
72
73 fn get_api_key<'a>(&self, credentials: &'a OAuthCredentials) -> &'a str;
75}
76
77static BUILT_IN_PROVIDERS: &[&str] = &["github-copilot"];
80
81static REGISTRY: OnceLock<Mutex<HashMap<String, Arc<dyn OAuthProvider>>>> = OnceLock::new();
82
83fn registry() -> &'static Mutex<HashMap<String, Arc<dyn OAuthProvider>>> {
84 REGISTRY.get_or_init(|| Mutex::new(HashMap::new()))
85}
86
87pub fn register(provider: Arc<dyn OAuthProvider>) {
89 registry()
90 .lock()
91 .unwrap()
92 .insert(provider.id().to_string(), provider);
93}
94
95pub fn get(id: &str) -> Option<Arc<dyn OAuthProvider>> {
97 registry().lock().unwrap().get(id).cloned()
98}
99
100pub fn list_ids() -> Vec<String> {
102 registry().lock().unwrap().keys().cloned().collect()
103}
104
105pub fn is_built_in(id: &str) -> bool {
107 BUILT_IN_PROVIDERS.contains(&id)
108}
109
110pub fn register_builtins() {
112 static INIT: std::sync::Once = std::sync::Once::new();
113 INIT.call_once(|| {
114 let gh = crate::provider::oauth::github_copilot::GitHubCopilotOAuth;
115 register(Arc::new(gh));
116 });
117}