llm-cascade 0.1.0

Resilient cascading LLM inference with automatic failover across multiple providers
Documentation
//! OpenAI Chat Completions API provider (and compatible endpoints).

use reqwest::Client;
use serde_json::{json, Value};

use crate::error::ProviderError;
use crate::models::{ContentBlock, Conversation, LlmResponse, MessageRole};
use crate::providers::LlmProvider;

const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";

/// Provider for the OpenAI Chat Completions API and any compatible endpoint (Groq, Together, vLLM, etc.).
pub struct OpenAiProvider {
    client: Client,
    api_key: String,
    model: String,
    base_url: String,
}

impl OpenAiProvider {
    /// Creates a new OpenAI-compatible provider.
    pub fn new(api_key: String, model: String, base_url: Option<String>) -> Self {
        Self {
            client: Client::new(),
            api_key,
            model,
            base_url: base_url.unwrap_or_else(|| DEFAULT_BASE_URL.to_string()),
        }
    }
}

#[async_trait::async_trait]
impl LlmProvider for OpenAiProvider {
    async fn complete(&self, conversation: &Conversation) -> Result<LlmResponse, ProviderError> {
        let mut messages = Vec::new();

        for msg in &conversation.messages {
            let role = match msg.role {
                MessageRole::System => "system",
                MessageRole::User => "user",
                MessageRole::Assistant => "assistant",
                MessageRole::Tool => "tool",
            };

            let mut message = json!({
                "role": role,
                "content": msg.content,
            });

            if let Some(ref tool_call_id) = msg.tool_call_id {
                message["tool_call_id"] = json!(tool_call_id);
            }

            messages.push(message);
        }

        let mut body = json!({
            "model": self.model,
            "messages": messages,
        });

        if let Some(ref tools) = conversation.tools {
            let openai_tools: Vec<Value> = tools.iter().map(|t| {
                json!({
                    "type": "function",
                    "function": {
                        "name": t.name,
                        "description": t.description,
                        "parameters": t.parameters,
                    }
                })
            }).collect();
            body["tools"] = json!(openai_tools);
        }

        let url = format!(
            "{}/chat/completions",
            self.base_url.trim_end_matches('/')
        );

        let response = self
            .client
            .post(&url)
            .header("Authorization", format!("Bearer {}", self.api_key))
            .json(&body)
            .send()
            .await?;

        let status = response.status();
        let retry_after = response
            .headers()
            .get("retry-after")
            .and_then(|v| v.to_str().ok())
            .and_then(|v| v.parse::<u64>().ok());

        if !status.is_success() {
            let error_body = response.text().await.unwrap_or_default();
            return Err(ProviderError::Http {
                status: status.as_u16(),
                body: error_body,
                retry_after,
            });
        }

        let data: Value = response.json().await.map_err(|e| ProviderError::Parse(e.to_string()))?;

        let choice = data["choices"][0].clone();
        let message = &choice["message"];
        let model = data["model"].as_str().unwrap_or(&self.model).to_string();

        let mut content_blocks = Vec::new();

        if let Some(text) = message["content"].as_str()
            && !text.is_empty()
        {
            content_blocks.push(ContentBlock::Text { text: text.to_string() });
        }

        if let Some(tool_calls) = message["tool_calls"].as_array() {
            for tc in tool_calls {
                let id = tc["id"].as_str().unwrap_or("").to_string();
                let function = &tc["function"];
                let name = function["name"].as_str().unwrap_or("").to_string();
                let arguments = function["arguments"].as_str().unwrap_or("{}").to_string();
                content_blocks.push(ContentBlock::ToolCall { id, name, arguments });
            }
        }

        let usage = &data["usage"];
        let input_tokens = usage["prompt_tokens"].as_u64().map(|v| v as u32);
        let output_tokens = usage["completion_tokens"].as_u64().map(|v| v as u32);

        Ok(LlmResponse {
            content: content_blocks,
            input_tokens,
            output_tokens,
            model,
        })
    }

    fn provider_name(&self) -> &str {
        "openai"
    }

    fn model_name(&self) -> &str {
        &self.model
    }
}