use super::auth::Auth;
use super::response::{CompletionResponse, StreamChunk, Usage};
use super::{CostInfo, CostResolution};
use crate::generator::CompletionParameters;
use crate::message::Message;
use std::future::Future;
use std::pin::Pin;
#[derive(Debug, Clone, Default, PartialEq)]
pub struct TokenPrice {
pub input_per_mtok: f64,
pub output_per_mtok: f64,
pub cache_read_per_mtok: Option<f64>,
pub cache_write_per_mtok: Option<f64>,
}
impl TokenPrice {
pub fn new(input_per_mtok: f64, output_per_mtok: f64) -> Self {
Self {
input_per_mtok,
output_per_mtok,
cache_read_per_mtok: None,
cache_write_per_mtok: None,
}
}
pub fn with_cache_rates(mut self, read_per_mtok: f64, write_per_mtok: f64) -> Self {
self.cache_read_per_mtok = Some(read_per_mtok);
self.cache_write_per_mtok = Some(write_per_mtok);
self
}
pub fn cost_of(&self, usage: &Usage) -> f64 {
let read_rate = self.cache_read_per_mtok.unwrap_or(self.input_per_mtok);
let write_rate = self.cache_write_per_mtok.unwrap_or(self.input_per_mtok);
(usage.uncached_input_tokens as f64 * self.input_per_mtok
+ usage.cache_read_tokens as f64 * read_rate
+ usage.cache_write_tokens as f64 * write_rate
+ usage.completion_tokens as f64 * self.output_per_mtok)
/ 1_000_000.0
}
}
#[derive(Debug, Clone)]
pub struct CostOutcome {
pub resolution: CostResolution,
pub usd: f64,
pub usage: Usage,
}
impl CostOutcome {
pub fn resolved(usd: f64, usage: Usage) -> Self {
Self {
resolution: CostResolution::Resolved,
usd,
usage,
}
}
pub fn unpriced(usage: Usage) -> Self {
Self {
resolution: CostResolution::Unpriced,
usd: 0.0,
usage,
}
}
pub fn unknown() -> Self {
Self {
resolution: CostResolution::Unknown,
usd: 0.0,
usage: Usage::default(),
}
}
pub fn into_cost_info(
self,
model: impl Into<String>,
response_id: impl Into<String>,
) -> CostInfo {
CostInfo {
cost: self.usd,
prompt_tokens: self.usage.prompt_tokens(),
completion_tokens: self.usage.completion_tokens,
total_tokens: self.usage.total_tokens(),
cache_read_tokens: self.usage.cache_read_tokens,
cache_write_tokens: self.usage.cache_write_tokens,
reasoning_tokens: self.usage.reasoning_tokens,
model: model.into(),
response_id: response_id.into(),
resolution: self.resolution,
}
}
}
pub struct PostStreamCtx<'a> {
pub client: &'a reqwest::Client,
pub generation_id: &'a str,
pub auth: &'a Auth,
pub price: Option<&'a TokenPrice>,
}
pub type CostFuture<'a> = Pin<Box<dyn Future<Output = CostOutcome> + Send + 'a>>;
#[derive(Debug, Clone)]
pub struct AppIdentity {
pub url: String,
pub title: String,
}
pub trait Provider: Send + Sync + std::fmt::Debug {
fn endpoint_url(&self, base_url: &str) -> String {
format!("{}/chat/completions", base_url.trim_end_matches('/'))
}
fn auth_headers(&self, auth: &Auth) -> crate::error::Result<Vec<(String, String)>> {
super::providers::openai_auth_headers(auth)
}
fn build_request(
&self,
model: &str,
messages: &[Message],
params: &CompletionParameters,
stream: bool,
include_usage: bool,
) -> crate::error::Result<serde_json::Value> {
super::providers::openai_build_request(
model,
messages,
params,
stream,
include_usage,
self.openai_token_limit_field(),
|body| self.openai_request_usage(body, stream),
)
}
fn openai_token_limit_field(&self) -> &'static str {
"max_completion_tokens"
}
fn openai_request_usage(&self, _body: &mut serde_json::Value, _stream: bool) {}
fn parse_response(&self, raw: serde_json::Value) -> crate::error::Result<CompletionResponse> {
super::response::parse_openai_response(raw, self)
}
fn parse_chunk(&self, data: &str) -> Option<crate::error::Result<StreamChunk>> {
super::response::parse_openai_chunk(data, self)
}
fn parse_usage(&self, raw: &serde_json::Value) -> Option<Usage> {
super::providers::parse_openai_usage_field(raw)
}
fn emits_stream_usage(&self, requested: bool) -> bool {
requested
}
fn attribution_headers(&self, _app: Option<&AppIdentity>) -> Vec<(String, String)> {
Vec::new()
}
fn cost_of(&self, usage: Usage, price: Option<&TokenPrice>) -> CostOutcome;
fn resolve_post_stream<'a>(&'a self, _ctx: PostStreamCtx<'a>) -> CostFuture<'a> {
Box::pin(async { CostOutcome::unknown() })
}
}