use async_trait::async_trait;
use axum::http::{HeaderMap, HeaderValue, header};
use serde_json::Value;
use std::sync::Arc;
use crate::model_api_provider::provider::{
ModelApiProvider, ProviderRequest, ProviderResponse, proxy_request,
};
use crate::serve_config::{ConfigError, ProviderConfig};
#[derive(Clone)]
pub struct ResponsesClient {
client: reqwest::Client,
base_url: String,
auth_headers: HeaderMap,
model_id: String,
upstream_model: String,
}
impl ResponsesClient {
pub fn new(
client: reqwest::Client,
base_url: String,
auth_headers: HeaderMap,
model_id: String,
upstream_model: String,
) -> Self {
Self {
client,
base_url,
auth_headers,
model_id,
upstream_model,
}
}
}
#[async_trait]
impl ModelApiProvider for ResponsesClient {
fn model_id(&self) -> &str {
&self.model_id
}
async fn execute(&self, mut req: ProviderRequest) -> Result<ProviderResponse, anyhow::Error> {
req.endpoint_path = "/responses".to_string();
proxy_request(
&self.client,
&self.base_url,
self.auth_headers.clone(),
Some(self.upstream_model.as_str()),
req,
)
.await
}
fn extract_request_id(&self, payload_json: &Value) -> Option<String> {
extract_request_id_json(payload_json)
}
fn extract_usage(&self, payload_json: &Value) -> Option<Value> {
extract_usage_json(payload_json)
}
fn inject_usage(&self, payload_json: &mut Value, usage: Value) -> bool {
inject_usage_json(payload_json, usage)
}
}
fn extract_request_id_json(payload_json: &Value) -> Option<String> {
payload_json
.get("id")
.and_then(Value::as_str)
.map(str::to_string)
.or_else(|| {
payload_json
.get("response")
.and_then(|response| response.get("id"))
.and_then(Value::as_str)
.map(str::to_string)
})
}
fn extract_usage_json(payload_json: &Value) -> Option<Value> {
non_null_usage(payload_json.get("usage")).or_else(|| {
non_null_usage(
payload_json
.get("response")
.and_then(|response| response.get("usage")),
)
})
}
fn inject_usage_json(payload_json: &mut Value, usage: Value) -> bool {
let Some(obj) = payload_json.as_object_mut() else {
return false;
};
if obj.contains_key("usage") {
obj.insert("usage".to_string(), usage);
return true;
}
if let Some(response) = obj.get_mut("response").and_then(Value::as_object_mut) {
response.insert("usage".to_string(), usage);
return true;
}
false
}
fn non_null_usage(value: Option<&Value>) -> Option<Value> {
value.filter(|usage| !usage.is_null()).cloned()
}
#[cfg(test)]
mod tests {
use super::{extract_request_id_json, extract_usage_json, inject_usage_json};
use serde_json::json;
#[test]
fn extract_request_id_json_prefers_top_level_id() {
let payload = json!({"id": "resp_top", "response": {"id": "resp_nested"}});
let request_id = extract_request_id_json(&payload);
assert_eq!(request_id.as_deref(), Some("resp_top"));
}
#[test]
fn extract_request_id_json_supports_nested_response_id() {
let payload = json!({"response": {"id": "resp_nested"}});
let request_id = extract_request_id_json(&payload);
assert_eq!(request_id.as_deref(), Some("resp_nested"));
}
#[test]
fn extract_usage_json_reads_nested_usage() {
let payload = json!({"response": {"usage": {"total_tokens": 42}}});
let usage = extract_usage_json(&payload);
assert_eq!(usage, Some(json!({"total_tokens": 42})));
}
#[test]
fn extract_usage_json_ignores_null_usage() {
let payload = json!({"usage": null});
let usage = extract_usage_json(&payload);
assert_eq!(usage, None);
}
#[test]
fn extract_usage_json_reads_nested_response_usage() {
let payload =
json!({"type": "response.completed", "response": {"usage": {"input_tokens": 8}}});
let usage = extract_usage_json(&payload);
assert_eq!(usage, Some(json!({"input_tokens": 8})));
}
#[test]
fn inject_usage_json_updates_top_level_usage() {
let mut payload = json!({"usage": {"total_tokens": 1}});
let new_usage = json!({"total_tokens": 99});
let injected = inject_usage_json(&mut payload, new_usage.clone());
assert!(injected);
assert_eq!(payload.get("usage"), Some(&new_usage));
}
}
pub fn build_client(provider: &ProviderConfig) -> Result<Arc<dyn ModelApiProvider>, ConfigError> {
if provider.provider_type != "responses" {
return Err(ConfigError::UnknownProviderType(
provider.provider_type.clone(),
));
}
let api_key = provider
.params
.get("api_key")
.ok_or_else(|| ConfigError::InvalidProvider("api_key is required".to_string()))?;
let base_url = provider
.params
.get("base_url")
.cloned()
.ok_or_else(|| ConfigError::InvalidProvider("base_url is required".to_string()))?;
let upstream_model = provider
.params
.get("model")
.cloned()
.ok_or_else(|| ConfigError::InvalidProvider("model is required".to_string()))?;
let mut headers = HeaderMap::new();
let auth_value = format!("Bearer {}", api_key);
headers.insert(
header::AUTHORIZATION,
auth_value
.parse::<HeaderValue>()
.map_err(|err| ConfigError::InvalidProvider(err.to_string()))?,
);
Ok(Arc::new(ResponsesClient::new(
reqwest::Client::new(),
base_url,
headers,
provider.model_id.clone(),
upstream_model,
)))
}