vllora 0.1.23

AI gateway for managing and routing LLM requests - Govern, Secure, and Optimize your AI Traffic.
use vllora_core::pricing::calculator::{calculate_image_price, calculate_tokens_cost};
use vllora_llm::types::credentials_ident::CredentialsIdent;
use vllora_llm::types::gateway::{
    CostCalculationResult, CostCalculator, CostCalculatorError, Usage,
};
use vllora_llm::types::provider::ModelPrice;

#[derive(Clone)]
pub struct GatewayCostCalculator {
    default_image_cost: f64,
}

impl GatewayCostCalculator {
    pub fn new() -> Self {
        Self {
            default_image_cost: 0.0,
        }
    }
}

#[async_trait::async_trait]
impl CostCalculator for GatewayCostCalculator {
    async fn calculate_cost(
        &self,
        price: &ModelPrice,
        usage: &Usage,
        _credentials_ident: &CredentialsIdent,
    ) -> Result<CostCalculationResult, CostCalculatorError> {
        match usage {
            vllora_llm::types::gateway::Usage::ImageGenerationModelUsage(usage) => {
                if let ModelPrice::ImageGeneration(p) = &price {
                    Ok(calculate_image_price(p, usage, self.default_image_cost))
                } else {
                    Err(CostCalculatorError::CalculationError(
                        "Image model pricing are not set".to_string(),
                    ))
                }
            }
            vllora_llm::types::gateway::Usage::CompletionModelUsage(usage) => {
                let (input_price, cached_input_price, cached_input_write_price, output_price) =
                    match price {
                        ModelPrice::Completion(c) => (
                            c.per_input_token,
                            c.per_cached_input_token,
                            c.per_cached_input_write_token,
                            c.per_output_token,
                        ),
                        ModelPrice::Embedding(c) => (c.per_input_token, None, None, 0.0),
                        ModelPrice::ImageGeneration(_) => {
                            return Err(CostCalculatorError::CalculationError(
                                "Model pricing not supported".to_string(),
                            ))
                        }
                    };
                Ok(calculate_tokens_cost(
                    usage,
                    input_price,
                    cached_input_price,
                    cached_input_write_price,
                    output_price,
                ))
            }
        }
    }
}