mod message;
use futures::StreamExt;
pub use message::*;
use reqwest::{
Client as AClient, Response as AResponse,
blocking::multipart::{Form, Part},
blocking::{Client, Response},
multipart::{Form as AForm, Part as APart},
};
use serde_json::{Deserializer, StreamDeserializer, Value, json};
use std::sync::Arc;
pub struct AsyncGroqClient {
api_key: String,
client: Arc<AClient>,
endpoint: String,
}
impl AsyncGroqClient {
pub async fn new(api_key: String, endpoint: Option<String>) -> Self {
let ep = endpoint.unwrap_or_else(|| String::from("https://api.groq.com/openai/v1"));
Self {
api_key,
client: Arc::new(AClient::new()),
endpoint: ep,
}
}
async fn send_request(&self, body: Value, link: &str) -> Result<reqwest::Response, GroqError> {
let res = self
.client
.post(link)
.header("Content-Type", "application/json")
.header("Authorization", &format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await?;
Ok(res)
}
pub async fn speech_to_text(
&self,
request: SpeechToTextRequest,
) -> Result<SpeechToTextResponse, GroqError> {
let file = request.file;
let temperature = request.temperature;
let language = request.language;
let english_text = request.english_text;
let model = request.model;
let mut form = AForm::new().part("file", APart::bytes(file).file_name("audio.wav"));
if let Some(temp) = temperature {
form = form.text("temperature", temp.to_string());
}
if let Some(lang) = language {
form = form.text("language", lang);
}
let link_addition = if english_text {
"/audio/translations"
} else {
"/audio/transcriptions"
};
if let Some(mdl) = model {
form = form.text("model", mdl);
}
let link = format!("{}{}", self.endpoint, link_addition);
let response = self
.client
.post(&link)
.header("Authorization", &format!("Bearer {}", self.api_key))
.multipart(form)
.send()
.await?;
let speech_to_text_response: SpeechToTextResponse = response.json().await?;
Ok(speech_to_text_response)
}
async fn send_response(
&self,
request: ChatCompletionRequest,
stream: bool,
) -> Result<reqwest::Response, GroqError> {
let messages = request
.messages
.iter()
.map(|m| {
let mut msg_json = json!({
"role": m.role,
"content": m.content,
});
if let Some(name) = &m.name {
msg_json["name"] = json!(name);
}
msg_json
})
.collect::<Vec<Value>>();
let mut body = json!({
"model": request.model,
"messages": messages,
"temperature": request.temperature.unwrap_or(1.0),
"max_tokens": request.max_tokens.unwrap_or(1024),
"top_p": request.top_p.unwrap_or(1.0),
"stream": request.stream.unwrap_or(stream),
});
if let Some(stop) = &request.stop {
body["stop"] = json!(stop);
}
if let Some(seed) = &request.seed {
body["seed"] = json!(seed);
}
let response = self
.send_request(body, &format!("{}/chat/completions", self.endpoint))
.await?;
Ok(response)
}
pub async fn chat_completion(
&self,
request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, GroqError> {
if Some(true) == request.stream {
return Err(GroqError::InvalidRequest(
"Stream parameter must be set to false for non-streaming responses.".to_string(),
));
}
let response = self.send_response(request, false).await?;
let response = self.parse_response(response).await?;
let chat_completion_response: ChatCompletionResponse = serde_json::from_value(response)?;
Ok(chat_completion_response)
}
pub async fn stream(
&self,
request: ChatCompletionRequest,
) -> Result<
impl futures::Stream<Item = Result<ChatCompletionDeltaResponse, GroqError>>,
GroqError,
> {
if Some(false) == request.stream {
return Err(GroqError::InvalidRequest(
"Stream parameter must be set to true for streaming responses.".to_string(),
));
}
let response = self.send_response(request, true).await?;
let stream_response = response.bytes_stream();
let prefix = "data: ";
Ok(futures::stream::unfold(
(stream_response, String::new()),
move |(mut stream_response, mut resp_string)| async move {
loop {
resp_string = resp_string
.strip_prefix(&prefix)
.unwrap_or(&resp_string)
.to_string();
let mut stream: StreamDeserializer<_, ChatCompletionDeltaResponse> =
Deserializer::from_slice(resp_string.as_bytes()).into_iter();
if let Some(line) = stream.next() {
if let Ok(line) = line {
let offset = stream.byte_offset();
resp_string = resp_string[offset..].trim().to_string();
return Some((Ok(line), (stream_response, resp_string)));
} else if resp_string == "[DONE]" {
return None;
}
}
if let Some(chunk) = stream_response.next().await {
if let Err(e) = chunk {
return Some((Err(GroqError::from(e)), (stream_response, resp_string)));
}
let chunk = String::from_utf8_lossy(&chunk.unwrap()).trim().to_string();
resp_string.push_str(&chunk);
continue;
} else if resp_string.is_empty() {
return None;
} else {
return Some((
Err(GroqError::DeserializationError {
message: resp_string.clone(),
type_: "DeserializationError".to_string(),
}),
(stream_response, resp_string),
));
}
}
},
))
}
async fn parse_response(&self, response: AResponse) -> Result<Value, GroqError> {
let status = response.status();
let body: Value = response.json().await?;
if !status.is_success()
&& let Some(error) = body.get("error")
{
return Err(GroqError::ApiError {
message: error["message"]
.as_str()
.unwrap_or("Unknown error")
.to_string(),
type_: error["type"]
.as_str()
.unwrap_or("unknown_error")
.to_string(),
});
}
Ok(body)
}
}
pub struct GroqClient {
api_key: String,
client: Client,
endpoint: String,
}
impl GroqClient {
pub fn new(api_key: String, endpoint: Option<String>) -> Self {
let ep = endpoint.unwrap_or_else(|| String::from("https://api.groq.com/openai/v1"));
Self {
api_key,
client: Client::new(),
endpoint: ep,
}
}
fn send_request(&self, body: Value, link: &str) -> Result<Value, GroqError> {
let res = self
.client
.post(link)
.header("Content-Type", "application/json")
.header("Authorization", &format!("Bearer {}", self.api_key))
.json(&body)
.send()?;
parse_response(res)
}
pub fn speech_to_text(
&self,
request: SpeechToTextRequest,
) -> Result<SpeechToTextResponse, GroqError> {
let file = request.file;
let temperature = request.temperature;
let language = request.language;
let english_text = request.english_text;
let model = request.model;
let prompt = request.prompt;
let mut form = Form::new().part("file", Part::bytes(file).file_name("audio.wav"));
if let Some(temp) = temperature {
form = form.text("temperature", temp.to_string());
}
if let Some(lang) = language {
form = form.text("language", lang);
}
let link_addition = if english_text {
"/audio/translations"
} else {
"/audio/transcriptions"
};
if let Some(mdl) = model {
form = form.text("model", mdl);
}
if let Some(prompt) = prompt {
form = form.text("prompt", prompt.to_string());
}
let link = format!("{}{}", self.endpoint, link_addition);
let response = self
.client
.post(link)
.header("Authorization", &format!("Bearer {}", self.api_key))
.multipart(form)
.send()?;
let speech_to_text_response: SpeechToTextResponse = response.json()?;
Ok(speech_to_text_response)
}
pub fn chat_completion(
&self,
request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, GroqError> {
let messages = request
.messages
.iter()
.map(|m| {
let mut msg_json = json!({
"role": m.role,
"content": m.content,
});
if let Some(name) = &m.name {
msg_json["name"] = json!(name);
}
msg_json
})
.collect::<Vec<_>>();
let mut body = json!({
"model": request.model,
"messages": messages,
"temperature": request.temperature.unwrap_or(1.0),
"max_tokens": request.max_tokens.unwrap_or(1024),
"top_p": request.top_p.unwrap_or(1.0),
"stream": request.stream.unwrap_or(false),
});
if let Some(stop) = &request.stop {
body["stop"] = json!(stop);
}
if let Some(seed) = &request.seed {
body["seed"] = json!(seed);
}
let response = self.send_request(body, &format!("{}/chat/completions", self.endpoint))?;
let chat_completion_response: ChatCompletionResponse = serde_json::from_value(response)?;
Ok(chat_completion_response)
}
}
fn parse_response(response: Response) -> Result<Value, GroqError> {
let status = response.status();
let body: Value = response.json()?;
if !status.is_success()
&& let Some(error) = body.get("error")
{
return Err(GroqError::ApiError {
message: error["message"]
.as_str()
.unwrap_or("Unknown error")
.to_string(),
type_: error["type"]
.as_str()
.unwrap_or("unknown_error")
.to_string(),
});
}
Ok(body)
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs::File;
use std::io::Read;
use tokio;
#[test]
fn test_chat_completion() {
let api_key = std::env::var("GROQ_API_KEY").unwrap();
let client = GroqClient::new(api_key.to_string(), None);
let messages = vec![ChatCompletionMessage {
role: ChatCompletionRoles::User,
content: "Hello".to_string(),
name: None,
}];
let request = ChatCompletionRequest::new("llama3-70b-8192", messages);
let response = client.chat_completion(request).unwrap();
println!("{:?}", response);
assert!(!response.choices.is_empty());
}
#[test]
fn test_speech_to_text() {
let api_key = std::env::var("GROQ_API_KEY").unwrap();
let client = GroqClient::new(api_key.to_string(), None);
let audio_file_path = "onepiece_demo.mp4";
let mut file = File::open(audio_file_path).expect("Failed to open audio file");
let mut audio_data = Vec::new();
file.read_to_end(&mut audio_data)
.expect("Failed to read audio file");
let request = SpeechToTextRequest::new(audio_data)
.temperature(0.7)
.language("en")
.model("whisper-large-v3");
let response = client
.speech_to_text(request)
.expect("Failed to get response");
println!("Speech to Text Response: {}", response.text);
assert!(!response.text.is_empty());
}
#[tokio::test]
async fn test_async_chat_completion() {
let api_key = std::env::var("GROQ_API_KEY").unwrap();
let client = AsyncGroqClient::new(api_key, None).await;
let messages1 = vec![ChatCompletionMessage {
role: ChatCompletionRoles::User,
content: "Hello".to_string(),
name: None,
}];
let request1 = ChatCompletionRequest::new("llama-3.3-70b-versatile", messages1);
let messages2 = vec![ChatCompletionMessage {
role: ChatCompletionRoles::User,
content: "How are you?".to_string(),
name: None,
}];
let request2 = ChatCompletionRequest::new("llama-3.3-70b-versatile", messages2);
let (response1, response2) = tokio::join!(
client.chat_completion(request1),
client.chat_completion(request2)
);
let response1 = response1.expect("Failed to get response for request 1");
let response2 = response2.expect("Failed to get response for request 2");
println!("Response 1: {}", response1.choices[0].message.content);
println!("Response 2: {}", response2.choices[0].message.content);
assert!(!response1.choices.is_empty());
assert!(!response2.choices.is_empty());
}
#[tokio::test]
async fn test_async_stream() {
let api_key = std::env::var("GROQ_API_KEY").unwrap();
let client = AsyncGroqClient::new(api_key, None).await;
let messages1 = vec![ChatCompletionMessage {
role: ChatCompletionRoles::User,
content: "Hello!".to_string(),
name: None,
}];
let request1 =
ChatCompletionRequest::new("llama-3.3-70b-versatile", messages1).stream(true);
let messages2 = vec![ChatCompletionMessage {
role: ChatCompletionRoles::User,
content: "How are you?".to_string(),
name: None,
}];
let request2 =
ChatCompletionRequest::new("llama-3.3-70b-versatile", messages2).stream(true);
let (stream1, stream2) = tokio::join!(client.stream(request1), client.stream(request2));
let stream1 = stream1.expect("Failed to get response for request 1");
let stream2 = stream2.expect("Failed to get response for request 2");
let mut response1 = String::new();
let mut response2 = String::new();
tokio::pin!(stream1);
tokio::pin!(stream2);
while let Some(item) = stream1.next().await {
let delta = item.expect("Failed to get delta from stream 1");
if let Some(content) = &delta.choices[0].delta.content {
response1.push_str(&content);
}
}
println!();
while let Some(item) = stream2.next().await {
let delta = item.expect("Failed to get delta from stream 2");
if let Some(content) = &delta.choices[0].delta.content {
response2.push_str(&content);
}
}
println!();
println!("Response 1: {}", response1);
println!("Response 2: {}", response2);
assert!(!response1.is_empty());
assert!(!response2.is_empty());
}
#[tokio::test]
async fn test_async_stream_fail() {
let api_key = std::env::var("GROQ_API_KEY").unwrap();
let client = AsyncGroqClient::new(api_key, None).await;
let messages1 = vec![ChatCompletionMessage {
role: ChatCompletionRoles::User,
content: "Hello!".to_string(),
name: None,
}];
let request = ChatCompletionRequest::new("llama3-70b-8192", messages1).stream(true);
let stream = client
.stream(request)
.await
.expect("Failed to get response");
tokio::pin!(stream);
while let Some(item) = stream.next().await {
if let Err(e) = item {
let expected_message = r#"Deserialization error: {"error":{"message":"The model `llama3-70b-8192` has been decommissioned and is no longer supported. Please refer to https://console.groq.com/docs/deprecations for a recommendation on which model to use instead.","type":"invalid_request_error","code":"model_decommissioned"}}"#;
assert_eq!(e.to_string(), expected_message);
return;
} else {
panic!("Expected an error but got a successful response");
}
}
}
#[tokio::test]
async fn test_async_speech_to_text() {
let api_key = std::env::var("GROQ_API_KEY").unwrap();
let client = AsyncGroqClient::new(api_key, None).await;
let audio_file_path1 = "onepiece_demo.mp4";
let audio_file_path2 = "save.ogg";
let (audio_data1, audio_data2) = tokio::join!(
tokio::fs::read(audio_file_path1),
tokio::fs::read(audio_file_path2)
);
let audio_data1 = audio_data1.expect("Failed to read first audio file");
let audio_data2 = audio_data2.expect("Failed to read second audio file");
let (request1, request2) = (
SpeechToTextRequest::new(audio_data1)
.temperature(0.7)
.language("en")
.model("whisper-large-v3"),
SpeechToTextRequest::new(audio_data2)
.temperature(0.7)
.language("en")
.model("whisper-large-v3"),
);
let (response1, response2) = tokio::join!(
client.speech_to_text(request1),
client.speech_to_text(request2)
);
let response1 = response1.expect("Failed to get response for first audio");
let response2 = response2.expect("Failed to get response for second audio");
println!("Speech to Text Response 1: {:?}", response1);
println!("Speech to Text Response 2: {:?}", response2);
assert!(!response1.text.is_empty());
assert!(!response2.text.is_empty());
}
}