1use async_trait::async_trait;
2use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
3use rand::Rng;
4use serde::{Deserialize, Serialize};
5use sha2::{Digest, Sha256};
6use std::net::SocketAddr;
7use std::sync::Arc;
8use std::time::{Duration, SystemTime, UNIX_EPOCH};
9use tracing::info;
10
11use steer_auth_plugin::AuthPlugin;
12use steer_auth_plugin::{
13 AuthDirective, AuthError, AuthErrorAction, AuthErrorContext, AuthHeaderContext,
14 AuthHeaderProvider, AuthMethod, AuthProgress, AuthSource, AuthStorage, AuthTokens,
15 AuthenticationFlow, Credential, CredentialType, DynAuthenticationFlow, HeaderPair,
16 InstructionPolicy, ModelId, ModelVisibilityPolicy, OpenAiResponsesAuth, ProviderId, Result,
17};
18use steer_tools::tools::{
19 AST_GREP_TOOL_NAME, BASH_TOOL_NAME, DISPATCH_AGENT_TOOL_NAME, EDIT_TOOL_NAME, FETCH_TOOL_NAME,
20 GLOB_TOOL_NAME, GREP_TOOL_NAME, LS_TOOL_NAME, MULTI_EDIT_TOOL_NAME, REPLACE_TOOL_NAME,
21 TODO_READ_TOOL_NAME, TODO_WRITE_TOOL_NAME, VIEW_TOOL_NAME,
22};
23
24mod callback_server;
25use callback_server::{CallbackResponse, CallbackServerHandle, spawn_callback_server};
26
27const PROVIDER_ID: &str = "openai";
28const AUTHORIZE_URL: &str = "https://auth.openai.com/oauth/authorize";
29const TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
30const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
31const REDIRECT_URI: &str = "http://localhost:1455/auth/callback";
32const SCOPES: &str = "openid profile email offline_access";
33const ORIGINATOR: &str = "codex_cli_rs";
34const CALLBACK_PATH: &str = "/auth/callback";
35const CALLBACK_PORT: u16 = 1455;
36
37const CODEX_BASE_URL: &str = "https://chatgpt.com/backend-api/codex/responses";
38const OPENAI_BETA: &str = "responses=experimental";
39const GPT_5_2_CODEX_MODEL_ID: &str = "gpt-5.2-codex";
40const GPT_5_3_CODEX_MODEL_ID: &str = "gpt-5.3-codex";
41const CODEX_SYSTEM_PROMPT: &str = r#"You are Codex, based on GPT-5. You are running as a coding agent in the Codex CLI on a user's computer.
42
43## General
44
45- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.)
46
47## Editing constraints
48
49- Default to ASCII when editing or creating files. Only introduce non-ASCII or other Unicode characters when there is a clear justification and the file already uses them.
50- Add succinct code comments that explain what is going on if code is not self-explanatory. You should not add comments like "Assigns the value to the variable", but a brief comment might be useful ahead of a complex code block that the user would otherwise have to spend time parsing out. Usage of these comments should be rare.
51- Try to use apply_patch for single file edits, but it is fine to explore other options to make the edit if it does not work well. Do not use apply_patch for changes that are auto-generated (i.e. generating package.json or running a lint or format command like gofmt) or when scripting is more efficient (such as search and replacing a string across a codebase).
52- You may be in a dirty git worktree.
53 * NEVER revert existing changes you did not make unless explicitly requested, since these changes were made by the user.
54 * If asked to make a commit or code edits and there are unrelated changes to your work or changes that you didn't make in those files, don't revert those changes.
55 * If the changes are in files you've touched recently, you should read carefully and understand how you can work with the changes rather than reverting them.
56 * If the changes are in unrelated files, just ignore them and don't revert them.
57- Do not amend a commit unless explicitly requested to do so.
58- While you are working, you might notice unexpected changes that you didn't make. If this happens, STOP IMMEDIATELY and ask the user how they would like to proceed.
59- **NEVER** use destructive commands like `git reset --hard` or `git checkout --` unless specifically requested or approved by the user.
60
61## Plan tool
62
63When using the planning tool:
64- Skip using the planning tool for straightforward tasks (roughly the easiest 25%).
65- Do not make single-step plans.
66- When you made a plan, update it after having performed one of the sub-tasks that you shared on the plan.
67
68## Special user requests
69
70- If the user makes a simple request (such as asking for the time) which you can fulfill by running a terminal command (such as `date`), you should do so.
71- If the user asks for a "review", default to a code review mindset: prioritise identifying bugs, risks, behavioural regressions, and missing tests. Findings must be the primary focus of the response - keep summaries or overviews brief and only after enumerating the issues. Present findings first (ordered by severity with file/line references), follow with open questions or assumptions, and offer a change-summary only as a secondary detail. If no findings are discovered, state that explicitly and mention any residual risks or testing gaps.
72
73## Frontend tasks
74When doing frontend design tasks, avoid collapsing into "AI slop" or safe, average-looking layouts.
75Aim for interfaces that feel intentional, bold, and a bit surprising.
76- Typography: Use expressive, purposeful fonts and avoid default stacks (Inter, Roboto, Arial, system).
77- Color & Look: Choose a clear visual direction; define CSS variables; avoid purple-on-white defaults. No purple bias or dark mode bias.
78- Motion: Use a few meaningful animations (page-load, staggered reveals) instead of generic micro-motions.
79- Background: Don't rely on flat, single-color backgrounds; use gradients, shapes, or subtle patterns to build atmosphere.
80- Overall: Avoid boilerplate layouts and interchangeable UI patterns. Vary themes, type families, and visual languages across outputs.
81- Ensure the page loads properly on both desktop and mobile
82
83Exception: If working within an existing website or design system, preserve the established patterns, structure, and visual language.
84
85## Presenting your work and final message
86
87You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value.
88
89- Default: be very concise; friendly coding teammate tone.
90- Ask only when needed; suggest ideas; mirror the user's style.
91- For substantial work, summarize clearly; follow final‑answer formatting.
92- Skip heavy formatting for simple confirmations.
93- Don't dump large files you've written; reference paths only.
94- No "save/copy this file" - User is on the same machine.
95- Offer logical next steps (tests, commits, build) briefly; add verify steps if you couldn't do something.
96- For code changes:
97 * Lead with a quick explanation of the change, and then give more details on the context covering where and why a change was made. Do not start this explanation with "summary", just jump right in.
98 * If there are natural next steps the user may want to take, suggest them at the end of your response. Do not make suggestions if there are no natural next steps.
99 * When suggesting multiple options, use numeric lists for the suggestions so the user can quickly respond with a single number.
100- The user does not command execution outputs. When asked to show the output of a command (e.g. `git show`), relay the important details in your answer or summarize the key lines so the user understands the result.
101
102### Final answer structure and style guidelines
103
104- Plain text; CLI handles styling. Use structure only when it helps scanability.
105- Headers: optional; short Title Case (1-3 words) wrapped in **…**; no blank line before the first bullet; add only if they truly help.
106- Bullets: use - ; merge related points; keep to one line when possible; 4–6 per list ordered by importance; keep phrasing consistent.
107- Monospace: backticks for commands/paths/env vars/code ids and inline examples; use for literal keyword bullets; never combine with **.
108- Code samples or multi-line snippets should be wrapped in fenced code blocks; include an info string as often as possible.
109- Structure: group related bullets; order sections general → specific → supporting; for subsections, start with a bolded keyword bullet, then items; match complexity to the task.
110- Tone: collaborative, concise, factual; present tense, active voice; self‑contained; no "above/below"; parallel wording.
111- Don'ts: no nested bullets/hierarchies; no ANSI codes; don't cram unrelated keywords; keep keyword lists short—wrap/reformat if long; avoid naming formatting styles in answers.
112- Adaptation: code explanations → precise, structured with code refs; simple tasks → lead with outcome; big changes → logical walkthrough + rationale + next actions; casual one-offs → plain sentences, no headers/bullets.
113- File References: When referencing files in your response follow the below rules:
114 * Use inline code to make file paths clickable.
115 * Each reference should have a stand alone path. Even if it's the same file.
116 * Accepted: absolute, workspace‑relative, a/ or b/ diff prefixes, or bare filename/suffix.
117 * Optionally include line/column (1‑based): :line[:column] or #Lline[Ccolumn] (column defaults to 1).
118 * Do not use URIs like file://, vscode://, or https://.
119 * Do not provide range of lines
120 * Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\repo\project\main.rs:12:5
121"#;
122
123fn steer_codex_bridge_prompt() -> String {
124 format!(
125 r"## Codex Running in Steer
126
127You are running Codex inside Steer, an open-source terminal coding assistant.
128
129### CRITICAL tool replacements
130- apply_patch does NOT exist. Use `{EDIT_TOOL_NAME}` instead.
131- update_plan does NOT exist. Use `{TODO_WRITE_TOOL_NAME}` instead.
132- read_plan does NOT exist. Use `{TODO_READ_TOOL_NAME}` instead.
133
134### Steer tool names
135- File: `{VIEW_TOOL_NAME}`, `{REPLACE_TOOL_NAME}`, `{EDIT_TOOL_NAME}`, `{MULTI_EDIT_TOOL_NAME}`
136- Search: `{GREP_TOOL_NAME}` (text), `{AST_GREP_TOOL_NAME}` (syntax), `{GLOB_TOOL_NAME}` (paths), `{LS_TOOL_NAME}` (list directories)
137- Exec: `{BASH_TOOL_NAME}`
138- Web: `{FETCH_TOOL_NAME}`
139- Agents: `{DISPATCH_AGENT_TOOL_NAME}`
140- Todos: `{TODO_READ_TOOL_NAME}`, `{TODO_WRITE_TOOL_NAME}`
141
142Tool names are case-sensitive; use exact casing.
143
144### File path rules
145- `{VIEW_TOOL_NAME}`, `{REPLACE_TOOL_NAME}`, `{EDIT_TOOL_NAME}`, `{MULTI_EDIT_TOOL_NAME}`, and `{LS_TOOL_NAME}` require absolute paths.
146
147### Edit semantics
148- `{EDIT_TOOL_NAME}` uses exact string replacement (empty `old_string` creates a file).
149- `{MULTI_EDIT_TOOL_NAME}` applies multiple exact replacements in a single file.
150- `{REPLACE_TOOL_NAME}` overwrites the entire file contents.
151
152### Search guidance
153- Prefer `{GREP_TOOL_NAME}`/`{AST_GREP_TOOL_NAME}`/`{GLOB_TOOL_NAME}`/`{LS_TOOL_NAME}` over shelling out to `rg` via `{BASH_TOOL_NAME}`.
154
155### Todo guidance
156- Use `{TODO_READ_TOOL_NAME}`/`{TODO_WRITE_TOOL_NAME}` for complex or multi-step tasks; skip them for simple, single-step work unless the user asks.
157",
158 )
159}
160
161fn codex_instructions() -> String {
162 format!("{CODEX_SYSTEM_PROMPT}\n\n{}", steer_codex_bridge_prompt())
163}
164
165const CHATGPT_ACCOUNT_ID_NESTED_CLAIM: &str = "https://api.openai.com/auth";
166
167#[derive(Debug)]
168pub struct PkceChallenge {
169 pub verifier: String,
170 pub challenge: String,
171}
172
173#[derive(Debug, Clone, PartialEq, Eq, Hash)]
174pub struct ChatGptAccountId(pub String);
175
176#[derive(Clone)]
177pub struct OpenAIOAuth {
178 client_id: String,
179 redirect_uri: String,
180 http_client: reqwest::Client,
181}
182
183impl Default for OpenAIOAuth {
184 fn default() -> Self {
185 Self::new()
186 }
187}
188
189impl OpenAIOAuth {
190 pub fn new() -> Self {
191 Self {
192 client_id: CLIENT_ID.to_string(),
193 redirect_uri: REDIRECT_URI.to_string(),
194 http_client: reqwest::Client::new(),
195 }
196 }
197
198 pub fn generate_pkce() -> PkceChallenge {
199 let verifier = generate_random_string(128);
200 let challenge = base64_url_encode(&sha256(&verifier));
201 PkceChallenge {
202 verifier,
203 challenge,
204 }
205 }
206
207 pub fn generate_state() -> String {
208 generate_random_string(32)
209 }
210
211 pub fn build_auth_url(&self, pkce: &PkceChallenge, state: &str) -> String {
212 let params = [
213 ("response_type", "code"),
214 ("client_id", &self.client_id),
215 ("redirect_uri", &self.redirect_uri),
216 ("scope", SCOPES),
217 ("code_challenge", &pkce.challenge),
218 ("code_challenge_method", "S256"),
219 ("state", state),
220 ("id_token_add_organizations", "true"),
221 ("codex_cli_simplified_flow", "true"),
222 ("originator", ORIGINATOR),
223 ];
224
225 let query = serde_urlencoded::to_string(params).unwrap_or_default();
226 format!("{AUTHORIZE_URL}?{query}")
227 }
228
229 pub async fn exchange_code_for_tokens(
230 &self,
231 code: &str,
232 pkce_verifier: &str,
233 ) -> Result<AuthTokens> {
234 #[derive(Serialize)]
235 struct TokenRequest {
236 grant_type: String,
237 client_id: String,
238 code: String,
239 redirect_uri: String,
240 code_verifier: String,
241 }
242
243 #[derive(Deserialize)]
244 struct TokenResponse {
245 id_token: Option<String>,
246 access_token: String,
247 refresh_token: Option<String>,
248 expires_in: Option<u64>,
249 }
250
251 let request = TokenRequest {
252 grant_type: "authorization_code".to_string(),
253 client_id: self.client_id.clone(),
254 code: code.to_string(),
255 redirect_uri: self.redirect_uri.clone(),
256 code_verifier: pkce_verifier.to_string(),
257 };
258
259 let response = self
260 .http_client
261 .post(TOKEN_URL)
262 .form(&request)
263 .send()
264 .await?;
265
266 if !response.status().is_success() {
267 let status = response.status();
268 let error_text = response
269 .text()
270 .await
271 .unwrap_or_else(|_| "Unknown error".to_string());
272 return Err(AuthError::InvalidResponse(format!(
273 "Token exchange failed with status {status}: {error_text}"
274 )));
275 }
276
277 let token_response: TokenResponse = response.json().await.map_err(|e| {
278 AuthError::InvalidResponse(format!("Failed to parse token response: {e}"))
279 })?;
280
281 if token_response.access_token.trim().is_empty() {
282 return Err(AuthError::InvalidResponse(
283 "Empty access_token in token response".to_string(),
284 ));
285 }
286
287 let id_token = token_response.id_token.ok_or_else(|| {
288 AuthError::InvalidResponse("Missing id_token in token response".to_string())
289 })?;
290
291 let refresh_token = token_response.refresh_token.ok_or_else(|| {
292 AuthError::InvalidResponse("Missing refresh_token in token response".to_string())
293 })?;
294
295 let expires_at =
296 resolve_expires_at(token_response.expires_in, &token_response.access_token)?;
297
298 Ok(AuthTokens {
299 access_token: token_response.access_token,
300 refresh_token,
301 expires_at,
302 id_token: Some(id_token),
303 })
304 }
305
306 pub async fn refresh_tokens(&self, refresh_token: &str) -> Result<AuthTokens> {
307 #[derive(Serialize)]
308 struct RefreshRequest {
309 grant_type: String,
310 refresh_token: String,
311 client_id: String,
312 }
313
314 #[derive(Deserialize)]
315 struct TokenResponse {
316 id_token: Option<String>,
317 access_token: String,
318 refresh_token: Option<String>,
319 expires_in: Option<u64>,
320 }
321
322 let request = RefreshRequest {
323 grant_type: "refresh_token".to_string(),
324 refresh_token: refresh_token.to_string(),
325 client_id: self.client_id.clone(),
326 };
327
328 let response = self
329 .http_client
330 .post(TOKEN_URL)
331 .form(&request)
332 .send()
333 .await?;
334
335 if !response.status().is_success() {
336 if matches!(
337 response.status(),
338 reqwest::StatusCode::UNAUTHORIZED | reqwest::StatusCode::BAD_REQUEST
339 ) {
340 return Err(AuthError::ReauthRequired);
341 }
342
343 let status = response.status();
344 let error_text = response
345 .text()
346 .await
347 .unwrap_or_else(|_| "Unknown error".to_string());
348 return Err(AuthError::InvalidResponse(format!(
349 "Token refresh failed with status {status}: {error_text}"
350 )));
351 }
352
353 let token_response: TokenResponse = response.json().await.map_err(|e| {
354 AuthError::InvalidResponse(format!("Failed to parse refresh response: {e}"))
355 })?;
356
357 if token_response.access_token.trim().is_empty() {
358 return Err(AuthError::InvalidResponse(
359 "Empty access_token in refresh response".to_string(),
360 ));
361 }
362
363 let expires_at =
364 resolve_expires_at(token_response.expires_in, &token_response.access_token)?;
365
366 let refresh_token = token_response
367 .refresh_token
368 .unwrap_or_else(|| refresh_token.to_string());
369
370 Ok(AuthTokens {
371 access_token: token_response.access_token,
372 refresh_token,
373 expires_at,
374 id_token: token_response.id_token,
375 })
376 }
377}
378
379pub fn tokens_need_refresh(tokens: &AuthTokens) -> bool {
381 match tokens.expires_at.duration_since(SystemTime::now()) {
382 Ok(duration) => duration.as_secs() <= 300,
383 Err(_) => true,
384 }
385}
386
387pub async fn refresh_if_needed(
389 storage: &Arc<dyn AuthStorage>,
390 oauth_client: &OpenAIOAuth,
391) -> Result<AuthTokens> {
392 let credential = storage
393 .get_credential(PROVIDER_ID, CredentialType::OAuth2)
394 .await?
395 .ok_or(AuthError::ReauthRequired)?;
396
397 let mut tokens = match credential {
398 Credential::OAuth2(tokens) => tokens,
399 Credential::ApiKey { .. } => return Err(AuthError::ReauthRequired),
400 };
401
402 if tokens.id_token.is_none() || tokens_need_refresh(&tokens) {
403 match oauth_client.refresh_tokens(&tokens.refresh_token).await {
404 Ok(new_tokens) => {
405 let merged_tokens = AuthTokens {
406 id_token: new_tokens.id_token.or(tokens.id_token),
407 ..new_tokens
408 };
409 storage
410 .set_credential(PROVIDER_ID, Credential::OAuth2(merged_tokens.clone()))
411 .await?;
412 tokens = merged_tokens;
413 }
414 Err(AuthError::ReauthRequired) => {
415 storage
416 .remove_credential(PROVIDER_ID, CredentialType::OAuth2)
417 .await?;
418 return Err(AuthError::ReauthRequired);
419 }
420 Err(e) => return Err(e),
421 }
422 }
423
424 if tokens.id_token.is_none() {
425 return Err(AuthError::ReauthRequired);
426 }
427
428 Ok(tokens)
429}
430
431async fn force_refresh(
432 storage: &Arc<dyn AuthStorage>,
433 oauth_client: &OpenAIOAuth,
434) -> Result<AuthTokens> {
435 let credential = storage
436 .get_credential(PROVIDER_ID, CredentialType::OAuth2)
437 .await?
438 .ok_or(AuthError::ReauthRequired)?;
439
440 let tokens = match credential {
441 Credential::OAuth2(tokens) => tokens,
442 Credential::ApiKey { .. } => return Err(AuthError::ReauthRequired),
443 };
444
445 match oauth_client.refresh_tokens(&tokens.refresh_token).await {
446 Ok(new_tokens) => {
447 let merged_tokens = AuthTokens {
448 id_token: new_tokens.id_token.or(tokens.id_token),
449 ..new_tokens
450 };
451 storage
452 .set_credential(PROVIDER_ID, Credential::OAuth2(merged_tokens.clone()))
453 .await?;
454 Ok(merged_tokens)
455 }
456 Err(AuthError::ReauthRequired) => {
457 storage
458 .remove_credential(PROVIDER_ID, CredentialType::OAuth2)
459 .await?;
460 Err(AuthError::ReauthRequired)
461 }
462 Err(e) => Err(e),
463 }
464}
465
466pub fn extract_chatgpt_account_id(id_token: &str) -> Result<ChatGptAccountId> {
467 extract_chatgpt_account_id_from_id_token(id_token)
468}
469
470fn resolve_expires_at(expires_in: Option<u64>, access_token: &str) -> Result<SystemTime> {
471 if let Some(expires_in) = expires_in {
472 return Ok(SystemTime::now() + Duration::from_secs(expires_in));
473 }
474
475 let payload = decode_jwt_payload(access_token)?;
476 let exp = payload
477 .get("exp")
478 .and_then(|v| v.as_u64().or_else(|| v.as_i64().map(|v| v as u64)))
479 .ok_or_else(|| {
480 AuthError::InvalidResponse("Missing exp claim in access token".to_string())
481 })?;
482
483 Ok(UNIX_EPOCH + Duration::from_secs(exp))
484}
485
486fn decode_jwt_payload(access_token: &str) -> Result<serde_json::Value> {
487 let parts: Vec<&str> = access_token.split('.').collect();
488 if parts.len() < 2 {
489 return Err(AuthError::InvalidResponse(
490 "Invalid access token format".to_string(),
491 ));
492 }
493
494 let payload_bytes = URL_SAFE_NO_PAD
495 .decode(parts[1])
496 .map_err(|e| AuthError::InvalidResponse(format!("Invalid token payload: {e}")))?;
497
498 serde_json::from_slice(&payload_bytes)
499 .map_err(|e| AuthError::InvalidResponse(format!("Invalid token payload JSON: {e}")))
500}
501
502fn extract_chatgpt_account_id_from_id_token(id_token: &str) -> Result<ChatGptAccountId> {
503 let payload = decode_jwt_payload(id_token)?;
504
505 if let Some(account_id) = payload
506 .get(CHATGPT_ACCOUNT_ID_NESTED_CLAIM)
507 .and_then(|v| v.get("chatgpt_account_id"))
508 .and_then(|v| v.as_str())
509 .filter(|s| !s.is_empty())
510 {
511 return Ok(ChatGptAccountId(account_id.to_string()));
512 }
513
514 if let Some(account_id) = payload
515 .get("chatgpt_account_id")
516 .and_then(|v| v.as_str())
517 .filter(|s| !s.is_empty())
518 {
519 return Ok(ChatGptAccountId(account_id.to_string()));
520 }
521
522 Err(AuthError::InvalidResponse(
523 "Missing chatgpt account id in token".to_string(),
524 ))
525}
526
527fn parse_callback_input(input: &str) -> Result<CallbackResponse> {
528 let trimmed = input.trim();
529
530 if trimmed.contains("code=") && trimmed.contains("state=") {
531 let query = if trimmed.contains("://") {
532 let url = url::Url::parse(trimmed)
533 .map_err(|_| AuthError::InvalidCredential("Invalid redirect URL".to_string()))?;
534 url.query().unwrap_or("").to_string()
535 } else {
536 trimmed.to_string()
537 };
538
539 let params: std::collections::HashMap<String, String> =
540 url::form_urlencoded::parse(query.as_bytes())
541 .into_owned()
542 .collect();
543
544 let code = params
545 .get("code")
546 .ok_or_else(|| AuthError::MissingInput("code parameter".to_string()))?;
547 let state = params
548 .get("state")
549 .ok_or_else(|| AuthError::MissingInput("state parameter".to_string()))?;
550
551 return Ok(CallbackResponse {
552 code: code.clone(),
553 state: state.clone(),
554 });
555 }
556
557 if let Some((code, state)) = trimmed.split_once('#') {
558 if code.is_empty() || state.is_empty() {
559 return Err(AuthError::InvalidResponse(
560 "Invalid callback code format".to_string(),
561 ));
562 }
563 return Ok(CallbackResponse {
564 code: code.to_string(),
565 state: state.to_string(),
566 });
567 }
568
569 let parts: Vec<&str> = trimmed.split_whitespace().collect();
570 if parts.len() == 2 {
571 return Ok(CallbackResponse {
572 code: parts[0].to_string(),
573 state: parts[1].to_string(),
574 });
575 }
576
577 Err(AuthError::InvalidResponse(
578 "Invalid callback input. Paste the full redirect URL or code#state.".to_string(),
579 ))
580}
581
582fn generate_random_string(length: usize) -> String {
583 const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~";
584 let mut rng = rand::thread_rng();
585
586 (0..length)
587 .map(|_| {
588 let idx = rng.gen_range(0..CHARSET.len());
589 CHARSET[idx] as char
590 })
591 .collect()
592}
593
594fn sha256(data: &str) -> Vec<u8> {
595 let mut hasher = Sha256::new();
596 hasher.update(data.as_bytes());
597 hasher.finalize().to_vec()
598}
599
600fn base64_url_encode(data: &[u8]) -> String {
601 URL_SAFE_NO_PAD.encode(data)
602}
603
604#[derive(Debug)]
605pub struct OpenAIAuthState {
606 pub kind: OpenAIAuthStateKind,
607}
608
609#[derive(Debug)]
610pub enum OpenAIAuthStateKind {
611 OAuthStarted {
612 verifier: String,
613 state: String,
614 auth_url: String,
615 callback_server: Option<CallbackServerHandle>,
616 },
617}
618
619pub struct OpenAIOAuthFlow {
620 storage: Arc<dyn AuthStorage>,
621 oauth_client: OpenAIOAuth,
622}
623
624impl OpenAIOAuthFlow {
625 pub fn new(storage: Arc<dyn AuthStorage>) -> Self {
626 Self {
627 storage,
628 oauth_client: OpenAIOAuth::new(),
629 }
630 }
631}
632
633#[async_trait]
634impl AuthenticationFlow for OpenAIOAuthFlow {
635 type State = OpenAIAuthState;
636
637 fn available_methods(&self) -> Vec<AuthMethod> {
638 vec![AuthMethod::OAuth]
639 }
640
641 async fn start_auth(&self, method: AuthMethod) -> Result<Self::State> {
642 match method {
643 AuthMethod::OAuth => {
644 let pkce = OpenAIOAuth::generate_pkce();
645 let state = OpenAIOAuth::generate_state();
646 let auth_url = self.oauth_client.build_auth_url(&pkce, &state);
647
648 let callback_server = match spawn_callback_server(
649 state.clone(),
650 SocketAddr::from(([127, 0, 0, 1], CALLBACK_PORT)),
651 CALLBACK_PATH,
652 )
653 .await
654 {
655 Ok(handle) => Some(handle),
656 Err(err) => {
657 info!(
658 "OpenAI OAuth callback server unavailable, falling back to manual paste: {}",
659 err
660 );
661 None
662 }
663 };
664
665 Ok(OpenAIAuthState {
666 kind: OpenAIAuthStateKind::OAuthStarted {
667 verifier: pkce.verifier,
668 state,
669 auth_url,
670 callback_server,
671 },
672 })
673 }
674 AuthMethod::ApiKey => Err(AuthError::UnsupportedMethod {
675 method: format!("{method:?}"),
676 provider: PROVIDER_ID.to_string(),
677 }),
678 }
679 }
680
681 async fn get_initial_progress(
682 &self,
683 state: &Self::State,
684 method: AuthMethod,
685 ) -> Result<AuthProgress> {
686 match method {
687 AuthMethod::OAuth => {
688 let OpenAIAuthStateKind::OAuthStarted { auth_url, .. } = &state.kind;
689 Ok(AuthProgress::OAuthStarted {
690 auth_url: auth_url.clone(),
691 })
692 }
693 AuthMethod::ApiKey => Err(AuthError::UnsupportedMethod {
694 method: format!("{method:?}"),
695 provider: PROVIDER_ID.to_string(),
696 }),
697 }
698 }
699
700 async fn handle_input(&self, state: &mut Self::State, input: &str) -> Result<AuthProgress> {
701 match &mut state.kind {
702 OpenAIAuthStateKind::OAuthStarted {
703 verifier,
704 state: expected_state,
705 callback_server,
706 ..
707 } => {
708 let callback = if input.trim().is_empty() {
709 if let Some(server) = callback_server {
710 if let Some(result) = server.try_recv() {
711 result?
712 } else {
713 return Ok(AuthProgress::InProgress(
714 "Waiting for OAuth callback...".to_string(),
715 ));
716 }
717 } else {
718 return Ok(AuthProgress::NeedInput(
719 "Paste the redirect URL from your browser".to_string(),
720 ));
721 }
722 } else {
723 parse_callback_input(input)?
724 };
725
726 if callback.state != *expected_state {
727 return Err(AuthError::StateMismatch);
728 }
729
730 let tokens = self
731 .oauth_client
732 .exchange_code_for_tokens(&callback.code, verifier)
733 .await?;
734
735 self.storage
736 .set_credential(PROVIDER_ID, Credential::OAuth2(tokens))
737 .await?;
738
739 if let Some(server) = callback_server.take() {
740 drop(server);
741 }
742
743 Ok(AuthProgress::Complete)
744 }
745 }
746 }
747
748 async fn is_authenticated(&self) -> Result<bool> {
749 if let Some(Credential::OAuth2(tokens)) = self
750 .storage
751 .get_credential(PROVIDER_ID, CredentialType::OAuth2)
752 .await?
753 {
754 return Ok(tokens.id_token.is_some() && !tokens_need_refresh(&tokens));
755 }
756
757 Ok(false)
758 }
759
760 fn provider_name(&self) -> String {
761 PROVIDER_ID.to_string()
762 }
763}
764
765#[derive(Clone)]
766struct OpenAiHeaderProvider {
767 storage: Arc<dyn AuthStorage>,
768 oauth: OpenAIOAuth,
769}
770
771impl OpenAiHeaderProvider {
772 fn new(storage: Arc<dyn AuthStorage>) -> Self {
773 Self {
774 storage,
775 oauth: OpenAIOAuth::new(),
776 }
777 }
778
779 async fn header_pairs(&self, _ctx: AuthHeaderContext) -> Result<Vec<HeaderPair>> {
780 let tokens = refresh_if_needed(&self.storage, &self.oauth).await?;
781 let id_token = tokens
782 .id_token
783 .as_deref()
784 .ok_or(AuthError::ReauthRequired)?;
785 let account_id = extract_chatgpt_account_id(id_token)?;
786
787 Ok(vec![
788 HeaderPair {
789 name: "authorization".to_string(),
790 value: format!("Bearer {}", tokens.access_token),
791 },
792 HeaderPair {
793 name: "chatgpt-account-id".to_string(),
794 value: account_id.0,
795 },
796 HeaderPair {
797 name: "openai-beta".to_string(),
798 value: OPENAI_BETA.to_string(),
799 },
800 HeaderPair {
801 name: "originator".to_string(),
802 value: ORIGINATOR.to_string(),
803 },
804 ])
805 }
806}
807
808#[async_trait]
809impl AuthHeaderProvider for OpenAiHeaderProvider {
810 async fn headers(&self, ctx: AuthHeaderContext) -> Result<Vec<HeaderPair>> {
811 self.header_pairs(ctx).await
812 }
813
814 async fn on_auth_error(&self, _ctx: AuthErrorContext) -> Result<AuthErrorAction> {
815 match force_refresh(&self.storage, &self.oauth).await {
816 Ok(_) => Ok(AuthErrorAction::RetryOnce),
817 Err(AuthError::ReauthRequired) => Ok(AuthErrorAction::ReauthRequired),
818 Err(err) => Err(err),
819 }
820 }
821}
822
823struct OpenAiModelVisibility;
824
825impl ModelVisibilityPolicy for OpenAiModelVisibility {
826 fn allow_model(&self, model_id: &ModelId, auth_source: &AuthSource) -> bool {
827 if model_id.provider_id.0 != PROVIDER_ID {
828 return true;
829 }
830
831 if matches!(
832 model_id.model_id.as_str(),
833 GPT_5_2_CODEX_MODEL_ID | GPT_5_3_CODEX_MODEL_ID
834 ) {
835 return matches!(auth_source, AuthSource::Plugin { .. });
836 }
837
838 true
839 }
840}
841
842#[derive(Clone)]
843pub struct OpenAiAuthPlugin;
844
845impl Default for OpenAiAuthPlugin {
846 fn default() -> Self {
847 Self::new()
848 }
849}
850
851impl OpenAiAuthPlugin {
852 pub fn new() -> Self {
853 Self
854 }
855}
856
857#[async_trait]
858impl AuthPlugin for OpenAiAuthPlugin {
859 fn provider_id(&self) -> ProviderId {
860 ProviderId(PROVIDER_ID.to_string())
861 }
862
863 fn supported_methods(&self) -> Vec<AuthMethod> {
864 vec![AuthMethod::OAuth]
865 }
866
867 fn create_flow(&self, storage: Arc<dyn AuthStorage>) -> Option<Box<dyn DynAuthenticationFlow>> {
868 Some(Box::new(steer_auth_plugin::AuthFlowWrapper::new(
869 OpenAIOAuthFlow::new(storage),
870 )))
871 }
872
873 async fn resolve_auth(&self, storage: Arc<dyn AuthStorage>) -> Result<Option<AuthDirective>> {
874 let is_authenticated = self.is_authenticated(storage.clone()).await?;
875 if !is_authenticated {
876 return Ok(None);
877 }
878
879 let headers = Arc::new(OpenAiHeaderProvider::new(storage));
880 let directive = OpenAiResponsesAuth {
881 headers,
882 base_url_override: Some(CODEX_BASE_URL.to_string()),
883 require_streaming: Some(true),
884 instruction_policy: Some(InstructionPolicy::Override(codex_instructions())),
885 include: Some(vec!["reasoning.encrypted_content".to_string()]),
886 };
887
888 Ok(Some(AuthDirective::OpenAiResponses(directive)))
889 }
890
891 async fn is_authenticated(&self, storage: Arc<dyn AuthStorage>) -> Result<bool> {
892 if let Some(Credential::OAuth2(tokens)) = storage
893 .get_credential(PROVIDER_ID, CredentialType::OAuth2)
894 .await?
895 {
896 return Ok(tokens.id_token.is_some() && !tokens_need_refresh(&tokens));
897 }
898
899 Ok(false)
900 }
901
902 fn model_visibility(&self) -> Option<Box<dyn ModelVisibilityPolicy>> {
903 Some(Box::new(OpenAiModelVisibility))
904 }
905}
906
907#[cfg(test)]
908mod tests {
909 use super::*;
910 use steer_auth_plugin::{AuthMethod, AuthSource};
911
912 #[test]
913 fn test_auth_url_building() {
914 let oauth = OpenAIOAuth::new();
915 let pkce = OpenAIOAuth::generate_pkce();
916 let state = OpenAIOAuth::generate_state();
917
918 let url = oauth.build_auth_url(&pkce, &state);
919
920 assert!(url.contains(AUTHORIZE_URL));
921 assert!(url.contains(&format!("client_id={CLIENT_ID}")));
922 assert!(url.contains("response_type=code"));
923 assert!(url.contains("code_challenge="));
924 assert!(url.contains("code_challenge_method=S256"));
925 assert!(url.contains("id_token_add_organizations=true"));
926 assert!(url.contains("codex_cli_simplified_flow=true"));
927 assert!(url.contains(&format!("originator={ORIGINATOR}")));
928 assert!(url.contains("redirect_uri=http%3A%2F%2Flocalhost%3A1455%2Fauth%2Fcallback"));
929 }
930
931 #[test]
932 fn test_parse_callback_input_url() {
933 let input = "http://localhost:1455/auth/callback?code=abc123&state=state456";
934 let parsed = parse_callback_input(input).unwrap();
935 assert_eq!(parsed.code, "abc123");
936 assert_eq!(parsed.state, "state456");
937 }
938
939 #[test]
940 fn test_extract_chatgpt_account_id() {
941 let payload = serde_json::json!({
942 CHATGPT_ACCOUNT_ID_NESTED_CLAIM: {
943 "chatgpt_account_id": "acct_123"
944 },
945 "exp": 1_700_000_000u64
946 });
947 let token = make_jwt(payload);
948 let account_id = extract_chatgpt_account_id(&token).unwrap();
949 assert_eq!(account_id.0, "acct_123");
950 }
951
952 #[test]
953 fn test_extract_chatgpt_account_id_nested_claim() {
954 let payload = serde_json::json!({
955 CHATGPT_ACCOUNT_ID_NESTED_CLAIM: {
956 "chatgpt_account_id": "acct_nested"
957 },
958 "exp": 1_700_000_000u64
959 });
960 let token = make_jwt(payload);
961 let account_id = extract_chatgpt_account_id(&token).unwrap();
962 assert_eq!(account_id.0, "acct_nested");
963 }
964
965 #[test]
966 fn test_resolve_expires_at_from_token() {
967 let payload = serde_json::json!({
968 "chatgpt_account_id": "acct_123",
969 "exp": 1_700_000_000u64
970 });
971 let token = make_jwt(payload);
972 let exp = resolve_expires_at(None, &token).unwrap();
973 assert_eq!(exp, UNIX_EPOCH + Duration::from_secs(1_700_000_000u64));
974 }
975
976 #[test]
977 fn test_openai_codex_models_require_plugin_auth() {
978 let visibility = OpenAiModelVisibility;
979 let codex_5_2 = ModelId {
980 provider_id: ProviderId(PROVIDER_ID.to_string()),
981 model_id: GPT_5_2_CODEX_MODEL_ID.to_string(),
982 };
983 let codex_5_3 = ModelId {
984 provider_id: ProviderId(PROVIDER_ID.to_string()),
985 model_id: GPT_5_3_CODEX_MODEL_ID.to_string(),
986 };
987
988 assert!(visibility.allow_model(
989 &codex_5_2,
990 &AuthSource::Plugin {
991 method: AuthMethod::OAuth,
992 }
993 ));
994 assert!(visibility.allow_model(
995 &codex_5_3,
996 &AuthSource::Plugin {
997 method: AuthMethod::OAuth,
998 }
999 ));
1000 assert!(!visibility.allow_model(
1001 &codex_5_2,
1002 &AuthSource::ApiKey {
1003 origin: steer_auth_plugin::ApiKeyOrigin::Env,
1004 }
1005 ));
1006 assert!(!visibility.allow_model(
1007 &codex_5_3,
1008 &AuthSource::ApiKey {
1009 origin: steer_auth_plugin::ApiKeyOrigin::Stored,
1010 }
1011 ));
1012 }
1013
1014 fn make_jwt(payload: serde_json::Value) -> String {
1015 let header = base64_url_encode(b"{}");
1016 let payload = base64_url_encode(payload.to_string().as_bytes());
1017 format!("{header}.{payload}.sig")
1018 }
1019}