Skip to main content

rig_core/providers/huggingface/
completion.rs

1use super::client::Client;
2use crate::completion::GetTokenUsage;
3use crate::http_client::HttpClientExt;
4use crate::providers::openai::StreamingCompletionResponse;
5use crate::telemetry::SpanCombinator;
6use crate::{
7    OneOrMany,
8    completion::{self, CompletionError, CompletionRequest},
9    json_utils,
10    message::{self},
11    one_or_many::string_or_one_or_many,
12};
13use serde::{Deserialize, Deserializer, Serialize, Serializer};
14use serde_json::Value;
15use std::{convert::Infallible, str::FromStr};
16use tracing::{Level, enabled, info_span};
17use tracing_futures::Instrument;
18
19#[derive(Debug, Deserialize)]
20#[serde(untagged)]
21pub enum ApiResponse<T> {
22    Ok(T),
23    Err(Value),
24}
25
26// ================================================================
27// Huggingface Completion API
28// ================================================================
29
30// Conversational LLMs
31/// `google/gemma-2-2b-it` completion model
32pub const GEMMA_2: &str = "google/gemma-2-2b-it";
33/// `meta-llama/Meta-Llama-3.1-8B-Instruct` completion model
34pub const META_LLAMA_3_1: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
35/// `PowerInfer/SmallThinker-3B-Preview` completion model
36pub const SMALLTHINKER_PREVIEW: &str = "PowerInfer/SmallThinker-3B-Preview";
37/// `Qwen/Qwen2.5-7B-Instruct` completion model
38pub const QWEN2_5: &str = "Qwen/Qwen2.5-7B-Instruct";
39/// `Qwen/Qwen2.5-Coder-32B-Instruct` completion model
40pub const QWEN2_5_CODER: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
41
42// Conversational VLMs
43
44/// `Qwen/Qwen2-VL-7B-Instruct` visual-language completion model
45pub const QWEN2_VL: &str = "Qwen/Qwen2-VL-7B-Instruct";
46/// `Qwen/QVQ-72B-Preview` visual-language completion model
47pub const QWEN_QVQ_PREVIEW: &str = "Qwen/QVQ-72B-Preview";
48
49#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
50pub struct Function {
51    name: String,
52    #[serde(
53        serialize_with = "json_utils::stringified_json::serialize",
54        deserialize_with = "json_utils::stringified_json::deserialize_maybe_stringified"
55    )]
56    pub arguments: serde_json::Value,
57}
58
59impl From<Function> for message::ToolFunction {
60    fn from(value: Function) -> Self {
61        message::ToolFunction {
62            name: value.name,
63            arguments: value.arguments,
64        }
65    }
66}
67
68#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
69#[serde(rename_all = "lowercase")]
70pub enum ToolType {
71    #[default]
72    Function,
73}
74
75#[derive(Debug, Deserialize, Serialize, Clone)]
76pub struct ToolDefinition {
77    pub r#type: String,
78    pub function: completion::ToolDefinition,
79}
80
81impl From<completion::ToolDefinition> for ToolDefinition {
82    fn from(tool: completion::ToolDefinition) -> Self {
83        Self {
84            r#type: "function".into(),
85            function: tool,
86        }
87    }
88}
89
90#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
91pub struct ToolCall {
92    pub id: String,
93    pub r#type: ToolType,
94    pub function: Function,
95}
96
97impl From<ToolCall> for message::ToolCall {
98    fn from(value: ToolCall) -> Self {
99        message::ToolCall {
100            id: value.id,
101            call_id: None,
102            function: value.function.into(),
103            signature: None,
104            additional_params: None,
105        }
106    }
107}
108
109impl From<message::ToolCall> for ToolCall {
110    fn from(value: message::ToolCall) -> Self {
111        ToolCall {
112            id: value.id,
113            r#type: ToolType::Function,
114            function: Function {
115                name: value.function.name,
116                arguments: value.function.arguments,
117            },
118        }
119    }
120}
121
122#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
123pub struct ImageUrl {
124    url: String,
125}
126
127#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
128#[serde(tag = "type", rename_all = "lowercase")]
129pub enum UserContent {
130    Text {
131        text: String,
132    },
133    #[serde(rename = "image_url")]
134    ImageUrl {
135        image_url: ImageUrl,
136    },
137}
138
139impl FromStr for UserContent {
140    type Err = Infallible;
141
142    fn from_str(s: &str) -> Result<Self, Self::Err> {
143        Ok(UserContent::Text {
144            text: s.to_string(),
145        })
146    }
147}
148
149#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
150#[serde(tag = "type", rename_all = "lowercase")]
151pub enum AssistantContent {
152    Text { text: String },
153}
154
155impl FromStr for AssistantContent {
156    type Err = Infallible;
157
158    fn from_str(s: &str) -> Result<Self, Self::Err> {
159        Ok(AssistantContent::Text {
160            text: s.to_string(),
161        })
162    }
163}
164
165#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
166#[serde(tag = "type", rename_all = "lowercase")]
167pub enum SystemContent {
168    Text { text: String },
169}
170
171impl FromStr for SystemContent {
172    type Err = Infallible;
173
174    fn from_str(s: &str) -> Result<Self, Self::Err> {
175        Ok(SystemContent::Text {
176            text: s.to_string(),
177        })
178    }
179}
180
181impl From<UserContent> for message::UserContent {
182    fn from(value: UserContent) -> Self {
183        match value {
184            UserContent::Text { text, .. } => message::UserContent::text(text),
185            UserContent::ImageUrl { image_url } => {
186                message::UserContent::image_url(image_url.url, None, None)
187            }
188        }
189    }
190}
191
192impl TryFrom<message::UserContent> for UserContent {
193    type Error = message::MessageError;
194
195    fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
196        match content {
197            message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }),
198            message::UserContent::Document(message::Document {
199                data: message::DocumentSourceKind::Raw(raw),
200                ..
201            }) => {
202                let text = String::from_utf8_lossy(raw.as_slice()).into();
203                Ok(UserContent::Text { text })
204            }
205            message::UserContent::Document(message::Document {
206                data:
207                    message::DocumentSourceKind::Base64(text)
208                    | message::DocumentSourceKind::String(text),
209                ..
210            }) => Ok(UserContent::Text { text }),
211            message::UserContent::Image(message::Image { data, .. }) => match data {
212                message::DocumentSourceKind::Url(url) => Ok(UserContent::ImageUrl {
213                    image_url: ImageUrl { url },
214                }),
215                _ => Err(message::MessageError::ConversionError(
216                    "Huggingface only supports images as urls".into(),
217                )),
218            },
219            _ => Err(message::MessageError::ConversionError(
220                "Huggingface only supports text and images".into(),
221            )),
222        }
223    }
224}
225
226#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
227#[serde(tag = "role", rename_all = "lowercase")]
228pub enum Message {
229    System {
230        #[serde(deserialize_with = "string_or_one_or_many")]
231        content: OneOrMany<SystemContent>,
232    },
233    User {
234        #[serde(deserialize_with = "string_or_one_or_many")]
235        content: OneOrMany<UserContent>,
236    },
237    Assistant {
238        #[serde(default, deserialize_with = "json_utils::string_or_vec")]
239        content: Vec<AssistantContent>,
240        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
241        tool_calls: Vec<ToolCall>,
242    },
243    #[serde(rename = "tool", alias = "Tool")]
244    ToolResult {
245        name: String,
246        #[serde(skip_serializing_if = "Option::is_none")]
247        arguments: Option<serde_json::Value>,
248        #[serde(
249            deserialize_with = "string_or_one_or_many",
250            serialize_with = "serialize_tool_content"
251        )]
252        content: OneOrMany<String>,
253    },
254}
255
256fn serialize_tool_content<S>(content: &OneOrMany<String>, serializer: S) -> Result<S::Ok, S::Error>
257where
258    S: Serializer,
259{
260    // OpenAI-compatible APIs expect tool content as a string, not an array
261    let joined = content
262        .iter()
263        .map(String::as_str)
264        .collect::<Vec<_>>()
265        .join("\n");
266    serializer.serialize_str(&joined)
267}
268
269impl Message {
270    pub fn system(content: &str) -> Self {
271        Message::System {
272            content: OneOrMany::one(SystemContent::Text {
273                text: content.to_string(),
274            }),
275        }
276    }
277}
278
279impl TryFrom<message::Message> for Vec<Message> {
280    type Error = message::MessageError;
281
282    fn try_from(message: message::Message) -> Result<Vec<Message>, Self::Error> {
283        match message {
284            message::Message::System { content } => Ok(vec![Message::system(&content)]),
285            message::Message::User { content } => {
286                let (tool_results, other_content): (Vec<_>, Vec<_>) = content
287                    .into_iter()
288                    .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
289
290                if !tool_results.is_empty() {
291                    tool_results
292                        .into_iter()
293                        .map(|content| match content {
294                            message::UserContent::ToolResult(message::ToolResult {
295                                id,
296                                content,
297                                ..
298                            }) => Ok::<_, message::MessageError>(Message::ToolResult {
299                                name: id,
300                                arguments: None,
301                                content: content.try_map(|content| match content {
302                                    message::ToolResultContent::Text(message::Text {
303                                        text,
304                                        ..
305                                    }) => Ok(text),
306                                    _ => Err(message::MessageError::ConversionError(
307                                        "Tool result content does not support non-text".into(),
308                                    )),
309                                })?,
310                            }),
311                            _ => Err(message::MessageError::ConversionError(
312                                "expected tool result content while converting HuggingFace input"
313                                    .into(),
314                            )),
315                        })
316                        .collect::<Result<Vec<_>, _>>()
317                } else {
318                    let other_content = OneOrMany::many(other_content).map_err(|_| {
319                        message::MessageError::ConversionError(
320                            "HuggingFace user message did not contain any non-tool content".into(),
321                        )
322                    })?;
323
324                    Ok(vec![Message::User {
325                        content: other_content.try_map(|content| match content {
326                            message::UserContent::Text(text) => {
327                                Ok(UserContent::Text { text: text.text })
328                            }
329                            message::UserContent::Image(image) => {
330                                let url = image.try_into_url()?;
331
332                                Ok(UserContent::ImageUrl {
333                                    image_url: ImageUrl { url },
334                                })
335                            }
336                            message::UserContent::Document(message::Document {
337                                data: message::DocumentSourceKind::Raw(raw), ..
338                            }) => {
339                                let text = String::from_utf8_lossy(raw.as_slice()).into();
340                                Ok(UserContent::Text { text })
341                            }
342                            message::UserContent::Document(message::Document {
343                                data: message::DocumentSourceKind::Base64(text) | message::DocumentSourceKind::String(text), ..
344                            }) => {
345                                Ok(UserContent::Text { text })
346                            }
347                            _ => Err(message::MessageError::ConversionError(
348                                "Huggingface inputs only support text and image URLs (both base64-encoded images and regular URLs)".into(),
349                            )),
350                        })?,
351                    }])
352                }
353            }
354            message::Message::Assistant { content, .. } => {
355                let mut text_content = Vec::new();
356                let mut tool_calls = Vec::new();
357
358                for content in content {
359                    match content {
360                        message::AssistantContent::Text(text) => text_content.push(text),
361                        message::AssistantContent::ToolCall(tool_call) => {
362                            tool_calls.push(tool_call)
363                        }
364                        message::AssistantContent::Reasoning(_) => {
365                            // HuggingFace does not support assistant-history reasoning items.
366                            // Silently skip unsupported reasoning content.
367                        }
368                        message::AssistantContent::Image(_) => {
369                            return Err(message::MessageError::ConversionError(
370                                "HuggingFace assistant messages do not support image content"
371                                    .into(),
372                            ));
373                        }
374                    }
375                }
376
377                if text_content.is_empty() && tool_calls.is_empty() {
378                    return Ok(vec![]);
379                }
380
381                Ok(vec![Message::Assistant {
382                    content: text_content
383                        .into_iter()
384                        .map(|content| AssistantContent::Text { text: content.text })
385                        .collect::<Vec<_>>(),
386                    tool_calls: tool_calls
387                        .into_iter()
388                        .map(|tool_call| tool_call.into())
389                        .collect::<Vec<_>>(),
390                }])
391            }
392        }
393    }
394}
395
396impl TryFrom<Message> for message::Message {
397    type Error = message::MessageError;
398
399    fn try_from(message: Message) -> Result<Self, Self::Error> {
400        Ok(match message {
401            Message::User { content, .. } => message::Message::User {
402                content: content.map(|content| content.into()),
403            },
404            Message::Assistant {
405                content,
406                tool_calls,
407                ..
408            } => {
409                let mut content = content
410                    .into_iter()
411                    .map(|content| match content {
412                        AssistantContent::Text { text, .. } => {
413                            message::AssistantContent::text(text)
414                        }
415                    })
416                    .collect::<Vec<_>>();
417
418                content.extend(
419                    tool_calls
420                        .into_iter()
421                        .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
422                        .collect::<Result<Vec<_>, _>>()?,
423                );
424
425                message::Message::Assistant {
426                    id: None,
427                    content: OneOrMany::many(content).map_err(|_| {
428                        message::MessageError::ConversionError(
429                            "Neither `content` nor `tool_calls` was provided to the Message"
430                                .to_owned(),
431                        )
432                    })?,
433                }
434            }
435
436            Message::ToolResult { name, content, .. } => message::Message::User {
437                content: OneOrMany::one(message::UserContent::tool_result(
438                    name,
439                    content.map(message::ToolResultContent::text),
440                )),
441            },
442
443            // System messages should get stripped out when converting message's, this is just a
444            // stop gap to avoid obnoxious error handling or panic occurring.
445            Message::System { content, .. } => message::Message::User {
446                content: content.map(|c| match c {
447                    SystemContent::Text { text, .. } => message::UserContent::text(text),
448                }),
449            },
450        })
451    }
452}
453
454#[derive(Clone, Debug, Deserialize, Serialize)]
455pub struct Choice {
456    pub finish_reason: String,
457    pub index: usize,
458    #[serde(default)]
459    pub logprobs: serde_json::Value,
460    pub message: Message,
461}
462
463#[derive(Debug, Deserialize, Clone, Serialize)]
464pub struct Usage {
465    pub completion_tokens: i32,
466    pub prompt_tokens: i32,
467    pub total_tokens: i32,
468}
469
470impl GetTokenUsage for Usage {
471    fn token_usage(&self) -> Option<crate::completion::Usage> {
472        let mut usage = crate::completion::Usage::new();
473        usage.input_tokens = self.prompt_tokens as u64;
474        usage.output_tokens = self.completion_tokens as u64;
475        usage.total_tokens = self.total_tokens as u64;
476
477        Some(usage)
478    }
479}
480
481#[derive(Clone, Debug, Deserialize, Serialize)]
482pub struct CompletionResponse {
483    pub created: i32,
484    pub id: String,
485    pub model: String,
486    pub choices: Vec<Choice>,
487    #[serde(default, deserialize_with = "default_string_on_null")]
488    pub system_fingerprint: String,
489    pub usage: Usage,
490}
491
492impl crate::telemetry::ProviderResponseExt for CompletionResponse {
493    type OutputMessage = Choice;
494    type Usage = Usage;
495
496    fn get_response_id(&self) -> Option<String> {
497        Some(self.id.clone())
498    }
499
500    fn get_response_model_name(&self) -> Option<String> {
501        Some(self.model.clone())
502    }
503
504    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
505        self.choices.clone()
506    }
507
508    fn get_text_response(&self) -> Option<String> {
509        let text_response = self
510            .choices
511            .iter()
512            .filter_map(|x| {
513                let Message::User { ref content } = x.message else {
514                    return None;
515                };
516
517                let text = content
518                    .iter()
519                    .filter_map(|x| {
520                        if let UserContent::Text { text, .. } = x {
521                            Some(text.clone())
522                        } else {
523                            None
524                        }
525                    })
526                    .collect::<Vec<String>>();
527
528                if text.is_empty() {
529                    None
530                } else {
531                    Some(text.join("\n"))
532                }
533            })
534            .collect::<Vec<String>>()
535            .join("\n");
536
537        if text_response.is_empty() {
538            None
539        } else {
540            Some(text_response)
541        }
542    }
543
544    fn get_usage(&self) -> Option<Self::Usage> {
545        Some(self.usage.clone())
546    }
547}
548
549fn default_string_on_null<'de, D>(deserializer: D) -> Result<String, D::Error>
550where
551    D: Deserializer<'de>,
552{
553    match Option::<String>::deserialize(deserializer)? {
554        Some(value) => Ok(value),      // Use provided value
555        None => Ok(String::default()), // Use `Default` implementation
556    }
557}
558
559impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
560    type Error = CompletionError;
561
562    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
563        let choice = response.choices.first().ok_or_else(|| {
564            CompletionError::ResponseError("Response contained no choices".to_owned())
565        })?;
566
567        let content = match &choice.message {
568            Message::Assistant {
569                content,
570                tool_calls,
571                ..
572            } => {
573                let mut content = content
574                    .iter()
575                    .map(|c| match c {
576                        AssistantContent::Text { text, .. } => {
577                            message::AssistantContent::text(text)
578                        }
579                    })
580                    .collect::<Vec<_>>();
581
582                content.extend(
583                    tool_calls
584                        .iter()
585                        .map(|call| {
586                            completion::AssistantContent::tool_call(
587                                &call.id,
588                                &call.function.name,
589                                call.function.arguments.clone(),
590                            )
591                        })
592                        .collect::<Vec<_>>(),
593                );
594                Ok(content)
595            }
596            _ => Err(CompletionError::ResponseError(
597                "Response did not contain a valid message or tool call".into(),
598            )),
599        }?;
600
601        let choice = OneOrMany::many(content).map_err(|_| {
602            CompletionError::ResponseError(
603                "Response contained no message or tool call (empty)".to_owned(),
604            )
605        })?;
606
607        let usage = completion::Usage {
608            input_tokens: response.usage.prompt_tokens as u64,
609            output_tokens: response.usage.completion_tokens as u64,
610            total_tokens: response.usage.total_tokens as u64,
611            cached_input_tokens: 0,
612            cache_creation_input_tokens: 0,
613            tool_use_prompt_tokens: 0,
614            reasoning_tokens: 0,
615        };
616
617        Ok(completion::CompletionResponse {
618            choice,
619            usage,
620            raw_response: response,
621            message_id: None,
622        })
623    }
624}
625
626#[derive(Debug, Serialize, Deserialize)]
627pub(super) struct HuggingfaceCompletionRequest {
628    model: String,
629    pub messages: Vec<Message>,
630    #[serde(skip_serializing_if = "Option::is_none")]
631    temperature: Option<f64>,
632    #[serde(skip_serializing_if = "Vec::is_empty")]
633    tools: Vec<ToolDefinition>,
634    #[serde(skip_serializing_if = "Option::is_none")]
635    tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
636    #[serde(flatten, skip_serializing_if = "Option::is_none")]
637    pub additional_params: Option<serde_json::Value>,
638}
639
640impl TryFrom<(&str, CompletionRequest)> for HuggingfaceCompletionRequest {
641    type Error = CompletionError;
642
643    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
644        if req.output_schema.is_some() {
645            tracing::warn!("Structured outputs currently not supported for Huggingface");
646        }
647        let model = req.model.clone().unwrap_or_else(|| model.to_string());
648        let mut full_history: Vec<Message> = match &req.preamble {
649            Some(preamble) => vec![Message::system(preamble)],
650            None => vec![],
651        };
652        if let Some(docs) = req.normalized_documents() {
653            let docs: Vec<Message> = docs.try_into()?;
654            full_history.extend(docs);
655        }
656
657        let chat_history: Vec<Message> = req
658            .chat_history
659            .clone()
660            .into_iter()
661            .map(|message| message.try_into())
662            .collect::<Result<Vec<Vec<Message>>, _>>()?
663            .into_iter()
664            .flatten()
665            .collect();
666
667        full_history.extend(chat_history);
668
669        if full_history.is_empty() {
670            return Err(CompletionError::RequestError(
671                std::io::Error::new(
672                    std::io::ErrorKind::InvalidInput,
673                    "HuggingFace request has no provider-compatible messages after conversion",
674                )
675                .into(),
676            ));
677        }
678
679        let tool_choice = req
680            .tool_choice
681            .clone()
682            .map(crate::providers::openai::completion::ToolChoice::try_from)
683            .transpose()?;
684
685        Ok(Self {
686            model: model.to_string(),
687            messages: full_history,
688            temperature: req.temperature,
689            tools: req
690                .tools
691                .clone()
692                .into_iter()
693                .map(ToolDefinition::from)
694                .collect::<Vec<_>>(),
695            tool_choice,
696            additional_params: req.additional_params,
697        })
698    }
699}
700
701#[derive(Clone)]
702pub struct CompletionModel<T = reqwest::Client> {
703    pub(crate) client: Client<T>,
704    /// Name of the model (e.g: google/gemma-2-2b-it)
705    pub model: String,
706}
707
708impl<T> CompletionModel<T> {
709    pub fn new(client: Client<T>, model: &str) -> Self {
710        Self {
711            client,
712            model: model.to_string(),
713        }
714    }
715}
716
717impl<T> completion::CompletionModel for CompletionModel<T>
718where
719    T: HttpClientExt + Clone + 'static,
720{
721    type Response = CompletionResponse;
722    type StreamingResponse = StreamingCompletionResponse;
723
724    type Client = Client<T>;
725
726    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
727        Self::new(client.clone(), &model.into())
728    }
729
730    async fn completion(
731        &self,
732        completion_request: CompletionRequest,
733    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
734        let request_model = completion_request
735            .model
736            .clone()
737            .unwrap_or_else(|| self.model.clone());
738        let span = if tracing::Span::current().is_disabled() {
739            info_span!(
740                target: "rig::completions",
741                "chat",
742                gen_ai.operation.name = "chat",
743                gen_ai.provider.name = "huggingface",
744                gen_ai.request.model = &request_model,
745                gen_ai.system_instructions = &completion_request.preamble,
746                gen_ai.response.id = tracing::field::Empty,
747                gen_ai.response.model = tracing::field::Empty,
748                gen_ai.usage.output_tokens = tracing::field::Empty,
749                gen_ai.usage.input_tokens = tracing::field::Empty,
750                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
751            )
752        } else {
753            tracing::Span::current()
754        };
755
756        let model = self.client.subprovider().model_identifier(&request_model);
757        let request = HuggingfaceCompletionRequest::try_from((model.as_ref(), completion_request))?;
758
759        if enabled!(Level::TRACE) {
760            tracing::trace!(
761                target: "rig::completions",
762                "Huggingface completion request: {}",
763                serde_json::to_string_pretty(&request)?
764            );
765        }
766
767        let request = serde_json::to_vec(&request)?;
768
769        let path = self
770            .client
771            .subprovider()
772            .completion_endpoint(&request_model);
773        let request = self
774            .client
775            .post(&path)?
776            .header("Content-Type", "application/json")
777            .body(request)
778            .map_err(|e| CompletionError::HttpError(e.into()))?;
779
780        async move {
781            let response = self.client.send(request).await?;
782
783            if response.status().is_success() {
784                let bytes: Vec<u8> = response.into_body().await?;
785                let text = String::from_utf8_lossy(&bytes);
786
787                tracing::debug!(target: "rig", "Huggingface completion error: {}", text);
788
789                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&bytes)? {
790                    ApiResponse::Ok(response) => {
791                        if enabled!(Level::TRACE) {
792                            tracing::trace!(
793                                target: "rig::completions",
794                                "Huggingface completion response: {}",
795                                serde_json::to_string_pretty(&response)?
796                            );
797                        }
798
799                        let span = tracing::Span::current();
800                        span.record_token_usage(&response.usage);
801                        span.record_response_metadata(&response);
802
803                        response.try_into()
804                    }
805                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())),
806                }
807            } else {
808                let status = response.status();
809                let text: Vec<u8> = response.into_body().await?;
810                let text: String = String::from_utf8_lossy(&text).into();
811
812                Err(CompletionError::ProviderError(format!(
813                    "{}: {}",
814                    status, text
815                )))
816            }
817        }
818        .instrument(span)
819        .await
820    }
821
822    async fn stream(
823        &self,
824        request: CompletionRequest,
825    ) -> Result<
826        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
827        CompletionError,
828    > {
829        CompletionModel::stream(self, request).await
830    }
831}
832
833#[cfg(test)]
834mod tests {
835    use super::*;
836    use serde_path_to_error::deserialize;
837
838    #[test]
839    fn test_huggingface_request_uses_request_model_override() {
840        let request = CompletionRequest {
841            model: Some("meta-llama/Meta-Llama-3.1-8B-Instruct".to_string()),
842            preamble: None,
843            chat_history: crate::OneOrMany::one("Hello".into()),
844            documents: vec![],
845            tools: vec![],
846            temperature: None,
847            max_tokens: None,
848            tool_choice: None,
849            additional_params: None,
850            output_schema: None,
851        };
852
853        let hf_request = HuggingfaceCompletionRequest::try_from(("mistralai/Mistral-7B", request))
854            .expect("request conversion should succeed");
855        let serialized = serde_json::to_value(hf_request).expect("serialization should succeed");
856
857        assert_eq!(serialized["model"], "meta-llama/Meta-Llama-3.1-8B-Instruct");
858    }
859
860    #[test]
861    fn test_huggingface_request_uses_default_model_when_override_unset() {
862        let request = CompletionRequest {
863            model: None,
864            preamble: None,
865            chat_history: crate::OneOrMany::one("Hello".into()),
866            documents: vec![],
867            tools: vec![],
868            temperature: None,
869            max_tokens: None,
870            tool_choice: None,
871            additional_params: None,
872            output_schema: None,
873        };
874
875        let hf_request = HuggingfaceCompletionRequest::try_from(("mistralai/Mistral-7B", request))
876            .expect("request conversion should succeed");
877        let serialized = serde_json::to_value(hf_request).expect("serialization should succeed");
878
879        assert_eq!(serialized["model"], "mistralai/Mistral-7B");
880    }
881
882    #[test]
883    fn test_deserialize_message() {
884        let assistant_message_json = r#"
885        {
886            "role": "assistant",
887            "content": "\n\nHello there, how may I assist you today?"
888        }
889        "#;
890
891        let assistant_message_json2 = r#"
892        {
893            "role": "assistant",
894            "content": [
895                {
896                    "type": "text",
897                    "text": "\n\nHello there, how may I assist you today?"
898                }
899            ],
900            "tool_calls": null
901        }
902        "#;
903
904        let assistant_message_json3 = r#"
905        {
906            "role": "assistant",
907            "tool_calls": [
908                {
909                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
910                    "type": "function",
911                    "function": {
912                        "name": "subtract",
913                        "arguments": {"x": 2, "y": 5}
914                    }
915                }
916            ],
917            "content": null,
918            "refusal": null
919        }
920        "#;
921
922        let user_message_json = r#"
923        {
924            "role": "user",
925            "content": [
926                {
927                    "type": "text",
928                    "text": "What's in this image?"
929                },
930                {
931                    "type": "image_url",
932                    "image_url": {
933                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
934                    }
935                }
936            ]
937        }
938        "#;
939
940        let assistant_message: Message = {
941            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
942            deserialize(jd).unwrap_or_else(|err| {
943                panic!(
944                    "Deserialization error at {} ({}:{}): {}",
945                    err.path(),
946                    err.inner().line(),
947                    err.inner().column(),
948                    err
949                );
950            })
951        };
952
953        let assistant_message2: Message = {
954            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
955            deserialize(jd).unwrap_or_else(|err| {
956                panic!(
957                    "Deserialization error at {} ({}:{}): {}",
958                    err.path(),
959                    err.inner().line(),
960                    err.inner().column(),
961                    err
962                );
963            })
964        };
965
966        let assistant_message3: Message = {
967            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
968                &mut serde_json::Deserializer::from_str(assistant_message_json3);
969            deserialize(jd).unwrap_or_else(|err| {
970                panic!(
971                    "Deserialization error at {} ({}:{}): {}",
972                    err.path(),
973                    err.inner().line(),
974                    err.inner().column(),
975                    err
976                );
977            })
978        };
979
980        let user_message: Message = {
981            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
982            deserialize(jd).unwrap_or_else(|err| {
983                panic!(
984                    "Deserialization error at {} ({}:{}): {}",
985                    err.path(),
986                    err.inner().line(),
987                    err.inner().column(),
988                    err
989                );
990            })
991        };
992
993        match assistant_message {
994            Message::Assistant { content, .. } => {
995                assert_eq!(
996                    content[0],
997                    AssistantContent::Text {
998                        text: "\n\nHello there, how may I assist you today?".to_string()
999                    }
1000                );
1001            }
1002            _ => panic!("Expected assistant message"),
1003        }
1004
1005        match assistant_message2 {
1006            Message::Assistant {
1007                content,
1008                tool_calls,
1009                ..
1010            } => {
1011                assert_eq!(
1012                    content[0],
1013                    AssistantContent::Text {
1014                        text: "\n\nHello there, how may I assist you today?".to_string()
1015                    }
1016                );
1017
1018                assert_eq!(tool_calls, vec![]);
1019            }
1020            _ => panic!("Expected assistant message"),
1021        }
1022
1023        match assistant_message3 {
1024            Message::Assistant {
1025                content,
1026                tool_calls,
1027                ..
1028            } => {
1029                assert!(content.is_empty());
1030                assert_eq!(
1031                    tool_calls[0],
1032                    ToolCall {
1033                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
1034                        r#type: ToolType::Function,
1035                        function: Function {
1036                            name: "subtract".to_string(),
1037                            arguments: serde_json::json!({"x": 2, "y": 5}),
1038                        },
1039                    }
1040                );
1041            }
1042            _ => panic!("Expected assistant message"),
1043        }
1044
1045        match user_message {
1046            Message::User { content, .. } => {
1047                let (first, second) = {
1048                    let mut iter = content.into_iter();
1049                    (iter.next().unwrap(), iter.next().unwrap())
1050                };
1051                assert_eq!(
1052                    first,
1053                    UserContent::Text {
1054                        text: "What's in this image?".to_string()
1055                    }
1056                );
1057                assert_eq!(second, UserContent::ImageUrl { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string() } });
1058            }
1059            _ => panic!("Expected user message"),
1060        }
1061    }
1062
1063    #[test]
1064    fn test_message_to_message_conversion() {
1065        let user_message = message::Message::User {
1066            content: OneOrMany::one(message::UserContent::text("Hello")),
1067        };
1068
1069        let assistant_message = message::Message::Assistant {
1070            id: None,
1071            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
1072        };
1073
1074        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
1075        let converted_assistant_message: Vec<Message> =
1076            assistant_message.clone().try_into().unwrap();
1077
1078        match converted_user_message[0].clone() {
1079            Message::User { content, .. } => {
1080                assert_eq!(
1081                    content.first(),
1082                    UserContent::Text {
1083                        text: "Hello".to_string()
1084                    }
1085                );
1086            }
1087            _ => panic!("Expected user message"),
1088        }
1089
1090        match converted_assistant_message[0].clone() {
1091            Message::Assistant { content, .. } => {
1092                assert_eq!(
1093                    content[0],
1094                    AssistantContent::Text {
1095                        text: "Hi there!".to_string()
1096                    }
1097                );
1098            }
1099            _ => panic!("Expected assistant message"),
1100        }
1101
1102        let original_user_message: message::Message =
1103            converted_user_message[0].clone().try_into().unwrap();
1104        let original_assistant_message: message::Message =
1105            converted_assistant_message[0].clone().try_into().unwrap();
1106
1107        assert_eq!(original_user_message, user_message);
1108        assert_eq!(original_assistant_message, assistant_message);
1109    }
1110
1111    #[test]
1112    fn test_message_from_message_conversion() {
1113        let user_message = Message::User {
1114            content: OneOrMany::one(UserContent::Text {
1115                text: "Hello".to_string(),
1116            }),
1117        };
1118
1119        let assistant_message = Message::Assistant {
1120            content: vec![AssistantContent::Text {
1121                text: "Hi there!".to_string(),
1122            }],
1123            tool_calls: vec![],
1124        };
1125
1126        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1127        let converted_assistant_message: message::Message =
1128            assistant_message.clone().try_into().unwrap();
1129
1130        match converted_user_message.clone() {
1131            message::Message::User { content } => {
1132                assert_eq!(content.first(), message::UserContent::text("Hello"));
1133            }
1134            _ => panic!("Expected user message"),
1135        }
1136
1137        match converted_assistant_message.clone() {
1138            message::Message::Assistant { content, .. } => {
1139                assert_eq!(
1140                    content.first(),
1141                    message::AssistantContent::text("Hi there!")
1142                );
1143            }
1144            _ => panic!("Expected assistant message"),
1145        }
1146
1147        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
1148        let original_assistant_message: Vec<Message> =
1149            converted_assistant_message.try_into().unwrap();
1150
1151        assert_eq!(original_user_message[0], user_message);
1152        assert_eq!(original_assistant_message[0], assistant_message);
1153    }
1154
1155    #[test]
1156    fn test_responses() {
1157        let fireworks_response_json = r#"
1158        {
1159            "choices": [
1160                {
1161                    "finish_reason": "tool_calls",
1162                    "index": 0,
1163                    "message": {
1164                        "role": "assistant",
1165                        "tool_calls": [
1166                            {
1167                                "function": {
1168                                "arguments": "{\"x\": 2, \"y\": 5}",
1169                                "name": "subtract"
1170                                },
1171                                "id": "call_1BspL6mQqjKgvsQbH1TIYkHf",
1172                                "index": 0,
1173                                "type": "function"
1174                            }
1175                        ]
1176                    }
1177                }
1178            ],
1179            "created": 1740704000,
1180            "id": "2a81f6a1-4866-42fb-9902-2655a2b5b1ff",
1181            "model": "accounts/fireworks/models/deepseek-v3",
1182            "object": "chat.completion",
1183            "usage": {
1184                "completion_tokens": 26,
1185                "prompt_tokens": 248,
1186                "total_tokens": 274
1187            }
1188        }
1189        "#;
1190
1191        let novita_response_json = r#"
1192        {
1193            "choices": [
1194                {
1195                    "finish_reason": "tool_calls",
1196                    "index": 0,
1197                    "logprobs": null,
1198                    "message": {
1199                        "audio": null,
1200                        "content": null,
1201                        "function_call": null,
1202                        "reasoning_content": null,
1203                        "refusal": null,
1204                        "role": "assistant",
1205                        "tool_calls": [
1206                            {
1207                                "function": {
1208                                    "arguments": "{\"x\": \"2\", \"y\": \"5\"}",
1209                                    "name": "subtract"
1210                                },
1211                                "id": "chatcmpl-tool-f6d2af7c8dc041058f95e2c2eede45c5",
1212                                "type": "function"
1213                            }
1214                        ]
1215                    },
1216                    "stop_reason": 128008
1217                }
1218            ],
1219            "created": 1740704592,
1220            "id": "chatcmpl-a92c60ae125c47c998ecdcb53387fed4",
1221            "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-fast",
1222            "object": "chat.completion",
1223            "prompt_logprobs": null,
1224            "service_tier": null,
1225            "system_fingerprint": null,
1226            "usage": {
1227                "completion_tokens": 28,
1228                "completion_tokens_details": null,
1229                "prompt_tokens": 335,
1230                "prompt_tokens_details": null,
1231                "total_tokens": 363
1232            }
1233        }
1234        "#;
1235
1236        let _firework_response: CompletionResponse = {
1237            let jd = &mut serde_json::Deserializer::from_str(fireworks_response_json);
1238            deserialize(jd).unwrap_or_else(|err| {
1239                panic!(
1240                    "Deserialization error at {} ({}:{}): {}",
1241                    err.path(),
1242                    err.inner().line(),
1243                    err.inner().column(),
1244                    err
1245                );
1246            })
1247        };
1248
1249        let _novita_response: CompletionResponse = {
1250            let jd = &mut serde_json::Deserializer::from_str(novita_response_json);
1251            deserialize(jd).unwrap_or_else(|err| {
1252                panic!(
1253                    "Deserialization error at {} ({}:{}): {}",
1254                    err.path(),
1255                    err.inner().line(),
1256                    err.inner().column(),
1257                    err
1258                );
1259            })
1260        };
1261    }
1262
1263    #[test]
1264    fn test_assistant_reasoning_is_silently_skipped() {
1265        let assistant = message::Message::Assistant {
1266            id: None,
1267            content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1268        };
1269
1270        let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
1271        assert!(converted.is_empty());
1272    }
1273
1274    #[test]
1275    fn test_assistant_text_and_tool_call_are_preserved_when_reasoning_present() {
1276        let assistant = message::Message::Assistant {
1277            id: None,
1278            content: OneOrMany::many(vec![
1279                message::AssistantContent::reasoning("hidden"),
1280                message::AssistantContent::text("visible"),
1281                message::AssistantContent::tool_call(
1282                    "call_1",
1283                    "subtract",
1284                    serde_json::json!({"x": 2, "y": 1}),
1285                ),
1286            ])
1287            .expect("non-empty assistant content"),
1288        };
1289
1290        let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
1291        assert_eq!(converted.len(), 1);
1292
1293        match &converted[0] {
1294            Message::Assistant {
1295                content,
1296                tool_calls,
1297                ..
1298            } => {
1299                assert_eq!(
1300                    content,
1301                    &vec![AssistantContent::Text {
1302                        text: "visible".to_string()
1303                    }]
1304                );
1305                assert_eq!(tool_calls.len(), 1);
1306                assert_eq!(tool_calls[0].id, "call_1");
1307                assert_eq!(tool_calls[0].function.name, "subtract");
1308                assert_eq!(
1309                    tool_calls[0].function.arguments,
1310                    serde_json::json!({"x": 2, "y": 1})
1311                );
1312            }
1313            _ => panic!("expected assistant message"),
1314        }
1315    }
1316
1317    #[test]
1318    fn test_request_conversion_errors_when_all_messages_are_filtered() {
1319        let request = completion::CompletionRequest {
1320            preamble: None,
1321            chat_history: OneOrMany::one(message::Message::Assistant {
1322                id: None,
1323                content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1324            }),
1325            documents: vec![],
1326            tools: vec![],
1327            temperature: None,
1328            max_tokens: None,
1329            tool_choice: None,
1330            additional_params: None,
1331            model: None,
1332            output_schema: None,
1333        };
1334
1335        let result = HuggingfaceCompletionRequest::try_from(("meta/test-model", request));
1336        assert!(matches!(result, Err(CompletionError::RequestError(_))));
1337    }
1338}