groq_api_rust/
message.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use thiserror::Error;
4#[derive(Error, Debug)]
5/// Represents errors that can occur when interacting with the GROQ API.
6///
7/// - `RequestFailed`: Indicates a failure in the underlying HTTP request.
8/// - `JsonParseError`: Indicates a failure in parsing the JSON response from the API.
9/// - `ApiError`: Indicates an error returned by the API, with a message and error type.
10/// - `DeserializationError`: Indicates an error with deserialization, with a message and error type.
11pub enum GroqError {
12    #[error("Invalid request: {0}")]
13    InvalidRequest(String),
14    #[error("API request failed: {0}")]
15    RequestFailed(#[from] reqwest::Error),
16    #[error("Failed to parse JSON: {0}")]
17    JsonParseError(#[from] serde_json::Error),
18    #[error("API error: {message}")]
19    ApiError { message: String, type_: String },
20    #[error("Deserialization error: {message}")]
21    DeserializationError { message: String, type_: String },
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25#[serde(rename_all = "lowercase")]
26/// Represents the different roles that can be used in a chat completion message.
27///
28/// - `System`: Indicates a message from the system.
29/// - `User`: Indicates a message from the user.
30/// - `Assistant`: Indicates a message from the assistant.
31pub enum ChatCompletionRoles {
32    System,
33    User,
34    Assistant,
35}
36
37#[derive(Debug, Clone, Serialize)]
38/// Represents a message in a chat completion response.
39///
40/// - `role`: The role of the message, such as `System`, `User`, or `Assistant`.
41/// - `content`: The content of the message.
42/// - `name`: An optional name associated with the message.
43pub struct ChatCompletionMessage {
44    pub role: ChatCompletionRoles,
45    pub content: String,
46    pub name: Option<String>,
47}
48
49#[derive(Debug, Clone, Deserialize)]
50/// Represents the response from a chat completion API request.
51///
52/// - `choices`: A vector of `Choice` objects, each representing a possible response.
53/// - `created`: The timestamp (in seconds since the epoch) when the response was generated.
54/// - `id`: The unique identifier for the response.
55/// - `model`: The name of the model used to generate the response.
56/// - `object`: The type of the response object.
57/// - `system_fingerprint`: A unique identifier for the system that generated the response.
58/// - `usage`: Usage statistics for the request, including token counts and processing times.
59/// - `x_groq`: Additional metadata about the response, including the GROQ API ID.
60pub struct ChatCompletionResponse {
61    pub choices: Vec<Choice>,
62    pub created: u64,
63    pub id: String,
64    pub model: String,
65    pub object: String,
66    pub system_fingerprint: String,
67    pub usage: Usage,
68    pub x_groq: XGroq,
69}
70
71#[derive(Debug, Clone, Deserialize)]
72/// Represents a single choice in a chat completion response.
73///
74/// - `finish_reason`: The reason the generation finished, such as "stop" or "length".
75/// - `index`: The index of the choice within the list of choices.
76/// - `logprobs`: Optional log probabilities for the tokens in the generated text.
77/// - `message`: The message associated with this choice, containing the role, content, and optional name.
78pub struct Choice {
79    pub finish_reason: String,
80    pub index: u64,
81    pub logprobs: Option<Value>,
82    pub message: Message,
83}
84
85#[derive(Debug, Clone, Deserialize)]
86/// Represents a message in a chat completion response.
87///
88/// - `content`: The content of the message.
89/// - `role`: The role of the message, such as `System`, `User`, or `Assistant`.
90pub struct Message {
91    pub content: String,
92    pub role: ChatCompletionRoles,
93}
94
95#[derive(Debug, Clone, Deserialize)]
96/// Represents the response from a chat completion API request.
97///
98/// - `choices`: A vector of `Choice` objects, each representing a possible response.
99/// - `created`: The timestamp (in seconds since the epoch) when the response was generated.
100/// - `id`: The unique identifier for the response.
101/// - `model`: The name of the model used to generate the response.
102/// - `object`: The type of the response object.
103/// - `system_fingerprint`: A unique identifier for the system that generated the response.
104/// - `usage`: Usage statistics for the request, including token counts and processing times.
105/// - `x_groq`: Additional metadata about the response, including the GROQ API ID.
106pub struct ChatCompletionDeltaResponse {
107    pub id: String,
108    pub object: String,
109    pub created: u64,
110    pub model: String,
111    pub system_fingerprint: String,
112    pub choices: Vec<ChoiceDelta>,
113    pub x_groq: Option<XGroq>,
114}
115
116#[derive(Debug, Clone, Deserialize)]
117/// Represents a single choice in a chat completion response.
118///
119/// - `finish_reason`: The reason the generation finished, such as "stop" or "length".
120/// - `index`: The index of the choice within the list of choices.
121/// - `logprobs`: Optional log probabilities for the tokens in the generated text.
122/// - `message`: The message associated with this choice, containing the role, content, and optional name.
123pub struct ChoiceDelta {
124    pub index: u64,
125    pub delta: Delta,
126    pub logprobs: Option<Value>,
127    pub finish_reason: Option<String>,
128}
129
130#[derive(Debug, Clone, Deserialize)]
131/// Represents a message in a chat completion response.
132///
133/// - `content`: The content of the message.
134/// - `role`: The role of the message, such as `System`, `User`, or `Assistant`.
135pub struct Delta {
136    pub role: Option<ChatCompletionRoles>,
137    pub content: Option<String>,
138}
139
140#[derive(Debug, Clone, Deserialize)]
141/// Represents usage statistics for a chat completion request, including token counts and processing times.
142///
143/// - `completion_time`: The time (in seconds) it took to generate the completion.
144/// - `completion_tokens`: The number of tokens in the generated completion.
145/// - `prompt_time`: The time (in seconds) it took to process the prompt.
146/// - `prompt_tokens`: The number of tokens in the prompt.
147/// - `total_time`: The total time (in seconds) for the entire request.
148/// - `total_tokens`: The total number of tokens used in the request.
149pub struct Usage {
150    pub completion_time: f64,
151    pub completion_tokens: u64,
152    pub prompt_time: f64,
153    pub prompt_tokens: u64,
154    pub total_time: f64,
155    pub total_tokens: u64,
156}
157
158#[derive(Debug, Clone, Deserialize)]
159/// Represents a GROQ-related data structure.
160///
161/// - `id`: The unique identifier for this GROQ-related data.
162pub struct XGroq {
163    pub id: String,
164}
165
166#[derive(Debug, Clone)]
167/// Represents a request to the speech-to-text API.
168///
169/// - `file`: The audio file to be transcribed.
170/// - `model`: The speech recognition model to use.
171/// - `temperature`: The temperature parameter to control the randomness of the transcription.
172/// - `language`: The language of the audio file.
173/// - `english_text`: If true, the API will use the translation endpoint instead of the transcription endpoint.
174/// - `prompt`: An optional prompt to provide context for the transcription.
175/// - `response_format`: The desired format of the transcription response, either "text" or "json".
176pub struct SpeechToTextRequest {
177    pub file: Vec<u8>,
178    pub model: Option<String>,
179    pub temperature: Option<f64>,
180    pub language: Option<String>,
181    /// If true, the API will use following path: `/audio/translations` instead of `/audio/transcriptions`
182    pub english_text: bool,
183    pub prompt: Option<String>,
184    pub response_format: Option<String>,
185}
186
187/// Constructs a new `SpeechToTextRequest` with the given audio file.
188///
189/// # Arguments
190/// * `file` - The audio file to be transcribed.
191///
192/// # Returns
193/// A new `SpeechToTextRequest` instance with the given audio file and default values for other fields.
194impl SpeechToTextRequest {
195    pub fn new(file: Vec<u8>) -> Self {
196        Self {
197            file,
198            model: None,
199            temperature: None,
200            language: None,
201            english_text: false,
202            prompt: None,
203            response_format: None,
204        }
205    }
206
207    /// Sets the temperature parameter for the speech recognition model.
208    ///
209    /// # Arguments
210    /// * `temperature` - The temperature parameter to control the randomness of the transcription.
211    ///
212    /// # Returns
213    /// The modified `SpeechToTextRequest` instance with the updated temperature.
214    pub fn temperature(mut self, temperature: f64) -> Self {
215        self.temperature = Some(temperature);
216        self
217    }
218
219    /// Sets the language of the audio file.
220    ///
221    /// # Arguments
222    /// * `language` - The language of the audio file.
223    ///
224    /// # Returns
225    /// The modified `SpeechToTextRequest` instance with the updated language.
226    pub fn language(mut self, language: &str) -> Self {
227        self.language = Some(language.to_string());
228        self
229    }
230
231    /// Sets whether the API should use the translation endpoint instead of the transcription endpoint.
232    ///
233    /// # Arguments
234    /// * `english_text` - If true, the API will use the translation endpoint.
235    ///
236    /// # Returns
237    /// The modified `SpeechToTextRequest` instance with the updated `english_text` flag.
238    pub fn english_text(mut self, english_text: bool) -> Self {
239        self.english_text = english_text;
240        self
241    }
242
243    /// Sets the speech recognition model to use.
244    ///
245    /// # Arguments
246    /// * `model` - The speech recognition model to use.
247    ///
248    /// # Returns
249    /// The modified `SpeechToTextRequest` instance with the updated model.
250    pub fn model(mut self, model: &str) -> Self {
251        self.model = Some(model.to_string());
252        self
253    }
254
255    /// Sets the prompt to provide context for the transcription.
256    ///
257    /// # Arguments
258    /// * `prompt` - The prompt to provide context for the transcription.
259    ///
260    /// # Returns
261    /// The modified `SpeechToTextRequest` instance with the updated prompt.
262    pub fn prompt(mut self, prompt: &str) -> Self {
263        self.prompt = Some(prompt.to_string());
264        self
265    }
266
267    /// Sets the desired format of the transcription response.
268    ///
269    /// # Arguments
270    /// * `response_format` - The desired format of the transcription response, either "text" or "json".
271    ///
272    /// # Returns
273    /// The modified `SpeechToTextRequest` instance with the updated response format.
274    pub fn response_format(mut self, response_format: &str) -> Self {
275        // Currently only "text" and "json" are supported.
276        self.response_format = Some(response_format.to_string());
277        self
278    }
279}
280
281#[derive(Debug, Clone, Deserialize)]
282/// Represents the response from a speech-to-text transcription request.
283///
284/// The `text` field contains the transcribed text from the audio input.
285pub struct SpeechToTextResponse {
286    pub text: String,
287}
288
289/// Represents a request to the OpenAI chat completion API.
290///
291/// - `model`: The language model to use for the chat completion.
292/// - `messages`: The messages to provide as context for the chat completion.
293/// - `temperature`: The temperature parameter to control the randomness of the generated response.
294/// - `max_tokens`: The maximum number of tokens to generate in the response.
295/// - `top_p`: The top-p parameter to control the nucleus sampling.
296/// - `stream`: Whether to stream the response or return it all at once.
297/// - `stop`: A list of strings to stop the generation when encountered.
298/// - `seed`: The seed value to use for the random number generator.
299#[derive(Debug, Clone)]
300pub struct ChatCompletionRequest {
301    pub model: String,
302    pub messages: Vec<ChatCompletionMessage>,
303    pub temperature: Option<f64>,
304    pub max_tokens: Option<u32>,
305    pub top_p: Option<f64>,
306    pub stream: Option<bool>,
307    pub stop: Option<Vec<String>>,
308    pub seed: Option<u64>,
309}
310
311/// Represents a request to the OpenAI chat completion API.
312///
313/// This struct provides a builder-style API for constructing a `ChatCompletionRequest` with various optional parameters. The `new` method creates a new instance with default values, and the other methods allow modifying individual parameters.
314///
315/// - `model`: The language model to use for the chat completion.
316/// - `messages`: The messages to provide as context for the chat completion.
317/// - `temperature`: The temperature parameter to control the randomness of the generated response.
318/// - `max_tokens`: The maximum number of tokens to generate in the response.
319/// - `top_p`: The top-p parameter to control the nucleus sampling.
320/// - `stream`: Whether to stream the response or return it all at once.
321/// - `stop`: A list of strings to stop the generation when encountered.
322/// - `seed`: The seed value to use for the random number generator.
323impl ChatCompletionRequest {
324    /// Creates a new `ChatCompletionRequest` instance with the given model and messages.
325    ///
326    /// # Arguments
327    ///
328    /// * `model` - The language model to use for the chat completion.
329    /// * `messages` - The messages to provide as context for the chat completion.
330    pub fn new(model: &str, messages: Vec<ChatCompletionMessage>) -> Self {
331        ChatCompletionRequest {
332            model: model.to_string(),
333            messages,
334            temperature: Some(1.0),
335            max_tokens: Some(1024),
336            top_p: Some(1.0),
337            stream: Some(false),
338            stop: None,
339            seed: None,
340        }
341    }
342
343    /// Sets the temperature parameter for the chat completion request.
344    ///
345    /// The temperature parameter controls the randomness of the generated response.
346    /// Higher values (up to 1.0) make the output more random, while lower values make it more deterministic.
347    ///
348    /// # Arguments
349    ///
350    /// * `temperature` - The temperature value to use.
351    pub fn temperature(mut self, temperature: f64) -> Self {
352        self.temperature = Some(temperature);
353        self
354    }
355
356    /// Sets the maximum number of tokens to generate in the response.
357    ///
358    /// # Arguments
359    ///
360    /// * `max_tokens` - The maximum number of tokens to generate.
361    pub fn max_tokens(mut self, max_tokens: u32) -> Self {
362        self.max_tokens = Some(max_tokens);
363        self
364    }
365
366    /// Sets the top-p parameter for the chat completion request.
367    ///
368    /// The top-p parameter controls the nucleus sampling, which is a technique for sampling from the most likely tokens.
369    ///
370    /// # Arguments
371    ///
372    /// * `top_p` - The top-p value to use.
373    pub fn top_p(mut self, top_p: f64) -> Self {
374        self.top_p = Some(top_p);
375        self
376    }
377
378    /// Sets whether to stream the response or return it all at once.
379    ///
380    /// # Arguments
381    ///
382    /// * `stream` - Whether to stream the response or not.
383    pub fn stream(mut self, stream: bool) -> Self {
384        self.stream = Some(stream);
385        self
386    }
387
388    /// Sets the list of strings to stop the generation when encountered.
389    ///
390    /// # Arguments
391    ///
392    /// * `stop` - The list of stop strings.
393    pub fn stop(mut self, stop: Vec<String>) -> Self {
394        self.stop = Some(stop);
395        self
396    }
397
398    /// Sets the seed value to use for the random number generator.
399    ///
400    /// # Arguments
401    ///
402    /// * `seed` - The seed value to use.
403    pub fn seed(mut self, seed: u64) -> Self {
404        self.seed = Some(seed);
405        self
406    }
407}