rig/providers/
groq.rs

1//! Groq API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::groq;
6//!
7//! let client = groq::Client::new("YOUR_API_KEY");
8//!
9//! let gpt4o = client.completion_model(groq::GPT_4O);
10//! ```
11use std::collections::HashMap;
12
13use super::openai::{CompletionResponse, StreamingToolCall, TranscriptionResponse, Usage};
14use crate::client::{ClientBuilderError, CompletionClient, TranscriptionClient};
15use crate::json_utils::merge;
16use futures::StreamExt;
17
18use crate::streaming::RawStreamingChoice;
19use crate::{
20    OneOrMany,
21    completion::{self, CompletionError, CompletionRequest},
22    json_utils,
23    message::{self, MessageError},
24    providers::openai::ToolDefinition,
25    transcription::{self, TranscriptionError},
26};
27use reqwest::RequestBuilder;
28use reqwest::multipart::Part;
29use rig::client::ProviderClient;
30use rig::impl_conversion_traits;
31use serde::{Deserialize, Serialize};
32use serde_json::{Value, json};
33
34// ================================================================
35// Main Groq Client
36// ================================================================
37const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
38
39pub struct ClientBuilder<'a> {
40    api_key: &'a str,
41    base_url: &'a str,
42    http_client: Option<reqwest::Client>,
43}
44
45impl<'a> ClientBuilder<'a> {
46    pub fn new(api_key: &'a str) -> Self {
47        Self {
48            api_key,
49            base_url: GROQ_API_BASE_URL,
50            http_client: None,
51        }
52    }
53
54    pub fn base_url(mut self, base_url: &'a str) -> Self {
55        self.base_url = base_url;
56        self
57    }
58
59    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
60        self.http_client = Some(client);
61        self
62    }
63
64    pub fn build(self) -> Result<Client, ClientBuilderError> {
65        let http_client = if let Some(http_client) = self.http_client {
66            http_client
67        } else {
68            reqwest::Client::builder().build()?
69        };
70
71        Ok(Client {
72            base_url: self.base_url.to_string(),
73            api_key: self.api_key.to_string(),
74            http_client,
75        })
76    }
77}
78
79#[derive(Clone)]
80pub struct Client {
81    base_url: String,
82    api_key: String,
83    http_client: reqwest::Client,
84}
85
86impl std::fmt::Debug for Client {
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        f.debug_struct("Client")
89            .field("base_url", &self.base_url)
90            .field("http_client", &self.http_client)
91            .field("api_key", &"<REDACTED>")
92            .finish()
93    }
94}
95
96impl Client {
97    /// Create a new Groq client builder.
98    ///
99    /// # Example
100    /// ```
101    /// use rig::providers::groq::{ClientBuilder, self};
102    ///
103    /// // Initialize the Groq client
104    /// let groq = Client::builder("your-groq-api-key")
105    ///    .build()
106    /// ```
107    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
108        ClientBuilder::new(api_key)
109    }
110
111    /// Create a new Groq client with the given API key.
112    ///
113    /// # Panics
114    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
115    pub fn new(api_key: &str) -> Self {
116        Self::builder(api_key)
117            .build()
118            .expect("Groq client should build")
119    }
120
121    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
122        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
123        self.http_client.post(url).bearer_auth(&self.api_key)
124    }
125}
126
127impl ProviderClient for Client {
128    /// Create a new Groq client from the `GROQ_API_KEY` environment variable.
129    /// Panics if the environment variable is not set.
130    fn from_env() -> Self {
131        let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
132        Self::new(&api_key)
133    }
134
135    fn from_val(input: crate::client::ProviderValue) -> Self {
136        let crate::client::ProviderValue::Simple(api_key) = input else {
137            panic!("Incorrect provider value type")
138        };
139        Self::new(&api_key)
140    }
141}
142
143impl CompletionClient for Client {
144    type CompletionModel = CompletionModel;
145
146    /// Create a completion model with the given name.
147    ///
148    /// # Example
149    /// ```
150    /// use rig::providers::groq::{Client, self};
151    ///
152    /// // Initialize the Groq client
153    /// let groq = Client::new("your-groq-api-key");
154    ///
155    /// let gpt4 = groq.completion_model(groq::GPT_4);
156    /// ```
157    fn completion_model(&self, model: &str) -> CompletionModel {
158        CompletionModel::new(self.clone(), model)
159    }
160}
161
162impl TranscriptionClient for Client {
163    type TranscriptionModel = TranscriptionModel;
164
165    /// Create a transcription model with the given name.
166    ///
167    /// # Example
168    /// ```
169    /// use rig::providers::groq::{Client, self};
170    ///
171    /// // Initialize the Groq client
172    /// let groq = Client::new("your-groq-api-key");
173    ///
174    /// let gpt4 = groq.transcription_model(groq::WHISPER_LARGE_V3);
175    /// ```
176    fn transcription_model(&self, model: &str) -> TranscriptionModel {
177        TranscriptionModel::new(self.clone(), model)
178    }
179}
180
181impl_conversion_traits!(
182    AsEmbeddings,
183    AsImageGeneration,
184    AsAudioGeneration for Client
185);
186
187#[derive(Debug, Deserialize)]
188struct ApiErrorResponse {
189    message: String,
190}
191
192#[derive(Debug, Deserialize)]
193#[serde(untagged)]
194enum ApiResponse<T> {
195    Ok(T),
196    Err(ApiErrorResponse),
197}
198
199#[derive(Debug, Serialize, Deserialize)]
200pub struct Message {
201    pub role: String,
202    pub content: Option<String>,
203    #[serde(skip_serializing_if = "Option::is_none")]
204    pub reasoning: Option<String>,
205}
206
207impl TryFrom<Message> for message::Message {
208    type Error = message::MessageError;
209
210    fn try_from(message: Message) -> Result<Self, Self::Error> {
211        match message.role.as_str() {
212            "user" => Ok(Self::User {
213                content: OneOrMany::one(
214                    message
215                        .content
216                        .map(|content| message::UserContent::text(&content))
217                        .ok_or_else(|| {
218                            message::MessageError::ConversionError("Empty user message".to_string())
219                        })?,
220                ),
221            }),
222            "assistant" => Ok(Self::Assistant {
223                id: None,
224                content: OneOrMany::one(
225                    message
226                        .content
227                        .map(|content| message::AssistantContent::text(&content))
228                        .ok_or_else(|| {
229                            message::MessageError::ConversionError(
230                                "Empty assistant message".to_string(),
231                            )
232                        })?,
233                ),
234            }),
235            _ => Err(message::MessageError::ConversionError(format!(
236                "Unknown role: {}",
237                message.role
238            ))),
239        }
240    }
241}
242
243impl TryFrom<message::Message> for Message {
244    type Error = message::MessageError;
245
246    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
247        match message {
248            message::Message::User { content } => Ok(Self {
249                role: "user".to_string(),
250                content: content.iter().find_map(|c| match c {
251                    message::UserContent::Text(text) => Some(text.text.clone()),
252                    _ => None,
253                }),
254                reasoning: None,
255            }),
256            message::Message::Assistant { content, .. } => {
257                let mut text_content: Option<String> = None;
258                let mut groq_reasoning: Option<String> = None;
259
260                for c in content.iter() {
261                    match c {
262                        message::AssistantContent::Text(text) => {
263                            text_content = Some(
264                                text_content
265                                    .map(|mut existing| {
266                                        existing.push('\n');
267                                        existing.push_str(&text.text);
268                                        existing
269                                    })
270                                    .unwrap_or_else(|| text.text.clone()),
271                            );
272                        }
273                        message::AssistantContent::ToolCall(_tool_call) => {
274                            return Err(MessageError::ConversionError(
275                                "Tool calls do not exist on this message".into(),
276                            ));
277                        }
278                        message::AssistantContent::Reasoning(message::Reasoning { reasoning }) => {
279                            groq_reasoning = Some(reasoning.to_owned());
280                        }
281                    }
282                }
283
284                Ok(Self {
285                    role: "assistant".to_string(),
286                    content: text_content,
287                    reasoning: groq_reasoning,
288                })
289            }
290        }
291    }
292}
293
294// ================================================================
295// Groq Completion API
296// ================================================================
297/// The `deepseek-r1-distill-llama-70b` model. Used for chat completion.
298pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
299/// The `gemma2-9b-it` model. Used for chat completion.
300pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
301/// The `llama-3.1-8b-instant` model. Used for chat completion.
302pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
303/// The `llama-3.2-11b-vision-preview` model. Used for chat completion.
304pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
305/// The `llama-3.2-1b-preview` model. Used for chat completion.
306pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
307/// The `llama-3.2-3b-preview` model. Used for chat completion.
308pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
309/// The `llama-3.2-90b-vision-preview` model. Used for chat completion.
310pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
311/// The `llama-3.2-70b-specdec` model. Used for chat completion.
312pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
313/// The `llama-3.2-70b-versatile` model. Used for chat completion.
314pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
315/// The `llama-guard-3-8b` model. Used for chat completion.
316pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
317/// The `llama3-70b-8192` model. Used for chat completion.
318pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
319/// The `llama3-8b-8192` model. Used for chat completion.
320pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
321/// The `mixtral-8x7b-32768` model. Used for chat completion.
322pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
323
324#[derive(Clone, Debug)]
325pub struct CompletionModel {
326    client: Client,
327    /// Name of the model (e.g.: deepseek-r1-distill-llama-70b)
328    pub model: String,
329}
330
331impl CompletionModel {
332    pub fn new(client: Client, model: &str) -> Self {
333        Self {
334            client,
335            model: model.to_string(),
336        }
337    }
338
339    fn create_completion_request(
340        &self,
341        completion_request: CompletionRequest,
342    ) -> Result<Value, CompletionError> {
343        // Build up the order of messages (context, chat_history, prompt)
344        let mut partial_history = vec![];
345        if let Some(docs) = completion_request.normalized_documents() {
346            partial_history.push(docs);
347        }
348        partial_history.extend(completion_request.chat_history);
349
350        // Initialize full history with preamble (or empty if non-existent)
351        let mut full_history: Vec<Message> =
352            completion_request
353                .preamble
354                .map_or_else(Vec::new, |preamble| {
355                    vec![Message {
356                        role: "system".to_string(),
357                        content: Some(preamble),
358                        reasoning: None,
359                    }]
360                });
361
362        // Convert and extend the rest of the history
363        full_history.extend(
364            partial_history
365                .into_iter()
366                .map(message::Message::try_into)
367                .collect::<Result<Vec<Message>, _>>()?,
368        );
369
370        let request = if completion_request.tools.is_empty() {
371            json!({
372                "model": self.model,
373                "messages": full_history,
374                "temperature": completion_request.temperature,
375            })
376        } else {
377            json!({
378                "model": self.model,
379                "messages": full_history,
380                "temperature": completion_request.temperature,
381                "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
382                "tool_choice": "auto",
383                "reasoning_format": "parsed"
384            })
385        };
386
387        let request = if let Some(params) = completion_request.additional_params {
388            json_utils::merge(request, params)
389        } else {
390            request
391        };
392
393        Ok(request)
394    }
395}
396
397impl completion::CompletionModel for CompletionModel {
398    type Response = CompletionResponse;
399    type StreamingResponse = StreamingCompletionResponse;
400
401    #[cfg_attr(feature = "worker", worker::send)]
402    async fn completion(
403        &self,
404        completion_request: CompletionRequest,
405    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
406        let request = self.create_completion_request(completion_request)?;
407
408        let response = self
409            .client
410            .post("/chat/completions")
411            .json(&request)
412            .send()
413            .await?;
414
415        if response.status().is_success() {
416            match response.json::<ApiResponse<CompletionResponse>>().await? {
417                ApiResponse::Ok(response) => {
418                    tracing::info!(target: "rig",
419                        "groq completion token usage: {:?}",
420                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
421                    );
422                    response.try_into()
423                }
424                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
425            }
426        } else {
427            Err(CompletionError::ProviderError(response.text().await?))
428        }
429    }
430
431    #[cfg_attr(feature = "worker", worker::send)]
432    async fn stream(
433        &self,
434        request: CompletionRequest,
435    ) -> Result<
436        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
437        CompletionError,
438    > {
439        let mut request = self.create_completion_request(request)?;
440
441        request = merge(
442            request,
443            json!({"stream": true, "stream_options": {"include_usage": true}}),
444        );
445
446        let builder = self.client.post("/chat/completions").json(&request);
447
448        send_compatible_streaming_request(builder).await
449    }
450}
451
452// ================================================================
453// Groq Transcription API
454// ================================================================
455pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
456pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
457pub const DISTIL_WHISPER_LARGE_V3: &str = "distil-whisper-large-v3-en";
458
459#[derive(Clone)]
460pub struct TranscriptionModel {
461    client: Client,
462    /// Name of the model (e.g.: gpt-3.5-turbo-1106)
463    pub model: String,
464}
465
466impl TranscriptionModel {
467    pub fn new(client: Client, model: &str) -> Self {
468        Self {
469            client,
470            model: model.to_string(),
471        }
472    }
473}
474impl transcription::TranscriptionModel for TranscriptionModel {
475    type Response = TranscriptionResponse;
476
477    #[cfg_attr(feature = "worker", worker::send)]
478    async fn transcription(
479        &self,
480        request: transcription::TranscriptionRequest,
481    ) -> Result<
482        transcription::TranscriptionResponse<Self::Response>,
483        transcription::TranscriptionError,
484    > {
485        let data = request.data;
486
487        let mut body = reqwest::multipart::Form::new()
488            .text("model", self.model.clone())
489            .text("language", request.language)
490            .part(
491                "file",
492                Part::bytes(data).file_name(request.filename.clone()),
493            );
494
495        if let Some(prompt) = request.prompt {
496            body = body.text("prompt", prompt.clone());
497        }
498
499        if let Some(ref temperature) = request.temperature {
500            body = body.text("temperature", temperature.to_string());
501        }
502
503        if let Some(ref additional_params) = request.additional_params {
504            for (key, value) in additional_params
505                .as_object()
506                .expect("Additional Parameters to OpenAI Transcription should be a map")
507            {
508                body = body.text(key.to_owned(), value.to_string());
509            }
510        }
511
512        let response = self
513            .client
514            .post("audio/transcriptions")
515            .multipart(body)
516            .send()
517            .await?;
518
519        if response.status().is_success() {
520            match response
521                .json::<ApiResponse<TranscriptionResponse>>()
522                .await?
523            {
524                ApiResponse::Ok(response) => response.try_into(),
525                ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
526                    api_error_response.message,
527                )),
528            }
529        } else {
530            Err(TranscriptionError::ProviderError(response.text().await?))
531        }
532    }
533}
534
535#[derive(Deserialize, Debug)]
536#[serde(untagged)]
537pub enum StreamingDelta {
538    Reasoning {
539        reasoning: String,
540    },
541    MessageContent {
542        #[serde(default)]
543        content: Option<String>,
544        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
545        tool_calls: Vec<StreamingToolCall>,
546    },
547}
548
549#[derive(Deserialize, Debug)]
550struct StreamingChoice {
551    delta: StreamingDelta,
552}
553
554#[derive(Deserialize, Debug)]
555struct StreamingCompletionChunk {
556    choices: Vec<StreamingChoice>,
557    usage: Option<Usage>,
558}
559
560#[derive(Clone, Deserialize, Serialize, Debug)]
561pub struct StreamingCompletionResponse {
562    pub usage: Usage,
563}
564
565pub async fn send_compatible_streaming_request(
566    request_builder: RequestBuilder,
567) -> Result<
568    crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
569    CompletionError,
570> {
571    let response = request_builder.send().await?;
572
573    if !response.status().is_success() {
574        return Err(CompletionError::ProviderError(format!(
575            "{}: {}",
576            response.status(),
577            response.text().await?
578        )));
579    }
580
581    // Handle OpenAI Compatible SSE chunks
582    let inner = Box::pin(async_stream::stream! {
583        let mut stream = response.bytes_stream();
584
585        let mut final_usage = Usage {
586            prompt_tokens: 0,
587            total_tokens: 0
588        };
589
590        let mut partial_data = None;
591        let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
592
593        while let Some(chunk_result) = stream.next().await {
594            let chunk = match chunk_result {
595                Ok(c) => c,
596                Err(e) => {
597                    yield Err(CompletionError::from(e));
598                    break;
599                }
600            };
601
602            let text = match String::from_utf8(chunk.to_vec()) {
603                Ok(t) => t,
604                Err(e) => {
605                    yield Err(CompletionError::ResponseError(e.to_string()));
606                    break;
607                }
608            };
609
610
611            for line in text.lines() {
612                let mut line = line.to_string();
613
614                // If there was a remaining part, concat with current line
615                if partial_data.is_some() {
616                    line = format!("{}{}", partial_data.unwrap(), line);
617                    partial_data = None;
618                }
619                // Otherwise full data line
620                else {
621                    let Some(data) = line.strip_prefix("data:") else {
622                        continue;
623                    };
624
625                    let data = data.trim_start();
626
627                    // Partial data, split somewhere in the middle
628                    if !line.ends_with("}") {
629                        partial_data = Some(data.to_string());
630                    } else {
631                        line = data.to_string();
632                    }
633                }
634
635                let data = serde_json::from_str::<StreamingCompletionChunk>(&line);
636
637                let Ok(data) = data else {
638                    let err = data.unwrap_err();
639                    tracing::debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
640                    continue;
641                };
642
643
644                if let Some(choice) = data.choices.first() {
645                    let delta = &choice.delta;
646
647                    match delta {
648                        StreamingDelta::Reasoning { reasoning } => {
649                            yield Ok(crate::streaming::RawStreamingChoice::Reasoning { reasoning: reasoning.to_string() })
650                        },
651                        StreamingDelta::MessageContent { content, tool_calls } => {
652                            if !tool_calls.is_empty() {
653                                for tool_call in tool_calls {
654                                    let function = tool_call.function.clone();
655                                    // Start of tool call
656                                    // name: Some(String)
657                                    // arguments: None
658                                    if function.name.is_some() && function.arguments.is_empty() {
659                                        let id = tool_call.id.clone().unwrap_or("".to_string());
660
661                                        calls.insert(tool_call.index, (id, function.name.clone().unwrap(), "".to_string()));
662                                    }
663                                    // Part of tool call
664                                    // name: None or Empty String
665                                    // arguments: Some(String)
666                                    else if function.name.clone().is_none_or(|s| s.is_empty()) && !function.arguments.is_empty() {
667                                        let Some((id, name, arguments)) = calls.get(&tool_call.index) else {
668                                            tracing::debug!("Partial tool call received but tool call was never started.");
669                                            continue;
670                                        };
671
672                                        let new_arguments = &function.arguments;
673                                        let arguments = format!("{arguments}{new_arguments}");
674
675                                        calls.insert(tool_call.index, (id.clone(), name.clone(), arguments));
676                                    }
677                                    // Entire tool call
678                                    else {
679                                        let id = tool_call.id.clone().unwrap_or("".to_string());
680                                        let name = function.name.expect("function name should be present for complete tool call");
681                                        let arguments = function.arguments;
682                                        let Ok(arguments) = serde_json::from_str(&arguments) else {
683                                            tracing::debug!("Couldn't serialize '{}' as a json value", arguments);
684                                            continue;
685                                        };
686
687                                        yield Ok(crate::streaming::RawStreamingChoice::ToolCall {id, name, arguments, call_id: None })
688                                    }
689                                }
690                            }
691
692                            if let Some(content) = &content {
693                                yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()))
694                            }
695                        }
696                    }
697                }
698
699
700                if let Some(usage) = data.usage {
701                    final_usage = usage.clone();
702                }
703            }
704        }
705
706        for (_, (id, name, arguments)) in calls {
707            let Ok(arguments) = serde_json::from_str(&arguments) else {
708                continue;
709            };
710
711            yield Ok(RawStreamingChoice::ToolCall {id, name, arguments, call_id: None });
712        }
713
714        yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
715            usage: final_usage.clone()
716        }))
717    });
718
719    Ok(crate::streaming::StreamingCompletionResponse::stream(inner))
720}