1use serde::{Deserialize, Serialize};
2use sha2::{Digest, Sha256};
3use std::time::{Duration, SystemTime, UNIX_EPOCH};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct TokenSet {
8 pub access_token: String,
10 #[serde(default, skip_serializing_if = "Option::is_none")]
12 pub id_token: Option<String>,
13 pub refresh_token: String,
15 pub expires_at: u64,
17 #[serde(default, skip_serializing_if = "Option::is_none")]
19 pub api_key: Option<String>,
20}
21
22impl TokenSet {
23 pub fn is_expired(&self) -> bool {
28 self.expires_in() <= Duration::from_secs(300)
29 }
30
31 pub fn expires_in(&self) -> Duration {
35 let now = SystemTime::now()
36 .duration_since(UNIX_EPOCH)
37 .unwrap()
38 .as_secs();
39
40 if self.expires_at > now {
41 Duration::from_secs(self.expires_at - now)
42 } else {
43 Duration::ZERO
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
53pub struct OAuthFlow {
54 pub authorization_url: String,
56 pub pkce_verifier: String,
58 pub state: String,
60}
61
62#[derive(Debug, Clone)]
64pub struct OAuthConfig {
65 pub client_id: String,
67 pub auth_url: String,
69 pub token_url: String,
71 pub redirect_uri: String,
73}
74
75impl Default for OAuthConfig {
76 fn default() -> Self {
77 Self {
78 client_id: "app_EMoamEEZ73f0CkXaXp7hrann".to_string(),
79 auth_url: "https://auth.openai.com/oauth/authorize".to_string(),
80 token_url: "https://auth.openai.com/oauth/token".to_string(),
81 redirect_uri: "http://localhost:1455/auth/callback".to_string(),
82 }
83 }
84}
85
86impl OAuthConfig {
87 pub fn builder() -> OAuthConfigBuilder {
89 OAuthConfigBuilder::default()
90 }
91}
92
93#[derive(Debug, Clone, Default)]
95pub struct OAuthConfigBuilder {
96 client_id: Option<String>,
97 auth_url: Option<String>,
98 token_url: Option<String>,
99 redirect_uri: Option<String>,
100}
101
102impl OAuthConfigBuilder {
103 pub fn client_id(mut self, client_id: impl Into<String>) -> Self {
105 self.client_id = Some(client_id.into());
106 self
107 }
108
109 pub fn auth_url(mut self, auth_url: impl Into<String>) -> Self {
111 self.auth_url = Some(auth_url.into());
112 self
113 }
114
115 pub fn token_url(mut self, token_url: impl Into<String>) -> Self {
117 self.token_url = Some(token_url.into());
118 self
119 }
120
121 pub fn redirect_uri(mut self, redirect_uri: impl Into<String>) -> Self {
123 self.redirect_uri = Some(redirect_uri.into());
124 self
125 }
126
127 pub fn redirect_port(mut self, port: u16) -> Self {
129 self.redirect_uri = Some(format!("http://localhost:{}/auth/callback", port));
130 self
131 }
132
133 pub fn build(self) -> OAuthConfig {
135 let defaults = OAuthConfig::default();
136 OAuthConfig {
137 client_id: self.client_id.unwrap_or(defaults.client_id),
138 auth_url: self.auth_url.unwrap_or(defaults.auth_url),
139 token_url: self.token_url.unwrap_or(defaults.token_url),
140 redirect_uri: self.redirect_uri.unwrap_or(defaults.redirect_uri),
141 }
142 }
143}
144
145#[derive(Debug, Deserialize)]
147pub(crate) struct TokenResponse {
148 pub access_token: String,
149 pub id_token: Option<String>,
150 pub refresh_token: Option<String>,
151 pub expires_in: Option<u64>,
152}
153
154impl From<TokenResponse> for TokenSet {
155 fn from(response: TokenResponse) -> Self {
156 let expires_at = SystemTime::now()
157 .duration_since(UNIX_EPOCH)
158 .unwrap()
159 .as_secs()
160 + response.expires_in.unwrap_or(3600);
161
162 TokenSet {
163 access_token: response.access_token,
164 id_token: response.id_token,
165 refresh_token: response.refresh_token.unwrap_or_default(),
166 expires_at,
167 api_key: None,
168 }
169 }
170}
171
172pub(crate) fn generate_random_state() -> String {
174 use base64::{Engine as _, engine::general_purpose};
175 use rand::Rng;
176
177 let random_bytes: Vec<u8> = (0..32).map(|_| rand::thread_rng().r#gen()).collect();
178 general_purpose::URL_SAFE_NO_PAD.encode(&random_bytes)
179}
180
181pub(crate) fn generate_pkce_pair() -> (String, String) {
182 use base64::{Engine as _, engine::general_purpose};
183 use rand::RngCore;
184
185 let mut bytes = [0u8; 32];
186 rand::thread_rng().fill_bytes(&mut bytes);
187 let verifier = general_purpose::URL_SAFE_NO_PAD.encode(bytes);
188 let digest = Sha256::digest(verifier.as_bytes());
189 let challenge = general_purpose::URL_SAFE_NO_PAD.encode(digest);
190 (challenge, verifier)
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct SessionData {
198 pub organization_id: Option<String>,
200 pub project_id: Option<String>,
202 pub completed_platform_onboarding: bool,
204 pub is_org_owner: bool,
206 pub chatgpt_plan_type: Option<String>,
208 pub chatgpt_account_id: Option<String>,
210}