use anyhow::anyhow;
use serde::{Deserialize, Serialize};
use crate::token_map::TokenType;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ProviderMetadata {
#[serde(untagged)]
ModelProvider(ModelProvider),
}
impl ProviderMetadata {
pub fn model_price(&self, model_name: &str, token: &TokenType) -> anyhow::Result<Price> {
let model_provider = match self {
ProviderMetadata::ModelProvider(model) => model,
};
let price = model_provider
.models
.iter()
.find(|m| m.name == model_name)
.ok_or(anyhow!(
"Provider doesn't provide service for model {model_name}"
))?
.pricing
.iter()
.find(|p| {
p.token == token.symbol || p.token == token.mint.to_string()
})
.ok_or(anyhow!(
"Model {model_name} doesn't support payment with {} ({})",
token.symbol,
token.mint
))?;
Ok(price.clone())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Category {
CompletionModel,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelProvider {
pub id: String,
pub name: String,
pub models: Vec<Model>,
pub category: Category,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Model {
pub name: String,
pub display_name: String,
pub provider_name: String,
pub pricing: Vec<Price>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Price {
pub token: String,
pub input_price: u64,
pub output_price: u64,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::token_map::token_map;
#[test]
fn test_json_format() {
let token_map = token_map(true); let _usdc_token = &token_map.usdc;
let another_token_mint = "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v";
let expected = serde_json::json!({
"id": "test-provider-id",
"name": "test",
"category": "completion_model",
"models": [
{
"name": "gpt-4o",
"display_name": "OpenAI: GPT-4o",
"provider_name": "openai",
"pricing": [
{
"token": "USDC",
"input_price": 10000,
"output_price": 50000,
},
{
"token": another_token_mint,
"input_price": 500,
"output_price": 1000,
}
],
}
]
});
let serialized = ProviderMetadata::ModelProvider(ModelProvider {
id: "test-provider-id".to_string(),
name: "test".to_string(),
models: vec![Model {
name: "gpt-4o".to_string(),
display_name: "OpenAI: GPT-4o".to_string(),
provider_name: "openai".to_string(),
pricing: vec![
Price {
token: "USDC".to_string(),
input_price: 10000,
output_price: 50000,
},
Price {
token: another_token_mint.to_string(),
input_price: 500,
output_price: 1000,
},
],
}],
category: Category::CompletionModel,
});
let json = serde_json::to_value(&serialized).unwrap();
assert_eq!(
canonical_json::to_string(&expected).unwrap(),
canonical_json::to_string(&json).unwrap()
);
}
}