pharia_skill/csi/inference.rs
1use serde::{Deserialize, Serialize};
2
3/// The reason that the model stopped completing text
4#[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq)]
5#[serde(rename_all = "snake_case")]
6pub enum FinishReason {
7 /// The model hit a natural stopping point or a provided stop sequence
8 Stop,
9 /// The maximum number of tokens specified in the request was reached
10 Length,
11 /// Content was omitted due to a flag from content filters
12 ContentFilter,
13}
14
15#[derive(Clone, Debug, Default, Serialize)]
16#[serde(rename_all = "snake_case")]
17pub enum Logprobs {
18 #[default]
19 No,
20 Sampled,
21 Top(u8),
22}
23
24#[derive(Clone, Debug, Deserialize, PartialEq)]
25pub struct Logprob {
26 pub token: Vec<u8>,
27 pub logprob: f64,
28}
29
30#[derive(Clone, Debug, Deserialize, PartialEq)]
31pub struct Distribution {
32 pub sampled: Logprob,
33 pub top: Vec<Logprob>,
34}
35
36#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
37pub struct TokenUsage {
38 pub prompt: u32,
39 pub completion: u32,
40}
41
42/// Completion request parameters
43#[derive(Clone, Debug, Serialize)]
44pub struct CompletionParams {
45 /// The maximum tokens that should be inferred.
46 ///
47 /// Note: the backing implementation may return less tokens due to
48 /// other stop reasons.
49 pub max_tokens: Option<u32>,
50 /// The randomness with which the next token is selected.
51 pub temperature: Option<f64>,
52 /// The number of possible next tokens the model will choose from.
53 pub top_k: Option<u32>,
54 /// The probability total of next tokens the model will choose from.
55 pub top_p: Option<f64>,
56 /// A list of sequences that, if encountered, the API will stop generating further tokens.
57 pub stop: Vec<String>,
58 /// Whether to include special tokens like `<|eot_id|>` in the completion
59 pub return_special_tokens: bool,
60 /// When specified, this number will decrease (or increase) the probability of repeating
61 /// tokens that were mentioned prior in the completion. The penalty is cumulative. The more
62 /// a token is mentioned in the completion, the more its probability will decrease.
63 /// A negative value will increase the likelihood of repeating tokens.
64 pub frequency_penalty: Option<f64>,
65 /// The presence penalty reduces the probability of generating tokens that are already
66 /// present in the generated text respectively prompt. Presence penalty is independent of the
67 /// number of occurrences. Increase the value to reduce the probability of repeating text.
68 pub presence_penalty: Option<f64>,
69 /// Use this to control the logarithmic probabilities you want to have returned. This is useful
70 /// to figure out how likely it had been that this specific token had been sampled.
71 pub logprobs: Logprobs,
72}
73
74impl Default for CompletionParams {
75 fn default() -> Self {
76 Self {
77 return_special_tokens: true,
78 max_tokens: None,
79 temperature: None,
80 top_k: None,
81 top_p: None,
82 stop: Vec::new(),
83 frequency_penalty: None,
84 presence_penalty: None,
85 logprobs: Logprobs::default(),
86 }
87 }
88}
89
90/// Parameters required to make a completion request.
91#[derive(Clone, Debug, Serialize)]
92pub struct CompletionRequest {
93 /// The model to generate a completion from.
94 pub model: String,
95 /// The text to prompt the model with.
96 pub prompt: String,
97 /// Parameters to adjust the sampling behavior of the model.
98 pub params: CompletionParams,
99}
100
101impl CompletionRequest {
102 pub fn new(model: impl Into<String>, prompt: impl Into<String>) -> Self {
103 Self {
104 model: model.into(),
105 prompt: prompt.into(),
106 params: CompletionParams::default(),
107 }
108 }
109
110 #[must_use]
111 pub fn with_params(mut self, params: CompletionParams) -> Self {
112 self.params = params;
113 self
114 }
115}
116
117/// The result of a completion, including the text generated as well as
118/// why the model finished completing.
119#[derive(Clone, Debug, Deserialize, PartialEq)]
120pub struct Completion {
121 /// The text generated by the model
122 pub text: String,
123 /// The reason the model finished generating
124 pub finish_reason: FinishReason,
125 /// Contains the logprobs for the sampled and top n tokens, given that
126 /// `completion-request.params.logprobs` has been set to `sampled` or `top`.
127 pub logprobs: Vec<Distribution>,
128 /// Usage statistics for the completion request.
129 pub usage: TokenUsage,
130}
131
132#[derive(Clone, Debug, Serialize, Deserialize)]
133pub struct Message {
134 pub role: String,
135 pub content: String,
136}
137
138impl Message {
139 pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
140 Self {
141 role: role.into(),
142 content: content.into(),
143 }
144 }
145
146 pub fn user(content: impl Into<String>) -> Self {
147 Self::new("user", content)
148 }
149
150 pub fn assistant(content: impl Into<String>) -> Self {
151 Self::new("assistant", content)
152 }
153
154 pub fn system(content: impl Into<String>) -> Self {
155 Self::new("system", content)
156 }
157}
158
159#[derive(Clone, Debug, Default, Serialize)]
160pub struct ChatParams {
161 /// The maximum tokens that should be inferred.
162 ///
163 /// Note: the backing implementation may return less tokens due to
164 /// other stop reasons.
165 pub max_tokens: Option<u32>,
166 /// The randomness with which the next token is selected.
167 pub temperature: Option<f64>,
168 /// The probability total of next tokens the model will choose from.
169 pub top_p: Option<f64>,
170 /// When specified, this number will decrease (or increase) the probability of repeating
171 /// tokens that were mentioned prior in the completion. The penalty is cumulative. The more
172 /// a token is mentioned in the completion, the more its probability will decrease.
173 /// A negative value will increase the likelihood of repeating tokens.
174 pub frequency_penalty: Option<f64>,
175 /// The presence penalty reduces the probability of generating tokens that are already
176 /// present in the generated text respectively prompt. Presence penalty is independent of the
177 /// number of occurrences. Increase the value to reduce the probability of repeating text.
178 pub presence_penalty: Option<f64>,
179 /// Use this to control the logarithmic probabilities you want to have returned. This is useful
180 /// to figure out how likely it had been that this specific token had been sampled.
181 pub logprobs: Logprobs,
182}
183
184#[derive(Clone, Debug, Serialize)]
185pub struct ChatRequest {
186 pub model: String,
187 pub messages: Vec<Message>,
188 pub params: ChatParams,
189}
190
191impl ChatRequest {
192 pub fn new(model: impl Into<String>, message: Message) -> Self {
193 Self {
194 model: model.into(),
195 messages: vec![message],
196 params: ChatParams::default(),
197 }
198 }
199
200 #[must_use]
201 pub fn and_message(mut self, message: Message) -> Self {
202 self.messages.push(message);
203 self
204 }
205
206 #[must_use]
207 pub fn with_params(mut self, params: ChatParams) -> Self {
208 self.params = params;
209 self
210 }
211}
212
213#[derive(Clone, Debug, Deserialize)]
214pub struct ChatResponse {
215 /// The message generated by the model
216 pub message: Message,
217 /// The reason the model finished generating
218 pub finish_reason: FinishReason,
219 /// Contains the logprobs for the sampled and top n tokens, given that
220 /// `completion-request.params.logprobs` has been set to `sampled` or `top`.
221 pub logprobs: Vec<Distribution>,
222 /// Usage statistics for the completion request.
223 pub usage: TokenUsage,
224}