openai_auth/
types.rs

1use serde::{Deserialize, Serialize};
2use sha2::{Digest, Sha256};
3use std::time::{Duration, SystemTime, UNIX_EPOCH};
4
5/// OAuth token set containing access token, refresh token, and expiration info
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct TokenSet {
8    /// The access token used to authenticate API requests
9    pub access_token: String,
10    /// The ID token returned by OpenAI (used for API key exchange)
11    #[serde(default, skip_serializing_if = "Option::is_none")]
12    pub id_token: Option<String>,
13    /// The refresh token used to obtain new access tokens
14    pub refresh_token: String,
15    /// Unix timestamp (seconds) when the access token expires
16    pub expires_at: u64,
17    /// OpenAI API key derived from token exchange
18    #[serde(default, skip_serializing_if = "Option::is_none")]
19    pub api_key: Option<String>,
20}
21
22impl TokenSet {
23    /// Check if the token is expired or will expire soon (within 5 minutes)
24    ///
25    /// This includes a 5-minute buffer to prevent race conditions where a token
26    /// expires between checking and using it.
27    pub fn is_expired(&self) -> bool {
28        self.expires_in() <= Duration::from_secs(300)
29    }
30
31    /// Get the duration until the token expires
32    ///
33    /// Returns `Duration::ZERO` if the token is already expired.
34    pub fn expires_in(&self) -> Duration {
35        let now = SystemTime::now()
36            .duration_since(UNIX_EPOCH)
37            .unwrap()
38            .as_secs();
39
40        if self.expires_at > now {
41            Duration::from_secs(self.expires_at - now)
42        } else {
43            Duration::ZERO
44        }
45    }
46}
47
48/// OAuth authorization flow information
49///
50/// Contains the authorization URL and PKCE verifier needed to complete
51/// the OAuth flow.
52#[derive(Debug, Clone)]
53pub struct OAuthFlow {
54    /// The URL the user should visit to authorize the application
55    pub authorization_url: String,
56    /// The PKCE verifier used to exchange the authorization code for tokens
57    pub pkce_verifier: String,
58    /// The CSRF state token for security validation
59    pub state: String,
60}
61
62/// Configuration for the OpenAI OAuth client
63#[derive(Debug, Clone)]
64pub struct OAuthConfig {
65    /// OAuth client ID (default: "app_EMoamEEZ73f0CkXaXp7hrann")
66    pub client_id: String,
67    /// Authorization endpoint URL
68    pub auth_url: String,
69    /// Token exchange endpoint URL
70    pub token_url: String,
71    /// Redirect URI for OAuth callback (default: "http://localhost:1455/auth/callback")
72    pub redirect_uri: String,
73}
74
75impl Default for OAuthConfig {
76    fn default() -> Self {
77        Self {
78            client_id: "app_EMoamEEZ73f0CkXaXp7hrann".to_string(),
79            auth_url: "https://auth.openai.com/oauth/authorize".to_string(),
80            token_url: "https://auth.openai.com/oauth/token".to_string(),
81            redirect_uri: "http://localhost:1455/auth/callback".to_string(),
82        }
83    }
84}
85
86impl OAuthConfig {
87    /// Create a new config builder
88    pub fn builder() -> OAuthConfigBuilder {
89        OAuthConfigBuilder::default()
90    }
91}
92
93/// Builder for OAuthConfig
94#[derive(Debug, Clone, Default)]
95pub struct OAuthConfigBuilder {
96    client_id: Option<String>,
97    auth_url: Option<String>,
98    token_url: Option<String>,
99    redirect_uri: Option<String>,
100}
101
102impl OAuthConfigBuilder {
103    /// Set the OAuth client ID
104    pub fn client_id(mut self, client_id: impl Into<String>) -> Self {
105        self.client_id = Some(client_id.into());
106        self
107    }
108
109    /// Set the authorization endpoint URL
110    pub fn auth_url(mut self, auth_url: impl Into<String>) -> Self {
111        self.auth_url = Some(auth_url.into());
112        self
113    }
114
115    /// Set the token exchange endpoint URL
116    pub fn token_url(mut self, token_url: impl Into<String>) -> Self {
117        self.token_url = Some(token_url.into());
118        self
119    }
120
121    /// Set the redirect URI
122    pub fn redirect_uri(mut self, redirect_uri: impl Into<String>) -> Self {
123        self.redirect_uri = Some(redirect_uri.into());
124        self
125    }
126
127    /// Set the redirect URI with a custom port
128    pub fn redirect_port(mut self, port: u16) -> Self {
129        self.redirect_uri = Some(format!("http://localhost:{}/auth/callback", port));
130        self
131    }
132
133    /// Build the OAuthConfig
134    pub fn build(self) -> OAuthConfig {
135        let defaults = OAuthConfig::default();
136        OAuthConfig {
137            client_id: self.client_id.unwrap_or(defaults.client_id),
138            auth_url: self.auth_url.unwrap_or(defaults.auth_url),
139            token_url: self.token_url.unwrap_or(defaults.token_url),
140            redirect_uri: self.redirect_uri.unwrap_or(defaults.redirect_uri),
141        }
142    }
143}
144
145/// Token response from OAuth server
146#[derive(Debug, Deserialize)]
147pub(crate) struct TokenResponse {
148    pub access_token: String,
149    pub id_token: Option<String>,
150    pub refresh_token: Option<String>,
151    pub expires_in: Option<u64>,
152}
153
154impl From<TokenResponse> for TokenSet {
155    fn from(response: TokenResponse) -> Self {
156        let expires_at = SystemTime::now()
157            .duration_since(UNIX_EPOCH)
158            .unwrap()
159            .as_secs()
160            + response.expires_in.unwrap_or(3600);
161
162        TokenSet {
163            access_token: response.access_token,
164            id_token: response.id_token,
165            refresh_token: response.refresh_token.unwrap_or_default(),
166            expires_at,
167            api_key: None,
168        }
169    }
170}
171
172/// Generate a random state string for CSRF protection
173pub(crate) fn generate_random_state() -> String {
174    use base64::{Engine as _, engine::general_purpose};
175    use rand::Rng;
176
177    let random_bytes: Vec<u8> = (0..32).map(|_| rand::thread_rng().r#gen()).collect();
178    general_purpose::URL_SAFE_NO_PAD.encode(&random_bytes)
179}
180
181pub(crate) fn generate_pkce_pair() -> (String, String) {
182    use base64::{Engine as _, engine::general_purpose};
183    use rand::RngCore;
184
185    let mut bytes = [0u8; 32];
186    rand::thread_rng().fill_bytes(&mut bytes);
187    let verifier = general_purpose::URL_SAFE_NO_PAD.encode(bytes);
188    let digest = Sha256::digest(verifier.as_bytes());
189    let challenge = general_purpose::URL_SAFE_NO_PAD.encode(digest);
190    (challenge, verifier)
191}
192
193/// Session data extracted from OAuth tokens
194///
195/// This contains user and organization information from the id_token and access_token
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct SessionData {
198    /// Organization ID from the id_token
199    pub organization_id: Option<String>,
200    /// Project ID from the id_token
201    pub project_id: Option<String>,
202    /// Whether the user has completed platform onboarding
203    pub completed_platform_onboarding: bool,
204    /// Whether the user is an organization owner
205    pub is_org_owner: bool,
206    /// ChatGPT plan type (from access_token)
207    pub chatgpt_plan_type: Option<String>,
208    /// ChatGPT account ID (from access_token)
209    pub chatgpt_account_id: Option<String>,
210}