groq_api_rust/message.rs
1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use thiserror::Error;
4#[derive(Error, Debug)]
5/// Represents errors that can occur when interacting with the GROQ API.
6///
7/// - `RequestFailed`: Indicates a failure in the underlying HTTP request.
8/// - `JsonParseError`: Indicates a failure in parsing the JSON response from the API.
9/// - `ApiError`: Indicates an error returned by the API, with a message and error type.
10pub enum GroqError {
11 #[error("API request failed: {0}")]
12 RequestFailed(#[from] reqwest::Error),
13 #[error("Failed to parse JSON: {0}")]
14 JsonParseError(#[from] serde_json::Error),
15 #[error("API error: {message}")]
16 ApiError { message: String, type_: String },
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20#[serde(rename_all = "lowercase")]
21/// Represents the different roles that can be used in a chat completion message.
22///
23/// - `System`: Indicates a message from the system.
24/// - `User`: Indicates a message from the user.
25/// - `Assistant`: Indicates a message from the assistant.
26pub enum ChatCompletionRoles {
27 System,
28 User,
29 Assistant,
30}
31
32#[derive(Debug, Clone, Serialize)]
33/// Represents a message in a chat completion response.
34///
35/// - `role`: The role of the message, such as `System`, `User`, or `Assistant`.
36/// - `content`: The content of the message.
37/// - `name`: An optional name associated with the message.
38pub struct ChatCompletionMessage {
39 pub role: ChatCompletionRoles,
40 pub content: String,
41 pub name: Option<String>,
42}
43
44#[derive(Debug, Clone, Deserialize)]
45/// Represents the response from a chat completion API request.
46///
47/// - `choices`: A vector of `Choice` objects, each representing a possible response.
48/// - `created`: The timestamp (in seconds since the epoch) when the response was generated.
49/// - `id`: The unique identifier for the response.
50/// - `model`: The name of the model used to generate the response.
51/// - `object`: The type of the response object.
52/// - `system_fingerprint`: A unique identifier for the system that generated the response.
53/// - `usage`: Usage statistics for the request, including token counts and processing times.
54/// - `x_groq`: Additional metadata about the response, including the GROQ API ID.
55pub struct ChatCompletionResponse {
56 pub choices: Vec<Choice>,
57 pub created: u64,
58 pub id: String,
59 pub model: String,
60 pub object: String,
61 pub system_fingerprint: String,
62 pub usage: Usage,
63 pub x_groq: XGroq,
64}
65
66#[derive(Debug, Clone, Deserialize)]
67/// Represents a single choice in a chat completion response.
68///
69/// - `finish_reason`: The reason the generation finished, such as "stop" or "length".
70/// - `index`: The index of the choice within the list of choices.
71/// - `logprobs`: Optional log probabilities for the tokens in the generated text.
72/// - `message`: The message associated with this choice, containing the role, content, and optional name.
73pub struct Choice {
74 pub finish_reason: String,
75 pub index: u64,
76 pub logprobs: Option<Value>,
77 pub message: Message,
78}
79
80#[derive(Debug, Clone, Deserialize)]
81/// Represents a message in a chat completion response.
82///
83/// - `content`: The content of the message.
84/// - `role`: The role of the message, such as `System`, `User`, or `Assistant`.
85pub struct Message {
86 pub content: String,
87 pub role: ChatCompletionRoles,
88}
89
90#[derive(Debug, Clone, Deserialize)]
91/// Represents usage statistics for a chat completion request, including token counts and processing times.
92///
93/// - `completion_time`: The time (in seconds) it took to generate the completion.
94/// - `completion_tokens`: The number of tokens in the generated completion.
95/// - `prompt_time`: The time (in seconds) it took to process the prompt.
96/// - `prompt_tokens`: The number of tokens in the prompt.
97/// - `total_time`: The total time (in seconds) for the entire request.
98/// - `total_tokens`: The total number of tokens used in the request.
99pub struct Usage {
100 pub completion_time: f64,
101 pub completion_tokens: u64,
102 pub prompt_time: f64,
103 pub prompt_tokens: u64,
104 pub total_time: f64,
105 pub total_tokens: u64,
106}
107
108#[derive(Debug, Clone, Deserialize)]
109/// Represents a GROQ-related data structure.
110///
111/// - `id`: The unique identifier for this GROQ-related data.
112pub struct XGroq {
113 pub id: String,
114}
115
116#[derive(Debug, Clone)]
117/// Represents a request to the speech-to-text API.
118///
119/// - `file`: The audio file to be transcribed.
120/// - `model`: The speech recognition model to use.
121/// - `temperature`: The temperature parameter to control the randomness of the transcription.
122/// - `language`: The language of the audio file.
123/// - `english_text`: If true, the API will use the translation endpoint instead of the transcription endpoint.
124/// - `prompt`: An optional prompt to provide context for the transcription.
125/// - `response_format`: The desired format of the transcription response, either "text" or "json".
126pub struct SpeechToTextRequest {
127 pub file: Vec<u8>,
128 pub model: Option<String>,
129 pub temperature: Option<f64>,
130 pub language: Option<String>,
131 /// If true, the API will use following path: `/audio/translations` instead of `/audio/transcriptions`
132 pub english_text: bool,
133 pub prompt: Option<String>,
134 pub response_format: Option<String>,
135}
136
137/// Constructs a new `SpeechToTextRequest` with the given audio file.
138///
139/// # Arguments
140/// * `file` - The audio file to be transcribed.
141///
142/// # Returns
143/// A new `SpeechToTextRequest` instance with the given audio file and default values for other fields.
144impl SpeechToTextRequest {
145 pub fn new(file: Vec<u8>) -> Self {
146 Self {
147 file,
148 model: None,
149 temperature: None,
150 language: None,
151 english_text: false,
152 prompt: None,
153 response_format: None,
154 }
155 }
156
157 /// Sets the temperature parameter for the speech recognition model.
158 ///
159 /// # Arguments
160 /// * `temperature` - The temperature parameter to control the randomness of the transcription.
161 ///
162 /// # Returns
163 /// The modified `SpeechToTextRequest` instance with the updated temperature.
164 pub fn temperature(mut self, temperature: f64) -> Self {
165 self.temperature = Some(temperature);
166 self
167 }
168
169 /// Sets the language of the audio file.
170 ///
171 /// # Arguments
172 /// * `language` - The language of the audio file.
173 ///
174 /// # Returns
175 /// The modified `SpeechToTextRequest` instance with the updated language.
176 pub fn language(mut self, language: &str) -> Self {
177 self.language = Some(language.to_string());
178 self
179 }
180
181 /// Sets whether the API should use the translation endpoint instead of the transcription endpoint.
182 ///
183 /// # Arguments
184 /// * `english_text` - If true, the API will use the translation endpoint.
185 ///
186 /// # Returns
187 /// The modified `SpeechToTextRequest` instance with the updated `english_text` flag.
188 pub fn english_text(mut self, english_text: bool) -> Self {
189 self.english_text = english_text;
190 self
191 }
192
193 /// Sets the speech recognition model to use.
194 ///
195 /// # Arguments
196 /// * `model` - The speech recognition model to use.
197 ///
198 /// # Returns
199 /// The modified `SpeechToTextRequest` instance with the updated model.
200 pub fn model(mut self, model: &str) -> Self {
201 self.model = Some(model.to_string());
202 self
203 }
204
205 /// Sets the prompt to provide context for the transcription.
206 ///
207 /// # Arguments
208 /// * `prompt` - The prompt to provide context for the transcription.
209 ///
210 /// # Returns
211 /// The modified `SpeechToTextRequest` instance with the updated prompt.
212 pub fn prompt(mut self, prompt: &str) -> Self {
213 self.prompt = Some(prompt.to_string());
214 self
215 }
216
217 /// Sets the desired format of the transcription response.
218 ///
219 /// # Arguments
220 /// * `response_format` - The desired format of the transcription response, either "text" or "json".
221 ///
222 /// # Returns
223 /// The modified `SpeechToTextRequest` instance with the updated response format.
224 pub fn response_format(mut self, response_format: &str) -> Self {
225 // Currently only "text" and "json" are supported.
226 self.response_format = Some(response_format.to_string());
227 self
228 }
229}
230
231#[derive(Debug, Clone, Deserialize)]
232/// Represents the response from a speech-to-text transcription request.
233///
234/// The `text` field contains the transcribed text from the audio input.
235pub struct SpeechToTextResponse {
236 pub text: String,
237}
238
239/// Represents a request to the OpenAI chat completion API.
240///
241/// - `model`: The language model to use for the chat completion.
242/// - `messages`: The messages to provide as context for the chat completion.
243/// - `temperature`: The temperature parameter to control the randomness of the generated response.
244/// - `max_tokens`: The maximum number of tokens to generate in the response.
245/// - `top_p`: The top-p parameter to control the nucleus sampling.
246/// - `stream`: Whether to stream the response or return it all at once.
247/// - `stop`: A list of strings to stop the generation when encountered.
248/// - `seed`: The seed value to use for the random number generator.
249#[derive(Debug, Clone)]
250pub struct ChatCompletionRequest {
251 pub model: String,
252 pub messages: Vec<ChatCompletionMessage>,
253 pub temperature: Option<f64>,
254 pub max_tokens: Option<u32>,
255 pub top_p: Option<f64>,
256 pub stream: Option<bool>,
257 pub stop: Option<Vec<String>>,
258 pub seed: Option<u64>,
259}
260
261/// Represents a request to the OpenAI chat completion API.
262///
263/// This struct provides a builder-style API for constructing a `ChatCompletionRequest` with various optional parameters. The `new` method creates a new instance with default values, and the other methods allow modifying individual parameters.
264///
265/// - `model`: The language model to use for the chat completion.
266/// - `messages`: The messages to provide as context for the chat completion.
267/// - `temperature`: The temperature parameter to control the randomness of the generated response.
268/// - `max_tokens`: The maximum number of tokens to generate in the response.
269/// - `top_p`: The top-p parameter to control the nucleus sampling.
270/// - `stream`: Whether to stream the response or return it all at once.
271/// - `stop`: A list of strings to stop the generation when encountered.
272/// - `seed`: The seed value to use for the random number generator.
273impl ChatCompletionRequest {
274 /// Creates a new `ChatCompletionRequest` instance with the given model and messages.
275 ///
276 /// # Arguments
277 ///
278 /// * `model` - The language model to use for the chat completion.
279 /// * `messages` - The messages to provide as context for the chat completion.
280 pub fn new(model: &str, messages: Vec<ChatCompletionMessage>) -> Self {
281 ChatCompletionRequest {
282 model: model.to_string(),
283 messages,
284 temperature: Some(1.0),
285 max_tokens: Some(1024),
286 top_p: Some(1.0),
287 stream: Some(false),
288 stop: None,
289 seed: None,
290 }
291 }
292
293 /// Sets the temperature parameter for the chat completion request.
294 ///
295 /// The temperature parameter controls the randomness of the generated response.
296 /// Higher values (up to 1.0) make the output more random, while lower values make it more deterministic.
297 ///
298 /// # Arguments
299 ///
300 /// * `temperature` - The temperature value to use.
301 pub fn temperature(mut self, temperature: f64) -> Self {
302 self.temperature = Some(temperature);
303 self
304 }
305
306 /// Sets the maximum number of tokens to generate in the response.
307 ///
308 /// # Arguments
309 ///
310 /// * `max_tokens` - The maximum number of tokens to generate.
311 pub fn max_tokens(mut self, max_tokens: u32) -> Self {
312 self.max_tokens = Some(max_tokens);
313 self
314 }
315
316 /// Sets the top-p parameter for the chat completion request.
317 ///
318 /// The top-p parameter controls the nucleus sampling, which is a technique for sampling from the most likely tokens.
319 ///
320 /// # Arguments
321 ///
322 /// * `top_p` - The top-p value to use.
323 pub fn top_p(mut self, top_p: f64) -> Self {
324 self.top_p = Some(top_p);
325 self
326 }
327
328 /// Sets whether to stream the response or return it all at once.
329 ///
330 /// # Arguments
331 ///
332 /// * `stream` - Whether to stream the response or not.
333 pub fn stream(mut self, stream: bool) -> Self {
334 self.stream = Some(stream);
335 self
336 }
337
338 /// Sets the list of strings to stop the generation when encountered.
339 ///
340 /// # Arguments
341 ///
342 /// * `stop` - The list of stop strings.
343 pub fn stop(mut self, stop: Vec<String>) -> Self {
344 self.stop = Some(stop);
345 self
346 }
347
348 /// Sets the seed value to use for the random number generator.
349 ///
350 /// # Arguments
351 ///
352 /// * `seed` - The seed value to use.
353 pub fn seed(mut self, seed: u64) -> Self {
354 self.seed = Some(seed);
355 self
356 }
357}