alith-interface 0.4.3

The Backend for the alith-client Crate
Documentation
use crate::requests::{completion::*, stop_sequence::StopSequences};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

#[derive(Clone, Serialize, Default, Debug, Deserialize)]
pub struct OpenAICompletionRequest {
    /// ID of the model to use.
    /// See the [model endpoint compatibility](https://platform.openai.com/docs/models/model-endpoint-compatibility) table for details on which models work with the Chat API.
    pub model: String,

    /// A list of messages comprising the conversation so far. [Example Python code](https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models).
    pub messages: Vec<CompletionRequestMessage>, // min: 1

    /// Modify the likelihood of specified tokens appearing in the completion.
    ///
    /// Accepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100.
    /// Mathematically, the bias is added to the logits generated by the model prior to sampling.
    /// The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection;
    /// values like -100 or 100 should result in a ban or exclusive selection of the relevant token.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub logit_bias: Option<HashMap<String, serde_json::Value>>, // default: null

    /// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the `content` of `message`.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub logprobs: Option<bool>,

    /// An integer between 0 and 20 specifying the number of most likely tokens to return at each token position, each with an associated log probability. `logprobs` must be set to `true` if this parameter is used.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub top_logprobs: Option<u8>,

    /// The maximum number of [tokens](https://platform.openai.com/tokenizer) that can be generated in the chat completion.
    ///
    /// The total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub max_tokens: Option<u64>,

    /// min: 0.0, max: 2.0, default: None
    #[serde(skip_serializing_if = "Option::is_none")]
    pub temperature: Option<f32>,

    /// min: -2.0, max: 2.0, default: None
    ///
    /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
    #[serde(skip_serializing_if = "Option::is_none")]
    pub frequency_penalty: Option<f32>,

    /// min: -2.0, max: 2.0, default: None
    ///
    /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
    #[serde(skip_serializing_if = "Option::is_none")]
    pub presence_penalty: Option<f32>,

    /// Up to 4 sequences where the API will stop generating further tokens.
    #[serde(skip_serializing_if = "Option::is_none")]
    pub stop: Option<Stop>,

    /// min: 0.0, max: 1.0, default: None
    #[serde(skip_serializing_if = "Option::is_none")]
    pub top_p: Option<f32>,

    /// The tools for the request, default: None
    #[serde(skip_serializing_if = "Option::is_none")]
    pub tools: Option<Vec<OpenAIToolDefinition>>,

    /// The tool choice for the request, default: None
    #[serde(skip_serializing_if = "Option::is_none")]
    pub tool_choice: Option<String>,
}

#[derive(Clone, Serialize, Debug, Deserialize)]
pub struct OpenAIToolDefinition {
    pub r#type: String,
    pub function: ToolDefinition,
}

impl OpenAICompletionRequest {
    pub fn new(req: &CompletionRequest) -> crate::Result<Self, CompletionError> {
        let mut messages = Vec::new();
        match &req.prompt.get_built_prompt_messages() {
            Ok(prompt_message) => {
                for m in prompt_message {
                    messages.push(CompletionRequestMessage::new(m)?);
                }
            }
            Err(e) => return Err(CompletionError::RequestBuilderError(e.to_string())),
        }

        Ok(OpenAICompletionRequest {
            messages,
            model: req.backend.model_id().to_owned(),
            logit_bias: req.logit_bias.as_ref().and_then(|lb| lb.get_openai()),
            frequency_penalty: req.config.frequency_penalty,
            logprobs: None,
            top_logprobs: None,
            max_tokens: req.config.actual_request_tokens,
            presence_penalty: Some(req.config.presence_penalty),
            stop: Stop::new(&req.stop_sequences)?,
            temperature: Some(req.config.temperature),
            top_p: req.config.top_p,
            tools: if !req.tools.is_empty() {
                Some(
                    req.tools
                        .iter()
                        .map(|tool| OpenAIToolDefinition {
                            r#type: "function".to_string(),
                            function: tool.clone(),
                        })
                        .collect(),
                )
            } else {
                None
            },
            tool_choice: if !req.tools.is_empty() {
                Some("auto".to_string())
            } else {
                None
            },
        })
    }
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct CompletionRequestMessage {
    pub role: String,
    pub content: String,
}

impl CompletionRequestMessage {
    pub fn new(
        message: &std::collections::HashMap<String, String>,
    ) -> crate::Result<Self, CompletionError> {
        let role = message
            .get("role")
            .ok_or_else(|| CompletionError::RequestBuilderError("Role not found".to_string()))?;
        let content = message
            .get("content")
            .ok_or_else(|| CompletionError::RequestBuilderError("Content not found".to_string()))?;

        match role.as_str() {
            "system" | "user" | "assistant" => Ok(CompletionRequestMessage {
                role: role.to_string(),
                content: content.to_string(),
            }),
            _ => Err(CompletionError::RequestBuilderError(format!(
                "Role {} not supported",
                role
            ))),
        }
    }
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(untagged)]
pub enum Stop {
    String(String),           // nullable: true
    StringArray(Vec<String>), // minItems: 1; maxItems: 4
}

impl Stop {
    pub fn new(stop_sequences: &StopSequences) -> crate::Result<Option<Self>, CompletionError> {
        match stop_sequences.sequences.len() {
            0 => Ok(None),
            1..=4 => Ok(Some(Stop::StringArray(stop_sequences.to_vec()))),
            _ => Err(CompletionError::RequestBuilderError(
                "OpenAI Stop array cannot have more than 4 elements".to_string(),
            )),
        }
    }
}