aleph_alpha_client/
completion.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5use crate::{http::Task, Distribution, Logprob, Logprobs, Prompt, StreamTask, Usage};
6
7/// Completes a prompt. E.g. continues a text.
8pub struct TaskCompletion<'a> {
9    /// The prompt (usually text) to be completed. Unconditional completion can be started with an
10    /// empty string. The prompt may contain a zero shot or few shot task.
11    pub prompt: Prompt<'a>,
12    /// Controls in which circumstances the model will stop generating new tokens.
13    pub stopping: Stopping<'a>,
14    /// Sampling controls how the tokens ("words") are selected for the completion.
15    pub sampling: Sampling,
16    /// Whether to include special tokens (e.g. <|endoftext|>, <|python_tag|>) in the completion.
17    pub special_tokens: bool,
18    /// Wether you are interessted in the probabilities of the sampled tokens, or most likely
19    /// tokens.
20    pub logprobs: Logprobs,
21    /// Echo the prompt in the completion. This may be especially helpful when log_probs is set
22    /// to return logprobs for the prompt.
23    pub echo: bool,
24}
25
26impl<'a> TaskCompletion<'a> {
27    /// Convenience constructor leaving most setting to default, just completing a given text
28    pub fn from_text(text: &'a str) -> Self {
29        TaskCompletion {
30            prompt: Prompt::from_text(text),
31            stopping: Stopping::NO_TOKEN_LIMIT,
32            sampling: Sampling::MOST_LIKELY,
33            special_tokens: false,
34            logprobs: Logprobs::No,
35            echo: false,
36        }
37    }
38
39    pub fn with_maximum_tokens(mut self, maximum_tokens: u32) -> Self {
40        self.stopping.maximum_tokens = Some(maximum_tokens);
41        self
42    }
43
44    pub fn with_stop_sequences(mut self, stop_sequences: &'a [&str]) -> Self {
45        self.stopping.stop_sequences = stop_sequences;
46        self
47    }
48
49    /// Include special tokens (e.g. <|endoftext|>, <|python_tag|>) in the completion.
50    pub fn with_special_tokens(mut self) -> Self {
51        self.special_tokens = true;
52        self
53    }
54
55    pub fn with_logprobs(mut self, logprobs: Logprobs) -> Self {
56        self.logprobs = logprobs;
57        self
58    }
59
60    pub fn with_echo(mut self) -> Self {
61        self.echo = true;
62        self
63    }
64}
65
66/// Sampling controls how the tokens ("words") are selected for the completion.
67pub struct Sampling {
68    /// A temperature encourages the model to produce less probable outputs ("be more creative").
69    /// Values are expected to be between 0 and 1. Try high values for a more random ("creative")
70    /// response.
71    pub temperature: Option<f64>,
72    /// Introduces random sampling for generated tokens by randomly selecting the next token from
73    /// the k most likely options. A value larger than 1 encourages the model to be more creative.
74    /// Set to 0 to get the same behaviour as `None`.
75    pub top_k: Option<u32>,
76    /// Introduces random sampling for generated tokens by randomly selecting the next token from
77    /// the smallest possible set of tokens whose cumulative probability exceeds the probability
78    /// top_p. Set to 0 to get the same behaviour as `None`.
79    pub top_p: Option<f64>,
80    /// When specified, this number will decrease (or increase) the likelihood of repeating tokens
81    /// that were mentioned prior in the completion. The penalty is cumulative. The more a token
82    /// is mentioned in the completion, the more its probability will decrease.
83    /// A negative value will increase the likelihood of repeating tokens.
84    pub frequency_penalty: Option<f64>,
85    /// The presence penalty reduces the likelihood of generating tokens that are already present
86    /// in the generated text (repetition_penalties_include_completion=true) respectively the
87    /// prompt (repetition_penalties_include_prompt=true). Presence penalty is independent of the
88    /// number of occurrences. Increase the value to reduce the likelihood of repeating text.
89    /// An operation like the following is applied:
90    ///
91    /// logits[t] -> logits[t] - 1 * penalty
92    ///
93    /// where logits[t] is the logits for any given token. Note that the formula is independent
94    /// of the number of times that a token appears.
95    pub presence_penalty: Option<f64>,
96}
97
98impl Sampling {
99    /// Always chooses the token most likely to come next. Choose this if you do want close to
100    /// deterministic behaviour and do not want to apply any penalties to avoid repetitions.
101    pub const MOST_LIKELY: Self = Sampling {
102        temperature: None,
103        top_k: None,
104        top_p: None,
105        frequency_penalty: None,
106        presence_penalty: None,
107    };
108}
109
110impl Default for Sampling {
111    fn default() -> Self {
112        Self::MOST_LIKELY
113    }
114}
115
116/// Controls the conditions under which the language models stops generating text.
117pub struct Stopping<'a> {
118    /// The maximum number of tokens to be generated. Completion will terminate after the maximum
119    /// number of tokens is reached. Increase this value to allow for longer outputs. A text is split
120    /// into tokens. Usually there are more tokens than words. The total number of tokens of prompt
121    /// and maximum_tokens depends on the model.
122    /// If maximum tokens is set to None, no outside limit is opposed on the number of maximum tokens.
123    /// The model will generate tokens until it generates one of the specified stop_sequences or it
124    /// reaches its technical limit, which usually is its context window.
125    pub maximum_tokens: Option<u32>,
126    /// List of strings which will stop generation if they are generated. Stop sequences are
127    /// helpful in structured texts. E.g.: In a question answering scenario a text may consist of
128    /// lines starting with either "Question: " or "Answer: " (alternating). After producing an
129    /// answer, the model will be likely to generate "Question: ". "Question: " may therefore be used
130    /// as stop sequence in order not to have the model generate more questions but rather restrict
131    /// text generation to the answers.
132    pub stop_sequences: &'a [&'a str],
133}
134
135impl<'a> Stopping<'a> {
136    /// Only stop once the model reaches its technical limit, usually the context window.
137    pub const NO_TOKEN_LIMIT: Self = Stopping {
138        maximum_tokens: None,
139        stop_sequences: &[],
140    };
141
142    /// Stop once the model has reached maximum_tokens.
143    pub fn from_maximum_tokens(maximum_tokens: u32) -> Self {
144        Self {
145            maximum_tokens: Some(maximum_tokens),
146            stop_sequences: &[],
147        }
148    }
149
150    pub fn from_stop_sequences(stop_sequences: &'a [&'a str]) -> Self {
151        Self {
152            maximum_tokens: None,
153            stop_sequences,
154        }
155    }
156}
157
158impl Default for Stopping<'_> {
159    fn default() -> Self {
160        Self::NO_TOKEN_LIMIT
161    }
162}
163
164/// Body send to the Aleph Alpha API on the POST `/completion` Route
165#[derive(Serialize, Debug)]
166struct BodyCompletion<'a> {
167    /// Name of the model tasked with completing the prompt. E.g. `luminous-base"`.
168    pub model: &'a str,
169    /// Prompt to complete. The modalities supported depend on `model`.
170    pub prompt: Prompt<'a>,
171    /// Limits the number of tokens, which are generated for the completion.
172    #[serde(skip_serializing_if = "Option::is_none")]
173    pub maximum_tokens: Option<u32>,
174    /// List of strings which will stop generation if they are generated. Stop sequences are
175    /// helpful in structured texts. E.g.: In a question answering scenario a text may consist of
176    /// lines starting with either "Question: " or "Answer: " (alternating). After producing an
177    /// answer, the model will be likely to generate "Question: ". "Question: " may therefore be used
178    /// as stop sequence in order not to have the model generate more questions but rather restrict
179    /// text generation to the answers.
180    #[serde(skip_serializing_if = "<[_]>::is_empty")]
181    pub stop_sequences: &'a [&'a str],
182    #[serde(skip_serializing_if = "Option::is_none")]
183    pub temperature: Option<f64>,
184    #[serde(skip_serializing_if = "Option::is_none")]
185    pub top_k: Option<u32>,
186    #[serde(skip_serializing_if = "Option::is_none")]
187    pub top_p: Option<f64>,
188    /// If true, the response will be streamed.
189    #[serde(skip_serializing_if = "std::ops::Not::not")]
190    pub stream: bool,
191    /// Forces the raw completion of the model to be returned.
192    /// For some models, the completion that was generated by the model may be optimized and
193    /// returned in the completion field of the CompletionResponse.
194    /// The raw completion, if returned, will contain the un-optimized completion.
195    /// Setting tokens to true or log_probs to any value will also trigger the raw completion to be returned.
196    #[serde(skip_serializing_if = "std::ops::Not::not")]
197    pub raw_completion: bool,
198    #[serde(skip_serializing_if = "Option::is_none")]
199    pub frequency_penalty: Option<f64>,
200    #[serde(skip_serializing_if = "Option::is_none")]
201    pub presence_penalty: Option<f64>,
202    #[serde(skip_serializing_if = "Option::is_none")]
203    pub log_probs: Option<u8>,
204    #[serde(skip_serializing_if = "std::ops::Not::not")]
205    pub tokens: bool,
206    #[serde(skip_serializing_if = "std::ops::Not::not")]
207    pub echo: bool,
208}
209
210impl<'a> BodyCompletion<'a> {
211    pub fn new(model: &'a str, task: &'a TaskCompletion<'a>) -> Self {
212        let TaskCompletion {
213            prompt,
214            stopping,
215            sampling,
216            special_tokens,
217            logprobs,
218            echo,
219        } = task;
220        Self {
221            model,
222            prompt: prompt.borrow(),
223            maximum_tokens: stopping.maximum_tokens,
224            stop_sequences: stopping.stop_sequences,
225            temperature: sampling.temperature,
226            top_k: sampling.top_k,
227            top_p: sampling.top_p,
228            stream: false,
229            raw_completion: *special_tokens,
230            frequency_penalty: sampling.frequency_penalty,
231            presence_penalty: sampling.presence_penalty,
232            log_probs: logprobs.to_logprobs_num(),
233            tokens: logprobs.to_tokens(),
234            echo: *echo,
235        }
236    }
237    pub fn with_streaming(mut self) -> Self {
238        self.stream = true;
239        self
240    }
241}
242
243#[derive(Deserialize, Debug, PartialEq)]
244pub struct ResponseCompletion {
245    model_version: String,
246    completions: Vec<DeserializedCompletion>,
247    num_tokens_prompt_total: u32,
248    num_tokens_generated: u32,
249}
250
251#[derive(Deserialize, Debug, PartialEq)]
252struct DeserializedCompletion {
253    completion: String,
254    finish_reason: String,
255    raw_completion: Option<String>,
256    #[serde(default)]
257    log_probs: Vec<HashMap<String, f64>>,
258    #[serde(default)]
259    completion_tokens: Vec<String>,
260}
261
262/// Completion and metainformation returned by a completion task
263#[derive(Deserialize, Debug, PartialEq)]
264pub struct CompletionOutput {
265    pub completion: String,
266    pub finish_reason: String,
267    pub logprobs: Vec<Distribution>,
268    pub usage: Usage,
269}
270
271impl Task for TaskCompletion<'_> {
272    type Output = CompletionOutput;
273
274    type ResponseBody = ResponseCompletion;
275
276    fn build_request(
277        &self,
278        client: &reqwest::Client,
279        base: &str,
280        model: &str,
281    ) -> reqwest::RequestBuilder {
282        let body = BodyCompletion::new(model, self);
283        client.post(format!("{base}/complete")).json(&body)
284    }
285
286    fn body_to_output(&self, mut response: Self::ResponseBody) -> Self::Output {
287        // We expect the API to return exactly one completion, despite them being modled as an array
288        let DeserializedCompletion {
289            completion,
290            finish_reason,
291            raw_completion,
292            log_probs,
293            completion_tokens,
294        } = response.completions.pop().unwrap();
295        let completion = if self.special_tokens {
296            raw_completion.unwrap()
297        } else {
298            completion
299        };
300        CompletionOutput {
301            completion,
302            finish_reason,
303            logprobs: completion_logprobs_to_canonical(
304                log_probs,
305                completion_tokens,
306                self.logprobs.top_logprobs().unwrap_or_default(),
307            ),
308            usage: Usage {
309                prompt_tokens: response.num_tokens_prompt_total,
310                completion_tokens: response.num_tokens_generated,
311            },
312        }
313    }
314}
315
316fn completion_logprobs_to_canonical(
317    log_probs: Vec<HashMap<String, f64>>,
318    completion_tokens: Vec<String>,
319    num_expected_top_logprobs: u8,
320) -> Vec<Distribution> {
321    let mut logprobs = Vec::new();
322    for (token, map) in completion_tokens.into_iter().zip(log_probs) {
323        // The NAN case can occur if `echo` is set to true, as there are no logprobs returned for the first token.
324        let logprob = *map.get(&token).unwrap_or(&f64::NAN);
325        let mut top_logprobs = map
326            .into_iter()
327            .map(|(token, logprob)| Logprob {
328                token: token.into_bytes(),
329                logprob,
330            })
331            .collect::<Vec<_>>();
332        // We want to make sure the most likely tokens are first in the array
333        top_logprobs.sort_by(|a, b| b.logprob.total_cmp(&a.logprob));
334        // The aa api always makes the sampled token part of the array, even if not in the top n
335        // elements. Since we translate into a representation with the sampled token separate, we
336        // can keep the top n elements constant. In case the sampled token has not been in the top
337        // n, the below line will shorten the array by one.
338        // For the special case of `echo` being set to true, we will receive an empty list of top
339        // logprobs for the first token.
340        top_logprobs = top_logprobs
341            .into_iter()
342            .take(num_expected_top_logprobs as usize)
343            .collect();
344        logprobs.push(Distribution {
345            sampled: Logprob {
346                token: token.into_bytes(),
347                logprob,
348            },
349            top: top_logprobs,
350        });
351    }
352    logprobs
353}
354
355#[derive(Deserialize)]
356#[serde(tag = "type")]
357#[serde(rename_all = "snake_case")]
358pub enum DeserializedCompletionEvent {
359    StreamChunk {
360        /// The completion of the stream.
361        completion: String,
362        /// Completion with special tokens still included
363        raw_completion: Option<String>,
364        #[serde(default)]
365        log_probs: Vec<HashMap<String, f64>>,
366        #[serde(default)]
367        completion_tokens: Vec<String>,
368    },
369    StreamSummary {
370        /// The reason why the model stopped generating new tokens.
371        finish_reason: String,
372    },
373    CompletionSummary {
374        /// Number of tokens combined across all completion tasks.
375        /// In particular, if you set best_of or n to a number larger than 1 then we report the
376        /// combined prompt token count for all best_of or n tasks.
377        num_tokens_prompt_total: u32,
378        /// Number of tokens combined across all completion tasks.
379        /// If multiple completions are returned or best_of is set to a value greater than 1 then
380        /// this value contains the combined generated token count.
381        num_tokens_generated: u32,
382    },
383}
384
385#[derive(Debug, PartialEq)]
386pub enum CompletionEvent {
387    Delta {
388        /// The completion of the stream.
389        completion: String,
390        /// Log probabilities of the completion tokens if requested via logprobs parameter in request.
391        logprobs: Vec<Distribution>,
392    },
393    Finished {
394        /// The reason why the model stopped generating new tokens.
395        reason: String,
396    },
397    Summary {
398        usage: Usage,
399    },
400}
401
402impl StreamTask for TaskCompletion<'_> {
403    type Output = CompletionEvent;
404
405    type ResponseBody = DeserializedCompletionEvent;
406
407    fn build_request(
408        &self,
409        client: &reqwest::Client,
410        base: &str,
411        model: &str,
412    ) -> reqwest::RequestBuilder {
413        let body = BodyCompletion::new(model, self).with_streaming();
414        client.post(format!("{base}/complete")).json(&body)
415    }
416
417    fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output {
418        match response {
419            DeserializedCompletionEvent::StreamChunk {
420                completion,
421                raw_completion,
422                log_probs,
423                completion_tokens,
424            } => CompletionEvent::Delta {
425                completion: if self.special_tokens {
426                    raw_completion.expect("Missing raw completion")
427                } else {
428                    completion
429                },
430                logprobs: completion_logprobs_to_canonical(
431                    log_probs,
432                    completion_tokens,
433                    self.logprobs.top_logprobs().unwrap_or_default(),
434                ),
435            },
436            DeserializedCompletionEvent::StreamSummary { finish_reason } => {
437                CompletionEvent::Finished {
438                    reason: finish_reason,
439                }
440            }
441            DeserializedCompletionEvent::CompletionSummary {
442                num_tokens_prompt_total,
443                num_tokens_generated,
444            } => CompletionEvent::Summary {
445                usage: Usage {
446                    prompt_tokens: num_tokens_prompt_total,
447                    completion_tokens: num_tokens_generated,
448                },
449            },
450        }
451    }
452}
453
454impl Logprobs {
455    /// Convert into a number for completion endpoint
456    fn to_logprobs_num(self) -> Option<u8> {
457        match self {
458            Logprobs::No => None,
459            Logprobs::Sampled => Some(0),
460            Logprobs::Top(n) => Some(n),
461        }
462    }
463
464    /// Wether or not we want to return the completion tokens
465    fn to_tokens(self) -> bool {
466        match self {
467            Logprobs::No => false,
468            Logprobs::Sampled | Logprobs::Top(_) => true,
469        }
470    }
471}