aleph_alpha_api/completion.rs
1use super::image_processing::{from_image_path, preprocess_image, LoadImageError};
2use crate::impl_builder_methods;
3use base64::prelude::{Engine as _, BASE64_STANDARD};
4use serde::{Deserialize, Serialize};
5use std::{collections::HashMap, path::Path};
6
7#[derive(Serialize, Debug)]
8pub struct Prompt(Vec<Modality>);
9
10impl Default for Prompt {
11 fn default() -> Self {
12 Self(vec![])
13 }
14}
15
16impl Prompt {
17 pub fn empty() -> Self {
18 Self::default()
19 }
20
21 /// Create a prompt from a single text item.
22 pub fn from_text(text: impl Into<String>) -> Self {
23 Self(vec![Modality::from_text(text, None)])
24 }
25
26 pub fn from_text_with_controls(text: impl Into<String>, controls: Vec<TextControl>) -> Self {
27 Self(vec![Modality::from_text(text, Some(controls))])
28 }
29
30 pub fn from_token_ids(ids: Vec<u32>, controls: Option<Vec<TokenControl>>) -> Self {
31 Self(vec![Modality::from_token_ids(ids, controls)])
32 }
33
34 /// Create a multimodal prompt from a list of individual items with any modality.
35 pub fn from_vec(items: Vec<Modality>) -> Self {
36 Self(items)
37 }
38}
39
40#[derive(Serialize, Debug, Clone, PartialEq)]
41pub struct TokenControl {
42 /// Index of the token, relative to the list of tokens IDs in the current prompt item.
43 pub index: u32,
44
45 /// Factor to apply to the given token in the attention matrix.
46 ///
47 /// - 0 <= factor < 1 => Suppress the given token
48 /// - factor == 1 => identity operation, no change to attention
49 /// - factor > 1 => Amplify the given token
50 pub factor: f64,
51}
52
53#[derive(Serialize, Debug, Clone, PartialEq)]
54pub struct TextControl {
55 /// Starting character index to apply the factor to.
56 start: i32,
57
58 /// The amount of characters to apply the factor to.
59 length: i32,
60
61 /// Factor to apply to the given token in the attention matrix.
62 ///
63 /// - 0 <= factor < 1 => Suppress the given token
64 /// - factor == 1 => identity operation, no change to attention
65 /// - factor > 1 => Amplify the given token
66 factor: f64,
67
68 /// What to do if a control partially overlaps with a text token.
69 ///
70 /// If set to "partial", the factor will be adjusted proportionally with the amount
71 /// of the token it overlaps. So a factor of 2.0 of a control that only covers 2 of
72 /// 4 token characters, would be adjusted to 1.5. (It always moves closer to 1, since
73 /// 1 is an identity operation for control factors.)
74 ///
75 /// If set to "complete", the full factor will be applied as long as the control
76 /// overlaps with the token at all.
77 #[serde(skip_serializing_if = "Option::is_none")]
78 token_overlap: Option<String>,
79}
80
81/// Bounding box in logical coordinates. From 0 to 1. With (0,0) being the upper left corner,
82/// and relative to the entire image.
83///
84/// Keep in mind, non-square images are center-cropped by default before going to the model.
85/// (You can specify a custom cropping if you want.). Since control coordinates are relative to
86/// the entire image, all or a portion of your control may be outside the "model visible area".
87#[derive(Serialize, Deserialize, Clone, Debug, Default)]
88pub struct BoundingBox {
89 /// x-coordinate of top left corner of the control bounding box.
90 /// Must be a value between 0 and 1, where 0 is the left corner and 1 is the right corner.
91 left: f64,
92
93 /// y-coordinate of top left corner of the control bounding box
94 /// Must be a value between 0 and 1, where 0 is the top pixel row and 1 is the bottom row.
95 top: f64,
96
97 /// width of the control bounding box
98 /// Must be a value between 0 and 1, where 1 means the full width of the image.
99 width: f64,
100
101 /// height of the control bounding box
102 /// Must be a value between 0 and 1, where 1 means the full height of the image.
103 heigh: f64,
104}
105
106#[derive(Serialize, Clone, Debug)]
107pub struct ImageControl {
108 /// Bounding box in logical coordinates. From 0 to 1. With (0,0) being the upper left corner,
109 /// and relative to the entire image.
110 ///
111 /// Keep in mind, non-square images are center-cropped by default before going to the model. (You
112 /// can specify a custom cropping if you want.). Since control coordinates are relative to the
113 /// entire image, all or a portion of your control may be outside the "model visible area".
114 rect: BoundingBox,
115
116 /// Factor to apply to the given token in the attention matrix.
117 ///
118 /// - 0 <= factor < 1 => Suppress the given token
119 /// - factor == 1 => identity operation, no change to attention
120 /// - factor > 1 => Amplify the given token
121 factor: f64,
122
123 /// What to do if a control partially overlaps with a text token.
124 ///
125 /// If set to "partial", the factor will be adjusted proportionally with the amount
126 /// of the token it overlaps. So a factor of 2.0 of a control that only covers 2 of
127 /// 4 token characters, would be adjusted to 1.5. (It always moves closer to 1, since
128 /// 1 is an identity operation for control factors.)
129 ///
130 /// If set to "complete", the full factor will be applied as long as the control
131 /// overlaps with the token at all.
132 #[serde(skip_serializing_if = "Option::is_none")]
133 token_overlap: Option<String>,
134}
135
136/// The prompt for models can be a combination of different modalities (Text and Image). The type of
137/// modalities which are supported depend on the Model in question.
138#[derive(Serialize, Debug, Clone)]
139#[serde(tag = "type", rename_all = "snake_case")]
140pub enum Modality {
141 /// The only type of prompt which can be used with pure language models
142 Text {
143 data: String,
144
145 #[serde(skip_serializing_if = "Option::is_none")]
146 controls: Option<Vec<TextControl>>,
147 },
148 /// An image input into the model. See [`Modality::from_image_path`].
149 Image {
150 /// An image send as part of a prompt to a model. The image is represented as base64.
151 ///
152 /// Note: The models operate on square images. All non-square images are center-cropped
153 /// before going to the model, so portions of the image may not be visible.
154 ///
155 /// You can supply specific cropping parameters if you like, to choose a different area
156 /// of the image than a center-crop. Or, you can always transform the image yourself to
157 /// a square before sending it.
158 data: String,
159
160 /// x-coordinate of top left corner of cropping box in pixels
161 #[serde(skip_serializing_if = "Option::is_none")]
162 x: Option<i32>,
163
164 /// y-coordinate of top left corner of cropping box in pixels
165 #[serde(skip_serializing_if = "Option::is_none")]
166 y: Option<i32>,
167
168 /// Size of the cropping square in pixels
169 #[serde(skip_serializing_if = "Option::is_none")]
170 size: Option<i32>,
171
172 #[serde(skip_serializing_if = "Option::is_none")]
173 controls: Option<Vec<ImageControl>>,
174 },
175 #[serde(rename = "token_ids")]
176 TokenIds {
177 data: Vec<u32>,
178
179 #[serde(skip_serializing_if = "Option::is_none")]
180 controls: Option<Vec<TokenControl>>,
181 },
182}
183
184impl Modality {
185 /// Instantiates a text prompt
186 pub fn from_text(text: impl Into<String>, controls: Option<Vec<TextControl>>) -> Self {
187 Modality::Text {
188 data: text.into(),
189 controls,
190 }
191 }
192
193 /// Instantiates a token_ids prompt
194 pub fn from_token_ids(ids: Vec<u32>, controls: Option<Vec<TokenControl>>) -> Self {
195 Modality::TokenIds {
196 data: ids,
197 controls,
198 }
199 }
200
201 pub fn from_image_path(path: impl AsRef<Path>) -> Result<Self, LoadImageError> {
202 let bytes = from_image_path(path.as_ref())?;
203 Ok(Self::from_image_bytes(&bytes))
204 }
205
206 /// Generates an image input from the binary representation of the image.
207 ///
208 /// Using this constructor you must use a binary representation compatible with the API. Png is
209 /// guaranteed to be supported, and all others formats are converted into it. Furthermore, the
210 /// model can only look at square shaped pictures. If the picture is not square shaped it will
211 /// be center cropped.
212 fn from_image_bytes(image: &[u8]) -> Self {
213 Modality::Image {
214 data: BASE64_STANDARD.encode(image).into(),
215 x: None,
216 y: None,
217 size: None,
218 controls: None,
219 }
220 }
221
222 /// Image input for model
223 ///
224 /// The model can only see squared pictures. Images are centercropped. You may want to use this
225 /// method instead of [`Self::from_image_path`] in case you have the image in memory already
226 /// and do not want to load it from a file again.
227 pub fn from_image(image: &image::DynamicImage) -> Result<Self, LoadImageError> {
228 let bytes = preprocess_image(image);
229 Ok(Self::from_image_bytes(&bytes))
230 }
231}
232
233/// Optional parameter that specifies which datacenters may process the request. You can either set the
234/// parameter to "aleph-alpha" or omit it (defaulting to null).
235///
236/// Not setting this value, or setting it to null, gives us maximal flexibility in processing your request
237/// in our own datacenters and on servers hosted with other providers. Choose this option for maximum
238/// availability.
239///
240/// Setting it to "aleph-alpha" allows us to only process the request in our own datacenters. Choose this
241/// option for maximal data privacy.
242#[derive(Serialize, Debug)]
243pub enum Hosting {
244 #[serde(rename = "aleph-alpha")]
245 AlephAlpha,
246}
247
248#[derive(Serialize, Debug, Default)]
249pub struct CompletionRequest {
250 /// The name of the model from the Luminous model family, e.g. `luminous-base"`.
251 /// Models and their respective architectures can differ in parameter size and capabilities.
252 /// The most recent version of the model is always used. The model output contains information
253 /// as to the model version.
254 pub model: String,
255
256 /// Determines in which datacenters the request may be processed.
257 /// You can either set the parameter to "aleph-alpha" or omit it (defaulting to None).
258 ///
259 /// Not setting this value, or setting it to None, gives us maximal flexibility in processing your request in our
260 /// own datacenters and on servers hosted with other providers. Choose this option for maximal availability.
261 ///
262 /// Setting it to "aleph-alpha" allows us to only process the request in our own datacenters.
263 /// Choose this option for maximal data privacy.
264 #[serde(skip_serializing_if = "Option::is_none")]
265 pub hosting: Option<Hosting>,
266
267 /// Prompt to complete. The modalities supported depend on `model`.
268 pub prompt: Prompt,
269
270 /// Limits the number of tokens, which are generated for the completion.
271 pub maximum_tokens: u32,
272
273 /// Generate at least this number of tokens before an end-of-text token is generated. (default: 0)
274 #[serde(skip_serializing_if = "Option::is_none")]
275 pub minimum_tokens: Option<u32>,
276
277 /// Echo the prompt in the completion. This may be especially helpful when log_probs is set to return logprobs for the
278 /// prompt.
279 #[serde(skip_serializing_if = "Option::is_none")]
280 pub echo: Option<bool>,
281 /// List of strings which will stop generation if they are generated. Stop sequences are
282 /// helpful in structured texts. E.g.: In a question answering scenario a text may consist of
283 /// lines starting with either "Question: " or "Answer: " (alternating). After producing an
284 /// answer, the model will be likely to generate "Question: ". "Question: " may therefore be used
285 /// as stop sequence in order not to have the model generate more questions but rather restrict
286 /// text generation to the answers.
287
288 /// A higher sampling temperature encourages the model to produce less probable outputs ("be more creative").
289 /// Values are expected in a range from 0.0 to 1.0. Try high values (e.g., 0.9) for a more "creative" response and the
290 /// default 0.0 for a well defined and repeatable answer. It is advised to use either temperature, top_k, or top_p, but
291 /// not all three at the same time. If a combination of temperature, top_k or top_p is used, rescaling of logits with
292 /// temperature will be performed first. Then top_k is applied. Top_p follows last.
293 #[serde(skip_serializing_if = "Option::is_none")]
294 pub temperature: Option<f64>,
295
296 /// Introduces random sampling for generated tokens by randomly selecting the next token from the k most likely options.
297 /// A value larger than 1 encourages the model to be more creative. Set to 0.0 if repeatable output is desired. It is
298 /// advised to use either temperature, top_k, or top_p, but not all three at the same time. If a combination of
299 /// temperature, top_k or top_p is used, rescaling of logits with temperature will be performed first. Then top_k is
300 /// applied. Top_p follows last.
301 #[serde(skip_serializing_if = "Option::is_none")]
302 pub top_k: Option<u32>,
303
304 /// Introduces random sampling for generated tokens by randomly selecting the next token from the smallest possible set
305 /// of tokens whose cumulative probability exceeds the probability top_p. Set to 0.0 if repeatable output is desired. It
306 /// is advised to use either temperature, top_k, or top_p, but not all three at the same time. If a combination of
307 /// temperature, top_k or top_p is used, rescaling of logits with temperature will be performed first. Then top_k is
308 /// applied. Top_p follows last.
309 #[serde(skip_serializing_if = "Option::is_none")]
310 pub top_p: Option<f64>,
311
312 /// The presence penalty reduces the likelihood of generating tokens that are already present in the
313 /// generated text (`repetition_penalties_include_completion=true`) respectively the prompt
314 /// (`repetition_penalties_include_prompt=true`).
315 /// Presence penalty is independent of the number of occurrences. Increase the value to reduce the likelihood of repeating
316 /// text.
317 /// An operation like the following is applied: `logits[t] -> logits[t] - 1 * penalty`
318 /// where `logits[t]` is the logits for any given token. Note that the formula is independent of the number of times
319 /// that a token appears.
320 #[serde(skip_serializing_if = "Option::is_none")]
321 pub presence_penalty: Option<f64>,
322
323 /// The frequency penalty reduces the likelihood of generating tokens that are already present in the
324 /// generated text (`repetition_penalties_include_completion=true`) respectively the prompt
325 /// (`repetition_penalties_include_prompt=true`).
326 /// If `repetition_penalties_include_prompt=True`, this also includes the tokens in the prompt.
327 /// Frequency penalty is dependent on the number of occurrences of a token.
328 /// An operation like the following is applied: `logits[t] -> logits[t] - count[t] * penalty`
329 /// where `logits[t]` is the logits for any given token and `count[t]` is the number of times that token appears.
330 #[serde(skip_serializing_if = "Option::is_none")]
331 pub frequency_penalty: Option<f64>,
332
333 /// Increasing the sequence penalty reduces the likelihood of reproducing token sequences that already
334 /// appear in the prompt
335 /// (if repetition_penalties_include_prompt is True) and prior completion.
336 #[serde(skip_serializing_if = "Option::is_none")]
337 pub sequence_penalty: Option<f64>,
338
339 /// Minimal number of tokens to be considered as sequence
340 #[serde(skip_serializing_if = "Option::is_none")]
341 pub sequence_penalty_min_length: Option<i32>,
342
343 /// Flag deciding whether presence penalty or frequency penalty are updated from tokens in the prompt
344 #[serde(skip_serializing_if = "Option::is_none")]
345 pub repetition_penalties_include_prompt: Option<bool>,
346
347 /// Flag deciding whether presence penalty or frequency penalty are updated from tokens in the completion
348 #[serde(skip_serializing_if = "Option::is_none")]
349 pub repetition_penalties_include_completion: Option<bool>,
350
351 /// Flag deciding whether presence penalty is applied multiplicatively (True) or additively (False).
352 /// This changes the formula stated for presence penalty.
353 #[serde(skip_serializing_if = "Option::is_none")]
354 pub use_multiplicative_presence_penalty: Option<bool>,
355
356 /// Flag deciding whether frequency penalty is applied multiplicatively (True) or additively (False).
357 /// This changes the formula stated for frequency penalty.
358 #[serde(skip_serializing_if = "Option::is_none")]
359 pub use_multiplicative_frequency_penalty: Option<bool>,
360
361 /// Flag deciding whether sequence penalty is applied multiplicatively (True) or additively (False).
362 #[serde(skip_serializing_if = "Option::is_none")]
363 pub use_multiplicative_sequence_penalty: Option<bool>,
364
365 /// List of strings that may be generated without penalty, regardless of other penalty settings.
366 /// By default, we will also include any `stop_sequences` you have set, since completion performance
367 /// can be degraded if expected stop sequences are penalized.
368 /// You can disable this behavior by setting `penalty_exceptions_include_stop_sequences` to `false`.
369 #[serde(skip_serializing_if = "Option::is_none")]
370 pub penalty_exceptions: Option<Vec<String>>,
371
372 /// All tokens in this text will be used in addition to the already penalized tokens for repetition
373 /// penalties.
374 /// These consist of the already generated completion tokens and the prompt tokens, if
375 /// `repetition_penalties_include_prompt` is set to `true`.
376 #[serde(skip_serializing_if = "Option::is_none")]
377 pub penalty_bias: Option<String>,
378
379 /// By default we include all `stop_sequences` in `penalty_exceptions`, so as not to penalise the
380 /// presence of stop sequences that are present in few-shot prompts to give structure to your
381 /// completions.
382 ///
383 /// You can set this to `false` if you do not want this behaviour.
384 ///
385 /// See the description of `penalty_exceptions` for more information on what `penalty_exceptions` are
386 /// used for.
387 #[serde(skip_serializing_if = "Option::is_none")]
388 pub penalty_exceptions_include_stop_sequences: Option<bool>,
389
390 /// If a value is given, the number of `best_of` completions will be generated on the server side. The
391 /// completion with the highest log probability per token is returned. If the parameter `n` is greater
392 /// than 1 more than 1 (`n`) completions will be returned. `best_of` must be strictly greater than `n`.
393 #[serde(skip_serializing_if = "Option::is_none")]
394 pub best_of: Option<u32>,
395
396 /// The number of completions to return. If argmax sampling is used (temperature, top_k, top_p are all
397 /// default) the same completions will be produced. This parameter should only be increased if random
398 /// sampling is used.
399 #[serde(skip_serializing_if = "Option::is_none")]
400 pub n: Option<u32>,
401
402 /// Number of top log probabilities for each token generated. Log probabilities can be used in downstream
403 /// tasks or to assess the model's certainty when producing tokens. No log probabilities are returned if
404 /// set to None. Log probabilities of generated tokens are returned if set to 0. Log probabilities of
405 /// generated tokens and top n log probabilities are returned if set to n.
406 #[serde(skip_serializing_if = "Option::is_none")]
407 pub log_probs: Option<i32>,
408
409 /// List of strings that will stop generation if they're generated. Stop sequences may be helpful in
410 /// structured texts.
411 #[serde(skip_serializing_if = "Option::is_none")]
412 pub stop_sequences: Option<Vec<String>>,
413
414 /// Flag indicating whether individual tokens of the completion should be returned (True) or whether
415 /// solely the generated text (i.e. the completion) is sufficient (False).
416 #[serde(skip_serializing_if = "Option::is_none")]
417 pub tokens: Option<bool>,
418
419 /// Setting this parameter to true forces the raw completion of the model to be returned.
420 /// For some models, we may optimize the completion that was generated by the model and
421 /// return the optimized completion in the completion field of the CompletionResponse.
422 /// The raw completion, if returned, will contain the un-optimized completion.
423 /// Setting tokens to true or log_probs to any value will also trigger the raw completion
424 /// to be returned.
425 #[serde(skip_serializing_if = "Option::is_none")]
426 pub raw_completion: Option<bool>,
427
428 /// We continually research optimal ways to work with our models. By default, we apply these
429 /// optimizations to both your prompt and completion for you.
430 /// Our goal is to improve your results while using our API. But you can always pass
431 /// `disable_optimizations: true` and we will leave your prompt and completion untouched.
432 #[serde(skip_serializing_if = "Option::is_none")]
433 pub disable_optimizations: Option<bool>,
434
435 /// Bias the completion to only generate options within this list;
436 /// all other tokens are disregarded at sampling
437 ///
438 /// Note that strings in the inclusion list must not be prefixes
439 /// of strings in the exclusion list and vice versa
440 #[serde(skip_serializing_if = "Option::is_none")]
441 pub completion_bias_inclusion: Option<Vec<String>>,
442
443 /// Only consider the first token for the completion_bias_inclusion
444 #[serde(skip_serializing_if = "Option::is_none")]
445 pub completion_bias_inclusion_first_token_only: Option<bool>,
446
447 /// Bias the completion to NOT generate options within this list;
448 /// all other tokens are unaffected in sampling
449 ///
450 /// Note that strings in the inclusion list must not be prefixes
451 /// of strings in the exclusion list and vice versa
452 #[serde(skip_serializing_if = "Option::is_none")]
453 pub completion_bias_exclusion: Option<Vec<String>>,
454
455 /// Only consider the first token for the completion_bias_exclusion
456 #[serde(skip_serializing_if = "Option::is_none")]
457 pub completion_bias_exclusion_first_token_only: Option<bool>,
458
459 /// If set to `null`, attention control parameters only apply to those tokens that have
460 /// explicitly been set in the request.
461 /// If set to a non-null value, we apply the control parameters to similar tokens as well.
462 /// Controls that have been applied to one token will then be applied to all other tokens
463 /// that have at least the similarity score defined by this parameter.
464 /// The similarity score is the cosine similarity of token embeddings.
465 #[serde(skip_serializing_if = "Option::is_none")]
466 pub contextual_control_threshold: Option<f64>,
467
468 /// `true`: apply controls on prompt items by adding the `log(control_factor)` to attention scores.
469 /// `false`: apply controls on prompt items by
470 /// `(attention_scores - -attention_scores.min(-1)) * control_factor`
471 #[serde(skip_serializing_if = "Option::is_none")]
472 pub control_log_additive: Option<bool>,
473
474 /// The logit bias allows to influence the likelihood of generating tokens. A dictionary mapping token
475 /// ids (int) to a bias (float) can be provided. Such bias is added to the logits as generated by the
476 /// model.
477 #[serde(skip_serializing_if = "Option::is_none")]
478 pub logit_bias: Option<HashMap<i32, f32>>,
479}
480
481impl CompletionRequest {
482 pub fn new(model: String, prompt: Prompt, maximum_tokens: u32) -> Self {
483 Self {
484 model,
485 prompt,
486 maximum_tokens,
487 ..Self::default()
488 }
489 }
490 pub fn from_text(model: String, prompt: String, maximum_tokens: u32) -> Self {
491 Self::new(model, Prompt::from_text(prompt), maximum_tokens)
492 }
493}
494
495impl_builder_methods!(
496 CompletionRequest,
497 minimum_tokens: u32,
498 echo: bool,
499 temperature: f64,
500 top_k: u32,
501 top_p: f64,
502 presence_penalty: f64,
503 frequency_penalty: f64,
504 sequence_penalty: f64,
505 sequence_penalty_min_length: i32,
506 repetition_penalties_include_prompt: bool,
507 repetition_penalties_include_completion: bool,
508 use_multiplicative_presence_penalty: bool,
509 use_multiplicative_frequency_penalty: bool,
510 use_multiplicative_sequence_penalty: bool,
511 penalty_exceptions: Vec<String>,
512 penalty_bias: String,
513 penalty_exceptions_include_stop_sequences: bool,
514 best_of: u32,
515 n: u32,
516 log_probs: i32,
517 stop_sequences: Vec<String>,
518 tokens: bool,
519 raw_completion: bool,
520 disable_optimizations: bool,
521 completion_bias_inclusion: Vec<String>,
522 completion_bias_inclusion_first_token_only: bool,
523 completion_bias_exclusion: Vec<String>,
524 completion_bias_exclusion_first_token_only: bool,
525 contextual_control_threshold: f64,
526 control_log_additive: bool,
527 logit_bias: HashMap<i32, f32>
528);
529
530#[derive(Deserialize, Debug)]
531pub struct CompletionResponse {
532 /// model name and version (if any) of the used model for inference
533 pub model_version: String,
534 /// list of completions; may contain only one entry if no more are requested (see parameter n)
535 pub completions: Vec<CompletionOutput>,
536}
537
538impl CompletionResponse {
539 /// The best completion in the answer.
540 pub fn best(&self) -> &CompletionOutput {
541 self.completions
542 .first()
543 .expect("Response is assumed to always have at least one completion")
544 }
545
546 /// Text of the best completion.
547 pub fn best_text(&self) -> &str {
548 &self.best().completion
549 }
550}
551
552#[derive(Deserialize, Debug)]
553pub struct CompletionOutput {
554 pub completion: String,
555 pub finish_reason: String,
556}