llmix-rs 2.0.6

Rust binding for the LLMix orchestration contract with cache, resilience, and config parity
Documentation
use crate::dispatch::DispatchFn;
use crate::error::{InvalidConfigError, LlmixError, LlmixResult, ProviderError};
use crate::types::{DispatchContext, LlmUsage, ProviderResult};
use async_trait::async_trait;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_TYPE};
use serde_json::{Map, Value};
use std::collections::HashMap;
use url::Url;

const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";

#[derive(Clone)]
pub struct OpenAiChatHelper {
    client: reqwest::Client,
    base_url: Option<String>,
    default_headers: HeaderMap,
}

impl Default for OpenAiChatHelper {
    fn default() -> Self {
        Self::new()
    }
}

impl OpenAiChatHelper {
    pub fn new() -> Self {
        Self {
            client: reqwest::Client::new(),
            base_url: Some(DEFAULT_OPENAI_BASE_URL.to_string()),
            default_headers: HeaderMap::new(),
        }
    }

    pub fn without_base_url(mut self) -> Self {
        self.base_url = None;
        self
    }

    pub fn with_base_url(mut self, base_url: impl AsRef<str>) -> LlmixResult<Self> {
        self.base_url = Some(normalize_base_url(base_url.as_ref())?);
        Ok(self)
    }

    pub fn with_client(mut self, client: reqwest::Client) -> Self {
        self.client = client;
        self
    }

    pub fn with_default_header(
        mut self,
        name: impl AsRef<str>,
        value: impl AsRef<str>,
    ) -> LlmixResult<Self> {
        let header_name = HeaderName::from_bytes(name.as_ref().as_bytes())
            .map_err(|error| invalid_config(format!("invalid header name: {error}")))?;
        let header_value = HeaderValue::from_str(value.as_ref())
            .map_err(|error| invalid_config(format!("invalid header value: {error}")))?;
        self.default_headers.insert(header_name, header_value);
        Ok(self)
    }

    pub(crate) async fn dispatch_with<F>(
        &self,
        ctx: DispatchContext,
        customize: F,
    ) -> LlmixResult<ProviderResult>
    where
        F: FnOnce(&DispatchContext, &mut HeaderMap, &mut Map<String, Value>) -> LlmixResult<()>,
    {
        let mut headers = self.request_headers(&ctx)?;
        let mut body = build_request_body(&ctx);
        customize(&ctx, &mut headers, &mut body)?;

        let base_url = self.resolve_base_url(&ctx.kwargs)?;
        let endpoint = format!("{}/chat/completions", base_url.trim_end_matches('/'));
        let response = self
            .client
            .post(endpoint)
            .headers(headers)
            .json(&Value::Object(body))
            .send()
            .await
            .map_err(|error| {
                provider_transport_error(format!("provider request failed: {error}"))
            })?;

        let status = response.status();
        let headers = collect_headers(response.headers());
        let body = response.bytes().await.map_err(|error| {
            provider_transport_error(format!("failed reading provider response: {error}"))
        })?;

        if !status.is_success() {
            return Err(parse_provider_error(status.as_u16(), headers, &body));
        }

        parse_provider_result(&ctx, &headers, &body)
    }

    fn request_headers(&self, ctx: &DispatchContext) -> LlmixResult<HeaderMap> {
        let mut headers = self.default_headers.clone();
        headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));

        if !ctx.api_key.trim().is_empty() {
            let bearer = format!("Bearer {}", ctx.api_key.trim());
            let value = HeaderValue::from_str(&bearer)
                .map_err(|error| invalid_config(format!("invalid api key header: {error}")))?;
            headers.insert(AUTHORIZATION, value);
        }

        Ok(headers)
    }

    fn resolve_base_url(&self, body: &Map<String, Value>) -> LlmixResult<String> {
        let override_base_url = body
            .get("base_url")
            .or_else(|| body.get("baseUrl"))
            .and_then(Value::as_str)
            .map(str::trim)
            .filter(|value| !value.is_empty());

        match override_base_url.or(self.base_url.as_deref()) {
            Some(base_url) => normalize_base_url(base_url),
            None => Err(invalid_config(
                "OpenAI-compatible helper requires a non-empty base_url".to_string(),
            )),
        }
    }
}

#[async_trait]
impl DispatchFn for OpenAiChatHelper {
    async fn dispatch(&self, ctx: DispatchContext) -> LlmixResult<ProviderResult> {
        self.dispatch_with(ctx, |_ctx, _headers, _body| Ok(()))
            .await
    }
}

fn build_request_body(ctx: &DispatchContext) -> Map<String, Value> {
    let mut body = ctx.kwargs.clone();
    body.remove("base_url");
    body.remove("baseUrl");
    body.remove("top_k");
    body.remove("topK");
    body.insert("model".to_string(), Value::String(ctx.model.clone()));
    body.insert("messages".to_string(), Value::Array(ctx.messages.clone()));
    body
}

fn parse_provider_result(
    ctx: &DispatchContext,
    headers: &HashMap<String, String>,
    body: &[u8],
) -> LlmixResult<ProviderResult> {
    let payload: Value = serde_json::from_slice(body).map_err(|error| ProviderError {
        message: format!("invalid provider response: {error}"),
        status_code: None,
        headers: Some(headers.clone()),
    })?;

    let content = payload
        .pointer("/choices/0/message/content")
        .map(extract_content)
        .unwrap_or_default();
    let model = payload
        .get("model")
        .and_then(Value::as_str)
        .unwrap_or(&ctx.model)
        .to_string();
    let tool_calls = payload
        .pointer("/choices/0/message/tool_calls")
        .or_else(|| payload.pointer("/choices/0/message/toolCalls"))
        .and_then(Value::as_array)
        .cloned();

    Ok(ProviderResult {
        content,
        model,
        usage: extract_usage(&payload),
        headers: (!headers.is_empty()).then(|| headers.clone()),
        tool_calls,
    })
}

fn parse_provider_error(
    status_code: u16,
    headers: HashMap<String, String>,
    body: &[u8],
) -> LlmixError {
    let message = serde_json::from_slice::<Value>(body)
        .ok()
        .and_then(|payload| {
            payload
                .pointer("/error/message")
                .and_then(Value::as_str)
                .or_else(|| payload.get("message").and_then(Value::as_str))
                .map(ToOwned::to_owned)
        })
        .unwrap_or_else(|| {
            let text = String::from_utf8_lossy(body).trim().to_string();
            if text.is_empty() {
                format!("provider request failed with status {status_code}")
            } else {
                text
            }
        });

    ProviderError {
        message,
        status_code: Some(status_code),
        headers: Some(headers),
    }
    .into()
}

fn extract_usage(payload: &Value) -> LlmUsage {
    let prompt_tokens = payload
        .pointer("/usage/prompt_tokens")
        .and_then(value_as_u32)
        .unwrap_or(0);
    let completion_tokens = payload
        .pointer("/usage/completion_tokens")
        .and_then(value_as_u32)
        .unwrap_or(0);
    let total_tokens = payload
        .pointer("/usage/total_tokens")
        .and_then(value_as_u32)
        .unwrap_or_else(|| prompt_tokens.saturating_add(completion_tokens));

    LlmUsage {
        input_tokens: prompt_tokens,
        output_tokens: completion_tokens,
        total_tokens,
    }
}

fn extract_content(value: &Value) -> String {
    match value {
        Value::String(text) => text.clone(),
        Value::Array(parts) => parts
            .iter()
            .filter_map(extract_text_part)
            .collect::<Vec<_>>()
            .join(""),
        _ => String::new(),
    }
}

fn extract_text_part(value: &Value) -> Option<&str> {
    match value {
        Value::String(text) => Some(text.as_str()),
        Value::Object(map) => map
            .get("text")
            .and_then(Value::as_str)
            .or_else(|| map.get("content").and_then(Value::as_str)),
        _ => None,
    }
}

fn collect_headers(headers: &reqwest::header::HeaderMap) -> HashMap<String, String> {
    headers
        .iter()
        .filter_map(|(name, value)| {
            value
                .to_str()
                .ok()
                .map(|value| (name.as_str().to_string(), value.to_string()))
        })
        .collect()
}

fn normalize_base_url(base_url: &str) -> LlmixResult<String> {
    let trimmed = base_url.trim();
    if trimmed.is_empty() {
        return Err(invalid_config("base_url must not be empty".to_string()));
    }

    Url::parse(trimmed).map_err(|error| invalid_config(format!("invalid base_url: {error}")))?;
    Ok(trimmed.trim_end_matches('/').to_string())
}

fn value_as_u32(value: &Value) -> Option<u32> {
    value
        .as_u64()
        .and_then(|candidate| u32::try_from(candidate).ok())
        .or_else(|| {
            value
                .as_i64()
                .filter(|candidate| *candidate >= 0)
                .and_then(|candidate| u32::try_from(candidate).ok())
        })
}

fn invalid_config(message: String) -> LlmixError {
    InvalidConfigError { message }.into()
}

fn provider_transport_error(message: String) -> LlmixError {
    ProviderError {
        message,
        status_code: None,
        headers: None,
    }
    .into()
}