aleph_alpha_client/
chat.rs

1use core::str;
2use std::borrow::Cow;
3
4use serde::{Deserialize, Serialize};
5
6use crate::{
7    logprobs::{Logprob, Logprobs},
8    Stopping, StreamTask, Task,
9};
10
11#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
12pub struct Message<'a> {
13    pub role: Cow<'a, str>,
14    pub content: Cow<'a, str>,
15}
16
17impl<'a> Message<'a> {
18    pub fn new(role: impl Into<Cow<'a, str>>, content: impl Into<Cow<'a, str>>) -> Self {
19        Self {
20            role: role.into(),
21            content: content.into(),
22        }
23    }
24    pub fn user(content: impl Into<Cow<'a, str>>) -> Self {
25        Self::new("user", content)
26    }
27    pub fn assistant(content: impl Into<Cow<'a, str>>) -> Self {
28        Self::new("assistant", content)
29    }
30    pub fn system(content: impl Into<Cow<'a, str>>) -> Self {
31        Self::new("system", content)
32    }
33}
34
35pub struct TaskChat<'a> {
36    /// The list of messages comprising the conversation so far.
37    pub messages: Vec<Message<'a>>,
38    /// Controls in which circumstances the model will stop generating new tokens.
39    pub stopping: Stopping<'a>,
40    /// Sampling controls how the tokens ("words") are selected for the completion.
41    pub sampling: ChatSampling,
42    /// Use this to control the logarithmic probabilities you want to have returned. This is useful
43    /// to figure out how likely it had been that this specific token had been sampled.
44    pub logprobs: Logprobs,
45}
46
47impl<'a> TaskChat<'a> {
48    /// Creates a new TaskChat containing one message with the given role and content.
49    /// All optional TaskChat attributes are left unset.
50    pub fn with_message(message: Message<'a>) -> Self {
51        Self::with_messages(vec![message])
52    }
53
54    /// Creates a new TaskChat containing the given messages.
55    /// All optional TaskChat attributes are left unset.
56    pub fn with_messages(messages: Vec<Message<'a>>) -> Self {
57        TaskChat {
58            messages,
59            sampling: ChatSampling::default(),
60            stopping: Stopping::default(),
61            logprobs: Logprobs::No,
62        }
63    }
64
65    /// Pushes a new Message to this TaskChat.
66    pub fn push_message(mut self, message: Message<'a>) -> Self {
67        self.messages.push(message);
68        self
69    }
70
71    /// Sets the maximum token attribute of this TaskChat.
72    pub fn with_maximum_tokens(mut self, maximum_tokens: u32) -> Self {
73        self.stopping.maximum_tokens = Some(maximum_tokens);
74        self
75    }
76}
77
78/// Sampling controls how the tokens ("words") are selected for the completion. This is different
79/// from [`crate::Sampling`], because it does **not** supprot the `top_k` parameter.
80pub struct ChatSampling {
81    /// A temperature encourages the model to produce less probable outputs ("be more creative").
82    /// Values are expected to be between 0 and 1. Try high values for a more random ("creative")
83    /// response.
84    pub temperature: Option<f64>,
85    /// Introduces random sampling for generated tokens by randomly selecting the next token from
86    /// the k most likely options. A value larger than 1 encourages the model to be more creative.
87    /// Set to 0 to get the same behaviour as `None`.
88    pub top_p: Option<f64>,
89    /// When specified, this number will decrease (or increase) the likelihood of repeating tokens
90    /// that were mentioned prior in the completion. The penalty is cumulative. The more a token
91    /// is mentioned in the completion, the more its probability will decrease.
92    /// A negative value will increase the likelihood of repeating tokens.
93    pub frequency_penalty: Option<f64>,
94    /// The presence penalty reduces the likelihood of generating tokens that are already present
95    /// in the generated text (repetition_penalties_include_completion=true) respectively the
96    /// prompt (repetition_penalties_include_prompt=true). Presence penalty is independent of the
97    /// number of occurrences. Increase the value to reduce the likelihood of repeating text.
98    /// An operation like the following is applied:
99    ///
100    /// logits[t] -> logits[t] - 1 * penalty
101    ///
102    /// where logits[t] is the logits for any given token. Note that the formula is independent
103    /// of the number of times that a token appears.
104    pub presence_penalty: Option<f64>,
105}
106
107impl ChatSampling {
108    /// Always chooses the token most likely to come next. Choose this if you do want close to
109    /// deterministic behaviour and do not want to apply any penalties to avoid repetitions.
110    pub const MOST_LIKELY: Self = ChatSampling {
111        temperature: None,
112        top_p: None,
113        frequency_penalty: None,
114        presence_penalty: None,
115    };
116}
117
118impl Default for ChatSampling {
119    fn default() -> Self {
120        Self::MOST_LIKELY
121    }
122}
123
124#[derive(Debug, PartialEq, Deserialize)]
125pub struct Usage {
126    pub prompt_tokens: u32,
127    pub completion_tokens: u32,
128}
129
130#[derive(Debug, PartialEq)]
131pub struct ChatOutput {
132    pub message: Message<'static>,
133    pub finish_reason: String,
134    /// Contains the logprobs for the sampled and top n tokens, given that [`crate::Logprobs`] has
135    /// been set to [`crate::Logprobs::Sampled`] or [`crate::Logprobs::Top`].
136    pub logprobs: Vec<Distribution>,
137    pub usage: Usage,
138}
139
140impl ChatOutput {
141    pub fn new(
142        message: Message<'static>,
143        finish_reason: String,
144        logprobs: Vec<Distribution>,
145        usage: Usage,
146    ) -> Self {
147        Self {
148            message,
149            finish_reason,
150            logprobs,
151            usage,
152        }
153    }
154}
155
156#[derive(Deserialize, Debug, PartialEq)]
157pub struct ResponseChoice {
158    pub message: Message<'static>,
159    pub finish_reason: String,
160    pub logprobs: Option<LogprobContent>,
161}
162
163#[derive(Deserialize, Debug, PartialEq, Default)]
164pub struct LogprobContent {
165    content: Vec<Distribution>,
166}
167
168/// Logprob information for a single token
169#[derive(Deserialize, Debug, PartialEq)]
170pub struct Distribution {
171    // Logarithmic probability of the token returned in the completion
172    #[serde(flatten)]
173    pub sampled: Logprob,
174    // Logarithmic probabilities of the most probable tokens, filled if user has requested [`crate::Logprobs::Top`]
175    #[serde(rename = "top_logprobs")]
176    pub top: Vec<Logprob>,
177}
178
179#[derive(Deserialize, Debug, PartialEq)]
180pub struct ResponseChat {
181    choices: Vec<ResponseChoice>,
182    usage: Usage,
183}
184
185#[derive(Serialize)]
186struct ChatBody<'a> {
187    /// Name of the model tasked with completing the prompt. E.g. `luminous-base"`.
188    pub model: &'a str,
189    /// The list of messages comprising the conversation so far.
190    messages: &'a [Message<'a>],
191    /// Limits the number of tokens, which are generated for the completion.
192    #[serde(skip_serializing_if = "Option::is_none")]
193    pub max_tokens: Option<u32>,
194    #[serde(skip_serializing_if = "<[_]>::is_empty")]
195    pub stop: &'a [&'a str],
196    /// Controls the randomness of the model. Lower values will make the model more deterministic and higher values will make it more random.
197    /// Mathematically, the temperature is used to divide the logits before sampling. A temperature of 0 will always return the most likely token.
198    /// When no value is provided, the default value of 1 will be used.
199    #[serde(skip_serializing_if = "Option::is_none")]
200    pub temperature: Option<f64>,
201    /// "nucleus" parameter to dynamically adjust the number of choices for each predicted token based on the cumulative probabilities. It specifies a probability threshold, below which all less likely tokens are filtered out.
202    /// When no value is provided, the default value of 1 will be used.
203    #[serde(skip_serializing_if = "Option::is_none")]
204    pub top_p: Option<f64>,
205    #[serde(skip_serializing_if = "Option::is_none")]
206    pub frequency_penalty: Option<f64>,
207    #[serde(skip_serializing_if = "Option::is_none")]
208    pub presence_penalty: Option<f64>,
209    /// Whether to stream the response or not.
210    #[serde(skip_serializing_if = "std::ops::Not::not")]
211    pub stream: bool,
212    #[serde(skip_serializing_if = "std::ops::Not::not")]
213    pub logprobs: bool,
214    #[serde(skip_serializing_if = "Option::is_none")]
215    pub top_logprobs: Option<u8>,
216}
217
218impl<'a> ChatBody<'a> {
219    pub fn new(model: &'a str, task: &'a TaskChat) -> Self {
220        let TaskChat {
221            messages,
222            stopping:
223                Stopping {
224                    maximum_tokens,
225                    stop_sequences,
226                },
227            sampling:
228                ChatSampling {
229                    temperature,
230                    top_p,
231                    frequency_penalty,
232                    presence_penalty,
233                },
234            logprobs,
235        } = task;
236
237        Self {
238            model,
239            messages,
240            max_tokens: *maximum_tokens,
241            stop: stop_sequences,
242            temperature: *temperature,
243            top_p: *top_p,
244            frequency_penalty: *frequency_penalty,
245            presence_penalty: *presence_penalty,
246            stream: false,
247            logprobs: logprobs.logprobs(),
248            top_logprobs: logprobs.top_logprobs(),
249        }
250    }
251
252    pub fn with_streaming(mut self) -> Self {
253        self.stream = true;
254        self
255    }
256}
257
258impl Task for TaskChat<'_> {
259    type Output = ChatOutput;
260
261    type ResponseBody = ResponseChat;
262
263    fn build_request(
264        &self,
265        client: &reqwest::Client,
266        base: &str,
267        model: &str,
268    ) -> reqwest::RequestBuilder {
269        let body = ChatBody::new(model, self);
270        client.post(format!("{base}/chat/completions")).json(&body)
271    }
272
273    fn body_to_output(&self, mut response: Self::ResponseBody) -> Self::Output {
274        let ResponseChoice {
275            message,
276            finish_reason,
277            logprobs,
278        } = response.choices.pop().unwrap();
279        ChatOutput::new(
280            message,
281            finish_reason,
282            logprobs.unwrap_or_default().content,
283            response.usage,
284        )
285    }
286}
287
288#[derive(Deserialize)]
289pub struct StreamMessage {
290    /// The role of the current chat completion. Will be assistant for the first chunk of every
291    /// completion stream and missing for the remaining chunks.
292    pub role: Option<String>,
293    /// The content of the current chat completion. Will be empty for the first chunk of every
294    /// completion stream and non-empty for the remaining chunks.
295    pub content: String,
296}
297
298/// One chunk of a chat completion stream.
299#[derive(Deserialize)]
300pub struct ChatStreamChunk {
301    /// The reason the model stopped generating tokens.
302    /// The value is only set in the last chunk of a completion and null otherwise.
303    pub finish_reason: Option<String>,
304    /// Chat completion chunk generated by the model when streaming is enabled.
305    pub delta: StreamMessage,
306}
307
308/// Event received from a chat completion stream. As the crate does not support multiple
309/// chat completions, there will always exactly one choice item.
310#[derive(Deserialize)]
311pub struct ChatEvent {
312    pub choices: Vec<ChatStreamChunk>,
313}
314
315impl StreamTask for TaskChat<'_> {
316    type Output = ChatStreamChunk;
317
318    type ResponseBody = ChatEvent;
319
320    fn build_request(
321        &self,
322        client: &reqwest::Client,
323        base: &str,
324        model: &str,
325    ) -> reqwest::RequestBuilder {
326        let body = ChatBody::new(model, self).with_streaming();
327        client.post(format!("{base}/chat/completions")).json(&body)
328    }
329
330    fn body_to_output(mut response: Self::ResponseBody) -> Self::Output {
331        // We always expect there to be exactly one choice, as the `n` parameter is not
332        // supported by this crate.
333        response
334            .choices
335            .pop()
336            .expect("There must always be at least one choice")
337    }
338}
339
340impl Logprobs {
341    /// Representation for serialization in request body, for `logprobs` parameter
342    pub fn logprobs(self) -> bool {
343        match self {
344            Logprobs::No => false,
345            Logprobs::Sampled | Logprobs::Top(_) => true,
346        }
347    }
348
349    /// Representation for serialization in request body, for `top_logprobs` parameter
350    pub fn top_logprobs(self) -> Option<u8> {
351        match self {
352            Logprobs::No | Logprobs::Sampled => None,
353            Logprobs::Top(n) => Some(n),
354        }
355    }
356}