cohere_rust/api/generate.rs
1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4use super::{GenerateModel, Truncate};
5
6#[derive(Serialize, Default, Debug)]
7pub struct GenerateRequest<'input> {
8 /// Represents the prompt or text to be completed.
9 pub prompt: &'input str,
10 /// optional - The model to use for text generation. Custom models can also be supplied with their full ID.
11 #[serde(skip_serializing_if = "Option::is_none")]
12 pub model: Option<GenerateModel>,
13 /// optional - Denotes the number of tokens to predict per generation.
14 #[serde(skip_serializing_if = "Option::is_none")]
15 pub max_tokens: Option<u32>,
16 /// optional - The ID of a custom playground preset.
17 #[serde(skip_serializing_if = "Option::is_none")]
18 pub preset: Option<String>,
19 /// optional - A non-negative float that tunes the degree of randomness in generation.
20 #[serde(skip_serializing_if = "Option::is_none")]
21 pub temperature: Option<f64>,
22 /// optional - Denotes the maximum number of generations that will be returned. Defaults to 1,
23 /// max value of 5.
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub num_generations: Option<u8>,
26 /// optional - If set to a positive integer, it ensures only the top k most likely tokens are
27 /// considered for generation at each step. Defaults to 0 (disabled)
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub k: Option<u64>,
30 /// optional - If set to a probability 0.0 < p < 1.0, it ensures that only the most likely tokens,
31 /// with total probability mass of p, are considered for generation at each step. If both k and
32 /// p are enabled, p acts after k. Max value of 1.0. Defaults to 0.75.
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub p: Option<f64>,
35 /// optional - Can be used to reduce repetitiveness of generated tokens. The higher the value,
36 /// the stronger a penalty is applied to previously present tokens, proportional to how many
37 /// times they have already appeared in the prompt or prior generation. Max value of 1.0. Defaults to 0.0.
38 #[serde(skip_serializing_if = "Option::is_none")]
39 pub frequency_penalty: Option<f64>,
40 /// optional - Can be used to reduce repetitiveness of generated tokens. Similar to frequency_penalty,
41 /// except that this penalty is applied equally to all tokens that have already appeared, regardless
42 /// of their exact frequencies. Max value of 1.0. Defaults to 0.0.
43 #[serde(skip_serializing_if = "Option::is_none")]
44 pub presence_penalty: Option<f64>,
45 /// optional - The generated text will be cut at the beginning of the earliest occurrence of an end sequence.
46 /// The sequence will be excluded from the text.
47 #[serde(skip_serializing_if = "Option::is_none")]
48 pub end_sequences: Option<Vec<String>>,
49 /// optional - The generated text will be cut at the end of the earliest occurrence of a stop sequence.
50 /// The sequence will be included the text.
51 #[serde(skip_serializing_if = "Option::is_none")]
52 pub stop_sequences: Option<Vec<String>>,
53 /// optional - One of GENERATION|ALL|NONE to specify how and if the token likelihoods are returned with
54 /// the response. If GENERATION is selected, the token likelihoods will only be provided for generated
55 /// text. If ALL is selected, the token likelihoods will be provided both for the prompt and the generated
56 /// text.
57 #[serde(skip_serializing_if = "Option::is_none")]
58 pub return_likelihoods: Option<ReturnLikelihoods>,
59 /// optional - Used to prevent the model from generating unwanted tokens or to incentivize it to include desired tokens
60 /// A map of tokens to biases where bias is a float between -10 and +10
61 /// Negative values will disincentivize that token from appearing while positives values will incentivize them
62 /// Tokens can be obtained from text using the tokenizer
63 /// Note: logit bias may not be supported for all finetune models
64 #[serde(skip_serializing_if = "Option::is_none")]
65 pub logit_bias: Option<HashMap<u64, f32>>,
66 /// optional - Specify how the API will handle inputs longer than the maximum token length.
67 /// Passing START will discard the start of the input. END will discard the end of the input.
68 /// In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
69 /// If NONE is selected, when the input exceeds the maximum input token length an error will be returned.
70 pub truncate: Option<Truncate>,
71}
72
73#[derive(strum_macros::Display, Serialize, Debug)]
74pub enum ReturnLikelihoods {
75 #[strum(serialize = "GENERATION")]
76 #[serde(rename = "GENERATION")]
77 Generation,
78 #[strum(serialize = "ALL")]
79 #[serde(rename = "ALL")]
80 All,
81 #[strum(serialize = "NONE")]
82 #[serde(rename = "NONE")]
83 None,
84}
85
86#[derive(Deserialize, Debug)]
87pub(crate) struct GenerateResponse {
88 /// Contains the generations.
89 pub generations: Vec<Generation>,
90}
91
92#[derive(Deserialize, Debug)]
93pub struct Generation {
94 /// Contains the generated text.
95 pub text: String,
96 /// The sum of the log-likelihood of each token in the string.
97 #[serde(default)]
98 pub likelihood: f64,
99 /// Only returned if `return_likelihoods` is not set to NONE.
100 /// The likelihood.
101 #[serde(default)]
102 pub token_likelihoods: Vec<TokenLikelihood>,
103}
104
105#[derive(Deserialize, Debug)]
106pub struct TokenLikelihood {
107 /// The token.
108 pub token: String,
109 /// Refers to the log-likelihood of the token. The first token of a context will not
110 /// have a likelihood.
111 pub likelihood: String,
112}