use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, schemars::JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum Modality {
Text,
Image,
Audio,
Video,
Pdf,
}
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize, schemars::JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct Modalities {
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub input: Vec<Modality>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub output: Vec<Modality>,
}
impl Modalities {
pub(crate) fn is_empty(&self) -> bool {
self.input.is_empty() && self.output.is_empty()
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, schemars::JsonSchema)]
#[serde(deny_unknown_fields)]
pub struct ModelSpec {
pub id: String,
pub provider_id: String,
pub upstream_model: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub context_window: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_output_tokens: Option<u32>,
#[serde(default, skip_serializing_if = "Modalities::is_empty")]
pub modalities: Modalities,
#[serde(
default,
deserialize_with = "deserialize_knowledge_cutoff",
skip_serializing_if = "Option::is_none"
)]
pub knowledge_cutoff: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub input_token_price_per_million_usd: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub output_token_price_per_million_usd: Option<f64>,
}
impl ModelSpec {
pub fn new(
id: impl Into<String>,
provider_id: impl Into<String>,
upstream_model: impl Into<String>,
) -> Self {
Self {
id: id.into(),
provider_id: provider_id.into(),
upstream_model: upstream_model.into(),
context_window: None,
max_output_tokens: None,
modalities: Modalities::default(),
knowledge_cutoff: None,
input_token_price_per_million_usd: None,
output_token_price_per_million_usd: None,
}
}
pub fn compute_cost_usd(&self, input_tokens: u32, output_tokens: u32) -> Option<f64> {
let ip = self.input_token_price_per_million_usd?;
let op = self.output_token_price_per_million_usd?;
Some(
f64::from(input_tokens) * ip / 1_000_000.0
+ f64::from(output_tokens) * op / 1_000_000.0,
)
}
}
#[must_use]
pub fn normalize_knowledge_cutoff(value: &str) -> Option<String> {
let value = value.trim();
let bytes = value.as_bytes();
let valid_shape = match bytes.len() {
7 => {
bytes[4] == b'-'
&& bytes[..4].iter().all(u8::is_ascii_digit)
&& bytes[5..].iter().all(u8::is_ascii_digit)
}
10 => {
bytes[4] == b'-'
&& bytes[7] == b'-'
&& bytes[..4].iter().all(u8::is_ascii_digit)
&& bytes[5..7].iter().all(u8::is_ascii_digit)
&& bytes[8..].iter().all(u8::is_ascii_digit)
}
_ => false,
};
if !valid_shape {
return None;
}
let month = value[5..7].parse::<u32>().ok()?;
if !(1..=12).contains(&month) {
return None;
}
if bytes.len() == 10 {
let year = value[..4].parse::<i32>().ok()?;
let day = value[8..10].parse::<u32>().ok()?;
if day < 1 || day > days_in_month(year, month) {
return None;
}
}
Some(value.to_owned())
}
fn days_in_month(year: i32, month: u32) -> u32 {
match month {
1 | 3 | 5 | 7 | 8 | 10 | 12 => 31,
4 | 6 | 9 | 11 => 30,
2 if is_leap_year(year) => 29,
2 => 28,
_ => 0,
}
}
fn is_leap_year(year: i32) -> bool {
(year % 4 == 0 && year % 100 != 0) || year % 400 == 0
}
fn deserialize_knowledge_cutoff<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
where
D: serde::Deserializer<'de>,
{
let raw = Option::<String>::deserialize(deserializer)?;
match raw {
None => Ok(None),
Some(value) => normalize_knowledge_cutoff(&value).map(Some).ok_or_else(|| {
serde::de::Error::custom(format!(
"knowledge_cutoff must be an ISO date of the form YYYY-MM or YYYY-MM-DD, got {value:?}"
))
}),
}
}