groq_api_rust/
lib.rs

1mod message;
2pub use message::*;
3use reqwest::{
4    blocking::multipart::{Form, Part},
5    blocking::{Client, Response},
6    multipart::{Form as AForm, Part as APart},
7    Client as AClient, Response as AResponse,
8};
9use serde_json::{json, Value};
10use std::sync::Arc;
11
12/// An asynchronous client for interacting with the Groq API.
13///
14/// # Parameters
15///
16/// - `api_key`: The API key for authenticating with the Groq API.
17/// - `endpoint`: The URL of the Groq API endpoint. If not provided, it defaults to <https://api.groq.com/openai/v1>.
18///
19/// # Returns
20///
21/// An instance of `AsyncGroqClient` configured with the provided API key and endpoint.
22///
23/// # Example
24///
25///```
26/// use groq_client::AsyncGroqClient;
27///
28/// let client = AsyncGroqClient::new("my_api_key".to_string(), None).await;
29///```
30pub struct AsyncGroqClient {
31    api_key: String,
32    client: Arc<AClient>,
33    endpoint: String,
34}
35
36impl AsyncGroqClient {
37    /// Creates a new `AsyncGroqClient`
38    pub async fn new(api_key: String, endpoint: Option<String>) -> Self {
39        let ep = endpoint.unwrap_or_else(|| String::from("https://api.groq.com/openai/v1"));
40        Self {
41            api_key,
42            client: Arc::new(AClient::new()),
43            endpoint: ep,
44        }
45    }
46
47    /// Sends a request to the Groq API with the provided JSON body and returns the parsed response.
48    ///
49    /// # Parameters
50    ///
51    /// - `body`: The JSON body to send in the request.
52    /// - `link`: The URL link to send the request to.
53    ///
54    /// # Returns
55    ///
56    /// The parsed JSON response from the Groq API.
57    async fn send_request(&self, body: Value, link: &str) -> Result<Value, GroqError> {
58        let res = self
59            .client
60            .post(link)
61            .header("Content-Type", "application/json")
62            .header("Authorization", &format!("Bearer {}", self.api_key))
63            .json(&body)
64            .send()
65            .await?;
66
67        self.parse_response(res).await
68    }
69
70    /// Sends a speech-to-text request to the Groq API and returns the parsed response.
71    ///
72    /// # Parameters
73    ///
74    /// - `request`: The `SpeechToTextRequest` containing the audio file, temperature, language, and other options.
75    ///
76    /// # Returns
77    ///
78    /// The parsed `SpeechToTextResponse` from the Groq API.
79    pub async fn speech_to_text(
80        &self,
81        request: SpeechToTextRequest,
82    ) -> Result<SpeechToTextResponse, GroqError> {
83        let file = request.file;
84        let temperature = request.temperature;
85        let language = request.language;
86        let english_text = request.english_text;
87        let model = request.model;
88
89        let mut form = AForm::new().part("file", APart::bytes(file).file_name("audio.wav"));
90        if let Some(temp) = temperature {
91            form = form.text("temperature", temp.to_string());
92        }
93        if let Some(lang) = language {
94            form = form.text("language", lang);
95        }
96
97        let link_addition = if english_text {
98            "/audio/translations"
99        } else {
100            "/audio/transcriptions"
101        };
102        if let Some(mdl) = model {
103            form = form.text("model", mdl);
104        }
105
106        let link = format!("{}{}", self.endpoint, link_addition);
107        let response = self
108            .client
109            .post(&link)
110            .header("Authorization", &format!("Bearer {}", self.api_key))
111            .multipart(form)
112            .send()
113            .await?;
114
115        let speech_to_text_response: SpeechToTextResponse = response.json().await?;
116        Ok(speech_to_text_response)
117    }
118
119    /// Sends a chat completion request to the Groq API and returns the parsed response.
120    ///
121    /// # Parameters
122    ///
123    /// - `request`: The `ChatCompletionRequest` containing the model, messages, temperature, max tokens, top-p, and other options.
124    ///
125    /// # Returns
126    ///
127    /// The parsed `ChatCompletionResponse` from the Groq API.
128    pub async fn chat_completion(
129        &self,
130        request: ChatCompletionRequest,
131    ) -> Result<ChatCompletionResponse, GroqError> {
132        let messages = request
133            .messages
134            .iter()
135            .map(|m| {
136                let mut msg_json = json!({
137                    "role": m.role,
138                    "content": m.content,
139                });
140                if let Some(name) = &m.name {
141                    msg_json["name"] = json!(name);
142                }
143                msg_json
144            })
145            .collect::<Vec<Value>>();
146
147        let mut body = json!({
148            "model": request.model,
149            "messages": messages,
150            "temperature": request.temperature.unwrap_or(1.0),
151            "max_tokens": request.max_tokens.unwrap_or(1024),
152            "top_p": request.top_p.unwrap_or(1.0),
153            "stream": request.stream.unwrap_or(false),
154        });
155
156        if let Some(stop) = &request.stop {
157            body["stop"] = json!(stop);
158        }
159        if let Some(seed) = &request.seed {
160            body["seed"] = json!(seed);
161        }
162
163        let response = self
164            .send_request(body, &format!("{}/chat/completions", self.endpoint))
165            .await?;
166        let chat_completion_response: ChatCompletionResponse = serde_json::from_value(response)?;
167        Ok(chat_completion_response)
168    }
169
170    /// Parses the response from a Groq API request and returns the response body as a JSON value.
171    ///
172    /// # Parameters
173    ///
174    /// - `response`: The HTTP response from the Groq API request.
175    ///
176    /// # Returns
177    ///
178    /// The parsed JSON value from the response body, or a `GroqError` if the response was not successful.
179    async fn parse_response(&self, response: AResponse) -> Result<Value, GroqError> {
180        let status = response.status();
181        let body: Value = response.json().await?;
182
183        if !status.is_success() {
184            if let Some(error) = body.get("error") {
185                return Err(GroqError::ApiError {
186                    message: error["message"]
187                        .as_str()
188                        .unwrap_or("Unknown error")
189                        .to_string(),
190                    type_: error["type"]
191                        .as_str()
192                        .unwrap_or("unknown_error")
193                        .to_string(),
194                });
195            }
196        }
197
198        Ok(body)
199    }
200}
201
202/// An client for interacting with the Groq API.
203///
204/// # Parameters
205///
206/// - `api_key`: The API key for authenticating with the Groq API.
207/// - `endpoint`: The URL of the Groq API endpoint. If not provided, it defaults to <https://api.groq.com/openai/v1>.
208///
209/// # Returns
210///
211/// An instance of `GroqClient` configured with the provided API key and endpoint.
212///
213/// # Example
214///
215///```
216/// use groq_client::GroqClient;
217///
218/// let client = GroqClient::new("my_api_key".to_string(), None);
219///```
220pub struct GroqClient {
221    api_key: String,
222    client: Client,
223    endpoint: String,
224}
225
226impl GroqClient {
227    /// Constructs a new `GroqClient` instance with the provided API key and optional endpoint.
228    ///
229    /// # Parameters
230    ///
231    /// - `api_key`: The API key for authenticating with the Groq API.
232    /// - `endpoint`: The URL of the Groq API endpoint. If not provided, it defaults to <https://api.groq.com/openai/v1>.
233    ///
234    /// # Returns
235    ///
236    /// A new `GroqClient` instance configured with the provided API key and endpoint.
237    pub fn new(api_key: String, endpoint: Option<String>) -> Self {
238        let ep = endpoint.unwrap_or_else(|| String::from("https://api.groq.com/openai/v1"));
239        Self {
240            api_key,
241            client: Client::new(),
242            endpoint: ep,
243        }
244    }
245
246    /// Sends a request to the Groq API with the provided JSON body and returns the parsed response.
247    ///
248    /// # Parameters
249    ///
250    /// - `body`: The JSON body to send in the request.
251    /// - `link`: The URL link to send the request to.
252    ///
253    /// # Returns
254    ///
255    /// The parsed response from the Groq API as a `Value`.
256    ///
257    /// # Errors
258    ///
259    /// Returns a `GroqError` if there is an issue sending the request or parsing the response.
260    fn send_request(&self, body: Value, link: &str) -> Result<Value, GroqError> {
261        let res = self
262            .client
263            .post(link)
264            .header("Content-Type", "application/json")
265            .header("Authorization", &format!("Bearer {}", self.api_key))
266            .json(&body)
267            .send()?;
268
269        parse_response(res)
270    }
271
272    /// Sends a speech-to-text request to the Groq API and returns the parsed response.
273    ///
274    /// # Parameters
275    ///
276    /// - `request`: A `SpeechToTextRequest` containing the necessary parameters for the speech-to-text request.
277    ///
278    /// # Returns
279    ///
280    /// The parsed `SpeechToTextResponse` from the Groq API.
281    ///
282    /// # Errors
283    ///
284    /// Returns a `GroqError` if there is an issue sending the request or parsing the response.
285    pub fn speech_to_text(
286        &self,
287        request: SpeechToTextRequest,
288    ) -> Result<SpeechToTextResponse, GroqError> {
289        // Extract values from request
290        let file = request.file;
291        let temperature = request.temperature;
292        let language = request.language;
293        let english_text = request.english_text;
294        let model = request.model;
295        let prompt = request.prompt;
296        let mut form = Form::new().part("file", Part::bytes(file).file_name("audio.wav"));
297
298        if let Some(temp) = temperature {
299            form = form.text("temperature", temp.to_string());
300        }
301
302        if let Some(lang) = language {
303            form = form.text("language", lang);
304        }
305
306        let link_addition = if english_text {
307            "/audio/translations"
308        } else {
309            "/audio/transcriptions"
310        };
311
312        if let Some(mdl) = model {
313            form = form.text("model", mdl);
314        }
315        if let Some(prompt) = prompt {
316            form = form.text("prompt", prompt.to_string());
317        }
318
319        let link = format!("{}{}", self.endpoint, link_addition);
320        let response = self
321            .client
322            .post(link)
323            .header("Authorization", &format!("Bearer {}", self.api_key))
324            .multipart(form)
325            .send()?;
326
327        let speech_to_text_response: SpeechToTextResponse = response.json()?;
328        Ok(speech_to_text_response)
329    }
330
331    /// Sends a chat completion request to the GROQ API and returns the response.
332    ///
333    /// # Parameters
334    ///
335    /// - `request` - A `ChatCompletionRequest` containing the details of the chat completion request.
336    ///
337    /// # Errors
338    ///
339    /// Returns a `GroqError` if there is an issue sending the request or parsing the response.
340    pub fn chat_completion(
341        &self,
342        request: ChatCompletionRequest,
343    ) -> Result<ChatCompletionResponse, GroqError> {
344        let messages = request
345            .messages
346            .iter()
347            .map(|m| {
348                let mut msg_json = json!({
349                    "role": m.role,
350                    "content": m.content,
351                });
352                if let Some(name) = &m.name {
353                    msg_json["name"] = json!(name);
354                }
355                msg_json
356            })
357            .collect::<Vec<_>>();
358
359        let mut body = json!({
360            "model": request.model,
361            "messages": messages,
362            "temperature": request.temperature.unwrap_or(1.0),
363            "max_tokens": request.max_tokens.unwrap_or(1024),
364            "top_p": request.top_p.unwrap_or(1.0),
365            "stream": request.stream.unwrap_or(false),
366        });
367
368        if let Some(stop) = &request.stop {
369            body["stop"] = json!(stop);
370        }
371        if let Some(seed) = &request.seed {
372            body["seed"] = json!(seed);
373        }
374
375        let response = self.send_request(body, &format!("{}/chat/completions", self.endpoint))?;
376        let chat_completion_response: ChatCompletionResponse = serde_json::from_value(response)?;
377        Ok(chat_completion_response)
378    }
379}
380
381/// Parses the response from a GROQ API request and returns the response body as a JSON value.
382///
383/// # Parameters
384///
385/// - `response` - The HTTP response from the GROQ API request.
386///
387/// # Errors
388///
389/// Returns a `GroqError` if the response status is not successful or if there is an error parsing the response body.
390///
391/// # Returns
392///
393/// The response body as a JSON value.
394fn parse_response(response: Response) -> Result<Value, GroqError> {
395    let status = response.status();
396    let body: Value = response.json()?;
397
398    if !status.is_success() {
399        if let Some(error) = body.get("error") {
400            return Err(GroqError::ApiError {
401                message: error["message"]
402                    .as_str()
403                    .unwrap_or("Unknown error")
404                    .to_string(),
405                type_: error["type"]
406                    .as_str()
407                    .unwrap_or("unknown_error")
408                    .to_string(),
409            });
410        }
411    }
412
413    Ok(body)
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419    use std::fs::File;
420    use std::io::Read;
421    use tokio;
422
423    #[test]
424    fn test_chat_completion() {
425        let api_key = std::env::var("GROQ_API_KEY").unwrap();
426        let client = GroqClient::new(api_key.to_string(), None);
427        let messages = vec![ChatCompletionMessage {
428            role: ChatCompletionRoles::User,
429            content: "Hello".to_string(),
430            name: None,
431        }];
432        let request = ChatCompletionRequest::new("llama3-70b-8192", messages);
433        let response = client.chat_completion(request).unwrap();
434        println!("{:?}", response);
435        assert!(!response.choices.is_empty());
436    }
437
438    #[test]
439    fn test_speech_to_text() {
440        let api_key = std::env::var("GROQ_API_KEY").unwrap();
441        let client = GroqClient::new(api_key.to_string(), None);
442        let audio_file_path = "onepiece_demo.mp4";
443        let mut file = File::open(audio_file_path).expect("Failed to open audio file");
444        let mut audio_data = Vec::new();
445        file.read_to_end(&mut audio_data)
446            .expect("Failed to read audio file");
447        let request = SpeechToTextRequest::new(audio_data)
448            .temperature(0.7)
449            .language("en")
450            .model("whisper-large-v3");
451        let response = client
452            .speech_to_text(request)
453            .expect("Failed to get response");
454        println!("Speech to Text Response: {}", response.text);
455        assert!(!response.text.is_empty());
456    }
457
458    #[tokio::test]
459    async fn test_async_chat_completion() {
460        let api_key = std::env::var("GROQ_API_KEY").unwrap();
461        let client = AsyncGroqClient::new(api_key, None).await;
462
463        let messages1 = vec![ChatCompletionMessage {
464            role: ChatCompletionRoles::User,
465            content: "Hello".to_string(),
466            name: None,
467        }];
468        let request1 = ChatCompletionRequest::new("llama3-70b-8192", messages1);
469
470        let messages2 = vec![ChatCompletionMessage {
471            role: ChatCompletionRoles::User,
472            content: "How are you?".to_string(),
473            name: None,
474        }];
475        let request2 = ChatCompletionRequest::new("llama3-70b-8192", messages2);
476
477        let (response1, response2) = tokio::join!(
478            client.chat_completion(request1),
479            client.chat_completion(request2)
480        );
481
482        let response1 = response1.expect("Failed to get response for request 1");
483        let response2 = response2.expect("Failed to get response for request 2");
484
485        println!("Response 1: {}", response1.choices[0].message.content);
486        println!("Response 2: {}", response2.choices[0].message.content);
487
488        assert!(!response1.choices.is_empty());
489        assert!(!response2.choices.is_empty());
490    }
491
492    #[tokio::test]
493    async fn test_async_speech_to_text() {
494        let api_key = std::env::var("GROQ_API_KEY").unwrap();
495        let client = AsyncGroqClient::new(api_key, None).await;
496
497        let audio_file_path1 = "onepiece_demo.mp4";
498        let audio_file_path2 = "save.ogg";
499
500        let (audio_data1, audio_data2) = tokio::join!(
501            tokio::fs::read(audio_file_path1),
502            tokio::fs::read(audio_file_path2)
503        );
504
505        let audio_data1 = audio_data1.expect("Failed to read first audio file");
506        let audio_data2 = audio_data2.expect("Failed to read second audio file");
507
508        let (request1, request2) = (
509            SpeechToTextRequest::new(audio_data1)
510                .temperature(0.7)
511                .language("en")
512                .model("whisper-large-v3"),
513            SpeechToTextRequest::new(audio_data2)
514                .temperature(0.7)
515                .language("en")
516                .model("whisper-large-v3"),
517        );
518        let (response1, response2) = tokio::join!(
519            client.speech_to_text(request1),
520            client.speech_to_text(request2)
521        );
522
523        let response1 = response1.expect("Failed to get response for first audio");
524        let response2 = response2.expect("Failed to get response for second audio");
525
526        println!("Speech to Text Response 1: {:?}", response1);
527        println!("Speech to Text Response 2: {:?}", response2);
528
529        assert!(!response1.text.is_empty());
530        assert!(!response2.text.is_empty());
531    }
532}