Skip to main content

skg_provider_codex/
lib.rs

1#![deny(missing_docs)]
2//! OpenAI Codex (Responses API) provider for skg-turn.
3//!
4//! Implements [`Provider`] and [`StreamProvider`] for the OpenAI Responses API,
5//! supporting both the standard `api.openai.com/v1/responses` endpoint and the
6//! Codex backend at `chatgpt.com/backend-api/codex/responses`.
7//!
8//! # Authentication
9//!
10//! Codex uses OAuth JWT tokens. The provider extracts the `chatgpt_account_id`
11//! from the JWT payload and sends it as the `chatgpt-account-id` header.
12//!
13//! ```ignore
14//! use skg_provider_codex::CodexProvider;
15//!
16//! let provider = CodexProvider::new("eyJ...");  // JWT from OMP
17//! ```
18
19mod auth;
20mod convert;
21mod types;
22
23use convert::{messages_to_input, tools_to_codex};
24use futures_util::StreamExt;
25use layer0::content::{Content, ContentBlock};
26use skg_turn::infer::{InferRequest, InferResponse, ToolCall};
27use skg_turn::provider::{Provider, ProviderError};
28use skg_turn::stream::{StreamEvent, StreamProvider, StreamRequest};
29use skg_turn::types::*;
30use rust_decimal::Decimal;
31use tracing::Instrument;
32use types::*;
33
34/// Default base URL for the Codex backend.
35const DEFAULT_BASE_URL: &str = "https://chatgpt.com/backend-api";
36
37/// SSE path for responses.
38const CODEX_RESPONSES_PATH: &str = "/codex/responses";
39
40/// OpenAI Codex (Responses API) provider.
41#[derive(Clone)]
42pub struct CodexProvider {
43    access_token: String,
44    account_id: String,
45    client: reqwest::Client,
46    base_url: String,
47}
48
49impl CodexProvider {
50    /// Create a new Codex provider with a JWT access token.
51    ///
52    /// The account ID is automatically extracted from the JWT payload.
53    /// Returns an error if the token is not a valid Codex JWT.
54    pub fn new(access_token: impl Into<String>) -> Result<Self, ProviderError> {
55        let token = access_token.into();
56        let account_id = auth::extract_account_id(&token).ok_or_else(|| {
57            ProviderError::AuthFailed("failed to extract account ID from JWT".into())
58        })?;
59        Ok(Self {
60            access_token: token,
61            account_id,
62            client: reqwest::Client::new(),
63            base_url: DEFAULT_BASE_URL.into(),
64        })
65    }
66
67    /// Create a provider with explicit token and account ID.
68    ///
69    /// Use this when you have the account ID from another source
70    /// (e.g., stored separately from the JWT).
71    pub fn with_account_id(access_token: impl Into<String>, account_id: impl Into<String>) -> Self {
72        Self {
73            access_token: access_token.into(),
74            account_id: account_id.into(),
75            client: reqwest::Client::new(),
76            base_url: DEFAULT_BASE_URL.into(),
77        }
78    }
79
80    /// Override the base URL (for testing or custom endpoints).
81    ///
82    /// Default: `https://chatgpt.com/backend-api`
83    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
84        self.base_url = url.into();
85        self
86    }
87
88    /// Build the full URL for the Responses API endpoint.
89    fn endpoint_url(&self) -> String {
90        let base = self.base_url.trim_end_matches('/');
91        format!("{base}{CODEX_RESPONSES_PATH}")
92    }
93
94    /// Build request headers for Codex API calls.
95    fn build_headers(&self) -> reqwest::header::HeaderMap {
96        let mut headers = reqwest::header::HeaderMap::new();
97        headers.insert(
98            "authorization",
99            format!("Bearer {}", self.access_token)
100                .parse()
101                .expect("valid header"),
102        );
103        headers.insert(
104            "chatgpt-account-id",
105            self.account_id.parse().expect("valid header"),
106        );
107        headers.insert(
108            "openai-beta",
109            "responses=experimental".parse().expect("valid header"),
110        );
111        headers.insert("originator", "pi".parse().expect("valid header"));
112        headers.insert(
113            "content-type",
114            "application/json".parse().expect("valid header"),
115        );
116        headers
117    }
118
119    /// Build a [`CodexRequest`] from an [`InferRequest`].
120    fn build_codex_request(&self, request: &InferRequest) -> CodexRequest {
121        let model = request.model.clone().unwrap_or_else(|| "gpt-5".into());
122
123        let input = messages_to_input(&request.messages);
124        let tools = tools_to_codex(&request.tools);
125
126        CodexRequest {
127            model,
128            input,
129            stream: true,
130            instructions: request.system.clone(),
131            tools,
132            tool_choice: None,
133            temperature: request.temperature,
134            max_output_tokens: request.max_tokens,
135            reasoning: None,
136            prompt_cache_key: None,
137            store: Some(false),
138        }
139    }
140
141    /// Build a [`CodexRequest`] from a [`StreamRequest`].
142    fn build_codex_stream_request(&self, request: &StreamRequest) -> CodexRequest {
143        let infer = InferRequest {
144            model: request.model.clone(),
145            messages: request.messages.clone(),
146            tools: request.tools.clone(),
147            max_tokens: request.max_tokens,
148            temperature: request.temperature,
149            system: request.system.clone(),
150            extra: request.extra.clone(),
151        };
152        self.build_codex_request(&infer)
153    }
154
155    /// Send request and process SSE stream, emitting events via callback.
156    async fn stream_sse(
157        &self,
158        codex_request: CodexRequest,
159        on_event: &(dyn Fn(StreamEvent) + Send + Sync),
160    ) -> Result<InferResponse, ProviderError> {
161        let url = self.endpoint_url();
162        let headers = self.build_headers();
163
164        let http_response = self
165            .client
166            .post(&url)
167            .headers(headers)
168            .json(&codex_request)
169            .send()
170            .await
171            .map_err(|e| ProviderError::TransientError {
172                message: e.to_string(),
173                status: None,
174            })?;
175
176        let status = http_response.status();
177        if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
178            return Err(ProviderError::RateLimited);
179        }
180        if status == reqwest::StatusCode::UNAUTHORIZED || status == reqwest::StatusCode::FORBIDDEN {
181            let body = http_response.text().await.unwrap_or_default();
182            return Err(ProviderError::AuthFailed(body));
183        }
184        if !status.is_success() {
185            let body = http_response.text().await.unwrap_or_default();
186            return Err(map_error_response(status, &body));
187        }
188
189        // Process SSE stream
190        let mut stream = http_response.bytes_stream();
191        let mut buf = String::new();
192
193        // Accumulation state
194        let mut model_name = codex_request.model.clone();
195        let mut usage = ResponseUsage::default();
196        let mut stop_reason = StopReason::EndTurn;
197        let mut text_blocks: Vec<String> = Vec::new();
198        let mut tool_calls: Vec<ToolCall> = Vec::new();
199
200        // Per-block state (indexed by output item position)
201        let mut current_text = String::new();
202        let mut current_tool_call_id = String::new();
203        let mut current_tool_item_id = String::new();
204        let mut current_tool_name = String::new();
205        let mut current_tool_args = String::new();
206        let mut tool_call_index: usize = 0;
207
208        while let Some(chunk) = stream.next().await {
209            let bytes = chunk.map_err(|e| ProviderError::TransientError {
210                message: format!("stream read error: {e}"),
211                status: None,
212            })?;
213            buf.push_str(&String::from_utf8_lossy(&bytes));
214
215            // Process complete SSE frames
216            while let Some(frame_end) = buf.find("\n\n") {
217                let frame = buf[..frame_end].to_string();
218                buf = buf[frame_end + 2..].to_string();
219
220                // Extract data from SSE frame
221                let mut data = String::new();
222                for line in frame.lines() {
223                    if let Some(rest) = line.strip_prefix("data: ") {
224                        if !data.is_empty() {
225                            data.push('\n');
226                        }
227                        data.push_str(rest);
228                    }
229                }
230
231                if data.is_empty() {
232                    continue;
233                }
234
235                let event: SseEvent = match serde_json::from_str(&data) {
236                    Ok(ev) => ev,
237                    Err(e) => {
238                        tracing::warn!(error = %e, "failed to parse Codex SSE event");
239                        continue;
240                    }
241                };
242
243                match event.event_type.as_str() {
244                    "response.output_item.added" => {
245                        if let Some(item) = event.data.get("item") {
246                            let item_type = item.get("type").and_then(|v| v.as_str()).unwrap_or("");
247                            match item_type {
248                                "message" => {
249                                    current_text = String::new();
250                                }
251                                "function_call" => {
252                                    let call_id = item
253                                        .get("call_id")
254                                        .and_then(|v| v.as_str())
255                                        .unwrap_or("")
256                                        .to_string();
257                                    let item_id = item
258                                        .get("id")
259                                        .and_then(|v| v.as_str())
260                                        .unwrap_or("")
261                                        .to_string();
262                                    let name = item
263                                        .get("name")
264                                        .and_then(|v| v.as_str())
265                                        .unwrap_or("")
266                                        .to_string();
267
268                                    current_tool_call_id = call_id.clone();
269                                    current_tool_item_id = item_id;
270                                    current_tool_name = name.clone();
271                                    current_tool_args = String::new();
272
273                                    // Compose the skelegent tool ID as "call_id|item_id"
274                                    let skg_id = format!("{call_id}|{}", current_tool_item_id);
275                                    on_event(StreamEvent::ToolCallStart {
276                                        index: tool_call_index,
277                                        id: skg_id,
278                                        name,
279                                    });
280                                }
281                                _ => {}
282                            }
283                        }
284                    }
285                    "response.output_text.delta" => {
286                        if let Some(delta) = event.data.get("delta").and_then(|v| v.as_str()) {
287                            current_text.push_str(delta);
288                            on_event(StreamEvent::TextDelta(delta.to_string()));
289                        }
290                    }
291                    "response.function_call_arguments.delta" => {
292                        if let Some(delta) = event.data.get("delta").and_then(|v| v.as_str()) {
293                            current_tool_args.push_str(delta);
294                            on_event(StreamEvent::ToolCallDelta {
295                                index: tool_call_index,
296                                json_delta: delta.to_string(),
297                            });
298                        }
299                    }
300                    "response.output_item.done" => {
301                        if let Some(item) = event.data.get("item") {
302                            let item_type = item.get("type").and_then(|v| v.as_str()).unwrap_or("");
303
304                            match item_type {
305                                "message" => {
306                                    // Finalize text from the item itself
307                                    let final_text = extract_output_text(item);
308                                    if !final_text.is_empty() {
309                                        current_text = final_text;
310                                    }
311                                    if !current_text.is_empty() {
312                                        text_blocks.push(current_text.clone());
313                                    }
314                                    current_text = String::new();
315                                }
316                                "function_call" => {
317                                    // Finalize tool call
318                                    let final_args = item
319                                        .get("arguments")
320                                        .and_then(|v| v.as_str())
321                                        .unwrap_or(&current_tool_args);
322                                    let input: serde_json::Value = serde_json::from_str(final_args)
323                                        .unwrap_or(serde_json::Value::Object(
324                                            serde_json::Map::new(),
325                                        ));
326                                    tool_calls.push(ToolCall {
327                                        id: format!(
328                                            "{}|{}",
329                                            current_tool_call_id, current_tool_item_id
330                                        ),
331                                        name: current_tool_name.clone(),
332                                        input,
333                                    });
334                                    tool_call_index += 1;
335                                    current_tool_args = String::new();
336                                }
337                                _ => {}
338                            }
339                        }
340                    }
341                    "response.completed" | "response.done" => {
342                        if let Some(response) = event.data.get("response") {
343                            if let Some(u) = response.get("usage") {
344                                usage = ResponseUsage::from_value(u);
345                                on_event(StreamEvent::Usage(TokenUsage {
346                                    input_tokens: usage.input_tokens,
347                                    output_tokens: usage.output_tokens,
348                                    cache_read_tokens: if usage.cached_tokens > 0 {
349                                        Some(usage.cached_tokens)
350                                    } else {
351                                        None
352                                    },
353                                    cache_creation_tokens: None,
354                                }));
355                            }
356                            if let Some(status) = response.get("status").and_then(|v| v.as_str()) {
357                                stop_reason = match status {
358                                    "completed" => StopReason::EndTurn,
359                                    "incomplete" => StopReason::MaxTokens,
360                                    "failed" | "cancelled" => StopReason::EndTurn,
361                                    _ => StopReason::EndTurn,
362                                };
363                            }
364                            if let Some(m) = response.get("model").and_then(|v| v.as_str()) {
365                                model_name = m.to_string();
366                            }
367                        }
368                    }
369                    "error" | "response.failed" => {
370                        let msg = event
371                            .data
372                            .get("message")
373                            .and_then(|v| v.as_str())
374                            .or_else(|| {
375                                event
376                                    .data
377                                    .get("error")
378                                    .and_then(|e| e.get("message"))
379                                    .and_then(|v| v.as_str())
380                            })
381                            .unwrap_or("Codex stream error");
382                        return Err(ProviderError::TransientError {
383                            message: msg.to_string(),
384                            status: None,
385                        });
386                    }
387                    _ => {
388                        // Ignore: response.created, ping, reasoning events, etc.
389                    }
390                }
391            }
392        }
393
394        // If tool calls present but stop reason is EndTurn, fix it.
395        if !tool_calls.is_empty() && stop_reason == StopReason::EndTurn {
396            stop_reason = StopReason::ToolUse;
397        }
398
399        // Build final content.
400        let content = if text_blocks.len() == 1 {
401            Content::Text(text_blocks.into_iter().next().unwrap())
402        } else if text_blocks.is_empty() {
403            Content::text("")
404        } else {
405            Content::Blocks(
406                text_blocks
407                    .into_iter()
408                    .map(|t| ContentBlock::Text { text: t })
409                    .collect(),
410            )
411        };
412
413        // Codex is free (included in subscription), cost is zero.
414        let token_usage = TokenUsage {
415            input_tokens: usage.input_tokens,
416            output_tokens: usage.output_tokens,
417            cache_read_tokens: if usage.cached_tokens > 0 {
418                Some(usage.cached_tokens)
419            } else {
420                None
421            },
422            cache_creation_tokens: None,
423        };
424
425        let response = InferResponse {
426            content,
427            tool_calls,
428            stop_reason,
429            usage: token_usage,
430            model: model_name,
431            cost: Some(Decimal::ZERO),
432            truncated: None,
433        };
434
435        on_event(StreamEvent::Done(response.clone()));
436
437        tracing::info!(
438            input_tokens = usage.input_tokens,
439            output_tokens = usage.output_tokens,
440            "codex streaming inference finished"
441        );
442
443        Ok(response)
444    }
445}
446
447impl Provider for CodexProvider {
448    fn infer(
449        &self,
450        request: InferRequest,
451    ) -> impl std::future::Future<Output = Result<InferResponse, ProviderError>> + Send {
452        let codex_request = self.build_codex_request(&request);
453        let this = self.clone();
454        let model = request.model.as_deref().unwrap_or("unknown");
455        let span = tracing::info_span!("provider.infer", provider = "codex", model);
456
457        async move {
458            // Non-streaming: use stream_sse with a no-op callback, then return the response.
459            this.stream_sse(codex_request, &|_| {}).await
460        }
461        .instrument(span)
462    }
463}
464
465impl StreamProvider for CodexProvider {
466    fn infer_stream(
467        &self,
468        request: StreamRequest,
469        on_event: impl Fn(StreamEvent) + Send + Sync + 'static,
470    ) -> impl std::future::Future<Output = Result<InferResponse, ProviderError>> + Send {
471        let codex_request = self.build_codex_stream_request(&request);
472        let this = self.clone();
473        let model = request.model.as_deref().unwrap_or("unknown");
474        let span = tracing::info_span!("provider.infer_stream", provider = "codex", model);
475
476        async move { this.stream_sse(codex_request, &on_event).await }.instrument(span)
477    }
478}
479
480/// Extract combined text from a Responses API output message item.
481fn extract_output_text(item: &serde_json::Value) -> String {
482    item.get("content")
483        .and_then(|c| c.as_array())
484        .map(|parts| {
485            parts
486                .iter()
487                .filter_map(|p| {
488                    let ptype = p.get("type").and_then(|v| v.as_str()).unwrap_or("");
489                    match ptype {
490                        "output_text" => p.get("text").and_then(|v| v.as_str()),
491                        "refusal" => p.get("refusal").and_then(|v| v.as_str()),
492                        _ => None,
493                    }
494                })
495                .collect::<Vec<_>>()
496                .join("")
497        })
498        .unwrap_or_default()
499}
500
501/// Map a non-success HTTP response to a [`ProviderError`].
502fn map_error_response(status: reqwest::StatusCode, body: &str) -> ProviderError {
503    let status_u16 = status.as_u16();
504
505    // Check for rate limit / usage limit signals.
506    if body.contains("usage_limit_reached")
507        || body.contains("usage_not_included")
508        || body.contains("rate_limit_exceeded")
509    {
510        return ProviderError::RateLimited;
511    }
512
513    if body.contains("content_filter") || body.contains("content policy") {
514        return ProviderError::ContentBlocked {
515            message: body.to_string(),
516        };
517    }
518
519    ProviderError::TransientError {
520        message: format!("HTTP {status}: {body}"),
521        status: Some(status_u16),
522    }
523}
524
525#[cfg(test)]
526mod tests {
527    use super::*;
528
529    #[test]
530    fn endpoint_url_default() {
531        let p = CodexProvider::with_account_id("tok", "acct");
532        assert_eq!(
533            p.endpoint_url(),
534            "https://chatgpt.com/backend-api/codex/responses"
535        );
536    }
537
538    #[test]
539    fn endpoint_url_custom() {
540        let p =
541            CodexProvider::with_account_id("tok", "acct").with_base_url("http://localhost:8080/");
542        assert_eq!(p.endpoint_url(), "http://localhost:8080/codex/responses");
543    }
544
545    #[test]
546    fn error_mapping_rate_limit() {
547        let err = map_error_response(
548            reqwest::StatusCode::BAD_REQUEST,
549            r#"{"error":{"code":"rate_limit_exceeded"}}"#,
550        );
551        assert!(matches!(err, ProviderError::RateLimited));
552    }
553
554    #[test]
555    fn error_mapping_content_filter() {
556        let err = map_error_response(reqwest::StatusCode::BAD_REQUEST, "content_filter triggered");
557        assert!(matches!(err, ProviderError::ContentBlocked { .. }));
558    }
559
560    #[test]
561    fn extract_output_text_basic() {
562        let item = serde_json::json!({
563            "type": "message",
564            "content": [
565                {"type": "output_text", "text": "Hello "},
566                {"type": "output_text", "text": "world"}
567            ]
568        });
569        assert_eq!(extract_output_text(&item), "Hello world");
570    }
571
572    #[test]
573    fn build_request_sets_stream_true() {
574        let p = CodexProvider::with_account_id("tok", "acct");
575        let req = InferRequest::new(vec![]);
576        let codex = p.build_codex_request(&req);
577        assert!(codex.stream);
578        assert_eq!(codex.store, Some(false));
579    }
580}