1mod error;
24mod exchange;
25mod pkce;
26mod server;
27
28use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
29pub use error::Error;
30use rand::RngCore as _;
31
32const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
33const REDIRECT_URI: &str = "http://localhost:1455/auth/callback";
34const AUTH_URL: &str = "https://auth.openai.com/oauth/authorize";
35pub(crate) const TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
36const CALLBACK_PORT: u16 = 1455;
37const LOGIN_TIMEOUT_SECS: u64 = 120;
38
39#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
41pub struct Token {
42 pub access_token: String,
43 pub refresh_token: String,
44 pub id_token: String,
45 pub expires_in: u64,
47 pub issued_at: u64,
50}
51
52impl Token {
53 pub fn is_expired(&self) -> bool {
55 unix_now() >= self.issued_at + self.expires_in
56 }
57}
58
59pub async fn login() -> Result<Token, Error> {
64 let pkce = pkce::Pkce::generate();
65
66 let mut state_bytes = [0u8; 16];
67 rand::rng().fill_bytes(&mut state_bytes);
68 let state = URL_SAFE_NO_PAD.encode(state_bytes);
69
70 let auth_url = build_auth_url(&pkce.challenge, &state);
71
72 println!("Open this URL to log in:\n\n {auth_url}\n");
73 let _ = open_browser(&auth_url);
74
75 let (code, returned_state) = tokio::time::timeout(
76 std::time::Duration::from_secs(LOGIN_TIMEOUT_SECS),
77 server::wait_for_callback(CALLBACK_PORT),
78 )
79 .await
80 .map_err(|_| {
81 Error::Callback(format!(
82 "timed out waiting for browser callback ({LOGIN_TIMEOUT_SECS}s)"
83 ))
84 })??;
85
86 if returned_state != state {
87 return Err(Error::StateMismatch);
88 }
89
90 exchange::exchange_code(&code, &pkce.verifier).await
91}
92
93pub async fn refresh(refresh_token: &str) -> Result<Token, Error> {
95 exchange::refresh_token(refresh_token).await
96}
97
98pub(crate) fn build_auth_url(challenge: &str, state: &str) -> String {
99 let mut url = reqwest::Url::parse(AUTH_URL).expect("AUTH_URL is valid");
100 url.query_pairs_mut()
101 .append_pair("client_id", CLIENT_ID)
102 .append_pair("response_type", "code")
103 .append_pair("redirect_uri", REDIRECT_URI)
104 .append_pair("scope", "openid profile email offline_access")
105 .append_pair("state", state)
106 .append_pair("code_challenge", challenge)
107 .append_pair("code_challenge_method", "S256");
108 url.to_string()
109}
110
111pub(crate) fn unix_now() -> u64 {
112 std::time::SystemTime::now()
113 .duration_since(std::time::UNIX_EPOCH)
114 .unwrap_or_default()
115 .as_secs()
116}
117
118fn open_browser(url: &str) -> std::io::Result<()> {
119 #[cfg(target_os = "macos")]
120 std::process::Command::new("open").arg(url).spawn()?;
121 #[cfg(target_os = "linux")]
122 std::process::Command::new("xdg-open").arg(url).spawn()?;
123 #[cfg(target_os = "windows")]
124 std::process::Command::new("cmd")
126 .args(["/c", "start", "", url])
127 .spawn()?;
128 Ok(())
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135
136 #[test]
137 fn auth_url_contains_required_params() {
138 let url = build_auth_url("challenge123", "state456");
139 assert!(url.contains("client_id=app_EMoamEEZ73f0CkXaXp7hrann"));
140 assert!(url.contains("response_type=code"));
141 assert!(url.contains("code_challenge=challenge123"));
142 assert!(url.contains("code_challenge_method=S256"));
143 assert!(url.contains("state=state456"));
144 assert!(url.contains("scope="));
145 }
146
147 #[test]
148 fn redirect_uri_is_percent_encoded_in_auth_url() {
149 let url = build_auth_url("c", "s");
150 assert!(url.contains("redirect_uri=http%3A%2F%2F"));
151 }
152
153 #[test]
154 fn auth_url_parses_as_valid_url() {
155 let url = build_auth_url("challenge", "state");
156 reqwest::Url::parse(&url).expect("auth URL must be valid");
157 }
158
159 #[test]
160 fn token_expiry_detection() {
161 let expired = Token {
162 access_token: "a".into(),
163 refresh_token: "r".into(),
164 id_token: "i".into(),
165 expires_in: 3600,
166 issued_at: 0,
167 };
168 assert!(expired.is_expired());
169
170 let valid = Token {
171 issued_at: unix_now(),
172 expires_in: 3600,
173 ..expired.clone()
174 };
175 assert!(!valid.is_expired());
176 }
177}