use crate::{
profile::{ApiFamily, ProviderProfile, RuntimeConfig},
ProviderError, ProviderResult,
};
use reqwest::header::HeaderMap;
#[derive(Debug, Clone)]
pub(crate) enum ProviderOverrides {
Messages(MessagesOverrides),
Completions(CompletionsOverrides),
Responses(ResponsesOverrides),
}
#[derive(Debug, Clone, Default)]
pub(crate) struct MessagesOverrides {
pub anthropic_version: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub(crate) struct CompletionsOverrides {}
#[derive(Debug, Clone, Default)]
pub(crate) struct ResponsesOverrides {
pub originator: Option<String>,
pub user_agent: Option<String>,
pub chatgpt_account_id: Option<String>,
pub fixed_body: Option<CodexFixedBody>,
pub endpoint_path: Option<String>,
}
#[derive(Debug, Clone)]
pub(crate) struct CodexFixedBody {
pub store: bool,
pub reasoning_effort: String,
pub parallel_tool_calls: bool,
}
impl Default for CodexFixedBody {
fn default() -> Self {
Self {
store: false,
reasoning_effort: "medium".into(),
parallel_tool_calls: true,
}
}
}
pub(crate) fn resolve_overrides(
profile: &ProviderProfile,
_runtime: &RuntimeConfig,
) -> ProviderOverrides {
match profile.family {
ApiFamily::Messages => {
let mut overrides = MessagesOverrides::default();
if profile.slug == "anthropic" || profile.base_url.contains("anthropic") {
overrides.anthropic_version = Some("2023-06-01".into());
}
ProviderOverrides::Messages(overrides)
}
ApiFamily::Completions => ProviderOverrides::Completions(CompletionsOverrides::default()),
ApiFamily::Responses => {
let mut overrides = ResponsesOverrides::default();
if profile.slug == "codex" {
overrides.originator = Some("iron-providers".into());
overrides.user_agent =
Some(format!("iron-providers/{}", env!("CARGO_PKG_VERSION")));
overrides.fixed_body = Some(CodexFixedBody::default());
overrides.endpoint_path = Some("/responses".into());
if let crate::profile::ProviderCredential::OAuthBearer {
access_token,
id_token,
..
} = &_runtime.credential
{
let token = id_token.as_deref().unwrap_or(access_token);
if let Some(account_id) = chatgpt_account_id_from_jwt(token) {
overrides.chatgpt_account_id = Some(account_id);
}
}
}
ProviderOverrides::Responses(overrides)
}
}
}
pub(crate) fn chatgpt_account_id_from_jwt(id_token: &str) -> Option<String> {
let payload_b64 = id_token.split('.').nth(1)?;
let payload_json = base64_decode_url_safe(payload_b64).ok()?;
let payload: serde_json::Value = serde_json::from_slice(&payload_json).ok()?;
payload
.get("chatgpt_account_id")
.and_then(|v| v.as_str().map(String::from))
.or_else(|| {
payload
.get("https://api.openai.com/auth.chatgpt_account_id")
.and_then(|v| v.as_str().map(String::from))
})
.or_else(|| {
payload
.get("https://api.openai.com/auth")
.and_then(|nested| nested.get("chatgpt_account_id"))
.and_then(|v| v.as_str().map(String::from))
})
.or_else(|| {
payload
.get("organizations")
.and_then(|orgs| orgs.as_array())
.and_then(|orgs| orgs.first())
.and_then(|first| first.get("id"))
.and_then(|v| v.as_str().map(String::from))
})
}
fn base64_decode_url_safe(input: &str) -> Result<Vec<u8>, base64::DecodeError> {
use base64::{engine::general_purpose, Engine as _};
general_purpose::URL_SAFE_NO_PAD.decode(input)
}
pub(crate) fn override_headers(overrides: &ProviderOverrides) -> ProviderResult<HeaderMap> {
let mut headers = HeaderMap::new();
match overrides {
ProviderOverrides::Messages(msg) => {
if let Some(ref version) = msg.anthropic_version {
let val = reqwest::header::HeaderValue::from_str(version).map_err(|e| {
ProviderError::invalid_request(format!(
"Invalid override header value for 'anthropic-version': {}",
e
))
})?;
headers.insert("anthropic-version", val);
}
}
ProviderOverrides::Responses(resp) => {
if let Some(ref originator) = resp.originator {
let val = reqwest::header::HeaderValue::from_str(originator).map_err(|e| {
ProviderError::invalid_request(format!(
"Invalid override header value for 'originator': {}",
e
))
})?;
headers.insert("originator", val);
}
if let Some(ref ua) = resp.user_agent {
let val = reqwest::header::HeaderValue::from_str(ua).map_err(|e| {
ProviderError::invalid_request(format!(
"Invalid override header value for 'user-agent': {}",
e
))
})?;
headers.insert(reqwest::header::USER_AGENT, val);
}
if let Some(ref account_id) = resp.chatgpt_account_id {
let val = reqwest::header::HeaderValue::from_str(account_id).map_err(|e| {
ProviderError::invalid_request(format!(
"Invalid override header value for 'chatgpt-account-id': {}",
e
))
})?;
headers.insert("chatgpt-account-id", val);
}
}
_ => {}
}
Ok(headers)
}