groq_api_rust/
lib.rs

1mod message;
2use futures::StreamExt;
3pub use message::*;
4use reqwest::{
5    Client as AClient, Response as AResponse,
6    blocking::multipart::{Form, Part},
7    blocking::{Client, Response},
8    multipart::{Form as AForm, Part as APart},
9};
10use serde_json::{Deserializer, StreamDeserializer, Value, json};
11use std::sync::Arc;
12
13/// An asynchronous client for interacting with the Groq API.
14///
15/// # Parameters
16///
17/// - `api_key`: The API key for authenticating with the Groq API.
18/// - `endpoint`: The URL of the Groq API endpoint. If not provided, it defaults to <https://api.groq.com/openai/v1>.
19///
20/// # Returns
21///
22/// An instance of `AsyncGroqClient` configured with the provided API key and endpoint.
23///
24/// # Example
25///
26///```
27/// use groq_client::AsyncGroqClient;
28///
29/// let client = AsyncGroqClient::new("my_api_key".to_string(), None).await;
30///```
31pub struct AsyncGroqClient {
32    api_key: String,
33    client: Arc<AClient>,
34    endpoint: String,
35}
36
37impl AsyncGroqClient {
38    /// Creates a new `AsyncGroqClient`
39    pub async fn new(api_key: String, endpoint: Option<String>) -> Self {
40        let ep = endpoint.unwrap_or_else(|| String::from("https://api.groq.com/openai/v1"));
41        Self {
42            api_key,
43            client: Arc::new(AClient::new()),
44            endpoint: ep,
45        }
46    }
47
48    /// Sends a request to the Groq API with the provided JSON body and returns the parsed response.
49    ///
50    /// # Parameters
51    ///
52    /// - `body`: The JSON body to send in the request.
53    /// - `link`: The URL link to send the request to.
54    ///
55    /// # Returns
56    ///
57    /// The parsed JSON response from the Groq API.
58    async fn send_request(&self, body: Value, link: &str) -> Result<reqwest::Response, GroqError> {
59        let res = self
60            .client
61            .post(link)
62            .header("Content-Type", "application/json")
63            .header("Authorization", &format!("Bearer {}", self.api_key))
64            .json(&body)
65            .send()
66            .await?;
67        Ok(res)
68    }
69
70    /// Sends a speech-to-text request to the Groq API and returns the parsed response.
71    ///
72    /// # Parameters
73    ///
74    /// - `request`: The `SpeechToTextRequest` containing the audio file, temperature, language, and other options.
75    ///
76    /// # Returns
77    ///
78    /// The parsed `SpeechToTextResponse` from the Groq API.
79    pub async fn speech_to_text(
80        &self,
81        request: SpeechToTextRequest,
82    ) -> Result<SpeechToTextResponse, GroqError> {
83        let file = request.file;
84        let temperature = request.temperature;
85        let language = request.language;
86        let english_text = request.english_text;
87        let model = request.model;
88
89        let mut form = AForm::new().part("file", APart::bytes(file).file_name("audio.wav"));
90        if let Some(temp) = temperature {
91            form = form.text("temperature", temp.to_string());
92        }
93        if let Some(lang) = language {
94            form = form.text("language", lang);
95        }
96
97        let link_addition = if english_text {
98            "/audio/translations"
99        } else {
100            "/audio/transcriptions"
101        };
102        if let Some(mdl) = model {
103            form = form.text("model", mdl);
104        }
105
106        let link = format!("{}{}", self.endpoint, link_addition);
107        let response = self
108            .client
109            .post(&link)
110            .header("Authorization", &format!("Bearer {}", self.api_key))
111            .multipart(form)
112            .send()
113            .await?;
114
115        let speech_to_text_response: SpeechToTextResponse = response.json().await?;
116        Ok(speech_to_text_response)
117    }
118
119    /// Internal function which sends a request to the Groq API and returns the raw response.
120    ///
121    /// # Parameters
122    ///
123    /// - `request`: The `ChatCompletionRequest` containing the model, messages, temperature, max tokens, top-p, and other options.
124    ///
125    /// # Returns
126    ///
127    /// The parsed `ChatCompletionResponse` from the Groq API.
128    async fn send_response(
129        &self,
130        request: ChatCompletionRequest,
131        stream: bool,
132    ) -> Result<reqwest::Response, GroqError> {
133        let messages = request
134            .messages
135            .iter()
136            .map(|m| {
137                let mut msg_json = json!({
138                    "role": m.role,
139                    "content": m.content,
140                });
141                if let Some(name) = &m.name {
142                    msg_json["name"] = json!(name);
143                }
144                msg_json
145            })
146            .collect::<Vec<Value>>();
147
148        let mut body = json!({
149            "model": request.model,
150            "messages": messages,
151            "temperature": request.temperature.unwrap_or(1.0),
152            "max_tokens": request.max_tokens.unwrap_or(1024),
153            "top_p": request.top_p.unwrap_or(1.0),
154            "stream": request.stream.unwrap_or(stream),
155        });
156
157        if let Some(stop) = &request.stop {
158            body["stop"] = json!(stop);
159        }
160        if let Some(seed) = &request.seed {
161            body["seed"] = json!(seed);
162        }
163
164        let response = self
165            .send_request(body, &format!("{}/chat/completions", self.endpoint))
166            .await?;
167        Ok(response)
168    }
169
170    /// Sends a chat completion request to the Groq API and returns the parsed response.
171    ///
172    /// # Parameters
173    ///
174    /// - `request`: The `ChatCompletionRequest` containing the model, messages, temperature, max tokens, top-p, and other options.
175    ///
176    /// # Returns
177    ///
178    /// The parsed `ChatCompletionResponse` from the Groq API.
179    pub async fn chat_completion(
180        &self,
181        request: ChatCompletionRequest,
182    ) -> Result<ChatCompletionResponse, GroqError> {
183        if Some(true) == request.stream {
184            return Err(GroqError::InvalidRequest(
185                "Stream parameter must be set to false for non-streaming responses.".to_string(),
186            ));
187        }
188        let response = self.send_response(request, false).await?;
189        let response = self.parse_response(response).await?;
190
191        let chat_completion_response: ChatCompletionResponse = serde_json::from_value(response)?;
192        Ok(chat_completion_response)
193    }
194
195    /// Streams to the Groq API and returns a stream of responses.
196    ///
197    /// # Parameters
198    ///
199    /// - `request`: The `ChatCompletionRequest` containing the model, messages, temperature, max tokens, top-p, and other options.
200    ///
201    /// # Returns
202    ///
203    /// A stream of `ChatCompletionDeltaResponse` from the Groq API.
204    pub async fn stream(
205        &self,
206        request: ChatCompletionRequest,
207    ) -> Result<
208        impl futures::Stream<Item = Result<ChatCompletionDeltaResponse, GroqError>>,
209        GroqError,
210    > {
211        if Some(false) == request.stream {
212            return Err(GroqError::InvalidRequest(
213                "Stream parameter must be set to true for streaming responses.".to_string(),
214            ));
215        }
216        let response = self.send_response(request, true).await?;
217        let stream_response = response.bytes_stream();
218
219        // Returns a stream that outputs everything returned from groq
220        // resp_string is the current state of what has been returned
221        // stream_response is what groq has currently sent
222        let prefix = "data: ";
223        Ok(futures::stream::unfold(
224            (stream_response, String::new()),
225            move |(mut stream_response, mut resp_string)| async move {
226                loop {
227                    // Remove prefix if it exists
228                    resp_string = resp_string
229                        .strip_prefix(&prefix)
230                        .unwrap_or(&resp_string)
231                        .to_string();
232
233                    // Attempts to deserialize resp_string
234                    let mut stream: StreamDeserializer<_, ChatCompletionDeltaResponse> =
235                        Deserializer::from_slice(resp_string.as_bytes()).into_iter();
236
237                    if let Some(line) = stream.next() {
238                        // If resp_string has a valid ChatCompletionDeltaResponse, return it
239                        // If erroring, check that it does not equal [DONE]
240                        if let Ok(line) = line {
241                            let offset = stream.byte_offset();
242                            resp_string = resp_string[offset..].trim().to_string();
243                            return Some((Ok(line), (stream_response, resp_string)));
244                        } else if resp_string == "[DONE]" {
245                            return None;
246                        }
247                    }
248
249                    if let Some(chunk) = stream_response.next().await {
250                        // Get the next chunk from the groq stream,
251                        // append it to resp_string and continue the loop to try and deserialize
252                        if let Err(e) = chunk {
253                            return Some((Err(GroqError::from(e)), (stream_response, resp_string)));
254                        }
255                        let chunk = String::from_utf8_lossy(&chunk.unwrap()).trim().to_string();
256                        resp_string.push_str(&chunk);
257                        continue;
258                    } else if resp_string.is_empty() {
259                        return None;
260                    } else {
261                        // If the stream has ended, and resp_string is not empty/[DONE]
262                        // then parsing must have failed, and there must be a deserialization error
263                        return Some((
264                            Err(GroqError::DeserializationError {
265                                message: resp_string.clone(),
266                                type_: "DeserializationError".to_string(),
267                            }),
268                            (stream_response, resp_string),
269                        ));
270                    }
271                }
272            },
273        ))
274    }
275
276    /// Parses the response from a Groq API request and returns the response body as a JSON value.
277    ///
278    /// # Parameters
279    ///
280    /// - `response`: The HTTP response from the Groq API request.
281    ///
282    /// # Returns
283    ///
284    /// The parsed JSON value from the response body, or a `GroqError` if the response was not successful.
285    async fn parse_response(&self, response: AResponse) -> Result<Value, GroqError> {
286        let status = response.status();
287        let body: Value = response.json().await?;
288
289        if !status.is_success()
290            && let Some(error) = body.get("error")
291        {
292            return Err(GroqError::ApiError {
293                message: error["message"]
294                    .as_str()
295                    .unwrap_or("Unknown error")
296                    .to_string(),
297                type_: error["type"]
298                    .as_str()
299                    .unwrap_or("unknown_error")
300                    .to_string(),
301            });
302        }
303
304        Ok(body)
305    }
306}
307
308/// An client for interacting with the Groq API.
309///
310/// # Parameters
311///
312/// - `api_key`: The API key for authenticating with the Groq API.
313/// - `endpoint`: The URL of the Groq API endpoint. If not provided, it defaults to <https://api.groq.com/openai/v1>.
314///
315/// # Returns
316///
317/// An instance of `GroqClient` configured with the provided API key and endpoint.
318///
319/// # Example
320///
321///```
322/// use groq_client::GroqClient;
323///
324/// let client = GroqClient::new("my_api_key".to_string(), None);
325///```
326pub struct GroqClient {
327    api_key: String,
328    client: Client,
329    endpoint: String,
330}
331
332impl GroqClient {
333    /// Constructs a new `GroqClient` instance with the provided API key and optional endpoint.
334    ///
335    /// # Parameters
336    ///
337    /// - `api_key`: The API key for authenticating with the Groq API.
338    /// - `endpoint`: The URL of the Groq API endpoint. If not provided, it defaults to <https://api.groq.com/openai/v1>.
339    ///
340    /// # Returns
341    ///
342    /// A new `GroqClient` instance configured with the provided API key and endpoint.
343    pub fn new(api_key: String, endpoint: Option<String>) -> Self {
344        let ep = endpoint.unwrap_or_else(|| String::from("https://api.groq.com/openai/v1"));
345        Self {
346            api_key,
347            client: Client::new(),
348            endpoint: ep,
349        }
350    }
351
352    /// Sends a request to the Groq API with the provided JSON body and returns the parsed response.
353    ///
354    /// # Parameters
355    ///
356    /// - `body`: The JSON body to send in the request.
357    /// - `link`: The URL link to send the request to.
358    ///
359    /// # Returns
360    ///
361    /// The parsed response from the Groq API as a `Value`.
362    ///
363    /// # Errors
364    ///
365    /// Returns a `GroqError` if there is an issue sending the request or parsing the response.
366    fn send_request(&self, body: Value, link: &str) -> Result<Value, GroqError> {
367        let res = self
368            .client
369            .post(link)
370            .header("Content-Type", "application/json")
371            .header("Authorization", &format!("Bearer {}", self.api_key))
372            .json(&body)
373            .send()?;
374
375        parse_response(res)
376    }
377
378    /// Sends a speech-to-text request to the Groq API and returns the parsed response.
379    ///
380    /// # Parameters
381    ///
382    /// - `request`: A `SpeechToTextRequest` containing the necessary parameters for the speech-to-text request.
383    ///
384    /// # Returns
385    ///
386    /// The parsed `SpeechToTextResponse` from the Groq API.
387    ///
388    /// # Errors
389    ///
390    /// Returns a `GroqError` if there is an issue sending the request or parsing the response.
391    pub fn speech_to_text(
392        &self,
393        request: SpeechToTextRequest,
394    ) -> Result<SpeechToTextResponse, GroqError> {
395        // Extract values from request
396        let file = request.file;
397        let temperature = request.temperature;
398        let language = request.language;
399        let english_text = request.english_text;
400        let model = request.model;
401        let prompt = request.prompt;
402        let mut form = Form::new().part("file", Part::bytes(file).file_name("audio.wav"));
403
404        if let Some(temp) = temperature {
405            form = form.text("temperature", temp.to_string());
406        }
407
408        if let Some(lang) = language {
409            form = form.text("language", lang);
410        }
411
412        let link_addition = if english_text {
413            "/audio/translations"
414        } else {
415            "/audio/transcriptions"
416        };
417
418        if let Some(mdl) = model {
419            form = form.text("model", mdl);
420        }
421        if let Some(prompt) = prompt {
422            form = form.text("prompt", prompt.to_string());
423        }
424
425        let link = format!("{}{}", self.endpoint, link_addition);
426        let response = self
427            .client
428            .post(link)
429            .header("Authorization", &format!("Bearer {}", self.api_key))
430            .multipart(form)
431            .send()?;
432
433        let speech_to_text_response: SpeechToTextResponse = response.json()?;
434        Ok(speech_to_text_response)
435    }
436
437    /// Sends a chat completion request to the GROQ API and returns the response.
438    ///
439    /// # Parameters
440    ///
441    /// - `request` - A `ChatCompletionRequest` containing the details of the chat completion request.
442    ///
443    /// # Errors
444    ///
445    /// Returns a `GroqError` if there is an issue sending the request or parsing the response.
446    pub fn chat_completion(
447        &self,
448        request: ChatCompletionRequest,
449    ) -> Result<ChatCompletionResponse, GroqError> {
450        let messages = request
451            .messages
452            .iter()
453            .map(|m| {
454                let mut msg_json = json!({
455                    "role": m.role,
456                    "content": m.content,
457                });
458                if let Some(name) = &m.name {
459                    msg_json["name"] = json!(name);
460                }
461                msg_json
462            })
463            .collect::<Vec<_>>();
464
465        let mut body = json!({
466            "model": request.model,
467            "messages": messages,
468            "temperature": request.temperature.unwrap_or(1.0),
469            "max_tokens": request.max_tokens.unwrap_or(1024),
470            "top_p": request.top_p.unwrap_or(1.0),
471            "stream": request.stream.unwrap_or(false),
472        });
473
474        if let Some(stop) = &request.stop {
475            body["stop"] = json!(stop);
476        }
477        if let Some(seed) = &request.seed {
478            body["seed"] = json!(seed);
479        }
480
481        let response = self.send_request(body, &format!("{}/chat/completions", self.endpoint))?;
482        let chat_completion_response: ChatCompletionResponse = serde_json::from_value(response)?;
483        Ok(chat_completion_response)
484    }
485}
486
487/// Parses the response from a GROQ API request and returns the response body as a JSON value.
488///
489/// # Parameters
490///
491/// - `response` - The HTTP response from the GROQ API request.
492///
493/// # Errors
494///
495/// Returns a `GroqError` if the response status is not successful or if there is an error parsing the response body.
496///
497/// # Returns
498///
499/// The response body as a JSON value.
500fn parse_response(response: Response) -> Result<Value, GroqError> {
501    let status = response.status();
502    let body: Value = response.json()?;
503
504    if !status.is_success()
505        && let Some(error) = body.get("error")
506    {
507        return Err(GroqError::ApiError {
508            message: error["message"]
509                .as_str()
510                .unwrap_or("Unknown error")
511                .to_string(),
512            type_: error["type"]
513                .as_str()
514                .unwrap_or("unknown_error")
515                .to_string(),
516        });
517    }
518
519    Ok(body)
520}
521
522#[cfg(test)]
523mod tests {
524    use super::*;
525    use std::fs::File;
526    use std::io::Read;
527    use tokio;
528
529    #[test]
530    fn test_chat_completion() {
531        let api_key = std::env::var("GROQ_API_KEY").unwrap();
532        let client = GroqClient::new(api_key.to_string(), None);
533        let messages = vec![ChatCompletionMessage {
534            role: ChatCompletionRoles::User,
535            content: "Hello".to_string(),
536            name: None,
537        }];
538        let request = ChatCompletionRequest::new("llama3-70b-8192", messages);
539        let response = client.chat_completion(request).unwrap();
540        println!("{:?}", response);
541        assert!(!response.choices.is_empty());
542    }
543
544    #[test]
545    fn test_speech_to_text() {
546        let api_key = std::env::var("GROQ_API_KEY").unwrap();
547        let client = GroqClient::new(api_key.to_string(), None);
548        let audio_file_path = "onepiece_demo.mp4";
549        let mut file = File::open(audio_file_path).expect("Failed to open audio file");
550        let mut audio_data = Vec::new();
551        file.read_to_end(&mut audio_data)
552            .expect("Failed to read audio file");
553        let request = SpeechToTextRequest::new(audio_data)
554            .temperature(0.7)
555            .language("en")
556            .model("whisper-large-v3");
557        let response = client
558            .speech_to_text(request)
559            .expect("Failed to get response");
560        println!("Speech to Text Response: {}", response.text);
561        assert!(!response.text.is_empty());
562    }
563
564    #[tokio::test]
565    async fn test_async_chat_completion() {
566        let api_key = std::env::var("GROQ_API_KEY").unwrap();
567        let client = AsyncGroqClient::new(api_key, None).await;
568
569        let messages1 = vec![ChatCompletionMessage {
570            role: ChatCompletionRoles::User,
571            content: "Hello".to_string(),
572            name: None,
573        }];
574        let request1 = ChatCompletionRequest::new("llama-3.3-70b-versatile", messages1);
575
576        let messages2 = vec![ChatCompletionMessage {
577            role: ChatCompletionRoles::User,
578            content: "How are you?".to_string(),
579            name: None,
580        }];
581        let request2 = ChatCompletionRequest::new("llama-3.3-70b-versatile", messages2);
582
583        let (response1, response2) = tokio::join!(
584            client.chat_completion(request1),
585            client.chat_completion(request2)
586        );
587
588        let response1 = response1.expect("Failed to get response for request 1");
589        let response2 = response2.expect("Failed to get response for request 2");
590
591        println!("Response 1: {}", response1.choices[0].message.content);
592        println!("Response 2: {}", response2.choices[0].message.content);
593
594        assert!(!response1.choices.is_empty());
595        assert!(!response2.choices.is_empty());
596    }
597
598    #[tokio::test]
599    async fn test_async_stream() {
600        let api_key = std::env::var("GROQ_API_KEY").unwrap();
601        let client = AsyncGroqClient::new(api_key, None).await;
602
603        let messages1 = vec![ChatCompletionMessage {
604            role: ChatCompletionRoles::User,
605            content: "Hello!".to_string(),
606            name: None,
607        }];
608        let request1 =
609            ChatCompletionRequest::new("llama-3.3-70b-versatile", messages1).stream(true);
610
611        let messages2 = vec![ChatCompletionMessage {
612            role: ChatCompletionRoles::User,
613            content: "How are you?".to_string(),
614            name: None,
615        }];
616        let request2 =
617            ChatCompletionRequest::new("llama-3.3-70b-versatile", messages2).stream(true);
618
619        let (stream1, stream2) = tokio::join!(client.stream(request1), client.stream(request2));
620
621        let stream1 = stream1.expect("Failed to get response for request 1");
622        let stream2 = stream2.expect("Failed to get response for request 2");
623
624        let mut response1 = String::new();
625        let mut response2 = String::new();
626
627        tokio::pin!(stream1);
628        tokio::pin!(stream2);
629
630        while let Some(item) = stream1.next().await {
631            let delta = item.expect("Failed to get delta from stream 1");
632            if let Some(content) = &delta.choices[0].delta.content {
633                response1.push_str(&content);
634            }
635        }
636        println!();
637        while let Some(item) = stream2.next().await {
638            let delta = item.expect("Failed to get delta from stream 2");
639            if let Some(content) = &delta.choices[0].delta.content {
640                response2.push_str(&content);
641            }
642        }
643        println!();
644
645        println!("Response 1: {}", response1);
646        println!("Response 2: {}", response2);
647
648        assert!(!response1.is_empty());
649        assert!(!response2.is_empty());
650    }
651
652    #[tokio::test]
653    async fn test_async_stream_fail() {
654        let api_key = std::env::var("GROQ_API_KEY").unwrap();
655        let client = AsyncGroqClient::new(api_key, None).await;
656
657        let messages1 = vec![ChatCompletionMessage {
658            role: ChatCompletionRoles::User,
659            content: "Hello!".to_string(),
660            name: None,
661        }];
662        let request = ChatCompletionRequest::new("llama3-70b-8192", messages1).stream(true);
663
664        let stream = client
665            .stream(request)
666            .await
667            .expect("Failed to get response");
668
669        tokio::pin!(stream);
670
671        while let Some(item) = stream.next().await {
672            if let Err(e) = item {
673                let expected_message = r#"Deserialization error: {"error":{"message":"The model `llama3-70b-8192` has been decommissioned and is no longer supported. Please refer to https://console.groq.com/docs/deprecations for a recommendation on which model to use instead.","type":"invalid_request_error","code":"model_decommissioned"}}"#;
674                assert_eq!(e.to_string(), expected_message);
675                return;
676            } else {
677                panic!("Expected an error but got a successful response");
678            }
679        }
680    }
681
682    #[tokio::test]
683    async fn test_async_speech_to_text() {
684        let api_key = std::env::var("GROQ_API_KEY").unwrap();
685        let client = AsyncGroqClient::new(api_key, None).await;
686
687        let audio_file_path1 = "onepiece_demo.mp4";
688        let audio_file_path2 = "save.ogg";
689
690        let (audio_data1, audio_data2) = tokio::join!(
691            tokio::fs::read(audio_file_path1),
692            tokio::fs::read(audio_file_path2)
693        );
694
695        let audio_data1 = audio_data1.expect("Failed to read first audio file");
696        let audio_data2 = audio_data2.expect("Failed to read second audio file");
697
698        let (request1, request2) = (
699            SpeechToTextRequest::new(audio_data1)
700                .temperature(0.7)
701                .language("en")
702                .model("whisper-large-v3"),
703            SpeechToTextRequest::new(audio_data2)
704                .temperature(0.7)
705                .language("en")
706                .model("whisper-large-v3"),
707        );
708        let (response1, response2) = tokio::join!(
709            client.speech_to_text(request1),
710            client.speech_to_text(request2)
711        );
712
713        let response1 = response1.expect("Failed to get response for first audio");
714        let response2 = response2.expect("Failed to get response for second audio");
715
716        println!("Speech to Text Response 1: {:?}", response1);
717        println!("Speech to Text Response 2: {:?}", response2);
718
719        assert!(!response1.text.is_empty());
720        assert!(!response2.text.is_empty());
721    }
722}