cohere_rust/api/
generate.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4use super::{GenerateModel, Truncate};
5
6#[derive(Serialize, Default, Debug)]
7pub struct GenerateRequest<'input> {
8    /// Represents the prompt or text to be completed.
9    pub prompt: &'input str,
10    /// optional - The model to use for text generation. Custom models can also be supplied with their full ID.
11    #[serde(skip_serializing_if = "Option::is_none")]
12    pub model: Option<GenerateModel>,
13    /// optional - Denotes the number of tokens to predict per generation.
14    #[serde(skip_serializing_if = "Option::is_none")]
15    pub max_tokens: Option<u32>,
16    /// optional - The ID of a custom playground preset.
17    #[serde(skip_serializing_if = "Option::is_none")]
18    pub preset: Option<String>,
19    /// optional - A non-negative float that tunes the degree of randomness in generation.
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub temperature: Option<f64>,
22    /// optional - Denotes the maximum number of generations that will be returned. Defaults to 1,
23    /// max value of 5.
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub num_generations: Option<u8>,
26    /// optional - If set to a positive integer, it ensures only the top k most likely tokens are
27    /// considered for generation at each step. Defaults to 0 (disabled)
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub k: Option<u64>,
30    /// optional - If set to a probability 0.0 < p < 1.0, it ensures that only the most likely tokens,
31    /// with total probability mass of p, are considered for generation at each step. If both k and
32    /// p are enabled, p acts after k. Max value of 1.0. Defaults to 0.75.
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub p: Option<f64>,
35    /// optional - Can be used to reduce repetitiveness of generated tokens. The higher the value,
36    /// the stronger a penalty is applied to previously present tokens, proportional to how many
37    /// times they have already appeared in the prompt or prior generation. Max value of 1.0. Defaults to 0.0.
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub frequency_penalty: Option<f64>,
40    /// optional - Can be used to reduce repetitiveness of generated tokens. Similar to frequency_penalty,
41    /// except that this penalty is applied equally to all tokens that have already appeared, regardless
42    /// of their exact frequencies. Max value of 1.0. Defaults to 0.0.
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub presence_penalty: Option<f64>,
45    /// optional - The generated text will be cut at the beginning of the earliest occurrence of an end sequence.
46    /// The sequence will be excluded from the text.
47    #[serde(skip_serializing_if = "Option::is_none")]
48    pub end_sequences: Option<Vec<String>>,
49    /// optional - The generated text will be cut at the end of the earliest occurrence of a stop sequence.
50    /// The sequence will be included the text.
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub stop_sequences: Option<Vec<String>>,
53    /// optional - One of GENERATION|ALL|NONE to specify how and if the token likelihoods are returned with
54    /// the response. If GENERATION is selected, the token likelihoods will only be provided for generated
55    /// text. If ALL is selected, the token likelihoods will be provided both for the prompt and the generated
56    /// text.
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub return_likelihoods: Option<ReturnLikelihoods>,
59    /// optional - Used to prevent the model from generating unwanted tokens or to incentivize it to include desired tokens
60    /// A map of tokens to biases where bias is a float between -10 and +10
61    /// Negative values will disincentivize that token from appearing while positives values will incentivize them
62    /// Tokens can be obtained from text using the tokenizer
63    /// Note: logit bias may not be supported for all finetune models
64    #[serde(skip_serializing_if = "Option::is_none")]
65    pub logit_bias: Option<HashMap<u64, f32>>,
66    /// optional - Specify how the API will handle inputs longer than the maximum token length.
67    /// Passing START will discard the start of the input. END will discard the end of the input.
68    /// In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
69    /// If NONE is selected, when the input exceeds the maximum input token length an error will be returned.
70    pub truncate: Option<Truncate>,
71}
72
73#[derive(strum_macros::Display, Serialize, Debug)]
74pub enum ReturnLikelihoods {
75    #[strum(serialize = "GENERATION")]
76    #[serde(rename = "GENERATION")]
77    Generation,
78    #[strum(serialize = "ALL")]
79    #[serde(rename = "ALL")]
80    All,
81    #[strum(serialize = "NONE")]
82    #[serde(rename = "NONE")]
83    None,
84}
85
86#[derive(Deserialize, Debug)]
87pub(crate) struct GenerateResponse {
88    /// Contains the generations.
89    pub generations: Vec<Generation>,
90}
91
92#[derive(Deserialize, Debug)]
93pub struct Generation {
94    /// Contains the generated text.
95    pub text: String,
96    /// The sum of the log-likelihood of each token in the string.
97    #[serde(default)]
98    pub likelihood: f64,
99    /// Only returned if `return_likelihoods` is not set to NONE.
100    /// The likelihood.
101    #[serde(default)]
102    pub token_likelihoods: Vec<TokenLikelihood>,
103}
104
105#[derive(Deserialize, Debug)]
106pub struct TokenLikelihood {
107    /// The token.
108    pub token: String,
109    /// Refers to the log-likelihood of the token. The first token of a context will not
110    /// have a likelihood.
111    pub likelihood: String,
112}