1mod message;
2use futures::StreamExt;
3pub use message::*;
4use reqwest::{
5 Client as AClient, Response as AResponse,
6 blocking::multipart::{Form, Part},
7 blocking::{Client, Response},
8 multipart::{Form as AForm, Part as APart},
9};
10use serde_json::{Deserializer, StreamDeserializer, Value, json};
11use std::sync::Arc;
12
13pub struct AsyncGroqClient {
32 api_key: String,
33 client: Arc<AClient>,
34 endpoint: String,
35}
36
37impl AsyncGroqClient {
38 pub async fn new(api_key: String, endpoint: Option<String>) -> Self {
40 let ep = endpoint.unwrap_or_else(|| String::from("https://api.groq.com/openai/v1"));
41 Self {
42 api_key,
43 client: Arc::new(AClient::new()),
44 endpoint: ep,
45 }
46 }
47
48 async fn send_request(&self, body: Value, link: &str) -> Result<reqwest::Response, GroqError> {
59 let res = self
60 .client
61 .post(link)
62 .header("Content-Type", "application/json")
63 .header("Authorization", &format!("Bearer {}", self.api_key))
64 .json(&body)
65 .send()
66 .await?;
67 Ok(res)
68 }
69
70 pub async fn speech_to_text(
80 &self,
81 request: SpeechToTextRequest,
82 ) -> Result<SpeechToTextResponse, GroqError> {
83 let file = request.file;
84 let temperature = request.temperature;
85 let language = request.language;
86 let english_text = request.english_text;
87 let model = request.model;
88
89 let mut form = AForm::new().part("file", APart::bytes(file).file_name("audio.wav"));
90 if let Some(temp) = temperature {
91 form = form.text("temperature", temp.to_string());
92 }
93 if let Some(lang) = language {
94 form = form.text("language", lang);
95 }
96
97 let link_addition = if english_text {
98 "/audio/translations"
99 } else {
100 "/audio/transcriptions"
101 };
102 if let Some(mdl) = model {
103 form = form.text("model", mdl);
104 }
105
106 let link = format!("{}{}", self.endpoint, link_addition);
107 let response = self
108 .client
109 .post(&link)
110 .header("Authorization", &format!("Bearer {}", self.api_key))
111 .multipart(form)
112 .send()
113 .await?;
114
115 let speech_to_text_response: SpeechToTextResponse = response.json().await?;
116 Ok(speech_to_text_response)
117 }
118
119 async fn send_response(
129 &self,
130 request: ChatCompletionRequest,
131 stream: bool,
132 ) -> Result<reqwest::Response, GroqError> {
133 let messages = request
134 .messages
135 .iter()
136 .map(|m| {
137 let mut msg_json = json!({
138 "role": m.role,
139 "content": m.content,
140 });
141 if let Some(name) = &m.name {
142 msg_json["name"] = json!(name);
143 }
144 msg_json
145 })
146 .collect::<Vec<Value>>();
147
148 let mut body = json!({
149 "model": request.model,
150 "messages": messages,
151 "temperature": request.temperature.unwrap_or(1.0),
152 "max_tokens": request.max_tokens.unwrap_or(1024),
153 "top_p": request.top_p.unwrap_or(1.0),
154 "stream": request.stream.unwrap_or(stream),
155 });
156
157 if let Some(stop) = &request.stop {
158 body["stop"] = json!(stop);
159 }
160 if let Some(seed) = &request.seed {
161 body["seed"] = json!(seed);
162 }
163
164 let response = self
165 .send_request(body, &format!("{}/chat/completions", self.endpoint))
166 .await?;
167 Ok(response)
168 }
169
170 pub async fn chat_completion(
180 &self,
181 request: ChatCompletionRequest,
182 ) -> Result<ChatCompletionResponse, GroqError> {
183 if Some(true) == request.stream {
184 return Err(GroqError::InvalidRequest(
185 "Stream parameter must be set to false for non-streaming responses.".to_string(),
186 ));
187 }
188 let response = self.send_response(request, false).await?;
189 let response = self.parse_response(response).await?;
190
191 let chat_completion_response: ChatCompletionResponse = serde_json::from_value(response)?;
192 Ok(chat_completion_response)
193 }
194
195 pub async fn stream(
205 &self,
206 request: ChatCompletionRequest,
207 ) -> Result<
208 impl futures::Stream<Item = Result<ChatCompletionDeltaResponse, GroqError>>,
209 GroqError,
210 > {
211 if Some(false) == request.stream {
212 return Err(GroqError::InvalidRequest(
213 "Stream parameter must be set to true for streaming responses.".to_string(),
214 ));
215 }
216 let response = self.send_response(request, true).await?;
217 let stream_response = response.bytes_stream();
218
219 let prefix = "data: ";
223 Ok(futures::stream::unfold(
224 (stream_response, String::new()),
225 move |(mut stream_response, mut resp_string)| async move {
226 loop {
227 resp_string = resp_string
229 .strip_prefix(&prefix)
230 .unwrap_or(&resp_string)
231 .to_string();
232
233 let mut stream: StreamDeserializer<_, ChatCompletionDeltaResponse> =
235 Deserializer::from_slice(resp_string.as_bytes()).into_iter();
236
237 if let Some(line) = stream.next() {
238 if let Ok(line) = line {
241 let offset = stream.byte_offset();
242 resp_string = resp_string[offset..].trim().to_string();
243 return Some((Ok(line), (stream_response, resp_string)));
244 } else if resp_string == "[DONE]" {
245 return None;
246 }
247 }
248
249 if let Some(chunk) = stream_response.next().await {
250 if let Err(e) = chunk {
253 return Some((Err(GroqError::from(e)), (stream_response, resp_string)));
254 }
255 let chunk = String::from_utf8_lossy(&chunk.unwrap()).trim().to_string();
256 resp_string.push_str(&chunk);
257 continue;
258 } else if resp_string.is_empty() {
259 return None;
260 } else {
261 return Some((
264 Err(GroqError::DeserializationError {
265 message: resp_string.clone(),
266 type_: "DeserializationError".to_string(),
267 }),
268 (stream_response, resp_string),
269 ));
270 }
271 }
272 },
273 ))
274 }
275
276 async fn parse_response(&self, response: AResponse) -> Result<Value, GroqError> {
286 let status = response.status();
287 let body: Value = response.json().await?;
288
289 if !status.is_success()
290 && let Some(error) = body.get("error")
291 {
292 return Err(GroqError::ApiError {
293 message: error["message"]
294 .as_str()
295 .unwrap_or("Unknown error")
296 .to_string(),
297 type_: error["type"]
298 .as_str()
299 .unwrap_or("unknown_error")
300 .to_string(),
301 });
302 }
303
304 Ok(body)
305 }
306}
307
308pub struct GroqClient {
327 api_key: String,
328 client: Client,
329 endpoint: String,
330}
331
332impl GroqClient {
333 pub fn new(api_key: String, endpoint: Option<String>) -> Self {
344 let ep = endpoint.unwrap_or_else(|| String::from("https://api.groq.com/openai/v1"));
345 Self {
346 api_key,
347 client: Client::new(),
348 endpoint: ep,
349 }
350 }
351
352 fn send_request(&self, body: Value, link: &str) -> Result<Value, GroqError> {
367 let res = self
368 .client
369 .post(link)
370 .header("Content-Type", "application/json")
371 .header("Authorization", &format!("Bearer {}", self.api_key))
372 .json(&body)
373 .send()?;
374
375 parse_response(res)
376 }
377
378 pub fn speech_to_text(
392 &self,
393 request: SpeechToTextRequest,
394 ) -> Result<SpeechToTextResponse, GroqError> {
395 let file = request.file;
397 let temperature = request.temperature;
398 let language = request.language;
399 let english_text = request.english_text;
400 let model = request.model;
401 let prompt = request.prompt;
402 let mut form = Form::new().part("file", Part::bytes(file).file_name("audio.wav"));
403
404 if let Some(temp) = temperature {
405 form = form.text("temperature", temp.to_string());
406 }
407
408 if let Some(lang) = language {
409 form = form.text("language", lang);
410 }
411
412 let link_addition = if english_text {
413 "/audio/translations"
414 } else {
415 "/audio/transcriptions"
416 };
417
418 if let Some(mdl) = model {
419 form = form.text("model", mdl);
420 }
421 if let Some(prompt) = prompt {
422 form = form.text("prompt", prompt.to_string());
423 }
424
425 let link = format!("{}{}", self.endpoint, link_addition);
426 let response = self
427 .client
428 .post(link)
429 .header("Authorization", &format!("Bearer {}", self.api_key))
430 .multipart(form)
431 .send()?;
432
433 let speech_to_text_response: SpeechToTextResponse = response.json()?;
434 Ok(speech_to_text_response)
435 }
436
437 pub fn chat_completion(
447 &self,
448 request: ChatCompletionRequest,
449 ) -> Result<ChatCompletionResponse, GroqError> {
450 let messages = request
451 .messages
452 .iter()
453 .map(|m| {
454 let mut msg_json = json!({
455 "role": m.role,
456 "content": m.content,
457 });
458 if let Some(name) = &m.name {
459 msg_json["name"] = json!(name);
460 }
461 msg_json
462 })
463 .collect::<Vec<_>>();
464
465 let mut body = json!({
466 "model": request.model,
467 "messages": messages,
468 "temperature": request.temperature.unwrap_or(1.0),
469 "max_tokens": request.max_tokens.unwrap_or(1024),
470 "top_p": request.top_p.unwrap_or(1.0),
471 "stream": request.stream.unwrap_or(false),
472 });
473
474 if let Some(stop) = &request.stop {
475 body["stop"] = json!(stop);
476 }
477 if let Some(seed) = &request.seed {
478 body["seed"] = json!(seed);
479 }
480
481 let response = self.send_request(body, &format!("{}/chat/completions", self.endpoint))?;
482 let chat_completion_response: ChatCompletionResponse = serde_json::from_value(response)?;
483 Ok(chat_completion_response)
484 }
485}
486
487fn parse_response(response: Response) -> Result<Value, GroqError> {
501 let status = response.status();
502 let body: Value = response.json()?;
503
504 if !status.is_success()
505 && let Some(error) = body.get("error")
506 {
507 return Err(GroqError::ApiError {
508 message: error["message"]
509 .as_str()
510 .unwrap_or("Unknown error")
511 .to_string(),
512 type_: error["type"]
513 .as_str()
514 .unwrap_or("unknown_error")
515 .to_string(),
516 });
517 }
518
519 Ok(body)
520}
521
522#[cfg(test)]
523mod tests {
524 use super::*;
525 use std::fs::File;
526 use std::io::Read;
527 use tokio;
528
529 #[test]
530 fn test_chat_completion() {
531 let api_key = std::env::var("GROQ_API_KEY").unwrap();
532 let client = GroqClient::new(api_key.to_string(), None);
533 let messages = vec![ChatCompletionMessage {
534 role: ChatCompletionRoles::User,
535 content: "Hello".to_string(),
536 name: None,
537 }];
538 let request = ChatCompletionRequest::new("llama3-70b-8192", messages);
539 let response = client.chat_completion(request).unwrap();
540 println!("{:?}", response);
541 assert!(!response.choices.is_empty());
542 }
543
544 #[test]
545 fn test_speech_to_text() {
546 let api_key = std::env::var("GROQ_API_KEY").unwrap();
547 let client = GroqClient::new(api_key.to_string(), None);
548 let audio_file_path = "onepiece_demo.mp4";
549 let mut file = File::open(audio_file_path).expect("Failed to open audio file");
550 let mut audio_data = Vec::new();
551 file.read_to_end(&mut audio_data)
552 .expect("Failed to read audio file");
553 let request = SpeechToTextRequest::new(audio_data)
554 .temperature(0.7)
555 .language("en")
556 .model("whisper-large-v3");
557 let response = client
558 .speech_to_text(request)
559 .expect("Failed to get response");
560 println!("Speech to Text Response: {}", response.text);
561 assert!(!response.text.is_empty());
562 }
563
564 #[tokio::test]
565 async fn test_async_chat_completion() {
566 let api_key = std::env::var("GROQ_API_KEY").unwrap();
567 let client = AsyncGroqClient::new(api_key, None).await;
568
569 let messages1 = vec![ChatCompletionMessage {
570 role: ChatCompletionRoles::User,
571 content: "Hello".to_string(),
572 name: None,
573 }];
574 let request1 = ChatCompletionRequest::new("llama-3.3-70b-versatile", messages1);
575
576 let messages2 = vec![ChatCompletionMessage {
577 role: ChatCompletionRoles::User,
578 content: "How are you?".to_string(),
579 name: None,
580 }];
581 let request2 = ChatCompletionRequest::new("llama-3.3-70b-versatile", messages2);
582
583 let (response1, response2) = tokio::join!(
584 client.chat_completion(request1),
585 client.chat_completion(request2)
586 );
587
588 let response1 = response1.expect("Failed to get response for request 1");
589 let response2 = response2.expect("Failed to get response for request 2");
590
591 println!("Response 1: {}", response1.choices[0].message.content);
592 println!("Response 2: {}", response2.choices[0].message.content);
593
594 assert!(!response1.choices.is_empty());
595 assert!(!response2.choices.is_empty());
596 }
597
598 #[tokio::test]
599 async fn test_async_stream() {
600 let api_key = std::env::var("GROQ_API_KEY").unwrap();
601 let client = AsyncGroqClient::new(api_key, None).await;
602
603 let messages1 = vec![ChatCompletionMessage {
604 role: ChatCompletionRoles::User,
605 content: "Hello!".to_string(),
606 name: None,
607 }];
608 let request1 =
609 ChatCompletionRequest::new("llama-3.3-70b-versatile", messages1).stream(true);
610
611 let messages2 = vec![ChatCompletionMessage {
612 role: ChatCompletionRoles::User,
613 content: "How are you?".to_string(),
614 name: None,
615 }];
616 let request2 =
617 ChatCompletionRequest::new("llama-3.3-70b-versatile", messages2).stream(true);
618
619 let (stream1, stream2) = tokio::join!(client.stream(request1), client.stream(request2));
620
621 let stream1 = stream1.expect("Failed to get response for request 1");
622 let stream2 = stream2.expect("Failed to get response for request 2");
623
624 let mut response1 = String::new();
625 let mut response2 = String::new();
626
627 tokio::pin!(stream1);
628 tokio::pin!(stream2);
629
630 while let Some(item) = stream1.next().await {
631 let delta = item.expect("Failed to get delta from stream 1");
632 if let Some(content) = &delta.choices[0].delta.content {
633 response1.push_str(&content);
634 }
635 }
636 println!();
637 while let Some(item) = stream2.next().await {
638 let delta = item.expect("Failed to get delta from stream 2");
639 if let Some(content) = &delta.choices[0].delta.content {
640 response2.push_str(&content);
641 }
642 }
643 println!();
644
645 println!("Response 1: {}", response1);
646 println!("Response 2: {}", response2);
647
648 assert!(!response1.is_empty());
649 assert!(!response2.is_empty());
650 }
651
652 #[tokio::test]
653 async fn test_async_stream_fail() {
654 let api_key = std::env::var("GROQ_API_KEY").unwrap();
655 let client = AsyncGroqClient::new(api_key, None).await;
656
657 let messages1 = vec![ChatCompletionMessage {
658 role: ChatCompletionRoles::User,
659 content: "Hello!".to_string(),
660 name: None,
661 }];
662 let request = ChatCompletionRequest::new("llama3-70b-8192", messages1).stream(true);
663
664 let stream = client
665 .stream(request)
666 .await
667 .expect("Failed to get response");
668
669 tokio::pin!(stream);
670
671 while let Some(item) = stream.next().await {
672 if let Err(e) = item {
673 let expected_message = r#"Deserialization error: {"error":{"message":"The model `llama3-70b-8192` has been decommissioned and is no longer supported. Please refer to https://console.groq.com/docs/deprecations for a recommendation on which model to use instead.","type":"invalid_request_error","code":"model_decommissioned"}}"#;
674 assert_eq!(e.to_string(), expected_message);
675 return;
676 } else {
677 panic!("Expected an error but got a successful response");
678 }
679 }
680 }
681
682 #[tokio::test]
683 async fn test_async_speech_to_text() {
684 let api_key = std::env::var("GROQ_API_KEY").unwrap();
685 let client = AsyncGroqClient::new(api_key, None).await;
686
687 let audio_file_path1 = "onepiece_demo.mp4";
688 let audio_file_path2 = "save.ogg";
689
690 let (audio_data1, audio_data2) = tokio::join!(
691 tokio::fs::read(audio_file_path1),
692 tokio::fs::read(audio_file_path2)
693 );
694
695 let audio_data1 = audio_data1.expect("Failed to read first audio file");
696 let audio_data2 = audio_data2.expect("Failed to read second audio file");
697
698 let (request1, request2) = (
699 SpeechToTextRequest::new(audio_data1)
700 .temperature(0.7)
701 .language("en")
702 .model("whisper-large-v3"),
703 SpeechToTextRequest::new(audio_data2)
704 .temperature(0.7)
705 .language("en")
706 .model("whisper-large-v3"),
707 );
708 let (response1, response2) = tokio::join!(
709 client.speech_to_text(request1),
710 client.speech_to_text(request2)
711 );
712
713 let response1 = response1.expect("Failed to get response for first audio");
714 let response2 = response2.expect("Failed to get response for second audio");
715
716 println!("Speech to Text Response 1: {:?}", response1);
717 println!("Speech to Text Response 2: {:?}", response2);
718
719 assert!(!response1.text.is_empty());
720 assert!(!response2.text.is_empty());
721 }
722}