llmix-rs 2.0.6

Rust binding for the LLMix orchestration contract with cache, resilience, and config parity
Documentation
use crate::dispatch::DispatchFn;
use crate::error::{InvalidConfigError, LlmixResult};
use crate::providers::openai::OpenAiChatHelper;
use crate::types::{DispatchContext, ProviderResult};
use async_trait::async_trait;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use serde_json::{Map, Value};

const INTERNAL_TOKEN_HEADER: &str = "x-internal-token";

#[derive(Clone)]
pub struct SnoGpuChatHelper {
    inner: OpenAiChatHelper,
    internal_token: Option<String>,
}

impl Default for SnoGpuChatHelper {
    fn default() -> Self {
        Self::new()
    }
}

impl SnoGpuChatHelper {
    pub fn new() -> Self {
        Self {
            inner: OpenAiChatHelper::new().without_base_url(),
            internal_token: None,
        }
    }

    pub fn with_base_url(mut self, base_url: impl AsRef<str>) -> LlmixResult<Self> {
        self.inner = self.inner.clone().with_base_url(base_url)?;
        Ok(self)
    }

    pub fn with_client(mut self, client: reqwest::Client) -> Self {
        self.inner = self.inner.clone().with_client(client);
        self
    }

    pub fn with_internal_token(mut self, internal_token: impl Into<String>) -> Self {
        self.internal_token = Some(internal_token.into());
        self
    }
}

#[async_trait]
impl DispatchFn for SnoGpuChatHelper {
    async fn dispatch(&self, ctx: DispatchContext) -> LlmixResult<ProviderResult> {
        let internal_token = self.internal_token.clone();
        self.inner
            .dispatch_with(ctx, move |ctx, headers, body| {
                insert_internal_token_header(headers, internal_token.as_deref())?;
                inject_sno_gpu_extra_body(ctx, body);
                Ok(())
            })
            .await
    }
}

fn insert_internal_token_header(
    headers: &mut HeaderMap,
    internal_token: Option<&str>,
) -> LlmixResult<()> {
    let Some(internal_token) = internal_token
        .map(str::trim)
        .filter(|value| !value.is_empty())
    else {
        return Ok(());
    };

    let header_name = HeaderName::from_static(INTERNAL_TOKEN_HEADER);
    let header_value =
        HeaderValue::from_str(internal_token).map_err(|error| InvalidConfigError {
            message: format!("invalid x-internal-token header: {error}"),
        })?;
    headers.insert(header_name, header_value);
    Ok(())
}

fn inject_sno_gpu_extra_body(ctx: &DispatchContext, body: &mut Map<String, Value>) {
    let enable_thinking = body
        .get("enable_thinking")
        .or_else(|| body.get("enableThinking"))
        .and_then(Value::as_bool)
        .or_else(|| common_bool(&ctx.config, &["enable_thinking", "enableThinking"]))
        .or_else(|| provider_option_bool(&ctx.config, &["enable_thinking", "enableThinking"]));

    body.remove("enable_thinking");
    body.remove("enableThinking");

    let Some(enable_thinking) = enable_thinking else {
        return;
    };

    let mut extra_body = remove_object_alias(body, &["extra_body", "extraBody"]);
    extra_body.insert("enable_thinking".to_string(), Value::Bool(enable_thinking));

    let mut chat_template_kwargs = remove_object_alias(
        &mut extra_body,
        &["chat_template_kwargs", "chatTemplateKwargs"],
    );
    chat_template_kwargs.insert("enable_thinking".to_string(), Value::Bool(enable_thinking));
    extra_body.insert(
        "chat_template_kwargs".to_string(),
        Value::Object(chat_template_kwargs),
    );
    body.insert("extra_body".to_string(), Value::Object(extra_body));
}

fn remove_object_alias(map: &mut Map<String, Value>, keys: &[&str]) -> Map<String, Value> {
    for key in keys {
        if let Some(Value::Object(object)) = map.remove(*key) {
            return object;
        }
    }
    Map::new()
}

fn common_bool(config: &Value, keys: &[&str]) -> Option<bool> {
    config
        .as_object()
        .and_then(|config| config.get("common"))
        .and_then(Value::as_object)
        .and_then(|common| get_bool_alias(common, keys))
}

fn provider_option_bool(config: &Value, keys: &[&str]) -> Option<bool> {
    config
        .as_object()
        .and_then(|config| {
            config
                .get("provider_options")
                .or_else(|| config.get("providerOptions"))
        })
        .and_then(Value::as_object)
        .and_then(|provider_options| provider_options.get("sno-gpu"))
        .and_then(Value::as_object)
        .and_then(|provider_options| get_bool_alias(provider_options, keys))
}

fn get_bool_alias(map: &Map<String, Value>, keys: &[&str]) -> Option<bool> {
    keys.iter()
        .find_map(|key| map.get(*key))
        .and_then(Value::as_bool)
}