Skip to main content

agcodex_login/
server.rs

1use std::io::Cursor;
2use std::io::{self};
3use std::path::Path;
4use std::path::PathBuf;
5use std::sync::Arc;
6use std::thread;
7
8use crate::AuthDotJson;
9use crate::get_auth_file;
10use crate::pkce::PkceCodes;
11use crate::pkce::generate_pkce;
12use base64::Engine;
13use chrono::Utc;
14use rand::RngCore;
15use tiny_http::Header;
16use tiny_http::Request;
17use tiny_http::Response;
18use tiny_http::Server;
19
20const DEFAULT_ISSUER: &str = "https://auth.openai.com";
21const DEFAULT_PORT: u16 = 1455;
22
23#[derive(Debug, Clone)]
24pub struct ServerOptions {
25    pub codex_home: PathBuf,
26    pub client_id: String,
27    pub issuer: String,
28    pub port: u16,
29    pub open_browser: bool,
30    pub force_state: Option<String>,
31}
32
33impl ServerOptions {
34    pub fn new(codex_home: PathBuf, client_id: String) -> Self {
35        Self {
36            codex_home,
37            client_id: client_id.to_string(),
38            issuer: DEFAULT_ISSUER.to_string(),
39            port: DEFAULT_PORT,
40            open_browser: true,
41            force_state: None,
42        }
43    }
44}
45
46pub struct LoginServer {
47    pub auth_url: String,
48    pub actual_port: u16,
49    server_handle: tokio::task::JoinHandle<io::Result<()>>,
50    shutdown_handle: ShutdownHandle,
51}
52
53impl LoginServer {
54    pub async fn block_until_done(self) -> io::Result<()> {
55        self.server_handle
56            .await
57            .map_err(|err| io::Error::other(format!("login server thread panicked: {err:?}")))?
58    }
59
60    pub fn cancel(&self) {
61        self.shutdown_handle.shutdown();
62    }
63
64    pub fn cancel_handle(&self) -> ShutdownHandle {
65        self.shutdown_handle.clone()
66    }
67}
68
69#[derive(Clone, Debug)]
70pub struct ShutdownHandle {
71    shutdown_notify: Arc<tokio::sync::Notify>,
72}
73
74impl ShutdownHandle {
75    pub fn shutdown(&self) {
76        self.shutdown_notify.notify_waiters();
77    }
78}
79
80pub fn run_login_server(opts: ServerOptions) -> io::Result<LoginServer> {
81    let pkce = generate_pkce();
82    let state = opts.force_state.clone().unwrap_or_else(generate_state);
83
84    let server = Server::http(format!("127.0.0.1:{}", opts.port)).map_err(io::Error::other)?;
85    let actual_port = match server.server_addr().to_ip() {
86        Some(addr) => addr.port(),
87        None => {
88            return Err(io::Error::new(
89                io::ErrorKind::AddrInUse,
90                "Unable to determine the server port",
91            ));
92        }
93    };
94    let server = Arc::new(server);
95
96    let redirect_uri = format!("http://localhost:{actual_port}/auth/callback");
97    let auth_url = build_authorize_url(&opts.issuer, &opts.client_id, &redirect_uri, &pkce, &state);
98
99    if opts.open_browser {
100        let _ = webbrowser::open(&auth_url);
101    }
102
103    // Map blocking reads from server.recv() to an async channel.
104    let (tx, mut rx) = tokio::sync::mpsc::channel::<Request>(16);
105    let _server_handle = {
106        let server = server.clone();
107        thread::spawn(move || -> io::Result<()> {
108            while let Ok(request) = server.recv() {
109                tx.blocking_send(request).map_err(|e| {
110                    eprintln!("Failed to send request to channel: {e}");
111                    io::Error::other("Failed to send request to channel")
112                })?;
113            }
114            Ok(())
115        })
116    };
117
118    let shutdown_notify = Arc::new(tokio::sync::Notify::new());
119    let server_handle = {
120        let shutdown_notify = shutdown_notify.clone();
121        let server = server.clone();
122        tokio::spawn(async move {
123            let result = loop {
124                tokio::select! {
125                    _ = shutdown_notify.notified() => {
126                        break Err(io::Error::other("Login was not completed"));
127                    }
128                    maybe_req = rx.recv() => {
129                        let Some(req) = maybe_req else {
130                            break Err(io::Error::other("Login was not completed"));
131                        };
132
133                        let url_raw = req.url().to_string();
134                        let response =
135                            process_request(&url_raw, &opts, &redirect_uri, &pkce, actual_port, &state).await;
136
137                        let is_login_complete = matches!(response, HandledRequest::ResponseAndExit(_));
138                        match response {
139                            HandledRequest::Response(r) | HandledRequest::ResponseAndExit(r) => {
140                                let _ = tokio::task::spawn_blocking(move || req.respond(r)).await;
141                            }
142                            HandledRequest::RedirectWithHeader(header) => {
143                                let redirect = Response::empty(302).with_header(header);
144                                let _ = tokio::task::spawn_blocking(move || req.respond(redirect)).await;
145                            }
146                        }
147
148                        if is_login_complete {
149                            break Ok(());
150                        }
151                    }
152                }
153            };
154
155            // Ensure that the server is unblocked so the thread dedicated to
156            // running `server.recv()` in a loop exits cleanly.
157            server.unblock();
158            result
159        })
160    };
161
162    Ok(LoginServer {
163        auth_url,
164        actual_port,
165        server_handle,
166        shutdown_handle: ShutdownHandle { shutdown_notify },
167    })
168}
169
170enum HandledRequest {
171    Response(Response<Cursor<Vec<u8>>>),
172    RedirectWithHeader(Header),
173    ResponseAndExit(Response<Cursor<Vec<u8>>>),
174}
175
176async fn process_request(
177    url_raw: &str,
178    opts: &ServerOptions,
179    redirect_uri: &str,
180    pkce: &PkceCodes,
181    actual_port: u16,
182    state: &str,
183) -> HandledRequest {
184    let parsed_url = match url::Url::parse(&format!("http://localhost{url_raw}")) {
185        Ok(u) => u,
186        Err(e) => {
187            eprintln!("URL parse error: {e}");
188            return HandledRequest::Response(
189                Response::from_string("Bad Request").with_status_code(400),
190            );
191        }
192    };
193    let path = parsed_url.path().to_string();
194
195    match path.as_str() {
196        "/auth/callback" => {
197            let params: std::collections::HashMap<String, String> =
198                parsed_url.query_pairs().into_owned().collect();
199            if params.get("state").map(String::as_str) != Some(state) {
200                return HandledRequest::Response(
201                    Response::from_string("State mismatch").with_status_code(400),
202                );
203            }
204            let code = match params.get("code") {
205                Some(c) if !c.is_empty() => c.clone(),
206                _ => {
207                    return HandledRequest::Response(
208                        Response::from_string("Missing authorization code").with_status_code(400),
209                    );
210                }
211            };
212
213            match exchange_code_for_tokens(&opts.issuer, &opts.client_id, redirect_uri, pkce, &code)
214                .await
215            {
216                Ok(tokens) => {
217                    // Obtain API key via token-exchange and persist
218                    let api_key = obtain_api_key(&opts.issuer, &opts.client_id, &tokens.id_token)
219                        .await
220                        .ok();
221                    if let Err(err) = persist_tokens_async(
222                        &opts.codex_home,
223                        api_key.clone(),
224                        tokens.id_token.clone(),
225                        Some(tokens.access_token.clone()),
226                        Some(tokens.refresh_token.clone()),
227                    )
228                    .await
229                    {
230                        eprintln!("Persist error: {err}");
231                        return HandledRequest::Response(
232                            Response::from_string(format!("Unable to persist auth file: {err}"))
233                                .with_status_code(500),
234                        );
235                    }
236
237                    let success_url = compose_success_url(
238                        actual_port,
239                        &opts.issuer,
240                        &tokens.id_token,
241                        &tokens.access_token,
242                    );
243                    match tiny_http::Header::from_bytes(&b"Location"[..], success_url.as_bytes()) {
244                        Ok(header) => HandledRequest::RedirectWithHeader(header),
245                        Err(_) => HandledRequest::Response(
246                            Response::from_string("Internal Server Error").with_status_code(500),
247                        ),
248                    }
249                }
250                Err(err) => {
251                    eprintln!("Token exchange error: {err}");
252                    HandledRequest::Response(
253                        Response::from_string(format!("Token exchange failed: {err}"))
254                            .with_status_code(500),
255                    )
256                }
257            }
258        }
259        "/success" => {
260            let body = include_str!("assets/success.html");
261            let mut resp = Response::from_data(body.as_bytes());
262            if let Ok(h) = tiny_http::Header::from_bytes(
263                &b"Content-Type"[..],
264                &b"text/html; charset=utf-8"[..],
265            ) {
266                resp.add_header(h);
267            }
268            HandledRequest::ResponseAndExit(resp)
269        }
270        _ => HandledRequest::Response(Response::from_string("Not Found").with_status_code(404)),
271    }
272}
273
274fn build_authorize_url(
275    issuer: &str,
276    client_id: &str,
277    redirect_uri: &str,
278    pkce: &PkceCodes,
279    state: &str,
280) -> String {
281    let query = vec![
282        ("response_type", "code"),
283        ("client_id", client_id),
284        ("redirect_uri", redirect_uri),
285        ("scope", "openid profile email offline_access"),
286        ("code_challenge", &pkce.code_challenge),
287        ("code_challenge_method", "S256"),
288        ("id_token_add_organizations", "true"),
289        ("agcodex_cli_simplified_flow", "true"),
290        ("state", state),
291    ];
292    let qs = query
293        .into_iter()
294        .map(|(k, v)| format!("{}={}", k, urlencoding::encode(v)))
295        .collect::<Vec<_>>()
296        .join("&");
297    format!("{issuer}/oauth/authorize?{qs}")
298}
299
300fn generate_state() -> String {
301    let mut bytes = [0u8; 32];
302    rand::rng().fill_bytes(&mut bytes);
303    base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
304}
305
306struct ExchangedTokens {
307    id_token: String,
308    access_token: String,
309    refresh_token: String,
310}
311
312async fn exchange_code_for_tokens(
313    issuer: &str,
314    client_id: &str,
315    redirect_uri: &str,
316    pkce: &PkceCodes,
317    code: &str,
318) -> io::Result<ExchangedTokens> {
319    #[derive(serde::Deserialize)]
320    struct TokenResponse {
321        id_token: String,
322        access_token: String,
323        refresh_token: String,
324    }
325
326    let client = reqwest::Client::new();
327    let resp = client
328        .post(format!("{issuer}/oauth/token"))
329        .header("Content-Type", "application/x-www-form-urlencoded")
330        .body(format!(
331            "grant_type=authorization_code&code={}&redirect_uri={}&client_id={}&code_verifier={}",
332            urlencoding::encode(code),
333            urlencoding::encode(redirect_uri),
334            urlencoding::encode(client_id),
335            urlencoding::encode(&pkce.code_verifier)
336        ))
337        .send()
338        .await
339        .map_err(io::Error::other)?;
340
341    if !resp.status().is_success() {
342        return Err(io::Error::other(format!(
343            "token endpoint returned status {}",
344            resp.status()
345        )));
346    }
347
348    let tokens: TokenResponse = resp.json().await.map_err(io::Error::other)?;
349    Ok(ExchangedTokens {
350        id_token: tokens.id_token,
351        access_token: tokens.access_token,
352        refresh_token: tokens.refresh_token,
353    })
354}
355
356async fn persist_tokens_async(
357    codex_home: &Path,
358    api_key: Option<String>,
359    id_token: String,
360    access_token: Option<String>,
361    refresh_token: Option<String>,
362) -> io::Result<()> {
363    // Reuse existing synchronous logic but run it off the async runtime.
364    let codex_home = codex_home.to_path_buf();
365    tokio::task::spawn_blocking(move || {
366        let auth_file = get_auth_file(&codex_home);
367        if let Some(parent) = auth_file.parent()
368            && !parent.exists()
369        {
370            std::fs::create_dir_all(parent).map_err(io::Error::other)?;
371        }
372
373        let mut auth = read_or_default(&auth_file);
374        if let Some(key) = api_key {
375            auth.openai_api_key = Some(key);
376        }
377        let tokens = auth
378            .tokens
379            .get_or_insert_with(crate::token_data::TokenData::default);
380        tokens.id_token = crate::token_data::parse_id_token(&id_token).map_err(io::Error::other)?;
381        // Persist chatgpt_account_id if present in claims
382        if let Some(acc) = jwt_auth_claims(&id_token)
383            .get("chatgpt_account_id")
384            .and_then(|v| v.as_str())
385        {
386            tokens.account_id = Some(acc.to_string());
387        }
388        if let Some(at) = access_token {
389            tokens.access_token = at;
390        }
391        if let Some(rt) = refresh_token {
392            tokens.refresh_token = rt;
393        }
394        auth.last_refresh = Some(Utc::now());
395        super::write_auth_json(&auth_file, &auth)
396    })
397    .await
398    .map_err(|e| io::Error::other(format!("persist task failed: {e}")))?
399}
400
401fn read_or_default(path: &Path) -> AuthDotJson {
402    match super::try_read_auth_json(path) {
403        Ok(auth) => auth,
404        Err(_) => AuthDotJson {
405            openai_api_key: None,
406            tokens: None,
407            last_refresh: None,
408        },
409    }
410}
411
412fn compose_success_url(port: u16, issuer: &str, id_token: &str, access_token: &str) -> String {
413    let token_claims = jwt_auth_claims(id_token);
414    let access_claims = jwt_auth_claims(access_token);
415
416    let org_id = token_claims
417        .get("organization_id")
418        .and_then(|v| v.as_str())
419        .unwrap_or("");
420    let project_id = token_claims
421        .get("project_id")
422        .and_then(|v| v.as_str())
423        .unwrap_or("");
424    let completed_onboarding = token_claims
425        .get("completed_platform_onboarding")
426        .and_then(|v| v.as_bool())
427        .unwrap_or(false);
428    let is_org_owner = token_claims
429        .get("is_org_owner")
430        .and_then(|v| v.as_bool())
431        .unwrap_or(false);
432    let needs_setup = (!completed_onboarding) && is_org_owner;
433    let plan_type = access_claims
434        .get("chatgpt_plan_type")
435        .and_then(|v| v.as_str())
436        .unwrap_or("");
437
438    let platform_url = if issuer == DEFAULT_ISSUER {
439        "https://platform.openai.com"
440    } else {
441        "https://platform.api.openai.org"
442    };
443
444    let mut params = vec![
445        ("id_token", id_token.to_string()),
446        ("needs_setup", needs_setup.to_string()),
447        ("org_id", org_id.to_string()),
448        ("project_id", project_id.to_string()),
449        ("plan_type", plan_type.to_string()),
450        ("platform_url", platform_url.to_string()),
451    ];
452    let qs = params
453        .drain(..)
454        .map(|(k, v)| format!("{}={}", k, urlencoding::encode(&v)))
455        .collect::<Vec<_>>()
456        .join("&");
457    format!("http://localhost:{port}/success?{qs}")
458}
459
460fn jwt_auth_claims(jwt: &str) -> serde_json::Map<String, serde_json::Value> {
461    let mut parts = jwt.split('.');
462    let (_h, payload_b64, _s) = match (parts.next(), parts.next(), parts.next()) {
463        (Some(h), Some(p), Some(s)) if !h.is_empty() && !p.is_empty() && !s.is_empty() => (h, p, s),
464        _ => {
465            eprintln!("Invalid JWT format while extracting claims");
466            return serde_json::Map::new();
467        }
468    };
469    match base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(payload_b64) {
470        Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
471            Ok(mut v) => {
472                if let Some(obj) = v
473                    .get_mut("https://api.openai.com/auth")
474                    .and_then(|x| x.as_object_mut())
475                {
476                    return obj.clone();
477                }
478                eprintln!("JWT payload missing expected 'https://api.openai.com/auth' object");
479            }
480            Err(e) => {
481                eprintln!("Failed to parse JWT JSON payload: {e}");
482            }
483        },
484        Err(e) => {
485            eprintln!("Failed to base64url-decode JWT payload: {e}");
486        }
487    }
488    serde_json::Map::new()
489}
490
491async fn obtain_api_key(issuer: &str, client_id: &str, id_token: &str) -> io::Result<String> {
492    // Token exchange for an API key access token
493    #[derive(serde::Deserialize)]
494    struct ExchangeResp {
495        access_token: String,
496    }
497    let client = reqwest::Client::new();
498    let resp = client
499        .post(format!("{issuer}/oauth/token"))
500        .header("Content-Type", "application/x-www-form-urlencoded")
501        .body(format!(
502            "grant_type={}&client_id={}&requested_token={}&subject_token={}&subject_token_type={}",
503            urlencoding::encode("urn:ietf:params:oauth:grant-type:token-exchange"),
504            urlencoding::encode(client_id),
505            urlencoding::encode("openai-api-key"),
506            urlencoding::encode(id_token),
507            urlencoding::encode("urn:ietf:params:oauth:token-type:id_token")
508        ))
509        .send()
510        .await
511        .map_err(io::Error::other)?;
512    if !resp.status().is_success() {
513        return Err(io::Error::other(format!(
514            "api key exchange failed with status {}",
515            resp.status()
516        )));
517    }
518    let body: ExchangeResp = resp.json().await.map_err(io::Error::other)?;
519    Ok(body.access_token)
520}