rig/providers/
groq.rs

1//! Groq API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::groq;
6//!
7//! let client = groq::Client::new("YOUR_API_KEY");
8//!
9//! let gpt4o = client.completion_model(groq::GPT_4O);
10//! ```
11use std::collections::HashMap;
12
13use super::openai::{CompletionResponse, StreamingToolCall, TranscriptionResponse, Usage};
14use crate::client::{
15    ClientBuilderError, CompletionClient, TranscriptionClient, VerifyClient, VerifyError,
16};
17use crate::completion::GetTokenUsage;
18use crate::json_utils::merge;
19use futures::StreamExt;
20
21use crate::streaming::RawStreamingChoice;
22use crate::{
23    OneOrMany,
24    completion::{self, CompletionError, CompletionRequest},
25    json_utils,
26    message::{self, MessageError},
27    providers::openai::ToolDefinition,
28    transcription::{self, TranscriptionError},
29};
30use reqwest::RequestBuilder;
31use reqwest::multipart::Part;
32use rig::client::ProviderClient;
33use rig::impl_conversion_traits;
34use serde::{Deserialize, Serialize};
35use serde_json::{Value, json};
36
37// ================================================================
38// Main Groq Client
39// ================================================================
40const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
41
42pub struct ClientBuilder<'a> {
43    api_key: &'a str,
44    base_url: &'a str,
45    http_client: Option<reqwest::Client>,
46}
47
48impl<'a> ClientBuilder<'a> {
49    pub fn new(api_key: &'a str) -> Self {
50        Self {
51            api_key,
52            base_url: GROQ_API_BASE_URL,
53            http_client: None,
54        }
55    }
56
57    pub fn base_url(mut self, base_url: &'a str) -> Self {
58        self.base_url = base_url;
59        self
60    }
61
62    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
63        self.http_client = Some(client);
64        self
65    }
66
67    pub fn build(self) -> Result<Client, ClientBuilderError> {
68        let http_client = if let Some(http_client) = self.http_client {
69            http_client
70        } else {
71            reqwest::Client::builder().build()?
72        };
73
74        Ok(Client {
75            base_url: self.base_url.to_string(),
76            api_key: self.api_key.to_string(),
77            http_client,
78        })
79    }
80}
81
82#[derive(Clone)]
83pub struct Client {
84    base_url: String,
85    api_key: String,
86    http_client: reqwest::Client,
87}
88
89impl std::fmt::Debug for Client {
90    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91        f.debug_struct("Client")
92            .field("base_url", &self.base_url)
93            .field("http_client", &self.http_client)
94            .field("api_key", &"<REDACTED>")
95            .finish()
96    }
97}
98
99impl Client {
100    /// Create a new Groq client builder.
101    ///
102    /// # Example
103    /// ```
104    /// use rig::providers::groq::{ClientBuilder, self};
105    ///
106    /// // Initialize the Groq client
107    /// let groq = Client::builder("your-groq-api-key")
108    ///    .build()
109    /// ```
110    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
111        ClientBuilder::new(api_key)
112    }
113
114    /// Create a new Groq client with the given API key.
115    ///
116    /// # Panics
117    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
118    pub fn new(api_key: &str) -> Self {
119        Self::builder(api_key)
120            .build()
121            .expect("Groq client should build")
122    }
123
124    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
125        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
126        self.http_client.post(url).bearer_auth(&self.api_key)
127    }
128
129    pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder {
130        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
131        self.http_client.get(url).bearer_auth(&self.api_key)
132    }
133}
134
135impl ProviderClient for Client {
136    /// Create a new Groq client from the `GROQ_API_KEY` environment variable.
137    /// Panics if the environment variable is not set.
138    fn from_env() -> Self {
139        let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
140        Self::new(&api_key)
141    }
142
143    fn from_val(input: crate::client::ProviderValue) -> Self {
144        let crate::client::ProviderValue::Simple(api_key) = input else {
145            panic!("Incorrect provider value type")
146        };
147        Self::new(&api_key)
148    }
149}
150
151impl CompletionClient for Client {
152    type CompletionModel = CompletionModel;
153
154    /// Create a completion model with the given name.
155    ///
156    /// # Example
157    /// ```
158    /// use rig::providers::groq::{Client, self};
159    ///
160    /// // Initialize the Groq client
161    /// let groq = Client::new("your-groq-api-key");
162    ///
163    /// let gpt4 = groq.completion_model(groq::GPT_4);
164    /// ```
165    fn completion_model(&self, model: &str) -> CompletionModel {
166        CompletionModel::new(self.clone(), model)
167    }
168}
169
170impl TranscriptionClient for Client {
171    type TranscriptionModel = TranscriptionModel;
172
173    /// Create a transcription model with the given name.
174    ///
175    /// # Example
176    /// ```
177    /// use rig::providers::groq::{Client, self};
178    ///
179    /// // Initialize the Groq client
180    /// let groq = Client::new("your-groq-api-key");
181    ///
182    /// let gpt4 = groq.transcription_model(groq::WHISPER_LARGE_V3);
183    /// ```
184    fn transcription_model(&self, model: &str) -> TranscriptionModel {
185        TranscriptionModel::new(self.clone(), model)
186    }
187}
188
189impl VerifyClient for Client {
190    #[cfg_attr(feature = "worker", worker::send)]
191    async fn verify(&self) -> Result<(), VerifyError> {
192        let response = self.get("/models").send().await?;
193        match response.status() {
194            reqwest::StatusCode::OK => Ok(()),
195            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
196            reqwest::StatusCode::INTERNAL_SERVER_ERROR
197            | reqwest::StatusCode::SERVICE_UNAVAILABLE
198            | reqwest::StatusCode::BAD_GATEWAY => {
199                Err(VerifyError::ProviderError(response.text().await?))
200            }
201            _ => {
202                response.error_for_status()?;
203                Ok(())
204            }
205        }
206    }
207}
208
209impl_conversion_traits!(
210    AsEmbeddings,
211    AsImageGeneration,
212    AsAudioGeneration for Client
213);
214
215#[derive(Debug, Deserialize)]
216struct ApiErrorResponse {
217    message: String,
218}
219
220#[derive(Debug, Deserialize)]
221#[serde(untagged)]
222enum ApiResponse<T> {
223    Ok(T),
224    Err(ApiErrorResponse),
225}
226
227#[derive(Debug, Serialize, Deserialize)]
228pub struct Message {
229    pub role: String,
230    pub content: Option<String>,
231    #[serde(skip_serializing_if = "Option::is_none")]
232    pub reasoning: Option<String>,
233}
234
235impl TryFrom<Message> for message::Message {
236    type Error = message::MessageError;
237
238    fn try_from(message: Message) -> Result<Self, Self::Error> {
239        match message.role.as_str() {
240            "user" => Ok(Self::User {
241                content: OneOrMany::one(
242                    message
243                        .content
244                        .map(|content| message::UserContent::text(&content))
245                        .ok_or_else(|| {
246                            message::MessageError::ConversionError("Empty user message".to_string())
247                        })?,
248                ),
249            }),
250            "assistant" => Ok(Self::Assistant {
251                id: None,
252                content: OneOrMany::one(
253                    message
254                        .content
255                        .map(|content| message::AssistantContent::text(&content))
256                        .ok_or_else(|| {
257                            message::MessageError::ConversionError(
258                                "Empty assistant message".to_string(),
259                            )
260                        })?,
261                ),
262            }),
263            _ => Err(message::MessageError::ConversionError(format!(
264                "Unknown role: {}",
265                message.role
266            ))),
267        }
268    }
269}
270
271impl TryFrom<message::Message> for Message {
272    type Error = message::MessageError;
273
274    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
275        match message {
276            message::Message::User { content } => Ok(Self {
277                role: "user".to_string(),
278                content: content.iter().find_map(|c| match c {
279                    message::UserContent::Text(text) => Some(text.text.clone()),
280                    _ => None,
281                }),
282                reasoning: None,
283            }),
284            message::Message::Assistant { content, .. } => {
285                let mut text_content: Option<String> = None;
286                let mut groq_reasoning: Option<String> = None;
287
288                for c in content.iter() {
289                    match c {
290                        message::AssistantContent::Text(text) => {
291                            text_content = Some(
292                                text_content
293                                    .map(|mut existing| {
294                                        existing.push('\n');
295                                        existing.push_str(&text.text);
296                                        existing
297                                    })
298                                    .unwrap_or_else(|| text.text.clone()),
299                            );
300                        }
301                        message::AssistantContent::ToolCall(_tool_call) => {
302                            return Err(MessageError::ConversionError(
303                                "Tool calls do not exist on this message".into(),
304                            ));
305                        }
306                        message::AssistantContent::Reasoning(message::Reasoning {
307                            reasoning,
308                            ..
309                        }) => {
310                            groq_reasoning =
311                                Some(reasoning.first().cloned().unwrap_or(String::new()));
312                        }
313                    }
314                }
315
316                Ok(Self {
317                    role: "assistant".to_string(),
318                    content: text_content,
319                    reasoning: groq_reasoning,
320                })
321            }
322        }
323    }
324}
325
326// ================================================================
327// Groq Completion API
328// ================================================================
329/// The `deepseek-r1-distill-llama-70b` model. Used for chat completion.
330pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
331/// The `gemma2-9b-it` model. Used for chat completion.
332pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
333/// The `llama-3.1-8b-instant` model. Used for chat completion.
334pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
335/// The `llama-3.2-11b-vision-preview` model. Used for chat completion.
336pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
337/// The `llama-3.2-1b-preview` model. Used for chat completion.
338pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
339/// The `llama-3.2-3b-preview` model. Used for chat completion.
340pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
341/// The `llama-3.2-90b-vision-preview` model. Used for chat completion.
342pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
343/// The `llama-3.2-70b-specdec` model. Used for chat completion.
344pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
345/// The `llama-3.2-70b-versatile` model. Used for chat completion.
346pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
347/// The `llama-guard-3-8b` model. Used for chat completion.
348pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
349/// The `llama3-70b-8192` model. Used for chat completion.
350pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
351/// The `llama3-8b-8192` model. Used for chat completion.
352pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
353/// The `mixtral-8x7b-32768` model. Used for chat completion.
354pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
355
356#[derive(Clone, Debug)]
357pub struct CompletionModel {
358    client: Client,
359    /// Name of the model (e.g.: deepseek-r1-distill-llama-70b)
360    pub model: String,
361}
362
363impl CompletionModel {
364    pub fn new(client: Client, model: &str) -> Self {
365        Self {
366            client,
367            model: model.to_string(),
368        }
369    }
370
371    fn create_completion_request(
372        &self,
373        completion_request: CompletionRequest,
374    ) -> Result<Value, CompletionError> {
375        // Build up the order of messages (context, chat_history, prompt)
376        let mut partial_history = vec![];
377        if let Some(docs) = completion_request.normalized_documents() {
378            partial_history.push(docs);
379        }
380        partial_history.extend(completion_request.chat_history);
381
382        // Initialize full history with preamble (or empty if non-existent)
383        let mut full_history: Vec<Message> =
384            completion_request
385                .preamble
386                .map_or_else(Vec::new, |preamble| {
387                    vec![Message {
388                        role: "system".to_string(),
389                        content: Some(preamble),
390                        reasoning: None,
391                    }]
392                });
393
394        // Convert and extend the rest of the history
395        full_history.extend(
396            partial_history
397                .into_iter()
398                .map(message::Message::try_into)
399                .collect::<Result<Vec<Message>, _>>()?,
400        );
401
402        let request = if completion_request.tools.is_empty() {
403            json!({
404                "model": self.model,
405                "messages": full_history,
406                "temperature": completion_request.temperature,
407            })
408        } else {
409            json!({
410                "model": self.model,
411                "messages": full_history,
412                "temperature": completion_request.temperature,
413                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
414                "tool_choice": "auto",
415                "reasoning_format": "parsed"
416            })
417        };
418
419        let request = if let Some(params) = completion_request.additional_params {
420            json_utils::merge(request, params)
421        } else {
422            request
423        };
424
425        Ok(request)
426    }
427}
428
429impl completion::CompletionModel for CompletionModel {
430    type Response = CompletionResponse;
431    type StreamingResponse = StreamingCompletionResponse;
432
433    #[cfg_attr(feature = "worker", worker::send)]
434    async fn completion(
435        &self,
436        completion_request: CompletionRequest,
437    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
438        let request = self.create_completion_request(completion_request)?;
439
440        let response = self
441            .client
442            .post("/chat/completions")
443            .json(&request)
444            .send()
445            .await?;
446
447        if response.status().is_success() {
448            match response.json::<ApiResponse<CompletionResponse>>().await? {
449                ApiResponse::Ok(response) => {
450                    tracing::info!(target: "rig",
451                        "groq completion token usage: {:?}",
452                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
453                    );
454                    response.try_into()
455                }
456                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
457            }
458        } else {
459            Err(CompletionError::ProviderError(response.text().await?))
460        }
461    }
462
463    #[cfg_attr(feature = "worker", worker::send)]
464    async fn stream(
465        &self,
466        request: CompletionRequest,
467    ) -> Result<
468        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
469        CompletionError,
470    > {
471        let mut request = self.create_completion_request(request)?;
472
473        request = merge(
474            request,
475            json!({"stream": true, "stream_options": {"include_usage": true}}),
476        );
477
478        let builder = self.client.post("/chat/completions").json(&request);
479
480        send_compatible_streaming_request(builder).await
481    }
482}
483
484// ================================================================
485// Groq Transcription API
486// ================================================================
487pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
488pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
489pub const DISTIL_WHISPER_LARGE_V3: &str = "distil-whisper-large-v3-en";
490
491#[derive(Clone)]
492pub struct TranscriptionModel {
493    client: Client,
494    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
495    pub model: String,
496}
497
498impl TranscriptionModel {
499    pub fn new(client: Client, model: &str) -> Self {
500        Self {
501            client,
502            model: model.to_string(),
503        }
504    }
505}
506impl transcription::TranscriptionModel for TranscriptionModel {
507    type Response = TranscriptionResponse;
508
509    #[cfg_attr(feature = "worker", worker::send)]
510    async fn transcription(
511        &self,
512        request: transcription::TranscriptionRequest,
513    ) -> Result<
514        transcription::TranscriptionResponse<Self::Response>,
515        transcription::TranscriptionError,
516    > {
517        let data = request.data;
518
519        let mut body = reqwest::multipart::Form::new()
520            .text("model", self.model.clone())
521            .text("language", request.language)
522            .part(
523                "file",
524                Part::bytes(data).file_name(request.filename.clone()),
525            );
526
527        if let Some(prompt) = request.prompt {
528            body = body.text("prompt", prompt.clone());
529        }
530
531        if let Some(ref temperature) = request.temperature {
532            body = body.text("temperature", temperature.to_string());
533        }
534
535        if let Some(ref additional_params) = request.additional_params {
536            for (key, value) in additional_params
537                .as_object()
538                .expect("Additional Parameters to OpenAI Transcription should be a map")
539            {
540                body = body.text(key.to_owned(), value.to_string());
541            }
542        }
543
544        let response = self
545            .client
546            .post("audio/transcriptions")
547            .multipart(body)
548            .send()
549            .await?;
550
551        if response.status().is_success() {
552            match response
553                .json::<ApiResponse<TranscriptionResponse>>()
554                .await?
555            {
556                ApiResponse::Ok(response) => response.try_into(),
557                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
558                    api_error_response.message,
559                )),
560            }
561        } else {
562            Err(TranscriptionError::ProviderError(response.text().await?))
563        }
564    }
565}
566
567#[derive(Deserialize, Debug)]
568#[serde(untagged)]
569pub enum StreamingDelta {
570    Reasoning {
571        reasoning: String,
572    },
573    MessageContent {
574        #[serde(default)]
575        content: Option<String>,
576        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
577        tool_calls: Vec<StreamingToolCall>,
578    },
579}
580
581#[derive(Deserialize, Debug)]
582struct StreamingChoice {
583    delta: StreamingDelta,
584}
585
586#[derive(Deserialize, Debug)]
587struct StreamingCompletionChunk {
588    choices: Vec<StreamingChoice>,
589    usage: Option<Usage>,
590}
591
592#[derive(Clone, Deserialize, Serialize, Debug)]
593pub struct StreamingCompletionResponse {
594    pub usage: Usage,
595}
596
597impl GetTokenUsage for StreamingCompletionResponse {
598    fn token_usage(&self) -> Option<crate::completion::Usage> {
599        let mut usage = crate::completion::Usage::new();
600
601        usage.input_tokens = self.usage.prompt_tokens as u64;
602        usage.total_tokens = self.usage.total_tokens as u64;
603        usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
604
605        Some(usage)
606    }
607}
608
609pub async fn send_compatible_streaming_request(
610    request_builder: RequestBuilder,
611) -> Result<
612    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
613    CompletionError,
614> {
615    let response = request_builder.send().await?;
616
617    if !response.status().is_success() {
618        return Err(CompletionError::ProviderError(format!(
619            "{}: {}",
620            response.status(),
621            response.text().await?
622        )));
623    }
624
625    // Handle OpenAI Compatible SSE chunks
626    let inner = Box::pin(async_stream::stream! {
627        let mut stream = response.bytes_stream();
628
629        let mut final_usage = Usage {
630            prompt_tokens: 0,
631            total_tokens: 0
632        };
633
634        let mut partial_data = None;
635        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
636
637        while let Some(chunk_result) = stream.next().await {
638            let chunk = match chunk_result {
639                Ok(c) => c,
640                Err(e) => {
641                    yield Err(CompletionError::from(e));
642                    break;
643                }
644            };
645
646            let text = match String::from_utf8(chunk.to_vec()) {
647                Ok(t) => t,
648                Err(e) => {
649                    yield Err(CompletionError::ResponseError(e.to_string()));
650                    break;
651                }
652            };
653
654
655            for line in text.lines() {
656                let mut line = line.to_string();
657
658                // If there was a remaining part, concat with current line
659                if partial_data.is_some() {
660                    line = format!("{}{}", partial_data.unwrap(), line);
661                    partial_data = None;
662                }
663                // Otherwise full data line
664                else {
665                    let Some(data) = line.strip_prefix("data:") else {
666                        continue;
667                    };
668
669                    let data = data.trim_start();
670
671                    // Partial data, split somewhere in the middle
672                    if !line.ends_with("}") {
673                        partial_data = Some(data.to_string());
674                    } else {
675                        line = data.to_string();
676                    }
677                }
678
679                let data = serde_json::from_str::<StreamingCompletionChunk>(&line);
680
681                let Ok(data) = data else {
682                    let err = data.unwrap_err();
683                    tracing::debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
684                    continue;
685                };
686
687
688                if let Some(choice) = data.choices.first() {
689                    let delta = &choice.delta;
690
691                    match delta {
692                        StreamingDelta::Reasoning { reasoning } => {
693                            yield Ok(crate::streaming::RawStreamingChoice::Reasoning { id: None, reasoning: reasoning.to_string() })
694                        },
695                        StreamingDelta::MessageContent { content, tool_calls } => {
696                            if !tool_calls.is_empty() {
697                                for tool_call in tool_calls {
698                                    let function = tool_call.function.clone();
699                                    // Start of tool call
700                                    // name: Some(String)
701                                    // arguments: None
702                                    if function.name.is_some() && function.arguments.is_empty() {
703                                        let id = tool_call.id.clone().unwrap_or("".to_string());
704
705                                        calls.insert(tool_call.index, (id, function.name.clone().unwrap(), "".to_string()));
706                                    }
707                                    // Part of tool call
708                                    // name: None or Empty String
709                                    // arguments: Some(String)
710                                    else if function.name.clone().is_none_or(|s| s.is_empty()) && !function.arguments.is_empty() {
711                                        let Some((id, name, arguments)) = calls.get(&tool_call.index) else {
712                                            tracing::debug!("Partial tool call received but tool call was never started.");
713                                            continue;
714                                        };
715
716                                        let new_arguments = &function.arguments;
717                                        let arguments = format!("{arguments}{new_arguments}");
718
719                                        calls.insert(tool_call.index, (id.clone(), name.clone(), arguments));
720                                    }
721                                    // Entire tool call
722                                    else {
723                                        let id = tool_call.id.clone().unwrap_or("".to_string());
724                                        let name = function.name.expect("function name should be present for complete tool call");
725                                        let arguments = function.arguments;
726                                        let Ok(arguments) = serde_json::from_str(&arguments) else {
727                                            tracing::debug!("Couldn't serialize '{}' as a json value", arguments);
728                                            continue;
729                                        };
730
731                                        yield Ok(crate::streaming::RawStreamingChoice::ToolCall {id, name, arguments, call_id: None })
732                                    }
733                                }
734                            }
735
736                            if let Some(content) = &content {
737                                yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()))
738                            }
739                        }
740                    }
741                }
742
743
744                if let Some(usage) = data.usage {
745                    final_usage = usage.clone();
746                }
747            }
748        }
749
750        for (_, (id, name, arguments)) in calls {
751            let Ok(arguments) = serde_json::from_str(&arguments) else {
752                continue;
753            };
754
755            yield Ok(RawStreamingChoice::ToolCall {id, name, arguments, call_id: None });
756        }
757
758        yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
759            usage: final_usage.clone()
760        }))
761    });
762
763    Ok(crate::streaming::StreamingCompletionResponse::stream(inner))
764}