aleph_alpha_client/completion.rs
1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5use crate::{http::Task, Distribution, Logprob, Logprobs, Prompt, StreamTask, Usage};
6
7/// Completes a prompt. E.g. continues a text.
8pub struct TaskCompletion<'a> {
9 /// The prompt (usually text) to be completed. Unconditional completion can be started with an
10 /// empty string. The prompt may contain a zero shot or few shot task.
11 pub prompt: Prompt<'a>,
12 /// Controls in which circumstances the model will stop generating new tokens.
13 pub stopping: Stopping<'a>,
14 /// Sampling controls how the tokens ("words") are selected for the completion.
15 pub sampling: Sampling,
16 /// Whether to include special tokens (e.g. <|endoftext|>, <|python_tag|>) in the completion.
17 pub special_tokens: bool,
18 /// Wether you are interessted in the probabilities of the sampled tokens, or most likely
19 /// tokens.
20 pub logprobs: Logprobs,
21}
22
23impl<'a> TaskCompletion<'a> {
24 /// Convenience constructor leaving most setting to default, just completing a given text
25 pub fn from_text(text: &'a str) -> Self {
26 TaskCompletion {
27 prompt: Prompt::from_text(text),
28 stopping: Stopping::NO_TOKEN_LIMIT,
29 sampling: Sampling::MOST_LIKELY,
30 special_tokens: false,
31 logprobs: Logprobs::No,
32 }
33 }
34
35 pub fn with_maximum_tokens(mut self, maximum_tokens: u32) -> Self {
36 self.stopping.maximum_tokens = Some(maximum_tokens);
37 self
38 }
39
40 pub fn with_stop_sequences(mut self, stop_sequences: &'a [&str]) -> Self {
41 self.stopping.stop_sequences = stop_sequences;
42 self
43 }
44
45 /// Include special tokens (e.g. <|endoftext|>, <|python_tag|>) in the completion.
46 pub fn with_special_tokens(mut self) -> Self {
47 self.special_tokens = true;
48 self
49 }
50
51 pub fn with_logprobs(mut self, logprobs: Logprobs) -> Self {
52 self.logprobs = logprobs;
53 self
54 }
55}
56
57/// Sampling controls how the tokens ("words") are selected for the completion.
58pub struct Sampling {
59 /// A temperature encourages the model to produce less probable outputs ("be more creative").
60 /// Values are expected to be between 0 and 1. Try high values for a more random ("creative")
61 /// response.
62 pub temperature: Option<f64>,
63 /// Introduces random sampling for generated tokens by randomly selecting the next token from
64 /// the k most likely options. A value larger than 1 encourages the model to be more creative.
65 /// Set to 0 to get the same behaviour as `None`.
66 pub top_k: Option<u32>,
67 /// Introduces random sampling for generated tokens by randomly selecting the next token from
68 /// the smallest possible set of tokens whose cumulative probability exceeds the probability
69 /// top_p. Set to 0 to get the same behaviour as `None`.
70 pub top_p: Option<f64>,
71 /// When specified, this number will decrease (or increase) the likelihood of repeating tokens
72 /// that were mentioned prior in the completion. The penalty is cumulative. The more a token
73 /// is mentioned in the completion, the more its probability will decrease.
74 /// A negative value will increase the likelihood of repeating tokens.
75 pub frequency_penalty: Option<f64>,
76 /// The presence penalty reduces the likelihood of generating tokens that are already present
77 /// in the generated text (repetition_penalties_include_completion=true) respectively the
78 /// prompt (repetition_penalties_include_prompt=true). Presence penalty is independent of the
79 /// number of occurrences. Increase the value to reduce the likelihood of repeating text.
80 /// An operation like the following is applied:
81 ///
82 /// logits[t] -> logits[t] - 1 * penalty
83 ///
84 /// where logits[t] is the logits for any given token. Note that the formula is independent
85 /// of the number of times that a token appears.
86 pub presence_penalty: Option<f64>,
87}
88
89impl Sampling {
90 /// Always chooses the token most likely to come next. Choose this if you do want close to
91 /// deterministic behaviour and do not want to apply any penalties to avoid repetitions.
92 pub const MOST_LIKELY: Self = Sampling {
93 temperature: None,
94 top_k: None,
95 top_p: None,
96 frequency_penalty: None,
97 presence_penalty: None,
98 };
99}
100
101impl Default for Sampling {
102 fn default() -> Self {
103 Self::MOST_LIKELY
104 }
105}
106
107/// Controls the conditions under which the language models stops generating text.
108pub struct Stopping<'a> {
109 /// The maximum number of tokens to be generated. Completion will terminate after the maximum
110 /// number of tokens is reached. Increase this value to allow for longer outputs. A text is split
111 /// into tokens. Usually there are more tokens than words. The total number of tokens of prompt
112 /// and maximum_tokens depends on the model.
113 /// If maximum tokens is set to None, no outside limit is opposed on the number of maximum tokens.
114 /// The model will generate tokens until it generates one of the specified stop_sequences or it
115 /// reaches its technical limit, which usually is its context window.
116 pub maximum_tokens: Option<u32>,
117 /// List of strings which will stop generation if they are generated. Stop sequences are
118 /// helpful in structured texts. E.g.: In a question answering scenario a text may consist of
119 /// lines starting with either "Question: " or "Answer: " (alternating). After producing an
120 /// answer, the model will be likely to generate "Question: ". "Question: " may therefore be used
121 /// as stop sequence in order not to have the model generate more questions but rather restrict
122 /// text generation to the answers.
123 pub stop_sequences: &'a [&'a str],
124}
125
126impl<'a> Stopping<'a> {
127 /// Only stop once the model reaches its technical limit, usually the context window.
128 pub const NO_TOKEN_LIMIT: Self = Stopping {
129 maximum_tokens: None,
130 stop_sequences: &[],
131 };
132
133 /// Stop once the model has reached maximum_tokens.
134 pub fn from_maximum_tokens(maximum_tokens: u32) -> Self {
135 Self {
136 maximum_tokens: Some(maximum_tokens),
137 stop_sequences: &[],
138 }
139 }
140
141 pub fn from_stop_sequences(stop_sequences: &'a [&'a str]) -> Self {
142 Self {
143 maximum_tokens: None,
144 stop_sequences,
145 }
146 }
147}
148
149impl Default for Stopping<'_> {
150 fn default() -> Self {
151 Self::NO_TOKEN_LIMIT
152 }
153}
154
155/// Body send to the Aleph Alpha API on the POST `/completion` Route
156#[derive(Serialize, Debug)]
157struct BodyCompletion<'a> {
158 /// Name of the model tasked with completing the prompt. E.g. `luminous-base"`.
159 pub model: &'a str,
160 /// Prompt to complete. The modalities supported depend on `model`.
161 pub prompt: Prompt<'a>,
162 /// Limits the number of tokens, which are generated for the completion.
163 #[serde(skip_serializing_if = "Option::is_none")]
164 pub maximum_tokens: Option<u32>,
165 /// List of strings which will stop generation if they are generated. Stop sequences are
166 /// helpful in structured texts. E.g.: In a question answering scenario a text may consist of
167 /// lines starting with either "Question: " or "Answer: " (alternating). After producing an
168 /// answer, the model will be likely to generate "Question: ". "Question: " may therefore be used
169 /// as stop sequence in order not to have the model generate more questions but rather restrict
170 /// text generation to the answers.
171 #[serde(skip_serializing_if = "<[_]>::is_empty")]
172 pub stop_sequences: &'a [&'a str],
173 #[serde(skip_serializing_if = "Option::is_none")]
174 pub temperature: Option<f64>,
175 #[serde(skip_serializing_if = "Option::is_none")]
176 pub top_k: Option<u32>,
177 #[serde(skip_serializing_if = "Option::is_none")]
178 pub top_p: Option<f64>,
179 /// If true, the response will be streamed.
180 #[serde(skip_serializing_if = "std::ops::Not::not")]
181 pub stream: bool,
182 /// Forces the raw completion of the model to be returned.
183 /// For some models, the completion that was generated by the model may be optimized and
184 /// returned in the completion field of the CompletionResponse.
185 /// The raw completion, if returned, will contain the un-optimized completion.
186 /// Setting tokens to true or log_probs to any value will also trigger the raw completion to be returned.
187 #[serde(skip_serializing_if = "std::ops::Not::not")]
188 pub raw_completion: bool,
189 #[serde(skip_serializing_if = "Option::is_none")]
190 pub frequency_penalty: Option<f64>,
191 #[serde(skip_serializing_if = "Option::is_none")]
192 pub presence_penalty: Option<f64>,
193 #[serde(skip_serializing_if = "Option::is_none")]
194 pub log_probs: Option<u8>,
195 #[serde(skip_serializing_if = "std::ops::Not::not")]
196 pub tokens: bool,
197}
198
199impl<'a> BodyCompletion<'a> {
200 pub fn new(model: &'a str, task: &'a TaskCompletion<'a>) -> Self {
201 let TaskCompletion {
202 prompt,
203 stopping,
204 sampling,
205 special_tokens,
206 logprobs,
207 } = task;
208 Self {
209 model,
210 prompt: prompt.borrow(),
211 maximum_tokens: stopping.maximum_tokens,
212 stop_sequences: stopping.stop_sequences,
213 temperature: sampling.temperature,
214 top_k: sampling.top_k,
215 top_p: sampling.top_p,
216 stream: false,
217 raw_completion: *special_tokens,
218 frequency_penalty: sampling.frequency_penalty,
219 presence_penalty: sampling.presence_penalty,
220 log_probs: logprobs.to_logprobs_num(),
221 tokens: logprobs.to_tokens(),
222 }
223 }
224 pub fn with_streaming(mut self) -> Self {
225 self.stream = true;
226 self
227 }
228}
229
230#[derive(Deserialize, Debug, PartialEq)]
231pub struct ResponseCompletion {
232 model_version: String,
233 completions: Vec<DeserializedCompletion>,
234 num_tokens_prompt_total: u32,
235 num_tokens_generated: u32,
236}
237
238#[derive(Deserialize, Debug, PartialEq)]
239struct DeserializedCompletion {
240 completion: String,
241 finish_reason: String,
242 raw_completion: Option<String>,
243 #[serde(default)]
244 log_probs: Vec<HashMap<String, f64>>,
245 #[serde(default)]
246 completion_tokens: Vec<String>,
247}
248
249/// Completion and metainformation returned by a completion task
250#[derive(Deserialize, Debug, PartialEq)]
251pub struct CompletionOutput {
252 pub completion: String,
253 pub finish_reason: String,
254 pub logprobs: Vec<Distribution>,
255 pub usage: Usage,
256}
257
258impl Task for TaskCompletion<'_> {
259 type Output = CompletionOutput;
260
261 type ResponseBody = ResponseCompletion;
262
263 fn build_request(
264 &self,
265 client: &reqwest::Client,
266 base: &str,
267 model: &str,
268 ) -> reqwest::RequestBuilder {
269 let body = BodyCompletion::new(model, self);
270 client.post(format!("{base}/complete")).json(&body)
271 }
272
273 fn body_to_output(&self, mut response: Self::ResponseBody) -> Self::Output {
274 // We expect the API to return exactly one completion, despite them being modled as an array
275 let DeserializedCompletion {
276 completion,
277 finish_reason,
278 raw_completion,
279 log_probs,
280 completion_tokens,
281 } = response.completions.pop().unwrap();
282 let completion = if self.special_tokens {
283 raw_completion.unwrap()
284 } else {
285 completion
286 };
287 CompletionOutput {
288 completion,
289 finish_reason,
290 logprobs: completion_logprobs_to_canonical(
291 log_probs,
292 completion_tokens,
293 self.logprobs.top_logprobs().unwrap_or_default(),
294 ),
295 usage: Usage {
296 prompt_tokens: response.num_tokens_prompt_total,
297 completion_tokens: response.num_tokens_generated,
298 },
299 }
300 }
301}
302
303fn completion_logprobs_to_canonical(
304 log_probs: Vec<HashMap<String, f64>>,
305 completion_tokens: Vec<String>,
306 num_expected_top_logprobs: u8,
307) -> Vec<Distribution> {
308 let mut logprobs = Vec::new();
309 for (token, map) in completion_tokens.into_iter().zip(log_probs) {
310 let logprob = *map.get(&token).unwrap_or(&f64::NAN);
311 let mut top_logprobs = map
312 .into_iter()
313 .map(|(token, logprob)| Logprob {
314 token: token.into_bytes(),
315 logprob,
316 })
317 .collect::<Vec<_>>();
318 // We want to make sure the most likely tokens are first in the array
319 top_logprobs.sort_by(|a, b| b.logprob.total_cmp(&a.logprob));
320 // The aa api always makes the sampled token part of the array, even if not in the top n
321 // elements. Since we translate into a representation with the sampled token separate, we
322 // can keep the top n elements constant. In case the sampled token has not been in the top
323 // n, the below line will shorten the array by one.
324 top_logprobs.resize_with(num_expected_top_logprobs as usize, || {
325 unreachable!("Vec should only shorten")
326 });
327 logprobs.push(Distribution {
328 sampled: Logprob {
329 token: token.into_bytes(),
330 logprob,
331 },
332 top: top_logprobs,
333 });
334 }
335 logprobs
336}
337
338#[derive(Deserialize)]
339#[serde(tag = "type")]
340#[serde(rename_all = "snake_case")]
341pub enum DeserializedCompletionEvent {
342 StreamChunk {
343 /// The completion of the stream.
344 completion: String,
345 /// Completion with special tokens still included
346 raw_completion: Option<String>,
347 #[serde(default)]
348 log_probs: Vec<HashMap<String, f64>>,
349 #[serde(default)]
350 completion_tokens: Vec<String>,
351 },
352 StreamSummary {
353 /// The reason why the model stopped generating new tokens.
354 finish_reason: String,
355 },
356 CompletionSummary {
357 /// Number of tokens combined across all completion tasks.
358 /// In particular, if you set best_of or n to a number larger than 1 then we report the
359 /// combined prompt token count for all best_of or n tasks.
360 num_tokens_prompt_total: u32,
361 /// Number of tokens combined across all completion tasks.
362 /// If multiple completions are returned or best_of is set to a value greater than 1 then
363 /// this value contains the combined generated token count.
364 num_tokens_generated: u32,
365 },
366}
367
368#[derive(Debug, PartialEq)]
369pub enum CompletionEvent {
370 Delta {
371 /// The completion of the stream.
372 completion: String,
373 /// Log probabilities of the completion tokens if requested via logprobs parameter in request.
374 logprobs: Vec<Distribution>,
375 },
376 Finished {
377 /// The reason why the model stopped generating new tokens.
378 reason: String,
379 },
380 Summary {
381 usage: Usage,
382 },
383}
384
385impl StreamTask for TaskCompletion<'_> {
386 type Output = CompletionEvent;
387
388 type ResponseBody = DeserializedCompletionEvent;
389
390 fn build_request(
391 &self,
392 client: &reqwest::Client,
393 base: &str,
394 model: &str,
395 ) -> reqwest::RequestBuilder {
396 let body = BodyCompletion::new(model, self).with_streaming();
397 client.post(format!("{base}/complete")).json(&body)
398 }
399
400 fn body_to_output(&self, response: Self::ResponseBody) -> Option<Self::Output> {
401 Some(match response {
402 DeserializedCompletionEvent::StreamChunk {
403 completion,
404 raw_completion,
405 log_probs,
406 completion_tokens,
407 } => CompletionEvent::Delta {
408 completion: if self.special_tokens {
409 raw_completion.expect("Missing raw completion")
410 } else {
411 completion
412 },
413 logprobs: completion_logprobs_to_canonical(
414 log_probs,
415 completion_tokens,
416 self.logprobs.top_logprobs().unwrap_or_default(),
417 ),
418 },
419 DeserializedCompletionEvent::StreamSummary { finish_reason } => {
420 CompletionEvent::Finished {
421 reason: finish_reason,
422 }
423 }
424 DeserializedCompletionEvent::CompletionSummary {
425 num_tokens_prompt_total,
426 num_tokens_generated,
427 } => CompletionEvent::Summary {
428 usage: Usage {
429 prompt_tokens: num_tokens_prompt_total,
430 completion_tokens: num_tokens_generated,
431 },
432 },
433 })
434 }
435}
436
437impl Logprobs {
438 /// Convert into a number for completion endpoint
439 fn to_logprobs_num(self) -> Option<u8> {
440 match self {
441 Logprobs::No => None,
442 Logprobs::Sampled => Some(0),
443 Logprobs::Top(n) => Some(n),
444 }
445 }
446
447 /// Wether or not we want to return the completion tokens
448 fn to_tokens(self) -> bool {
449 match self {
450 Logprobs::No => false,
451 Logprobs::Sampled | Logprobs::Top(_) => true,
452 }
453 }
454}