groq_api_rust/
message.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use thiserror::Error;
4#[derive(Error, Debug)]
5/// Represents errors that can occur when interacting with the GROQ API.
6///
7/// - `RequestFailed`: Indicates a failure in the underlying HTTP request.
8/// - `JsonParseError`: Indicates a failure in parsing the JSON response from the API.
9/// - `ApiError`: Indicates an error returned by the API, with a message and error type.
10pub enum GroqError {
11    #[error("API request failed: {0}")]
12    RequestFailed(#[from] reqwest::Error),
13    #[error("Failed to parse JSON: {0}")]
14    JsonParseError(#[from] serde_json::Error),
15    #[error("API error: {message}")]
16    ApiError { message: String, type_: String },
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20#[serde(rename_all = "lowercase")]
21/// Represents the different roles that can be used in a chat completion message.
22///
23/// - `System`: Indicates a message from the system.
24/// - `User`: Indicates a message from the user.
25/// - `Assistant`: Indicates a message from the assistant.
26pub enum ChatCompletionRoles {
27    System,
28    User,
29    Assistant,
30}
31
32#[derive(Debug, Clone, Serialize)]
33/// Represents a message in a chat completion response.
34///
35/// - `role`: The role of the message, such as `System`, `User`, or `Assistant`.
36/// - `content`: The content of the message.
37/// - `name`: An optional name associated with the message.
38pub struct ChatCompletionMessage {
39    pub role: ChatCompletionRoles,
40    pub content: String,
41    pub name: Option<String>,
42}
43
44#[derive(Debug, Clone, Deserialize)]
45/// Represents the response from a chat completion API request.
46///
47/// - `choices`: A vector of `Choice` objects, each representing a possible response.
48/// - `created`: The timestamp (in seconds since the epoch) when the response was generated.
49/// - `id`: The unique identifier for the response.
50/// - `model`: The name of the model used to generate the response.
51/// - `object`: The type of the response object.
52/// - `system_fingerprint`: A unique identifier for the system that generated the response.
53/// - `usage`: Usage statistics for the request, including token counts and processing times.
54/// - `x_groq`: Additional metadata about the response, including the GROQ API ID.
55pub struct ChatCompletionResponse {
56    pub choices: Vec<Choice>,
57    pub created: u64,
58    pub id: String,
59    pub model: String,
60    pub object: String,
61    pub system_fingerprint: String,
62    pub usage: Usage,
63    pub x_groq: XGroq,
64}
65
66#[derive(Debug, Clone, Deserialize)]
67/// Represents a single choice in a chat completion response.
68///
69/// - `finish_reason`: The reason the generation finished, such as "stop" or "length".
70/// - `index`: The index of the choice within the list of choices.
71/// - `logprobs`: Optional log probabilities for the tokens in the generated text.
72/// - `message`: The message associated with this choice, containing the role, content, and optional name.
73pub struct Choice {
74    pub finish_reason: String,
75    pub index: u64,
76    pub logprobs: Option<Value>,
77    pub message: Message,
78}
79
80#[derive(Debug, Clone, Deserialize)]
81/// Represents a message in a chat completion response.
82///
83/// - `content`: The content of the message.
84/// - `role`: The role of the message, such as `System`, `User`, or `Assistant`.
85pub struct Message {
86    pub content: String,
87    pub role: ChatCompletionRoles,
88}
89
90#[derive(Debug, Clone, Deserialize)]
91/// Represents usage statistics for a chat completion request, including token counts and processing times.
92///
93/// - `completion_time`: The time (in seconds) it took to generate the completion.
94/// - `completion_tokens`: The number of tokens in the generated completion.
95/// - `prompt_time`: The time (in seconds) it took to process the prompt.
96/// - `prompt_tokens`: The number of tokens in the prompt.
97/// - `total_time`: The total time (in seconds) for the entire request.
98/// - `total_tokens`: The total number of tokens used in the request.
99pub struct Usage {
100    pub completion_time: f64,
101    pub completion_tokens: u64,
102    pub prompt_time: f64,
103    pub prompt_tokens: u64,
104    pub total_time: f64,
105    pub total_tokens: u64,
106}
107
108#[derive(Debug, Clone, Deserialize)]
109/// Represents a GROQ-related data structure.
110///
111/// - `id`: The unique identifier for this GROQ-related data.
112pub struct XGroq {
113    pub id: String,
114}
115
116#[derive(Debug, Clone)]
117/// Represents a request to the speech-to-text API.
118///
119/// - `file`: The audio file to be transcribed.
120/// - `model`: The speech recognition model to use.
121/// - `temperature`: The temperature parameter to control the randomness of the transcription.
122/// - `language`: The language of the audio file.
123/// - `english_text`: If true, the API will use the translation endpoint instead of the transcription endpoint.
124/// - `prompt`: An optional prompt to provide context for the transcription.
125/// - `response_format`: The desired format of the transcription response, either "text" or "json".
126pub struct SpeechToTextRequest {
127    pub file: Vec<u8>,
128    pub model: Option<String>,
129    pub temperature: Option<f64>,
130    pub language: Option<String>,
131    /// If true, the API will use following path: `/audio/translations` instead of `/audio/transcriptions`
132    pub english_text: bool,
133    pub prompt: Option<String>,
134    pub response_format: Option<String>,
135}
136
137/// Constructs a new `SpeechToTextRequest` with the given audio file.
138///
139/// # Arguments
140/// * `file` - The audio file to be transcribed.
141///
142/// # Returns
143/// A new `SpeechToTextRequest` instance with the given audio file and default values for other fields.
144impl SpeechToTextRequest {
145    pub fn new(file: Vec<u8>) -> Self {
146        Self {
147            file,
148            model: None,
149            temperature: None,
150            language: None,
151            english_text: false,
152            prompt: None,
153            response_format: None,
154        }
155    }
156
157    /// Sets the temperature parameter for the speech recognition model.
158    ///
159    /// # Arguments
160    /// * `temperature` - The temperature parameter to control the randomness of the transcription.
161    ///
162    /// # Returns
163    /// The modified `SpeechToTextRequest` instance with the updated temperature.
164    pub fn temperature(mut self, temperature: f64) -> Self {
165        self.temperature = Some(temperature);
166        self
167    }
168
169    /// Sets the language of the audio file.
170    ///
171    /// # Arguments
172    /// * `language` - The language of the audio file.
173    ///
174    /// # Returns
175    /// The modified `SpeechToTextRequest` instance with the updated language.
176    pub fn language(mut self, language: &str) -> Self {
177        self.language = Some(language.to_string());
178        self
179    }
180
181    /// Sets whether the API should use the translation endpoint instead of the transcription endpoint.
182    ///
183    /// # Arguments
184    /// * `english_text` - If true, the API will use the translation endpoint.
185    ///
186    /// # Returns
187    /// The modified `SpeechToTextRequest` instance with the updated `english_text` flag.
188    pub fn english_text(mut self, english_text: bool) -> Self {
189        self.english_text = english_text;
190        self
191    }
192
193    /// Sets the speech recognition model to use.
194    ///
195    /// # Arguments
196    /// * `model` - The speech recognition model to use.
197    ///
198    /// # Returns
199    /// The modified `SpeechToTextRequest` instance with the updated model.
200    pub fn model(mut self, model: &str) -> Self {
201        self.model = Some(model.to_string());
202        self
203    }
204
205    /// Sets the prompt to provide context for the transcription.
206    ///
207    /// # Arguments
208    /// * `prompt` - The prompt to provide context for the transcription.
209    ///
210    /// # Returns
211    /// The modified `SpeechToTextRequest` instance with the updated prompt.
212    pub fn prompt(mut self, prompt: &str) -> Self {
213        self.prompt = Some(prompt.to_string());
214        self
215    }
216
217    /// Sets the desired format of the transcription response.
218    ///
219    /// # Arguments
220    /// * `response_format` - The desired format of the transcription response, either "text" or "json".
221    ///
222    /// # Returns
223    /// The modified `SpeechToTextRequest` instance with the updated response format.
224    pub fn response_format(mut self, response_format: &str) -> Self {
225        // Currently only "text" and "json" are supported.
226        self.response_format = Some(response_format.to_string());
227        self
228    }
229}
230
231#[derive(Debug, Clone, Deserialize)]
232/// Represents the response from a speech-to-text transcription request.
233///
234/// The `text` field contains the transcribed text from the audio input.
235pub struct SpeechToTextResponse {
236    pub text: String,
237}
238
239/// Represents a request to the OpenAI chat completion API.
240///
241/// - `model`: The language model to use for the chat completion.
242/// - `messages`: The messages to provide as context for the chat completion.
243/// - `temperature`: The temperature parameter to control the randomness of the generated response.
244/// - `max_tokens`: The maximum number of tokens to generate in the response.
245/// - `top_p`: The top-p parameter to control the nucleus sampling.
246/// - `stream`: Whether to stream the response or return it all at once.
247/// - `stop`: A list of strings to stop the generation when encountered.
248/// - `seed`: The seed value to use for the random number generator.
249#[derive(Debug, Clone)]
250pub struct ChatCompletionRequest {
251    pub model: String,
252    pub messages: Vec<ChatCompletionMessage>,
253    pub temperature: Option<f64>,
254    pub max_tokens: Option<u32>,
255    pub top_p: Option<f64>,
256    pub stream: Option<bool>,
257    pub stop: Option<Vec<String>>,
258    pub seed: Option<u64>,
259}
260
261/// Represents a request to the OpenAI chat completion API.
262///
263/// This struct provides a builder-style API for constructing a `ChatCompletionRequest` with various optional parameters. The `new` method creates a new instance with default values, and the other methods allow modifying individual parameters.
264///
265/// - `model`: The language model to use for the chat completion.
266/// - `messages`: The messages to provide as context for the chat completion.
267/// - `temperature`: The temperature parameter to control the randomness of the generated response.
268/// - `max_tokens`: The maximum number of tokens to generate in the response.
269/// - `top_p`: The top-p parameter to control the nucleus sampling.
270/// - `stream`: Whether to stream the response or return it all at once.
271/// - `stop`: A list of strings to stop the generation when encountered.
272/// - `seed`: The seed value to use for the random number generator.
273impl ChatCompletionRequest {
274    /// Creates a new `ChatCompletionRequest` instance with the given model and messages.
275    ///
276    /// # Arguments
277    ///
278    /// * `model` - The language model to use for the chat completion.
279    /// * `messages` - The messages to provide as context for the chat completion.
280    pub fn new(model: &str, messages: Vec<ChatCompletionMessage>) -> Self {
281        ChatCompletionRequest {
282            model: model.to_string(),
283            messages,
284            temperature: Some(1.0),
285            max_tokens: Some(1024),
286            top_p: Some(1.0),
287            stream: Some(false),
288            stop: None,
289            seed: None,
290        }
291    }
292
293    /// Sets the temperature parameter for the chat completion request.
294    ///
295    /// The temperature parameter controls the randomness of the generated response.
296    /// Higher values (up to 1.0) make the output more random, while lower values make it more deterministic.
297    ///
298    /// # Arguments
299    ///
300    /// * `temperature` - The temperature value to use.
301    pub fn temperature(mut self, temperature: f64) -> Self {
302        self.temperature = Some(temperature);
303        self
304    }
305
306    /// Sets the maximum number of tokens to generate in the response.
307    ///
308    /// # Arguments
309    ///
310    /// * `max_tokens` - The maximum number of tokens to generate.
311    pub fn max_tokens(mut self, max_tokens: u32) -> Self {
312        self.max_tokens = Some(max_tokens);
313        self
314    }
315
316    /// Sets the top-p parameter for the chat completion request.
317    ///
318    /// The top-p parameter controls the nucleus sampling, which is a technique for sampling from the most likely tokens.
319    ///
320    /// # Arguments
321    ///
322    /// * `top_p` - The top-p value to use.
323    pub fn top_p(mut self, top_p: f64) -> Self {
324        self.top_p = Some(top_p);
325        self
326    }
327
328    /// Sets whether to stream the response or return it all at once.
329    ///
330    /// # Arguments
331    ///
332    /// * `stream` - Whether to stream the response or not.
333    pub fn stream(mut self, stream: bool) -> Self {
334        self.stream = Some(stream);
335        self
336    }
337
338    /// Sets the list of strings to stop the generation when encountered.
339    ///
340    /// # Arguments
341    ///
342    /// * `stop` - The list of stop strings.
343    pub fn stop(mut self, stop: Vec<String>) -> Self {
344        self.stop = Some(stop);
345        self
346    }
347
348    /// Sets the seed value to use for the random number generator.
349    ///
350    /// # Arguments
351    ///
352    /// * `seed` - The seed value to use.
353    pub fn seed(mut self, seed: u64) -> Self {
354        self.seed = Some(seed);
355        self
356    }
357}