aleph_alpha_client/
completion.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5use crate::{http::Task, Distribution, Logprob, Logprobs, Prompt, StreamTask, Usage};
6
7/// Completes a prompt. E.g. continues a text.
8pub struct TaskCompletion<'a> {
9    /// The prompt (usually text) to be completed. Unconditional completion can be started with an
10    /// empty string. The prompt may contain a zero shot or few shot task.
11    pub prompt: Prompt<'a>,
12    /// Controls in which circumstances the model will stop generating new tokens.
13    pub stopping: Stopping<'a>,
14    /// Sampling controls how the tokens ("words") are selected for the completion.
15    pub sampling: Sampling,
16    /// Whether to include special tokens (e.g. <|endoftext|>, <|python_tag|>) in the completion.
17    pub special_tokens: bool,
18    /// Wether you are interessted in the probabilities of the sampled tokens, or most likely
19    /// tokens.
20    pub logprobs: Logprobs,
21}
22
23impl<'a> TaskCompletion<'a> {
24    /// Convenience constructor leaving most setting to default, just completing a given text
25    pub fn from_text(text: &'a str) -> Self {
26        TaskCompletion {
27            prompt: Prompt::from_text(text),
28            stopping: Stopping::NO_TOKEN_LIMIT,
29            sampling: Sampling::MOST_LIKELY,
30            special_tokens: false,
31            logprobs: Logprobs::No,
32        }
33    }
34
35    pub fn with_maximum_tokens(mut self, maximum_tokens: u32) -> Self {
36        self.stopping.maximum_tokens = Some(maximum_tokens);
37        self
38    }
39
40    pub fn with_stop_sequences(mut self, stop_sequences: &'a [&str]) -> Self {
41        self.stopping.stop_sequences = stop_sequences;
42        self
43    }
44
45    /// Include special tokens (e.g. <|endoftext|>, <|python_tag|>) in the completion.
46    pub fn with_special_tokens(mut self) -> Self {
47        self.special_tokens = true;
48        self
49    }
50
51    pub fn with_logprobs(mut self, logprobs: Logprobs) -> Self {
52        self.logprobs = logprobs;
53        self
54    }
55}
56
57/// Sampling controls how the tokens ("words") are selected for the completion.
58pub struct Sampling {
59    /// A temperature encourages the model to produce less probable outputs ("be more creative").
60    /// Values are expected to be between 0 and 1. Try high values for a more random ("creative")
61    /// response.
62    pub temperature: Option<f64>,
63    /// Introduces random sampling for generated tokens by randomly selecting the next token from
64    /// the k most likely options. A value larger than 1 encourages the model to be more creative.
65    /// Set to 0 to get the same behaviour as `None`.
66    pub top_k: Option<u32>,
67    /// Introduces random sampling for generated tokens by randomly selecting the next token from
68    /// the smallest possible set of tokens whose cumulative probability exceeds the probability
69    /// top_p. Set to 0 to get the same behaviour as `None`.
70    pub top_p: Option<f64>,
71    /// When specified, this number will decrease (or increase) the likelihood of repeating tokens
72    /// that were mentioned prior in the completion. The penalty is cumulative. The more a token
73    /// is mentioned in the completion, the more its probability will decrease.
74    /// A negative value will increase the likelihood of repeating tokens.
75    pub frequency_penalty: Option<f64>,
76    /// The presence penalty reduces the likelihood of generating tokens that are already present
77    /// in the generated text (repetition_penalties_include_completion=true) respectively the
78    /// prompt (repetition_penalties_include_prompt=true). Presence penalty is independent of the
79    /// number of occurrences. Increase the value to reduce the likelihood of repeating text.
80    /// An operation like the following is applied:
81    ///
82    /// logits[t] -> logits[t] - 1 * penalty
83    ///
84    /// where logits[t] is the logits for any given token. Note that the formula is independent
85    /// of the number of times that a token appears.
86    pub presence_penalty: Option<f64>,
87}
88
89impl Sampling {
90    /// Always chooses the token most likely to come next. Choose this if you do want close to
91    /// deterministic behaviour and do not want to apply any penalties to avoid repetitions.
92    pub const MOST_LIKELY: Self = Sampling {
93        temperature: None,
94        top_k: None,
95        top_p: None,
96        frequency_penalty: None,
97        presence_penalty: None,
98    };
99}
100
101impl Default for Sampling {
102    fn default() -> Self {
103        Self::MOST_LIKELY
104    }
105}
106
107/// Controls the conditions under which the language models stops generating text.
108pub struct Stopping<'a> {
109    /// The maximum number of tokens to be generated. Completion will terminate after the maximum
110    /// number of tokens is reached. Increase this value to allow for longer outputs. A text is split
111    /// into tokens. Usually there are more tokens than words. The total number of tokens of prompt
112    /// and maximum_tokens depends on the model.
113    /// If maximum tokens is set to None, no outside limit is opposed on the number of maximum tokens.
114    /// The model will generate tokens until it generates one of the specified stop_sequences or it
115    /// reaches its technical limit, which usually is its context window.
116    pub maximum_tokens: Option<u32>,
117    /// List of strings which will stop generation if they are generated. Stop sequences are
118    /// helpful in structured texts. E.g.: In a question answering scenario a text may consist of
119    /// lines starting with either "Question: " or "Answer: " (alternating). After producing an
120    /// answer, the model will be likely to generate "Question: ". "Question: " may therefore be used
121    /// as stop sequence in order not to have the model generate more questions but rather restrict
122    /// text generation to the answers.
123    pub stop_sequences: &'a [&'a str],
124}
125
126impl<'a> Stopping<'a> {
127    /// Only stop once the model reaches its technical limit, usually the context window.
128    pub const NO_TOKEN_LIMIT: Self = Stopping {
129        maximum_tokens: None,
130        stop_sequences: &[],
131    };
132
133    /// Stop once the model has reached maximum_tokens.
134    pub fn from_maximum_tokens(maximum_tokens: u32) -> Self {
135        Self {
136            maximum_tokens: Some(maximum_tokens),
137            stop_sequences: &[],
138        }
139    }
140
141    pub fn from_stop_sequences(stop_sequences: &'a [&'a str]) -> Self {
142        Self {
143            maximum_tokens: None,
144            stop_sequences,
145        }
146    }
147}
148
149impl Default for Stopping<'_> {
150    fn default() -> Self {
151        Self::NO_TOKEN_LIMIT
152    }
153}
154
155/// Body send to the Aleph Alpha API on the POST `/completion` Route
156#[derive(Serialize, Debug)]
157struct BodyCompletion<'a> {
158    /// Name of the model tasked with completing the prompt. E.g. `luminous-base"`.
159    pub model: &'a str,
160    /// Prompt to complete. The modalities supported depend on `model`.
161    pub prompt: Prompt<'a>,
162    /// Limits the number of tokens, which are generated for the completion.
163    #[serde(skip_serializing_if = "Option::is_none")]
164    pub maximum_tokens: Option<u32>,
165    /// List of strings which will stop generation if they are generated. Stop sequences are
166    /// helpful in structured texts. E.g.: In a question answering scenario a text may consist of
167    /// lines starting with either "Question: " or "Answer: " (alternating). After producing an
168    /// answer, the model will be likely to generate "Question: ". "Question: " may therefore be used
169    /// as stop sequence in order not to have the model generate more questions but rather restrict
170    /// text generation to the answers.
171    #[serde(skip_serializing_if = "<[_]>::is_empty")]
172    pub stop_sequences: &'a [&'a str],
173    #[serde(skip_serializing_if = "Option::is_none")]
174    pub temperature: Option<f64>,
175    #[serde(skip_serializing_if = "Option::is_none")]
176    pub top_k: Option<u32>,
177    #[serde(skip_serializing_if = "Option::is_none")]
178    pub top_p: Option<f64>,
179    /// If true, the response will be streamed.
180    #[serde(skip_serializing_if = "std::ops::Not::not")]
181    pub stream: bool,
182    /// Forces the raw completion of the model to be returned.
183    /// For some models, the completion that was generated by the model may be optimized and
184    /// returned in the completion field of the CompletionResponse.
185    /// The raw completion, if returned, will contain the un-optimized completion.
186    /// Setting tokens to true or log_probs to any value will also trigger the raw completion to be returned.
187    #[serde(skip_serializing_if = "std::ops::Not::not")]
188    pub raw_completion: bool,
189    #[serde(skip_serializing_if = "Option::is_none")]
190    pub frequency_penalty: Option<f64>,
191    #[serde(skip_serializing_if = "Option::is_none")]
192    pub presence_penalty: Option<f64>,
193    #[serde(skip_serializing_if = "Option::is_none")]
194    pub log_probs: Option<u8>,
195    #[serde(skip_serializing_if = "std::ops::Not::not")]
196    pub tokens: bool,
197}
198
199impl<'a> BodyCompletion<'a> {
200    pub fn new(model: &'a str, task: &'a TaskCompletion<'a>) -> Self {
201        let TaskCompletion {
202            prompt,
203            stopping,
204            sampling,
205            special_tokens,
206            logprobs,
207        } = task;
208        Self {
209            model,
210            prompt: prompt.borrow(),
211            maximum_tokens: stopping.maximum_tokens,
212            stop_sequences: stopping.stop_sequences,
213            temperature: sampling.temperature,
214            top_k: sampling.top_k,
215            top_p: sampling.top_p,
216            stream: false,
217            raw_completion: *special_tokens,
218            frequency_penalty: sampling.frequency_penalty,
219            presence_penalty: sampling.presence_penalty,
220            log_probs: logprobs.to_logprobs_num(),
221            tokens: logprobs.to_tokens(),
222        }
223    }
224    pub fn with_streaming(mut self) -> Self {
225        self.stream = true;
226        self
227    }
228}
229
230#[derive(Deserialize, Debug, PartialEq)]
231pub struct ResponseCompletion {
232    model_version: String,
233    completions: Vec<DeserializedCompletion>,
234    num_tokens_prompt_total: u32,
235    num_tokens_generated: u32,
236}
237
238#[derive(Deserialize, Debug, PartialEq)]
239struct DeserializedCompletion {
240    completion: String,
241    finish_reason: String,
242    raw_completion: Option<String>,
243    #[serde(default)]
244    log_probs: Vec<HashMap<String, f64>>,
245    #[serde(default)]
246    completion_tokens: Vec<String>,
247}
248
249/// Completion and metainformation returned by a completion task
250#[derive(Deserialize, Debug, PartialEq)]
251pub struct CompletionOutput {
252    pub completion: String,
253    pub finish_reason: String,
254    pub logprobs: Vec<Distribution>,
255    pub usage: Usage,
256}
257
258impl Task for TaskCompletion<'_> {
259    type Output = CompletionOutput;
260
261    type ResponseBody = ResponseCompletion;
262
263    fn build_request(
264        &self,
265        client: &reqwest::Client,
266        base: &str,
267        model: &str,
268    ) -> reqwest::RequestBuilder {
269        let body = BodyCompletion::new(model, self);
270        client.post(format!("{base}/complete")).json(&body)
271    }
272
273    fn body_to_output(&self, mut response: Self::ResponseBody) -> Self::Output {
274        // We expect the API to return exactly one completion, despite them being modled as an array
275        let DeserializedCompletion {
276            completion,
277            finish_reason,
278            raw_completion,
279            log_probs,
280            completion_tokens,
281        } = response.completions.pop().unwrap();
282        let completion = if self.special_tokens {
283            raw_completion.unwrap()
284        } else {
285            completion
286        };
287        CompletionOutput {
288            completion,
289            finish_reason,
290            logprobs: completion_logprobs_to_canonical(
291                log_probs,
292                completion_tokens,
293                self.logprobs.top_logprobs().unwrap_or_default(),
294            ),
295            usage: Usage {
296                prompt_tokens: response.num_tokens_prompt_total,
297                completion_tokens: response.num_tokens_generated,
298            },
299        }
300    }
301}
302
303fn completion_logprobs_to_canonical(
304    log_probs: Vec<HashMap<String, f64>>,
305    completion_tokens: Vec<String>,
306    num_expected_top_logprobs: u8,
307) -> Vec<Distribution> {
308    let mut logprobs = Vec::new();
309    for (token, map) in completion_tokens.into_iter().zip(log_probs) {
310        let logprob = *map.get(&token).unwrap_or(&f64::NAN);
311        let mut top_logprobs = map
312            .into_iter()
313            .map(|(token, logprob)| Logprob {
314                token: token.into_bytes(),
315                logprob,
316            })
317            .collect::<Vec<_>>();
318        // We want to make sure the most likely tokens are first in the array
319        top_logprobs.sort_by(|a, b| b.logprob.total_cmp(&a.logprob));
320        // The aa api always makes the sampled token part of the array, even if not in the top n
321        // elements. Since we translate into a representation with the sampled token separate, we
322        // can keep the top n elements constant. In case the sampled token has not been in the top
323        // n, the below line will shorten the array by one.
324        top_logprobs.resize_with(num_expected_top_logprobs as usize, || {
325            unreachable!("Vec should only shorten")
326        });
327        logprobs.push(Distribution {
328            sampled: Logprob {
329                token: token.into_bytes(),
330                logprob,
331            },
332            top: top_logprobs,
333        });
334    }
335    logprobs
336}
337
338#[derive(Deserialize)]
339#[serde(tag = "type")]
340#[serde(rename_all = "snake_case")]
341pub enum DeserializedCompletionEvent {
342    StreamChunk {
343        /// The completion of the stream.
344        completion: String,
345        /// Completion with special tokens still included
346        raw_completion: Option<String>,
347        #[serde(default)]
348        log_probs: Vec<HashMap<String, f64>>,
349        #[serde(default)]
350        completion_tokens: Vec<String>,
351    },
352    StreamSummary {
353        /// The reason why the model stopped generating new tokens.
354        finish_reason: String,
355    },
356    CompletionSummary {
357        /// Number of tokens combined across all completion tasks.
358        /// In particular, if you set best_of or n to a number larger than 1 then we report the
359        /// combined prompt token count for all best_of or n tasks.
360        num_tokens_prompt_total: u32,
361        /// Number of tokens combined across all completion tasks.
362        /// If multiple completions are returned or best_of is set to a value greater than 1 then
363        /// this value contains the combined generated token count.
364        num_tokens_generated: u32,
365    },
366}
367
368#[derive(Debug, PartialEq)]
369pub enum CompletionEvent {
370    Delta {
371        /// The completion of the stream.
372        completion: String,
373        /// Log probabilities of the completion tokens if requested via logprobs parameter in request.
374        logprobs: Vec<Distribution>,
375    },
376    Finished {
377        /// The reason why the model stopped generating new tokens.
378        reason: String,
379    },
380    Summary {
381        usage: Usage,
382    },
383}
384
385impl StreamTask for TaskCompletion<'_> {
386    type Output = CompletionEvent;
387
388    type ResponseBody = DeserializedCompletionEvent;
389
390    fn build_request(
391        &self,
392        client: &reqwest::Client,
393        base: &str,
394        model: &str,
395    ) -> reqwest::RequestBuilder {
396        let body = BodyCompletion::new(model, self).with_streaming();
397        client.post(format!("{base}/complete")).json(&body)
398    }
399
400    fn body_to_output(&self, response: Self::ResponseBody) -> Option<Self::Output> {
401        Some(match response {
402            DeserializedCompletionEvent::StreamChunk {
403                completion,
404                raw_completion,
405                log_probs,
406                completion_tokens,
407            } => CompletionEvent::Delta {
408                completion: if self.special_tokens {
409                    raw_completion.expect("Missing raw completion")
410                } else {
411                    completion
412                },
413                logprobs: completion_logprobs_to_canonical(
414                    log_probs,
415                    completion_tokens,
416                    self.logprobs.top_logprobs().unwrap_or_default(),
417                ),
418            },
419            DeserializedCompletionEvent::StreamSummary { finish_reason } => {
420                CompletionEvent::Finished {
421                    reason: finish_reason,
422                }
423            }
424            DeserializedCompletionEvent::CompletionSummary {
425                num_tokens_prompt_total,
426                num_tokens_generated,
427            } => CompletionEvent::Summary {
428                usage: Usage {
429                    prompt_tokens: num_tokens_prompt_total,
430                    completion_tokens: num_tokens_generated,
431                },
432            },
433        })
434    }
435}
436
437impl Logprobs {
438    /// Convert into a number for completion endpoint
439    fn to_logprobs_num(self) -> Option<u8> {
440        match self {
441            Logprobs::No => None,
442            Logprobs::Sampled => Some(0),
443            Logprobs::Top(n) => Some(n),
444        }
445    }
446
447    /// Wether or not we want to return the completion tokens
448    fn to_tokens(self) -> bool {
449        match self {
450            Logprobs::No => false,
451            Logprobs::Sampled | Logprobs::Top(_) => true,
452        }
453    }
454}