aleph_alpha_client/chat.rs
1use core::str;
2use std::borrow::Cow;
3
4use serde::{Deserialize, Serialize};
5
6use crate::{
7 logprobs::{Logprob, Logprobs},
8 Stopping, StreamTask, Task,
9};
10
11#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
12pub struct Message<'a> {
13 pub role: Cow<'a, str>,
14 pub content: Cow<'a, str>,
15}
16
17impl<'a> Message<'a> {
18 pub fn new(role: impl Into<Cow<'a, str>>, content: impl Into<Cow<'a, str>>) -> Self {
19 Self {
20 role: role.into(),
21 content: content.into(),
22 }
23 }
24 pub fn user(content: impl Into<Cow<'a, str>>) -> Self {
25 Self::new("user", content)
26 }
27 pub fn assistant(content: impl Into<Cow<'a, str>>) -> Self {
28 Self::new("assistant", content)
29 }
30 pub fn system(content: impl Into<Cow<'a, str>>) -> Self {
31 Self::new("system", content)
32 }
33}
34
35pub struct TaskChat<'a> {
36 /// The list of messages comprising the conversation so far.
37 pub messages: Vec<Message<'a>>,
38 /// Controls in which circumstances the model will stop generating new tokens.
39 pub stopping: Stopping<'a>,
40 /// Sampling controls how the tokens ("words") are selected for the completion.
41 pub sampling: ChatSampling,
42 /// Use this to control the logarithmic probabilities you want to have returned. This is useful
43 /// to figure out how likely it had been that this specific token had been sampled.
44 pub logprobs: Logprobs,
45}
46
47impl<'a> TaskChat<'a> {
48 /// Creates a new TaskChat containing one message with the given role and content.
49 /// All optional TaskChat attributes are left unset.
50 pub fn with_message(message: Message<'a>) -> Self {
51 Self::with_messages(vec![message])
52 }
53
54 /// Creates a new TaskChat containing the given messages.
55 /// All optional TaskChat attributes are left unset.
56 pub fn with_messages(messages: Vec<Message<'a>>) -> Self {
57 TaskChat {
58 messages,
59 sampling: ChatSampling::default(),
60 stopping: Stopping::default(),
61 logprobs: Logprobs::No,
62 }
63 }
64
65 /// Pushes a new Message to this TaskChat.
66 pub fn push_message(mut self, message: Message<'a>) -> Self {
67 self.messages.push(message);
68 self
69 }
70
71 /// Sets the maximum token attribute of this TaskChat.
72 pub fn with_maximum_tokens(mut self, maximum_tokens: u32) -> Self {
73 self.stopping.maximum_tokens = Some(maximum_tokens);
74 self
75 }
76
77 /// Sets the logprobs attribute of this TaskChat.
78 pub fn with_logprobs(mut self, logprobs: Logprobs) -> Self {
79 self.logprobs = logprobs;
80 self
81 }
82}
83
84/// Sampling controls how the tokens ("words") are selected for the completion. This is different
85/// from [`crate::Sampling`], because it does **not** supprot the `top_k` parameter.
86pub struct ChatSampling {
87 /// A temperature encourages the model to produce less probable outputs ("be more creative").
88 /// Values are expected to be between 0 and 1. Try high values for a more random ("creative")
89 /// response.
90 pub temperature: Option<f64>,
91 /// Introduces random sampling for generated tokens by randomly selecting the next token from
92 /// the k most likely options. A value larger than 1 encourages the model to be more creative.
93 /// Set to 0 to get the same behaviour as `None`.
94 pub top_p: Option<f64>,
95 /// When specified, this number will decrease (or increase) the likelihood of repeating tokens
96 /// that were mentioned prior in the completion. The penalty is cumulative. The more a token
97 /// is mentioned in the completion, the more its probability will decrease.
98 /// A negative value will increase the likelihood of repeating tokens.
99 pub frequency_penalty: Option<f64>,
100 /// The presence penalty reduces the likelihood of generating tokens that are already present
101 /// in the generated text (repetition_penalties_include_completion=true) respectively the
102 /// prompt (repetition_penalties_include_prompt=true). Presence penalty is independent of the
103 /// number of occurrences. Increase the value to reduce the likelihood of repeating text.
104 /// An operation like the following is applied:
105 ///
106 /// logits[t] -> logits[t] - 1 * penalty
107 ///
108 /// where logits[t] is the logits for any given token. Note that the formula is independent
109 /// of the number of times that a token appears.
110 pub presence_penalty: Option<f64>,
111}
112
113impl ChatSampling {
114 /// Always chooses the token most likely to come next. Choose this if you do want close to
115 /// deterministic behaviour and do not want to apply any penalties to avoid repetitions.
116 pub const MOST_LIKELY: Self = ChatSampling {
117 temperature: None,
118 top_p: None,
119 frequency_penalty: None,
120 presence_penalty: None,
121 };
122}
123
124impl Default for ChatSampling {
125 fn default() -> Self {
126 Self::MOST_LIKELY
127 }
128}
129
130#[derive(Debug, PartialEq, Deserialize)]
131pub struct Usage {
132 pub prompt_tokens: u32,
133 pub completion_tokens: u32,
134}
135
136#[derive(Debug, PartialEq)]
137pub struct ChatOutput {
138 pub message: Message<'static>,
139 pub finish_reason: String,
140 /// Contains the logprobs for the sampled and top n tokens, given that [`crate::Logprobs`] has
141 /// been set to [`crate::Logprobs::Sampled`] or [`crate::Logprobs::Top`].
142 pub logprobs: Vec<Distribution>,
143 pub usage: Usage,
144}
145
146impl ChatOutput {
147 pub fn new(
148 message: Message<'static>,
149 finish_reason: String,
150 logprobs: Vec<Distribution>,
151 usage: Usage,
152 ) -> Self {
153 Self {
154 message,
155 finish_reason,
156 logprobs,
157 usage,
158 }
159 }
160}
161
162#[derive(Deserialize, Debug, PartialEq)]
163pub struct ResponseChoice {
164 pub message: Message<'static>,
165 pub finish_reason: String,
166 pub logprobs: Option<LogprobContent>,
167}
168
169#[derive(Deserialize, Debug, PartialEq, Default)]
170pub struct LogprobContent {
171 content: Vec<Distribution>,
172}
173
174/// Logprob information for a single token
175#[derive(Deserialize, Debug, PartialEq)]
176pub struct Distribution {
177 // Logarithmic probability of the token returned in the completion
178 #[serde(flatten)]
179 pub sampled: Logprob,
180 // Logarithmic probabilities of the most probable tokens, filled if user has requested [`crate::Logprobs::Top`]
181 // The length of this array is always equal to the value of the `top_logprobs` parameter.
182 // For the special case of echo being set to true, it will include an empty element for the first token.
183 #[serde(rename = "top_logprobs")]
184 pub top: Vec<Logprob>,
185}
186
187#[derive(Deserialize, Debug, PartialEq)]
188pub struct ChatResponse {
189 choices: Vec<ResponseChoice>,
190 usage: Usage,
191}
192
193/// Additional options to affect the streaming behavior.
194#[derive(Serialize)]
195struct StreamOptions {
196 /// If set, an additional chunk will be streamed before the data: [DONE] message.
197 /// The usage field on this chunk shows the token usage statistics for the entire request,
198 /// and the choices field will always be an empty array.
199 include_usage: bool,
200}
201
202#[derive(Serialize)]
203struct ChatBody<'a> {
204 /// Name of the model tasked with completing the prompt. E.g. `luminous-base"`.
205 pub model: &'a str,
206 /// The list of messages comprising the conversation so far.
207 messages: &'a [Message<'a>],
208 /// Limits the number of tokens, which are generated for the completion.
209 #[serde(skip_serializing_if = "Option::is_none")]
210 pub max_tokens: Option<u32>,
211 #[serde(skip_serializing_if = "<[_]>::is_empty")]
212 pub stop: &'a [&'a str],
213 /// Controls the randomness of the model. Lower values will make the model more deterministic and higher values will make it more random.
214 /// Mathematically, the temperature is used to divide the logits before sampling. A temperature of 0 will always return the most likely token.
215 /// When no value is provided, the default value of 1 will be used.
216 #[serde(skip_serializing_if = "Option::is_none")]
217 pub temperature: Option<f64>,
218 /// "nucleus" parameter to dynamically adjust the number of choices for each predicted token based on the cumulative probabilities. It specifies a probability threshold, below which all less likely tokens are filtered out.
219 /// When no value is provided, the default value of 1 will be used.
220 #[serde(skip_serializing_if = "Option::is_none")]
221 pub top_p: Option<f64>,
222 #[serde(skip_serializing_if = "Option::is_none")]
223 pub frequency_penalty: Option<f64>,
224 #[serde(skip_serializing_if = "Option::is_none")]
225 pub presence_penalty: Option<f64>,
226 /// Whether to stream the response or not.
227 #[serde(skip_serializing_if = "std::ops::Not::not")]
228 pub stream: bool,
229 #[serde(skip_serializing_if = "std::ops::Not::not")]
230 pub logprobs: bool,
231 #[serde(skip_serializing_if = "Option::is_none")]
232 pub top_logprobs: Option<u8>,
233 #[serde(skip_serializing_if = "Option::is_none")]
234 pub stream_options: Option<StreamOptions>,
235}
236
237impl<'a> ChatBody<'a> {
238 pub fn new(model: &'a str, task: &'a TaskChat) -> Self {
239 let TaskChat {
240 messages,
241 stopping:
242 Stopping {
243 maximum_tokens,
244 stop_sequences,
245 },
246 sampling:
247 ChatSampling {
248 temperature,
249 top_p,
250 frequency_penalty,
251 presence_penalty,
252 },
253 logprobs,
254 } = task;
255
256 Self {
257 model,
258 messages,
259 max_tokens: *maximum_tokens,
260 stop: stop_sequences,
261 temperature: *temperature,
262 top_p: *top_p,
263 frequency_penalty: *frequency_penalty,
264 presence_penalty: *presence_penalty,
265 stream: false,
266 logprobs: logprobs.logprobs(),
267 top_logprobs: logprobs.top_logprobs(),
268 stream_options: None,
269 }
270 }
271
272 pub fn with_streaming(mut self) -> Self {
273 self.stream = true;
274 // Always set the `include_usage` to true, as currently we have not seen a
275 // case where this information might hurt.
276 self.stream_options = Some(StreamOptions {
277 include_usage: true,
278 });
279 self
280 }
281}
282
283impl Task for TaskChat<'_> {
284 type Output = ChatOutput;
285
286 type ResponseBody = ChatResponse;
287
288 fn build_request(
289 &self,
290 client: &reqwest::Client,
291 base: &str,
292 model: &str,
293 ) -> reqwest::RequestBuilder {
294 let body = ChatBody::new(model, self);
295 client.post(format!("{base}/chat/completions")).json(&body)
296 }
297
298 fn body_to_output(&self, mut response: Self::ResponseBody) -> Self::Output {
299 let ResponseChoice {
300 message,
301 finish_reason,
302 logprobs,
303 } = response.choices.pop().unwrap();
304 ChatOutput::new(
305 message,
306 finish_reason,
307 logprobs.unwrap_or_default().content,
308 response.usage,
309 )
310 }
311}
312
313#[derive(Debug, Deserialize)]
314pub struct StreamMessage {
315 /// The role of the current chat completion. Will be assistant for the first chunk of every
316 /// completion stream and missing for the remaining chunks.
317 pub role: Option<String>,
318 /// The content of the current chat completion. Will be empty for the first chunk of every
319 /// completion stream and non-empty for the remaining chunks.
320 pub content: Option<String>,
321}
322
323/// One chunk of a chat completion stream.
324#[derive(Debug, Deserialize)]
325#[serde(untagged)]
326pub enum DeserializedChatChunk {
327 Delta {
328 /// Chat completion chunk generated by the model when streaming is enabled.
329 delta: StreamMessage,
330 logprobs: Option<LogprobContent>,
331 /// The reason the model stopped generating tokens.
332 finish_reason: Option<String>,
333 },
334}
335
336/// Response received from a chat completion stream.
337/// Will either have Some(Usage) or choices of length 1.
338///
339/// While we could deserialize directly into an enum, deserializing into a struct and
340/// only having the enum on the output type seems to be the simpler solution.
341#[derive(Deserialize)]
342pub struct StreamChatResponse {
343 pub choices: Vec<DeserializedChatChunk>,
344 pub usage: Option<Usage>,
345}
346
347#[derive(Debug, PartialEq)]
348pub enum ChatEvent {
349 MessageStart {
350 role: String,
351 },
352 MessageDelta {
353 /// Chat completion chunk generated by the model when streaming is enabled.
354 /// The role is always "assistant".
355 content: Option<String>,
356 /// Log probabilities of the completion tokens if requested via logprobs parameter in request.
357 logprobs: Vec<Distribution>,
358 /// The reason the model stopped generating tokens. Only present in the last chunk.
359 finish_reason: Option<String>,
360 },
361 /// Summary of the chat completion stream.
362 Summary {
363 usage: Usage,
364 },
365}
366
367impl StreamTask for TaskChat<'_> {
368 type Output = ChatEvent;
369
370 type ResponseBody = StreamChatResponse;
371
372 fn build_request(
373 &self,
374 client: &reqwest::Client,
375 base: &str,
376 model: &str,
377 ) -> reqwest::RequestBuilder {
378 let body = ChatBody::new(model, self).with_streaming();
379 client.post(format!("{base}/chat/completions")).json(&body)
380 }
381
382 fn body_to_output(&self, mut response: Self::ResponseBody) -> Self::Output {
383 if let Some(usage) = response.usage {
384 ChatEvent::Summary { usage }
385 } else {
386 // We always expect there to be exactly one choice, as the `n` parameter is not
387 // supported by this crate.
388 let chunk = response
389 .choices
390 .pop()
391 .expect("There must always be at least one choice");
392
393 match chunk {
394 // Skip the role message
395 DeserializedChatChunk::Delta {
396 delta:
397 StreamMessage {
398 role: Some(role), ..
399 },
400 ..
401 } => ChatEvent::MessageStart { role },
402 DeserializedChatChunk::Delta {
403 delta:
404 StreamMessage {
405 role: None,
406 content,
407 },
408 logprobs,
409 finish_reason,
410 } => ChatEvent::MessageDelta {
411 content,
412 logprobs: logprobs.unwrap_or_default().content,
413 finish_reason,
414 },
415 }
416 }
417 }
418}
419
420impl Logprobs {
421 /// Representation for serialization in request body, for `logprobs` parameter
422 pub fn logprobs(self) -> bool {
423 match self {
424 Logprobs::No => false,
425 Logprobs::Sampled | Logprobs::Top(_) => true,
426 }
427 }
428
429 /// Representation for serialization in request body, for `top_logprobs` parameter
430 pub fn top_logprobs(self) -> Option<u8> {
431 match self {
432 Logprobs::No | Logprobs::Sampled => None,
433 Logprobs::Top(n) => Some(n),
434 }
435 }
436}