Skip to main content

codetether_agent/provider/
openai_codex.rs

1//! OpenAI Codex provider using ChatGPT Plus/Pro subscription via OAuth
2//!
3//! This provider uses the OAuth PKCE flow that the official OpenAI Codex CLI uses,
4//! allowing users to authenticate with their ChatGPT subscription instead of API credits.
5//!
6//! Reference: https://github.com/numman-ali/opencode-openai-codex-auth
7
8use super::{
9    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
10    Role, StreamChunk, ToolDefinition, Usage,
11};
12use anyhow::{Context, Result};
13use async_trait::async_trait;
14use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
15use futures::StreamExt;
16use futures::stream::BoxStream;
17use reqwest::Client;
18use serde::{Deserialize, Serialize};
19use serde_json::{Value, json};
20use sha2::{Digest, Sha256};
21use std::sync::Arc;
22use tokio::sync::RwLock;
23
24const OPENAI_API_URL: &str = "https://api.openai.com/v1";
25const AUTHORIZE_URL: &str = "https://auth.openai.com/oauth/authorize";
26const TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
27const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
28const REDIRECT_URI: &str = "http://localhost:1455/auth/callback";
29const SCOPE: &str = "openid profile email offline_access";
30
31/// Cached OAuth tokens with expiration tracking
32struct CachedTokens {
33    access_token: String,
34    refresh_token: String,
35    expires_at: std::time::Instant,
36}
37
38/// Stored OAuth credentials (persisted to Vault)
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct OAuthCredentials {
41    pub access_token: String,
42    pub refresh_token: String,
43    pub expires_at: u64, // Unix timestamp in seconds
44}
45
46/// PKCE code verifier and challenge pair
47struct PkcePair {
48    verifier: String,
49    challenge: String,
50}
51
52pub struct OpenAiCodexProvider {
53    client: Client,
54    cached_tokens: Arc<RwLock<Option<CachedTokens>>>,
55    /// Stored credentials from Vault (for refresh on startup)
56    stored_credentials: Option<OAuthCredentials>,
57}
58
59impl std::fmt::Debug for OpenAiCodexProvider {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        f.debug_struct("OpenAiCodexProvider")
62            .field("has_credentials", &self.stored_credentials.is_some())
63            .finish()
64    }
65}
66
67impl OpenAiCodexProvider {
68    /// Create from stored OAuth credentials (from Vault)
69    pub fn from_credentials(credentials: OAuthCredentials) -> Self {
70        Self {
71            client: Client::new(),
72            cached_tokens: Arc::new(RwLock::new(None)),
73            stored_credentials: Some(credentials),
74        }
75    }
76
77    /// Create a new unauthenticated instance (requires OAuth flow)
78    #[allow(dead_code)]
79    pub fn new() -> Self {
80        Self {
81            client: Client::new(),
82            cached_tokens: Arc::new(RwLock::new(None)),
83            stored_credentials: None,
84        }
85    }
86
87    /// Generate PKCE code verifier and challenge
88    fn generate_pkce() -> PkcePair {
89        let random_bytes: [u8; 32] = {
90            let timestamp = std::time::SystemTime::now()
91                .duration_since(std::time::UNIX_EPOCH)
92                .map(|d| d.as_nanos())
93                .unwrap_or(0);
94
95            let mut bytes = [0u8; 32];
96            let ts_bytes = timestamp.to_le_bytes();
97            let tid = std::thread::current().id();
98            let tid_repr = format!("{:?}", tid);
99            let tid_hash = Sha256::digest(tid_repr.as_bytes());
100
101            bytes[0..8].copy_from_slice(&ts_bytes);
102            bytes[8..24].copy_from_slice(&tid_hash[0..16]);
103            bytes[24..].copy_from_slice(&Sha256::digest(&ts_bytes)[0..8]);
104            bytes
105        };
106        let verifier = URL_SAFE_NO_PAD.encode(&random_bytes);
107
108        let mut hasher = Sha256::new();
109        hasher.update(verifier.as_bytes());
110        let challenge_bytes = hasher.finalize();
111        let challenge = URL_SAFE_NO_PAD.encode(&challenge_bytes);
112
113        PkcePair {
114            verifier,
115            challenge,
116        }
117    }
118
119    /// Generate random state value
120    fn generate_state() -> String {
121        let timestamp = std::time::SystemTime::now()
122            .duration_since(std::time::UNIX_EPOCH)
123            .map(|d| d.as_nanos())
124            .unwrap_or(0);
125        let random: [u8; 8] = {
126            let ptr = Box::into_raw(Box::new(timestamp)) as usize;
127            let bytes = ptr.to_le_bytes();
128            let mut arr = [0u8; 8];
129            arr.copy_from_slice(&bytes);
130            arr
131        };
132        format!("{:016x}{:016x}", timestamp, u64::from_le_bytes(random))
133    }
134
135    /// Get the OAuth authorization URL for the user to visit
136    #[allow(dead_code)]
137    pub fn get_authorization_url() -> (String, String, String) {
138        let pkce = Self::generate_pkce();
139        let state = Self::generate_state();
140
141        let url = format!(
142            "{}?response_type=code&client_id={}&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method=S256&state={}&id_token_add_organizations=true&codex_cli_simplified_flow=true&originator=codex_cli_rs",
143            AUTHORIZE_URL,
144            CLIENT_ID,
145            urlencoding::encode(REDIRECT_URI),
146            urlencoding::encode(SCOPE),
147            pkce.challenge,
148            state
149        );
150
151        (url, pkce.verifier, state)
152    }
153
154    /// Exchange authorization code for tokens
155    #[allow(dead_code)]
156    pub async fn exchange_code(code: &str, verifier: &str) -> Result<OAuthCredentials> {
157        let client = Client::new();
158        let form_body = format!(
159            "grant_type={}&client_id={}&code={}&code_verifier={}&redirect_uri={}",
160            urlencoding::encode("authorization_code"),
161            CLIENT_ID,
162            urlencoding::encode(code),
163            urlencoding::encode(verifier),
164            urlencoding::encode(REDIRECT_URI),
165        );
166
167        let response = client
168            .post(TOKEN_URL)
169            .header("Content-Type", "application/x-www-form-urlencoded")
170            .body(form_body)
171            .send()
172            .await
173            .context("Failed to exchange authorization code")?;
174
175        if !response.status().is_success() {
176            let body = response.text().await.unwrap_or_default();
177            anyhow::bail!("OAuth token exchange failed: {}", body);
178        }
179
180        #[derive(Deserialize)]
181        struct TokenResponse {
182            access_token: String,
183            refresh_token: String,
184            expires_in: u64,
185        }
186
187        let tokens: TokenResponse = response
188            .json()
189            .await
190            .context("Failed to parse token response")?;
191
192        let expires_at = std::time::SystemTime::now()
193            .duration_since(std::time::UNIX_EPOCH)
194            .context("System time error")?
195            .as_secs()
196            + tokens.expires_in;
197
198        Ok(OAuthCredentials {
199            access_token: tokens.access_token,
200            refresh_token: tokens.refresh_token,
201            expires_at,
202        })
203    }
204
205    /// Refresh access token using refresh token
206    async fn refresh_access_token(&self, refresh_token: &str) -> Result<OAuthCredentials> {
207        let form_body = format!(
208            "grant_type={}&refresh_token={}&client_id={}",
209            urlencoding::encode("refresh_token"),
210            urlencoding::encode(refresh_token),
211            CLIENT_ID,
212        );
213
214        let response = self
215            .client
216            .post(TOKEN_URL)
217            .header("Content-Type", "application/x-www-form-urlencoded")
218            .body(form_body)
219            .send()
220            .await
221            .context("Failed to refresh access token")?;
222
223        if !response.status().is_success() {
224            let body = response.text().await.unwrap_or_default();
225            anyhow::bail!("Token refresh failed: {}", body);
226        }
227
228        #[derive(Deserialize)]
229        struct TokenResponse {
230            access_token: String,
231            refresh_token: String,
232            expires_in: u64,
233        }
234
235        let tokens: TokenResponse = response
236            .json()
237            .await
238            .context("Failed to parse refresh response")?;
239
240        let expires_at = std::time::SystemTime::now()
241            .duration_since(std::time::UNIX_EPOCH)
242            .context("System time error")?
243            .as_secs()
244            + tokens.expires_in;
245
246        Ok(OAuthCredentials {
247            access_token: tokens.access_token,
248            refresh_token: tokens.refresh_token,
249            expires_at,
250        })
251    }
252
253    /// Get a valid access token, refreshing if necessary
254    async fn get_access_token(&self) -> Result<String> {
255        {
256            let cache = self.cached_tokens.read().await;
257            if let Some(ref tokens) = *cache {
258                if tokens
259                    .expires_at
260                    .duration_since(std::time::Instant::now())
261                    .as_secs()
262                    > 300
263                {
264                    return Ok(tokens.access_token.clone());
265                }
266            }
267        }
268
269        let mut cache = self.cached_tokens.write().await;
270
271        let creds = if let Some(ref stored) = self.stored_credentials {
272            let now = std::time::SystemTime::now()
273                .duration_since(std::time::UNIX_EPOCH)
274                .context("System time error")?
275                .as_secs();
276
277            if stored.expires_at > now + 300 {
278                stored.clone()
279            } else {
280                let new_creds = self.refresh_access_token(&stored.refresh_token).await?;
281                new_creds
282            }
283        } else {
284            anyhow::bail!("No OAuth credentials available. Run OAuth flow first.");
285        };
286
287        let expires_in = creds.expires_at
288            - std::time::SystemTime::now()
289                .duration_since(std::time::UNIX_EPOCH)
290                .context("System time error")?
291                .as_secs();
292
293        let cached = CachedTokens {
294            access_token: creds.access_token.clone(),
295            refresh_token: creds.refresh_token.clone(),
296            expires_at: std::time::Instant::now() + std::time::Duration::from_secs(expires_in),
297        };
298
299        let token = cached.access_token.clone();
300        *cache = Some(cached);
301        Ok(token)
302    }
303
304    fn convert_messages(messages: &[Message]) -> Vec<Value> {
305        messages
306            .iter()
307            .map(|msg| {
308                let role = match msg.role {
309                    Role::System => "system",
310                    Role::User => "user",
311                    Role::Assistant => "assistant",
312                    Role::Tool => "tool",
313                };
314
315                match msg.role {
316                    Role::Tool => {
317                        if let Some(ContentPart::ToolResult {
318                            tool_call_id,
319                            content,
320                        }) = msg.content.first()
321                        {
322                            json!({
323                                "role": "tool",
324                                "tool_call_id": tool_call_id,
325                                "content": content
326                            })
327                        } else {
328                            json!({ "role": role, "content": "" })
329                        }
330                    }
331                    Role::Assistant => {
332                        let text: String = msg
333                            .content
334                            .iter()
335                            .filter_map(|p| match p {
336                                ContentPart::Text { text } => Some(text.clone()),
337                                _ => None,
338                            })
339                            .collect::<Vec<_>>()
340                            .join("");
341
342                        let tool_calls: Vec<Value> = msg
343                            .content
344                            .iter()
345                            .filter_map(|p| match p {
346                                ContentPart::ToolCall {
347                                    id,
348                                    name,
349                                    arguments,
350                                    ..
351                                } => Some(json!({
352                                    "id": id,
353                                    "type": "function",
354                                    "function": {
355                                        "name": name,
356                                        "arguments": arguments
357                                    }
358                                })),
359                                _ => None,
360                            })
361                            .collect();
362
363                        if tool_calls.is_empty() {
364                            json!({ "role": "assistant", "content": text })
365                        } else {
366                            json!({
367                                "role": "assistant",
368                                "content": if text.is_empty() { Value::Null } else { json!(text) },
369                                "tool_calls": tool_calls
370                            })
371                        }
372                    }
373                    _ => {
374                        let text: String = msg
375                            .content
376                            .iter()
377                            .filter_map(|p| match p {
378                                ContentPart::Text { text } => Some(text.clone()),
379                                _ => None,
380                            })
381                            .collect::<Vec<_>>()
382                            .join("\n");
383                        json!({ "role": role, "content": text })
384                    }
385                }
386            })
387            .collect()
388    }
389
390    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
391        tools
392            .iter()
393            .map(|t| {
394                json!({
395                    "type": "function",
396                    "function": {
397                        "name": t.name,
398                        "description": t.description,
399                        "parameters": t.parameters
400                    }
401                })
402            })
403            .collect()
404    }
405}
406
407#[async_trait]
408impl Provider for OpenAiCodexProvider {
409    fn name(&self) -> &str {
410        "openai-codex"
411    }
412
413    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
414        Ok(vec![
415            ModelInfo {
416                id: "gpt-5".to_string(),
417                name: "GPT-5".to_string(),
418                provider: "openai-codex".to_string(),
419                context_window: 400_000,
420                max_output_tokens: Some(128_000),
421                supports_vision: false,
422                supports_tools: true,
423                supports_streaming: true,
424                input_cost_per_million: Some(0.0),
425                output_cost_per_million: Some(0.0),
426            },
427            ModelInfo {
428                id: "gpt-5-mini".to_string(),
429                name: "GPT-5 Mini".to_string(),
430                provider: "openai-codex".to_string(),
431                context_window: 264_000,
432                max_output_tokens: Some(64_000),
433                supports_vision: false,
434                supports_tools: true,
435                supports_streaming: true,
436                input_cost_per_million: Some(0.0),
437                output_cost_per_million: Some(0.0),
438            },
439            ModelInfo {
440                id: "gpt-5.1-codex".to_string(),
441                name: "GPT-5.1 Codex".to_string(),
442                provider: "openai-codex".to_string(),
443                context_window: 400_000,
444                max_output_tokens: Some(128_000),
445                supports_vision: false,
446                supports_tools: true,
447                supports_streaming: true,
448                input_cost_per_million: Some(0.0),
449                output_cost_per_million: Some(0.0),
450            },
451            ModelInfo {
452                id: "gpt-5.2".to_string(),
453                name: "GPT-5.2".to_string(),
454                provider: "openai-codex".to_string(),
455                context_window: 400_000,
456                max_output_tokens: Some(128_000),
457                supports_vision: false,
458                supports_tools: true,
459                supports_streaming: true,
460                input_cost_per_million: Some(0.0),
461                output_cost_per_million: Some(0.0),
462            },
463            ModelInfo {
464                id: "gpt-5.3-codex".to_string(),
465                name: "GPT-5.3 Codex".to_string(),
466                provider: "openai-codex".to_string(),
467                context_window: 400_000,
468                max_output_tokens: Some(128_000),
469                supports_vision: false,
470                supports_tools: true,
471                supports_streaming: true,
472                input_cost_per_million: Some(0.0),
473                output_cost_per_million: Some(0.0),
474            },
475            ModelInfo {
476                id: "o3".to_string(),
477                name: "O3".to_string(),
478                provider: "openai-codex".to_string(),
479                context_window: 200_000,
480                max_output_tokens: Some(100_000),
481                supports_vision: true,
482                supports_tools: true,
483                supports_streaming: true,
484                input_cost_per_million: Some(0.0),
485                output_cost_per_million: Some(0.0),
486            },
487            ModelInfo {
488                id: "o4-mini".to_string(),
489                name: "O4 Mini".to_string(),
490                provider: "openai-codex".to_string(),
491                context_window: 200_000,
492                max_output_tokens: Some(100_000),
493                supports_vision: true,
494                supports_tools: true,
495                supports_streaming: true,
496                input_cost_per_million: Some(0.0),
497                output_cost_per_million: Some(0.0),
498            },
499        ])
500    }
501
502    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
503        let access_token = self.get_access_token().await?;
504
505        let messages = Self::convert_messages(&request.messages);
506        let tools = Self::convert_tools(&request.tools);
507
508        let mut body = json!({
509            "model": request.model,
510            "messages": messages,
511        });
512
513        if !tools.is_empty() {
514            body["tools"] = json!(tools);
515        }
516        if let Some(temp) = request.temperature {
517            body["temperature"] = json!(temp);
518        }
519        if let Some(max_tokens) = request.max_tokens {
520            body["max_tokens"] = json!(max_tokens);
521        }
522
523        let response = self
524            .client
525            .post(format!("{}/chat/completions", OPENAI_API_URL))
526            .header("Authorization", format!("Bearer {}", access_token))
527            .header("Content-Type", "application/json")
528            .json(&body)
529            .send()
530            .await
531            .context("Failed to send request to OpenAI")?;
532
533        let status = response.status();
534        if !status.is_success() {
535            let body = response.text().await.unwrap_or_default();
536            anyhow::bail!("OpenAI API error ({}): {}", status, body);
537        }
538
539        #[derive(Deserialize)]
540        struct OpenAiResponse {
541            choices: Vec<OpenAiChoice>,
542            usage: Option<OpenAiUsage>,
543        }
544
545        #[derive(Deserialize)]
546        struct OpenAiChoice {
547            message: OpenAiMessage,
548            finish_reason: Option<String>,
549        }
550
551        #[derive(Deserialize)]
552        struct OpenAiMessage {
553            content: Option<String>,
554            tool_calls: Option<Vec<OpenAiToolCall>>,
555        }
556
557        #[derive(Deserialize)]
558        struct OpenAiToolCall {
559            id: String,
560            function: OpenAiFunction,
561        }
562
563        #[derive(Deserialize)]
564        struct OpenAiFunction {
565            name: String,
566            arguments: String,
567        }
568
569        #[derive(Deserialize)]
570        struct OpenAiUsage {
571            prompt_tokens: usize,
572            completion_tokens: usize,
573            total_tokens: usize,
574        }
575
576        let openai_resp: OpenAiResponse = response
577            .json()
578            .await
579            .context("Failed to parse OpenAI response")?;
580
581        let choice = openai_resp
582            .choices
583            .into_iter()
584            .next()
585            .context("No choices in response")?;
586
587        let mut content = Vec::new();
588
589        if let Some(text) = choice.message.content {
590            if !text.is_empty() {
591                content.push(ContentPart::Text { text });
592            }
593        }
594
595        if let Some(tool_calls) = choice.message.tool_calls {
596            for tc in tool_calls {
597                content.push(ContentPart::ToolCall {
598                    id: tc.id,
599                    name: tc.function.name,
600                    arguments: tc.function.arguments,
601                    thought_signature: None,
602                });
603            }
604        }
605
606        let finish_reason = match choice.finish_reason.as_deref() {
607            Some("stop") => FinishReason::Stop,
608            Some("tool_calls") => FinishReason::ToolCalls,
609            Some("length") => FinishReason::Length,
610            _ => FinishReason::Stop,
611        };
612
613        let usage = openai_resp
614            .usage
615            .map(|u| Usage {
616                prompt_tokens: u.prompt_tokens,
617                completion_tokens: u.completion_tokens,
618                total_tokens: u.total_tokens,
619                cache_read_tokens: None,
620                cache_write_tokens: None,
621            })
622            .unwrap_or_default();
623
624        Ok(CompletionResponse {
625            message: Message {
626                role: Role::Assistant,
627                content,
628            },
629            usage,
630            finish_reason,
631        })
632    }
633
634    async fn complete_stream(
635        &self,
636        request: CompletionRequest,
637    ) -> Result<BoxStream<'static, StreamChunk>> {
638        let access_token = self.get_access_token().await?;
639
640        let messages = Self::convert_messages(&request.messages);
641        let tools = Self::convert_tools(&request.tools);
642
643        let mut body = json!({
644            "model": request.model,
645            "messages": messages,
646            "stream": true,
647        });
648
649        if !tools.is_empty() {
650            body["tools"] = json!(tools);
651        }
652        if let Some(temp) = request.temperature {
653            body["temperature"] = json!(temp);
654        }
655        if let Some(max_tokens) = request.max_tokens {
656            body["max_tokens"] = json!(max_tokens);
657        }
658
659        let response = self
660            .client
661            .post(format!("{}/chat/completions", OPENAI_API_URL))
662            .header("Authorization", format!("Bearer {}", access_token))
663            .header("Content-Type", "application/json")
664            .json(&body)
665            .send()
666            .await
667            .context("Failed to send streaming request to OpenAI")?;
668
669        let status = response.status();
670        if !status.is_success() {
671            let body = response.text().await.unwrap_or_default();
672            anyhow::bail!("OpenAI API error ({}): {}", status, body);
673        }
674
675        let stream = response.bytes_stream().flat_map(|result| match result {
676            Ok(bytes) => {
677                let text = String::from_utf8_lossy(&bytes);
678                let mut chunks = Vec::new();
679
680                for line in text.lines() {
681                    if !line.starts_with("data: ") {
682                        continue;
683                    }
684                    let data = &line[6..];
685                    if data == "[DONE]" {
686                        chunks.push(StreamChunk::Done { usage: None });
687                        continue;
688                    }
689
690                    #[derive(Deserialize)]
691                    struct StreamResponse {
692                        choices: Vec<StreamChoice>,
693                    }
694                    #[derive(Deserialize)]
695                    struct StreamChoice {
696                        delta: StreamDelta,
697                        #[allow(dead_code)]
698                        finish_reason: Option<String>,
699                    }
700                    #[derive(Deserialize)]
701                    struct StreamDelta {
702                        content: Option<String>,
703                        tool_calls: Option<Vec<StreamToolCall>>,
704                    }
705                    #[derive(Deserialize)]
706                    struct StreamToolCall {
707                        id: Option<String>,
708                        function: Option<StreamFunction>,
709                    }
710                    #[derive(Deserialize)]
711                    struct StreamFunction {
712                        name: Option<String>,
713                        arguments: Option<String>,
714                    }
715
716                    if let Ok(resp) = serde_json::from_str::<StreamResponse>(data) {
717                        for choice in resp.choices {
718                            if let Some(content) = choice.delta.content {
719                                chunks.push(StreamChunk::Text(content));
720                            }
721                            if let Some(tool_calls) = choice.delta.tool_calls {
722                                for tc in tool_calls {
723                                    if let Some(id) = &tc.id {
724                                        if let Some(func) = &tc.function {
725                                            if let Some(name) = &func.name {
726                                                chunks.push(StreamChunk::ToolCallStart {
727                                                    id: id.clone(),
728                                                    name: name.clone(),
729                                                });
730                                            }
731                                            if let Some(args) = &func.arguments {
732                                                chunks.push(StreamChunk::ToolCallDelta {
733                                                    id: id.clone(),
734                                                    arguments_delta: args.clone(),
735                                                });
736                                            }
737                                        }
738                                    }
739                                }
740                            }
741                        }
742                    }
743                }
744                futures::stream::iter(chunks)
745            }
746            Err(e) => futures::stream::iter(vec![StreamChunk::Error(e.to_string())]),
747        });
748
749        Ok(Box::pin(stream))
750    }
751}
752
753#[cfg(test)]
754mod tests {
755    use super::*;
756
757    #[test]
758    fn test_generate_pkce() {
759        let pkce = OpenAiCodexProvider::generate_pkce();
760        assert!(!pkce.verifier.is_empty());
761        assert!(!pkce.challenge.is_empty());
762        assert_ne!(pkce.verifier, pkce.challenge);
763    }
764
765    #[test]
766    fn test_generate_state() {
767        let state = OpenAiCodexProvider::generate_state();
768        assert_eq!(state.len(), 32);
769    }
770}