asla 0.1.5

An absurdly simple LLM API client for Rust
Documentation
use crate::error::Result;
use bon::{Builder, bon, builder};
use reqwest::Client;
use serde::Serialize;
use serde_json::{Value, json};
use std::collections::HashMap;
use std::time::Duration;
use tap::Pipe;

#[derive(Debug, Clone, Serialize, PartialEq)]
pub enum MessageRole {
    #[serde(rename = "user")]
    User,
    #[serde(rename = "assistant")]
    Assistant,
}

#[derive(Debug, Clone, Serialize)]
pub struct Message {
    pub role: MessageRole,
    pub content: String,
}

impl Message {
    pub fn new(role: MessageRole, content: impl Into<String>) -> Self {
        Self {
            role,
            content: content.into(),
        }
    }

    pub fn user(content: impl Into<String>) -> Self {
        Self::new(MessageRole::User, content)
    }

    pub fn assistant(content: impl Into<String>) -> Self {
        Self::new(MessageRole::Assistant, content)
    }

    pub fn is_user(&self) -> bool {
        self.role == MessageRole::User
    }

    pub fn is_assistant(&self) -> bool {
        self.role == MessageRole::Assistant
    }
}

#[derive(Debug, Clone, Serialize)]
pub struct ResponseFormat {
    #[serde(rename = "type")]
    pub format_type: String,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub schema: Option<Value>,
}

impl ResponseFormat {
    pub fn json_object() -> Self {
        Self {
            format_type: "json_object".to_string(),
            schema: None,
        }
    }

    pub fn json_schema(schema: Value) -> Self {
        Self {
            format_type: "json_object".to_string(),
            schema: Some(schema),
        }
    }
}

#[derive(Debug, Clone, Default, Builder)]
pub struct GenerationConfig {
    // Core generation parameters
    pub max_tokens: Option<u32>,
    pub temperature: Option<f32>,
    pub top_p: Option<f32>,
    pub top_k: Option<u32>,
    pub min_p: Option<f32>,
    pub typical_p: Option<f32>,

    // Penalty parameters
    pub frequency_penalty: Option<f32>,
    pub presence_penalty: Option<f32>,
    pub repeat_penalty: Option<f32>,
    pub repeat_last_n: Option<u32>,

    // Advanced sampling
    pub tfs_z: Option<f32>,
    pub mirostat_mode: Option<u32>,
    pub mirostat_tau: Option<f32>,
    pub mirostat_eta: Option<f32>,
    pub samplers: Option<Vec<String>>,

    // Control parameters
    pub seed: Option<i32>,
    pub stop: Option<Vec<String>>,
    pub stream: Option<bool>,
    pub echo: Option<bool>,
    pub logprobs: Option<u32>,
    pub top_logprobs: Option<u32>,
    pub logit_bias: Option<HashMap<String, f32>>,

    // Structured output
    pub response_format: Option<ResponseFormat>,
    pub grammar: Option<String>,

    // Other
    pub n: Option<u32>,
    pub penalize_nl: Option<bool>,
    pub ignore_eos: Option<bool>,
}

impl GenerationConfig {
    /// Merge two configs, with the override config taking precedence
    pub fn merge(&self, override_config: &GenerationConfig) -> Self {
        Self {
            max_tokens: override_config.max_tokens.or(self.max_tokens),
            temperature: override_config.temperature.or(self.temperature),
            top_p: override_config.top_p.or(self.top_p),
            top_k: override_config.top_k.or(self.top_k),
            min_p: override_config.min_p.or(self.min_p),
            typical_p: override_config.typical_p.or(self.typical_p),
            frequency_penalty: override_config.frequency_penalty.or(self.frequency_penalty),
            presence_penalty: override_config.presence_penalty.or(self.presence_penalty),
            repeat_penalty: override_config.repeat_penalty.or(self.repeat_penalty),
            repeat_last_n: override_config.repeat_last_n.or(self.repeat_last_n),
            tfs_z: override_config.tfs_z.or(self.tfs_z),
            mirostat_mode: override_config.mirostat_mode.or(self.mirostat_mode),
            mirostat_tau: override_config.mirostat_tau.or(self.mirostat_tau),
            mirostat_eta: override_config.mirostat_eta.or(self.mirostat_eta),
            samplers: override_config
                .samplers
                .clone()
                .or_else(|| self.samplers.clone()),
            seed: override_config.seed.or(self.seed),
            stop: override_config.stop.clone().or_else(|| self.stop.clone()),
            stream: override_config.stream.or(self.stream),
            echo: override_config.echo.or(self.echo),
            logprobs: override_config.logprobs.or(self.logprobs),
            top_logprobs: override_config.top_logprobs.or(self.top_logprobs),
            logit_bias: override_config
                .logit_bias
                .clone()
                .or_else(|| self.logit_bias.clone()),
            response_format: override_config
                .response_format
                .clone()
                .or_else(|| self.response_format.clone()),
            grammar: override_config
                .grammar
                .clone()
                .or_else(|| self.grammar.clone()),
            n: override_config.n.or(self.n),
            penalize_nl: override_config.penalize_nl.or(self.penalize_nl),
            ignore_eos: override_config.ignore_eos.or(self.ignore_eos),
        }
    }

    /// Convert config to JSON, only including set values
    pub fn to_json(&self) -> Value {
        let mut obj = serde_json::Map::new();

        macro_rules! add_if_some {
            ($field:ident) => {
                if let Some(val) = &self.$field {
                    obj.insert(stringify!($field).to_string(), json!(val));
                }
            };
        }

        add_if_some!(max_tokens);
        add_if_some!(temperature);
        add_if_some!(top_p);
        add_if_some!(top_k);
        add_if_some!(min_p);
        add_if_some!(typical_p);
        add_if_some!(frequency_penalty);
        add_if_some!(presence_penalty);
        add_if_some!(repeat_penalty);
        add_if_some!(repeat_last_n);
        add_if_some!(tfs_z);
        add_if_some!(mirostat_mode);
        add_if_some!(mirostat_tau);
        add_if_some!(mirostat_eta);
        add_if_some!(samplers);
        add_if_some!(seed);
        add_if_some!(stop);
        add_if_some!(stream);
        add_if_some!(echo);
        add_if_some!(logprobs);
        add_if_some!(top_logprobs);
        add_if_some!(logit_bias);
        add_if_some!(response_format);
        add_if_some!(grammar);
        add_if_some!(n);
        add_if_some!(penalize_nl);
        add_if_some!(ignore_eos);

        Value::Object(obj)
    }
}

#[derive(Debug, Clone, Builder)]
pub struct Chat {
    pub system_prompt: Option<String>,
    #[builder(default)]
    pub messages: Vec<Message>,
}

impl Chat {
    pub fn new() -> Self {
        Self {
            system_prompt: None,
            messages: Vec::new(),
        }
    }

    pub fn add_message(&mut self, message: Message) {
        self.messages.push(message);
    }

    pub fn to_json(&self) -> Vec<Value> {
        let mut json_messages = Vec::new();

        if let Some(system) = &self.system_prompt {
            json_messages.push(json!({
                "role": "system",
                "content": system
            }));
        }

        json_messages.extend(self.messages.iter().map(|message| {
            json!({
                "role": message.role,
                "content": message.content
            })
        }));

        json_messages
    }

    pub fn len_with_system_prompt(&self) -> usize {
        self.messages.len() + self.system_prompt.is_some() as usize
    }
}

#[derive(Debug, Clone, Builder)]
#[builder(on(String, into))]
pub struct LlmClient {
    client: Client,
    api_url: String,
    api_key: Option<String>,
    #[builder(default = GenerationConfig {
        max_tokens: Some(512),
        temperature: Some(0.7),
        frequency_penalty: Some(0.0),
        ..Default::default()
    })]
    default_config: GenerationConfig,
}

#[bon]
impl LlmClient {
    pub fn builder_with_default_client() -> Result<LlmClientBuilder<llm_client_builder::SetClient>>
    {
        Ok(LlmClient::builder().client(
            reqwest::ClientBuilder::new()
                .connect_timeout(Duration::from_secs(10))
                .read_timeout(Duration::from_secs(360))
                .build()?,
        ))
    }

    #[builder]
    pub async fn generate_response_and_add_to_chat(
        &self,
        chat: &mut Chat,
        config: Option<GenerationConfig>,
        api_url: Option<&str>,
        api_key: Option<&str>,
    ) -> Result<String> {
        let messages = chat.to_json();
        let response = self.call_api(messages, config, api_url, api_key).await?;
        chat.add_message(Message::new(MessageRole::Assistant, &response));
        Ok(response)
    }

    #[builder]
    pub async fn generate_response(
        &self,
        chat: &Chat,
        config: Option<GenerationConfig>,
        api_url: Option<&str>,
        api_key: Option<&str>,
    ) -> Result<String> {
        let messages = chat.to_json();
        self.call_api(messages, config, api_url, api_key).await
    }

    #[cfg(feature = "mock_llm_api")]
    async fn call_api(
        &self,
        messages: Vec<Value>,
        _config: Option<GenerationConfig>,
        _api_url: Option<&str>,
        _api_key: Option<&str>,
    ) -> Result<String> {
        Ok(messages
            .iter()
            .map(|message| message.to_string())
            .collect::<Vec<String>>()
            .join("\n"))
    }

    #[cfg(not(feature = "mock_llm_api"))]
    async fn call_api(
        &self,
        messages: Vec<Value>,
        config: Option<GenerationConfig>,
        api_url: Option<&str>,
        api_key: Option<&str>,
    ) -> Result<String> {
        use crate::error::Error;

        let api_url = api_url.unwrap_or(&self.api_url);
        let api_key = api_key.or(self.api_key.as_deref());

        // Merge default config with provided config
        let final_config = if let Some(override_config) = config {
            self.default_config.merge(&override_config)
        } else {
            self.default_config.clone()
        };

        // Build JSON body
        let mut body = json!({
            "messages": messages,
        });

        // Merge in the config
        if let Value::Object(config_map) = final_config.to_json() {
            if let Value::Object(body_map) = &mut body {
                body_map.extend(config_map);
            }
        }

        let response = self
            .client
            .post(api_url)
            .pipe(|req| {
                if let Some(key) = api_key {
                    req.header("Authorization", format!("Bearer {}", key))
                } else {
                    req
                }
            })
            .header("Content-Type", "application/json")
            .json(&body)
            .send()
            .await?
            .json::<Value>()
            .await?;

        let content = response
            .pointer("/choices/0/message/content")
            .ok_or(Error::IssueWithLlmApiReturnedJson)?
            .as_str()
            .ok_or(Error::FailedToExtractResponseContent)?
            .to_string();

        Ok(content)
    }
}