pharia_skill/csi/
inference.rs

1use serde::{Deserialize, Serialize};
2
3/// The reason that the model stopped completing text
4#[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq)]
5#[serde(rename_all = "snake_case")]
6pub enum FinishReason {
7    /// The model hit a natural stopping point or a provided stop sequence
8    Stop,
9    /// The maximum number of tokens specified in the request was reached
10    Length,
11    /// Content was omitted due to a flag from content filters
12    ContentFilter,
13}
14
15#[derive(Clone, Debug, Default, Serialize)]
16#[serde(rename_all = "snake_case")]
17pub enum Logprobs {
18    #[default]
19    No,
20    Sampled,
21    Top(u8),
22}
23
24#[derive(Clone, Debug, Deserialize, PartialEq)]
25pub struct Logprob {
26    pub token: Vec<u8>,
27    pub logprob: f64,
28}
29
30#[derive(Clone, Debug, Deserialize, PartialEq)]
31pub struct Distribution {
32    pub sampled: Logprob,
33    pub top: Vec<Logprob>,
34}
35
36#[derive(Clone, Debug, Deserialize, PartialEq, Eq)]
37pub struct TokenUsage {
38    pub prompt: u32,
39    pub completion: u32,
40}
41
42/// Completion request parameters
43#[derive(Clone, Debug, Serialize)]
44pub struct CompletionParams {
45    /// The maximum tokens that should be inferred.
46    ///
47    /// Note: the backing implementation may return less tokens due to
48    /// other stop reasons.
49    pub max_tokens: Option<u32>,
50    /// The randomness with which the next token is selected.
51    pub temperature: Option<f64>,
52    /// The number of possible next tokens the model will choose from.
53    pub top_k: Option<u32>,
54    /// The probability total of next tokens the model will choose from.
55    pub top_p: Option<f64>,
56    /// A list of sequences that, if encountered, the API will stop generating further tokens.
57    pub stop: Vec<String>,
58    /// Whether to include special tokens like `<|eot_id|>` in the completion
59    pub return_special_tokens: bool,
60    /// When specified, this number will decrease (or increase) the probability of repeating
61    /// tokens that were mentioned prior in the completion. The penalty is cumulative. The more
62    /// a token is mentioned in the completion, the more its probability will decrease.
63    /// A negative value will increase the likelihood of repeating tokens.
64    pub frequency_penalty: Option<f64>,
65    /// The presence penalty reduces the probability of generating tokens that are already
66    /// present in the generated text respectively prompt. Presence penalty is independent of the
67    /// number of occurrences. Increase the value to reduce the probability of repeating text.
68    pub presence_penalty: Option<f64>,
69    /// Use this to control the logarithmic probabilities you want to have returned. This is useful
70    /// to figure out how likely it had been that this specific token had been sampled.
71    pub logprobs: Logprobs,
72}
73
74impl Default for CompletionParams {
75    fn default() -> Self {
76        Self {
77            return_special_tokens: true,
78            max_tokens: None,
79            temperature: None,
80            top_k: None,
81            top_p: None,
82            stop: Vec::new(),
83            frequency_penalty: None,
84            presence_penalty: None,
85            logprobs: Logprobs::default(),
86        }
87    }
88}
89
90/// Parameters required to make a completion request.
91#[derive(Clone, Debug, Serialize)]
92pub struct CompletionRequest {
93    /// The model to generate a completion from.
94    pub model: String,
95    /// The text to prompt the model with.
96    pub prompt: String,
97    /// Parameters to adjust the sampling behavior of the model.
98    pub params: CompletionParams,
99}
100
101impl CompletionRequest {
102    pub fn new(model: impl Into<String>, prompt: impl Into<String>) -> Self {
103        Self {
104            model: model.into(),
105            prompt: prompt.into(),
106            params: CompletionParams::default(),
107        }
108    }
109
110    #[must_use]
111    pub fn with_params(mut self, params: CompletionParams) -> Self {
112        self.params = params;
113        self
114    }
115}
116
117/// The result of a completion, including the text generated as well as
118/// why the model finished completing.
119#[derive(Clone, Debug, Deserialize, PartialEq)]
120pub struct Completion {
121    /// The text generated by the model
122    pub text: String,
123    /// The reason the model finished generating
124    pub finish_reason: FinishReason,
125    /// Contains the logprobs for the sampled and top n tokens, given that
126    /// `completion-request.params.logprobs` has been set to `sampled` or `top`.
127    pub logprobs: Vec<Distribution>,
128    /// Usage statistics for the completion request.
129    pub usage: TokenUsage,
130}
131
132#[derive(Clone, Debug, Serialize, Deserialize)]
133pub struct Message {
134    pub role: String,
135    pub content: String,
136}
137
138impl Message {
139    pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
140        Self {
141            role: role.into(),
142            content: content.into(),
143        }
144    }
145
146    pub fn user(content: impl Into<String>) -> Self {
147        Self::new("user", content)
148    }
149
150    pub fn assistant(content: impl Into<String>) -> Self {
151        Self::new("assistant", content)
152    }
153
154    pub fn system(content: impl Into<String>) -> Self {
155        Self::new("system", content)
156    }
157}
158
159#[derive(Clone, Debug, Default, Serialize)]
160pub struct ChatParams {
161    /// The maximum tokens that should be inferred.
162    ///
163    /// Note: the backing implementation may return less tokens due to
164    /// other stop reasons.
165    pub max_tokens: Option<u32>,
166    /// The randomness with which the next token is selected.
167    pub temperature: Option<f64>,
168    /// The probability total of next tokens the model will choose from.
169    pub top_p: Option<f64>,
170    /// When specified, this number will decrease (or increase) the probability of repeating
171    /// tokens that were mentioned prior in the completion. The penalty is cumulative. The more
172    /// a token is mentioned in the completion, the more its probability will decrease.
173    /// A negative value will increase the likelihood of repeating tokens.
174    pub frequency_penalty: Option<f64>,
175    /// The presence penalty reduces the probability of generating tokens that are already
176    /// present in the generated text respectively prompt. Presence penalty is independent of the
177    /// number of occurrences. Increase the value to reduce the probability of repeating text.
178    pub presence_penalty: Option<f64>,
179    /// Use this to control the logarithmic probabilities you want to have returned. This is useful
180    /// to figure out how likely it had been that this specific token had been sampled.
181    pub logprobs: Logprobs,
182}
183
184#[derive(Clone, Debug, Serialize)]
185pub struct ChatRequest {
186    pub model: String,
187    pub messages: Vec<Message>,
188    pub params: ChatParams,
189}
190
191impl ChatRequest {
192    pub fn new(model: impl Into<String>, message: Message) -> Self {
193        Self {
194            model: model.into(),
195            messages: vec![message],
196            params: ChatParams::default(),
197        }
198    }
199
200    #[must_use]
201    pub fn and_message(mut self, message: Message) -> Self {
202        self.messages.push(message);
203        self
204    }
205
206    #[must_use]
207    pub fn with_params(mut self, params: ChatParams) -> Self {
208        self.params = params;
209        self
210    }
211}
212
213#[derive(Clone, Debug, Deserialize)]
214pub struct ChatResponse {
215    /// The message generated by the model
216    pub message: Message,
217    /// The reason the model finished generating
218    pub finish_reason: FinishReason,
219    /// Contains the logprobs for the sampled and top n tokens, given that
220    /// `completion-request.params.logprobs` has been set to `sampled` or `top`.
221    pub logprobs: Vec<Distribution>,
222    /// Usage statistics for the completion request.
223    pub usage: TokenUsage,
224}