1use std::collections::HashMap;
12
13use super::openai::{CompletionResponse, StreamingToolCall, TranscriptionResponse, Usage};
14use crate::client::{ClientBuilderError, CompletionClient, TranscriptionClient};
15use crate::json_utils::merge;
16use futures::StreamExt;
17
18use crate::streaming::RawStreamingChoice;
19use crate::{
20 OneOrMany,
21 completion::{self, CompletionError, CompletionRequest},
22 json_utils,
23 message::{self, MessageError},
24 providers::openai::ToolDefinition,
25 transcription::{self, TranscriptionError},
26};
27use reqwest::RequestBuilder;
28use reqwest::multipart::Part;
29use rig::client::ProviderClient;
30use rig::impl_conversion_traits;
31use serde::{Deserialize, Serialize};
32use serde_json::{Value, json};
33
34const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
38
39pub struct ClientBuilder<'a> {
40 api_key: &'a str,
41 base_url: &'a str,
42 http_client: Option<reqwest::Client>,
43}
44
45impl<'a> ClientBuilder<'a> {
46 pub fn new(api_key: &'a str) -> Self {
47 Self {
48 api_key,
49 base_url: GROQ_API_BASE_URL,
50 http_client: None,
51 }
52 }
53
54 pub fn base_url(mut self, base_url: &'a str) -> Self {
55 self.base_url = base_url;
56 self
57 }
58
59 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
60 self.http_client = Some(client);
61 self
62 }
63
64 pub fn build(self) -> Result<Client, ClientBuilderError> {
65 let http_client = if let Some(http_client) = self.http_client {
66 http_client
67 } else {
68 reqwest::Client::builder().build()?
69 };
70
71 Ok(Client {
72 base_url: self.base_url.to_string(),
73 api_key: self.api_key.to_string(),
74 http_client,
75 })
76 }
77}
78
79#[derive(Clone)]
80pub struct Client {
81 base_url: String,
82 api_key: String,
83 http_client: reqwest::Client,
84}
85
86impl std::fmt::Debug for Client {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 f.debug_struct("Client")
89 .field("base_url", &self.base_url)
90 .field("http_client", &self.http_client)
91 .field("api_key", &"<REDACTED>")
92 .finish()
93 }
94}
95
96impl Client {
97 pub fn builder(api_key: &str) -> ClientBuilder<'_> {
108 ClientBuilder::new(api_key)
109 }
110
111 pub fn new(api_key: &str) -> Self {
116 Self::builder(api_key)
117 .build()
118 .expect("Groq client should build")
119 }
120
121 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
122 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
123 self.http_client.post(url).bearer_auth(&self.api_key)
124 }
125}
126
127impl ProviderClient for Client {
128 fn from_env() -> Self {
131 let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
132 Self::new(&api_key)
133 }
134
135 fn from_val(input: crate::client::ProviderValue) -> Self {
136 let crate::client::ProviderValue::Simple(api_key) = input else {
137 panic!("Incorrect provider value type")
138 };
139 Self::new(&api_key)
140 }
141}
142
143impl CompletionClient for Client {
144 type CompletionModel = CompletionModel;
145
146 fn completion_model(&self, model: &str) -> CompletionModel {
158 CompletionModel::new(self.clone(), model)
159 }
160}
161
162impl TranscriptionClient for Client {
163 type TranscriptionModel = TranscriptionModel;
164
165 fn transcription_model(&self, model: &str) -> TranscriptionModel {
177 TranscriptionModel::new(self.clone(), model)
178 }
179}
180
181impl_conversion_traits!(
182 AsEmbeddings,
183 AsImageGeneration,
184 AsAudioGeneration for Client
185);
186
187#[derive(Debug, Deserialize)]
188struct ApiErrorResponse {
189 message: String,
190}
191
192#[derive(Debug, Deserialize)]
193#[serde(untagged)]
194enum ApiResponse<T> {
195 Ok(T),
196 Err(ApiErrorResponse),
197}
198
199#[derive(Debug, Serialize, Deserialize)]
200pub struct Message {
201 pub role: String,
202 pub content: Option<String>,
203 #[serde(skip_serializing_if = "Option::is_none")]
204 pub reasoning: Option<String>,
205}
206
207impl TryFrom<Message> for message::Message {
208 type Error = message::MessageError;
209
210 fn try_from(message: Message) -> Result<Self, Self::Error> {
211 match message.role.as_str() {
212 "user" => Ok(Self::User {
213 content: OneOrMany::one(
214 message
215 .content
216 .map(|content| message::UserContent::text(&content))
217 .ok_or_else(|| {
218 message::MessageError::ConversionError("Empty user message".to_string())
219 })?,
220 ),
221 }),
222 "assistant" => Ok(Self::Assistant {
223 id: None,
224 content: OneOrMany::one(
225 message
226 .content
227 .map(|content| message::AssistantContent::text(&content))
228 .ok_or_else(|| {
229 message::MessageError::ConversionError(
230 "Empty assistant message".to_string(),
231 )
232 })?,
233 ),
234 }),
235 _ => Err(message::MessageError::ConversionError(format!(
236 "Unknown role: {}",
237 message.role
238 ))),
239 }
240 }
241}
242
243impl TryFrom<message::Message> for Message {
244 type Error = message::MessageError;
245
246 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
247 match message {
248 message::Message::User { content } => Ok(Self {
249 role: "user".to_string(),
250 content: content.iter().find_map(|c| match c {
251 message::UserContent::Text(text) => Some(text.text.clone()),
252 _ => None,
253 }),
254 reasoning: None,
255 }),
256 message::Message::Assistant { content, .. } => {
257 let mut text_content: Option<String> = None;
258 let mut groq_reasoning: Option<String> = None;
259
260 for c in content.iter() {
261 match c {
262 message::AssistantContent::Text(text) => {
263 text_content = Some(
264 text_content
265 .map(|mut existing| {
266 existing.push('\n');
267 existing.push_str(&text.text);
268 existing
269 })
270 .unwrap_or_else(|| text.text.clone()),
271 );
272 }
273 message::AssistantContent::ToolCall(_tool_call) => {
274 return Err(MessageError::ConversionError(
275 "Tool calls do not exist on this message".into(),
276 ));
277 }
278 message::AssistantContent::Reasoning(message::Reasoning { reasoning }) => {
279 groq_reasoning = Some(reasoning.to_owned());
280 }
281 }
282 }
283
284 Ok(Self {
285 role: "assistant".to_string(),
286 content: text_content,
287 reasoning: groq_reasoning,
288 })
289 }
290 }
291 }
292}
293
294pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
299pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
301pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
303pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
305pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
307pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
309pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
311pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
313pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
315pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
317pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
319pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
321pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
323
324#[derive(Clone, Debug)]
325pub struct CompletionModel {
326 client: Client,
327 pub model: String,
329}
330
331impl CompletionModel {
332 pub fn new(client: Client, model: &str) -> Self {
333 Self {
334 client,
335 model: model.to_string(),
336 }
337 }
338
339 fn create_completion_request(
340 &self,
341 completion_request: CompletionRequest,
342 ) -> Result<Value, CompletionError> {
343 let mut partial_history = vec![];
345 if let Some(docs) = completion_request.normalized_documents() {
346 partial_history.push(docs);
347 }
348 partial_history.extend(completion_request.chat_history);
349
350 let mut full_history: Vec<Message> =
352 completion_request
353 .preamble
354 .map_or_else(Vec::new, |preamble| {
355 vec![Message {
356 role: "system".to_string(),
357 content: Some(preamble),
358 reasoning: None,
359 }]
360 });
361
362 full_history.extend(
364 partial_history
365 .into_iter()
366 .map(message::Message::try_into)
367 .collect::<Result<Vec<Message>, _>>()?,
368 );
369
370 let request = if completion_request.tools.is_empty() {
371 json!({
372 "model": self.model,
373 "messages": full_history,
374 "temperature": completion_request.temperature,
375 })
376 } else {
377 json!({
378 "model": self.model,
379 "messages": full_history,
380 "temperature": completion_request.temperature,
381 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
382 "tool_choice": "auto",
383 "reasoning_format": "parsed"
384 })
385 };
386
387 let request = if let Some(params) = completion_request.additional_params {
388 json_utils::merge(request, params)
389 } else {
390 request
391 };
392
393 Ok(request)
394 }
395}
396
397impl completion::CompletionModel for CompletionModel {
398 type Response = CompletionResponse;
399 type StreamingResponse = StreamingCompletionResponse;
400
401 #[cfg_attr(feature = "worker", worker::send)]
402 async fn completion(
403 &self,
404 completion_request: CompletionRequest,
405 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
406 let request = self.create_completion_request(completion_request)?;
407
408 let response = self
409 .client
410 .post("/chat/completions")
411 .json(&request)
412 .send()
413 .await?;
414
415 if response.status().is_success() {
416 match response.json::<ApiResponse<CompletionResponse>>().await? {
417 ApiResponse::Ok(response) => {
418 tracing::info!(target: "rig",
419 "groq completion token usage: {:?}",
420 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
421 );
422 response.try_into()
423 }
424 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
425 }
426 } else {
427 Err(CompletionError::ProviderError(response.text().await?))
428 }
429 }
430
431 #[cfg_attr(feature = "worker", worker::send)]
432 async fn stream(
433 &self,
434 request: CompletionRequest,
435 ) -> Result<
436 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
437 CompletionError,
438 > {
439 let mut request = self.create_completion_request(request)?;
440
441 request = merge(
442 request,
443 json!({"stream": true, "stream_options": {"include_usage": true}}),
444 );
445
446 let builder = self.client.post("/chat/completions").json(&request);
447
448 send_compatible_streaming_request(builder).await
449 }
450}
451
452pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
456pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
457pub const DISTIL_WHISPER_LARGE_V3: &str = "distil-whisper-large-v3-en";
458
459#[derive(Clone)]
460pub struct TranscriptionModel {
461 client: Client,
462 pub model: String,
464}
465
466impl TranscriptionModel {
467 pub fn new(client: Client, model: &str) -> Self {
468 Self {
469 client,
470 model: model.to_string(),
471 }
472 }
473}
474impl transcription::TranscriptionModel for TranscriptionModel {
475 type Response = TranscriptionResponse;
476
477 #[cfg_attr(feature = "worker", worker::send)]
478 async fn transcription(
479 &self,
480 request: transcription::TranscriptionRequest,
481 ) -> Result<
482 transcription::TranscriptionResponse<Self::Response>,
483 transcription::TranscriptionError,
484 > {
485 let data = request.data;
486
487 let mut body = reqwest::multipart::Form::new()
488 .text("model", self.model.clone())
489 .text("language", request.language)
490 .part(
491 "file",
492 Part::bytes(data).file_name(request.filename.clone()),
493 );
494
495 if let Some(prompt) = request.prompt {
496 body = body.text("prompt", prompt.clone());
497 }
498
499 if let Some(ref temperature) = request.temperature {
500 body = body.text("temperature", temperature.to_string());
501 }
502
503 if let Some(ref additional_params) = request.additional_params {
504 for (key, value) in additional_params
505 .as_object()
506 .expect("Additional Parameters to OpenAI Transcription should be a map")
507 {
508 body = body.text(key.to_owned(), value.to_string());
509 }
510 }
511
512 let response = self
513 .client
514 .post("audio/transcriptions")
515 .multipart(body)
516 .send()
517 .await?;
518
519 if response.status().is_success() {
520 match response
521 .json::<ApiResponse<TranscriptionResponse>>()
522 .await?
523 {
524 ApiResponse::Ok(response) => response.try_into(),
525 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
526 api_error_response.message,
527 )),
528 }
529 } else {
530 Err(TranscriptionError::ProviderError(response.text().await?))
531 }
532 }
533}
534
535#[derive(Deserialize, Debug)]
536#[serde(untagged)]
537pub enum StreamingDelta {
538 Reasoning {
539 reasoning: String,
540 },
541 MessageContent {
542 #[serde(default)]
543 content: Option<String>,
544 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
545 tool_calls: Vec<StreamingToolCall>,
546 },
547}
548
549#[derive(Deserialize, Debug)]
550struct StreamingChoice {
551 delta: StreamingDelta,
552}
553
554#[derive(Deserialize, Debug)]
555struct StreamingCompletionChunk {
556 choices: Vec<StreamingChoice>,
557 usage: Option<Usage>,
558}
559
560#[derive(Clone, Deserialize, Serialize, Debug)]
561pub struct StreamingCompletionResponse {
562 pub usage: Usage,
563}
564
565pub async fn send_compatible_streaming_request(
566 request_builder: RequestBuilder,
567) -> Result<
568 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
569 CompletionError,
570> {
571 let response = request_builder.send().await?;
572
573 if !response.status().is_success() {
574 return Err(CompletionError::ProviderError(format!(
575 "{}: {}",
576 response.status(),
577 response.text().await?
578 )));
579 }
580
581 let inner = Box::pin(async_stream::stream! {
583 let mut stream = response.bytes_stream();
584
585 let mut final_usage = Usage {
586 prompt_tokens: 0,
587 total_tokens: 0
588 };
589
590 let mut partial_data = None;
591 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
592
593 while let Some(chunk_result) = stream.next().await {
594 let chunk = match chunk_result {
595 Ok(c) => c,
596 Err(e) => {
597 yield Err(CompletionError::from(e));
598 break;
599 }
600 };
601
602 let text = match String::from_utf8(chunk.to_vec()) {
603 Ok(t) => t,
604 Err(e) => {
605 yield Err(CompletionError::ResponseError(e.to_string()));
606 break;
607 }
608 };
609
610
611 for line in text.lines() {
612 let mut line = line.to_string();
613
614 if partial_data.is_some() {
616 line = format!("{}{}", partial_data.unwrap(), line);
617 partial_data = None;
618 }
619 else {
621 let Some(data) = line.strip_prefix("data:") else {
622 continue;
623 };
624
625 let data = data.trim_start();
626
627 if !line.ends_with("}") {
629 partial_data = Some(data.to_string());
630 } else {
631 line = data.to_string();
632 }
633 }
634
635 let data = serde_json::from_str::<StreamingCompletionChunk>(&line);
636
637 let Ok(data) = data else {
638 let err = data.unwrap_err();
639 tracing::debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
640 continue;
641 };
642
643
644 if let Some(choice) = data.choices.first() {
645 let delta = &choice.delta;
646
647 match delta {
648 StreamingDelta::Reasoning { reasoning } => {
649 yield Ok(crate::streaming::RawStreamingChoice::Reasoning { reasoning: reasoning.to_string() })
650 },
651 StreamingDelta::MessageContent { content, tool_calls } => {
652 if !tool_calls.is_empty() {
653 for tool_call in tool_calls {
654 let function = tool_call.function.clone();
655 if function.name.is_some() && function.arguments.is_empty() {
659 let id = tool_call.id.clone().unwrap_or("".to_string());
660
661 calls.insert(tool_call.index, (id, function.name.clone().unwrap(), "".to_string()));
662 }
663 else if function.name.clone().is_none_or(|s| s.is_empty()) && !function.arguments.is_empty() {
667 let Some((id, name, arguments)) = calls.get(&tool_call.index) else {
668 tracing::debug!("Partial tool call received but tool call was never started.");
669 continue;
670 };
671
672 let new_arguments = &function.arguments;
673 let arguments = format!("{arguments}{new_arguments}");
674
675 calls.insert(tool_call.index, (id.clone(), name.clone(), arguments));
676 }
677 else {
679 let id = tool_call.id.clone().unwrap_or("".to_string());
680 let name = function.name.expect("function name should be present for complete tool call");
681 let arguments = function.arguments;
682 let Ok(arguments) = serde_json::from_str(&arguments) else {
683 tracing::debug!("Couldn't serialize '{}' as a json value", arguments);
684 continue;
685 };
686
687 yield Ok(crate::streaming::RawStreamingChoice::ToolCall {id, name, arguments, call_id: None })
688 }
689 }
690 }
691
692 if let Some(content) = &content {
693 yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()))
694 }
695 }
696 }
697 }
698
699
700 if let Some(usage) = data.usage {
701 final_usage = usage.clone();
702 }
703 }
704 }
705
706 for (_, (id, name, arguments)) in calls {
707 let Ok(arguments) = serde_json::from_str(&arguments) else {
708 continue;
709 };
710
711 yield Ok(RawStreamingChoice::ToolCall {id, name, arguments, call_id: None });
712 }
713
714 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
715 usage: final_usage.clone()
716 }))
717 });
718
719 Ok(crate::streaming::StreamingCompletionResponse::stream(inner))
720}