1use std::collections::HashMap;
12
13use super::openai::{CompletionResponse, StreamingToolCall, TranscriptionResponse, Usage};
14use crate::client::{
15 ClientBuilderError, CompletionClient, TranscriptionClient, VerifyClient, VerifyError,
16};
17use crate::completion::GetTokenUsage;
18use crate::json_utils::merge;
19use futures::StreamExt;
20
21use crate::streaming::RawStreamingChoice;
22use crate::{
23 OneOrMany,
24 completion::{self, CompletionError, CompletionRequest},
25 json_utils,
26 message::{self, MessageError},
27 providers::openai::ToolDefinition,
28 transcription::{self, TranscriptionError},
29};
30use reqwest::RequestBuilder;
31use reqwest::multipart::Part;
32use rig::client::ProviderClient;
33use rig::impl_conversion_traits;
34use serde::{Deserialize, Serialize};
35use serde_json::{Value, json};
36
37const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
41
42pub struct ClientBuilder<'a> {
43 api_key: &'a str,
44 base_url: &'a str,
45 http_client: Option<reqwest::Client>,
46}
47
48impl<'a> ClientBuilder<'a> {
49 pub fn new(api_key: &'a str) -> Self {
50 Self {
51 api_key,
52 base_url: GROQ_API_BASE_URL,
53 http_client: None,
54 }
55 }
56
57 pub fn base_url(mut self, base_url: &'a str) -> Self {
58 self.base_url = base_url;
59 self
60 }
61
62 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
63 self.http_client = Some(client);
64 self
65 }
66
67 pub fn build(self) -> Result<Client, ClientBuilderError> {
68 let http_client = if let Some(http_client) = self.http_client {
69 http_client
70 } else {
71 reqwest::Client::builder().build()?
72 };
73
74 Ok(Client {
75 base_url: self.base_url.to_string(),
76 api_key: self.api_key.to_string(),
77 http_client,
78 })
79 }
80}
81
82#[derive(Clone)]
83pub struct Client {
84 base_url: String,
85 api_key: String,
86 http_client: reqwest::Client,
87}
88
89impl std::fmt::Debug for Client {
90 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91 f.debug_struct("Client")
92 .field("base_url", &self.base_url)
93 .field("http_client", &self.http_client)
94 .field("api_key", &"<REDACTED>")
95 .finish()
96 }
97}
98
99impl Client {
100 pub fn builder(api_key: &str) -> ClientBuilder<'_> {
111 ClientBuilder::new(api_key)
112 }
113
114 pub fn new(api_key: &str) -> Self {
119 Self::builder(api_key)
120 .build()
121 .expect("Groq client should build")
122 }
123
124 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
125 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
126 self.http_client.post(url).bearer_auth(&self.api_key)
127 }
128
129 pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder {
130 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
131 self.http_client.get(url).bearer_auth(&self.api_key)
132 }
133}
134
135impl ProviderClient for Client {
136 fn from_env() -> Self {
139 let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
140 Self::new(&api_key)
141 }
142
143 fn from_val(input: crate::client::ProviderValue) -> Self {
144 let crate::client::ProviderValue::Simple(api_key) = input else {
145 panic!("Incorrect provider value type")
146 };
147 Self::new(&api_key)
148 }
149}
150
151impl CompletionClient for Client {
152 type CompletionModel = CompletionModel;
153
154 fn completion_model(&self, model: &str) -> CompletionModel {
166 CompletionModel::new(self.clone(), model)
167 }
168}
169
170impl TranscriptionClient for Client {
171 type TranscriptionModel = TranscriptionModel;
172
173 fn transcription_model(&self, model: &str) -> TranscriptionModel {
185 TranscriptionModel::new(self.clone(), model)
186 }
187}
188
189impl VerifyClient for Client {
190 #[cfg_attr(feature = "worker", worker::send)]
191 async fn verify(&self) -> Result<(), VerifyError> {
192 let response = self.get("/models").send().await?;
193 match response.status() {
194 reqwest::StatusCode::OK => Ok(()),
195 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
196 reqwest::StatusCode::INTERNAL_SERVER_ERROR
197 | reqwest::StatusCode::SERVICE_UNAVAILABLE
198 | reqwest::StatusCode::BAD_GATEWAY => {
199 Err(VerifyError::ProviderError(response.text().await?))
200 }
201 _ => {
202 response.error_for_status()?;
203 Ok(())
204 }
205 }
206 }
207}
208
209impl_conversion_traits!(
210 AsEmbeddings,
211 AsImageGeneration,
212 AsAudioGeneration for Client
213);
214
215#[derive(Debug, Deserialize)]
216struct ApiErrorResponse {
217 message: String,
218}
219
220#[derive(Debug, Deserialize)]
221#[serde(untagged)]
222enum ApiResponse<T> {
223 Ok(T),
224 Err(ApiErrorResponse),
225}
226
227#[derive(Debug, Serialize, Deserialize)]
228pub struct Message {
229 pub role: String,
230 pub content: Option<String>,
231 #[serde(skip_serializing_if = "Option::is_none")]
232 pub reasoning: Option<String>,
233}
234
235impl TryFrom<Message> for message::Message {
236 type Error = message::MessageError;
237
238 fn try_from(message: Message) -> Result<Self, Self::Error> {
239 match message.role.as_str() {
240 "user" => Ok(Self::User {
241 content: OneOrMany::one(
242 message
243 .content
244 .map(|content| message::UserContent::text(&content))
245 .ok_or_else(|| {
246 message::MessageError::ConversionError("Empty user message".to_string())
247 })?,
248 ),
249 }),
250 "assistant" => Ok(Self::Assistant {
251 id: None,
252 content: OneOrMany::one(
253 message
254 .content
255 .map(|content| message::AssistantContent::text(&content))
256 .ok_or_else(|| {
257 message::MessageError::ConversionError(
258 "Empty assistant message".to_string(),
259 )
260 })?,
261 ),
262 }),
263 _ => Err(message::MessageError::ConversionError(format!(
264 "Unknown role: {}",
265 message.role
266 ))),
267 }
268 }
269}
270
271impl TryFrom<message::Message> for Message {
272 type Error = message::MessageError;
273
274 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
275 match message {
276 message::Message::User { content } => Ok(Self {
277 role: "user".to_string(),
278 content: content.iter().find_map(|c| match c {
279 message::UserContent::Text(text) => Some(text.text.clone()),
280 _ => None,
281 }),
282 reasoning: None,
283 }),
284 message::Message::Assistant { content, .. } => {
285 let mut text_content: Option<String> = None;
286 let mut groq_reasoning: Option<String> = None;
287
288 for c in content.iter() {
289 match c {
290 message::AssistantContent::Text(text) => {
291 text_content = Some(
292 text_content
293 .map(|mut existing| {
294 existing.push('\n');
295 existing.push_str(&text.text);
296 existing
297 })
298 .unwrap_or_else(|| text.text.clone()),
299 );
300 }
301 message::AssistantContent::ToolCall(_tool_call) => {
302 return Err(MessageError::ConversionError(
303 "Tool calls do not exist on this message".into(),
304 ));
305 }
306 message::AssistantContent::Reasoning(message::Reasoning {
307 reasoning,
308 ..
309 }) => {
310 groq_reasoning =
311 Some(reasoning.first().cloned().unwrap_or(String::new()));
312 }
313 }
314 }
315
316 Ok(Self {
317 role: "assistant".to_string(),
318 content: text_content,
319 reasoning: groq_reasoning,
320 })
321 }
322 }
323 }
324}
325
326pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
331pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
333pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
335pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
337pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
339pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
341pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
343pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
345pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
347pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
349pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
351pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
353pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
355
356#[derive(Clone, Debug)]
357pub struct CompletionModel {
358 client: Client,
359 pub model: String,
361}
362
363impl CompletionModel {
364 pub fn new(client: Client, model: &str) -> Self {
365 Self {
366 client,
367 model: model.to_string(),
368 }
369 }
370
371 fn create_completion_request(
372 &self,
373 completion_request: CompletionRequest,
374 ) -> Result<Value, CompletionError> {
375 let mut partial_history = vec![];
377 if let Some(docs) = completion_request.normalized_documents() {
378 partial_history.push(docs);
379 }
380 partial_history.extend(completion_request.chat_history);
381
382 let mut full_history: Vec<Message> =
384 completion_request
385 .preamble
386 .map_or_else(Vec::new, |preamble| {
387 vec![Message {
388 role: "system".to_string(),
389 content: Some(preamble),
390 reasoning: None,
391 }]
392 });
393
394 full_history.extend(
396 partial_history
397 .into_iter()
398 .map(message::Message::try_into)
399 .collect::<Result<Vec<Message>, _>>()?,
400 );
401
402 let request = if completion_request.tools.is_empty() {
403 json!({
404 "model": self.model,
405 "messages": full_history,
406 "temperature": completion_request.temperature,
407 })
408 } else {
409 json!({
410 "model": self.model,
411 "messages": full_history,
412 "temperature": completion_request.temperature,
413 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
414 "tool_choice": "auto",
415 "reasoning_format": "parsed"
416 })
417 };
418
419 let request = if let Some(params) = completion_request.additional_params {
420 json_utils::merge(request, params)
421 } else {
422 request
423 };
424
425 Ok(request)
426 }
427}
428
429impl completion::CompletionModel for CompletionModel {
430 type Response = CompletionResponse;
431 type StreamingResponse = StreamingCompletionResponse;
432
433 #[cfg_attr(feature = "worker", worker::send)]
434 async fn completion(
435 &self,
436 completion_request: CompletionRequest,
437 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
438 let request = self.create_completion_request(completion_request)?;
439
440 let response = self
441 .client
442 .post("/chat/completions")
443 .json(&request)
444 .send()
445 .await?;
446
447 if response.status().is_success() {
448 match response.json::<ApiResponse<CompletionResponse>>().await? {
449 ApiResponse::Ok(response) => {
450 tracing::info!(target: "rig",
451 "groq completion token usage: {:?}",
452 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
453 );
454 response.try_into()
455 }
456 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
457 }
458 } else {
459 Err(CompletionError::ProviderError(response.text().await?))
460 }
461 }
462
463 #[cfg_attr(feature = "worker", worker::send)]
464 async fn stream(
465 &self,
466 request: CompletionRequest,
467 ) -> Result<
468 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
469 CompletionError,
470 > {
471 let mut request = self.create_completion_request(request)?;
472
473 request = merge(
474 request,
475 json!({"stream": true, "stream_options": {"include_usage": true}}),
476 );
477
478 let builder = self.client.post("/chat/completions").json(&request);
479
480 send_compatible_streaming_request(builder).await
481 }
482}
483
484pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
488pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
489pub const DISTIL_WHISPER_LARGE_V3: &str = "distil-whisper-large-v3-en";
490
491#[derive(Clone)]
492pub struct TranscriptionModel {
493 client: Client,
494 pub model: String,
496}
497
498impl TranscriptionModel {
499 pub fn new(client: Client, model: &str) -> Self {
500 Self {
501 client,
502 model: model.to_string(),
503 }
504 }
505}
506impl transcription::TranscriptionModel for TranscriptionModel {
507 type Response = TranscriptionResponse;
508
509 #[cfg_attr(feature = "worker", worker::send)]
510 async fn transcription(
511 &self,
512 request: transcription::TranscriptionRequest,
513 ) -> Result<
514 transcription::TranscriptionResponse<Self::Response>,
515 transcription::TranscriptionError,
516 > {
517 let data = request.data;
518
519 let mut body = reqwest::multipart::Form::new()
520 .text("model", self.model.clone())
521 .text("language", request.language)
522 .part(
523 "file",
524 Part::bytes(data).file_name(request.filename.clone()),
525 );
526
527 if let Some(prompt) = request.prompt {
528 body = body.text("prompt", prompt.clone());
529 }
530
531 if let Some(ref temperature) = request.temperature {
532 body = body.text("temperature", temperature.to_string());
533 }
534
535 if let Some(ref additional_params) = request.additional_params {
536 for (key, value) in additional_params
537 .as_object()
538 .expect("Additional Parameters to OpenAI Transcription should be a map")
539 {
540 body = body.text(key.to_owned(), value.to_string());
541 }
542 }
543
544 let response = self
545 .client
546 .post("audio/transcriptions")
547 .multipart(body)
548 .send()
549 .await?;
550
551 if response.status().is_success() {
552 match response
553 .json::<ApiResponse<TranscriptionResponse>>()
554 .await?
555 {
556 ApiResponse::Ok(response) => response.try_into(),
557 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
558 api_error_response.message,
559 )),
560 }
561 } else {
562 Err(TranscriptionError::ProviderError(response.text().await?))
563 }
564 }
565}
566
567#[derive(Deserialize, Debug)]
568#[serde(untagged)]
569pub enum StreamingDelta {
570 Reasoning {
571 reasoning: String,
572 },
573 MessageContent {
574 #[serde(default)]
575 content: Option<String>,
576 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
577 tool_calls: Vec<StreamingToolCall>,
578 },
579}
580
581#[derive(Deserialize, Debug)]
582struct StreamingChoice {
583 delta: StreamingDelta,
584}
585
586#[derive(Deserialize, Debug)]
587struct StreamingCompletionChunk {
588 choices: Vec<StreamingChoice>,
589 usage: Option<Usage>,
590}
591
592#[derive(Clone, Deserialize, Serialize, Debug)]
593pub struct StreamingCompletionResponse {
594 pub usage: Usage,
595}
596
597impl GetTokenUsage for StreamingCompletionResponse {
598 fn token_usage(&self) -> Option<crate::completion::Usage> {
599 let mut usage = crate::completion::Usage::new();
600
601 usage.input_tokens = self.usage.prompt_tokens as u64;
602 usage.total_tokens = self.usage.total_tokens as u64;
603 usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
604
605 Some(usage)
606 }
607}
608
609pub async fn send_compatible_streaming_request(
610 request_builder: RequestBuilder,
611) -> Result<
612 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
613 CompletionError,
614> {
615 let response = request_builder.send().await?;
616
617 if !response.status().is_success() {
618 return Err(CompletionError::ProviderError(format!(
619 "{}: {}",
620 response.status(),
621 response.text().await?
622 )));
623 }
624
625 let inner = Box::pin(async_stream::stream! {
627 let mut stream = response.bytes_stream();
628
629 let mut final_usage = Usage {
630 prompt_tokens: 0,
631 total_tokens: 0
632 };
633
634 let mut partial_data = None;
635 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
636
637 while let Some(chunk_result) = stream.next().await {
638 let chunk = match chunk_result {
639 Ok(c) => c,
640 Err(e) => {
641 yield Err(CompletionError::from(e));
642 break;
643 }
644 };
645
646 let text = match String::from_utf8(chunk.to_vec()) {
647 Ok(t) => t,
648 Err(e) => {
649 yield Err(CompletionError::ResponseError(e.to_string()));
650 break;
651 }
652 };
653
654
655 for line in text.lines() {
656 let mut line = line.to_string();
657
658 if partial_data.is_some() {
660 line = format!("{}{}", partial_data.unwrap(), line);
661 partial_data = None;
662 }
663 else {
665 let Some(data) = line.strip_prefix("data:") else {
666 continue;
667 };
668
669 let data = data.trim_start();
670
671 if !line.ends_with("}") {
673 partial_data = Some(data.to_string());
674 } else {
675 line = data.to_string();
676 }
677 }
678
679 let data = serde_json::from_str::<StreamingCompletionChunk>(&line);
680
681 let Ok(data) = data else {
682 let err = data.unwrap_err();
683 tracing::debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
684 continue;
685 };
686
687
688 if let Some(choice) = data.choices.first() {
689 let delta = &choice.delta;
690
691 match delta {
692 StreamingDelta::Reasoning { reasoning } => {
693 yield Ok(crate::streaming::RawStreamingChoice::Reasoning { id: None, reasoning: reasoning.to_string() })
694 },
695 StreamingDelta::MessageContent { content, tool_calls } => {
696 if !tool_calls.is_empty() {
697 for tool_call in tool_calls {
698 let function = tool_call.function.clone();
699 if function.name.is_some() && function.arguments.is_empty() {
703 let id = tool_call.id.clone().unwrap_or("".to_string());
704
705 calls.insert(tool_call.index, (id, function.name.clone().unwrap(), "".to_string()));
706 }
707 else if function.name.clone().is_none_or(|s| s.is_empty()) && !function.arguments.is_empty() {
711 let Some((id, name, arguments)) = calls.get(&tool_call.index) else {
712 tracing::debug!("Partial tool call received but tool call was never started.");
713 continue;
714 };
715
716 let new_arguments = &function.arguments;
717 let arguments = format!("{arguments}{new_arguments}");
718
719 calls.insert(tool_call.index, (id.clone(), name.clone(), arguments));
720 }
721 else {
723 let id = tool_call.id.clone().unwrap_or("".to_string());
724 let name = function.name.expect("function name should be present for complete tool call");
725 let arguments = function.arguments;
726 let Ok(arguments) = serde_json::from_str(&arguments) else {
727 tracing::debug!("Couldn't serialize '{}' as a json value", arguments);
728 continue;
729 };
730
731 yield Ok(crate::streaming::RawStreamingChoice::ToolCall {id, name, arguments, call_id: None })
732 }
733 }
734 }
735
736 if let Some(content) = &content {
737 yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()))
738 }
739 }
740 }
741 }
742
743
744 if let Some(usage) = data.usage {
745 final_usage = usage.clone();
746 }
747 }
748 }
749
750 for (_, (id, name, arguments)) in calls {
751 let Ok(arguments) = serde_json::from_str(&arguments) else {
752 continue;
753 };
754
755 yield Ok(RawStreamingChoice::ToolCall {id, name, arguments, call_id: None });
756 }
757
758 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
759 usage: final_usage.clone()
760 }))
761 });
762
763 Ok(crate::streaming::StreamingCompletionResponse::stream(inner))
764}