1mod message;
2use futures::StreamExt;
3pub use message::*;
4use reqwest::{
5 Client as AClient, Response as AResponse,
6 blocking::multipart::{Form, Part},
7 blocking::{Client, Response},
8 multipart::{Form as AForm, Part as APart},
9};
10use serde_json::{Deserializer, StreamDeserializer, Value, json};
11use std::sync::Arc;
12
13pub struct AsyncGroqClient {
32 api_key: String,
33 client: Arc<AClient>,
34 endpoint: String,
35}
36
37impl AsyncGroqClient {
38 pub async fn new(api_key: String, endpoint: Option<String>) -> Self {
40 let ep = endpoint.unwrap_or_else(|| String::from("https://api.groq.com/openai/v1"));
41 Self {
42 api_key,
43 client: Arc::new(AClient::new()),
44 endpoint: ep,
45 }
46 }
47
48 async fn send_request(&self, body: Value, link: &str) -> Result<reqwest::Response, GroqError> {
59 let res = self
60 .client
61 .post(link)
62 .header("Content-Type", "application/json")
63 .header("Authorization", &format!("Bearer {}", self.api_key))
64 .json(&body)
65 .send()
66 .await?;
67 Ok(res)
68 }
69
70 pub async fn speech_to_text(
80 &self,
81 request: SpeechToTextRequest,
82 ) -> Result<SpeechToTextResponse, GroqError> {
83 let file = request.file;
84 let temperature = request.temperature;
85 let language = request.language;
86 let english_text = request.english_text;
87 let model = request.model;
88
89 let mut form = AForm::new().part("file", APart::bytes(file).file_name("audio.wav"));
90 if let Some(temp) = temperature {
91 form = form.text("temperature", temp.to_string());
92 }
93 if let Some(lang) = language {
94 form = form.text("language", lang);
95 }
96
97 let link_addition = if english_text {
98 "/audio/translations"
99 } else {
100 "/audio/transcriptions"
101 };
102 if let Some(mdl) = model {
103 form = form.text("model", mdl);
104 }
105
106 let link = format!("{}{}", self.endpoint, link_addition);
107 let response = self
108 .client
109 .post(&link)
110 .header("Authorization", &format!("Bearer {}", self.api_key))
111 .multipart(form)
112 .send()
113 .await?;
114
115 let speech_to_text_response: SpeechToTextResponse = response.json().await?;
116 Ok(speech_to_text_response)
117 }
118
119 async fn send_response(
129 &self,
130 request: ChatCompletionRequest,
131 stream: bool,
132 ) -> Result<reqwest::Response, GroqError> {
133 let messages = request
134 .messages
135 .iter()
136 .map(|m| {
137 let mut msg_json = json!({
138 "role": m.role,
139 "content": m.content,
140 });
141 if let Some(name) = &m.name {
142 msg_json["name"] = json!(name);
143 }
144 msg_json
145 })
146 .collect::<Vec<Value>>();
147
148 let mut body = json!({
149 "model": request.model,
150 "messages": messages,
151 "temperature": request.temperature.unwrap_or(1.0),
152 "max_tokens": request.max_tokens.unwrap_or(1024),
153 "top_p": request.top_p.unwrap_or(1.0),
154 "stream": request.stream.unwrap_or(stream),
155 });
156
157 if let Some(stop) = &request.stop {
158 body["stop"] = json!(stop);
159 }
160 if let Some(seed) = &request.seed {
161 body["seed"] = json!(seed);
162 }
163
164 let response = self
165 .send_request(body, &format!("{}/chat/completions", self.endpoint))
166 .await?;
167 Ok(response)
168 }
169
170 pub async fn chat_completion(
180 &self,
181 request: ChatCompletionRequest,
182 ) -> Result<ChatCompletionResponse, GroqError> {
183 if Some(true) == request.stream {
184 return Err(GroqError::InvalidRequest(
185 "Stream parameter must be set to false for non-streaming responses.".to_string(),
186 ));
187 }
188 let response = self.send_response(request, false).await?;
189 let response = self.parse_response(response).await?;
190
191 let chat_completion_response: ChatCompletionResponse = serde_json::from_value(response)?;
192 Ok(chat_completion_response)
193 }
194
195 pub async fn stream(
205 &self,
206 request: ChatCompletionRequest,
207 ) -> Result<
208 impl futures::Stream<Item = Result<ChatCompletionDeltaResponse, GroqError>>,
209 GroqError,
210 > {
211 if Some(false) == request.stream {
212 return Err(GroqError::InvalidRequest(
213 "Stream parameter must be set to true for streaming responses.".to_string(),
214 ));
215 }
216 let response = self.send_response(request, true).await?;
217 let stream_response = response.bytes_stream();
218
219 Ok(futures::stream::unfold(
220 (stream_response, String::new()),
221 |(mut stream_response, mut resp_string)| async move {
222 let prefix = String::from("data: ");
223 if let Some(chunk) = stream_response.next().await {
224 if let Err(e) = chunk {
225 return Some((Err(GroqError::from(e)), (stream_response, resp_string)));
226 }
227 let chunk = String::from_utf8_lossy(&chunk.unwrap()).trim().to_string();
228 resp_string.push_str(&chunk);
229 }
230
231 loop {
232 if resp_string[..prefix.len()] != prefix {
233 return Some((
234 Err(GroqError::ApiError {
235 message: resp_string.clone(),
236 type_: "api_error".to_string(),
237 }),
238 (stream_response, resp_string),
239 ));
240 } else {
241 resp_string = resp_string[prefix.len()..].to_string();
242 }
243
244 let mut stream: StreamDeserializer<_, ChatCompletionDeltaResponse> =
245 Deserializer::from_slice(resp_string.as_bytes()).into_iter();
246
247 let line = match stream.next() {
248 Some(l) => l,
249 None => {
250 println!("Breaking, no complete line yet.");
251 continue;
252 }
253 };
254 let offset = stream.byte_offset();
255
256 if let Err(e) = &line {
257 if resp_string == "[DONE]" {
258 return None;
259 } else {
260 return Some((
261 Err(GroqError::DeserializationError {
262 message: e.to_string(),
263 type_: format!("{:?}", e.classify()),
264 }),
265 (stream_response, resp_string),
266 ));
267 }
268 }
269
270 let response = line.unwrap();
271
272 resp_string = resp_string[offset..].trim().to_string();
273 return Some((Ok(response.clone()), (stream_response, resp_string)));
274 }
275 },
276 ))
277 }
278
279 async fn parse_response(&self, response: AResponse) -> Result<Value, GroqError> {
289 let status = response.status();
290 let body: Value = response.json().await?;
291
292 if !status.is_success() {
293 if let Some(error) = body.get("error") {
294 return Err(GroqError::ApiError {
295 message: error["message"]
296 .as_str()
297 .unwrap_or("Unknown error")
298 .to_string(),
299 type_: error["type"]
300 .as_str()
301 .unwrap_or("unknown_error")
302 .to_string(),
303 });
304 }
305 }
306
307 Ok(body)
308 }
309}
310
311pub struct GroqClient {
330 api_key: String,
331 client: Client,
332 endpoint: String,
333}
334
335impl GroqClient {
336 pub fn new(api_key: String, endpoint: Option<String>) -> Self {
347 let ep = endpoint.unwrap_or_else(|| String::from("https://api.groq.com/openai/v1"));
348 Self {
349 api_key,
350 client: Client::new(),
351 endpoint: ep,
352 }
353 }
354
355 fn send_request(&self, body: Value, link: &str) -> Result<Value, GroqError> {
370 let res = self
371 .client
372 .post(link)
373 .header("Content-Type", "application/json")
374 .header("Authorization", &format!("Bearer {}", self.api_key))
375 .json(&body)
376 .send()?;
377
378 parse_response(res)
379 }
380
381 pub fn speech_to_text(
395 &self,
396 request: SpeechToTextRequest,
397 ) -> Result<SpeechToTextResponse, GroqError> {
398 let file = request.file;
400 let temperature = request.temperature;
401 let language = request.language;
402 let english_text = request.english_text;
403 let model = request.model;
404 let prompt = request.prompt;
405 let mut form = Form::new().part("file", Part::bytes(file).file_name("audio.wav"));
406
407 if let Some(temp) = temperature {
408 form = form.text("temperature", temp.to_string());
409 }
410
411 if let Some(lang) = language {
412 form = form.text("language", lang);
413 }
414
415 let link_addition = if english_text {
416 "/audio/translations"
417 } else {
418 "/audio/transcriptions"
419 };
420
421 if let Some(mdl) = model {
422 form = form.text("model", mdl);
423 }
424 if let Some(prompt) = prompt {
425 form = form.text("prompt", prompt.to_string());
426 }
427
428 let link = format!("{}{}", self.endpoint, link_addition);
429 let response = self
430 .client
431 .post(link)
432 .header("Authorization", &format!("Bearer {}", self.api_key))
433 .multipart(form)
434 .send()?;
435
436 let speech_to_text_response: SpeechToTextResponse = response.json()?;
437 Ok(speech_to_text_response)
438 }
439
440 pub fn chat_completion(
450 &self,
451 request: ChatCompletionRequest,
452 ) -> Result<ChatCompletionResponse, GroqError> {
453 let messages = request
454 .messages
455 .iter()
456 .map(|m| {
457 let mut msg_json = json!({
458 "role": m.role,
459 "content": m.content,
460 });
461 if let Some(name) = &m.name {
462 msg_json["name"] = json!(name);
463 }
464 msg_json
465 })
466 .collect::<Vec<_>>();
467
468 let mut body = json!({
469 "model": request.model,
470 "messages": messages,
471 "temperature": request.temperature.unwrap_or(1.0),
472 "max_tokens": request.max_tokens.unwrap_or(1024),
473 "top_p": request.top_p.unwrap_or(1.0),
474 "stream": request.stream.unwrap_or(false),
475 });
476
477 if let Some(stop) = &request.stop {
478 body["stop"] = json!(stop);
479 }
480 if let Some(seed) = &request.seed {
481 body["seed"] = json!(seed);
482 }
483
484 let response = self.send_request(body, &format!("{}/chat/completions", self.endpoint))?;
485 let chat_completion_response: ChatCompletionResponse = serde_json::from_value(response)?;
486 Ok(chat_completion_response)
487 }
488}
489
490fn parse_response(response: Response) -> Result<Value, GroqError> {
504 let status = response.status();
505 let body: Value = response.json()?;
506
507 if !status.is_success() {
508 if let Some(error) = body.get("error") {
509 return Err(GroqError::ApiError {
510 message: error["message"]
511 .as_str()
512 .unwrap_or("Unknown error")
513 .to_string(),
514 type_: error["type"]
515 .as_str()
516 .unwrap_or("unknown_error")
517 .to_string(),
518 });
519 }
520 }
521
522 Ok(body)
523}
524
525#[cfg(test)]
526mod tests {
527 use super::*;
528 use std::fs::File;
529 use std::io::Read;
530 use tokio;
531
532 #[test]
533 fn test_chat_completion() {
534 let api_key = std::env::var("GROQ_API_KEY").unwrap();
535 let client = GroqClient::new(api_key.to_string(), None);
536 let messages = vec![ChatCompletionMessage {
537 role: ChatCompletionRoles::User,
538 content: "Hello".to_string(),
539 name: None,
540 }];
541 let request = ChatCompletionRequest::new("llama3-70b-8192", messages);
542 let response = client.chat_completion(request).unwrap();
543 println!("{:?}", response);
544 assert!(!response.choices.is_empty());
545 }
546
547 #[test]
548 fn test_speech_to_text() {
549 let api_key = std::env::var("GROQ_API_KEY").unwrap();
550 let client = GroqClient::new(api_key.to_string(), None);
551 let audio_file_path = "onepiece_demo.mp4";
552 let mut file = File::open(audio_file_path).expect("Failed to open audio file");
553 let mut audio_data = Vec::new();
554 file.read_to_end(&mut audio_data)
555 .expect("Failed to read audio file");
556 let request = SpeechToTextRequest::new(audio_data)
557 .temperature(0.7)
558 .language("en")
559 .model("whisper-large-v3");
560 let response = client
561 .speech_to_text(request)
562 .expect("Failed to get response");
563 println!("Speech to Text Response: {}", response.text);
564 assert!(!response.text.is_empty());
565 }
566
567 #[tokio::test]
568 async fn test_async_chat_completion() {
569 let api_key = std::env::var("GROQ_API_KEY").unwrap();
570 let client = AsyncGroqClient::new(api_key, None).await;
571
572 let messages1 = vec![ChatCompletionMessage {
573 role: ChatCompletionRoles::User,
574 content: "Hello".to_string(),
575 name: None,
576 }];
577 let request1 = ChatCompletionRequest::new("llama-3.3-70b-versatile", messages1);
578
579 let messages2 = vec![ChatCompletionMessage {
580 role: ChatCompletionRoles::User,
581 content: "How are you?".to_string(),
582 name: None,
583 }];
584 let request2 = ChatCompletionRequest::new("llama-3.3-70b-versatile", messages2);
585
586 let (response1, response2) = tokio::join!(
587 client.chat_completion(request1),
588 client.chat_completion(request2)
589 );
590
591 let response1 = response1.expect("Failed to get response for request 1");
592 let response2 = response2.expect("Failed to get response for request 2");
593
594 println!("Response 1: {}", response1.choices[0].message.content);
595 println!("Response 2: {}", response2.choices[0].message.content);
596
597 assert!(!response1.choices.is_empty());
598 assert!(!response2.choices.is_empty());
599 }
600
601 #[tokio::test]
602 async fn test_async_stream() {
603 let api_key = std::env::var("GROQ_API_KEY").unwrap();
604 let client = AsyncGroqClient::new(api_key, None).await;
605
606 let messages1 = vec![ChatCompletionMessage {
607 role: ChatCompletionRoles::User,
608 content: "Hello!".to_string(),
609 name: None,
610 }];
611 let request1 =
612 ChatCompletionRequest::new("llama-3.3-70b-versatile", messages1).stream(true);
613
614 let messages2 = vec![ChatCompletionMessage {
615 role: ChatCompletionRoles::User,
616 content: "How are you?".to_string(),
617 name: None,
618 }];
619 let request2 =
620 ChatCompletionRequest::new("llama-3.3-70b-versatile", messages2).stream(true);
621
622 let (stream1, stream2) = tokio::join!(client.stream(request1), client.stream(request2));
623
624 let stream1 = stream1.expect("Failed to get response for request 1");
625 let stream2 = stream2.expect("Failed to get response for request 2");
626
627 let mut response1 = String::new();
628 let mut response2 = String::new();
629
630 tokio::pin!(stream1);
631 tokio::pin!(stream2);
632
633 while let Some(item) = stream1.next().await {
634 let delta = item.expect("Failed to get delta from stream 1");
635 if let Some(content) = &delta.choices[0].delta.content {
636 response1.push_str(&content);
637 }
638 }
639 println!();
640 while let Some(item) = stream2.next().await {
641 let delta = item.expect("Failed to get delta from stream 2");
642 if let Some(content) = &delta.choices[0].delta.content {
643 response2.push_str(&content);
644 }
645 }
646 println!();
647
648 println!("Response 1: {}", response1);
649 println!("Response 2: {}", response2);
650
651 assert!(!response1.is_empty());
652 assert!(!response2.is_empty());
653 }
654
655 #[tokio::test]
656 async fn test_async_stream_fail() {
657 let api_key = std::env::var("GROQ_API_KEY").unwrap();
658 let client = AsyncGroqClient::new(api_key, None).await;
659
660 let messages1 = vec![ChatCompletionMessage {
661 role: ChatCompletionRoles::User,
662 content: "Hello!".to_string(),
663 name: None,
664 }];
665 let request = ChatCompletionRequest::new("llama3-70b-8192", messages1).stream(true);
666
667 let stream = client
668 .stream(request)
669 .await
670 .expect("Failed to get response");
671
672 tokio::pin!(stream);
673
674 while let Some(item) = stream.next().await {
675 if let Err(e) = item {
676 let expected_message = r#"API error: {"error":{"message":"The model `llama3-70b-8192` has been decommissioned and is no longer supported. Please refer to https://console.groq.com/docs/deprecations for a recommendation on which model to use instead.","type":"invalid_request_error","code":"model_decommissioned"}}"#;
677 assert_eq!(e.to_string(), expected_message);
678 return;
679 } else {
680 panic!("Expected an error but got a successful response");
681 }
682 }
683 }
684
685 #[tokio::test]
686 async fn test_async_speech_to_text() {
687 let api_key = std::env::var("GROQ_API_KEY").unwrap();
688 let client = AsyncGroqClient::new(api_key, None).await;
689
690 let audio_file_path1 = "onepiece_demo.mp4";
691 let audio_file_path2 = "save.ogg";
692
693 let (audio_data1, audio_data2) = tokio::join!(
694 tokio::fs::read(audio_file_path1),
695 tokio::fs::read(audio_file_path2)
696 );
697
698 let audio_data1 = audio_data1.expect("Failed to read first audio file");
699 let audio_data2 = audio_data2.expect("Failed to read second audio file");
700
701 let (request1, request2) = (
702 SpeechToTextRequest::new(audio_data1)
703 .temperature(0.7)
704 .language("en")
705 .model("whisper-large-v3"),
706 SpeechToTextRequest::new(audio_data2)
707 .temperature(0.7)
708 .language("en")
709 .model("whisper-large-v3"),
710 );
711 let (response1, response2) = tokio::join!(
712 client.speech_to_text(request1),
713 client.speech_to_text(request2)
714 );
715
716 let response1 = response1.expect("Failed to get response for first audio");
717 let response2 = response2.expect("Failed to get response for second audio");
718
719 println!("Speech to Text Response 1: {:?}", response1);
720 println!("Speech to Text Response 2: {:?}", response2);
721
722 assert!(!response1.text.is_empty());
723 assert!(!response2.text.is_empty());
724 }
725}