1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
use crate::requests::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Clone, Serialize, Default, Debug, Deserialize, PartialEq)]
pub struct OpenAiCompletionRequest {
/// ID of the model to use.
/// See the [model endpoint compatibility](https://platform.openai.com/docs/models/model-endpoint-compatibility) table for details on which models work with the Chat API.
pub model: String,
/// A list of messages comprising the conversation so far. [Example Python code](https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models).
pub messages: Vec<CompletionRequestMessage>, // min: 1
/// Modify the likelihood of specified tokens appearing in the completion.
///
/// Accepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100.
/// Mathematically, the bias is added to the logits generated by the model prior to sampling.
/// The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection;
/// values like -100 or 100 should result in a ban or exclusive selection of the relevant token.
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<HashMap<String, serde_json::Value>>, // default: null
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the `content` of `message`.
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<bool>,
/// An integer between 0 and 20 specifying the number of most likely tokens to return at each token position, each with an associated log probability. `logprobs` must be set to `true` if this parameter is used.
#[serde(skip_serializing_if = "Option::is_none")]
pub top_logprobs: Option<u8>,
/// The maximum number of [tokens](https://platform.openai.com/tokenizer) that can be generated in the chat completion.
///
/// The total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u64>,
/// min: 0.0, max: 2.0, default: None
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
/// min: -2.0, max: 2.0, default: None
///
/// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
/// min: -2.0, max: 2.0, default: None
///
/// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
/// Up to 4 sequences where the API will stop generating further tokens.
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Stop>,
/// min: 0.0, max: 1.0, default: None
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
}
impl OpenAiCompletionRequest {
pub fn new(req: &CompletionRequest) -> crate::Result<Self, CompletionError> {
let mut messages = Vec::new();
match &req
.prompt
.api_prompt()
.map_err(|e| CompletionError::RequestBuilderError(e.to_string()))?
.get_built_prompt()
{
Ok(prompt_message) => {
for m in prompt_message {
messages.push(CompletionRequestMessage::new(m).unwrap());
}
}
Err(e) => return Err(CompletionError::RequestBuilderError(e.to_string())),
}
Ok(OpenAiCompletionRequest {
messages,
model: req.backend.model_id().to_owned(),
logit_bias: req.logit_bias.as_ref().and_then(|lb| lb.get_openai()),
frequency_penalty: req.config.frequency_penalty,
logprobs: None,
top_logprobs: None,
max_tokens: req.config.actual_request_tokens,
presence_penalty: Some(req.config.presence_penalty),
stop: Stop::new(&req.stop_sequences)?,
temperature: Some(req.config.temperature),
top_p: req.config.top_p,
})
}
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct CompletionRequestMessage {
pub role: String,
pub content: String,
}
impl CompletionRequestMessage {
pub fn new(
message: &std::collections::HashMap<String, String>,
) -> crate::Result<Self, CompletionError> {
let role = message
.get("role")
.ok_or_else(|| CompletionError::RequestBuilderError("Role not found".to_string()))?;
let content = message
.get("content")
.ok_or_else(|| CompletionError::RequestBuilderError("Content not found".to_string()))?;
match role.as_str() {
"system" | "user" | "assistant" => Ok(CompletionRequestMessage {
role: role.to_string(),
content: content.to_string(),
}),
_ => Err(CompletionError::RequestBuilderError(format!(
"Role {} not supported",
role
))),
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(untagged)]
pub enum Stop {
String(String), // nullable: true
StringArray(Vec<String>), // minItems: 1; maxItems: 4
}
impl Stop {
pub fn new(stop_sequences: &StopSequences) -> crate::Result<Option<Self>, CompletionError> {
match stop_sequences.sequences.len() {
0 => Ok(None),
1..=4 => Ok(Some(Stop::StringArray(stop_sequences.to_vec()))),
_ => Err(CompletionError::RequestBuilderError(
"OpenAI Stop array cannot have more than 4 elements".to_string(),
)),
}
}
}