aleph_alpha_client/
chat.rs

1use core::str;
2use std::borrow::Cow;
3
4use serde::{Deserialize, Serialize};
5
6use crate::{
7    logprobs::{Logprob, Logprobs},
8    Stopping, StreamTask, Task,
9};
10
11#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
12pub struct Message<'a> {
13    pub role: Cow<'a, str>,
14    pub content: Cow<'a, str>,
15}
16
17impl<'a> Message<'a> {
18    pub fn new(role: impl Into<Cow<'a, str>>, content: impl Into<Cow<'a, str>>) -> Self {
19        Self {
20            role: role.into(),
21            content: content.into(),
22        }
23    }
24    pub fn user(content: impl Into<Cow<'a, str>>) -> Self {
25        Self::new("user", content)
26    }
27    pub fn assistant(content: impl Into<Cow<'a, str>>) -> Self {
28        Self::new("assistant", content)
29    }
30    pub fn system(content: impl Into<Cow<'a, str>>) -> Self {
31        Self::new("system", content)
32    }
33}
34
35pub struct TaskChat<'a> {
36    /// The list of messages comprising the conversation so far.
37    pub messages: Vec<Message<'a>>,
38    /// Controls in which circumstances the model will stop generating new tokens.
39    pub stopping: Stopping<'a>,
40    /// Sampling controls how the tokens ("words") are selected for the completion.
41    pub sampling: ChatSampling,
42    /// Use this to control the logarithmic probabilities you want to have returned. This is useful
43    /// to figure out how likely it had been that this specific token had been sampled.
44    pub logprobs: Logprobs,
45}
46
47impl<'a> TaskChat<'a> {
48    /// Creates a new TaskChat containing one message with the given role and content.
49    /// All optional TaskChat attributes are left unset.
50    pub fn with_message(message: Message<'a>) -> Self {
51        Self::with_messages(vec![message])
52    }
53
54    /// Creates a new TaskChat containing the given messages.
55    /// All optional TaskChat attributes are left unset.
56    pub fn with_messages(messages: Vec<Message<'a>>) -> Self {
57        TaskChat {
58            messages,
59            sampling: ChatSampling::default(),
60            stopping: Stopping::default(),
61            logprobs: Logprobs::No,
62        }
63    }
64
65    /// Pushes a new Message to this TaskChat.
66    pub fn push_message(mut self, message: Message<'a>) -> Self {
67        self.messages.push(message);
68        self
69    }
70
71    /// Sets the maximum token attribute of this TaskChat.
72    pub fn with_maximum_tokens(mut self, maximum_tokens: u32) -> Self {
73        self.stopping.maximum_tokens = Some(maximum_tokens);
74        self
75    }
76
77    /// Sets the logprobs attribute of this TaskChat.
78    pub fn with_logprobs(mut self, logprobs: Logprobs) -> Self {
79        self.logprobs = logprobs;
80        self
81    }
82}
83
84/// Sampling controls how the tokens ("words") are selected for the completion. This is different
85/// from [`crate::Sampling`], because it does **not** supprot the `top_k` parameter.
86pub struct ChatSampling {
87    /// A temperature encourages the model to produce less probable outputs ("be more creative").
88    /// Values are expected to be between 0 and 1. Try high values for a more random ("creative")
89    /// response.
90    pub temperature: Option<f64>,
91    /// Introduces random sampling for generated tokens by randomly selecting the next token from
92    /// the k most likely options. A value larger than 1 encourages the model to be more creative.
93    /// Set to 0 to get the same behaviour as `None`.
94    pub top_p: Option<f64>,
95    /// When specified, this number will decrease (or increase) the likelihood of repeating tokens
96    /// that were mentioned prior in the completion. The penalty is cumulative. The more a token
97    /// is mentioned in the completion, the more its probability will decrease.
98    /// A negative value will increase the likelihood of repeating tokens.
99    pub frequency_penalty: Option<f64>,
100    /// The presence penalty reduces the likelihood of generating tokens that are already present
101    /// in the generated text (repetition_penalties_include_completion=true) respectively the
102    /// prompt (repetition_penalties_include_prompt=true). Presence penalty is independent of the
103    /// number of occurrences. Increase the value to reduce the likelihood of repeating text.
104    /// An operation like the following is applied:
105    ///
106    /// logits[t] -> logits[t] - 1 * penalty
107    ///
108    /// where logits[t] is the logits for any given token. Note that the formula is independent
109    /// of the number of times that a token appears.
110    pub presence_penalty: Option<f64>,
111}
112
113impl ChatSampling {
114    /// Always chooses the token most likely to come next. Choose this if you do want close to
115    /// deterministic behaviour and do not want to apply any penalties to avoid repetitions.
116    pub const MOST_LIKELY: Self = ChatSampling {
117        temperature: None,
118        top_p: None,
119        frequency_penalty: None,
120        presence_penalty: None,
121    };
122}
123
124impl Default for ChatSampling {
125    fn default() -> Self {
126        Self::MOST_LIKELY
127    }
128}
129
130#[derive(Debug, PartialEq, Deserialize)]
131pub struct Usage {
132    pub prompt_tokens: u32,
133    pub completion_tokens: u32,
134}
135
136#[derive(Debug, PartialEq)]
137pub struct ChatOutput {
138    pub message: Message<'static>,
139    pub finish_reason: String,
140    /// Contains the logprobs for the sampled and top n tokens, given that [`crate::Logprobs`] has
141    /// been set to [`crate::Logprobs::Sampled`] or [`crate::Logprobs::Top`].
142    pub logprobs: Vec<Distribution>,
143    pub usage: Usage,
144}
145
146impl ChatOutput {
147    pub fn new(
148        message: Message<'static>,
149        finish_reason: String,
150        logprobs: Vec<Distribution>,
151        usage: Usage,
152    ) -> Self {
153        Self {
154            message,
155            finish_reason,
156            logprobs,
157            usage,
158        }
159    }
160}
161
162#[derive(Deserialize, Debug, PartialEq)]
163pub struct ResponseChoice {
164    pub message: Message<'static>,
165    pub finish_reason: String,
166    pub logprobs: Option<LogprobContent>,
167}
168
169#[derive(Deserialize, Debug, PartialEq, Default)]
170pub struct LogprobContent {
171    content: Vec<Distribution>,
172}
173
174/// Logprob information for a single token
175#[derive(Deserialize, Debug, PartialEq)]
176pub struct Distribution {
177    // Logarithmic probability of the token returned in the completion
178    #[serde(flatten)]
179    pub sampled: Logprob,
180    // Logarithmic probabilities of the most probable tokens, filled if user has requested [`crate::Logprobs::Top`]
181    // The length of this array is always equal to the value of the `top_logprobs` parameter.
182    // For the special case of echo being set to true, it will include an empty element for the first token.
183    #[serde(rename = "top_logprobs")]
184    pub top: Vec<Logprob>,
185}
186
187#[derive(Deserialize, Debug, PartialEq)]
188pub struct ChatResponse {
189    choices: Vec<ResponseChoice>,
190    usage: Usage,
191}
192
193/// Additional options to affect the streaming behavior.
194#[derive(Serialize)]
195struct StreamOptions {
196    /// If set, an additional chunk will be streamed before the data: [DONE] message.
197    /// The usage field on this chunk shows the token usage statistics for the entire request,
198    /// and the choices field will always be an empty array.
199    include_usage: bool,
200}
201
202#[derive(Serialize)]
203struct ChatBody<'a> {
204    /// Name of the model tasked with completing the prompt. E.g. `luminous-base"`.
205    pub model: &'a str,
206    /// The list of messages comprising the conversation so far.
207    messages: &'a [Message<'a>],
208    /// Limits the number of tokens, which are generated for the completion.
209    #[serde(skip_serializing_if = "Option::is_none")]
210    pub max_tokens: Option<u32>,
211    #[serde(skip_serializing_if = "<[_]>::is_empty")]
212    pub stop: &'a [&'a str],
213    /// Controls the randomness of the model. Lower values will make the model more deterministic and higher values will make it more random.
214    /// Mathematically, the temperature is used to divide the logits before sampling. A temperature of 0 will always return the most likely token.
215    /// When no value is provided, the default value of 1 will be used.
216    #[serde(skip_serializing_if = "Option::is_none")]
217    pub temperature: Option<f64>,
218    /// "nucleus" parameter to dynamically adjust the number of choices for each predicted token based on the cumulative probabilities. It specifies a probability threshold, below which all less likely tokens are filtered out.
219    /// When no value is provided, the default value of 1 will be used.
220    #[serde(skip_serializing_if = "Option::is_none")]
221    pub top_p: Option<f64>,
222    #[serde(skip_serializing_if = "Option::is_none")]
223    pub frequency_penalty: Option<f64>,
224    #[serde(skip_serializing_if = "Option::is_none")]
225    pub presence_penalty: Option<f64>,
226    /// Whether to stream the response or not.
227    #[serde(skip_serializing_if = "std::ops::Not::not")]
228    pub stream: bool,
229    #[serde(skip_serializing_if = "std::ops::Not::not")]
230    pub logprobs: bool,
231    #[serde(skip_serializing_if = "Option::is_none")]
232    pub top_logprobs: Option<u8>,
233    #[serde(skip_serializing_if = "Option::is_none")]
234    pub stream_options: Option<StreamOptions>,
235}
236
237impl<'a> ChatBody<'a> {
238    pub fn new(model: &'a str, task: &'a TaskChat) -> Self {
239        let TaskChat {
240            messages,
241            stopping:
242                Stopping {
243                    maximum_tokens,
244                    stop_sequences,
245                },
246            sampling:
247                ChatSampling {
248                    temperature,
249                    top_p,
250                    frequency_penalty,
251                    presence_penalty,
252                },
253            logprobs,
254        } = task;
255
256        Self {
257            model,
258            messages,
259            max_tokens: *maximum_tokens,
260            stop: stop_sequences,
261            temperature: *temperature,
262            top_p: *top_p,
263            frequency_penalty: *frequency_penalty,
264            presence_penalty: *presence_penalty,
265            stream: false,
266            logprobs: logprobs.logprobs(),
267            top_logprobs: logprobs.top_logprobs(),
268            stream_options: None,
269        }
270    }
271
272    pub fn with_streaming(mut self) -> Self {
273        self.stream = true;
274        // Always set the `include_usage` to true, as currently we have not seen a
275        // case where this information might hurt.
276        self.stream_options = Some(StreamOptions {
277            include_usage: true,
278        });
279        self
280    }
281}
282
283impl Task for TaskChat<'_> {
284    type Output = ChatOutput;
285
286    type ResponseBody = ChatResponse;
287
288    fn build_request(
289        &self,
290        client: &reqwest::Client,
291        base: &str,
292        model: &str,
293    ) -> reqwest::RequestBuilder {
294        let body = ChatBody::new(model, self);
295        client.post(format!("{base}/chat/completions")).json(&body)
296    }
297
298    fn body_to_output(&self, mut response: Self::ResponseBody) -> Self::Output {
299        let ResponseChoice {
300            message,
301            finish_reason,
302            logprobs,
303        } = response.choices.pop().unwrap();
304        ChatOutput::new(
305            message,
306            finish_reason,
307            logprobs.unwrap_or_default().content,
308            response.usage,
309        )
310    }
311}
312
313#[derive(Debug, Deserialize)]
314pub struct StreamMessage {
315    /// The role of the current chat completion. Will be assistant for the first chunk of every
316    /// completion stream and missing for the remaining chunks.
317    pub role: Option<String>,
318    /// The content of the current chat completion. Will be empty for the first chunk of every
319    /// completion stream and non-empty for the remaining chunks.
320    pub content: Option<String>,
321}
322
323/// One chunk of a chat completion stream.
324#[derive(Debug, Deserialize)]
325#[serde(untagged)]
326pub enum DeserializedChatChunk {
327    Delta {
328        /// Chat completion chunk generated by the model when streaming is enabled.
329        delta: StreamMessage,
330        logprobs: Option<LogprobContent>,
331        /// The reason the model stopped generating tokens.
332        finish_reason: Option<String>,
333    },
334}
335
336/// Response received from a chat completion stream.
337/// Will either have Some(Usage) or choices of length 1.
338///
339/// While we could deserialize directly into an enum, deserializing into a struct and
340/// only having the enum on the output type seems to be the simpler solution.
341#[derive(Deserialize)]
342pub struct StreamChatResponse {
343    pub choices: Vec<DeserializedChatChunk>,
344    pub usage: Option<Usage>,
345}
346
347#[derive(Debug, PartialEq)]
348pub enum ChatEvent {
349    MessageStart {
350        role: String,
351    },
352    MessageDelta {
353        /// Chat completion chunk generated by the model when streaming is enabled.
354        /// The role is always "assistant".
355        content: Option<String>,
356        /// Log probabilities of the completion tokens if requested via logprobs parameter in request.
357        logprobs: Vec<Distribution>,
358        /// The reason the model stopped generating tokens. Only present in the last chunk.
359        finish_reason: Option<String>,
360    },
361    /// Summary of the chat completion stream.
362    Summary {
363        usage: Usage,
364    },
365}
366
367impl StreamTask for TaskChat<'_> {
368    type Output = ChatEvent;
369
370    type ResponseBody = StreamChatResponse;
371
372    fn build_request(
373        &self,
374        client: &reqwest::Client,
375        base: &str,
376        model: &str,
377    ) -> reqwest::RequestBuilder {
378        let body = ChatBody::new(model, self).with_streaming();
379        client.post(format!("{base}/chat/completions")).json(&body)
380    }
381
382    fn body_to_output(&self, mut response: Self::ResponseBody) -> Self::Output {
383        if let Some(usage) = response.usage {
384            ChatEvent::Summary { usage }
385        } else {
386            // We always expect there to be exactly one choice, as the `n` parameter is not
387            // supported by this crate.
388            let chunk = response
389                .choices
390                .pop()
391                .expect("There must always be at least one choice");
392
393            match chunk {
394                // Skip the role message
395                DeserializedChatChunk::Delta {
396                    delta:
397                        StreamMessage {
398                            role: Some(role), ..
399                        },
400                    ..
401                } => ChatEvent::MessageStart { role },
402                DeserializedChatChunk::Delta {
403                    delta:
404                        StreamMessage {
405                            role: None,
406                            content,
407                        },
408                    logprobs,
409                    finish_reason,
410                } => ChatEvent::MessageDelta {
411                    content,
412                    logprobs: logprobs.unwrap_or_default().content,
413                    finish_reason,
414                },
415            }
416        }
417    }
418}
419
420impl Logprobs {
421    /// Representation for serialization in request body, for `logprobs` parameter
422    pub fn logprobs(self) -> bool {
423        match self {
424            Logprobs::No => false,
425            Logprobs::Sampled | Logprobs::Top(_) => true,
426        }
427    }
428
429    /// Representation for serialization in request body, for `top_logprobs` parameter
430    pub fn top_logprobs(self) -> Option<u8> {
431        match self {
432            Logprobs::No | Logprobs::Sampled => None,
433            Logprobs::Top(n) => Some(n),
434        }
435    }
436}