rig/providers/
groq.rs

1//! Groq API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::groq;
6//!
7//! let client = groq::Client::new("YOUR_API_KEY");
8//!
9//! let gpt4o = client.completion_model(groq::GPT_4O);
10//! ```
11use std::collections::HashMap;
12
13use super::openai::{CompletionResponse, StreamingToolCall, TranscriptionResponse, Usage};
14use crate::client::{ClientBuilderError, CompletionClient, TranscriptionClient};
15use crate::completion::GetTokenUsage;
16use crate::json_utils::merge;
17use futures::StreamExt;
18
19use crate::streaming::RawStreamingChoice;
20use crate::{
21    OneOrMany,
22    completion::{self, CompletionError, CompletionRequest},
23    json_utils,
24    message::{self, MessageError},
25    providers::openai::ToolDefinition,
26    transcription::{self, TranscriptionError},
27};
28use reqwest::RequestBuilder;
29use reqwest::multipart::Part;
30use rig::client::ProviderClient;
31use rig::impl_conversion_traits;
32use serde::{Deserialize, Serialize};
33use serde_json::{Value, json};
34
35// ================================================================
36// Main Groq Client
37// ================================================================
38const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
39
40pub struct ClientBuilder<'a> {
41    api_key: &'a str,
42    base_url: &'a str,
43    http_client: Option<reqwest::Client>,
44}
45
46impl<'a> ClientBuilder<'a> {
47    pub fn new(api_key: &'a str) -> Self {
48        Self {
49            api_key,
50            base_url: GROQ_API_BASE_URL,
51            http_client: None,
52        }
53    }
54
55    pub fn base_url(mut self, base_url: &'a str) -> Self {
56        self.base_url = base_url;
57        self
58    }
59
60    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
61        self.http_client = Some(client);
62        self
63    }
64
65    pub fn build(self) -> Result<Client, ClientBuilderError> {
66        let http_client = if let Some(http_client) = self.http_client {
67            http_client
68        } else {
69            reqwest::Client::builder().build()?
70        };
71
72        Ok(Client {
73            base_url: self.base_url.to_string(),
74            api_key: self.api_key.to_string(),
75            http_client,
76        })
77    }
78}
79
80#[derive(Clone)]
81pub struct Client {
82    base_url: String,
83    api_key: String,
84    http_client: reqwest::Client,
85}
86
87impl std::fmt::Debug for Client {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        f.debug_struct("Client")
90            .field("base_url", &self.base_url)
91            .field("http_client", &self.http_client)
92            .field("api_key", &"<REDACTED>")
93            .finish()
94    }
95}
96
97impl Client {
98    /// Create a new Groq client builder.
99    ///
100    /// # Example
101    /// ```
102    /// use rig::providers::groq::{ClientBuilder, self};
103    ///
104    /// // Initialize the Groq client
105    /// let groq = Client::builder("your-groq-api-key")
106    ///    .build()
107    /// ```
108    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
109        ClientBuilder::new(api_key)
110    }
111
112    /// Create a new Groq client with the given API key.
113    ///
114    /// # Panics
115    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
116    pub fn new(api_key: &str) -> Self {
117        Self::builder(api_key)
118            .build()
119            .expect("Groq client should build")
120    }
121
122    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
123        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
124        self.http_client.post(url).bearer_auth(&self.api_key)
125    }
126}
127
128impl ProviderClient for Client {
129    /// Create a new Groq client from the `GROQ_API_KEY` environment variable.
130    /// Panics if the environment variable is not set.
131    fn from_env() -> Self {
132        let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
133        Self::new(&api_key)
134    }
135
136    fn from_val(input: crate::client::ProviderValue) -> Self {
137        let crate::client::ProviderValue::Simple(api_key) = input else {
138            panic!("Incorrect provider value type")
139        };
140        Self::new(&api_key)
141    }
142}
143
144impl CompletionClient for Client {
145    type CompletionModel = CompletionModel;
146
147    /// Create a completion model with the given name.
148    ///
149    /// # Example
150    /// ```
151    /// use rig::providers::groq::{Client, self};
152    ///
153    /// // Initialize the Groq client
154    /// let groq = Client::new("your-groq-api-key");
155    ///
156    /// let gpt4 = groq.completion_model(groq::GPT_4);
157    /// ```
158    fn completion_model(&self, model: &str) -> CompletionModel {
159        CompletionModel::new(self.clone(), model)
160    }
161}
162
163impl TranscriptionClient for Client {
164    type TranscriptionModel = TranscriptionModel;
165
166    /// Create a transcription model with the given name.
167    ///
168    /// # Example
169    /// ```
170    /// use rig::providers::groq::{Client, self};
171    ///
172    /// // Initialize the Groq client
173    /// let groq = Client::new("your-groq-api-key");
174    ///
175    /// let gpt4 = groq.transcription_model(groq::WHISPER_LARGE_V3);
176    /// ```
177    fn transcription_model(&self, model: &str) -> TranscriptionModel {
178        TranscriptionModel::new(self.clone(), model)
179    }
180}
181
182impl_conversion_traits!(
183    AsEmbeddings,
184    AsImageGeneration,
185    AsAudioGeneration for Client
186);
187
188#[derive(Debug, Deserialize)]
189struct ApiErrorResponse {
190    message: String,
191}
192
193#[derive(Debug, Deserialize)]
194#[serde(untagged)]
195enum ApiResponse<T> {
196    Ok(T),
197    Err(ApiErrorResponse),
198}
199
200#[derive(Debug, Serialize, Deserialize)]
201pub struct Message {
202    pub role: String,
203    pub content: Option<String>,
204    #[serde(skip_serializing_if = "Option::is_none")]
205    pub reasoning: Option<String>,
206}
207
208impl TryFrom<Message> for message::Message {
209    type Error = message::MessageError;
210
211    fn try_from(message: Message) -> Result<Self, Self::Error> {
212        match message.role.as_str() {
213            "user" => Ok(Self::User {
214                content: OneOrMany::one(
215                    message
216                        .content
217                        .map(|content| message::UserContent::text(&content))
218                        .ok_or_else(|| {
219                            message::MessageError::ConversionError("Empty user message".to_string())
220                        })?,
221                ),
222            }),
223            "assistant" => Ok(Self::Assistant {
224                id: None,
225                content: OneOrMany::one(
226                    message
227                        .content
228                        .map(|content| message::AssistantContent::text(&content))
229                        .ok_or_else(|| {
230                            message::MessageError::ConversionError(
231                                "Empty assistant message".to_string(),
232                            )
233                        })?,
234                ),
235            }),
236            _ => Err(message::MessageError::ConversionError(format!(
237                "Unknown role: {}",
238                message.role
239            ))),
240        }
241    }
242}
243
244impl TryFrom<message::Message> for Message {
245    type Error = message::MessageError;
246
247    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
248        match message {
249            message::Message::User { content } => Ok(Self {
250                role: "user".to_string(),
251                content: content.iter().find_map(|c| match c {
252                    message::UserContent::Text(text) => Some(text.text.clone()),
253                    _ => None,
254                }),
255                reasoning: None,
256            }),
257            message::Message::Assistant { content, .. } => {
258                let mut text_content: Option<String> = None;
259                let mut groq_reasoning: Option<String> = None;
260
261                for c in content.iter() {
262                    match c {
263                        message::AssistantContent::Text(text) => {
264                            text_content = Some(
265                                text_content
266                                    .map(|mut existing| {
267                                        existing.push('\n');
268                                        existing.push_str(&text.text);
269                                        existing
270                                    })
271                                    .unwrap_or_else(|| text.text.clone()),
272                            );
273                        }
274                        message::AssistantContent::ToolCall(_tool_call) => {
275                            return Err(MessageError::ConversionError(
276                                "Tool calls do not exist on this message".into(),
277                            ));
278                        }
279                        message::AssistantContent::Reasoning(message::Reasoning {
280                            reasoning,
281                            ..
282                        }) => {
283                            groq_reasoning =
284                                Some(reasoning.first().cloned().unwrap_or(String::new()));
285                        }
286                    }
287                }
288
289                Ok(Self {
290                    role: "assistant".to_string(),
291                    content: text_content,
292                    reasoning: groq_reasoning,
293                })
294            }
295        }
296    }
297}
298
299// ================================================================
300// Groq Completion API
301// ================================================================
302/// The `deepseek-r1-distill-llama-70b` model. Used for chat completion.
303pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
304/// The `gemma2-9b-it` model. Used for chat completion.
305pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
306/// The `llama-3.1-8b-instant` model. Used for chat completion.
307pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
308/// The `llama-3.2-11b-vision-preview` model. Used for chat completion.
309pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
310/// The `llama-3.2-1b-preview` model. Used for chat completion.
311pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
312/// The `llama-3.2-3b-preview` model. Used for chat completion.
313pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
314/// The `llama-3.2-90b-vision-preview` model. Used for chat completion.
315pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
316/// The `llama-3.2-70b-specdec` model. Used for chat completion.
317pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
318/// The `llama-3.2-70b-versatile` model. Used for chat completion.
319pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
320/// The `llama-guard-3-8b` model. Used for chat completion.
321pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
322/// The `llama3-70b-8192` model. Used for chat completion.
323pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
324/// The `llama3-8b-8192` model. Used for chat completion.
325pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
326/// The `mixtral-8x7b-32768` model. Used for chat completion.
327pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
328
329#[derive(Clone, Debug)]
330pub struct CompletionModel {
331    client: Client,
332    /// Name of the model (e.g.: deepseek-r1-distill-llama-70b)
333    pub model: String,
334}
335
336impl CompletionModel {
337    pub fn new(client: Client, model: &str) -> Self {
338        Self {
339            client,
340            model: model.to_string(),
341        }
342    }
343
344    fn create_completion_request(
345        &self,
346        completion_request: CompletionRequest,
347    ) -> Result<Value, CompletionError> {
348        // Build up the order of messages (context, chat_history, prompt)
349        let mut partial_history = vec![];
350        if let Some(docs) = completion_request.normalized_documents() {
351            partial_history.push(docs);
352        }
353        partial_history.extend(completion_request.chat_history);
354
355        // Initialize full history with preamble (or empty if non-existent)
356        let mut full_history: Vec<Message> =
357            completion_request
358                .preamble
359                .map_or_else(Vec::new, |preamble| {
360                    vec![Message {
361                        role: "system".to_string(),
362                        content: Some(preamble),
363                        reasoning: None,
364                    }]
365                });
366
367        // Convert and extend the rest of the history
368        full_history.extend(
369            partial_history
370                .into_iter()
371                .map(message::Message::try_into)
372                .collect::<Result<Vec<Message>, _>>()?,
373        );
374
375        let request = if completion_request.tools.is_empty() {
376            json!({
377                "model": self.model,
378                "messages": full_history,
379                "temperature": completion_request.temperature,
380            })
381        } else {
382            json!({
383                "model": self.model,
384                "messages": full_history,
385                "temperature": completion_request.temperature,
386                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
387                "tool_choice": "auto",
388                "reasoning_format": "parsed"
389            })
390        };
391
392        let request = if let Some(params) = completion_request.additional_params {
393            json_utils::merge(request, params)
394        } else {
395            request
396        };
397
398        Ok(request)
399    }
400}
401
402impl completion::CompletionModel for CompletionModel {
403    type Response = CompletionResponse;
404    type StreamingResponse = StreamingCompletionResponse;
405
406    #[cfg_attr(feature = "worker", worker::send)]
407    async fn completion(
408        &self,
409        completion_request: CompletionRequest,
410    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
411        let request = self.create_completion_request(completion_request)?;
412
413        let response = self
414            .client
415            .post("/chat/completions")
416            .json(&request)
417            .send()
418            .await?;
419
420        if response.status().is_success() {
421            match response.json::<ApiResponse<CompletionResponse>>().await? {
422                ApiResponse::Ok(response) => {
423                    tracing::info!(target: "rig",
424                        "groq completion token usage: {:?}",
425                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
426                    );
427                    response.try_into()
428                }
429                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
430            }
431        } else {
432            Err(CompletionError::ProviderError(response.text().await?))
433        }
434    }
435
436    #[cfg_attr(feature = "worker", worker::send)]
437    async fn stream(
438        &self,
439        request: CompletionRequest,
440    ) -> Result<
441        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
442        CompletionError,
443    > {
444        let mut request = self.create_completion_request(request)?;
445
446        request = merge(
447            request,
448            json!({"stream": true, "stream_options": {"include_usage": true}}),
449        );
450
451        let builder = self.client.post("/chat/completions").json(&request);
452
453        send_compatible_streaming_request(builder).await
454    }
455}
456
457// ================================================================
458// Groq Transcription API
459// ================================================================
460pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
461pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
462pub const DISTIL_WHISPER_LARGE_V3: &str = "distil-whisper-large-v3-en";
463
464#[derive(Clone)]
465pub struct TranscriptionModel {
466    client: Client,
467    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
468    pub model: String,
469}
470
471impl TranscriptionModel {
472    pub fn new(client: Client, model: &str) -> Self {
473        Self {
474            client,
475            model: model.to_string(),
476        }
477    }
478}
479impl transcription::TranscriptionModel for TranscriptionModel {
480    type Response = TranscriptionResponse;
481
482    #[cfg_attr(feature = "worker", worker::send)]
483    async fn transcription(
484        &self,
485        request: transcription::TranscriptionRequest,
486    ) -> Result<
487        transcription::TranscriptionResponse<Self::Response>,
488        transcription::TranscriptionError,
489    > {
490        let data = request.data;
491
492        let mut body = reqwest::multipart::Form::new()
493            .text("model", self.model.clone())
494            .text("language", request.language)
495            .part(
496                "file",
497                Part::bytes(data).file_name(request.filename.clone()),
498            );
499
500        if let Some(prompt) = request.prompt {
501            body = body.text("prompt", prompt.clone());
502        }
503
504        if let Some(ref temperature) = request.temperature {
505            body = body.text("temperature", temperature.to_string());
506        }
507
508        if let Some(ref additional_params) = request.additional_params {
509            for (key, value) in additional_params
510                .as_object()
511                .expect("Additional Parameters to OpenAI Transcription should be a map")
512            {
513                body = body.text(key.to_owned(), value.to_string());
514            }
515        }
516
517        let response = self
518            .client
519            .post("audio/transcriptions")
520            .multipart(body)
521            .send()
522            .await?;
523
524        if response.status().is_success() {
525            match response
526                .json::<ApiResponse<TranscriptionResponse>>()
527                .await?
528            {
529                ApiResponse::Ok(response) => response.try_into(),
530                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
531                    api_error_response.message,
532                )),
533            }
534        } else {
535            Err(TranscriptionError::ProviderError(response.text().await?))
536        }
537    }
538}
539
540#[derive(Deserialize, Debug)]
541#[serde(untagged)]
542pub enum StreamingDelta {
543    Reasoning {
544        reasoning: String,
545    },
546    MessageContent {
547        #[serde(default)]
548        content: Option<String>,
549        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
550        tool_calls: Vec<StreamingToolCall>,
551    },
552}
553
554#[derive(Deserialize, Debug)]
555struct StreamingChoice {
556    delta: StreamingDelta,
557}
558
559#[derive(Deserialize, Debug)]
560struct StreamingCompletionChunk {
561    choices: Vec<StreamingChoice>,
562    usage: Option<Usage>,
563}
564
565#[derive(Clone, Deserialize, Serialize, Debug)]
566pub struct StreamingCompletionResponse {
567    pub usage: Usage,
568}
569
570impl GetTokenUsage for StreamingCompletionResponse {
571    fn token_usage(&self) -> Option<crate::completion::Usage> {
572        let mut usage = crate::completion::Usage::new();
573
574        usage.input_tokens = self.usage.prompt_tokens as u64;
575        usage.total_tokens = self.usage.total_tokens as u64;
576        usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
577
578        Some(usage)
579    }
580}
581
582pub async fn send_compatible_streaming_request(
583    request_builder: RequestBuilder,
584) -> Result<
585    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
586    CompletionError,
587> {
588    let response = request_builder.send().await?;
589
590    if !response.status().is_success() {
591        return Err(CompletionError::ProviderError(format!(
592            "{}: {}",
593            response.status(),
594            response.text().await?
595        )));
596    }
597
598    // Handle OpenAI Compatible SSE chunks
599    let inner = Box::pin(async_stream::stream! {
600        let mut stream = response.bytes_stream();
601
602        let mut final_usage = Usage {
603            prompt_tokens: 0,
604            total_tokens: 0
605        };
606
607        let mut partial_data = None;
608        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
609
610        while let Some(chunk_result) = stream.next().await {
611            let chunk = match chunk_result {
612                Ok(c) => c,
613                Err(e) => {
614                    yield Err(CompletionError::from(e));
615                    break;
616                }
617            };
618
619            let text = match String::from_utf8(chunk.to_vec()) {
620                Ok(t) => t,
621                Err(e) => {
622                    yield Err(CompletionError::ResponseError(e.to_string()));
623                    break;
624                }
625            };
626
627
628            for line in text.lines() {
629                let mut line = line.to_string();
630
631                // If there was a remaining part, concat with current line
632                if partial_data.is_some() {
633                    line = format!("{}{}", partial_data.unwrap(), line);
634                    partial_data = None;
635                }
636                // Otherwise full data line
637                else {
638                    let Some(data) = line.strip_prefix("data:") else {
639                        continue;
640                    };
641
642                    let data = data.trim_start();
643
644                    // Partial data, split somewhere in the middle
645                    if !line.ends_with("}") {
646                        partial_data = Some(data.to_string());
647                    } else {
648                        line = data.to_string();
649                    }
650                }
651
652                let data = serde_json::from_str::<StreamingCompletionChunk>(&line);
653
654                let Ok(data) = data else {
655                    let err = data.unwrap_err();
656                    tracing::debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
657                    continue;
658                };
659
660
661                if let Some(choice) = data.choices.first() {
662                    let delta = &choice.delta;
663
664                    match delta {
665                        StreamingDelta::Reasoning { reasoning } => {
666                            yield Ok(crate::streaming::RawStreamingChoice::Reasoning { id: None, reasoning: reasoning.to_string() })
667                        },
668                        StreamingDelta::MessageContent { content, tool_calls } => {
669                            if !tool_calls.is_empty() {
670                                for tool_call in tool_calls {
671                                    let function = tool_call.function.clone();
672                                    // Start of tool call
673                                    // name: Some(String)
674                                    // arguments: None
675                                    if function.name.is_some() && function.arguments.is_empty() {
676                                        let id = tool_call.id.clone().unwrap_or("".to_string());
677
678                                        calls.insert(tool_call.index, (id, function.name.clone().unwrap(), "".to_string()));
679                                    }
680                                    // Part of tool call
681                                    // name: None or Empty String
682                                    // arguments: Some(String)
683                                    else if function.name.clone().is_none_or(|s| s.is_empty()) && !function.arguments.is_empty() {
684                                        let Some((id, name, arguments)) = calls.get(&tool_call.index) else {
685                                            tracing::debug!("Partial tool call received but tool call was never started.");
686                                            continue;
687                                        };
688
689                                        let new_arguments = &function.arguments;
690                                        let arguments = format!("{arguments}{new_arguments}");
691
692                                        calls.insert(tool_call.index, (id.clone(), name.clone(), arguments));
693                                    }
694                                    // Entire tool call
695                                    else {
696                                        let id = tool_call.id.clone().unwrap_or("".to_string());
697                                        let name = function.name.expect("function name should be present for complete tool call");
698                                        let arguments = function.arguments;
699                                        let Ok(arguments) = serde_json::from_str(&arguments) else {
700                                            tracing::debug!("Couldn't serialize '{}' as a json value", arguments);
701                                            continue;
702                                        };
703
704                                        yield Ok(crate::streaming::RawStreamingChoice::ToolCall {id, name, arguments, call_id: None })
705                                    }
706                                }
707                            }
708
709                            if let Some(content) = &content {
710                                yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()))
711                            }
712                        }
713                    }
714                }
715
716
717                if let Some(usage) = data.usage {
718                    final_usage = usage.clone();
719                }
720            }
721        }
722
723        for (_, (id, name, arguments)) in calls {
724            let Ok(arguments) = serde_json::from_str(&arguments) else {
725                continue;
726            };
727
728            yield Ok(RawStreamingChoice::ToolCall {id, name, arguments, call_id: None });
729        }
730
731        yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
732            usage: final_usage.clone()
733        }))
734    });
735
736    Ok(crate::streaming::StreamingCompletionResponse::stream(inner))
737}