use async_trait::async_trait;
use futures_core::Stream;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_json::json;
use std::pin::Pin;
use crate::openai_types::ChatCompletionRequest;
#[derive(Debug, Clone, Default)]
pub struct UsageSummary {
pub input_tokens: i64,
pub output_tokens: i64,
pub total_tokens: i64,
pub cached_tokens: Option<i64>,
pub reasoning_tokens: Option<i64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct InputOutputUsage {
pub input_tokens: i64,
pub output_tokens: i64,
pub total_tokens: i64,
#[serde(skip_serializing_if = "Option::is_none")]
pub cached_tokens: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_tokens: Option<i64>,
}
pub trait UsageAccumulator {
fn accumulate_into(&self, total: &mut UsageSummary);
}
pub trait ToOpenAIUsage {
fn to_openai_usage(&self) -> crate::openai_types::Usage;
}
impl UsageAccumulator for crate::openai_types::Usage {
fn accumulate_into(&self, total: &mut UsageSummary) {
total.input_tokens += i64::from(self.prompt_tokens);
total.output_tokens += i64::from(self.completion_tokens);
total.total_tokens += i64::from(self.total_tokens);
if let Some(details) = &self.prompt_tokens_details {
total.cached_tokens =
Some(total.cached_tokens.unwrap_or(0) + i64::from(details.cached_tokens));
}
if let Some(details) = &self.completion_tokens_details {
total.reasoning_tokens =
Some(total.reasoning_tokens.unwrap_or(0) + i64::from(details.reasoning_tokens));
}
}
}
impl ToOpenAIUsage for crate::openai_types::Usage {
fn to_openai_usage(&self) -> crate::openai_types::Usage {
self.clone()
}
}
impl UsageAccumulator for InputOutputUsage {
fn accumulate_into(&self, total: &mut UsageSummary) {
total.input_tokens += self.input_tokens;
total.output_tokens += self.output_tokens;
total.total_tokens += self.total_tokens;
if let Some(cached_tokens) = self.cached_tokens {
total.cached_tokens = Some(total.cached_tokens.unwrap_or(0) + cached_tokens);
}
if let Some(reasoning_tokens) = self.reasoning_tokens {
total.reasoning_tokens = Some(total.reasoning_tokens.unwrap_or(0) + reasoning_tokens);
}
}
}
impl ToOpenAIUsage for InputOutputUsage {
fn to_openai_usage(&self) -> crate::openai_types::Usage {
crate::openai_types::Usage {
prompt_tokens: self.input_tokens as i32,
completion_tokens: self.output_tokens as i32,
total_tokens: self.total_tokens as i32,
prompt_tokens_details: Some(crate::openai_types::PromptTokensDetails {
audio_tokens: 0,
cached_tokens: self.cached_tokens.unwrap_or(0) as i32,
}),
completion_tokens_details: Some(crate::openai_types::CompletionTokensDetails {
accepted_prediction_tokens: 0,
audio_tokens: 0,
reasoning_tokens: self.reasoning_tokens.unwrap_or(0) as i32,
rejected_prediction_tokens: 0,
}),
}
}
}
pub(crate) fn parse_usage_payload(payload: &Value) -> Option<InputOutputUsage> {
let obj = payload.as_object()?;
let input_tokens = obj
.get("input_tokens")
.and_then(Value::as_i64)
.or_else(|| obj.get("prompt_tokens").and_then(Value::as_i64))
.or_else(|| obj.get("prompt_token_count").and_then(Value::as_i64))?;
let output_tokens = obj
.get("output_tokens")
.and_then(Value::as_i64)
.or_else(|| obj.get("completion_tokens").and_then(Value::as_i64))
.or_else(|| obj.get("candidates_token_count").and_then(Value::as_i64))
.unwrap_or(0);
let total_tokens = obj
.get("total_tokens")
.and_then(Value::as_i64)
.or_else(|| obj.get("total_token_count").and_then(Value::as_i64))
.unwrap_or(input_tokens + output_tokens);
let cached_tokens = obj
.get("cached_tokens")
.and_then(Value::as_i64)
.or_else(|| {
obj.get("prompt_tokens_details")
.and_then(Value::as_object)
.and_then(|details| details.get("cached_tokens"))
.and_then(Value::as_i64)
});
let reasoning_tokens = obj
.get("reasoning_tokens")
.and_then(Value::as_i64)
.or_else(|| {
obj.get("completion_tokens_details")
.and_then(Value::as_object)
.and_then(|details| details.get("reasoning_tokens"))
.and_then(Value::as_i64)
})
.or_else(|| obj.get("thoughts_token_count").and_then(Value::as_i64));
Some(InputOutputUsage {
input_tokens,
output_tokens,
total_tokens,
cached_tokens,
reasoning_tokens,
})
}
pub fn usage_summary_to_value(summary: &UsageSummary) -> Value {
json!({
"input_tokens": summary.input_tokens,
"output_tokens": summary.output_tokens,
"total_tokens": summary.total_tokens,
"cached_tokens": summary.cached_tokens,
"reasoning_tokens": summary.reasoning_tokens,
})
}
#[cfg(test)]
mod tests {
use super::{InputOutputUsage, ToOpenAIUsage, parse_usage_payload};
#[test]
fn parse_usage_payload_supports_openai_usage_shape() {
let usage = serde_json::json!({
"prompt_tokens": 12,
"completion_tokens": 8,
"total_tokens": 20,
"prompt_tokens_details": {"cached_tokens": 3},
"completion_tokens_details": {"reasoning_tokens": 2}
});
let parsed = parse_usage_payload(&usage).expect("must parse usage");
assert_eq!(parsed.input_tokens, 12);
assert_eq!(parsed.output_tokens, 8);
assert_eq!(parsed.total_tokens, 20);
assert_eq!(parsed.cached_tokens, Some(3));
assert_eq!(parsed.reasoning_tokens, Some(2));
}
#[test]
fn maps_input_output_usage_to_openai_usage() {
let usage = InputOutputUsage {
input_tokens: 10,
output_tokens: 5,
total_tokens: 15,
cached_tokens: Some(2),
reasoning_tokens: Some(1),
};
let mapped = usage.to_openai_usage();
assert_eq!(mapped.prompt_tokens, 10);
assert_eq!(mapped.completion_tokens, 5);
assert_eq!(mapped.total_tokens, 15);
assert_eq!(
mapped
.prompt_tokens_details
.as_ref()
.expect("prompt details")
.cached_tokens,
2
);
assert_eq!(
mapped
.completion_tokens_details
.as_ref()
.expect("completion details")
.reasoning_tokens,
1
);
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum FinishReason {
Stop,
Length,
ToolCalls,
ContentFilter,
Other,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
pub name: String,
pub description: String,
pub arguments: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UnifiedResponse {
pub request_id: String,
pub model: String,
pub output_text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
pub finish_reason: FinishReason,
pub usage: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum UnifiedEvent {
ResponseCreated {
id: String,
model: String,
created_at: String,
},
ResponseInProgress {
id: String,
model: String,
created_at: String,
},
OutputItemAdded {
id: String,
item_type: String,
},
OutputItemDone {
id: String,
item_type: String,
},
ContentPartAdded {
item_id: String,
part_type: String,
},
ContentPartDelta {
item_id: String,
part_type: String,
delta: String,
},
ContentPartDone {
item_id: String,
part_type: String,
},
ThinkingDelta {
id: String,
delta: String,
},
ThinkingDone {
id: String,
summary: Option<String>,
},
ToolCallDelta {
id: String,
name: String,
arguments_delta: String,
},
ToolCallDone {
id: String,
name: String,
arguments: String,
},
ServerToolCall {
tool_call_id: String,
name: String,
arguments: String,
},
ServerToolCallResult {
tool_call_id: String,
name: String,
#[serde(skip_serializing_if = "Option::is_none")]
result: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
error: Option<String>,
},
MessageStart {
id: String,
role: String,
},
MessageDelta {
id: String,
delta: String,
},
MessageStop {
id: String,
stop_reason: Option<String>,
},
Usage {
usage: Value,
},
Completed {
finish_reason: Option<String>,
usage: Option<Value>,
},
Failed {
code: String,
message: String,
},
Cancelled {
reason: String,
},
}
#[derive(Debug)]
pub enum ProviderError {
Public {
status: axum::http::StatusCode,
error: crate::openai_types::ErrorDetail,
},
Internal(String),
}
impl std::fmt::Display for ProviderError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ProviderError::Public { error, .. } => {
write!(f, "provider error: {}", error.message)
}
ProviderError::Internal(message) => write!(f, "provider error: {message}"),
}
}
}
impl std::error::Error for ProviderError {}
#[async_trait]
pub trait Provider: Send + Sync {
fn model_id(&self) -> &str;
async fn complete(
&self,
request: ChatCompletionRequest,
) -> Result<UnifiedResponse, ProviderError>;
async fn stream<'a>(
&'a self,
request: ChatCompletionRequest,
) -> Result<
Pin<Box<dyn Stream<Item = Result<UnifiedEvent, ProviderError>> + Send + 'a>>,
ProviderError,
>;
}