groq_api_rust/
lib.rs

1mod message;
2use futures::StreamExt;
3pub use message::*;
4use reqwest::{
5    Client as AClient, Response as AResponse,
6    blocking::multipart::{Form, Part},
7    blocking::{Client, Response},
8    multipart::{Form as AForm, Part as APart},
9};
10use serde_json::{Deserializer, StreamDeserializer, Value, json};
11use std::sync::Arc;
12
13/// An asynchronous client for interacting with the Groq API.
14///
15/// # Parameters
16///
17/// - `api_key`: The API key for authenticating with the Groq API.
18/// - `endpoint`: The URL of the Groq API endpoint. If not provided, it defaults to <https://api.groq.com/openai/v1>.
19///
20/// # Returns
21///
22/// An instance of `AsyncGroqClient` configured with the provided API key and endpoint.
23///
24/// # Example
25///
26///```
27/// use groq_client::AsyncGroqClient;
28///
29/// let client = AsyncGroqClient::new("my_api_key".to_string(), None).await;
30///```
31pub struct AsyncGroqClient {
32    api_key: String,
33    client: Arc<AClient>,
34    endpoint: String,
35}
36
37impl AsyncGroqClient {
38    /// Creates a new `AsyncGroqClient`
39    pub async fn new(api_key: String, endpoint: Option<String>) -> Self {
40        let ep = endpoint.unwrap_or_else(|| String::from("https://api.groq.com/openai/v1"));
41        Self {
42            api_key,
43            client: Arc::new(AClient::new()),
44            endpoint: ep,
45        }
46    }
47
48    /// Sends a request to the Groq API with the provided JSON body and returns the parsed response.
49    ///
50    /// # Parameters
51    ///
52    /// - `body`: The JSON body to send in the request.
53    /// - `link`: The URL link to send the request to.
54    ///
55    /// # Returns
56    ///
57    /// The parsed JSON response from the Groq API.
58    async fn send_request(&self, body: Value, link: &str) -> Result<reqwest::Response, GroqError> {
59        let res = self
60            .client
61            .post(link)
62            .header("Content-Type", "application/json")
63            .header("Authorization", &format!("Bearer {}", self.api_key))
64            .json(&body)
65            .send()
66            .await?;
67        Ok(res)
68    }
69
70    /// Sends a speech-to-text request to the Groq API and returns the parsed response.
71    ///
72    /// # Parameters
73    ///
74    /// - `request`: The `SpeechToTextRequest` containing the audio file, temperature, language, and other options.
75    ///
76    /// # Returns
77    ///
78    /// The parsed `SpeechToTextResponse` from the Groq API.
79    pub async fn speech_to_text(
80        &self,
81        request: SpeechToTextRequest,
82    ) -> Result<SpeechToTextResponse, GroqError> {
83        let file = request.file;
84        let temperature = request.temperature;
85        let language = request.language;
86        let english_text = request.english_text;
87        let model = request.model;
88
89        let mut form = AForm::new().part("file", APart::bytes(file).file_name("audio.wav"));
90        if let Some(temp) = temperature {
91            form = form.text("temperature", temp.to_string());
92        }
93        if let Some(lang) = language {
94            form = form.text("language", lang);
95        }
96
97        let link_addition = if english_text {
98            "/audio/translations"
99        } else {
100            "/audio/transcriptions"
101        };
102        if let Some(mdl) = model {
103            form = form.text("model", mdl);
104        }
105
106        let link = format!("{}{}", self.endpoint, link_addition);
107        let response = self
108            .client
109            .post(&link)
110            .header("Authorization", &format!("Bearer {}", self.api_key))
111            .multipart(form)
112            .send()
113            .await?;
114
115        let speech_to_text_response: SpeechToTextResponse = response.json().await?;
116        Ok(speech_to_text_response)
117    }
118
119    /// Internal function which sends a request to the Groq API and returns the raw response.
120    ///
121    /// # Parameters
122    ///
123    /// - `request`: The `ChatCompletionRequest` containing the model, messages, temperature, max tokens, top-p, and other options.
124    ///
125    /// # Returns
126    ///
127    /// The parsed `ChatCompletionResponse` from the Groq API.
128    async fn send_response(
129        &self,
130        request: ChatCompletionRequest,
131        stream: bool,
132    ) -> Result<reqwest::Response, GroqError> {
133        let messages = request
134            .messages
135            .iter()
136            .map(|m| {
137                let mut msg_json = json!({
138                    "role": m.role,
139                    "content": m.content,
140                });
141                if let Some(name) = &m.name {
142                    msg_json["name"] = json!(name);
143                }
144                msg_json
145            })
146            .collect::<Vec<Value>>();
147
148        let mut body = json!({
149            "model": request.model,
150            "messages": messages,
151            "temperature": request.temperature.unwrap_or(1.0),
152            "max_tokens": request.max_tokens.unwrap_or(1024),
153            "top_p": request.top_p.unwrap_or(1.0),
154            "stream": request.stream.unwrap_or(stream),
155        });
156
157        if let Some(stop) = &request.stop {
158            body["stop"] = json!(stop);
159        }
160        if let Some(seed) = &request.seed {
161            body["seed"] = json!(seed);
162        }
163
164        let response = self
165            .send_request(body, &format!("{}/chat/completions", self.endpoint))
166            .await?;
167        Ok(response)
168    }
169
170    /// Sends a chat completion request to the Groq API and returns the parsed response.
171    ///
172    /// # Parameters
173    ///
174    /// - `request`: The `ChatCompletionRequest` containing the model, messages, temperature, max tokens, top-p, and other options.
175    ///
176    /// # Returns
177    ///
178    /// The parsed `ChatCompletionResponse` from the Groq API.
179    pub async fn chat_completion(
180        &self,
181        request: ChatCompletionRequest,
182    ) -> Result<ChatCompletionResponse, GroqError> {
183        if Some(true) == request.stream {
184            return Err(GroqError::InvalidRequest(
185                "Stream parameter must be set to false for non-streaming responses.".to_string(),
186            ));
187        }
188        let response = self.send_response(request, false).await?;
189        let response = self.parse_response(response).await?;
190
191        let chat_completion_response: ChatCompletionResponse = serde_json::from_value(response)?;
192        Ok(chat_completion_response)
193    }
194
195    /// Streams to the Groq API and returns a stream of responses.
196    ///
197    /// # Parameters
198    ///
199    /// - `request`: The `ChatCompletionRequest` containing the model, messages, temperature, max tokens, top-p, and other options.
200    ///
201    /// # Returns
202    ///
203    /// A stream of `ChatCompletionDeltaResponse` from the Groq API.
204    pub async fn stream(
205        &self,
206        request: ChatCompletionRequest,
207    ) -> Result<
208        impl futures::Stream<Item = Result<ChatCompletionDeltaResponse, GroqError>>,
209        GroqError,
210    > {
211        if Some(false) == request.stream {
212            return Err(GroqError::InvalidRequest(
213                "Stream parameter must be set to true for streaming responses.".to_string(),
214            ));
215        }
216        let response = self.send_response(request, true).await?;
217        let stream_response = response.bytes_stream();
218
219        Ok(futures::stream::unfold(
220            (stream_response, String::new()),
221            |(mut stream_response, mut resp_string)| async move {
222                let prefix = String::from("data: ");
223                if let Some(chunk) = stream_response.next().await {
224                    if let Err(e) = chunk {
225                        return Some((Err(GroqError::from(e)), (stream_response, resp_string)));
226                    }
227                    let chunk = String::from_utf8_lossy(&chunk.unwrap()).trim().to_string();
228                    resp_string.push_str(&chunk);
229                }
230
231                loop {
232                    if resp_string[..prefix.len()] != prefix {
233                        return Some((
234                            Err(GroqError::ApiError {
235                                message: resp_string.clone(),
236                                type_: "api_error".to_string(),
237                            }),
238                            (stream_response, resp_string),
239                        ));
240                    } else {
241                        resp_string = resp_string[prefix.len()..].to_string();
242                    }
243
244                    let mut stream: StreamDeserializer<_, ChatCompletionDeltaResponse> =
245                        Deserializer::from_slice(resp_string.as_bytes()).into_iter();
246
247                    let line = match stream.next() {
248                        Some(l) => l,
249                        None => {
250                            println!("Breaking, no complete line yet.");
251                            continue;
252                        }
253                    };
254                    let offset = stream.byte_offset();
255
256                    if let Err(e) = &line {
257                        if resp_string == "[DONE]" {
258                            return None;
259                        } else {
260                            return Some((
261                                Err(GroqError::DeserializationError {
262                                    message: e.to_string(),
263                                    type_: format!("{:?}", e.classify()),
264                                }),
265                                (stream_response, resp_string),
266                            ));
267                        }
268                    }
269
270                    let response = line.unwrap();
271
272                    resp_string = resp_string[offset..].trim().to_string();
273                    return Some((Ok(response.clone()), (stream_response, resp_string)));
274                }
275            },
276        ))
277    }
278
279    /// Parses the response from a Groq API request and returns the response body as a JSON value.
280    ///
281    /// # Parameters
282    ///
283    /// - `response`: The HTTP response from the Groq API request.
284    ///
285    /// # Returns
286    ///
287    /// The parsed JSON value from the response body, or a `GroqError` if the response was not successful.
288    async fn parse_response(&self, response: AResponse) -> Result<Value, GroqError> {
289        let status = response.status();
290        let body: Value = response.json().await?;
291
292        if !status.is_success() {
293            if let Some(error) = body.get("error") {
294                return Err(GroqError::ApiError {
295                    message: error["message"]
296                        .as_str()
297                        .unwrap_or("Unknown error")
298                        .to_string(),
299                    type_: error["type"]
300                        .as_str()
301                        .unwrap_or("unknown_error")
302                        .to_string(),
303                });
304            }
305        }
306
307        Ok(body)
308    }
309}
310
311/// An client for interacting with the Groq API.
312///
313/// # Parameters
314///
315/// - `api_key`: The API key for authenticating with the Groq API.
316/// - `endpoint`: The URL of the Groq API endpoint. If not provided, it defaults to <https://api.groq.com/openai/v1>.
317///
318/// # Returns
319///
320/// An instance of `GroqClient` configured with the provided API key and endpoint.
321///
322/// # Example
323///
324///```
325/// use groq_client::GroqClient;
326///
327/// let client = GroqClient::new("my_api_key".to_string(), None);
328///```
329pub struct GroqClient {
330    api_key: String,
331    client: Client,
332    endpoint: String,
333}
334
335impl GroqClient {
336    /// Constructs a new `GroqClient` instance with the provided API key and optional endpoint.
337    ///
338    /// # Parameters
339    ///
340    /// - `api_key`: The API key for authenticating with the Groq API.
341    /// - `endpoint`: The URL of the Groq API endpoint. If not provided, it defaults to <https://api.groq.com/openai/v1>.
342    ///
343    /// # Returns
344    ///
345    /// A new `GroqClient` instance configured with the provided API key and endpoint.
346    pub fn new(api_key: String, endpoint: Option<String>) -> Self {
347        let ep = endpoint.unwrap_or_else(|| String::from("https://api.groq.com/openai/v1"));
348        Self {
349            api_key,
350            client: Client::new(),
351            endpoint: ep,
352        }
353    }
354
355    /// Sends a request to the Groq API with the provided JSON body and returns the parsed response.
356    ///
357    /// # Parameters
358    ///
359    /// - `body`: The JSON body to send in the request.
360    /// - `link`: The URL link to send the request to.
361    ///
362    /// # Returns
363    ///
364    /// The parsed response from the Groq API as a `Value`.
365    ///
366    /// # Errors
367    ///
368    /// Returns a `GroqError` if there is an issue sending the request or parsing the response.
369    fn send_request(&self, body: Value, link: &str) -> Result<Value, GroqError> {
370        let res = self
371            .client
372            .post(link)
373            .header("Content-Type", "application/json")
374            .header("Authorization", &format!("Bearer {}", self.api_key))
375            .json(&body)
376            .send()?;
377
378        parse_response(res)
379    }
380
381    /// Sends a speech-to-text request to the Groq API and returns the parsed response.
382    ///
383    /// # Parameters
384    ///
385    /// - `request`: A `SpeechToTextRequest` containing the necessary parameters for the speech-to-text request.
386    ///
387    /// # Returns
388    ///
389    /// The parsed `SpeechToTextResponse` from the Groq API.
390    ///
391    /// # Errors
392    ///
393    /// Returns a `GroqError` if there is an issue sending the request or parsing the response.
394    pub fn speech_to_text(
395        &self,
396        request: SpeechToTextRequest,
397    ) -> Result<SpeechToTextResponse, GroqError> {
398        // Extract values from request
399        let file = request.file;
400        let temperature = request.temperature;
401        let language = request.language;
402        let english_text = request.english_text;
403        let model = request.model;
404        let prompt = request.prompt;
405        let mut form = Form::new().part("file", Part::bytes(file).file_name("audio.wav"));
406
407        if let Some(temp) = temperature {
408            form = form.text("temperature", temp.to_string());
409        }
410
411        if let Some(lang) = language {
412            form = form.text("language", lang);
413        }
414
415        let link_addition = if english_text {
416            "/audio/translations"
417        } else {
418            "/audio/transcriptions"
419        };
420
421        if let Some(mdl) = model {
422            form = form.text("model", mdl);
423        }
424        if let Some(prompt) = prompt {
425            form = form.text("prompt", prompt.to_string());
426        }
427
428        let link = format!("{}{}", self.endpoint, link_addition);
429        let response = self
430            .client
431            .post(link)
432            .header("Authorization", &format!("Bearer {}", self.api_key))
433            .multipart(form)
434            .send()?;
435
436        let speech_to_text_response: SpeechToTextResponse = response.json()?;
437        Ok(speech_to_text_response)
438    }
439
440    /// Sends a chat completion request to the GROQ API and returns the response.
441    ///
442    /// # Parameters
443    ///
444    /// - `request` - A `ChatCompletionRequest` containing the details of the chat completion request.
445    ///
446    /// # Errors
447    ///
448    /// Returns a `GroqError` if there is an issue sending the request or parsing the response.
449    pub fn chat_completion(
450        &self,
451        request: ChatCompletionRequest,
452    ) -> Result<ChatCompletionResponse, GroqError> {
453        let messages = request
454            .messages
455            .iter()
456            .map(|m| {
457                let mut msg_json = json!({
458                    "role": m.role,
459                    "content": m.content,
460                });
461                if let Some(name) = &m.name {
462                    msg_json["name"] = json!(name);
463                }
464                msg_json
465            })
466            .collect::<Vec<_>>();
467
468        let mut body = json!({
469            "model": request.model,
470            "messages": messages,
471            "temperature": request.temperature.unwrap_or(1.0),
472            "max_tokens": request.max_tokens.unwrap_or(1024),
473            "top_p": request.top_p.unwrap_or(1.0),
474            "stream": request.stream.unwrap_or(false),
475        });
476
477        if let Some(stop) = &request.stop {
478            body["stop"] = json!(stop);
479        }
480        if let Some(seed) = &request.seed {
481            body["seed"] = json!(seed);
482        }
483
484        let response = self.send_request(body, &format!("{}/chat/completions", self.endpoint))?;
485        let chat_completion_response: ChatCompletionResponse = serde_json::from_value(response)?;
486        Ok(chat_completion_response)
487    }
488}
489
490/// Parses the response from a GROQ API request and returns the response body as a JSON value.
491///
492/// # Parameters
493///
494/// - `response` - The HTTP response from the GROQ API request.
495///
496/// # Errors
497///
498/// Returns a `GroqError` if the response status is not successful or if there is an error parsing the response body.
499///
500/// # Returns
501///
502/// The response body as a JSON value.
503fn parse_response(response: Response) -> Result<Value, GroqError> {
504    let status = response.status();
505    let body: Value = response.json()?;
506
507    if !status.is_success() {
508        if let Some(error) = body.get("error") {
509            return Err(GroqError::ApiError {
510                message: error["message"]
511                    .as_str()
512                    .unwrap_or("Unknown error")
513                    .to_string(),
514                type_: error["type"]
515                    .as_str()
516                    .unwrap_or("unknown_error")
517                    .to_string(),
518            });
519        }
520    }
521
522    Ok(body)
523}
524
525#[cfg(test)]
526mod tests {
527    use super::*;
528    use std::fs::File;
529    use std::io::Read;
530    use tokio;
531
532    #[test]
533    fn test_chat_completion() {
534        let api_key = std::env::var("GROQ_API_KEY").unwrap();
535        let client = GroqClient::new(api_key.to_string(), None);
536        let messages = vec![ChatCompletionMessage {
537            role: ChatCompletionRoles::User,
538            content: "Hello".to_string(),
539            name: None,
540        }];
541        let request = ChatCompletionRequest::new("llama3-70b-8192", messages);
542        let response = client.chat_completion(request).unwrap();
543        println!("{:?}", response);
544        assert!(!response.choices.is_empty());
545    }
546
547    #[test]
548    fn test_speech_to_text() {
549        let api_key = std::env::var("GROQ_API_KEY").unwrap();
550        let client = GroqClient::new(api_key.to_string(), None);
551        let audio_file_path = "onepiece_demo.mp4";
552        let mut file = File::open(audio_file_path).expect("Failed to open audio file");
553        let mut audio_data = Vec::new();
554        file.read_to_end(&mut audio_data)
555            .expect("Failed to read audio file");
556        let request = SpeechToTextRequest::new(audio_data)
557            .temperature(0.7)
558            .language("en")
559            .model("whisper-large-v3");
560        let response = client
561            .speech_to_text(request)
562            .expect("Failed to get response");
563        println!("Speech to Text Response: {}", response.text);
564        assert!(!response.text.is_empty());
565    }
566
567    #[tokio::test]
568    async fn test_async_chat_completion() {
569        let api_key = std::env::var("GROQ_API_KEY").unwrap();
570        let client = AsyncGroqClient::new(api_key, None).await;
571
572        let messages1 = vec![ChatCompletionMessage {
573            role: ChatCompletionRoles::User,
574            content: "Hello".to_string(),
575            name: None,
576        }];
577        let request1 = ChatCompletionRequest::new("llama-3.3-70b-versatile", messages1);
578
579        let messages2 = vec![ChatCompletionMessage {
580            role: ChatCompletionRoles::User,
581            content: "How are you?".to_string(),
582            name: None,
583        }];
584        let request2 = ChatCompletionRequest::new("llama-3.3-70b-versatile", messages2);
585
586        let (response1, response2) = tokio::join!(
587            client.chat_completion(request1),
588            client.chat_completion(request2)
589        );
590
591        let response1 = response1.expect("Failed to get response for request 1");
592        let response2 = response2.expect("Failed to get response for request 2");
593
594        println!("Response 1: {}", response1.choices[0].message.content);
595        println!("Response 2: {}", response2.choices[0].message.content);
596
597        assert!(!response1.choices.is_empty());
598        assert!(!response2.choices.is_empty());
599    }
600
601    #[tokio::test]
602    async fn test_async_stream() {
603        let api_key = std::env::var("GROQ_API_KEY").unwrap();
604        let client = AsyncGroqClient::new(api_key, None).await;
605
606        let messages1 = vec![ChatCompletionMessage {
607            role: ChatCompletionRoles::User,
608            content: "Hello!".to_string(),
609            name: None,
610        }];
611        let request1 =
612            ChatCompletionRequest::new("llama-3.3-70b-versatile", messages1).stream(true);
613
614        let messages2 = vec![ChatCompletionMessage {
615            role: ChatCompletionRoles::User,
616            content: "How are you?".to_string(),
617            name: None,
618        }];
619        let request2 =
620            ChatCompletionRequest::new("llama-3.3-70b-versatile", messages2).stream(true);
621
622        let (stream1, stream2) = tokio::join!(client.stream(request1), client.stream(request2));
623
624        let stream1 = stream1.expect("Failed to get response for request 1");
625        let stream2 = stream2.expect("Failed to get response for request 2");
626
627        let mut response1 = String::new();
628        let mut response2 = String::new();
629
630        tokio::pin!(stream1);
631        tokio::pin!(stream2);
632
633        while let Some(item) = stream1.next().await {
634            let delta = item.expect("Failed to get delta from stream 1");
635            if let Some(content) = &delta.choices[0].delta.content {
636                response1.push_str(&content);
637            }
638        }
639        println!();
640        while let Some(item) = stream2.next().await {
641            let delta = item.expect("Failed to get delta from stream 2");
642            if let Some(content) = &delta.choices[0].delta.content {
643                response2.push_str(&content);
644            }
645        }
646        println!();
647
648        println!("Response 1: {}", response1);
649        println!("Response 2: {}", response2);
650
651        assert!(!response1.is_empty());
652        assert!(!response2.is_empty());
653    }
654
655    #[tokio::test]
656    async fn test_async_stream_fail() {
657        let api_key = std::env::var("GROQ_API_KEY").unwrap();
658        let client = AsyncGroqClient::new(api_key, None).await;
659
660        let messages1 = vec![ChatCompletionMessage {
661            role: ChatCompletionRoles::User,
662            content: "Hello!".to_string(),
663            name: None,
664        }];
665        let request = ChatCompletionRequest::new("llama3-70b-8192", messages1).stream(true);
666
667        let stream = client
668            .stream(request)
669            .await
670            .expect("Failed to get response");
671
672        tokio::pin!(stream);
673
674        while let Some(item) = stream.next().await {
675            if let Err(e) = item {
676                let expected_message = r#"API error: {"error":{"message":"The model `llama3-70b-8192` has been decommissioned and is no longer supported. Please refer to https://console.groq.com/docs/deprecations for a recommendation on which model to use instead.","type":"invalid_request_error","code":"model_decommissioned"}}"#;
677                assert_eq!(e.to_string(), expected_message);
678                return;
679            } else {
680                panic!("Expected an error but got a successful response");
681            }
682        }
683    }
684
685    #[tokio::test]
686    async fn test_async_speech_to_text() {
687        let api_key = std::env::var("GROQ_API_KEY").unwrap();
688        let client = AsyncGroqClient::new(api_key, None).await;
689
690        let audio_file_path1 = "onepiece_demo.mp4";
691        let audio_file_path2 = "save.ogg";
692
693        let (audio_data1, audio_data2) = tokio::join!(
694            tokio::fs::read(audio_file_path1),
695            tokio::fs::read(audio_file_path2)
696        );
697
698        let audio_data1 = audio_data1.expect("Failed to read first audio file");
699        let audio_data2 = audio_data2.expect("Failed to read second audio file");
700
701        let (request1, request2) = (
702            SpeechToTextRequest::new(audio_data1)
703                .temperature(0.7)
704                .language("en")
705                .model("whisper-large-v3"),
706            SpeechToTextRequest::new(audio_data2)
707                .temperature(0.7)
708                .language("en")
709                .model("whisper-large-v3"),
710        );
711        let (response1, response2) = tokio::join!(
712            client.speech_to_text(request1),
713            client.speech_to_text(request2)
714        );
715
716        let response1 = response1.expect("Failed to get response for first audio");
717        let response2 = response2.expect("Failed to get response for second audio");
718
719        println!("Speech to Text Response 1: {:?}", response1);
720        println!("Speech to Text Response 2: {:?}", response2);
721
722        assert!(!response1.text.is_empty());
723        assert!(!response2.text.is_empty());
724    }
725}