1use super::{
9 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
10 Role, StreamChunk, ToolDefinition, Usage,
11};
12use anyhow::{Context, Result};
13use async_trait::async_trait;
14use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
15use futures::StreamExt;
16use futures::stream::BoxStream;
17use reqwest::Client;
18use serde::{Deserialize, Serialize};
19use serde_json::{Value, json};
20use sha2::{Digest, Sha256};
21use std::sync::Arc;
22use tokio::sync::RwLock;
23
24const OPENAI_API_URL: &str = "https://api.openai.com/v1";
25const AUTHORIZE_URL: &str = "https://auth.openai.com/oauth/authorize";
26const TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
27const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
28const REDIRECT_URI: &str = "http://localhost:1455/auth/callback";
29const SCOPE: &str = "openid profile email offline_access";
30
31struct CachedTokens {
33 access_token: String,
34 refresh_token: String,
35 expires_at: std::time::Instant,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct OAuthCredentials {
41 pub access_token: String,
42 pub refresh_token: String,
43 pub expires_at: u64, }
45
46struct PkcePair {
48 verifier: String,
49 challenge: String,
50}
51
52pub struct OpenAiCodexProvider {
53 client: Client,
54 cached_tokens: Arc<RwLock<Option<CachedTokens>>>,
55 stored_credentials: Option<OAuthCredentials>,
57}
58
59impl std::fmt::Debug for OpenAiCodexProvider {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 f.debug_struct("OpenAiCodexProvider")
62 .field("has_credentials", &self.stored_credentials.is_some())
63 .finish()
64 }
65}
66
67impl OpenAiCodexProvider {
68 pub fn from_credentials(credentials: OAuthCredentials) -> Self {
70 Self {
71 client: Client::new(),
72 cached_tokens: Arc::new(RwLock::new(None)),
73 stored_credentials: Some(credentials),
74 }
75 }
76
77 #[allow(dead_code)]
79 pub fn new() -> Self {
80 Self {
81 client: Client::new(),
82 cached_tokens: Arc::new(RwLock::new(None)),
83 stored_credentials: None,
84 }
85 }
86
87 fn generate_pkce() -> PkcePair {
89 let random_bytes: [u8; 32] = {
90 let timestamp = std::time::SystemTime::now()
91 .duration_since(std::time::UNIX_EPOCH)
92 .map(|d| d.as_nanos())
93 .unwrap_or(0);
94
95 let mut bytes = [0u8; 32];
96 let ts_bytes = timestamp.to_le_bytes();
97 let tid = std::thread::current().id();
98 let tid_repr = format!("{:?}", tid);
99 let tid_hash = Sha256::digest(tid_repr.as_bytes());
100
101 bytes[0..8].copy_from_slice(&ts_bytes);
102 bytes[8..24].copy_from_slice(&tid_hash[0..16]);
103 bytes[24..].copy_from_slice(&Sha256::digest(&ts_bytes)[0..8]);
104 bytes
105 };
106 let verifier = URL_SAFE_NO_PAD.encode(&random_bytes);
107
108 let mut hasher = Sha256::new();
109 hasher.update(verifier.as_bytes());
110 let challenge_bytes = hasher.finalize();
111 let challenge = URL_SAFE_NO_PAD.encode(&challenge_bytes);
112
113 PkcePair {
114 verifier,
115 challenge,
116 }
117 }
118
119 fn generate_state() -> String {
121 let timestamp = std::time::SystemTime::now()
122 .duration_since(std::time::UNIX_EPOCH)
123 .map(|d| d.as_nanos())
124 .unwrap_or(0);
125 let random: [u8; 8] = {
126 let ptr = Box::into_raw(Box::new(timestamp)) as usize;
127 let bytes = ptr.to_le_bytes();
128 let mut arr = [0u8; 8];
129 arr.copy_from_slice(&bytes);
130 arr
131 };
132 format!("{:016x}{:016x}", timestamp, u64::from_le_bytes(random))
133 }
134
135 #[allow(dead_code)]
137 pub fn get_authorization_url() -> (String, String, String) {
138 let pkce = Self::generate_pkce();
139 let state = Self::generate_state();
140
141 let url = format!(
142 "{}?response_type=code&client_id={}&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method=S256&state={}&id_token_add_organizations=true&codex_cli_simplified_flow=true&originator=codex_cli_rs",
143 AUTHORIZE_URL,
144 CLIENT_ID,
145 urlencoding::encode(REDIRECT_URI),
146 urlencoding::encode(SCOPE),
147 pkce.challenge,
148 state
149 );
150
151 (url, pkce.verifier, state)
152 }
153
154 #[allow(dead_code)]
156 pub async fn exchange_code(code: &str, verifier: &str) -> Result<OAuthCredentials> {
157 let client = Client::new();
158 let form_body = format!(
159 "grant_type={}&client_id={}&code={}&code_verifier={}&redirect_uri={}",
160 urlencoding::encode("authorization_code"),
161 CLIENT_ID,
162 urlencoding::encode(code),
163 urlencoding::encode(verifier),
164 urlencoding::encode(REDIRECT_URI),
165 );
166
167 let response = client
168 .post(TOKEN_URL)
169 .header("Content-Type", "application/x-www-form-urlencoded")
170 .body(form_body)
171 .send()
172 .await
173 .context("Failed to exchange authorization code")?;
174
175 if !response.status().is_success() {
176 let body = response.text().await.unwrap_or_default();
177 anyhow::bail!("OAuth token exchange failed: {}", body);
178 }
179
180 #[derive(Deserialize)]
181 struct TokenResponse {
182 access_token: String,
183 refresh_token: String,
184 expires_in: u64,
185 }
186
187 let tokens: TokenResponse = response
188 .json()
189 .await
190 .context("Failed to parse token response")?;
191
192 let expires_at = std::time::SystemTime::now()
193 .duration_since(std::time::UNIX_EPOCH)
194 .context("System time error")?
195 .as_secs()
196 + tokens.expires_in;
197
198 Ok(OAuthCredentials {
199 access_token: tokens.access_token,
200 refresh_token: tokens.refresh_token,
201 expires_at,
202 })
203 }
204
205 async fn refresh_access_token(&self, refresh_token: &str) -> Result<OAuthCredentials> {
207 let form_body = format!(
208 "grant_type={}&refresh_token={}&client_id={}",
209 urlencoding::encode("refresh_token"),
210 urlencoding::encode(refresh_token),
211 CLIENT_ID,
212 );
213
214 let response = self
215 .client
216 .post(TOKEN_URL)
217 .header("Content-Type", "application/x-www-form-urlencoded")
218 .body(form_body)
219 .send()
220 .await
221 .context("Failed to refresh access token")?;
222
223 if !response.status().is_success() {
224 let body = response.text().await.unwrap_or_default();
225 anyhow::bail!("Token refresh failed: {}", body);
226 }
227
228 #[derive(Deserialize)]
229 struct TokenResponse {
230 access_token: String,
231 refresh_token: String,
232 expires_in: u64,
233 }
234
235 let tokens: TokenResponse = response
236 .json()
237 .await
238 .context("Failed to parse refresh response")?;
239
240 let expires_at = std::time::SystemTime::now()
241 .duration_since(std::time::UNIX_EPOCH)
242 .context("System time error")?
243 .as_secs()
244 + tokens.expires_in;
245
246 Ok(OAuthCredentials {
247 access_token: tokens.access_token,
248 refresh_token: tokens.refresh_token,
249 expires_at,
250 })
251 }
252
253 async fn get_access_token(&self) -> Result<String> {
255 {
256 let cache = self.cached_tokens.read().await;
257 if let Some(ref tokens) = *cache {
258 if tokens
259 .expires_at
260 .duration_since(std::time::Instant::now())
261 .as_secs()
262 > 300
263 {
264 return Ok(tokens.access_token.clone());
265 }
266 }
267 }
268
269 let mut cache = self.cached_tokens.write().await;
270
271 let creds = if let Some(ref stored) = self.stored_credentials {
272 let now = std::time::SystemTime::now()
273 .duration_since(std::time::UNIX_EPOCH)
274 .context("System time error")?
275 .as_secs();
276
277 if stored.expires_at > now + 300 {
278 stored.clone()
279 } else {
280 let new_creds = self.refresh_access_token(&stored.refresh_token).await?;
281 new_creds
282 }
283 } else {
284 anyhow::bail!("No OAuth credentials available. Run OAuth flow first.");
285 };
286
287 let expires_in = creds.expires_at
288 - std::time::SystemTime::now()
289 .duration_since(std::time::UNIX_EPOCH)
290 .context("System time error")?
291 .as_secs();
292
293 let cached = CachedTokens {
294 access_token: creds.access_token.clone(),
295 refresh_token: creds.refresh_token.clone(),
296 expires_at: std::time::Instant::now() + std::time::Duration::from_secs(expires_in),
297 };
298
299 let token = cached.access_token.clone();
300 *cache = Some(cached);
301 Ok(token)
302 }
303
304 fn convert_messages(messages: &[Message]) -> Vec<Value> {
305 messages
306 .iter()
307 .map(|msg| {
308 let role = match msg.role {
309 Role::System => "system",
310 Role::User => "user",
311 Role::Assistant => "assistant",
312 Role::Tool => "tool",
313 };
314
315 match msg.role {
316 Role::Tool => {
317 if let Some(ContentPart::ToolResult {
318 tool_call_id,
319 content,
320 }) = msg.content.first()
321 {
322 json!({
323 "role": "tool",
324 "tool_call_id": tool_call_id,
325 "content": content
326 })
327 } else {
328 json!({ "role": role, "content": "" })
329 }
330 }
331 Role::Assistant => {
332 let text: String = msg
333 .content
334 .iter()
335 .filter_map(|p| match p {
336 ContentPart::Text { text } => Some(text.clone()),
337 _ => None,
338 })
339 .collect::<Vec<_>>()
340 .join("");
341
342 let tool_calls: Vec<Value> = msg
343 .content
344 .iter()
345 .filter_map(|p| match p {
346 ContentPart::ToolCall {
347 id,
348 name,
349 arguments,
350 ..
351 } => Some(json!({
352 "id": id,
353 "type": "function",
354 "function": {
355 "name": name,
356 "arguments": arguments
357 }
358 })),
359 _ => None,
360 })
361 .collect();
362
363 if tool_calls.is_empty() {
364 json!({ "role": "assistant", "content": text })
365 } else {
366 json!({
367 "role": "assistant",
368 "content": if text.is_empty() { Value::Null } else { json!(text) },
369 "tool_calls": tool_calls
370 })
371 }
372 }
373 _ => {
374 let text: String = msg
375 .content
376 .iter()
377 .filter_map(|p| match p {
378 ContentPart::Text { text } => Some(text.clone()),
379 _ => None,
380 })
381 .collect::<Vec<_>>()
382 .join("\n");
383 json!({ "role": role, "content": text })
384 }
385 }
386 })
387 .collect()
388 }
389
390 fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
391 tools
392 .iter()
393 .map(|t| {
394 json!({
395 "type": "function",
396 "function": {
397 "name": t.name,
398 "description": t.description,
399 "parameters": t.parameters
400 }
401 })
402 })
403 .collect()
404 }
405}
406
407#[async_trait]
408impl Provider for OpenAiCodexProvider {
409 fn name(&self) -> &str {
410 "openai-codex"
411 }
412
413 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
414 Ok(vec![
415 ModelInfo {
416 id: "gpt-5".to_string(),
417 name: "GPT-5".to_string(),
418 provider: "openai-codex".to_string(),
419 context_window: 400_000,
420 max_output_tokens: Some(128_000),
421 supports_vision: false,
422 supports_tools: true,
423 supports_streaming: true,
424 input_cost_per_million: Some(0.0),
425 output_cost_per_million: Some(0.0),
426 },
427 ModelInfo {
428 id: "gpt-5-mini".to_string(),
429 name: "GPT-5 Mini".to_string(),
430 provider: "openai-codex".to_string(),
431 context_window: 264_000,
432 max_output_tokens: Some(64_000),
433 supports_vision: false,
434 supports_tools: true,
435 supports_streaming: true,
436 input_cost_per_million: Some(0.0),
437 output_cost_per_million: Some(0.0),
438 },
439 ModelInfo {
440 id: "gpt-5.1-codex".to_string(),
441 name: "GPT-5.1 Codex".to_string(),
442 provider: "openai-codex".to_string(),
443 context_window: 400_000,
444 max_output_tokens: Some(128_000),
445 supports_vision: false,
446 supports_tools: true,
447 supports_streaming: true,
448 input_cost_per_million: Some(0.0),
449 output_cost_per_million: Some(0.0),
450 },
451 ModelInfo {
452 id: "gpt-5.2".to_string(),
453 name: "GPT-5.2".to_string(),
454 provider: "openai-codex".to_string(),
455 context_window: 400_000,
456 max_output_tokens: Some(128_000),
457 supports_vision: false,
458 supports_tools: true,
459 supports_streaming: true,
460 input_cost_per_million: Some(0.0),
461 output_cost_per_million: Some(0.0),
462 },
463 ModelInfo {
464 id: "gpt-5.3-codex".to_string(),
465 name: "GPT-5.3 Codex".to_string(),
466 provider: "openai-codex".to_string(),
467 context_window: 400_000,
468 max_output_tokens: Some(128_000),
469 supports_vision: false,
470 supports_tools: true,
471 supports_streaming: true,
472 input_cost_per_million: Some(0.0),
473 output_cost_per_million: Some(0.0),
474 },
475 ModelInfo {
476 id: "o3".to_string(),
477 name: "O3".to_string(),
478 provider: "openai-codex".to_string(),
479 context_window: 200_000,
480 max_output_tokens: Some(100_000),
481 supports_vision: true,
482 supports_tools: true,
483 supports_streaming: true,
484 input_cost_per_million: Some(0.0),
485 output_cost_per_million: Some(0.0),
486 },
487 ModelInfo {
488 id: "o4-mini".to_string(),
489 name: "O4 Mini".to_string(),
490 provider: "openai-codex".to_string(),
491 context_window: 200_000,
492 max_output_tokens: Some(100_000),
493 supports_vision: true,
494 supports_tools: true,
495 supports_streaming: true,
496 input_cost_per_million: Some(0.0),
497 output_cost_per_million: Some(0.0),
498 },
499 ])
500 }
501
502 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
503 let access_token = self.get_access_token().await?;
504
505 let messages = Self::convert_messages(&request.messages);
506 let tools = Self::convert_tools(&request.tools);
507
508 let mut body = json!({
509 "model": request.model,
510 "messages": messages,
511 });
512
513 if !tools.is_empty() {
514 body["tools"] = json!(tools);
515 }
516 if let Some(temp) = request.temperature {
517 body["temperature"] = json!(temp);
518 }
519 if let Some(max_tokens) = request.max_tokens {
520 body["max_tokens"] = json!(max_tokens);
521 }
522
523 let response = self
524 .client
525 .post(format!("{}/chat/completions", OPENAI_API_URL))
526 .header("Authorization", format!("Bearer {}", access_token))
527 .header("Content-Type", "application/json")
528 .json(&body)
529 .send()
530 .await
531 .context("Failed to send request to OpenAI")?;
532
533 let status = response.status();
534 if !status.is_success() {
535 let body = response.text().await.unwrap_or_default();
536 anyhow::bail!("OpenAI API error ({}): {}", status, body);
537 }
538
539 #[derive(Deserialize)]
540 struct OpenAiResponse {
541 choices: Vec<OpenAiChoice>,
542 usage: Option<OpenAiUsage>,
543 }
544
545 #[derive(Deserialize)]
546 struct OpenAiChoice {
547 message: OpenAiMessage,
548 finish_reason: Option<String>,
549 }
550
551 #[derive(Deserialize)]
552 struct OpenAiMessage {
553 content: Option<String>,
554 tool_calls: Option<Vec<OpenAiToolCall>>,
555 }
556
557 #[derive(Deserialize)]
558 struct OpenAiToolCall {
559 id: String,
560 function: OpenAiFunction,
561 }
562
563 #[derive(Deserialize)]
564 struct OpenAiFunction {
565 name: String,
566 arguments: String,
567 }
568
569 #[derive(Deserialize)]
570 struct OpenAiUsage {
571 prompt_tokens: usize,
572 completion_tokens: usize,
573 total_tokens: usize,
574 }
575
576 let openai_resp: OpenAiResponse = response
577 .json()
578 .await
579 .context("Failed to parse OpenAI response")?;
580
581 let choice = openai_resp
582 .choices
583 .into_iter()
584 .next()
585 .context("No choices in response")?;
586
587 let mut content = Vec::new();
588
589 if let Some(text) = choice.message.content {
590 if !text.is_empty() {
591 content.push(ContentPart::Text { text });
592 }
593 }
594
595 if let Some(tool_calls) = choice.message.tool_calls {
596 for tc in tool_calls {
597 content.push(ContentPart::ToolCall {
598 id: tc.id,
599 name: tc.function.name,
600 arguments: tc.function.arguments,
601 thought_signature: None,
602 });
603 }
604 }
605
606 let finish_reason = match choice.finish_reason.as_deref() {
607 Some("stop") => FinishReason::Stop,
608 Some("tool_calls") => FinishReason::ToolCalls,
609 Some("length") => FinishReason::Length,
610 _ => FinishReason::Stop,
611 };
612
613 let usage = openai_resp
614 .usage
615 .map(|u| Usage {
616 prompt_tokens: u.prompt_tokens,
617 completion_tokens: u.completion_tokens,
618 total_tokens: u.total_tokens,
619 cache_read_tokens: None,
620 cache_write_tokens: None,
621 })
622 .unwrap_or_default();
623
624 Ok(CompletionResponse {
625 message: Message {
626 role: Role::Assistant,
627 content,
628 },
629 usage,
630 finish_reason,
631 })
632 }
633
634 async fn complete_stream(
635 &self,
636 request: CompletionRequest,
637 ) -> Result<BoxStream<'static, StreamChunk>> {
638 let access_token = self.get_access_token().await?;
639
640 let messages = Self::convert_messages(&request.messages);
641 let tools = Self::convert_tools(&request.tools);
642
643 let mut body = json!({
644 "model": request.model,
645 "messages": messages,
646 "stream": true,
647 });
648
649 if !tools.is_empty() {
650 body["tools"] = json!(tools);
651 }
652 if let Some(temp) = request.temperature {
653 body["temperature"] = json!(temp);
654 }
655 if let Some(max_tokens) = request.max_tokens {
656 body["max_tokens"] = json!(max_tokens);
657 }
658
659 let response = self
660 .client
661 .post(format!("{}/chat/completions", OPENAI_API_URL))
662 .header("Authorization", format!("Bearer {}", access_token))
663 .header("Content-Type", "application/json")
664 .json(&body)
665 .send()
666 .await
667 .context("Failed to send streaming request to OpenAI")?;
668
669 let status = response.status();
670 if !status.is_success() {
671 let body = response.text().await.unwrap_or_default();
672 anyhow::bail!("OpenAI API error ({}): {}", status, body);
673 }
674
675 let stream = response.bytes_stream().flat_map(|result| match result {
676 Ok(bytes) => {
677 let text = String::from_utf8_lossy(&bytes);
678 let mut chunks = Vec::new();
679
680 for line in text.lines() {
681 if !line.starts_with("data: ") {
682 continue;
683 }
684 let data = &line[6..];
685 if data == "[DONE]" {
686 chunks.push(StreamChunk::Done { usage: None });
687 continue;
688 }
689
690 #[derive(Deserialize)]
691 struct StreamResponse {
692 choices: Vec<StreamChoice>,
693 }
694 #[derive(Deserialize)]
695 struct StreamChoice {
696 delta: StreamDelta,
697 #[allow(dead_code)]
698 finish_reason: Option<String>,
699 }
700 #[derive(Deserialize)]
701 struct StreamDelta {
702 content: Option<String>,
703 tool_calls: Option<Vec<StreamToolCall>>,
704 }
705 #[derive(Deserialize)]
706 struct StreamToolCall {
707 id: Option<String>,
708 function: Option<StreamFunction>,
709 }
710 #[derive(Deserialize)]
711 struct StreamFunction {
712 name: Option<String>,
713 arguments: Option<String>,
714 }
715
716 if let Ok(resp) = serde_json::from_str::<StreamResponse>(data) {
717 for choice in resp.choices {
718 if let Some(content) = choice.delta.content {
719 chunks.push(StreamChunk::Text(content));
720 }
721 if let Some(tool_calls) = choice.delta.tool_calls {
722 for tc in tool_calls {
723 if let Some(id) = &tc.id {
724 if let Some(func) = &tc.function {
725 if let Some(name) = &func.name {
726 chunks.push(StreamChunk::ToolCallStart {
727 id: id.clone(),
728 name: name.clone(),
729 });
730 }
731 if let Some(args) = &func.arguments {
732 chunks.push(StreamChunk::ToolCallDelta {
733 id: id.clone(),
734 arguments_delta: args.clone(),
735 });
736 }
737 }
738 }
739 }
740 }
741 }
742 }
743 }
744 futures::stream::iter(chunks)
745 }
746 Err(e) => futures::stream::iter(vec![StreamChunk::Error(e.to_string())]),
747 });
748
749 Ok(Box::pin(stream))
750 }
751}
752
753#[cfg(test)]
754mod tests {
755 use super::*;
756
757 #[test]
758 fn test_generate_pkce() {
759 let pkce = OpenAiCodexProvider::generate_pkce();
760 assert!(!pkce.verifier.is_empty());
761 assert!(!pkce.challenge.is_empty());
762 assert_ne!(pkce.verifier, pkce.challenge);
763 }
764
765 #[test]
766 fn test_generate_state() {
767 let state = OpenAiCodexProvider::generate_state();
768 assert_eq!(state.len(), 32);
769 }
770}