aimo_core/
provider.rs

1use anyhow::anyhow;
2use serde::{Deserialize, Serialize};
3
4use crate::token_map::TokenType;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub enum ProviderMetadata {
8    #[serde(untagged)]
9    ModelProvider(ModelProvider),
10}
11
12impl ProviderMetadata {
13    pub fn model_price(&self, model_name: &str, token: &TokenType) -> anyhow::Result<Price> {
14        let model_provider = match self {
15            ProviderMetadata::ModelProvider(model) => model,
16        };
17
18        let price = model_provider
19            .models
20            .iter()
21            .find(|m| m.name == model_name)
22            .ok_or(anyhow!(
23                "Provider doesn't provide service for model {model_name}"
24            ))?
25            .pricing
26            .iter()
27            .find(|p| {
28                // Try to find pricing by token symbol first, then by mint address
29                p.token == token.symbol || p.token == token.mint.to_string()
30            })
31            .ok_or(anyhow!(
32                "Model {model_name} doesn't support payment with {} ({})",
33                token.symbol,
34                token.mint
35            ))?;
36
37        Ok(price.clone())
38    }
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
42#[serde(rename_all = "snake_case")]
43pub enum Category {
44    CompletionModel,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ModelProvider {
49    pub id: String,
50    pub name: String,
51    pub models: Vec<Model>,
52    pub category: Category,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct Model {
57    pub name: String,
58    pub display_name: String,
59    pub provider_name: String,
60    /// Input token price (per token)
61    pub pricing: Vec<Price>,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct Price {
66    pub token: String,
67    pub input_price: u64,
68    pub output_price: u64,
69}
70
71#[cfg(test)]
72mod tests {
73    use super::*;
74    use crate::token_map::token_map;
75
76    #[test]
77    fn test_json_format() {
78        let token_map = token_map(true); // Use devnet tokens
79        let _usdc_token = &token_map.usdc;
80        let another_token_mint = "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v"; // USDC mainnet
81
82        let expected = serde_json::json!({
83            "id": "test-provider-id",
84            "name": "test",
85            "category": "completion_model",
86            "models": [
87                {
88                    "name": "gpt-4o",
89                    "display_name": "OpenAI: GPT-4o",
90                    "provider_name": "openai",
91                    "pricing": [
92                        {
93                            "token": "USDC",
94                            "input_price": 10000,
95                            "output_price": 50000,
96                        },
97                        {
98                            "token": another_token_mint,
99                            "input_price": 500,
100                            "output_price": 1000,
101                        }
102                    ],
103                }
104            ]
105        });
106
107        let serialized = ProviderMetadata::ModelProvider(ModelProvider {
108            id: "test-provider-id".to_string(),
109            name: "test".to_string(),
110            models: vec![Model {
111                name: "gpt-4o".to_string(),
112                display_name: "OpenAI: GPT-4o".to_string(),
113                provider_name: "openai".to_string(),
114                pricing: vec![
115                    Price {
116                        token: "USDC".to_string(),
117                        input_price: 10000,
118                        output_price: 50000,
119                    },
120                    Price {
121                        token: another_token_mint.to_string(),
122                        input_price: 500,
123                        output_price: 1000,
124                    },
125                ],
126            }],
127            category: Category::CompletionModel,
128        });
129
130        let json = serde_json::to_value(&serialized).unwrap();
131        assert_eq!(
132            canonical_json::to_string(&expected).unwrap(),
133            canonical_json::to_string(&json).unwrap()
134        );
135    }
136}