litellm-rs 0.1.1

A high-performance AI Gateway written in Rust, providing OpenAI-compatible APIs with intelligent routing, load balancing, and enterprise features
//! DeepInfra provider implementation
//!
//! This module provides DeepInfra API integration for serverless inference.

use super::{BaseProvider, ModelPricing, Provider, ProviderError, ProviderType};
use crate::config::ProviderConfig;
use crate::core::models::{RequestContext, openai::*};
use crate::utils::error::Result;
use async_trait::async_trait;
use serde_json::json;
use std::collections::HashMap;
use tracing::{debug, info};

/// DeepInfra provider implementation
#[derive(Debug, Clone)]
pub struct DeepInfraProvider {
    /// Base provider functionality
    base: BaseProvider,
    /// Model pricing cache
    pricing_cache: HashMap<String, ModelPricing>,
}

impl DeepInfraProvider {
    /// Create a new DeepInfra provider
    pub async fn new(config: &ProviderConfig) -> Result<Self> {
        let base = BaseProvider::new(config)?;

        let base_url = config
            .base_url
            .clone()
            .unwrap_or_else(|| "https://api.deepinfra.com".to_string());

        let provider = Self {
            base: BaseProvider { base_url, ..base },
            pricing_cache: Self::initialize_pricing_cache(),
        };

        info!(
            "DeepInfra provider '{}' initialized successfully",
            config.name
        );
        Ok(provider)
    }

    /// Initialize pricing cache with DeepInfra model prices
    fn initialize_pricing_cache() -> HashMap<String, ModelPricing> {
        let mut cache = HashMap::new();

        // Llama models
        cache.insert(
            "meta-llama/Llama-2-7b-chat-hf".to_string(),
            ModelPricing {
                model: "meta-llama/Llama-2-7b-chat-hf".to_string(),
                input_cost_per_1k: 0.00013,
                output_cost_per_1k: 0.00013,
                currency: "USD".to_string(),
                updated_at: chrono::Utc::now(),
            },
        );

        cache.insert(
            "meta-llama/Llama-2-13b-chat-hf".to_string(),
            ModelPricing {
                model: "meta-llama/Llama-2-13b-chat-hf".to_string(),
                input_cost_per_1k: 0.00022,
                output_cost_per_1k: 0.00022,
                currency: "USD".to_string(),
                updated_at: chrono::Utc::now(),
            },
        );

        cache.insert(
            "meta-llama/Llama-2-70b-chat-hf".to_string(),
            ModelPricing {
                model: "meta-llama/Llama-2-70b-chat-hf".to_string(),
                input_cost_per_1k: 0.0007,
                output_cost_per_1k: 0.0009,
                currency: "USD".to_string(),
                updated_at: chrono::Utc::now(),
            },
        );

        // Mixtral models
        cache.insert(
            "mistralai/Mixtral-8x7B-Instruct-v0.1".to_string(),
            ModelPricing {
                model: "mistralai/Mixtral-8x7B-Instruct-v0.1".to_string(),
                input_cost_per_1k: 0.00027,
                output_cost_per_1k: 0.00027,
                currency: "USD".to_string(),
                updated_at: chrono::Utc::now(),
            },
        );

        // Code models
        cache.insert(
            "codellama/CodeLlama-34b-Instruct-hf".to_string(),
            ModelPricing {
                model: "codellama/CodeLlama-34b-Instruct-hf".to_string(),
                input_cost_per_1k: 0.0006,
                output_cost_per_1k: 0.0006,
                currency: "USD".to_string(),
                updated_at: chrono::Utc::now(),
            },
        );

        // Embedding models
        cache.insert(
            "BAAI/bge-base-en-v1.5".to_string(),
            ModelPricing {
                model: "BAAI/bge-base-en-v1.5".to_string(),
                input_cost_per_1k: 0.00005,
                output_cost_per_1k: 0.0,
                currency: "USD".to_string(),
                updated_at: chrono::Utc::now(),
            },
        );

        cache.insert(
            "BAAI/bge-large-en-v1.5".to_string(),
            ModelPricing {
                model: "BAAI/bge-large-en-v1.5".to_string(),
                input_cost_per_1k: 0.0001,
                output_cost_per_1k: 0.0,
                currency: "USD".to_string(),
                updated_at: chrono::Utc::now(),
            },
        );

        cache
    }
}

#[async_trait]
impl Provider for DeepInfraProvider {
    fn name(&self) -> &str {
        &self.base.name
    }

    fn provider_type(&self) -> ProviderType {
        ProviderType::Custom("deepinfra".to_string())
    }

    async fn supports_model(&self, model: &str) -> bool {
        self.base.is_model_supported(model)
            || model.contains("llama")
            || model.contains("mixtral")
            || model.contains("codellama")
            || model.contains("bge")
    }

    async fn supports_images(&self) -> bool {
        false
    }

    async fn supports_embeddings(&self) -> bool {
        true // DeepInfra has embedding models
    }

    async fn supports_streaming(&self) -> bool {
        true
    }

    async fn list_models(&self) -> Result<Vec<Model>> {
        let known_models = vec![
            "meta-llama/Llama-2-7b-chat-hf",
            "meta-llama/Llama-2-13b-chat-hf",
            "meta-llama/Llama-2-70b-chat-hf",
            "mistralai/Mixtral-8x7B-Instruct-v0.1",
            "codellama/CodeLlama-34b-Instruct-hf",
            "BAAI/bge-base-en-v1.5",
            "BAAI/bge-large-en-v1.5",
        ];

        let models = known_models
            .into_iter()
            .map(|model| Model {
                id: model.to_string(),
                object: "model".to_string(),
                created: chrono::Utc::now().timestamp() as u64,
                owned_by: "deepinfra".to_string(),
            })
            .collect();

        Ok(models)
    }

    async fn health_check(&self) -> Result<()> {
        debug!("Performing DeepInfra health check");
        Ok(()) // DeepInfra doesn't have a specific health endpoint
    }

    async fn chat_completion(
        &self,
        request: ChatCompletionRequest,
        _context: RequestContext,
    ) -> Result<ChatCompletionResponse> {
        debug!("DeepInfra chat completion for model: {}", request.model);

        let mut body = json!({
            "model": request.model,
            "messages": request.messages
        });

        if let Some(max_tokens) = request.max_tokens {
            body["max_tokens"] = json!(max_tokens);
        }
        if let Some(temperature) = request.temperature {
            body["temperature"] = json!(temperature);
        }
        if let Some(top_p) = request.top_p {
            body["top_p"] = json!(top_p);
        }
        if let Some(stream) = request.stream {
            body["stream"] = json!(stream);
        }

        let url = format!("{}/v1/openai/chat/completions", self.base.base_url);

        let response = self
            .base
            .client
            .post(&url)
            .header("Authorization", format!("Bearer {}", self.base.api_key))
            .header("Content-Type", "application/json")
            .json(&body)
            .send()
            .await
            .map_err(|e| ProviderError::Network(e.to_string()))?;

        if !response.status().is_success() {
            let status = response.status();
            let error_text = response.text().await.unwrap_or_default();

            return Err(match status.as_u16() {
                401 => ProviderError::Authentication(error_text),
                429 => ProviderError::RateLimit(error_text),
                404 => ProviderError::ModelNotFound(error_text),
                400 => ProviderError::InvalidRequest(error_text),
                _ => ProviderError::Unknown(format!("HTTP {}: {}", status, error_text)),
            }
            .into());
        }

        let response_json: ChatCompletionResponse = self.base.parse_json_response(response).await?;
        Ok(response_json)
    }

    async fn completion(
        &self,
        request: CompletionRequest,
        _context: RequestContext,
    ) -> Result<CompletionResponse> {
        debug!("DeepInfra completion for model: {}", request.model);

        let mut body = json!({
            "model": request.model,
            "prompt": request.prompt
        });

        if let Some(max_tokens) = request.max_tokens {
            body["max_tokens"] = json!(max_tokens);
        }
        if let Some(temperature) = request.temperature {
            body["temperature"] = json!(temperature);
        }
        if let Some(top_p) = request.top_p {
            body["top_p"] = json!(top_p);
        }
        if let Some(stream) = request.stream {
            body["stream"] = json!(stream);
        }

        let url = format!("{}/v1/openai/completions", self.base.base_url);

        let response = self
            .base
            .client
            .post(&url)
            .header("Authorization", format!("Bearer {}", self.base.api_key))
            .header("Content-Type", "application/json")
            .json(&body)
            .send()
            .await
            .map_err(|e| ProviderError::Network(e.to_string()))?;

        if !response.status().is_success() {
            let status = response.status();
            let error_text = response.text().await.unwrap_or_default();

            return Err(match status.as_u16() {
                401 => ProviderError::Authentication(error_text),
                429 => ProviderError::RateLimit(error_text),
                404 => ProviderError::ModelNotFound(error_text),
                400 => ProviderError::InvalidRequest(error_text),
                _ => ProviderError::Unknown(format!("HTTP {}: {}", status, error_text)),
            }
            .into());
        }

        let response_json: CompletionResponse = self.base.parse_json_response(response).await?;
        Ok(response_json)
    }

    async fn embedding(
        &self,
        request: EmbeddingRequest,
        _context: RequestContext,
    ) -> Result<EmbeddingResponse> {
        debug!("DeepInfra embedding for model: {}", request.model);

        let body = json!({
            "model": request.model,
            "input": request.input
        });

        let url = format!("{}/v1/openai/embeddings", self.base.base_url);

        let response = self
            .base
            .client
            .post(&url)
            .header("Authorization", format!("Bearer {}", self.base.api_key))
            .header("Content-Type", "application/json")
            .json(&body)
            .send()
            .await
            .map_err(|e| ProviderError::Network(e.to_string()))?;

        if !response.status().is_success() {
            let status = response.status();
            let error_text = response.text().await.unwrap_or_default();

            return Err(match status.as_u16() {
                401 => ProviderError::Authentication(error_text),
                429 => ProviderError::RateLimit(error_text),
                404 => ProviderError::ModelNotFound(error_text),
                400 => ProviderError::InvalidRequest(error_text),
                _ => ProviderError::Unknown(format!("HTTP {}: {}", status, error_text)),
            }
            .into());
        }

        let response_json: EmbeddingResponse = self.base.parse_json_response(response).await?;
        Ok(response_json)
    }

    async fn image_generation(
        &self,
        _request: ImageGenerationRequest,
        _context: RequestContext,
    ) -> Result<ImageGenerationResponse> {
        Err(ProviderError::InvalidRequest(
            "Image generation not supported by DeepInfra".to_string(),
        )
        .into())
    }

    async fn get_model_pricing(&self, model: &str) -> Result<ModelPricing> {
        if let Some(pricing) = self.pricing_cache.get(model) {
            Ok(pricing.clone())
        } else {
            Ok(ModelPricing {
                model: model.to_string(),
                input_cost_per_1k: 0.0003,
                output_cost_per_1k: 0.0003,
                currency: "USD".to_string(),
                updated_at: chrono::Utc::now(),
            })
        }
    }

    async fn calculate_cost(
        &self,
        model: &str,
        input_tokens: u32,
        output_tokens: u32,
    ) -> Result<f64> {
        let pricing = self.get_model_pricing(model).await?;

        let input_cost = (input_tokens as f64 / 1000.0) * pricing.input_cost_per_1k;
        let output_cost = (output_tokens as f64 / 1000.0) * pricing.output_cost_per_1k;

        Ok(input_cost + output_cost)
    }
}