Skip to main content

alith_interface/llms/api/openai/completion/
req.rs

1use crate::requests::{completion::*, stop_sequence::StopSequences};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4
5#[derive(Clone, Serialize, Default, Debug, Deserialize)]
6pub struct OpenAICompletionRequest {
7    /// ID of the model to use.
8    /// See the [model endpoint compatibility](https://platform.openai.com/docs/models/model-endpoint-compatibility) table for details on which models work with the Chat API.
9    pub model: String,
10
11    /// A list of messages comprising the conversation so far. [Example Python code](https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models).
12    pub messages: Vec<CompletionRequestMessage>, // min: 1
13
14    /// Modify the likelihood of specified tokens appearing in the completion.
15    ///
16    /// Accepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100.
17    /// Mathematically, the bias is added to the logits generated by the model prior to sampling.
18    /// The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection;
19    /// values like -100 or 100 should result in a ban or exclusive selection of the relevant token.
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub logit_bias: Option<HashMap<String, serde_json::Value>>, // default: null
22
23    /// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the `content` of `message`.
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub logprobs: Option<bool>,
26
27    /// An integer between 0 and 20 specifying the number of most likely tokens to return at each token position, each with an associated log probability. `logprobs` must be set to `true` if this parameter is used.
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub top_logprobs: Option<u8>,
30
31    /// The maximum number of [tokens](https://platform.openai.com/tokenizer) that can be generated in the chat completion.
32    ///
33    /// The total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub max_tokens: Option<u64>,
36
37    /// min: 0.0, max: 2.0, default: None
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub temperature: Option<f32>,
40
41    /// min: -2.0, max: 2.0, default: None
42    ///
43    /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub frequency_penalty: Option<f32>,
46
47    /// min: -2.0, max: 2.0, default: None
48    ///
49    /// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub presence_penalty: Option<f32>,
52
53    /// Up to 4 sequences where the API will stop generating further tokens.
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub stop: Option<Stop>,
56
57    /// min: 0.0, max: 1.0, default: None
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub top_p: Option<f32>,
60
61    /// The tools for the request, default: None
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub tools: Option<Vec<OpenAIToolDefinition>>,
64
65    /// The tool choice for the request, default: None
66    #[serde(skip_serializing_if = "Option::is_none")]
67    pub tool_choice: Option<String>,
68}
69
70#[derive(Clone, Serialize, Debug, Deserialize)]
71pub struct OpenAIToolDefinition {
72    pub r#type: String,
73    pub function: ToolDefinition,
74}
75
76impl OpenAICompletionRequest {
77    pub fn new(req: &CompletionRequest) -> crate::Result<Self, CompletionError> {
78        let mut messages = Vec::new();
79        match &req.prompt.get_built_prompt_messages() {
80            Ok(prompt_message) => {
81                for m in prompt_message {
82                    messages.push(CompletionRequestMessage::new(m)?);
83                }
84            }
85            Err(e) => return Err(CompletionError::RequestBuilderError(e.to_string())),
86        }
87
88        Ok(OpenAICompletionRequest {
89            messages,
90            model: req.backend.model_id().to_owned(),
91            logit_bias: req.logit_bias.as_ref().and_then(|lb| lb.get_openai()),
92            frequency_penalty: req.config.frequency_penalty,
93            logprobs: None,
94            top_logprobs: None,
95            max_tokens: req.config.actual_request_tokens,
96            presence_penalty: Some(req.config.presence_penalty),
97            stop: Stop::new(&req.stop_sequences)?,
98            temperature: Some(req.config.temperature),
99            top_p: req.config.top_p,
100            tools: if !req.tools.is_empty() {
101                Some(
102                    req.tools
103                        .iter()
104                        .map(|tool| OpenAIToolDefinition {
105                            r#type: "function".to_string(),
106                            function: tool.clone(),
107                        })
108                        .collect(),
109                )
110            } else {
111                None
112            },
113            tool_choice: if !req.tools.is_empty() {
114                Some("auto".to_string())
115            } else {
116                None
117            },
118        })
119    }
120}
121
122#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
123pub struct CompletionRequestMessage {
124    pub role: String,
125    pub content: String,
126}
127
128impl CompletionRequestMessage {
129    pub fn new(
130        message: &std::collections::HashMap<String, String>,
131    ) -> crate::Result<Self, CompletionError> {
132        let role = message
133            .get("role")
134            .ok_or_else(|| CompletionError::RequestBuilderError("Role not found".to_string()))?;
135        let content = message
136            .get("content")
137            .ok_or_else(|| CompletionError::RequestBuilderError("Content not found".to_string()))?;
138
139        match role.as_str() {
140            "system" | "user" | "assistant" => Ok(CompletionRequestMessage {
141                role: role.to_string(),
142                content: content.to_string(),
143            }),
144            _ => Err(CompletionError::RequestBuilderError(format!(
145                "Role {} not supported",
146                role
147            ))),
148        }
149    }
150}
151
152#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
153#[serde(untagged)]
154pub enum Stop {
155    String(String),           // nullable: true
156    StringArray(Vec<String>), // minItems: 1; maxItems: 4
157}
158
159impl Stop {
160    pub fn new(stop_sequences: &StopSequences) -> crate::Result<Option<Self>, CompletionError> {
161        match stop_sequences.sequences.len() {
162            0 => Ok(None),
163            1..=4 => Ok(Some(Stop::StringArray(stop_sequences.to_vec()))),
164            _ => Err(CompletionError::RequestBuilderError(
165                "OpenAI Stop array cannot have more than 4 elements".to_string(),
166            )),
167        }
168    }
169}