1use anyhow::anyhow;
2use serde::{Deserialize, Serialize};
3
4use crate::token_map::TokenType;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub enum ProviderMetadata {
8 #[serde(untagged)]
9 ModelProvider(ModelProvider),
10}
11
12impl ProviderMetadata {
13 pub fn model_price(&self, model_name: &str, token: &TokenType) -> anyhow::Result<Price> {
14 let model_provider = match self {
15 ProviderMetadata::ModelProvider(model) => model,
16 };
17
18 let price = model_provider
19 .models
20 .iter()
21 .find(|m| m.name == model_name)
22 .ok_or(anyhow!(
23 "Provider doesn't provide service for model {model_name}"
24 ))?
25 .pricing
26 .iter()
27 .find(|p| {
28 p.token == token.symbol || p.token == token.mint.to_string()
30 })
31 .ok_or(anyhow!(
32 "Model {model_name} doesn't support payment with {} ({})",
33 token.symbol,
34 token.mint
35 ))?;
36
37 Ok(price.clone())
38 }
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
42#[serde(rename_all = "snake_case")]
43pub enum Category {
44 CompletionModel,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ModelProvider {
49 pub id: String,
50 pub name: String,
51 pub models: Vec<Model>,
52 pub category: Category,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct Model {
57 pub name: String,
58 pub display_name: String,
59 pub provider_name: String,
60 pub pricing: Vec<Price>,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct Price {
66 pub token: String,
67 pub input_price: u64,
68 pub output_price: u64,
69}
70
71#[cfg(test)]
72mod tests {
73 use super::*;
74 use crate::token_map::token_map;
75
76 #[test]
77 fn test_json_format() {
78 let token_map = token_map(true); let _usdc_token = &token_map.usdc;
80 let another_token_mint = "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v"; let expected = serde_json::json!({
83 "id": "test-provider-id",
84 "name": "test",
85 "category": "completion_model",
86 "models": [
87 {
88 "name": "gpt-4o",
89 "display_name": "OpenAI: GPT-4o",
90 "provider_name": "openai",
91 "pricing": [
92 {
93 "token": "USDC",
94 "input_price": 10000,
95 "output_price": 50000,
96 },
97 {
98 "token": another_token_mint,
99 "input_price": 500,
100 "output_price": 1000,
101 }
102 ],
103 }
104 ]
105 });
106
107 let serialized = ProviderMetadata::ModelProvider(ModelProvider {
108 id: "test-provider-id".to_string(),
109 name: "test".to_string(),
110 models: vec![Model {
111 name: "gpt-4o".to_string(),
112 display_name: "OpenAI: GPT-4o".to_string(),
113 provider_name: "openai".to_string(),
114 pricing: vec![
115 Price {
116 token: "USDC".to_string(),
117 input_price: 10000,
118 output_price: 50000,
119 },
120 Price {
121 token: another_token_mint.to_string(),
122 input_price: 500,
123 output_price: 1000,
124 },
125 ],
126 }],
127 category: Category::CompletionModel,
128 });
129
130 let json = serde_json::to_value(&serialized).unwrap();
131 assert_eq!(
132 canonical_json::to_string(&expected).unwrap(),
133 canonical_json::to_string(&json).unwrap()
134 );
135 }
136}