1use std::io::Cursor;
2use std::io::{self};
3use std::path::Path;
4use std::path::PathBuf;
5use std::sync::Arc;
6use std::thread;
7
8use crate::AuthDotJson;
9use crate::get_auth_file;
10use crate::pkce::PkceCodes;
11use crate::pkce::generate_pkce;
12use base64::Engine;
13use chrono::Utc;
14use rand::RngCore;
15use tiny_http::Header;
16use tiny_http::Request;
17use tiny_http::Response;
18use tiny_http::Server;
19
20const DEFAULT_ISSUER: &str = "https://auth.openai.com";
21const DEFAULT_PORT: u16 = 1455;
22
23#[derive(Debug, Clone)]
24pub struct ServerOptions {
25 pub codex_home: PathBuf,
26 pub client_id: String,
27 pub issuer: String,
28 pub port: u16,
29 pub open_browser: bool,
30 pub force_state: Option<String>,
31}
32
33impl ServerOptions {
34 pub fn new(codex_home: PathBuf, client_id: String) -> Self {
35 Self {
36 codex_home,
37 client_id: client_id.to_string(),
38 issuer: DEFAULT_ISSUER.to_string(),
39 port: DEFAULT_PORT,
40 open_browser: true,
41 force_state: None,
42 }
43 }
44}
45
46pub struct LoginServer {
47 pub auth_url: String,
48 pub actual_port: u16,
49 server_handle: tokio::task::JoinHandle<io::Result<()>>,
50 shutdown_handle: ShutdownHandle,
51}
52
53impl LoginServer {
54 pub async fn block_until_done(self) -> io::Result<()> {
55 self.server_handle
56 .await
57 .map_err(|err| io::Error::other(format!("login server thread panicked: {err:?}")))?
58 }
59
60 pub fn cancel(&self) {
61 self.shutdown_handle.shutdown();
62 }
63
64 pub fn cancel_handle(&self) -> ShutdownHandle {
65 self.shutdown_handle.clone()
66 }
67}
68
69#[derive(Clone, Debug)]
70pub struct ShutdownHandle {
71 shutdown_notify: Arc<tokio::sync::Notify>,
72}
73
74impl ShutdownHandle {
75 pub fn shutdown(&self) {
76 self.shutdown_notify.notify_waiters();
77 }
78}
79
80pub fn run_login_server(opts: ServerOptions) -> io::Result<LoginServer> {
81 let pkce = generate_pkce();
82 let state = opts.force_state.clone().unwrap_or_else(generate_state);
83
84 let server = Server::http(format!("127.0.0.1:{}", opts.port)).map_err(io::Error::other)?;
85 let actual_port = match server.server_addr().to_ip() {
86 Some(addr) => addr.port(),
87 None => {
88 return Err(io::Error::new(
89 io::ErrorKind::AddrInUse,
90 "Unable to determine the server port",
91 ));
92 }
93 };
94 let server = Arc::new(server);
95
96 let redirect_uri = format!("http://localhost:{actual_port}/auth/callback");
97 let auth_url = build_authorize_url(&opts.issuer, &opts.client_id, &redirect_uri, &pkce, &state);
98
99 if opts.open_browser {
100 let _ = webbrowser::open(&auth_url);
101 }
102
103 let (tx, mut rx) = tokio::sync::mpsc::channel::<Request>(16);
105 let _server_handle = {
106 let server = server.clone();
107 thread::spawn(move || -> io::Result<()> {
108 while let Ok(request) = server.recv() {
109 tx.blocking_send(request).map_err(|e| {
110 eprintln!("Failed to send request to channel: {e}");
111 io::Error::other("Failed to send request to channel")
112 })?;
113 }
114 Ok(())
115 })
116 };
117
118 let shutdown_notify = Arc::new(tokio::sync::Notify::new());
119 let server_handle = {
120 let shutdown_notify = shutdown_notify.clone();
121 let server = server.clone();
122 tokio::spawn(async move {
123 let result = loop {
124 tokio::select! {
125 _ = shutdown_notify.notified() => {
126 break Err(io::Error::other("Login was not completed"));
127 }
128 maybe_req = rx.recv() => {
129 let Some(req) = maybe_req else {
130 break Err(io::Error::other("Login was not completed"));
131 };
132
133 let url_raw = req.url().to_string();
134 let response =
135 process_request(&url_raw, &opts, &redirect_uri, &pkce, actual_port, &state).await;
136
137 let is_login_complete = matches!(response, HandledRequest::ResponseAndExit(_));
138 match response {
139 HandledRequest::Response(r) | HandledRequest::ResponseAndExit(r) => {
140 let _ = tokio::task::spawn_blocking(move || req.respond(r)).await;
141 }
142 HandledRequest::RedirectWithHeader(header) => {
143 let redirect = Response::empty(302).with_header(header);
144 let _ = tokio::task::spawn_blocking(move || req.respond(redirect)).await;
145 }
146 }
147
148 if is_login_complete {
149 break Ok(());
150 }
151 }
152 }
153 };
154
155 server.unblock();
158 result
159 })
160 };
161
162 Ok(LoginServer {
163 auth_url,
164 actual_port,
165 server_handle,
166 shutdown_handle: ShutdownHandle { shutdown_notify },
167 })
168}
169
170enum HandledRequest {
171 Response(Response<Cursor<Vec<u8>>>),
172 RedirectWithHeader(Header),
173 ResponseAndExit(Response<Cursor<Vec<u8>>>),
174}
175
176async fn process_request(
177 url_raw: &str,
178 opts: &ServerOptions,
179 redirect_uri: &str,
180 pkce: &PkceCodes,
181 actual_port: u16,
182 state: &str,
183) -> HandledRequest {
184 let parsed_url = match url::Url::parse(&format!("http://localhost{url_raw}")) {
185 Ok(u) => u,
186 Err(e) => {
187 eprintln!("URL parse error: {e}");
188 return HandledRequest::Response(
189 Response::from_string("Bad Request").with_status_code(400),
190 );
191 }
192 };
193 let path = parsed_url.path().to_string();
194
195 match path.as_str() {
196 "/auth/callback" => {
197 let params: std::collections::HashMap<String, String> =
198 parsed_url.query_pairs().into_owned().collect();
199 if params.get("state").map(String::as_str) != Some(state) {
200 return HandledRequest::Response(
201 Response::from_string("State mismatch").with_status_code(400),
202 );
203 }
204 let code = match params.get("code") {
205 Some(c) if !c.is_empty() => c.clone(),
206 _ => {
207 return HandledRequest::Response(
208 Response::from_string("Missing authorization code").with_status_code(400),
209 );
210 }
211 };
212
213 match exchange_code_for_tokens(&opts.issuer, &opts.client_id, redirect_uri, pkce, &code)
214 .await
215 {
216 Ok(tokens) => {
217 let api_key = obtain_api_key(&opts.issuer, &opts.client_id, &tokens.id_token)
219 .await
220 .ok();
221 if let Err(err) = persist_tokens_async(
222 &opts.codex_home,
223 api_key.clone(),
224 tokens.id_token.clone(),
225 Some(tokens.access_token.clone()),
226 Some(tokens.refresh_token.clone()),
227 )
228 .await
229 {
230 eprintln!("Persist error: {err}");
231 return HandledRequest::Response(
232 Response::from_string(format!("Unable to persist auth file: {err}"))
233 .with_status_code(500),
234 );
235 }
236
237 let success_url = compose_success_url(
238 actual_port,
239 &opts.issuer,
240 &tokens.id_token,
241 &tokens.access_token,
242 );
243 match tiny_http::Header::from_bytes(&b"Location"[..], success_url.as_bytes()) {
244 Ok(header) => HandledRequest::RedirectWithHeader(header),
245 Err(_) => HandledRequest::Response(
246 Response::from_string("Internal Server Error").with_status_code(500),
247 ),
248 }
249 }
250 Err(err) => {
251 eprintln!("Token exchange error: {err}");
252 HandledRequest::Response(
253 Response::from_string(format!("Token exchange failed: {err}"))
254 .with_status_code(500),
255 )
256 }
257 }
258 }
259 "/success" => {
260 let body = include_str!("assets/success.html");
261 let mut resp = Response::from_data(body.as_bytes());
262 if let Ok(h) = tiny_http::Header::from_bytes(
263 &b"Content-Type"[..],
264 &b"text/html; charset=utf-8"[..],
265 ) {
266 resp.add_header(h);
267 }
268 HandledRequest::ResponseAndExit(resp)
269 }
270 _ => HandledRequest::Response(Response::from_string("Not Found").with_status_code(404)),
271 }
272}
273
274fn build_authorize_url(
275 issuer: &str,
276 client_id: &str,
277 redirect_uri: &str,
278 pkce: &PkceCodes,
279 state: &str,
280) -> String {
281 let query = vec![
282 ("response_type", "code"),
283 ("client_id", client_id),
284 ("redirect_uri", redirect_uri),
285 ("scope", "openid profile email offline_access"),
286 ("code_challenge", &pkce.code_challenge),
287 ("code_challenge_method", "S256"),
288 ("id_token_add_organizations", "true"),
289 ("agcodex_cli_simplified_flow", "true"),
290 ("state", state),
291 ];
292 let qs = query
293 .into_iter()
294 .map(|(k, v)| format!("{}={}", k, urlencoding::encode(v)))
295 .collect::<Vec<_>>()
296 .join("&");
297 format!("{issuer}/oauth/authorize?{qs}")
298}
299
300fn generate_state() -> String {
301 let mut bytes = [0u8; 32];
302 rand::rng().fill_bytes(&mut bytes);
303 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
304}
305
306struct ExchangedTokens {
307 id_token: String,
308 access_token: String,
309 refresh_token: String,
310}
311
312async fn exchange_code_for_tokens(
313 issuer: &str,
314 client_id: &str,
315 redirect_uri: &str,
316 pkce: &PkceCodes,
317 code: &str,
318) -> io::Result<ExchangedTokens> {
319 #[derive(serde::Deserialize)]
320 struct TokenResponse {
321 id_token: String,
322 access_token: String,
323 refresh_token: String,
324 }
325
326 let client = reqwest::Client::new();
327 let resp = client
328 .post(format!("{issuer}/oauth/token"))
329 .header("Content-Type", "application/x-www-form-urlencoded")
330 .body(format!(
331 "grant_type=authorization_code&code={}&redirect_uri={}&client_id={}&code_verifier={}",
332 urlencoding::encode(code),
333 urlencoding::encode(redirect_uri),
334 urlencoding::encode(client_id),
335 urlencoding::encode(&pkce.code_verifier)
336 ))
337 .send()
338 .await
339 .map_err(io::Error::other)?;
340
341 if !resp.status().is_success() {
342 return Err(io::Error::other(format!(
343 "token endpoint returned status {}",
344 resp.status()
345 )));
346 }
347
348 let tokens: TokenResponse = resp.json().await.map_err(io::Error::other)?;
349 Ok(ExchangedTokens {
350 id_token: tokens.id_token,
351 access_token: tokens.access_token,
352 refresh_token: tokens.refresh_token,
353 })
354}
355
356async fn persist_tokens_async(
357 codex_home: &Path,
358 api_key: Option<String>,
359 id_token: String,
360 access_token: Option<String>,
361 refresh_token: Option<String>,
362) -> io::Result<()> {
363 let codex_home = codex_home.to_path_buf();
365 tokio::task::spawn_blocking(move || {
366 let auth_file = get_auth_file(&codex_home);
367 if let Some(parent) = auth_file.parent()
368 && !parent.exists()
369 {
370 std::fs::create_dir_all(parent).map_err(io::Error::other)?;
371 }
372
373 let mut auth = read_or_default(&auth_file);
374 if let Some(key) = api_key {
375 auth.openai_api_key = Some(key);
376 }
377 let tokens = auth
378 .tokens
379 .get_or_insert_with(crate::token_data::TokenData::default);
380 tokens.id_token = crate::token_data::parse_id_token(&id_token).map_err(io::Error::other)?;
381 if let Some(acc) = jwt_auth_claims(&id_token)
383 .get("chatgpt_account_id")
384 .and_then(|v| v.as_str())
385 {
386 tokens.account_id = Some(acc.to_string());
387 }
388 if let Some(at) = access_token {
389 tokens.access_token = at;
390 }
391 if let Some(rt) = refresh_token {
392 tokens.refresh_token = rt;
393 }
394 auth.last_refresh = Some(Utc::now());
395 super::write_auth_json(&auth_file, &auth)
396 })
397 .await
398 .map_err(|e| io::Error::other(format!("persist task failed: {e}")))?
399}
400
401fn read_or_default(path: &Path) -> AuthDotJson {
402 match super::try_read_auth_json(path) {
403 Ok(auth) => auth,
404 Err(_) => AuthDotJson {
405 openai_api_key: None,
406 tokens: None,
407 last_refresh: None,
408 },
409 }
410}
411
412fn compose_success_url(port: u16, issuer: &str, id_token: &str, access_token: &str) -> String {
413 let token_claims = jwt_auth_claims(id_token);
414 let access_claims = jwt_auth_claims(access_token);
415
416 let org_id = token_claims
417 .get("organization_id")
418 .and_then(|v| v.as_str())
419 .unwrap_or("");
420 let project_id = token_claims
421 .get("project_id")
422 .and_then(|v| v.as_str())
423 .unwrap_or("");
424 let completed_onboarding = token_claims
425 .get("completed_platform_onboarding")
426 .and_then(|v| v.as_bool())
427 .unwrap_or(false);
428 let is_org_owner = token_claims
429 .get("is_org_owner")
430 .and_then(|v| v.as_bool())
431 .unwrap_or(false);
432 let needs_setup = (!completed_onboarding) && is_org_owner;
433 let plan_type = access_claims
434 .get("chatgpt_plan_type")
435 .and_then(|v| v.as_str())
436 .unwrap_or("");
437
438 let platform_url = if issuer == DEFAULT_ISSUER {
439 "https://platform.openai.com"
440 } else {
441 "https://platform.api.openai.org"
442 };
443
444 let mut params = vec![
445 ("id_token", id_token.to_string()),
446 ("needs_setup", needs_setup.to_string()),
447 ("org_id", org_id.to_string()),
448 ("project_id", project_id.to_string()),
449 ("plan_type", plan_type.to_string()),
450 ("platform_url", platform_url.to_string()),
451 ];
452 let qs = params
453 .drain(..)
454 .map(|(k, v)| format!("{}={}", k, urlencoding::encode(&v)))
455 .collect::<Vec<_>>()
456 .join("&");
457 format!("http://localhost:{port}/success?{qs}")
458}
459
460fn jwt_auth_claims(jwt: &str) -> serde_json::Map<String, serde_json::Value> {
461 let mut parts = jwt.split('.');
462 let (_h, payload_b64, _s) = match (parts.next(), parts.next(), parts.next()) {
463 (Some(h), Some(p), Some(s)) if !h.is_empty() && !p.is_empty() && !s.is_empty() => (h, p, s),
464 _ => {
465 eprintln!("Invalid JWT format while extracting claims");
466 return serde_json::Map::new();
467 }
468 };
469 match base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(payload_b64) {
470 Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
471 Ok(mut v) => {
472 if let Some(obj) = v
473 .get_mut("https://api.openai.com/auth")
474 .and_then(|x| x.as_object_mut())
475 {
476 return obj.clone();
477 }
478 eprintln!("JWT payload missing expected 'https://api.openai.com/auth' object");
479 }
480 Err(e) => {
481 eprintln!("Failed to parse JWT JSON payload: {e}");
482 }
483 },
484 Err(e) => {
485 eprintln!("Failed to base64url-decode JWT payload: {e}");
486 }
487 }
488 serde_json::Map::new()
489}
490
491async fn obtain_api_key(issuer: &str, client_id: &str, id_token: &str) -> io::Result<String> {
492 #[derive(serde::Deserialize)]
494 struct ExchangeResp {
495 access_token: String,
496 }
497 let client = reqwest::Client::new();
498 let resp = client
499 .post(format!("{issuer}/oauth/token"))
500 .header("Content-Type", "application/x-www-form-urlencoded")
501 .body(format!(
502 "grant_type={}&client_id={}&requested_token={}&subject_token={}&subject_token_type={}",
503 urlencoding::encode("urn:ietf:params:oauth:grant-type:token-exchange"),
504 urlencoding::encode(client_id),
505 urlencoding::encode("openai-api-key"),
506 urlencoding::encode(id_token),
507 urlencoding::encode("urn:ietf:params:oauth:token-type:id_token")
508 ))
509 .send()
510 .await
511 .map_err(io::Error::other)?;
512 if !resp.status().is_success() {
513 return Err(io::Error::other(format!(
514 "api key exchange failed with status {}",
515 resp.status()
516 )));
517 }
518 let body: ExchangeResp = resp.json().await.map_err(io::Error::other)?;
519 Ok(body.access_token)
520}