aimo-core 0.1.1

AiMo Network core protocol Rust specs
Documentation
use anyhow::anyhow;
use serde::{Deserialize, Serialize};

use crate::token_map::TokenType;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ProviderMetadata {
    #[serde(untagged)]
    ModelProvider(ModelProvider),
}

impl ProviderMetadata {
    pub fn model_price(&self, model_name: &str, token: &TokenType) -> anyhow::Result<Price> {
        let model_provider = match self {
            ProviderMetadata::ModelProvider(model) => model,
        };

        let price = model_provider
            .models
            .iter()
            .find(|m| m.name == model_name)
            .ok_or(anyhow!(
                "Provider doesn't provide service for model {model_name}"
            ))?
            .pricing
            .iter()
            .find(|p| {
                // Try to find pricing by token symbol first, then by mint address
                p.token == token.symbol || p.token == token.mint.to_string()
            })
            .ok_or(anyhow!(
                "Model {model_name} doesn't support payment with {} ({})",
                token.symbol,
                token.mint
            ))?;

        Ok(price.clone())
    }
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Category {
    CompletionModel,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelProvider {
    pub id: String,
    pub name: String,
    pub models: Vec<Model>,
    pub category: Category,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Model {
    pub name: String,
    pub display_name: String,
    pub provider_name: String,
    /// Input token price (per token)
    pub pricing: Vec<Price>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Price {
    pub token: String,
    pub input_price: u64,
    pub output_price: u64,
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::token_map::token_map;

    #[test]
    fn test_json_format() {
        let token_map = token_map(true); // Use devnet tokens
        let _usdc_token = &token_map.usdc;
        let another_token_mint = "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v"; // USDC mainnet

        let expected = serde_json::json!({
            "id": "test-provider-id",
            "name": "test",
            "category": "completion_model",
            "models": [
                {
                    "name": "gpt-4o",
                    "display_name": "OpenAI: GPT-4o",
                    "provider_name": "openai",
                    "pricing": [
                        {
                            "token": "USDC",
                            "input_price": 10000,
                            "output_price": 50000,
                        },
                        {
                            "token": another_token_mint,
                            "input_price": 500,
                            "output_price": 1000,
                        }
                    ],
                }
            ]
        });

        let serialized = ProviderMetadata::ModelProvider(ModelProvider {
            id: "test-provider-id".to_string(),
            name: "test".to_string(),
            models: vec![Model {
                name: "gpt-4o".to_string(),
                display_name: "OpenAI: GPT-4o".to_string(),
                provider_name: "openai".to_string(),
                pricing: vec![
                    Price {
                        token: "USDC".to_string(),
                        input_price: 10000,
                        output_price: 50000,
                    },
                    Price {
                        token: another_token_mint.to_string(),
                        input_price: 500,
                        output_price: 1000,
                    },
                ],
            }],
            category: Category::CompletionModel,
        });

        let json = serde_json::to_value(&serialized).unwrap();
        assert_eq!(
            canonical_json::to_string(&expected).unwrap(),
            canonical_json::to_string(&json).unwrap()
        );
    }
}