llmkit-openai 0.1.0

OpenAI (GPT-4o, o1) provider adapter for llmkit-rs
Documentation
//! [`OpenAiProvider`] — implements [`LlmProvider`] against the OpenAI REST API.

use std::time::{Duration, Instant};

use async_trait::async_trait;
use llmkit_core::{
    pricing, ChatRequest, ChatResponse, ChatStream, CostEstimate, EmbedRequest, EmbedResponse,
    LlmError, LlmProvider, LlmResult,
};

use crate::types::{ApiError, ChatCompletionResponse, EmbeddingsResponse};
use crate::{chat, embed, stream};

const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
const DEFAULT_MODEL: &str = "gpt-4o-mini";
const DEFAULT_EMBED_MODEL: &str = "text-embedding-3-small";

/// OpenAI provider (GPT-4o, o1, embeddings).
#[derive(Clone)]
pub struct OpenAiProvider {
    http: reqwest::Client,
    api_key: String,
    base_url: String,
    model: String,
    embed_model: String,
}

impl OpenAiProvider {
    /// Construct with an explicit API key.
    pub fn new(api_key: impl Into<String>) -> Self {
        Self {
            http: reqwest::Client::new(),
            api_key: api_key.into(),
            base_url: DEFAULT_BASE_URL.to_string(),
            model: DEFAULT_MODEL.to_string(),
            embed_model: DEFAULT_EMBED_MODEL.to_string(),
        }
    }

    /// Construct from the `OPENAI_API_KEY` environment variable.
    pub fn from_env() -> LlmResult<Self> {
        let key = std::env::var("OPENAI_API_KEY")
            .map_err(|_| LlmError::Auth("OPENAI_API_KEY not set".into()))?;
        Ok(Self::new(key))
    }

    /// Set the default chat model.
    pub fn model(mut self, model: impl Into<String>) -> Self {
        self.model = model.into();
        self
    }

    /// Set the default embedding model.
    pub fn embed_model(mut self, model: impl Into<String>) -> Self {
        self.embed_model = model.into();
        self
    }

    /// Override the base URL (e.g. for an OpenAI-compatible gateway).
    pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
        self.base_url = base_url.into();
        self
    }

    /// Provide a custom [`reqwest::Client`] (connection pooling, timeouts, …).
    pub fn with_client(mut self, client: reqwest::Client) -> Self {
        self.http = client;
        self
    }

    fn resolved_model(&self, req: &ChatRequest) -> String {
        req.model.clone().unwrap_or_else(|| self.model.clone())
    }
}

#[async_trait]
impl LlmProvider for OpenAiProvider {
    async fn chat(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
        let model = self.resolved_model(&req);
        let body = chat::build_request(&req, model, false);

        let start = Instant::now();
        let resp = self
            .http
            .post(format!("{}/chat/completions", self.base_url))
            .bearer_auth(&self.api_key)
            .json(&body)
            .send()
            .await
            .map_err(map_reqwest_err)?;

        let resp = check_status(resp).await?;
        let parsed: ChatCompletionResponse = resp.json().await.map_err(map_reqwest_err)?;
        let mut out = chat::map_response(parsed, start.elapsed().as_millis() as u64)?;
        out.cost = pricing::pricing_for(&out.model).map(|p| p.cost_for(out.usage));
        Ok(out)
    }

    async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
        let model = self.resolved_model(&req);
        let body = chat::build_request(&req, model, true);

        let resp = self
            .http
            .post(format!("{}/chat/completions", self.base_url))
            .bearer_auth(&self.api_key)
            .json(&body)
            .send()
            .await
            .map_err(map_reqwest_err)?;

        let resp = check_status(resp).await?;
        Ok(stream::parse(resp))
    }

    async fn embed(&self, req: EmbedRequest) -> LlmResult<EmbedResponse> {
        let model = req.model.clone().unwrap_or_else(|| self.embed_model.clone());
        let body = embed::build_request(req.input, model);

        let resp = self
            .http
            .post(format!("{}/embeddings", self.base_url))
            .bearer_auth(&self.api_key)
            .json(&body)
            .send()
            .await
            .map_err(map_reqwest_err)?;

        let resp = check_status(resp).await?;
        let parsed: EmbeddingsResponse = resp.json().await.map_err(map_reqwest_err)?;
        Ok(embed::map_response(parsed))
    }

    fn name(&self) -> &'static str {
        "openai"
    }

    fn model(&self) -> &str {
        &self.model
    }

    fn estimate_cost(&self, req: &ChatRequest) -> Option<CostEstimate> {
        let model = self.resolved_model(req);
        let pricing = pricing::pricing_for(&model)?;
        // Rough pre-flight estimate: ~4 chars/token for the prompt, assume the
        // response fills max_tokens (or a small default).
        let prompt_chars: usize = req.messages.iter().filter_map(|m| m.content.as_text()).map(|t| t.len()).sum();
        let prompt_tokens = (prompt_chars / 4) as u32;
        let completion_tokens = req.max_tokens.unwrap_or(256);
        Some(pricing.cost_for(llmkit_core::TokenUsage::new(prompt_tokens, completion_tokens)))
    }
}

/// Map a `reqwest` transport error into an [`LlmError`].
pub(crate) fn map_reqwest_err(e: reqwest::Error) -> LlmError {
    if e.is_timeout() {
        LlmError::Timeout
    } else if e.is_decode() {
        LlmError::Serialization(e.to_string())
    } else {
        LlmError::Transport(e.to_string())
    }
}

/// Inspect the HTTP status and convert non-2xx responses into [`LlmError`].
pub(crate) async fn check_status(resp: reqwest::Response) -> LlmResult<reqwest::Response> {
    let status = resp.status();
    if status.is_success() {
        return Ok(resp);
    }

    let retry_after = resp
        .headers()
        .get(reqwest::header::RETRY_AFTER)
        .and_then(|v| v.to_str().ok())
        .and_then(|s| s.parse::<u64>().ok())
        .map(Duration::from_secs);

    let body = resp.text().await.unwrap_or_default();
    let message = serde_json::from_str::<ApiError>(&body)
        .map(|e| e.error.message)
        .unwrap_or(body);

    Err(match status.as_u16() {
        401 | 403 => LlmError::Auth(message),
        429 => LlmError::RateLimited { retry_after, message },
        400 | 404 | 422 => LlmError::InvalidRequest(message),
        code => LlmError::Provider { status: code, message },
    })
}