aleph_alpha_client/
chat.rs

1use core::str;
2use std::borrow::Cow;
3
4use serde::{Deserialize, Serialize};
5
6use crate::{
7    logprobs::{Logprob, Logprobs},
8    Stopping, StreamTask, Task,
9};
10
11#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
12pub struct Message<'a> {
13    pub role: Cow<'a, str>,
14    pub content: Cow<'a, str>,
15}
16
17impl<'a> Message<'a> {
18    pub fn new(role: impl Into<Cow<'a, str>>, content: impl Into<Cow<'a, str>>) -> Self {
19        Self {
20            role: role.into(),
21            content: content.into(),
22        }
23    }
24    pub fn user(content: impl Into<Cow<'a, str>>) -> Self {
25        Self::new("user", content)
26    }
27    pub fn assistant(content: impl Into<Cow<'a, str>>) -> Self {
28        Self::new("assistant", content)
29    }
30    pub fn system(content: impl Into<Cow<'a, str>>) -> Self {
31        Self::new("system", content)
32    }
33}
34
35pub struct TaskChat<'a> {
36    /// The list of messages comprising the conversation so far.
37    pub messages: Vec<Message<'a>>,
38    /// Controls in which circumstances the model will stop generating new tokens.
39    pub stopping: Stopping<'a>,
40    /// Sampling controls how the tokens ("words") are selected for the completion.
41    pub sampling: ChatSampling,
42    /// Use this to control the logarithmic probabilities you want to have returned. This is useful
43    /// to figure out how likely it had been that this specific token had been sampled.
44    pub logprobs: Logprobs,
45}
46
47impl<'a> TaskChat<'a> {
48    /// Creates a new TaskChat containing one message with the given role and content.
49    /// All optional TaskChat attributes are left unset.
50    pub fn with_message(message: Message<'a>) -> Self {
51        Self::with_messages(vec![message])
52    }
53
54    /// Creates a new TaskChat containing the given messages.
55    /// All optional TaskChat attributes are left unset.
56    pub fn with_messages(messages: Vec<Message<'a>>) -> Self {
57        TaskChat {
58            messages,
59            sampling: ChatSampling::default(),
60            stopping: Stopping::default(),
61            logprobs: Logprobs::No,
62        }
63    }
64
65    /// Pushes a new Message to this TaskChat.
66    pub fn push_message(mut self, message: Message<'a>) -> Self {
67        self.messages.push(message);
68        self
69    }
70
71    /// Sets the maximum token attribute of this TaskChat.
72    pub fn with_maximum_tokens(mut self, maximum_tokens: u32) -> Self {
73        self.stopping.maximum_tokens = Some(maximum_tokens);
74        self
75    }
76
77    /// Sets the logprobs attribute of this TaskChat.
78    pub fn with_logprobs(mut self, logprobs: Logprobs) -> Self {
79        self.logprobs = logprobs;
80        self
81    }
82}
83
84/// Sampling controls how the tokens ("words") are selected for the completion. This is different
85/// from [`crate::Sampling`], because it does **not** supprot the `top_k` parameter.
86pub struct ChatSampling {
87    /// A temperature encourages the model to produce less probable outputs ("be more creative").
88    /// Values are expected to be between 0 and 1. Try high values for a more random ("creative")
89    /// response.
90    pub temperature: Option<f64>,
91    /// Introduces random sampling for generated tokens by randomly selecting the next token from
92    /// the k most likely options. A value larger than 1 encourages the model to be more creative.
93    /// Set to 0 to get the same behaviour as `None`.
94    pub top_p: Option<f64>,
95    /// When specified, this number will decrease (or increase) the likelihood of repeating tokens
96    /// that were mentioned prior in the completion. The penalty is cumulative. The more a token
97    /// is mentioned in the completion, the more its probability will decrease.
98    /// A negative value will increase the likelihood of repeating tokens.
99    pub frequency_penalty: Option<f64>,
100    /// The presence penalty reduces the likelihood of generating tokens that are already present
101    /// in the generated text (repetition_penalties_include_completion=true) respectively the
102    /// prompt (repetition_penalties_include_prompt=true). Presence penalty is independent of the
103    /// number of occurrences. Increase the value to reduce the likelihood of repeating text.
104    /// An operation like the following is applied:
105    ///
106    /// logits[t] -> logits[t] - 1 * penalty
107    ///
108    /// where logits[t] is the logits for any given token. Note that the formula is independent
109    /// of the number of times that a token appears.
110    pub presence_penalty: Option<f64>,
111}
112
113impl ChatSampling {
114    /// Always chooses the token most likely to come next. Choose this if you do want close to
115    /// deterministic behaviour and do not want to apply any penalties to avoid repetitions.
116    pub const MOST_LIKELY: Self = ChatSampling {
117        temperature: None,
118        top_p: None,
119        frequency_penalty: None,
120        presence_penalty: None,
121    };
122}
123
124impl Default for ChatSampling {
125    fn default() -> Self {
126        Self::MOST_LIKELY
127    }
128}
129
130#[derive(Debug, PartialEq, Deserialize)]
131pub struct Usage {
132    pub prompt_tokens: u32,
133    pub completion_tokens: u32,
134}
135
136#[derive(Debug, PartialEq)]
137pub struct ChatOutput {
138    pub message: Message<'static>,
139    pub finish_reason: String,
140    /// Contains the logprobs for the sampled and top n tokens, given that [`crate::Logprobs`] has
141    /// been set to [`crate::Logprobs::Sampled`] or [`crate::Logprobs::Top`].
142    pub logprobs: Vec<Distribution>,
143    pub usage: Usage,
144}
145
146impl ChatOutput {
147    pub fn new(
148        message: Message<'static>,
149        finish_reason: String,
150        logprobs: Vec<Distribution>,
151        usage: Usage,
152    ) -> Self {
153        Self {
154            message,
155            finish_reason,
156            logprobs,
157            usage,
158        }
159    }
160}
161
162#[derive(Deserialize, Debug, PartialEq)]
163pub struct ResponseChoice {
164    pub message: Message<'static>,
165    pub finish_reason: String,
166    pub logprobs: Option<LogprobContent>,
167}
168
169#[derive(Deserialize, Debug, PartialEq, Default)]
170pub struct LogprobContent {
171    content: Vec<Distribution>,
172}
173
174/// Logprob information for a single token
175#[derive(Deserialize, Debug, PartialEq)]
176pub struct Distribution {
177    // Logarithmic probability of the token returned in the completion
178    #[serde(flatten)]
179    pub sampled: Logprob,
180    // Logarithmic probabilities of the most probable tokens, filled if user has requested [`crate::Logprobs::Top`]
181    #[serde(rename = "top_logprobs")]
182    pub top: Vec<Logprob>,
183}
184
185#[derive(Deserialize, Debug, PartialEq)]
186pub struct ChatResponse {
187    choices: Vec<ResponseChoice>,
188    usage: Usage,
189}
190
191/// Additional options to affect the streaming behavior.
192#[derive(Serialize)]
193struct StreamOptions {
194    /// If set, an additional chunk will be streamed before the data: [DONE] message.
195    /// The usage field on this chunk shows the token usage statistics for the entire request,
196    /// and the choices field will always be an empty array.
197    include_usage: bool,
198}
199
200#[derive(Serialize)]
201struct ChatBody<'a> {
202    /// Name of the model tasked with completing the prompt. E.g. `luminous-base"`.
203    pub model: &'a str,
204    /// The list of messages comprising the conversation so far.
205    messages: &'a [Message<'a>],
206    /// Limits the number of tokens, which are generated for the completion.
207    #[serde(skip_serializing_if = "Option::is_none")]
208    pub max_tokens: Option<u32>,
209    #[serde(skip_serializing_if = "<[_]>::is_empty")]
210    pub stop: &'a [&'a str],
211    /// Controls the randomness of the model. Lower values will make the model more deterministic and higher values will make it more random.
212    /// Mathematically, the temperature is used to divide the logits before sampling. A temperature of 0 will always return the most likely token.
213    /// When no value is provided, the default value of 1 will be used.
214    #[serde(skip_serializing_if = "Option::is_none")]
215    pub temperature: Option<f64>,
216    /// "nucleus" parameter to dynamically adjust the number of choices for each predicted token based on the cumulative probabilities. It specifies a probability threshold, below which all less likely tokens are filtered out.
217    /// When no value is provided, the default value of 1 will be used.
218    #[serde(skip_serializing_if = "Option::is_none")]
219    pub top_p: Option<f64>,
220    #[serde(skip_serializing_if = "Option::is_none")]
221    pub frequency_penalty: Option<f64>,
222    #[serde(skip_serializing_if = "Option::is_none")]
223    pub presence_penalty: Option<f64>,
224    /// Whether to stream the response or not.
225    #[serde(skip_serializing_if = "std::ops::Not::not")]
226    pub stream: bool,
227    #[serde(skip_serializing_if = "std::ops::Not::not")]
228    pub logprobs: bool,
229    #[serde(skip_serializing_if = "Option::is_none")]
230    pub top_logprobs: Option<u8>,
231    #[serde(skip_serializing_if = "Option::is_none")]
232    pub stream_options: Option<StreamOptions>,
233}
234
235impl<'a> ChatBody<'a> {
236    pub fn new(model: &'a str, task: &'a TaskChat) -> Self {
237        let TaskChat {
238            messages,
239            stopping:
240                Stopping {
241                    maximum_tokens,
242                    stop_sequences,
243                },
244            sampling:
245                ChatSampling {
246                    temperature,
247                    top_p,
248                    frequency_penalty,
249                    presence_penalty,
250                },
251            logprobs,
252        } = task;
253
254        Self {
255            model,
256            messages,
257            max_tokens: *maximum_tokens,
258            stop: stop_sequences,
259            temperature: *temperature,
260            top_p: *top_p,
261            frequency_penalty: *frequency_penalty,
262            presence_penalty: *presence_penalty,
263            stream: false,
264            logprobs: logprobs.logprobs(),
265            top_logprobs: logprobs.top_logprobs(),
266            stream_options: None,
267        }
268    }
269
270    pub fn with_streaming(mut self) -> Self {
271        self.stream = true;
272        // Always set the `include_usage` to true, as currently we have not seen a
273        // case where this information might hurt.
274        self.stream_options = Some(StreamOptions {
275            include_usage: true,
276        });
277        self
278    }
279}
280
281impl Task for TaskChat<'_> {
282    type Output = ChatOutput;
283
284    type ResponseBody = ChatResponse;
285
286    fn build_request(
287        &self,
288        client: &reqwest::Client,
289        base: &str,
290        model: &str,
291    ) -> reqwest::RequestBuilder {
292        let body = ChatBody::new(model, self);
293        client.post(format!("{base}/chat/completions")).json(&body)
294    }
295
296    fn body_to_output(&self, mut response: Self::ResponseBody) -> Self::Output {
297        let ResponseChoice {
298            message,
299            finish_reason,
300            logprobs,
301        } = response.choices.pop().unwrap();
302        ChatOutput::new(
303            message,
304            finish_reason,
305            logprobs.unwrap_or_default().content,
306            response.usage,
307        )
308    }
309}
310
311#[derive(Debug, Deserialize)]
312pub struct StreamMessage {
313    /// The role of the current chat completion. Will be assistant for the first chunk of every
314    /// completion stream and missing for the remaining chunks.
315    pub role: Option<String>,
316    /// The content of the current chat completion. Will be empty for the first chunk of every
317    /// completion stream and non-empty for the remaining chunks.
318    pub content: String,
319}
320
321/// One chunk of a chat completion stream.
322#[derive(Debug, Deserialize)]
323#[serde(untagged)]
324pub enum DeserializedChatChunk {
325    Delta {
326        /// Chat completion chunk generated by the model when streaming is enabled.
327        delta: StreamMessage,
328        logprobs: Option<LogprobContent>,
329    },
330    /// The last chunk of a chat completion stream.
331    Finished {
332        /// The reason the model stopped generating tokens.
333        finish_reason: String,
334    },
335}
336
337/// Response received from a chat completion stream.
338/// Will either have Some(Usage) or choices of length 1.
339///
340/// While we could deserialize directly into an enum, deserializing into a struct and
341/// only having the enum on the output type seems to be the simpler solution.
342#[derive(Deserialize)]
343pub struct StreamChatResponse {
344    pub choices: Vec<DeserializedChatChunk>,
345    pub usage: Option<Usage>,
346}
347
348#[derive(Debug)]
349pub enum ChatEvent {
350    Delta {
351        /// Chat completion chunk generated by the model when streaming is enabled.
352        /// The role is always "assistant".
353        content: String,
354        /// Log probabilities of the completion tokens if requested via logprobs parameter in request.
355        logprobs: Vec<Distribution>,
356    },
357    /// The last chunk of a chat completion stream.
358    Finished {
359        /// The reason the model stopped generating tokens.
360        reason: String,
361    },
362    /// Summary of the chat completion stream.
363    Summary { usage: Usage },
364}
365
366impl StreamTask for TaskChat<'_> {
367    type Output = ChatEvent;
368
369    type ResponseBody = StreamChatResponse;
370
371    fn build_request(
372        &self,
373        client: &reqwest::Client,
374        base: &str,
375        model: &str,
376    ) -> reqwest::RequestBuilder {
377        let body = ChatBody::new(model, self).with_streaming();
378        client.post(format!("{base}/chat/completions")).json(&body)
379    }
380
381    fn body_to_output(&self, mut response: Self::ResponseBody) -> Option<Self::Output> {
382        if let Some(usage) = response.usage {
383            Some(ChatEvent::Summary { usage })
384        } else {
385            // We always expect there to be exactly one choice, as the `n` parameter is not
386            // supported by this crate.
387            let chunk = response
388                .choices
389                .pop()
390                .expect("There must always be at least one choice");
391
392            match chunk {
393                // Skip the role message
394                DeserializedChatChunk::Delta {
395                    delta: StreamMessage { role: Some(_), .. },
396                    ..
397                } => None,
398                DeserializedChatChunk::Delta {
399                    delta:
400                        StreamMessage {
401                            role: None,
402                            content,
403                        },
404                    logprobs,
405                } => Some(ChatEvent::Delta {
406                    content,
407                    logprobs: logprobs.unwrap_or_default().content,
408                }),
409                DeserializedChatChunk::Finished { finish_reason } => Some(ChatEvent::Finished {
410                    reason: finish_reason,
411                }),
412            }
413        }
414    }
415}
416
417impl Logprobs {
418    /// Representation for serialization in request body, for `logprobs` parameter
419    pub fn logprobs(self) -> bool {
420        match self {
421            Logprobs::No => false,
422            Logprobs::Sampled | Logprobs::Top(_) => true,
423        }
424    }
425
426    /// Representation for serialization in request body, for `top_logprobs` parameter
427    pub fn top_logprobs(self) -> Option<u8> {
428        match self {
429            Logprobs::No | Logprobs::Sampled => None,
430            Logprobs::Top(n) => Some(n),
431        }
432    }
433}