llm_interface 0.0.3

llm_interface: The backend for the llm_client crate
Documentation
use crate::requests::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

#[derive(Clone, Serialize, Default, Debug, Deserialize, PartialEq)]
pub struct OpenAiCompletionRequest {
    /// ID of the model to use.
    /// See the [model endpoint compatibility](https://platform.openai.com/docs/models/model-endpoint-compatibility) table for details on which models work with the Chat API.
    pub model: String,

    /// A list of messages comprising the conversation so far. [Example Python code](https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models).
    pub messages: Vec<CompletionRequestMessage>, // min: 1

    /// Modify the likelihood of specified tokens appearing in the completion.
    ///
    /// Accepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100.
    /// Mathematically, the bias is added to the logits generated by the model prior to sampling.
    /// The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection;
    /// values like -100 or 100 should result in a ban or exclusive selection of the relevant token.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub logit_bias: Option<HashMap<String, serde_json::Value>>, // default: null

    /// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the `content` of `message`.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub logprobs: Option<bool>,

    /// An integer between 0 and 20 specifying the number of most likely tokens to return at each token position, each with an associated log probability. `logprobs` must be set to `true` if this parameter is used.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub top_logprobs: Option<u8>,

    /// The maximum number of [tokens](https://platform.openai.com/tokenizer) that can be generated in the chat completion.
    ///
    /// The total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub max_tokens: Option<u64>,

    /// min: 0.0, max: 2.0, default: None
    #[serde(skip_serializing_if = "Option::is_none")]
    pub temperature: Option<f32>,

    /// min: -2.0, max: 2.0, default: None
    ///
    /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
    #[serde(skip_serializing_if = "Option::is_none")]
    pub frequency_penalty: Option<f32>,

    /// min: -2.0, max: 2.0, default: None
    ///
    /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
    #[serde(skip_serializing_if = "Option::is_none")]
    pub presence_penalty: Option<f32>,

    /// Up to 4 sequences where the API will stop generating further tokens.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub stop: Option<Stop>,

    /// min: 0.0, max: 1.0, default: None
    #[serde(skip_serializing_if = "Option::is_none")]
    pub top_p: Option<f32>,
}

impl OpenAiCompletionRequest {
    pub fn new(req: &CompletionRequest) -> crate::Result<Self, CompletionError> {
        let mut messages = Vec::new();
        match &req
            .prompt
            .api_prompt()
            .map_err(|e| CompletionError::RequestBuilderError(e.to_string()))?
            .get_built_prompt()
        {
            Ok(prompt_message) => {
                for m in prompt_message {
                    messages.push(CompletionRequestMessage::new(m).unwrap());
                }
            }
            Err(e) => return Err(CompletionError::RequestBuilderError(e.to_string())),
        }

        Ok(OpenAiCompletionRequest {
            messages,
            model: req.backend.model_id().to_owned(),
            logit_bias: req.logit_bias.as_ref().and_then(|lb| lb.get_openai()),
            frequency_penalty: req.config.frequency_penalty,
            logprobs: None,
            top_logprobs: None,
            max_tokens: req.config.actual_request_tokens,
            presence_penalty: Some(req.config.presence_penalty),
            stop: Stop::new(&req.stop_sequences)?,
            temperature: Some(req.config.temperature),
            top_p: req.config.top_p,
        })
    }
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct CompletionRequestMessage {
    pub role: String,
    pub content: String,
}

impl CompletionRequestMessage {
    pub fn new(
        message: &std::collections::HashMap<String, String>,
    ) -> crate::Result<Self, CompletionError> {
        let role = message
            .get("role")
            .ok_or_else(|| CompletionError::RequestBuilderError("Role not found".to_string()))?;
        let content = message
            .get("content")
            .ok_or_else(|| CompletionError::RequestBuilderError("Content not found".to_string()))?;

        match role.as_str() {
            "system" | "user" | "assistant" => Ok(CompletionRequestMessage {
                role: role.to_string(),
                content: content.to_string(),
            }),
            _ => Err(CompletionError::RequestBuilderError(format!(
                "Role {} not supported",
                role
            ))),
        }
    }
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(untagged)]
pub enum Stop {
    String(String),           // nullable: true
    StringArray(Vec<String>), // minItems: 1; maxItems: 4
}

impl Stop {
    pub fn new(stop_sequences: &StopSequences) -> crate::Result<Option<Self>, CompletionError> {
        match stop_sequences.sequences.len() {
            0 => Ok(None),
            1..=4 => Ok(Some(Stop::StringArray(stop_sequences.to_vec()))),
            _ => Err(CompletionError::RequestBuilderError(
                "OpenAI Stop array cannot have more than 4 elements".to_string(),
            )),
        }
    }
}