1use bytes::Bytes;
12use http::Request;
13use serde_json::{Map, Value};
14use std::collections::HashMap;
15use tracing::info_span;
16use tracing_futures::Instrument;
17
18use super::openai::{
19 CompletionResponse, Message as OpenAIMessage, StreamingToolCall, TranscriptionResponse, Usage,
20};
21use crate::client::{
22 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
23 ProviderClient,
24};
25use crate::completion::GetTokenUsage;
26use crate::http_client::multipart::Part;
27use crate::http_client::sse::{Event, GenericEventSource};
28use crate::http_client::{self, HttpClientExt, MultipartForm};
29use crate::json_utils::empty_or_none;
30use crate::providers::openai::{AssistantContent, Function, ToolType};
31use async_stream::stream;
32use futures::StreamExt;
33
34use crate::{
35 completion::{self, CompletionError, CompletionRequest},
36 json_utils,
37 message::{self},
38 providers::openai::ToolDefinition,
39 transcription::{self, TranscriptionError},
40};
41use serde::{Deserialize, Serialize};
42
43const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
47
48#[derive(Debug, Default, Clone, Copy)]
49pub struct GroqExt;
50#[derive(Debug, Default, Clone, Copy)]
51pub struct GroqBuilder;
52
53type GroqApiKey = BearerAuth;
54
55impl Provider for GroqExt {
56 type Builder = GroqBuilder;
57 const VERIFY_PATH: &'static str = "/models";
58}
59
60impl<H> Capabilities<H> for GroqExt {
61 type Completion = Capable<CompletionModel<H>>;
62 type Embeddings = Nothing;
63 type Transcription = Capable<TranscriptionModel<H>>;
64 type ModelListing = Nothing;
65 #[cfg(feature = "image")]
66 type ImageGeneration = Nothing;
67
68 #[cfg(feature = "audio")]
69 type AudioGeneration = Nothing;
70}
71
72impl DebugExt for GroqExt {}
73
74impl ProviderBuilder for GroqBuilder {
75 type Extension<H>
76 = GroqExt
77 where
78 H: HttpClientExt;
79 type ApiKey = GroqApiKey;
80
81 const BASE_URL: &'static str = GROQ_API_BASE_URL;
82
83 fn build<H>(
84 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
85 ) -> http_client::Result<Self::Extension<H>>
86 where
87 H: HttpClientExt,
88 {
89 Ok(GroqExt)
90 }
91}
92
93pub type Client<H = reqwest::Client> = client::Client<GroqExt, H>;
94pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<GroqBuilder, String, H>;
95
96impl ProviderClient for Client {
97 type Input = String;
98
99 fn from_env() -> Self {
102 let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
103 Self::new(&api_key).unwrap()
104 }
105
106 fn from_val(input: Self::Input) -> Self {
107 Self::new(&input).unwrap()
108 }
109}
110
111#[derive(Debug, Deserialize)]
112struct ApiErrorResponse {
113 message: String,
114}
115
116#[derive(Debug, Deserialize)]
117#[serde(untagged)]
118enum ApiResponse<T> {
119 Ok(T),
120 Err(ApiErrorResponse),
121}
122
123pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
129pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
131pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
133pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
135pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
137pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
139pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
141pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
143pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
145pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
147pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
149pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
151pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
153
154#[derive(Clone, Debug, Serialize, Deserialize)]
155#[serde(rename_all = "lowercase")]
156pub enum ReasoningFormat {
157 Parsed,
158 Raw,
159 Hidden,
160}
161
162#[derive(Debug, Serialize, Deserialize)]
163pub(super) struct GroqCompletionRequest {
164 model: String,
165 pub messages: Vec<OpenAIMessage>,
166 #[serde(skip_serializing_if = "Option::is_none")]
167 temperature: Option<f64>,
168 #[serde(skip_serializing_if = "Vec::is_empty")]
169 tools: Vec<ToolDefinition>,
170 #[serde(skip_serializing_if = "Option::is_none")]
171 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
172 #[serde(flatten, skip_serializing_if = "Option::is_none")]
173 pub additional_params: Option<GroqAdditionalParameters>,
174 pub(super) stream: bool,
175 #[serde(skip_serializing_if = "Option::is_none")]
176 pub(super) stream_options: Option<StreamOptions>,
177}
178
179#[derive(Debug, Serialize, Deserialize, Default)]
180pub(super) struct StreamOptions {
181 pub(super) include_usage: bool,
182}
183
184impl TryFrom<(&str, CompletionRequest)> for GroqCompletionRequest {
185 type Error = CompletionError;
186
187 fn try_from((model, mut req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
188 if req.output_schema.is_some() {
189 tracing::warn!("Structured outputs currently not supported for Groq");
190 }
191 let model = req.model.clone().unwrap_or_else(|| model.to_string());
192 let mut partial_history = vec![];
194 if let Some(docs) = req.normalized_documents() {
195 partial_history.push(docs);
196 }
197 partial_history.extend(req.chat_history);
198
199 let mut full_history: Vec<OpenAIMessage> = match &req.preamble {
201 Some(preamble) => vec![OpenAIMessage::system(preamble)],
202 None => vec![],
203 };
204
205 full_history.extend(
207 partial_history
208 .into_iter()
209 .map(message::Message::try_into)
210 .collect::<Result<Vec<Vec<OpenAIMessage>>, _>>()?
211 .into_iter()
212 .flatten()
213 .collect::<Vec<_>>(),
214 );
215
216 let tool_choice = req
217 .tool_choice
218 .clone()
219 .map(crate::providers::openai::ToolChoice::try_from)
220 .transpose()?;
221
222 let mut additional_params_payload = req.additional_params.take().unwrap_or(Value::Null);
223 let native_tools =
224 extract_native_tools_from_additional_params(&mut additional_params_payload)?;
225
226 let mut additional_params: Option<GroqAdditionalParameters> =
227 if additional_params_payload.is_null() {
228 None
229 } else {
230 Some(serde_json::from_value(additional_params_payload)?)
231 };
232 apply_native_tools_to_additional_params(&mut additional_params, native_tools);
233
234 Ok(Self {
235 model: model.to_string(),
236 messages: full_history,
237 temperature: req.temperature,
238 tools: req
239 .tools
240 .clone()
241 .into_iter()
242 .map(ToolDefinition::from)
243 .collect::<Vec<_>>(),
244 tool_choice,
245 additional_params,
246 stream: false,
247 stream_options: None,
248 })
249 }
250}
251
252fn extract_native_tools_from_additional_params(
253 additional_params: &mut Value,
254) -> Result<Vec<Value>, CompletionError> {
255 if let Some(map) = additional_params.as_object_mut()
256 && let Some(raw_tools) = map.remove("tools")
257 {
258 return serde_json::from_value::<Vec<Value>>(raw_tools).map_err(|err| {
259 CompletionError::RequestError(
260 format!("Invalid Groq `additional_params.tools` payload: {err}").into(),
261 )
262 });
263 }
264
265 Ok(Vec::new())
266}
267
268fn apply_native_tools_to_additional_params(
269 additional_params: &mut Option<GroqAdditionalParameters>,
270 native_tools: Vec<Value>,
271) {
272 if native_tools.is_empty() {
273 return;
274 }
275
276 let params = additional_params.get_or_insert_with(GroqAdditionalParameters::default);
277 let extra = params.extra.get_or_insert_with(Map::new);
278
279 let mut compound_custom = match extra.remove("compound_custom") {
280 Some(Value::Object(map)) => map,
281 _ => Map::new(),
282 };
283
284 let mut enabled_tools = match compound_custom.remove("enabled_tools") {
285 Some(Value::Array(values)) => values,
286 _ => Vec::new(),
287 };
288
289 for native_tool in native_tools {
290 let already_enabled = enabled_tools
291 .iter()
292 .any(|existing| native_tools_match(existing, &native_tool));
293 if !already_enabled {
294 enabled_tools.push(native_tool);
295 }
296 }
297
298 compound_custom.insert("enabled_tools".to_string(), Value::Array(enabled_tools));
299 extra.insert(
300 "compound_custom".to_string(),
301 Value::Object(compound_custom),
302 );
303}
304
305fn native_tools_match(lhs: &Value, rhs: &Value) -> bool {
306 if let (Some(lhs_type), Some(rhs_type)) = (native_tool_kind(lhs), native_tool_kind(rhs)) {
307 return lhs_type == rhs_type;
308 }
309
310 lhs == rhs
311}
312
313fn native_tool_kind(value: &Value) -> Option<&str> {
314 match value {
315 Value::String(kind) => Some(kind),
316 Value::Object(map) => map.get("type").and_then(Value::as_str),
317 _ => None,
318 }
319}
320
321#[derive(Clone, Debug, Default, Serialize, Deserialize)]
323pub struct GroqAdditionalParameters {
324 #[serde(skip_serializing_if = "Option::is_none")]
326 pub reasoning_format: Option<ReasoningFormat>,
327 #[serde(skip_serializing_if = "Option::is_none")]
329 pub include_reasoning: Option<bool>,
330 #[serde(flatten, skip_serializing_if = "Option::is_none")]
332 pub extra: Option<Map<String, serde_json::Value>>,
333}
334
335#[derive(Clone, Debug)]
336pub struct CompletionModel<T = reqwest::Client> {
337 client: Client<T>,
338 pub model: String,
340}
341
342impl<T> CompletionModel<T> {
343 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
344 Self {
345 client,
346 model: model.into(),
347 }
348 }
349}
350
351impl<T> completion::CompletionModel for CompletionModel<T>
352where
353 T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
354{
355 type Response = CompletionResponse;
356 type StreamingResponse = StreamingCompletionResponse;
357
358 type Client = Client<T>;
359
360 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
361 Self::new(client.clone(), model)
362 }
363
364 async fn completion(
365 &self,
366 completion_request: CompletionRequest,
367 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
368 let span = if tracing::Span::current().is_disabled() {
369 info_span!(
370 target: "rig::completions",
371 "chat",
372 gen_ai.operation.name = "chat",
373 gen_ai.provider.name = "groq",
374 gen_ai.request.model = self.model,
375 gen_ai.system_instructions = tracing::field::Empty,
376 gen_ai.response.id = tracing::field::Empty,
377 gen_ai.response.model = tracing::field::Empty,
378 gen_ai.usage.output_tokens = tracing::field::Empty,
379 gen_ai.usage.input_tokens = tracing::field::Empty,
380 gen_ai.usage.cached_tokens = tracing::field::Empty,
381 )
382 } else {
383 tracing::Span::current()
384 };
385
386 span.record("gen_ai.system_instructions", &completion_request.preamble);
387
388 let request = GroqCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
389
390 if tracing::enabled!(tracing::Level::TRACE) {
391 tracing::trace!(target: "rig::completions",
392 "Groq completion request: {}",
393 serde_json::to_string_pretty(&request)?
394 );
395 }
396
397 let body = serde_json::to_vec(&request)?;
398 let req = self
399 .client
400 .post("/chat/completions")?
401 .body(body)
402 .map_err(|e| http_client::Error::Instance(e.into()))?;
403
404 let async_block = async move {
405 let response = self.client.send::<_, Bytes>(req).await?;
406 let status = response.status();
407 let response_body = response.into_body().into_future().await?.to_vec();
408
409 if status.is_success() {
410 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
411 ApiResponse::Ok(response) => {
412 let span = tracing::Span::current();
413 span.record("gen_ai.response.id", response.id.clone());
414 span.record("gen_ai.response.model_name", response.model.clone());
415 if let Some(ref usage) = response.usage {
416 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
417 span.record(
418 "gen_ai.usage.output_tokens",
419 usage.total_tokens - usage.prompt_tokens,
420 );
421 span.record(
422 "gen_ai.usage.cached_tokens",
423 usage
424 .prompt_tokens_details
425 .as_ref()
426 .map(|d| d.cached_tokens)
427 .unwrap_or(0),
428 );
429 }
430
431 if tracing::enabled!(tracing::Level::TRACE) {
432 tracing::trace!(target: "rig::completions",
433 "Groq completion response: {}",
434 serde_json::to_string_pretty(&response)?
435 );
436 }
437
438 response.try_into()
439 }
440 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
441 }
442 } else {
443 Err(CompletionError::ProviderError(
444 String::from_utf8_lossy(&response_body).to_string(),
445 ))
446 }
447 };
448
449 tracing::Instrument::instrument(async_block, span).await
450 }
451
452 async fn stream(
453 &self,
454 request: CompletionRequest,
455 ) -> Result<
456 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
457 CompletionError,
458 > {
459 let span = if tracing::Span::current().is_disabled() {
460 info_span!(
461 target: "rig::completions",
462 "chat_streaming",
463 gen_ai.operation.name = "chat_streaming",
464 gen_ai.provider.name = "groq",
465 gen_ai.request.model = self.model,
466 gen_ai.system_instructions = tracing::field::Empty,
467 gen_ai.response.id = tracing::field::Empty,
468 gen_ai.response.model = tracing::field::Empty,
469 gen_ai.usage.output_tokens = tracing::field::Empty,
470 gen_ai.usage.input_tokens = tracing::field::Empty,
471 gen_ai.usage.cached_tokens = tracing::field::Empty,
472 )
473 } else {
474 tracing::Span::current()
475 };
476
477 span.record("gen_ai.system_instructions", &request.preamble);
478
479 let mut request = GroqCompletionRequest::try_from((self.model.as_ref(), request))?;
480
481 request.stream = true;
482 request.stream_options = Some(StreamOptions {
483 include_usage: true,
484 });
485
486 if tracing::enabled!(tracing::Level::TRACE) {
487 tracing::trace!(target: "rig::completions",
488 "Groq streaming completion request: {}",
489 serde_json::to_string_pretty(&request)?
490 );
491 }
492
493 let body = serde_json::to_vec(&request)?;
494 let req = self
495 .client
496 .post("/chat/completions")?
497 .body(body)
498 .map_err(|e| http_client::Error::Instance(e.into()))?;
499
500 tracing::Instrument::instrument(
501 send_compatible_streaming_request(self.client.clone(), req),
502 span,
503 )
504 .await
505 }
506}
507
508pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
513pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
514pub const DISTIL_WHISPER_LARGE_V3_EN: &str = "distil-whisper-large-v3-en";
515
516#[derive(Clone)]
517pub struct TranscriptionModel<T> {
518 client: Client<T>,
519 pub model: String,
521}
522
523impl<T> TranscriptionModel<T> {
524 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
525 Self {
526 client,
527 model: model.into(),
528 }
529 }
530}
531impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
532where
533 T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
534{
535 type Response = TranscriptionResponse;
536
537 type Client = Client<T>;
538
539 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
540 Self::new(client.clone(), model)
541 }
542
543 async fn transcription(
544 &self,
545 request: transcription::TranscriptionRequest,
546 ) -> Result<
547 transcription::TranscriptionResponse<Self::Response>,
548 transcription::TranscriptionError,
549 > {
550 let data = request.data;
551
552 let mut body = MultipartForm::new()
553 .text("model", self.model.clone())
554 .part(Part::bytes("file", data).filename(request.filename.clone()));
555
556 if let Some(language) = request.language {
557 body = body.text("language", language);
558 }
559
560 if let Some(prompt) = request.prompt {
561 body = body.text("prompt", prompt.clone());
562 }
563
564 if let Some(ref temperature) = request.temperature {
565 body = body.text("temperature", temperature.to_string());
566 }
567
568 if let Some(ref additional_params) = request.additional_params {
569 for (key, value) in additional_params
570 .as_object()
571 .expect("Additional Parameters to OpenAI Transcription should be a map")
572 {
573 body = body.text(key.to_owned(), value.to_string());
574 }
575 }
576
577 let req = self
578 .client
579 .post("/audio/transcriptions")?
580 .body(body)
581 .unwrap();
582
583 let response = self.client.send_multipart::<Bytes>(req).await.unwrap();
584
585 let status = response.status();
586 let response_body = response.into_body().into_future().await?.to_vec();
587
588 if status.is_success() {
589 match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
590 ApiResponse::Ok(response) => response.try_into(),
591 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
592 api_error_response.message,
593 )),
594 }
595 } else {
596 Err(TranscriptionError::ProviderError(
597 String::from_utf8_lossy(&response_body).to_string(),
598 ))
599 }
600 }
601}
602
603#[derive(Deserialize, Debug)]
604#[serde(untagged)]
605enum StreamingDelta {
606 Reasoning {
607 reasoning: String,
608 },
609 MessageContent {
610 #[serde(default)]
611 content: Option<String>,
612 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
613 tool_calls: Vec<StreamingToolCall>,
614 },
615}
616
617#[derive(Deserialize, Debug)]
618struct StreamingChoice {
619 delta: StreamingDelta,
620}
621
622#[derive(Deserialize, Debug)]
623struct StreamingCompletionChunk {
624 choices: Vec<StreamingChoice>,
625 usage: Option<Usage>,
626}
627
628#[derive(Clone, Deserialize, Serialize, Debug)]
629pub struct StreamingCompletionResponse {
630 pub usage: Usage,
631}
632
633impl GetTokenUsage for StreamingCompletionResponse {
634 fn token_usage(&self) -> Option<crate::completion::Usage> {
635 let mut usage = crate::completion::Usage::new();
636
637 usage.input_tokens = self.usage.prompt_tokens as u64;
638 usage.total_tokens = self.usage.total_tokens as u64;
639 usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
640 usage.cached_input_tokens = self
641 .usage
642 .prompt_tokens_details
643 .as_ref()
644 .map(|d| d.cached_tokens as u64)
645 .unwrap_or(0);
646
647 Some(usage)
648 }
649}
650
651pub async fn send_compatible_streaming_request<T>(
652 client: T,
653 req: Request<Vec<u8>>,
654) -> Result<
655 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
656 CompletionError,
657>
658where
659 T: HttpClientExt + Clone + 'static,
660{
661 let span = tracing::Span::current();
662
663 let mut event_source = GenericEventSource::new(client, req);
664
665 let stream = stream! {
666 let span = tracing::Span::current();
667 let mut final_usage = Usage {
668 prompt_tokens: 0,
669 total_tokens: 0,
670 prompt_tokens_details: None,
671 };
672
673 let mut text_response = String::new();
674
675 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
676
677 while let Some(event_result) = event_source.next().await {
678 match event_result {
679 Ok(Event::Open) => {
680 tracing::trace!("SSE connection opened");
681 continue;
682 }
683
684 Ok(Event::Message(message)) => {
685 let data_str = message.data.trim();
686
687 let parsed = serde_json::from_str::<StreamingCompletionChunk>(data_str);
688 let Ok(data) = parsed else {
689 let err = parsed.unwrap_err();
690 tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
691 continue;
692 };
693
694 if let Some(choice) = data.choices.first() {
695 match &choice.delta {
696 StreamingDelta::Reasoning { reasoning } => {
697 yield Ok(crate::streaming::RawStreamingChoice::ReasoningDelta {
698 id: None,
699 reasoning: reasoning.to_string(),
700 });
701 }
702
703 StreamingDelta::MessageContent { content, tool_calls } => {
704 for tool_call in tool_calls {
706 let function = &tool_call.function;
707
708 if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
710 && empty_or_none(&function.arguments)
711 {
712 let id = tool_call.id.clone().unwrap_or_default();
713 let name = function.name.clone().unwrap();
714 calls.insert(tool_call.index, (id, name, String::new()));
715 }
716 else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
718 && let Some(arguments) = &function.arguments
719 && !arguments.is_empty()
720 {
721 if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
722 let combined = format!("{}{}", existing_args, arguments);
723 calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
724 } else {
725 tracing::debug!("Partial tool call received but tool call was never started.");
726 }
727 }
728 else {
730 let id = tool_call.id.clone().unwrap_or_default();
731 let name = function.name.clone().unwrap_or_default();
732 let arguments_str = function.arguments.clone().unwrap_or_default();
733
734 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
735 tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
736 continue;
737 };
738
739 yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
740 crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
741 ));
742 }
743 }
744
745 if let Some(content) = content {
747 text_response += content;
748 yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
749 }
750 }
751 }
752 }
753
754 if let Some(usage) = data.usage {
755 final_usage = usage.clone();
756 }
757 }
758
759 Err(crate::http_client::Error::StreamEnded) => break,
760 Err(err) => {
761 tracing::error!(?err, "SSE error");
762 yield Err(CompletionError::ResponseError(err.to_string()));
763 break;
764 }
765 }
766 }
767
768 event_source.close();
769
770 let mut tool_calls = Vec::new();
771 for (_, (id, name, arguments)) in calls {
773 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
774 continue;
775 };
776
777 tool_calls.push(rig::providers::openai::completion::ToolCall {
778 id: id.clone(),
779 r#type: ToolType::Function,
780 function: Function {
781 name: name.clone(),
782 arguments: arguments_json.clone()
783 }
784 });
785 yield Ok(crate::streaming::RawStreamingChoice::ToolCall(
786 crate::streaming::RawStreamingToolCall::new(id, name, arguments_json)
787 ));
788 }
789
790 let response_message = crate::providers::openai::completion::Message::Assistant {
791 content: vec![AssistantContent::Text { text: text_response }],
792 refusal: None,
793 audio: None,
794 name: None,
795 tool_calls
796 };
797
798 span.record("gen_ai.output.messages", serde_json::to_string(&vec![response_message]).unwrap());
799 span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
800 span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
801 span.record(
802 "gen_ai.usage.cached_tokens",
803 final_usage
804 .prompt_tokens_details
805 .as_ref()
806 .map(|d| d.cached_tokens)
807 .unwrap_or(0),
808 );
809
810 yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
812 StreamingCompletionResponse { usage: final_usage.clone() }
813 ));
814 }.instrument(span);
815
816 Ok(crate::streaming::StreamingCompletionResponse::stream(
817 Box::pin(stream),
818 ))
819}
820
821#[cfg(test)]
822mod tests {
823 use crate::{
824 OneOrMany,
825 providers::{
826 groq::{GroqAdditionalParameters, GroqCompletionRequest},
827 openai::{Message, UserContent},
828 },
829 };
830
831 #[test]
832 fn serialize_groq_request() {
833 let additional_params = GroqAdditionalParameters {
834 include_reasoning: Some(true),
835 reasoning_format: Some(super::ReasoningFormat::Parsed),
836 ..Default::default()
837 };
838
839 let groq = GroqCompletionRequest {
840 model: "openai/gpt-120b-oss".to_string(),
841 temperature: None,
842 tool_choice: None,
843 stream_options: None,
844 tools: Vec::new(),
845 messages: vec![Message::User {
846 content: OneOrMany::one(UserContent::Text {
847 text: "Hello world!".to_string(),
848 }),
849 name: None,
850 }],
851 stream: false,
852 additional_params: Some(additional_params),
853 };
854
855 let json = serde_json::to_value(&groq).unwrap();
856
857 assert_eq!(
858 json,
859 serde_json::json!({
860 "model": "openai/gpt-120b-oss",
861 "messages": [
862 {
863 "role": "user",
864 "content": "Hello world!"
865 }
866 ],
867 "stream": false,
868 "include_reasoning": true,
869 "reasoning_format": "parsed"
870 })
871 )
872 }
873 #[test]
874 fn test_client_initialization() {
875 let _client =
876 crate::providers::groq::Client::new("dummy-key").expect("Client::new() failed");
877 let _client_from_builder = crate::providers::groq::Client::builder()
878 .api_key("dummy-key")
879 .build()
880 .expect("Client::builder() failed");
881 }
882}