1use anyhow::anyhow;
4use serde::{Deserialize, Serialize};
5
6use crate::token_map::TokenType;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub enum ProviderMetadata {
10 #[serde(untagged)]
11 ModelProvider(ModelProvider),
12}
13
14impl ProviderMetadata {
15 pub fn model_price(&self, model_name: &str, token: &TokenType) -> anyhow::Result<Price> {
16 let model_provider = match self {
17 ProviderMetadata::ModelProvider(model) => model,
18 };
19
20 let price = model_provider
21 .models
22 .iter()
23 .find(|m| m.name == model_name)
24 .ok_or(anyhow!(
25 "Provider doesn't provide service for model {model_name}"
26 ))?
27 .pricing
28 .iter()
29 .find(|p| {
30 p.token == token.symbol || p.token == token.mint.to_string()
32 })
33 .ok_or(anyhow!(
34 "Model {model_name} doesn't support payment with {} ({})",
35 token.symbol,
36 token.mint
37 ))?;
38
39 Ok(price.clone())
40 }
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44#[serde(rename_all = "snake_case")]
45pub enum Category {
46 CompletionModel,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct ModelProvider {
51 pub id: String,
52 pub name: String,
53 pub category: Category,
54 pub models: Vec<Model>,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct Model {
59 pub name: String,
60 pub display_name: String,
61 pub provider_name: String,
62 pub pricing: Vec<Price>,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct Price {
68 pub token: String,
69 pub input_price: u64,
70 pub output_price: u64,
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76 use crate::token_map::token_map;
77
78 #[test]
79 fn test_json_format() {
80 let token_map = token_map(true); let _usdc_token = &token_map.usdc;
82 let another_token_mint = "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v"; let expected = serde_json::json!({
85 "id": "test-provider-id",
86 "name": "test",
87 "category": "completion_model",
88 "models": [
89 {
90 "name": "gpt-4o",
91 "display_name": "OpenAI: GPT-4o",
92 "provider_name": "openai",
93 "pricing": [
94 {
95 "token": "USDC",
96 "input_price": 10000,
97 "output_price": 50000,
98 },
99 {
100 "token": another_token_mint,
101 "input_price": 500,
102 "output_price": 1000,
103 }
104 ],
105 }
106 ]
107 });
108
109 let serialized = ProviderMetadata::ModelProvider(ModelProvider {
110 id: "test-provider-id".to_string(),
111 name: "test".to_string(),
112 models: vec![Model {
113 name: "gpt-4o".to_string(),
114 display_name: "OpenAI: GPT-4o".to_string(),
115 provider_name: "openai".to_string(),
116 pricing: vec![
117 Price {
118 token: "USDC".to_string(),
119 input_price: 10000,
120 output_price: 50000,
121 },
122 Price {
123 token: another_token_mint.to_string(),
124 input_price: 500,
125 output_price: 1000,
126 },
127 ],
128 }],
129 category: Category::CompletionModel,
130 });
131
132 let json = serde_json::to_value(&serialized).unwrap();
133 assert_eq!(
134 canonical_json::to_string(&expected).unwrap(),
135 canonical_json::to_string(&json).unwrap()
136 );
137 }
138}