aleph_alpha_api/
completion.rs

1use super::image_processing::{from_image_path, preprocess_image, LoadImageError};
2use crate::impl_builder_methods;
3use base64::prelude::{Engine as _, BASE64_STANDARD};
4use serde::{Deserialize, Serialize};
5use std::{collections::HashMap, path::Path};
6
7#[derive(Serialize, Debug)]
8pub struct Prompt(Vec<Modality>);
9
10impl Default for Prompt {
11    fn default() -> Self {
12        Self(vec![])
13    }
14}
15
16impl Prompt {
17    pub fn empty() -> Self {
18        Self::default()
19    }
20
21    /// Create a prompt from a single text item.
22    pub fn from_text(text: impl Into<String>) -> Self {
23        Self(vec![Modality::from_text(text, None)])
24    }
25
26    pub fn from_text_with_controls(text: impl Into<String>, controls: Vec<TextControl>) -> Self {
27        Self(vec![Modality::from_text(text, Some(controls))])
28    }
29
30    pub fn from_token_ids(ids: Vec<u32>, controls: Option<Vec<TokenControl>>) -> Self {
31        Self(vec![Modality::from_token_ids(ids, controls)])
32    }
33
34    /// Create a multimodal prompt from a list of individual items with any modality.
35    pub fn from_vec(items: Vec<Modality>) -> Self {
36        Self(items)
37    }
38}
39
40#[derive(Serialize, Debug, Clone, PartialEq)]
41pub struct TokenControl {
42    /// Index of the token, relative to the list of tokens IDs in the current prompt item.
43    pub index: u32,
44
45    /// Factor to apply to the given token in the attention matrix.
46    ///
47    /// - 0 <= factor < 1 => Suppress the given token
48    /// - factor == 1 => identity operation, no change to attention
49    /// - factor > 1 => Amplify the given token
50    pub factor: f64,
51}
52
53#[derive(Serialize, Debug, Clone, PartialEq)]
54pub struct TextControl {
55    /// Starting character index to apply the factor to.
56    start: i32,
57
58    /// The amount of characters to apply the factor to.
59    length: i32,
60
61    /// Factor to apply to the given token in the attention matrix.
62    ///
63    /// - 0 <= factor < 1 => Suppress the given token
64    /// - factor == 1 => identity operation, no change to attention
65    /// - factor > 1 => Amplify the given token
66    factor: f64,
67
68    /// What to do if a control partially overlaps with a text token.
69    ///
70    /// If set to "partial", the factor will be adjusted proportionally with the amount
71    /// of the token it overlaps. So a factor of 2.0 of a control that only covers 2 of
72    /// 4 token characters, would be adjusted to 1.5. (It always moves closer to 1, since
73    /// 1 is an identity operation for control factors.)
74    ///
75    /// If set to "complete", the full factor will be applied as long as the control
76    /// overlaps with the token at all.
77    #[serde(skip_serializing_if = "Option::is_none")]
78    token_overlap: Option<String>,
79}
80
81/// Bounding box in logical coordinates. From 0 to 1. With (0,0) being the upper left corner,
82/// and relative to the entire image.
83///
84/// Keep in mind, non-square images are center-cropped by default before going to the model.
85/// (You can specify a custom cropping if you want.). Since control coordinates are relative to
86/// the entire image, all or a portion of your control may be outside the "model visible area".
87#[derive(Serialize, Deserialize, Clone, Debug, Default)]
88pub struct BoundingBox {
89    /// x-coordinate of top left corner of the control bounding box.
90    /// Must be a value between 0 and 1, where 0 is the left corner and 1 is the right corner.
91    left: f64,
92
93    /// y-coordinate of top left corner of the control bounding box
94    /// Must be a value between 0 and 1, where 0 is the top pixel row and 1 is the bottom row.
95    top: f64,
96
97    /// width of the control bounding box
98    /// Must be a value between 0 and 1, where 1 means the full width of the image.
99    width: f64,
100
101    /// height of the control bounding box
102    /// Must be a value between 0 and 1, where 1 means the full height of the image.
103    heigh: f64,
104}
105
106#[derive(Serialize, Clone, Debug)]
107pub struct ImageControl {
108    /// Bounding box in logical coordinates. From 0 to 1. With (0,0) being the upper left corner,
109    /// and relative to the entire image.
110    ///
111    /// Keep in mind, non-square images are center-cropped by default before going to the model. (You
112    /// can specify a custom cropping if you want.). Since control coordinates are relative to the
113    /// entire image, all or a portion of your control may be outside the "model visible area".
114    rect: BoundingBox,
115
116    /// Factor to apply to the given token in the attention matrix.
117    ///
118    /// - 0 <= factor < 1 => Suppress the given token
119    /// - factor == 1 => identity operation, no change to attention
120    /// - factor > 1 => Amplify the given token
121    factor: f64,
122
123    /// What to do if a control partially overlaps with a text token.
124    ///
125    /// If set to "partial", the factor will be adjusted proportionally with the amount
126    /// of the token it overlaps. So a factor of 2.0 of a control that only covers 2 of
127    /// 4 token characters, would be adjusted to 1.5. (It always moves closer to 1, since
128    /// 1 is an identity operation for control factors.)
129    ///
130    /// If set to "complete", the full factor will be applied as long as the control
131    /// overlaps with the token at all.
132    #[serde(skip_serializing_if = "Option::is_none")]
133    token_overlap: Option<String>,
134}
135
136/// The prompt for models can be a combination of different modalities (Text and Image). The type of
137/// modalities which are supported depend on the Model in question.
138#[derive(Serialize, Debug, Clone)]
139#[serde(tag = "type", rename_all = "snake_case")]
140pub enum Modality {
141    /// The only type of prompt which can be used with pure language models
142    Text {
143        data: String,
144
145        #[serde(skip_serializing_if = "Option::is_none")]
146        controls: Option<Vec<TextControl>>,
147    },
148    /// An image input into the model. See [`Modality::from_image_path`].
149    Image {
150        /// An image send as part of a prompt to a model. The image is represented as base64.
151        ///
152        /// Note: The models operate on square images. All non-square images are center-cropped
153        /// before going to the model, so portions of the image may not be visible.
154        ///
155        /// You can supply specific cropping parameters if you like, to choose a different area
156        /// of the image than a center-crop. Or, you can always transform the image yourself to
157        /// a square before sending it.
158        data: String,
159
160        /// x-coordinate of top left corner of cropping box in pixels
161        #[serde(skip_serializing_if = "Option::is_none")]
162        x: Option<i32>,
163
164        /// y-coordinate of top left corner of cropping box in pixels
165        #[serde(skip_serializing_if = "Option::is_none")]
166        y: Option<i32>,
167
168        /// Size of the cropping square in pixels
169        #[serde(skip_serializing_if = "Option::is_none")]
170        size: Option<i32>,
171
172        #[serde(skip_serializing_if = "Option::is_none")]
173        controls: Option<Vec<ImageControl>>,
174    },
175    #[serde(rename = "token_ids")]
176    TokenIds {
177        data: Vec<u32>,
178
179        #[serde(skip_serializing_if = "Option::is_none")]
180        controls: Option<Vec<TokenControl>>,
181    },
182}
183
184impl Modality {
185    /// Instantiates a text prompt
186    pub fn from_text(text: impl Into<String>, controls: Option<Vec<TextControl>>) -> Self {
187        Modality::Text {
188            data: text.into(),
189            controls,
190        }
191    }
192
193    /// Instantiates a token_ids prompt
194    pub fn from_token_ids(ids: Vec<u32>, controls: Option<Vec<TokenControl>>) -> Self {
195        Modality::TokenIds {
196            data: ids,
197            controls,
198        }
199    }
200
201    pub fn from_image_path(path: impl AsRef<Path>) -> Result<Self, LoadImageError> {
202        let bytes = from_image_path(path.as_ref())?;
203        Ok(Self::from_image_bytes(&bytes))
204    }
205
206    /// Generates an image input from the binary representation of the image.
207    ///
208    /// Using this constructor you must use a binary representation compatible with the API. Png is
209    /// guaranteed to be supported, and all others formats are converted into it. Furthermore, the
210    /// model can only look at square shaped pictures. If the picture is not square shaped it will
211    /// be center cropped.
212    fn from_image_bytes(image: &[u8]) -> Self {
213        Modality::Image {
214            data: BASE64_STANDARD.encode(image).into(),
215            x: None,
216            y: None,
217            size: None,
218            controls: None,
219        }
220    }
221
222    /// Image input for model
223    ///
224    /// The model can only see squared pictures. Images are centercropped. You may want to use this
225    /// method instead of [`Self::from_image_path`] in case you have the image in memory already
226    /// and do not want to load it from a file again.
227    pub fn from_image(image: &image::DynamicImage) -> Result<Self, LoadImageError> {
228        let bytes = preprocess_image(image);
229        Ok(Self::from_image_bytes(&bytes))
230    }
231}
232
233/// Optional parameter that specifies which datacenters may process the request. You can either set the
234/// parameter to "aleph-alpha" or omit it (defaulting to null).
235///
236/// Not setting this value, or setting it to null, gives us maximal flexibility in processing your request
237/// in our own datacenters and on servers hosted with other providers. Choose this option for maximum
238/// availability.
239///
240/// Setting it to "aleph-alpha" allows us to only process the request in our own datacenters. Choose this
241/// option for maximal data privacy.
242#[derive(Serialize, Debug)]
243pub enum Hosting {
244    #[serde(rename = "aleph-alpha")]
245    AlephAlpha,
246}
247
248#[derive(Serialize, Debug, Default)]
249pub struct CompletionRequest {
250    /// The name of the model from the Luminous model family, e.g. `luminous-base"`.
251    /// Models and their respective architectures can differ in parameter size and capabilities.
252    /// The most recent version of the model is always used. The model output contains information
253    /// as to the model version.
254    pub model: String,
255
256    /// Determines in which datacenters the request may be processed.
257    /// You can either set the parameter to "aleph-alpha" or omit it (defaulting to None).
258    ///
259    /// Not setting this value, or setting it to None, gives us maximal flexibility in processing your request in our
260    /// own datacenters and on servers hosted with other providers. Choose this option for maximal availability.
261    ///
262    /// Setting it to "aleph-alpha" allows us to only process the request in our own datacenters.
263    /// Choose this option for maximal data privacy.
264    #[serde(skip_serializing_if = "Option::is_none")]
265    pub hosting: Option<Hosting>,
266
267    /// Prompt to complete. The modalities supported depend on `model`.
268    pub prompt: Prompt,
269
270    /// Limits the number of tokens, which are generated for the completion.
271    pub maximum_tokens: u32,
272
273    /// Generate at least this number of tokens before an end-of-text token is generated. (default: 0)
274    #[serde(skip_serializing_if = "Option::is_none")]
275    pub minimum_tokens: Option<u32>,
276
277    /// Echo the prompt in the completion. This may be especially helpful when log_probs is set to return logprobs for the
278    /// prompt.
279    #[serde(skip_serializing_if = "Option::is_none")]
280    pub echo: Option<bool>,
281    /// List of strings which will stop generation if they are generated. Stop sequences are
282    /// helpful in structured texts. E.g.: In a question answering scenario a text may consist of
283    /// lines starting with either "Question: " or "Answer: " (alternating). After producing an
284    /// answer, the model will be likely to generate "Question: ". "Question: " may therefore be used
285    /// as stop sequence in order not to have the model generate more questions but rather restrict
286    /// text generation to the answers.
287
288    /// A higher sampling temperature encourages the model to produce less probable outputs ("be more creative").
289    /// Values are expected in a range from 0.0 to 1.0. Try high values (e.g., 0.9) for a more "creative" response and the
290    /// default 0.0 for a well defined and repeatable answer. It is advised to use either temperature, top_k, or top_p, but
291    /// not all three at the same time. If a combination of temperature, top_k or top_p is used, rescaling of logits with
292    /// temperature will be performed first. Then top_k is applied. Top_p follows last.
293    #[serde(skip_serializing_if = "Option::is_none")]
294    pub temperature: Option<f64>,
295
296    /// Introduces random sampling for generated tokens by randomly selecting the next token from the k most likely options.
297    /// A value larger than 1 encourages the model to be more creative. Set to 0.0 if repeatable output is desired. It is
298    /// advised to use either temperature, top_k, or top_p, but not all three at the same time. If a combination of
299    /// temperature, top_k or top_p is used, rescaling of logits with temperature will be performed first. Then top_k is
300    /// applied. Top_p follows last.
301    #[serde(skip_serializing_if = "Option::is_none")]
302    pub top_k: Option<u32>,
303
304    /// Introduces random sampling for generated tokens by randomly selecting the next token from the smallest possible set
305    /// of tokens whose cumulative probability exceeds the probability top_p. Set to 0.0 if repeatable output is desired. It
306    /// is advised to use either temperature, top_k, or top_p, but not all three at the same time. If a combination of
307    /// temperature, top_k or top_p is used, rescaling of logits with temperature will be performed first. Then top_k is
308    /// applied. Top_p follows last.
309    #[serde(skip_serializing_if = "Option::is_none")]
310    pub top_p: Option<f64>,
311
312    /// The presence penalty reduces the likelihood of generating tokens that are already present in the
313    /// generated text (`repetition_penalties_include_completion=true`) respectively the prompt
314    /// (`repetition_penalties_include_prompt=true`).
315    /// Presence penalty is independent of the number of occurrences. Increase the value to reduce the likelihood of repeating
316    /// text.
317    /// An operation like the following is applied: `logits[t] -> logits[t] - 1 * penalty`
318    /// where `logits[t]` is the logits for any given token. Note that the formula is independent of the number of times
319    /// that a token appears.
320    #[serde(skip_serializing_if = "Option::is_none")]
321    pub presence_penalty: Option<f64>,
322
323    /// The frequency penalty reduces the likelihood of generating tokens that are already present in the
324    /// generated text (`repetition_penalties_include_completion=true`) respectively the prompt
325    /// (`repetition_penalties_include_prompt=true`).
326    /// If `repetition_penalties_include_prompt=True`, this also includes the tokens in the prompt.
327    /// Frequency penalty is dependent on the number of occurrences of a token.
328    /// An operation like the following is applied: `logits[t] -> logits[t] - count[t] * penalty`
329    /// where `logits[t]` is the logits for any given token and `count[t]` is the number of times that token appears.
330    #[serde(skip_serializing_if = "Option::is_none")]
331    pub frequency_penalty: Option<f64>,
332
333    /// Increasing the sequence penalty reduces the likelihood of reproducing token sequences that already
334    /// appear in the prompt
335    /// (if repetition_penalties_include_prompt is True) and prior completion.
336    #[serde(skip_serializing_if = "Option::is_none")]
337    pub sequence_penalty: Option<f64>,
338
339    /// Minimal number of tokens to be considered as sequence
340    #[serde(skip_serializing_if = "Option::is_none")]
341    pub sequence_penalty_min_length: Option<i32>,
342
343    /// Flag deciding whether presence penalty or frequency penalty are updated from tokens in the prompt
344    #[serde(skip_serializing_if = "Option::is_none")]
345    pub repetition_penalties_include_prompt: Option<bool>,
346
347    /// Flag deciding whether presence penalty or frequency penalty are updated from tokens in the completion
348    #[serde(skip_serializing_if = "Option::is_none")]
349    pub repetition_penalties_include_completion: Option<bool>,
350
351    /// Flag deciding whether presence penalty is applied multiplicatively (True) or additively (False).
352    /// This changes the formula stated for presence penalty.
353    #[serde(skip_serializing_if = "Option::is_none")]
354    pub use_multiplicative_presence_penalty: Option<bool>,
355
356    /// Flag deciding whether frequency penalty is applied multiplicatively (True) or additively (False).
357    /// This changes the formula stated for frequency penalty.
358    #[serde(skip_serializing_if = "Option::is_none")]
359    pub use_multiplicative_frequency_penalty: Option<bool>,
360
361    /// Flag deciding whether sequence penalty is applied multiplicatively (True) or additively (False).
362    #[serde(skip_serializing_if = "Option::is_none")]
363    pub use_multiplicative_sequence_penalty: Option<bool>,
364
365    /// List of strings that may be generated without penalty, regardless of other penalty settings.
366    /// By default, we will also include any `stop_sequences` you have set, since completion performance
367    /// can be degraded if expected stop sequences are penalized.
368    /// You can disable this behavior by setting `penalty_exceptions_include_stop_sequences` to `false`.
369    #[serde(skip_serializing_if = "Option::is_none")]
370    pub penalty_exceptions: Option<Vec<String>>,
371
372    /// All tokens in this text will be used in addition to the already penalized tokens for repetition
373    /// penalties.
374    /// These consist of the already generated completion tokens and the prompt tokens, if
375    /// `repetition_penalties_include_prompt` is set to `true`.
376    #[serde(skip_serializing_if = "Option::is_none")]
377    pub penalty_bias: Option<String>,
378
379    /// By default we include all `stop_sequences` in `penalty_exceptions`, so as not to penalise the
380    /// presence of stop sequences that are present in few-shot prompts to give structure to your
381    /// completions.
382    ///
383    /// You can set this to `false` if you do not want this behaviour.
384    ///
385    /// See the description of `penalty_exceptions` for more information on what `penalty_exceptions` are
386    /// used for.
387    #[serde(skip_serializing_if = "Option::is_none")]
388    pub penalty_exceptions_include_stop_sequences: Option<bool>,
389
390    /// If a value is given, the number of `best_of` completions will be generated on the server side. The
391    /// completion with the highest log probability per token is returned. If the parameter `n` is greater
392    /// than 1 more than 1 (`n`) completions will be returned. `best_of` must be strictly greater than `n`.
393    #[serde(skip_serializing_if = "Option::is_none")]
394    pub best_of: Option<u32>,
395
396    /// The number of completions to return. If argmax sampling is used (temperature, top_k, top_p are all
397    /// default) the same completions will be produced. This parameter should only be increased if random
398    /// sampling is used.
399    #[serde(skip_serializing_if = "Option::is_none")]
400    pub n: Option<u32>,
401
402    /// Number of top log probabilities for each token generated. Log probabilities can be used in downstream
403    /// tasks or to assess the model's certainty when producing tokens. No log probabilities are returned if
404    /// set to None. Log probabilities of generated tokens are returned if set to 0. Log probabilities of
405    /// generated tokens and top n log probabilities are returned if set to n.
406    #[serde(skip_serializing_if = "Option::is_none")]
407    pub log_probs: Option<i32>,
408
409    /// List of strings that will stop generation if they're generated. Stop sequences may be helpful in
410    /// structured texts.
411    #[serde(skip_serializing_if = "Option::is_none")]
412    pub stop_sequences: Option<Vec<String>>,
413
414    /// Flag indicating whether individual tokens of the completion should be returned (True) or whether
415    /// solely the generated text (i.e. the completion) is sufficient (False).
416    #[serde(skip_serializing_if = "Option::is_none")]
417    pub tokens: Option<bool>,
418
419    /// Setting this parameter to true forces the raw completion of the model to be returned.
420    /// For some models, we may optimize the completion that was generated by the model and
421    /// return the optimized completion in the completion field of the CompletionResponse.
422    /// The raw completion, if returned, will contain the un-optimized completion.
423    /// Setting tokens to true or log_probs to any value will also trigger the raw completion
424    /// to be returned.
425    #[serde(skip_serializing_if = "Option::is_none")]
426    pub raw_completion: Option<bool>,
427
428    /// We continually research optimal ways to work with our models. By default, we apply these
429    /// optimizations to both your prompt and completion for you.
430    /// Our goal is to improve your results while using our API. But you can always pass
431    /// `disable_optimizations: true` and we will leave your prompt and completion untouched.
432    #[serde(skip_serializing_if = "Option::is_none")]
433    pub disable_optimizations: Option<bool>,
434
435    /// Bias the completion to only generate options within this list;
436    /// all other tokens are disregarded at sampling
437    ///
438    /// Note that strings in the inclusion list must not be prefixes
439    /// of strings in the exclusion list and vice versa
440    #[serde(skip_serializing_if = "Option::is_none")]
441    pub completion_bias_inclusion: Option<Vec<String>>,
442
443    /// Only consider the first token for the completion_bias_inclusion
444    #[serde(skip_serializing_if = "Option::is_none")]
445    pub completion_bias_inclusion_first_token_only: Option<bool>,
446
447    /// Bias the completion to NOT generate options within this list;
448    /// all other tokens are unaffected in sampling
449    ///
450    /// Note that strings in the inclusion list must not be prefixes
451    /// of strings in the exclusion list and vice versa
452    #[serde(skip_serializing_if = "Option::is_none")]
453    pub completion_bias_exclusion: Option<Vec<String>>,
454
455    /// Only consider the first token for the completion_bias_exclusion
456    #[serde(skip_serializing_if = "Option::is_none")]
457    pub completion_bias_exclusion_first_token_only: Option<bool>,
458
459    /// If set to `null`, attention control parameters only apply to those tokens that have
460    /// explicitly been set in the request.
461    /// If set to a non-null value, we apply the control parameters to similar tokens as well.
462    /// Controls that have been applied to one token will then be applied to all other tokens
463    /// that have at least the similarity score defined by this parameter.
464    /// The similarity score is the cosine similarity of token embeddings.
465    #[serde(skip_serializing_if = "Option::is_none")]
466    pub contextual_control_threshold: Option<f64>,
467
468    /// `true`: apply controls on prompt items by adding the `log(control_factor)` to attention scores.
469    /// `false`: apply controls on prompt items by
470    /// `(attention_scores - -attention_scores.min(-1)) * control_factor`
471    #[serde(skip_serializing_if = "Option::is_none")]
472    pub control_log_additive: Option<bool>,
473
474    /// The logit bias allows to influence the likelihood of generating tokens. A dictionary mapping token
475    /// ids (int) to a bias (float) can be provided. Such bias is added to the logits as generated by the
476    /// model.
477    #[serde(skip_serializing_if = "Option::is_none")]
478    pub logit_bias: Option<HashMap<i32, f32>>,
479}
480
481impl CompletionRequest {
482    pub fn new(model: String, prompt: Prompt, maximum_tokens: u32) -> Self {
483        Self {
484            model,
485            prompt,
486            maximum_tokens,
487            ..Self::default()
488        }
489    }
490    pub fn from_text(model: String, prompt: String, maximum_tokens: u32) -> Self {
491        Self::new(model, Prompt::from_text(prompt), maximum_tokens)
492    }
493}
494
495impl_builder_methods!(
496    CompletionRequest,
497    minimum_tokens: u32,
498    echo: bool,
499    temperature: f64,
500    top_k: u32,
501    top_p: f64,
502    presence_penalty: f64,
503    frequency_penalty: f64,
504    sequence_penalty: f64,
505    sequence_penalty_min_length: i32,
506    repetition_penalties_include_prompt: bool,
507    repetition_penalties_include_completion: bool,
508    use_multiplicative_presence_penalty: bool,
509    use_multiplicative_frequency_penalty: bool,
510    use_multiplicative_sequence_penalty: bool,
511    penalty_exceptions: Vec<String>,
512    penalty_bias: String,
513    penalty_exceptions_include_stop_sequences: bool,
514    best_of: u32,
515    n: u32,
516    log_probs: i32,
517    stop_sequences: Vec<String>,
518    tokens: bool,
519    raw_completion: bool,
520    disable_optimizations: bool,
521    completion_bias_inclusion: Vec<String>,
522    completion_bias_inclusion_first_token_only: bool,
523    completion_bias_exclusion: Vec<String>,
524    completion_bias_exclusion_first_token_only: bool,
525    contextual_control_threshold: f64,
526    control_log_additive: bool,
527    logit_bias: HashMap<i32, f32>
528);
529
530#[derive(Deserialize, Debug)]
531pub struct CompletionResponse {
532    /// model name and version (if any) of the used model for inference
533    pub model_version: String,
534    /// list of completions; may contain only one entry if no more are requested (see parameter n)
535    pub completions: Vec<CompletionOutput>,
536}
537
538impl CompletionResponse {
539    /// The best completion in the answer.
540    pub fn best(&self) -> &CompletionOutput {
541        self.completions
542            .first()
543            .expect("Response is assumed to always have at least one completion")
544    }
545
546    /// Text of the best completion.
547    pub fn best_text(&self) -> &str {
548        &self.best().completion
549    }
550}
551
552#[derive(Deserialize, Debug)]
553pub struct CompletionOutput {
554    pub completion: String,
555    pub finish_reason: String,
556}