1use bytes::Bytes;
12use http::{Method, Request};
13use std::collections::HashMap;
14use tracing::info_span;
15use tracing_futures::Instrument;
16
17use super::openai::{CompletionResponse, StreamingToolCall, TranscriptionResponse, Usage};
18use crate::client::{CompletionClient, TranscriptionClient, VerifyClient, VerifyError};
19use crate::completion::GetTokenUsage;
20use crate::http_client::sse::{Event, GenericEventSource};
21use crate::http_client::{self, HttpClientExt};
22use crate::json_utils::merge;
23use crate::providers::openai::{AssistantContent, Function, ToolType};
24use async_stream::stream;
25use futures::StreamExt;
26
27use crate::{
28 OneOrMany,
29 completion::{self, CompletionError, CompletionRequest},
30 json_utils,
31 message::{self, MessageError},
32 providers::openai::ToolDefinition,
33 transcription::{self, TranscriptionError},
34};
35use reqwest::multipart::Part;
36use rig::client::ProviderClient;
37use rig::impl_conversion_traits;
38use serde::{Deserialize, Serialize};
39use serde_json::{Value, json};
40
41const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";
45
46pub struct ClientBuilder<'a, T = reqwest::Client> {
47 api_key: &'a str,
48 base_url: &'a str,
49 http_client: T,
50}
51
52impl<'a, T> ClientBuilder<'a, T>
53where
54 T: Default,
55{
56 pub fn new(api_key: &'a str) -> Self {
57 Self {
58 api_key,
59 base_url: GROQ_API_BASE_URL,
60 http_client: Default::default(),
61 }
62 }
63}
64
65impl<'a, T> ClientBuilder<'a, T> {
66 pub fn new_with_client(api_key: &'a str, http_client: T) -> Self {
67 Self {
68 api_key,
69 base_url: GROQ_API_BASE_URL,
70 http_client,
71 }
72 }
73
74 pub fn base_url(mut self, base_url: &'a str) -> Self {
75 self.base_url = base_url;
76 self
77 }
78
79 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
80 ClientBuilder {
81 api_key: self.api_key,
82 base_url: self.base_url,
83 http_client,
84 }
85 }
86
87 pub fn build(self) -> Client<T> {
88 Client {
89 base_url: self.base_url.to_string(),
90 api_key: self.api_key.to_string(),
91 http_client: self.http_client,
92 }
93 }
94}
95
96#[derive(Clone)]
97pub struct Client<T = reqwest::Client> {
98 base_url: String,
99 api_key: String,
100 http_client: T,
101}
102
103impl<T> std::fmt::Debug for Client<T>
104where
105 T: std::fmt::Debug,
106{
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 f.debug_struct("Client")
109 .field("base_url", &self.base_url)
110 .field("http_client", &self.http_client)
111 .field("api_key", &"<REDACTED>")
112 .finish()
113 }
114}
115
116impl<T> Client<T>
117where
118 T: HttpClientExt,
119{
120 fn req(
121 &self,
122 method: http_client::Method,
123 path: &str,
124 ) -> http_client::Result<http_client::Builder> {
125 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
126
127 http_client::with_bearer_auth(
128 http_client::Builder::new().method(method).uri(url),
129 &self.api_key,
130 )
131 }
132
133 fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
134 self.req(http_client::Method::GET, path)
135 }
136}
137
138impl Client<reqwest::Client> {
139 pub fn builder(api_key: &str) -> ClientBuilder<'_, reqwest::Client> {
140 ClientBuilder::new(api_key)
141 }
142
143 pub fn new(api_key: &str) -> Self {
144 ClientBuilder::new(api_key).build()
145 }
146
147 pub fn from_env() -> Self {
148 <Self as ProviderClient>::from_env()
149 }
150}
151
152impl<T> ProviderClient for Client<T>
153where
154 T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
155{
156 fn from_env() -> Self {
159 let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
160 ClientBuilder::<T>::new(&api_key).build()
161 }
162
163 fn from_val(input: crate::client::ProviderValue) -> Self {
164 let crate::client::ProviderValue::Simple(api_key) = input else {
165 panic!("Incorrect provider value type")
166 };
167 ClientBuilder::<T>::new(&api_key).build()
168 }
169}
170
171impl<T> CompletionClient for Client<T>
172where
173 T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
174{
175 type CompletionModel = CompletionModel<T>;
176
177 fn completion_model(&self, model: &str) -> Self::CompletionModel {
189 CompletionModel::new(self.clone(), model)
190 }
191}
192
193impl<T> TranscriptionClient for Client<T>
194where
195 T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
196{
197 type TranscriptionModel = TranscriptionModel<T>;
198
199 fn transcription_model(&self, model: &str) -> Self::TranscriptionModel {
211 TranscriptionModel::new(self.clone(), model)
212 }
213}
214
215impl<T> VerifyClient for Client<T>
216where
217 T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
218{
219 #[cfg_attr(feature = "worker", worker::send)]
220 async fn verify(&self) -> Result<(), VerifyError> {
221 let req = self
222 .get("/models")?
223 .body(http_client::NoBody)
224 .map_err(http_client::Error::from)?;
225
226 let response = HttpClientExt::send(&self.http_client, req).await?;
227
228 match response.status() {
229 reqwest::StatusCode::OK => Ok(()),
230 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
231 reqwest::StatusCode::INTERNAL_SERVER_ERROR
232 | reqwest::StatusCode::SERVICE_UNAVAILABLE
233 | reqwest::StatusCode::BAD_GATEWAY => {
234 let text = http_client::text(response).await?;
235 Err(VerifyError::ProviderError(text))
236 }
237 _ => {
238 Ok(())
240 }
241 }
242 }
243}
244
245impl_conversion_traits!(
246 AsEmbeddings,
247 AsImageGeneration,
248 AsAudioGeneration for Client<T>
249);
250
251#[derive(Debug, Deserialize)]
252struct ApiErrorResponse {
253 message: String,
254}
255
256#[derive(Debug, Deserialize)]
257#[serde(untagged)]
258enum ApiResponse<T> {
259 Ok(T),
260 Err(ApiErrorResponse),
261}
262
263#[derive(Debug, Serialize, Deserialize)]
264pub struct Message {
265 pub role: String,
266 pub content: Option<String>,
267 #[serde(skip_serializing_if = "Option::is_none")]
268 pub reasoning: Option<String>,
269}
270
271impl TryFrom<Message> for message::Message {
272 type Error = message::MessageError;
273
274 fn try_from(message: Message) -> Result<Self, Self::Error> {
275 match message.role.as_str() {
276 "user" => Ok(Self::User {
277 content: OneOrMany::one(
278 message
279 .content
280 .map(|content| message::UserContent::text(&content))
281 .ok_or_else(|| {
282 message::MessageError::ConversionError("Empty user message".to_string())
283 })?,
284 ),
285 }),
286 "assistant" => Ok(Self::Assistant {
287 id: None,
288 content: OneOrMany::one(
289 message
290 .content
291 .map(|content| message::AssistantContent::text(&content))
292 .ok_or_else(|| {
293 message::MessageError::ConversionError(
294 "Empty assistant message".to_string(),
295 )
296 })?,
297 ),
298 }),
299 _ => Err(message::MessageError::ConversionError(format!(
300 "Unknown role: {}",
301 message.role
302 ))),
303 }
304 }
305}
306
307impl TryFrom<message::Message> for Message {
308 type Error = message::MessageError;
309
310 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
311 match message {
312 message::Message::User { content } => Ok(Self {
313 role: "user".to_string(),
314 content: content.iter().find_map(|c| match c {
315 message::UserContent::Text(text) => Some(text.text.clone()),
316 _ => None,
317 }),
318 reasoning: None,
319 }),
320 message::Message::Assistant { content, .. } => {
321 let mut text_content: Option<String> = None;
322 let mut groq_reasoning: Option<String> = None;
323
324 for c in content.iter() {
325 match c {
326 message::AssistantContent::Text(text) => {
327 text_content = Some(
328 text_content
329 .map(|mut existing| {
330 existing.push('\n');
331 existing.push_str(&text.text);
332 existing
333 })
334 .unwrap_or_else(|| text.text.clone()),
335 );
336 }
337 message::AssistantContent::ToolCall(_tool_call) => {
338 return Err(MessageError::ConversionError(
339 "Tool calls do not exist on this message".into(),
340 ));
341 }
342 message::AssistantContent::Reasoning(message::Reasoning {
343 reasoning,
344 ..
345 }) => {
346 groq_reasoning =
347 Some(reasoning.first().cloned().unwrap_or(String::new()));
348 }
349 }
350 }
351
352 Ok(Self {
353 role: "assistant".to_string(),
354 content: text_content,
355 reasoning: groq_reasoning,
356 })
357 }
358 }
359 }
360}
361
362pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
367pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
369pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
371pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
373pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
375pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
377pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
379pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
381pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
383pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
385pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
387pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
389pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";
391
392#[derive(Clone, Debug)]
393pub struct CompletionModel<T> {
394 client: Client<T>,
395 pub model: String,
397}
398
399impl<T> CompletionModel<T> {
400 pub fn new(client: Client<T>, model: &str) -> Self {
401 Self {
402 client,
403 model: model.to_string(),
404 }
405 }
406
407 fn create_completion_request(
408 &self,
409 completion_request: CompletionRequest,
410 ) -> Result<Value, CompletionError> {
411 let mut partial_history = vec![];
413 if let Some(docs) = completion_request.normalized_documents() {
414 partial_history.push(docs);
415 }
416 partial_history.extend(completion_request.chat_history);
417
418 let mut full_history: Vec<Message> =
420 completion_request
421 .preamble
422 .map_or_else(Vec::new, |preamble| {
423 vec![Message {
424 role: "system".to_string(),
425 content: Some(preamble),
426 reasoning: None,
427 }]
428 });
429
430 full_history.extend(
432 partial_history
433 .into_iter()
434 .map(message::Message::try_into)
435 .collect::<Result<Vec<Message>, _>>()?,
436 );
437
438 let tool_choice = completion_request
439 .tool_choice
440 .map(crate::providers::openai::ToolChoice::try_from)
441 .transpose()?;
442
443 let request = if completion_request.tools.is_empty() {
444 json!({
445 "model": self.model,
446 "messages": full_history,
447 "temperature": completion_request.temperature,
448 })
449 } else {
450 json!({
451 "model": self.model,
452 "messages": full_history,
453 "temperature": completion_request.temperature,
454 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
455 "tool_choice": tool_choice,
456 "reasoning_format": "parsed"
457 })
458 };
459
460 let request = if let Some(params) = completion_request.additional_params {
461 json_utils::merge(request, params)
462 } else {
463 request
464 };
465
466 Ok(request)
467 }
468}
469
470impl<T> completion::CompletionModel for CompletionModel<T>
471where
472 T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
473{
474 type Response = CompletionResponse;
475 type StreamingResponse = StreamingCompletionResponse;
476
477 #[cfg_attr(feature = "worker", worker::send)]
478 async fn completion(
479 &self,
480 completion_request: CompletionRequest,
481 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
482 let preamble = completion_request.preamble.clone();
483
484 let request = self.create_completion_request(completion_request)?;
485 let span = if tracing::Span::current().is_disabled() {
486 info_span!(
487 target: "rig::completions",
488 "chat",
489 gen_ai.operation.name = "chat",
490 gen_ai.provider.name = "groq",
491 gen_ai.request.model = self.model,
492 gen_ai.system_instructions = preamble,
493 gen_ai.response.id = tracing::field::Empty,
494 gen_ai.response.model = tracing::field::Empty,
495 gen_ai.usage.output_tokens = tracing::field::Empty,
496 gen_ai.usage.input_tokens = tracing::field::Empty,
497 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
498 gen_ai.output.messages = tracing::field::Empty,
499 )
500 } else {
501 tracing::Span::current()
502 };
503
504 let body = serde_json::to_vec(&request)?;
505 let req = self
506 .client
507 .req(Method::POST, "/chat/completions")?
508 .header("Content-Type", "application/json")
509 .body(body)
510 .map_err(|e| http_client::Error::Instance(e.into()))?;
511
512 let async_block = async move {
513 let response = self.client.http_client.send::<_, Bytes>(req).await?;
514 let status = response.status();
515 let response_body = response.into_body().into_future().await?.to_vec();
516
517 if status.is_success() {
518 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
519 ApiResponse::Ok(response) => {
520 let span = tracing::Span::current();
521 span.record("gen_ai.response.id", response.id.clone());
522 span.record("gen_ai.response.model_name", response.model.clone());
523 span.record(
524 "gen_ai.output.messages",
525 serde_json::to_string(&response.choices).unwrap(),
526 );
527 if let Some(ref usage) = response.usage {
528 span.record("gen_ai.usage.input_tokens", usage.prompt_tokens);
529 span.record(
530 "gen_ai.usage.output_tokens",
531 usage.total_tokens - usage.prompt_tokens,
532 );
533 }
534 response.try_into()
535 }
536 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
537 }
538 } else {
539 Err(CompletionError::ProviderError(
540 String::from_utf8_lossy(&response_body).to_string(),
541 ))
542 }
543 };
544
545 tracing::Instrument::instrument(async_block, span).await
546 }
547
548 #[cfg_attr(feature = "worker", worker::send)]
549 async fn stream(
550 &self,
551 request: CompletionRequest,
552 ) -> Result<
553 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
554 CompletionError,
555 > {
556 let preamble = request.preamble.clone();
557 let mut request = self.create_completion_request(request)?;
558
559 request = merge(
560 request,
561 json!({"stream": true, "stream_options": {"include_usage": true}}),
562 );
563
564 let body = serde_json::to_vec(&request)?;
565 let req = self
566 .client
567 .req(Method::POST, "/chat/completions")?
568 .header("Content-Type", "application/json")
569 .body(body)
570 .map_err(|e| http_client::Error::Instance(e.into()))?;
571
572 let span = if tracing::Span::current().is_disabled() {
573 info_span!(
574 target: "rig::completions",
575 "chat_streaming",
576 gen_ai.operation.name = "chat_streaming",
577 gen_ai.provider.name = "groq",
578 gen_ai.request.model = self.model,
579 gen_ai.system_instructions = preamble,
580 gen_ai.response.id = tracing::field::Empty,
581 gen_ai.response.model = tracing::field::Empty,
582 gen_ai.usage.output_tokens = tracing::field::Empty,
583 gen_ai.usage.input_tokens = tracing::field::Empty,
584 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
585 gen_ai.output.messages = tracing::field::Empty,
586 )
587 } else {
588 tracing::Span::current()
589 };
590
591 tracing::Instrument::instrument(
592 send_compatible_streaming_request(self.client.http_client.clone(), req),
593 span,
594 )
595 .await
596 }
597}
598
599pub const WHISPER_LARGE_V3: &str = "whisper-large-v3";
603pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo";
604pub const DISTIL_WHISPER_LARGE_V3: &str = "distil-whisper-large-v3-en";
605
606#[derive(Clone)]
607pub struct TranscriptionModel<T> {
608 client: Client<T>,
609 pub model: String,
611}
612
613impl<T> TranscriptionModel<T> {
614 pub fn new(client: Client<T>, model: &str) -> Self {
615 Self {
616 client,
617 model: model.to_string(),
618 }
619 }
620}
621impl<T> transcription::TranscriptionModel for TranscriptionModel<T>
622where
623 T: HttpClientExt + Clone + Send + std::fmt::Debug + Default + 'static,
624{
625 type Response = TranscriptionResponse;
626
627 #[cfg_attr(feature = "worker", worker::send)]
628 async fn transcription(
629 &self,
630 request: transcription::TranscriptionRequest,
631 ) -> Result<
632 transcription::TranscriptionResponse<Self::Response>,
633 transcription::TranscriptionError,
634 > {
635 let data = request.data;
636
637 let mut body = reqwest::multipart::Form::new()
638 .text("model", self.model.clone())
639 .part(
640 "file",
641 Part::bytes(data).file_name(request.filename.clone()),
642 );
643
644 if let Some(language) = request.language {
645 body = body.text("language", language);
646 }
647
648 if let Some(prompt) = request.prompt {
649 body = body.text("prompt", prompt.clone());
650 }
651
652 if let Some(ref temperature) = request.temperature {
653 body = body.text("temperature", temperature.to_string());
654 }
655
656 if let Some(ref additional_params) = request.additional_params {
657 for (key, value) in additional_params
658 .as_object()
659 .expect("Additional Parameters to OpenAI Transcription should be a map")
660 {
661 body = body.text(key.to_owned(), value.to_string());
662 }
663 }
664
665 let req = self
666 .client
667 .req(Method::POST, "/audio/transcriptions")?
668 .body(body)
669 .unwrap();
670
671 let response = self
672 .client
673 .http_client
674 .send_multipart::<Bytes>(req)
675 .await
676 .unwrap();
677
678 let status = response.status();
679 let response_body = response.into_body().into_future().await?.to_vec();
680
681 if status.is_success() {
682 match serde_json::from_slice::<ApiResponse<TranscriptionResponse>>(&response_body)? {
683 ApiResponse::Ok(response) => response.try_into(),
684 ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError(
685 api_error_response.message,
686 )),
687 }
688 } else {
689 Err(TranscriptionError::ProviderError(
690 String::from_utf8_lossy(&response_body).to_string(),
691 ))
692 }
693 }
694}
695
696#[derive(Deserialize, Debug)]
697#[serde(untagged)]
698pub enum StreamingDelta {
699 Reasoning {
700 reasoning: String,
701 },
702 MessageContent {
703 #[serde(default)]
704 content: Option<String>,
705 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
706 tool_calls: Vec<StreamingToolCall>,
707 },
708}
709
710#[derive(Deserialize, Debug)]
711struct StreamingChoice {
712 delta: StreamingDelta,
713}
714
715#[derive(Deserialize, Debug)]
716struct StreamingCompletionChunk {
717 choices: Vec<StreamingChoice>,
718 usage: Option<Usage>,
719}
720
721#[derive(Clone, Deserialize, Serialize, Debug)]
722pub struct StreamingCompletionResponse {
723 pub usage: Usage,
724}
725
726impl GetTokenUsage for StreamingCompletionResponse {
727 fn token_usage(&self) -> Option<crate::completion::Usage> {
728 let mut usage = crate::completion::Usage::new();
729
730 usage.input_tokens = self.usage.prompt_tokens as u64;
731 usage.total_tokens = self.usage.total_tokens as u64;
732 usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
733
734 Some(usage)
735 }
736}
737
738pub async fn send_compatible_streaming_request<T>(
739 client: T,
740 req: Request<Vec<u8>>,
741) -> Result<
742 crate::streaming::StreamingCompletionResponse<StreamingCompletionResponse>,
743 CompletionError,
744>
745where
746 T: HttpClientExt + Clone + 'static,
747{
748 let span = tracing::Span::current();
749
750 let mut event_source = GenericEventSource::new(client, req);
751
752 let stream = stream! {
753 let span = tracing::Span::current();
754 let mut final_usage = Usage {
755 prompt_tokens: 0,
756 total_tokens: 0
757 };
758
759 let mut text_response = String::new();
760
761 let mut calls: HashMap<usize, (String, String, String)> = HashMap::new();
762
763 while let Some(event_result) = event_source.next().await {
764 match event_result {
765 Ok(Event::Open) => {
766 tracing::trace!("SSE connection opened");
767 continue;
768 }
769
770 Ok(Event::Message(message)) => {
771 let data_str = message.data.trim();
772
773 let parsed = serde_json::from_str::<StreamingCompletionChunk>(data_str);
774 let Ok(data) = parsed else {
775 let err = parsed.unwrap_err();
776 tracing::debug!("Couldn't parse SSE payload as StreamingCompletionChunk: {:?}", err);
777 continue;
778 };
779
780 if let Some(choice) = data.choices.first() {
781 match &choice.delta {
782 StreamingDelta::Reasoning { reasoning } => {
783 yield Ok(crate::streaming::RawStreamingChoice::Reasoning {
784 id: None,
785 reasoning: reasoning.to_string(),
786 signature: None,
787 });
788 }
789
790 StreamingDelta::MessageContent { content, tool_calls } => {
791 for tool_call in tool_calls {
793 let function = &tool_call.function;
794
795 if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false)
797 && function.arguments.is_empty()
798 {
799 let id = tool_call.id.clone().unwrap_or_default();
800 let name = function.name.clone().unwrap();
801 calls.insert(tool_call.index, (id, name, String::new()));
802 }
803 else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true)
805 && !function.arguments.is_empty()
806 {
807 if let Some((id, name, existing_args)) = calls.get(&tool_call.index) {
808 let combined = format!("{}{}", existing_args, function.arguments);
809 calls.insert(tool_call.index, (id.clone(), name.clone(), combined));
810 } else {
811 tracing::debug!("Partial tool call received but tool call was never started.");
812 }
813 }
814 else {
816 let id = tool_call.id.clone().unwrap_or_default();
817 let name = function.name.clone().unwrap_or_default();
818 let arguments_str = function.arguments.clone();
819
820 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments_str) else {
821 tracing::debug!("Couldn't parse tool call args '{}'", arguments_str);
822 continue;
823 };
824
825 yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
826 id,
827 name,
828 arguments: arguments_json,
829 call_id: None
830 });
831 }
832 }
833
834 if let Some(content) = content {
836 text_response += content;
837 yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone()));
838 }
839 }
840 }
841 }
842
843 if let Some(usage) = data.usage {
844 final_usage = usage.clone();
845 }
846 }
847
848 Err(crate::http_client::Error::StreamEnded) => break,
849 Err(err) => {
850 tracing::error!(?err, "SSE error");
851 yield Err(CompletionError::ResponseError(err.to_string()));
852 break;
853 }
854 }
855 }
856
857 event_source.close();
858
859 let mut tool_calls = Vec::new();
860 for (_, (id, name, arguments)) in calls {
862 let Ok(arguments_json) = serde_json::from_str::<serde_json::Value>(&arguments) else {
863 continue;
864 };
865
866 tool_calls.push(rig::providers::openai::completion::ToolCall {
867 id: id.clone(),
868 r#type: ToolType::Function,
869 function: Function {
870 name: name.clone(),
871 arguments: arguments_json.clone()
872 }
873 });
874 yield Ok(crate::streaming::RawStreamingChoice::ToolCall {
875 id,
876 name,
877 arguments: arguments_json,
878 call_id: None,
879 });
880 }
881
882 let response_message = crate::providers::openai::completion::Message::Assistant {
883 content: vec![AssistantContent::Text { text: text_response }],
884 refusal: None,
885 audio: None,
886 name: None,
887 tool_calls
888 };
889
890 span.record("gen_ai.output.messages", serde_json::to_string(&vec![response_message]).unwrap());
891 span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
892 span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
893
894 yield Ok(crate::streaming::RawStreamingChoice::FinalResponse(
896 StreamingCompletionResponse { usage: final_usage.clone() }
897 ));
898 }.instrument(span);
899
900 Ok(crate::streaming::StreamingCompletionResponse::stream(
901 Box::pin(stream),
902 ))
903}