1mod message;
2pub use message::*;
3use reqwest::{
4 blocking::multipart::{Form, Part},
5 blocking::{Client, Response},
6 multipart::{Form as AForm, Part as APart},
7 Client as AClient, Response as AResponse,
8};
9use serde_json::{json, Value};
10use std::sync::Arc;
11
12pub struct AsyncGroqClient {
31 api_key: String,
32 client: Arc<AClient>,
33 endpoint: String,
34}
35
36impl AsyncGroqClient {
37 pub async fn new(api_key: String, endpoint: Option<String>) -> Self {
39 let ep = endpoint.unwrap_or_else(|| String::from("https://api.groq.com/openai/v1"));
40 Self {
41 api_key,
42 client: Arc::new(AClient::new()),
43 endpoint: ep,
44 }
45 }
46
47 async fn send_request(&self, body: Value, link: &str) -> Result<Value, GroqError> {
58 let res = self
59 .client
60 .post(link)
61 .header("Content-Type", "application/json")
62 .header("Authorization", &format!("Bearer {}", self.api_key))
63 .json(&body)
64 .send()
65 .await?;
66
67 self.parse_response(res).await
68 }
69
70 pub async fn speech_to_text(
80 &self,
81 request: SpeechToTextRequest,
82 ) -> Result<SpeechToTextResponse, GroqError> {
83 let file = request.file;
84 let temperature = request.temperature;
85 let language = request.language;
86 let english_text = request.english_text;
87 let model = request.model;
88
89 let mut form = AForm::new().part("file", APart::bytes(file).file_name("audio.wav"));
90 if let Some(temp) = temperature {
91 form = form.text("temperature", temp.to_string());
92 }
93 if let Some(lang) = language {
94 form = form.text("language", lang);
95 }
96
97 let link_addition = if english_text {
98 "/audio/translations"
99 } else {
100 "/audio/transcriptions"
101 };
102 if let Some(mdl) = model {
103 form = form.text("model", mdl);
104 }
105
106 let link = format!("{}{}", self.endpoint, link_addition);
107 let response = self
108 .client
109 .post(&link)
110 .header("Authorization", &format!("Bearer {}", self.api_key))
111 .multipart(form)
112 .send()
113 .await?;
114
115 let speech_to_text_response: SpeechToTextResponse = response.json().await?;
116 Ok(speech_to_text_response)
117 }
118
119 pub async fn chat_completion(
129 &self,
130 request: ChatCompletionRequest,
131 ) -> Result<ChatCompletionResponse, GroqError> {
132 let messages = request
133 .messages
134 .iter()
135 .map(|m| {
136 let mut msg_json = json!({
137 "role": m.role,
138 "content": m.content,
139 });
140 if let Some(name) = &m.name {
141 msg_json["name"] = json!(name);
142 }
143 msg_json
144 })
145 .collect::<Vec<Value>>();
146
147 let mut body = json!({
148 "model": request.model,
149 "messages": messages,
150 "temperature": request.temperature.unwrap_or(1.0),
151 "max_tokens": request.max_tokens.unwrap_or(1024),
152 "top_p": request.top_p.unwrap_or(1.0),
153 "stream": request.stream.unwrap_or(false),
154 });
155
156 if let Some(stop) = &request.stop {
157 body["stop"] = json!(stop);
158 }
159 if let Some(seed) = &request.seed {
160 body["seed"] = json!(seed);
161 }
162
163 let response = self
164 .send_request(body, &format!("{}/chat/completions", self.endpoint))
165 .await?;
166 let chat_completion_response: ChatCompletionResponse = serde_json::from_value(response)?;
167 Ok(chat_completion_response)
168 }
169
170 async fn parse_response(&self, response: AResponse) -> Result<Value, GroqError> {
180 let status = response.status();
181 let body: Value = response.json().await?;
182
183 if !status.is_success() {
184 if let Some(error) = body.get("error") {
185 return Err(GroqError::ApiError {
186 message: error["message"]
187 .as_str()
188 .unwrap_or("Unknown error")
189 .to_string(),
190 type_: error["type"]
191 .as_str()
192 .unwrap_or("unknown_error")
193 .to_string(),
194 });
195 }
196 }
197
198 Ok(body)
199 }
200}
201
202pub struct GroqClient {
221 api_key: String,
222 client: Client,
223 endpoint: String,
224}
225
226impl GroqClient {
227 pub fn new(api_key: String, endpoint: Option<String>) -> Self {
238 let ep = endpoint.unwrap_or_else(|| String::from("https://api.groq.com/openai/v1"));
239 Self {
240 api_key,
241 client: Client::new(),
242 endpoint: ep,
243 }
244 }
245
246 fn send_request(&self, body: Value, link: &str) -> Result<Value, GroqError> {
261 let res = self
262 .client
263 .post(link)
264 .header("Content-Type", "application/json")
265 .header("Authorization", &format!("Bearer {}", self.api_key))
266 .json(&body)
267 .send()?;
268
269 parse_response(res)
270 }
271
272 pub fn speech_to_text(
286 &self,
287 request: SpeechToTextRequest,
288 ) -> Result<SpeechToTextResponse, GroqError> {
289 let file = request.file;
291 let temperature = request.temperature;
292 let language = request.language;
293 let english_text = request.english_text;
294 let model = request.model;
295 let prompt = request.prompt;
296 let mut form = Form::new().part("file", Part::bytes(file).file_name("audio.wav"));
297
298 if let Some(temp) = temperature {
299 form = form.text("temperature", temp.to_string());
300 }
301
302 if let Some(lang) = language {
303 form = form.text("language", lang);
304 }
305
306 let link_addition = if english_text {
307 "/audio/translations"
308 } else {
309 "/audio/transcriptions"
310 };
311
312 if let Some(mdl) = model {
313 form = form.text("model", mdl);
314 }
315 if let Some(prompt) = prompt {
316 form = form.text("prompt", prompt.to_string());
317 }
318
319 let link = format!("{}{}", self.endpoint, link_addition);
320 let response = self
321 .client
322 .post(link)
323 .header("Authorization", &format!("Bearer {}", self.api_key))
324 .multipart(form)
325 .send()?;
326
327 let speech_to_text_response: SpeechToTextResponse = response.json()?;
328 Ok(speech_to_text_response)
329 }
330
331 pub fn chat_completion(
341 &self,
342 request: ChatCompletionRequest,
343 ) -> Result<ChatCompletionResponse, GroqError> {
344 let messages = request
345 .messages
346 .iter()
347 .map(|m| {
348 let mut msg_json = json!({
349 "role": m.role,
350 "content": m.content,
351 });
352 if let Some(name) = &m.name {
353 msg_json["name"] = json!(name);
354 }
355 msg_json
356 })
357 .collect::<Vec<_>>();
358
359 let mut body = json!({
360 "model": request.model,
361 "messages": messages,
362 "temperature": request.temperature.unwrap_or(1.0),
363 "max_tokens": request.max_tokens.unwrap_or(1024),
364 "top_p": request.top_p.unwrap_or(1.0),
365 "stream": request.stream.unwrap_or(false),
366 });
367
368 if let Some(stop) = &request.stop {
369 body["stop"] = json!(stop);
370 }
371 if let Some(seed) = &request.seed {
372 body["seed"] = json!(seed);
373 }
374
375 let response = self.send_request(body, &format!("{}/chat/completions", self.endpoint))?;
376 let chat_completion_response: ChatCompletionResponse = serde_json::from_value(response)?;
377 Ok(chat_completion_response)
378 }
379}
380
381fn parse_response(response: Response) -> Result<Value, GroqError> {
395 let status = response.status();
396 let body: Value = response.json()?;
397
398 if !status.is_success() {
399 if let Some(error) = body.get("error") {
400 return Err(GroqError::ApiError {
401 message: error["message"]
402 .as_str()
403 .unwrap_or("Unknown error")
404 .to_string(),
405 type_: error["type"]
406 .as_str()
407 .unwrap_or("unknown_error")
408 .to_string(),
409 });
410 }
411 }
412
413 Ok(body)
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419 use std::fs::File;
420 use std::io::Read;
421 use tokio;
422
423 #[test]
424 fn test_chat_completion() {
425 let api_key = std::env::var("GROQ_API_KEY").unwrap();
426 let client = GroqClient::new(api_key.to_string(), None);
427 let messages = vec![ChatCompletionMessage {
428 role: ChatCompletionRoles::User,
429 content: "Hello".to_string(),
430 name: None,
431 }];
432 let request = ChatCompletionRequest::new("llama3-70b-8192", messages);
433 let response = client.chat_completion(request).unwrap();
434 println!("{:?}", response);
435 assert!(!response.choices.is_empty());
436 }
437
438 #[test]
439 fn test_speech_to_text() {
440 let api_key = std::env::var("GROQ_API_KEY").unwrap();
441 let client = GroqClient::new(api_key.to_string(), None);
442 let audio_file_path = "onepiece_demo.mp4";
443 let mut file = File::open(audio_file_path).expect("Failed to open audio file");
444 let mut audio_data = Vec::new();
445 file.read_to_end(&mut audio_data)
446 .expect("Failed to read audio file");
447 let request = SpeechToTextRequest::new(audio_data)
448 .temperature(0.7)
449 .language("en")
450 .model("whisper-large-v3");
451 let response = client
452 .speech_to_text(request)
453 .expect("Failed to get response");
454 println!("Speech to Text Response: {}", response.text);
455 assert!(!response.text.is_empty());
456 }
457
458 #[tokio::test]
459 async fn test_async_chat_completion() {
460 let api_key = std::env::var("GROQ_API_KEY").unwrap();
461 let client = AsyncGroqClient::new(api_key, None).await;
462
463 let messages1 = vec![ChatCompletionMessage {
464 role: ChatCompletionRoles::User,
465 content: "Hello".to_string(),
466 name: None,
467 }];
468 let request1 = ChatCompletionRequest::new("llama3-70b-8192", messages1);
469
470 let messages2 = vec![ChatCompletionMessage {
471 role: ChatCompletionRoles::User,
472 content: "How are you?".to_string(),
473 name: None,
474 }];
475 let request2 = ChatCompletionRequest::new("llama3-70b-8192", messages2);
476
477 let (response1, response2) = tokio::join!(
478 client.chat_completion(request1),
479 client.chat_completion(request2)
480 );
481
482 let response1 = response1.expect("Failed to get response for request 1");
483 let response2 = response2.expect("Failed to get response for request 2");
484
485 println!("Response 1: {}", response1.choices[0].message.content);
486 println!("Response 2: {}", response2.choices[0].message.content);
487
488 assert!(!response1.choices.is_empty());
489 assert!(!response2.choices.is_empty());
490 }
491
492 #[tokio::test]
493 async fn test_async_speech_to_text() {
494 let api_key = std::env::var("GROQ_API_KEY").unwrap();
495 let client = AsyncGroqClient::new(api_key, None).await;
496
497 let audio_file_path1 = "onepiece_demo.mp4";
498 let audio_file_path2 = "save.ogg";
499
500 let (audio_data1, audio_data2) = tokio::join!(
501 tokio::fs::read(audio_file_path1),
502 tokio::fs::read(audio_file_path2)
503 );
504
505 let audio_data1 = audio_data1.expect("Failed to read first audio file");
506 let audio_data2 = audio_data2.expect("Failed to read second audio file");
507
508 let (request1, request2) = (
509 SpeechToTextRequest::new(audio_data1)
510 .temperature(0.7)
511 .language("en")
512 .model("whisper-large-v3"),
513 SpeechToTextRequest::new(audio_data2)
514 .temperature(0.7)
515 .language("en")
516 .model("whisper-large-v3"),
517 );
518 let (response1, response2) = tokio::join!(
519 client.speech_to_text(request1),
520 client.speech_to_text(request2)
521 );
522
523 let response1 = response1.expect("Failed to get response for first audio");
524 let response2 = response2.expect("Failed to get response for second audio");
525
526 println!("Speech to Text Response 1: {:?}", response1);
527 println!("Speech to Text Response 2: {:?}", response2);
528
529 assert!(!response1.text.is_empty());
530 assert!(!response2.text.is_empty());
531 }
532}