aleph_alpha_client/completion.rs
1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5use crate::{http::Task, Distribution, Logprob, Logprobs, Prompt, StreamTask, Usage};
6
7/// Completes a prompt. E.g. continues a text.
8pub struct TaskCompletion<'a> {
9 /// The prompt (usually text) to be completed. Unconditional completion can be started with an
10 /// empty string. The prompt may contain a zero shot or few shot task.
11 pub prompt: Prompt<'a>,
12 /// Controls in which circumstances the model will stop generating new tokens.
13 pub stopping: Stopping<'a>,
14 /// Sampling controls how the tokens ("words") are selected for the completion.
15 pub sampling: Sampling,
16 /// Whether to include special tokens (e.g. <|endoftext|>, <|python_tag|>) in the completion.
17 pub special_tokens: bool,
18 /// Wether you are interessted in the probabilities of the sampled tokens, or most likely
19 /// tokens.
20 pub logprobs: Logprobs,
21 /// Echo the prompt in the completion. This may be especially helpful when log_probs is set
22 /// to return logprobs for the prompt.
23 pub echo: bool,
24}
25
26impl<'a> TaskCompletion<'a> {
27 /// Convenience constructor leaving most setting to default, just completing a given text
28 pub fn from_text(text: &'a str) -> Self {
29 TaskCompletion {
30 prompt: Prompt::from_text(text),
31 stopping: Stopping::NO_TOKEN_LIMIT,
32 sampling: Sampling::MOST_LIKELY,
33 special_tokens: false,
34 logprobs: Logprobs::No,
35 echo: false,
36 }
37 }
38
39 pub fn with_maximum_tokens(mut self, maximum_tokens: u32) -> Self {
40 self.stopping.maximum_tokens = Some(maximum_tokens);
41 self
42 }
43
44 pub fn with_stop_sequences(mut self, stop_sequences: &'a [&str]) -> Self {
45 self.stopping.stop_sequences = stop_sequences;
46 self
47 }
48
49 /// Include special tokens (e.g. <|endoftext|>, <|python_tag|>) in the completion.
50 pub fn with_special_tokens(mut self) -> Self {
51 self.special_tokens = true;
52 self
53 }
54
55 pub fn with_logprobs(mut self, logprobs: Logprobs) -> Self {
56 self.logprobs = logprobs;
57 self
58 }
59
60 pub fn with_echo(mut self) -> Self {
61 self.echo = true;
62 self
63 }
64}
65
66/// Sampling controls how the tokens ("words") are selected for the completion.
67pub struct Sampling {
68 /// A temperature encourages the model to produce less probable outputs ("be more creative").
69 /// Values are expected to be between 0 and 1. Try high values for a more random ("creative")
70 /// response.
71 pub temperature: Option<f64>,
72 /// Introduces random sampling for generated tokens by randomly selecting the next token from
73 /// the k most likely options. A value larger than 1 encourages the model to be more creative.
74 /// Set to 0 to get the same behaviour as `None`.
75 pub top_k: Option<u32>,
76 /// Introduces random sampling for generated tokens by randomly selecting the next token from
77 /// the smallest possible set of tokens whose cumulative probability exceeds the probability
78 /// top_p. Set to 0 to get the same behaviour as `None`.
79 pub top_p: Option<f64>,
80 /// When specified, this number will decrease (or increase) the likelihood of repeating tokens
81 /// that were mentioned prior in the completion. The penalty is cumulative. The more a token
82 /// is mentioned in the completion, the more its probability will decrease.
83 /// A negative value will increase the likelihood of repeating tokens.
84 pub frequency_penalty: Option<f64>,
85 /// The presence penalty reduces the likelihood of generating tokens that are already present
86 /// in the generated text (repetition_penalties_include_completion=true) respectively the
87 /// prompt (repetition_penalties_include_prompt=true). Presence penalty is independent of the
88 /// number of occurrences. Increase the value to reduce the likelihood of repeating text.
89 /// An operation like the following is applied:
90 ///
91 /// logits[t] -> logits[t] - 1 * penalty
92 ///
93 /// where logits[t] is the logits for any given token. Note that the formula is independent
94 /// of the number of times that a token appears.
95 pub presence_penalty: Option<f64>,
96}
97
98impl Sampling {
99 /// Always chooses the token most likely to come next. Choose this if you do want close to
100 /// deterministic behaviour and do not want to apply any penalties to avoid repetitions.
101 pub const MOST_LIKELY: Self = Sampling {
102 temperature: None,
103 top_k: None,
104 top_p: None,
105 frequency_penalty: None,
106 presence_penalty: None,
107 };
108}
109
110impl Default for Sampling {
111 fn default() -> Self {
112 Self::MOST_LIKELY
113 }
114}
115
116/// Controls the conditions under which the language models stops generating text.
117pub struct Stopping<'a> {
118 /// The maximum number of tokens to be generated. Completion will terminate after the maximum
119 /// number of tokens is reached. Increase this value to allow for longer outputs. A text is split
120 /// into tokens. Usually there are more tokens than words. The total number of tokens of prompt
121 /// and maximum_tokens depends on the model.
122 /// If maximum tokens is set to None, no outside limit is opposed on the number of maximum tokens.
123 /// The model will generate tokens until it generates one of the specified stop_sequences or it
124 /// reaches its technical limit, which usually is its context window.
125 pub maximum_tokens: Option<u32>,
126 /// List of strings which will stop generation if they are generated. Stop sequences are
127 /// helpful in structured texts. E.g.: In a question answering scenario a text may consist of
128 /// lines starting with either "Question: " or "Answer: " (alternating). After producing an
129 /// answer, the model will be likely to generate "Question: ". "Question: " may therefore be used
130 /// as stop sequence in order not to have the model generate more questions but rather restrict
131 /// text generation to the answers.
132 pub stop_sequences: &'a [&'a str],
133}
134
135impl<'a> Stopping<'a> {
136 /// Only stop once the model reaches its technical limit, usually the context window.
137 pub const NO_TOKEN_LIMIT: Self = Stopping {
138 maximum_tokens: None,
139 stop_sequences: &[],
140 };
141
142 /// Stop once the model has reached maximum_tokens.
143 pub fn from_maximum_tokens(maximum_tokens: u32) -> Self {
144 Self {
145 maximum_tokens: Some(maximum_tokens),
146 stop_sequences: &[],
147 }
148 }
149
150 pub fn from_stop_sequences(stop_sequences: &'a [&'a str]) -> Self {
151 Self {
152 maximum_tokens: None,
153 stop_sequences,
154 }
155 }
156}
157
158impl Default for Stopping<'_> {
159 fn default() -> Self {
160 Self::NO_TOKEN_LIMIT
161 }
162}
163
164/// Body send to the Aleph Alpha API on the POST `/completion` Route
165#[derive(Serialize, Debug)]
166struct BodyCompletion<'a> {
167 /// Name of the model tasked with completing the prompt. E.g. `luminous-base"`.
168 pub model: &'a str,
169 /// Prompt to complete. The modalities supported depend on `model`.
170 pub prompt: Prompt<'a>,
171 /// Limits the number of tokens, which are generated for the completion.
172 #[serde(skip_serializing_if = "Option::is_none")]
173 pub maximum_tokens: Option<u32>,
174 /// List of strings which will stop generation if they are generated. Stop sequences are
175 /// helpful in structured texts. E.g.: In a question answering scenario a text may consist of
176 /// lines starting with either "Question: " or "Answer: " (alternating). After producing an
177 /// answer, the model will be likely to generate "Question: ". "Question: " may therefore be used
178 /// as stop sequence in order not to have the model generate more questions but rather restrict
179 /// text generation to the answers.
180 #[serde(skip_serializing_if = "<[_]>::is_empty")]
181 pub stop_sequences: &'a [&'a str],
182 #[serde(skip_serializing_if = "Option::is_none")]
183 pub temperature: Option<f64>,
184 #[serde(skip_serializing_if = "Option::is_none")]
185 pub top_k: Option<u32>,
186 #[serde(skip_serializing_if = "Option::is_none")]
187 pub top_p: Option<f64>,
188 /// If true, the response will be streamed.
189 #[serde(skip_serializing_if = "std::ops::Not::not")]
190 pub stream: bool,
191 /// Forces the raw completion of the model to be returned.
192 /// For some models, the completion that was generated by the model may be optimized and
193 /// returned in the completion field of the CompletionResponse.
194 /// The raw completion, if returned, will contain the un-optimized completion.
195 /// Setting tokens to true or log_probs to any value will also trigger the raw completion to be returned.
196 #[serde(skip_serializing_if = "std::ops::Not::not")]
197 pub raw_completion: bool,
198 #[serde(skip_serializing_if = "Option::is_none")]
199 pub frequency_penalty: Option<f64>,
200 #[serde(skip_serializing_if = "Option::is_none")]
201 pub presence_penalty: Option<f64>,
202 #[serde(skip_serializing_if = "Option::is_none")]
203 pub log_probs: Option<u8>,
204 #[serde(skip_serializing_if = "std::ops::Not::not")]
205 pub tokens: bool,
206 #[serde(skip_serializing_if = "std::ops::Not::not")]
207 pub echo: bool,
208}
209
210impl<'a> BodyCompletion<'a> {
211 pub fn new(model: &'a str, task: &'a TaskCompletion<'a>) -> Self {
212 let TaskCompletion {
213 prompt,
214 stopping,
215 sampling,
216 special_tokens,
217 logprobs,
218 echo,
219 } = task;
220 Self {
221 model,
222 prompt: prompt.borrow(),
223 maximum_tokens: stopping.maximum_tokens,
224 stop_sequences: stopping.stop_sequences,
225 temperature: sampling.temperature,
226 top_k: sampling.top_k,
227 top_p: sampling.top_p,
228 stream: false,
229 raw_completion: *special_tokens,
230 frequency_penalty: sampling.frequency_penalty,
231 presence_penalty: sampling.presence_penalty,
232 log_probs: logprobs.to_logprobs_num(),
233 tokens: logprobs.to_tokens(),
234 echo: *echo,
235 }
236 }
237 pub fn with_streaming(mut self) -> Self {
238 self.stream = true;
239 self
240 }
241}
242
243#[derive(Deserialize, Debug, PartialEq)]
244pub struct ResponseCompletion {
245 model_version: String,
246 completions: Vec<DeserializedCompletion>,
247 num_tokens_prompt_total: u32,
248 num_tokens_generated: u32,
249}
250
251#[derive(Deserialize, Debug, PartialEq)]
252struct DeserializedCompletion {
253 completion: String,
254 finish_reason: String,
255 raw_completion: Option<String>,
256 #[serde(default)]
257 log_probs: Vec<HashMap<String, f64>>,
258 #[serde(default)]
259 completion_tokens: Vec<String>,
260}
261
262/// Completion and metainformation returned by a completion task
263#[derive(Deserialize, Debug, PartialEq)]
264pub struct CompletionOutput {
265 pub completion: String,
266 pub finish_reason: String,
267 pub logprobs: Vec<Distribution>,
268 pub usage: Usage,
269}
270
271impl Task for TaskCompletion<'_> {
272 type Output = CompletionOutput;
273
274 type ResponseBody = ResponseCompletion;
275
276 fn build_request(
277 &self,
278 client: &reqwest::Client,
279 base: &str,
280 model: &str,
281 ) -> reqwest::RequestBuilder {
282 let body = BodyCompletion::new(model, self);
283 client.post(format!("{base}/complete")).json(&body)
284 }
285
286 fn body_to_output(&self, mut response: Self::ResponseBody) -> Self::Output {
287 // We expect the API to return exactly one completion, despite them being modled as an array
288 let DeserializedCompletion {
289 completion,
290 finish_reason,
291 raw_completion,
292 log_probs,
293 completion_tokens,
294 } = response.completions.pop().unwrap();
295 let completion = if self.special_tokens {
296 raw_completion.unwrap()
297 } else {
298 completion
299 };
300 CompletionOutput {
301 completion,
302 finish_reason,
303 logprobs: completion_logprobs_to_canonical(
304 log_probs,
305 completion_tokens,
306 self.logprobs.top_logprobs().unwrap_or_default(),
307 ),
308 usage: Usage {
309 prompt_tokens: response.num_tokens_prompt_total,
310 completion_tokens: response.num_tokens_generated,
311 },
312 }
313 }
314}
315
316fn completion_logprobs_to_canonical(
317 log_probs: Vec<HashMap<String, f64>>,
318 completion_tokens: Vec<String>,
319 num_expected_top_logprobs: u8,
320) -> Vec<Distribution> {
321 let mut logprobs = Vec::new();
322 for (token, map) in completion_tokens.into_iter().zip(log_probs) {
323 // The NAN case can occur if `echo` is set to true, as there are no logprobs returned for the first token.
324 let logprob = *map.get(&token).unwrap_or(&f64::NAN);
325 let mut top_logprobs = map
326 .into_iter()
327 .map(|(token, logprob)| Logprob {
328 token: token.into_bytes(),
329 logprob,
330 })
331 .collect::<Vec<_>>();
332 // We want to make sure the most likely tokens are first in the array
333 top_logprobs.sort_by(|a, b| b.logprob.total_cmp(&a.logprob));
334 // The aa api always makes the sampled token part of the array, even if not in the top n
335 // elements. Since we translate into a representation with the sampled token separate, we
336 // can keep the top n elements constant. In case the sampled token has not been in the top
337 // n, the below line will shorten the array by one.
338 // For the special case of `echo` being set to true, we will receive an empty list of top
339 // logprobs for the first token.
340 top_logprobs = top_logprobs
341 .into_iter()
342 .take(num_expected_top_logprobs as usize)
343 .collect();
344 logprobs.push(Distribution {
345 sampled: Logprob {
346 token: token.into_bytes(),
347 logprob,
348 },
349 top: top_logprobs,
350 });
351 }
352 logprobs
353}
354
355#[derive(Deserialize)]
356#[serde(tag = "type")]
357#[serde(rename_all = "snake_case")]
358pub enum DeserializedCompletionEvent {
359 StreamChunk {
360 /// The completion of the stream.
361 completion: String,
362 /// Completion with special tokens still included
363 raw_completion: Option<String>,
364 #[serde(default)]
365 log_probs: Vec<HashMap<String, f64>>,
366 #[serde(default)]
367 completion_tokens: Vec<String>,
368 },
369 StreamSummary {
370 /// The reason why the model stopped generating new tokens.
371 finish_reason: String,
372 },
373 CompletionSummary {
374 /// Number of tokens combined across all completion tasks.
375 /// In particular, if you set best_of or n to a number larger than 1 then we report the
376 /// combined prompt token count for all best_of or n tasks.
377 num_tokens_prompt_total: u32,
378 /// Number of tokens combined across all completion tasks.
379 /// If multiple completions are returned or best_of is set to a value greater than 1 then
380 /// this value contains the combined generated token count.
381 num_tokens_generated: u32,
382 },
383}
384
385#[derive(Debug, PartialEq)]
386pub enum CompletionEvent {
387 Delta {
388 /// The completion of the stream.
389 completion: String,
390 /// Log probabilities of the completion tokens if requested via logprobs parameter in request.
391 logprobs: Vec<Distribution>,
392 },
393 Finished {
394 /// The reason why the model stopped generating new tokens.
395 reason: String,
396 },
397 Summary {
398 usage: Usage,
399 },
400}
401
402impl StreamTask for TaskCompletion<'_> {
403 type Output = CompletionEvent;
404
405 type ResponseBody = DeserializedCompletionEvent;
406
407 fn build_request(
408 &self,
409 client: &reqwest::Client,
410 base: &str,
411 model: &str,
412 ) -> reqwest::RequestBuilder {
413 let body = BodyCompletion::new(model, self).with_streaming();
414 client.post(format!("{base}/complete")).json(&body)
415 }
416
417 fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output {
418 match response {
419 DeserializedCompletionEvent::StreamChunk {
420 completion,
421 raw_completion,
422 log_probs,
423 completion_tokens,
424 } => CompletionEvent::Delta {
425 completion: if self.special_tokens {
426 raw_completion.expect("Missing raw completion")
427 } else {
428 completion
429 },
430 logprobs: completion_logprobs_to_canonical(
431 log_probs,
432 completion_tokens,
433 self.logprobs.top_logprobs().unwrap_or_default(),
434 ),
435 },
436 DeserializedCompletionEvent::StreamSummary { finish_reason } => {
437 CompletionEvent::Finished {
438 reason: finish_reason,
439 }
440 }
441 DeserializedCompletionEvent::CompletionSummary {
442 num_tokens_prompt_total,
443 num_tokens_generated,
444 } => CompletionEvent::Summary {
445 usage: Usage {
446 prompt_tokens: num_tokens_prompt_total,
447 completion_tokens: num_tokens_generated,
448 },
449 },
450 }
451 }
452}
453
454impl Logprobs {
455 /// Convert into a number for completion endpoint
456 fn to_logprobs_num(self) -> Option<u8> {
457 match self {
458 Logprobs::No => None,
459 Logprobs::Sampled => Some(0),
460 Logprobs::Top(n) => Some(n),
461 }
462 }
463
464 /// Wether or not we want to return the completion tokens
465 fn to_tokens(self) -> bool {
466 match self {
467 Logprobs::No => false,
468 Logprobs::Sampled | Logprobs::Top(_) => true,
469 }
470 }
471}