1use serde::{Deserialize, Serialize};
2use std::time::{Duration, SystemTime, UNIX_EPOCH};
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct TokenSet {
7 pub access_token: String,
9 pub refresh_token: String,
11 pub expires_at: u64,
13}
14
15impl TokenSet {
16 pub fn is_expired(&self) -> bool {
21 self.expires_in() <= Duration::from_secs(300)
22 }
23
24 pub fn expires_in(&self) -> Duration {
28 let now = SystemTime::now()
29 .duration_since(UNIX_EPOCH)
30 .unwrap()
31 .as_secs();
32
33 if self.expires_at > now {
34 Duration::from_secs(self.expires_at - now)
35 } else {
36 Duration::ZERO
37 }
38 }
39}
40
41#[derive(Debug, Clone)]
46pub struct OAuthFlow {
47 pub authorization_url: String,
49 pub pkce_verifier: String,
51 pub state: String,
53}
54
55#[derive(Debug, Clone)]
57pub struct OAuthConfig {
58 pub client_id: String,
60 pub auth_url: String,
62 pub token_url: String,
64 pub redirect_uri: String,
66}
67
68impl Default for OAuthConfig {
69 fn default() -> Self {
70 Self {
71 client_id: "app_EMoamEEZ73f0CkXaXp7hrann".to_string(),
72 auth_url: "https://auth.openai.com/oauth/authorize".to_string(),
73 token_url: "https://auth.openai.com/oauth/token".to_string(),
74 redirect_uri: "http://localhost:1455/auth/callback".to_string(),
75 }
76 }
77}
78
79impl OAuthConfig {
80 pub fn builder() -> OAuthConfigBuilder {
82 OAuthConfigBuilder::default()
83 }
84}
85
86#[derive(Debug, Clone, Default)]
88pub struct OAuthConfigBuilder {
89 client_id: Option<String>,
90 auth_url: Option<String>,
91 token_url: Option<String>,
92 redirect_uri: Option<String>,
93}
94
95impl OAuthConfigBuilder {
96 pub fn client_id(mut self, client_id: impl Into<String>) -> Self {
98 self.client_id = Some(client_id.into());
99 self
100 }
101
102 pub fn auth_url(mut self, auth_url: impl Into<String>) -> Self {
104 self.auth_url = Some(auth_url.into());
105 self
106 }
107
108 pub fn token_url(mut self, token_url: impl Into<String>) -> Self {
110 self.token_url = Some(token_url.into());
111 self
112 }
113
114 pub fn redirect_uri(mut self, redirect_uri: impl Into<String>) -> Self {
116 self.redirect_uri = Some(redirect_uri.into());
117 self
118 }
119
120 pub fn redirect_port(mut self, port: u16) -> Self {
122 self.redirect_uri = Some(format!("http://localhost:{}/auth/callback", port));
123 self
124 }
125
126 pub fn build(self) -> OAuthConfig {
128 let defaults = OAuthConfig::default();
129 OAuthConfig {
130 client_id: self.client_id.unwrap_or(defaults.client_id),
131 auth_url: self.auth_url.unwrap_or(defaults.auth_url),
132 token_url: self.token_url.unwrap_or(defaults.token_url),
133 redirect_uri: self.redirect_uri.unwrap_or(defaults.redirect_uri),
134 }
135 }
136}
137
138#[derive(Debug, Deserialize)]
140pub(crate) struct TokenResponse {
141 pub access_token: String,
142 pub refresh_token: Option<String>,
143 pub expires_in: Option<u64>,
144}
145
146impl From<TokenResponse> for TokenSet {
147 fn from(response: TokenResponse) -> Self {
148 let expires_at = SystemTime::now()
149 .duration_since(UNIX_EPOCH)
150 .unwrap()
151 .as_secs()
152 + response.expires_in.unwrap_or(3600);
153
154 TokenSet {
155 access_token: response.access_token,
156 refresh_token: response.refresh_token.unwrap_or_default(),
157 expires_at,
158 }
159 }
160}
161
162pub(crate) fn generate_random_state() -> String {
164 use base64::{engine::general_purpose, Engine as _};
165 use rand::Rng;
166
167 let random_bytes: Vec<u8> = (0..32).map(|_| rand::thread_rng().gen()).collect();
168 general_purpose::URL_SAFE_NO_PAD.encode(&random_bytes)
169}