Skip to main content

rig/providers/huggingface/
completion.rs

1use super::client::Client;
2use crate::completion::GetTokenUsage;
3use crate::http_client::HttpClientExt;
4use crate::providers::openai::StreamingCompletionResponse;
5use crate::telemetry::SpanCombinator;
6use crate::{
7    OneOrMany,
8    completion::{self, CompletionError, CompletionRequest},
9    json_utils,
10    message::{self},
11    one_or_many::string_or_one_or_many,
12};
13use serde::{Deserialize, Deserializer, Serialize, Serializer};
14use serde_json::Value;
15use std::{convert::Infallible, str::FromStr};
16use tracing::{Level, enabled, info_span};
17use tracing_futures::Instrument;
18
19#[derive(Debug, Deserialize)]
20#[serde(untagged)]
21pub enum ApiResponse<T> {
22    Ok(T),
23    Err(Value),
24}
25
26// ================================================================
27// Huggingface Completion API
28// ================================================================
29
30// Conversational LLMs
31/// `google/gemma-2-2b-it` completion model
32pub const GEMMA_2: &str = "google/gemma-2-2b-it";
33/// `meta-llama/Meta-Llama-3.1-8B-Instruct` completion model
34pub const META_LLAMA_3_1: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
35/// `PowerInfer/SmallThinker-3B-Preview` completion model
36pub const SMALLTHINKER_PREVIEW: &str = "PowerInfer/SmallThinker-3B-Preview";
37/// `Qwen/Qwen2.5-7B-Instruct` completion model
38pub const QWEN2_5: &str = "Qwen/Qwen2.5-7B-Instruct";
39/// `Qwen/Qwen2.5-Coder-32B-Instruct` completion model
40pub const QWEN2_5_CODER: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
41
42// Conversational VLMs
43
44/// `Qwen/Qwen2-VL-7B-Instruct` visual-language completion model
45pub const QWEN2_VL: &str = "Qwen/Qwen2-VL-7B-Instruct";
46/// `Qwen/QVQ-72B-Preview` visual-language completion model
47pub const QWEN_QVQ_PREVIEW: &str = "Qwen/QVQ-72B-Preview";
48
49#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
50pub struct Function {
51    name: String,
52    #[serde(
53        serialize_with = "json_utils::stringified_json::serialize",
54        deserialize_with = "deserialize_arguments"
55    )]
56    pub arguments: serde_json::Value,
57}
58
59fn deserialize_arguments<'de, D>(deserializer: D) -> Result<Value, D::Error>
60where
61    D: Deserializer<'de>,
62{
63    let value = Value::deserialize(deserializer)?;
64
65    match value {
66        Value::String(s) => serde_json::from_str(&s).map_err(serde::de::Error::custom),
67        other => Ok(other),
68    }
69}
70
71impl From<Function> for message::ToolFunction {
72    fn from(value: Function) -> Self {
73        message::ToolFunction {
74            name: value.name,
75            arguments: value.arguments,
76        }
77    }
78}
79
80#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
81#[serde(rename_all = "lowercase")]
82pub enum ToolType {
83    #[default]
84    Function,
85}
86
87#[derive(Debug, Deserialize, Serialize, Clone)]
88pub struct ToolDefinition {
89    pub r#type: String,
90    pub function: completion::ToolDefinition,
91}
92
93impl From<completion::ToolDefinition> for ToolDefinition {
94    fn from(tool: completion::ToolDefinition) -> Self {
95        Self {
96            r#type: "function".into(),
97            function: tool,
98        }
99    }
100}
101
102#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
103pub struct ToolCall {
104    pub id: String,
105    pub r#type: ToolType,
106    pub function: Function,
107}
108
109impl From<ToolCall> for message::ToolCall {
110    fn from(value: ToolCall) -> Self {
111        message::ToolCall {
112            id: value.id,
113            call_id: None,
114            function: value.function.into(),
115            signature: None,
116            additional_params: None,
117        }
118    }
119}
120
121impl From<message::ToolCall> for ToolCall {
122    fn from(value: message::ToolCall) -> Self {
123        ToolCall {
124            id: value.id,
125            r#type: ToolType::Function,
126            function: Function {
127                name: value.function.name,
128                arguments: value.function.arguments,
129            },
130        }
131    }
132}
133
134#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
135pub struct ImageUrl {
136    url: String,
137}
138
139#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
140#[serde(tag = "type", rename_all = "lowercase")]
141pub enum UserContent {
142    Text {
143        text: String,
144    },
145    #[serde(rename = "image_url")]
146    ImageUrl {
147        image_url: ImageUrl,
148    },
149}
150
151impl FromStr for UserContent {
152    type Err = Infallible;
153
154    fn from_str(s: &str) -> Result<Self, Self::Err> {
155        Ok(UserContent::Text {
156            text: s.to_string(),
157        })
158    }
159}
160
161#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
162#[serde(tag = "type", rename_all = "lowercase")]
163pub enum AssistantContent {
164    Text { text: String },
165}
166
167impl FromStr for AssistantContent {
168    type Err = Infallible;
169
170    fn from_str(s: &str) -> Result<Self, Self::Err> {
171        Ok(AssistantContent::Text {
172            text: s.to_string(),
173        })
174    }
175}
176
177#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
178#[serde(tag = "type", rename_all = "lowercase")]
179pub enum SystemContent {
180    Text { text: String },
181}
182
183impl FromStr for SystemContent {
184    type Err = Infallible;
185
186    fn from_str(s: &str) -> Result<Self, Self::Err> {
187        Ok(SystemContent::Text {
188            text: s.to_string(),
189        })
190    }
191}
192
193impl From<UserContent> for message::UserContent {
194    fn from(value: UserContent) -> Self {
195        match value {
196            UserContent::Text { text } => message::UserContent::text(text),
197            UserContent::ImageUrl { image_url } => {
198                message::UserContent::image_url(image_url.url, None, None)
199            }
200        }
201    }
202}
203
204impl TryFrom<message::UserContent> for UserContent {
205    type Error = message::MessageError;
206
207    fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
208        match content {
209            message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }),
210            message::UserContent::Document(message::Document {
211                data: message::DocumentSourceKind::Raw(raw),
212                ..
213            }) => {
214                let text = String::from_utf8_lossy(raw.as_slice()).into();
215                Ok(UserContent::Text { text })
216            }
217            message::UserContent::Document(message::Document {
218                data:
219                    message::DocumentSourceKind::Base64(text)
220                    | message::DocumentSourceKind::String(text),
221                ..
222            }) => Ok(UserContent::Text { text }),
223            message::UserContent::Image(message::Image { data, .. }) => match data {
224                message::DocumentSourceKind::Url(url) => Ok(UserContent::ImageUrl {
225                    image_url: ImageUrl { url },
226                }),
227                _ => Err(message::MessageError::ConversionError(
228                    "Huggingface only supports images as urls".into(),
229                )),
230            },
231            _ => Err(message::MessageError::ConversionError(
232                "Huggingface only supports text and images".into(),
233            )),
234        }
235    }
236}
237
238#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
239#[serde(tag = "role", rename_all = "lowercase")]
240pub enum Message {
241    System {
242        #[serde(deserialize_with = "string_or_one_or_many")]
243        content: OneOrMany<SystemContent>,
244    },
245    User {
246        #[serde(deserialize_with = "string_or_one_or_many")]
247        content: OneOrMany<UserContent>,
248    },
249    Assistant {
250        #[serde(default, deserialize_with = "json_utils::string_or_vec")]
251        content: Vec<AssistantContent>,
252        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
253        tool_calls: Vec<ToolCall>,
254    },
255    #[serde(rename = "tool", alias = "Tool")]
256    ToolResult {
257        name: String,
258        #[serde(skip_serializing_if = "Option::is_none")]
259        arguments: Option<serde_json::Value>,
260        #[serde(
261            deserialize_with = "string_or_one_or_many",
262            serialize_with = "serialize_tool_content"
263        )]
264        content: OneOrMany<String>,
265    },
266}
267
268fn serialize_tool_content<S>(content: &OneOrMany<String>, serializer: S) -> Result<S::Ok, S::Error>
269where
270    S: Serializer,
271{
272    // OpenAI-compatible APIs expect tool content as a string, not an array
273    let joined = content
274        .iter()
275        .map(String::as_str)
276        .collect::<Vec<_>>()
277        .join("\n");
278    serializer.serialize_str(&joined)
279}
280
281impl Message {
282    pub fn system(content: &str) -> Self {
283        Message::System {
284            content: OneOrMany::one(SystemContent::Text {
285                text: content.to_string(),
286            }),
287        }
288    }
289}
290
291impl TryFrom<message::Message> for Vec<Message> {
292    type Error = message::MessageError;
293
294    fn try_from(message: message::Message) -> Result<Vec<Message>, Self::Error> {
295        match message {
296            message::Message::System { content } => Ok(vec![Message::system(&content)]),
297            message::Message::User { content } => {
298                let (tool_results, other_content): (Vec<_>, Vec<_>) = content
299                    .into_iter()
300                    .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
301
302                if !tool_results.is_empty() {
303                    tool_results
304                        .into_iter()
305                        .map(|content| match content {
306                            message::UserContent::ToolResult(message::ToolResult {
307                                id,
308                                content,
309                                ..
310                            }) => Ok::<_, message::MessageError>(Message::ToolResult {
311                                name: id,
312                                arguments: None,
313                                content: content.try_map(|content| match content {
314                                    message::ToolResultContent::Text(message::Text { text }) => {
315                                        Ok(text)
316                                    }
317                                    _ => Err(message::MessageError::ConversionError(
318                                        "Tool result content does not support non-text".into(),
319                                    )),
320                                })?,
321                            }),
322                            _ => unreachable!(),
323                        })
324                        .collect::<Result<Vec<_>, _>>()
325                } else {
326                    let other_content = OneOrMany::many(other_content).expect(
327                        "There must be other content here if there were no tool result content",
328                    );
329
330                    Ok(vec![Message::User {
331                        content: other_content.try_map(|content| match content {
332                            message::UserContent::Text(text) => {
333                                Ok(UserContent::Text { text: text.text })
334                            }
335                            message::UserContent::Image(image) => {
336                                let url = image.try_into_url()?;
337
338                                Ok(UserContent::ImageUrl {
339                                    image_url: ImageUrl { url },
340                                })
341                            }
342                            message::UserContent::Document(message::Document {
343                                data: message::DocumentSourceKind::Raw(raw), ..
344                            }) => {
345                                let text = String::from_utf8_lossy(raw.as_slice()).into();
346                                Ok(UserContent::Text { text })
347                            }
348                            message::UserContent::Document(message::Document {
349                                data: message::DocumentSourceKind::Base64(text) | message::DocumentSourceKind::String(text), ..
350                            }) => {
351                                Ok(UserContent::Text { text })
352                            }
353                            _ => Err(message::MessageError::ConversionError(
354                                "Huggingface inputs only support text and image URLs (both base64-encoded images and regular URLs)".into(),
355                            )),
356                        })?,
357                    }])
358                }
359            }
360            message::Message::Assistant { content, .. } => {
361                let mut text_content = Vec::new();
362                let mut tool_calls = Vec::new();
363
364                for content in content {
365                    match content {
366                        message::AssistantContent::Text(text) => text_content.push(text),
367                        message::AssistantContent::ToolCall(tool_call) => {
368                            tool_calls.push(tool_call)
369                        }
370                        message::AssistantContent::Reasoning(_) => {
371                            // HuggingFace does not support assistant-history reasoning items.
372                            // Silently skip unsupported reasoning content.
373                        }
374                        message::AssistantContent::Image(_) => {
375                            panic!("Image content is not supported on HuggingFace via Rig");
376                        }
377                    }
378                }
379
380                if text_content.is_empty() && tool_calls.is_empty() {
381                    return Ok(vec![]);
382                }
383
384                Ok(vec![Message::Assistant {
385                    content: text_content
386                        .into_iter()
387                        .map(|content| AssistantContent::Text { text: content.text })
388                        .collect::<Vec<_>>(),
389                    tool_calls: tool_calls
390                        .into_iter()
391                        .map(|tool_call| tool_call.into())
392                        .collect::<Vec<_>>(),
393                }])
394            }
395        }
396    }
397}
398
399impl TryFrom<Message> for message::Message {
400    type Error = message::MessageError;
401
402    fn try_from(message: Message) -> Result<Self, Self::Error> {
403        Ok(match message {
404            Message::User { content, .. } => message::Message::User {
405                content: content.map(|content| content.into()),
406            },
407            Message::Assistant {
408                content,
409                tool_calls,
410                ..
411            } => {
412                let mut content = content
413                    .into_iter()
414                    .map(|content| match content {
415                        AssistantContent::Text { text } => message::AssistantContent::text(text),
416                    })
417                    .collect::<Vec<_>>();
418
419                content.extend(
420                    tool_calls
421                        .into_iter()
422                        .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
423                        .collect::<Result<Vec<_>, _>>()?,
424                );
425
426                message::Message::Assistant {
427                    id: None,
428                    content: OneOrMany::many(content).map_err(|_| {
429                        message::MessageError::ConversionError(
430                            "Neither `content` nor `tool_calls` was provided to the Message"
431                                .to_owned(),
432                        )
433                    })?,
434                }
435            }
436
437            Message::ToolResult { name, content, .. } => message::Message::User {
438                content: OneOrMany::one(message::UserContent::tool_result(
439                    name,
440                    content.map(message::ToolResultContent::text),
441                )),
442            },
443
444            // System messages should get stripped out when converting message's, this is just a
445            // stop gap to avoid obnoxious error handling or panic occurring.
446            Message::System { content, .. } => message::Message::User {
447                content: content.map(|c| match c {
448                    SystemContent::Text { text } => message::UserContent::text(text),
449                }),
450            },
451        })
452    }
453}
454
455#[derive(Clone, Debug, Deserialize, Serialize)]
456pub struct Choice {
457    pub finish_reason: String,
458    pub index: usize,
459    #[serde(default)]
460    pub logprobs: serde_json::Value,
461    pub message: Message,
462}
463
464#[derive(Debug, Deserialize, Clone, Serialize)]
465pub struct Usage {
466    pub completion_tokens: i32,
467    pub prompt_tokens: i32,
468    pub total_tokens: i32,
469}
470
471impl GetTokenUsage for Usage {
472    fn token_usage(&self) -> Option<crate::completion::Usage> {
473        let mut usage = crate::completion::Usage::new();
474        usage.input_tokens = self.prompt_tokens as u64;
475        usage.output_tokens = self.completion_tokens as u64;
476        usage.total_tokens = self.total_tokens as u64;
477
478        Some(usage)
479    }
480}
481
482#[derive(Clone, Debug, Deserialize, Serialize)]
483pub struct CompletionResponse {
484    pub created: i32,
485    pub id: String,
486    pub model: String,
487    pub choices: Vec<Choice>,
488    #[serde(default, deserialize_with = "default_string_on_null")]
489    pub system_fingerprint: String,
490    pub usage: Usage,
491}
492
493impl crate::telemetry::ProviderResponseExt for CompletionResponse {
494    type OutputMessage = Choice;
495    type Usage = Usage;
496
497    fn get_response_id(&self) -> Option<String> {
498        Some(self.id.clone())
499    }
500
501    fn get_response_model_name(&self) -> Option<String> {
502        Some(self.model.clone())
503    }
504
505    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
506        self.choices.clone()
507    }
508
509    fn get_text_response(&self) -> Option<String> {
510        let text_response = self
511            .choices
512            .iter()
513            .filter_map(|x| {
514                let Message::User { ref content } = x.message else {
515                    return None;
516                };
517
518                let text = content
519                    .iter()
520                    .filter_map(|x| {
521                        if let UserContent::Text { text } = x {
522                            Some(text.clone())
523                        } else {
524                            None
525                        }
526                    })
527                    .collect::<Vec<String>>();
528
529                if text.is_empty() {
530                    None
531                } else {
532                    Some(text.join("\n"))
533                }
534            })
535            .collect::<Vec<String>>()
536            .join("\n");
537
538        if text_response.is_empty() {
539            None
540        } else {
541            Some(text_response)
542        }
543    }
544
545    fn get_usage(&self) -> Option<Self::Usage> {
546        Some(self.usage.clone())
547    }
548}
549
550fn default_string_on_null<'de, D>(deserializer: D) -> Result<String, D::Error>
551where
552    D: Deserializer<'de>,
553{
554    match Option::<String>::deserialize(deserializer)? {
555        Some(value) => Ok(value),      // Use provided value
556        None => Ok(String::default()), // Use `Default` implementation
557    }
558}
559
560impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
561    type Error = CompletionError;
562
563    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
564        let choice = response.choices.first().ok_or_else(|| {
565            CompletionError::ResponseError("Response contained no choices".to_owned())
566        })?;
567
568        let content = match &choice.message {
569            Message::Assistant {
570                content,
571                tool_calls,
572                ..
573            } => {
574                let mut content = content
575                    .iter()
576                    .map(|c| match c {
577                        AssistantContent::Text { text } => message::AssistantContent::text(text),
578                    })
579                    .collect::<Vec<_>>();
580
581                content.extend(
582                    tool_calls
583                        .iter()
584                        .map(|call| {
585                            completion::AssistantContent::tool_call(
586                                &call.id,
587                                &call.function.name,
588                                call.function.arguments.clone(),
589                            )
590                        })
591                        .collect::<Vec<_>>(),
592                );
593                Ok(content)
594            }
595            _ => Err(CompletionError::ResponseError(
596                "Response did not contain a valid message or tool call".into(),
597            )),
598        }?;
599
600        let choice = OneOrMany::many(content).map_err(|_| {
601            CompletionError::ResponseError(
602                "Response contained no message or tool call (empty)".to_owned(),
603            )
604        })?;
605
606        let usage = completion::Usage {
607            input_tokens: response.usage.prompt_tokens as u64,
608            output_tokens: response.usage.completion_tokens as u64,
609            total_tokens: response.usage.total_tokens as u64,
610            cached_input_tokens: 0,
611            cache_creation_input_tokens: 0,
612        };
613
614        Ok(completion::CompletionResponse {
615            choice,
616            usage,
617            raw_response: response,
618            message_id: None,
619        })
620    }
621}
622
623#[derive(Debug, Serialize, Deserialize)]
624pub(super) struct HuggingfaceCompletionRequest {
625    model: String,
626    pub messages: Vec<Message>,
627    #[serde(skip_serializing_if = "Option::is_none")]
628    temperature: Option<f64>,
629    #[serde(skip_serializing_if = "Vec::is_empty")]
630    tools: Vec<ToolDefinition>,
631    #[serde(skip_serializing_if = "Option::is_none")]
632    tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
633    #[serde(flatten, skip_serializing_if = "Option::is_none")]
634    pub additional_params: Option<serde_json::Value>,
635}
636
637impl TryFrom<(&str, CompletionRequest)> for HuggingfaceCompletionRequest {
638    type Error = CompletionError;
639
640    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
641        if req.output_schema.is_some() {
642            tracing::warn!("Structured outputs currently not supported for Huggingface");
643        }
644        let model = req.model.clone().unwrap_or_else(|| model.to_string());
645        let mut full_history: Vec<Message> = match &req.preamble {
646            Some(preamble) => vec![Message::system(preamble)],
647            None => vec![],
648        };
649        if let Some(docs) = req.normalized_documents() {
650            let docs: Vec<Message> = docs.try_into()?;
651            full_history.extend(docs);
652        }
653
654        let chat_history: Vec<Message> = req
655            .chat_history
656            .clone()
657            .into_iter()
658            .map(|message| message.try_into())
659            .collect::<Result<Vec<Vec<Message>>, _>>()?
660            .into_iter()
661            .flatten()
662            .collect();
663
664        full_history.extend(chat_history);
665
666        if full_history.is_empty() {
667            return Err(CompletionError::RequestError(
668                std::io::Error::new(
669                    std::io::ErrorKind::InvalidInput,
670                    "HuggingFace request has no provider-compatible messages after conversion",
671                )
672                .into(),
673            ));
674        }
675
676        let tool_choice = req
677            .tool_choice
678            .clone()
679            .map(crate::providers::openai::completion::ToolChoice::try_from)
680            .transpose()?;
681
682        Ok(Self {
683            model: model.to_string(),
684            messages: full_history,
685            temperature: req.temperature,
686            tools: req
687                .tools
688                .clone()
689                .into_iter()
690                .map(ToolDefinition::from)
691                .collect::<Vec<_>>(),
692            tool_choice,
693            additional_params: req.additional_params,
694        })
695    }
696}
697
698#[derive(Clone)]
699pub struct CompletionModel<T = reqwest::Client> {
700    pub(crate) client: Client<T>,
701    /// Name of the model (e.g: google/gemma-2-2b-it)
702    pub model: String,
703}
704
705impl<T> CompletionModel<T> {
706    pub fn new(client: Client<T>, model: &str) -> Self {
707        Self {
708            client,
709            model: model.to_string(),
710        }
711    }
712}
713
714impl<T> completion::CompletionModel for CompletionModel<T>
715where
716    T: HttpClientExt + Clone + 'static,
717{
718    type Response = CompletionResponse;
719    type StreamingResponse = StreamingCompletionResponse;
720
721    type Client = Client<T>;
722
723    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
724        Self::new(client.clone(), &model.into())
725    }
726
727    async fn completion(
728        &self,
729        completion_request: CompletionRequest,
730    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
731        let request_model = completion_request
732            .model
733            .clone()
734            .unwrap_or_else(|| self.model.clone());
735        let span = if tracing::Span::current().is_disabled() {
736            info_span!(
737                target: "rig::completions",
738                "chat",
739                gen_ai.operation.name = "chat",
740                gen_ai.provider.name = "huggingface",
741                gen_ai.request.model = &request_model,
742                gen_ai.system_instructions = &completion_request.preamble,
743                gen_ai.response.id = tracing::field::Empty,
744                gen_ai.response.model = tracing::field::Empty,
745                gen_ai.usage.output_tokens = tracing::field::Empty,
746                gen_ai.usage.input_tokens = tracing::field::Empty,
747                gen_ai.usage.cached_tokens = tracing::field::Empty,
748            )
749        } else {
750            tracing::Span::current()
751        };
752
753        let model = self.client.subprovider().model_identifier(&request_model);
754        let request = HuggingfaceCompletionRequest::try_from((model.as_ref(), completion_request))?;
755
756        if enabled!(Level::TRACE) {
757            tracing::trace!(
758                target: "rig::completions",
759                "Huggingface completion request: {}",
760                serde_json::to_string_pretty(&request)?
761            );
762        }
763
764        let request = serde_json::to_vec(&request)?;
765
766        let path = self
767            .client
768            .subprovider()
769            .completion_endpoint(&request_model);
770        let request = self
771            .client
772            .post(&path)?
773            .header("Content-Type", "application/json")
774            .body(request)
775            .map_err(|e| CompletionError::HttpError(e.into()))?;
776
777        async move {
778            let response = self.client.send(request).await?;
779
780            if response.status().is_success() {
781                let bytes: Vec<u8> = response.into_body().await?;
782                let text = String::from_utf8_lossy(&bytes);
783
784                tracing::debug!(target: "rig", "Huggingface completion error: {}", text);
785
786                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&bytes)? {
787                    ApiResponse::Ok(response) => {
788                        if enabled!(Level::TRACE) {
789                            tracing::trace!(
790                                target: "rig::completions",
791                                "Huggingface completion response: {}",
792                                serde_json::to_string_pretty(&response)?
793                            );
794                        }
795
796                        let span = tracing::Span::current();
797                        span.record_token_usage(&response.usage);
798                        span.record_response_metadata(&response);
799
800                        response.try_into()
801                    }
802                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())),
803                }
804            } else {
805                let status = response.status();
806                let text: Vec<u8> = response.into_body().await?;
807                let text: String = String::from_utf8_lossy(&text).into();
808
809                Err(CompletionError::ProviderError(format!(
810                    "{}: {}",
811                    status, text
812                )))
813            }
814        }
815        .instrument(span)
816        .await
817    }
818
819    async fn stream(
820        &self,
821        request: CompletionRequest,
822    ) -> Result<
823        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
824        CompletionError,
825    > {
826        CompletionModel::stream(self, request).await
827    }
828}
829
830#[cfg(test)]
831mod tests {
832    use super::*;
833    use serde_path_to_error::deserialize;
834
835    #[test]
836    fn test_huggingface_request_uses_request_model_override() {
837        let request = CompletionRequest {
838            model: Some("meta-llama/Meta-Llama-3.1-8B-Instruct".to_string()),
839            preamble: None,
840            chat_history: crate::OneOrMany::one("Hello".into()),
841            documents: vec![],
842            tools: vec![],
843            temperature: None,
844            max_tokens: None,
845            tool_choice: None,
846            additional_params: None,
847            output_schema: None,
848        };
849
850        let hf_request = HuggingfaceCompletionRequest::try_from(("mistralai/Mistral-7B", request))
851            .expect("request conversion should succeed");
852        let serialized = serde_json::to_value(hf_request).expect("serialization should succeed");
853
854        assert_eq!(serialized["model"], "meta-llama/Meta-Llama-3.1-8B-Instruct");
855    }
856
857    #[test]
858    fn test_huggingface_request_uses_default_model_when_override_unset() {
859        let request = CompletionRequest {
860            model: None,
861            preamble: None,
862            chat_history: crate::OneOrMany::one("Hello".into()),
863            documents: vec![],
864            tools: vec![],
865            temperature: None,
866            max_tokens: None,
867            tool_choice: None,
868            additional_params: None,
869            output_schema: None,
870        };
871
872        let hf_request = HuggingfaceCompletionRequest::try_from(("mistralai/Mistral-7B", request))
873            .expect("request conversion should succeed");
874        let serialized = serde_json::to_value(hf_request).expect("serialization should succeed");
875
876        assert_eq!(serialized["model"], "mistralai/Mistral-7B");
877    }
878
879    #[test]
880    fn test_deserialize_message() {
881        let assistant_message_json = r#"
882        {
883            "role": "assistant",
884            "content": "\n\nHello there, how may I assist you today?"
885        }
886        "#;
887
888        let assistant_message_json2 = r#"
889        {
890            "role": "assistant",
891            "content": [
892                {
893                    "type": "text",
894                    "text": "\n\nHello there, how may I assist you today?"
895                }
896            ],
897            "tool_calls": null
898        }
899        "#;
900
901        let assistant_message_json3 = r#"
902        {
903            "role": "assistant",
904            "tool_calls": [
905                {
906                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
907                    "type": "function",
908                    "function": {
909                        "name": "subtract",
910                        "arguments": {"x": 2, "y": 5}
911                    }
912                }
913            ],
914            "content": null,
915            "refusal": null
916        }
917        "#;
918
919        let user_message_json = r#"
920        {
921            "role": "user",
922            "content": [
923                {
924                    "type": "text",
925                    "text": "What's in this image?"
926                },
927                {
928                    "type": "image_url",
929                    "image_url": {
930                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
931                    }
932                }
933            ]
934        }
935        "#;
936
937        let assistant_message: Message = {
938            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
939            deserialize(jd).unwrap_or_else(|err| {
940                panic!(
941                    "Deserialization error at {} ({}:{}): {}",
942                    err.path(),
943                    err.inner().line(),
944                    err.inner().column(),
945                    err
946                );
947            })
948        };
949
950        let assistant_message2: Message = {
951            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
952            deserialize(jd).unwrap_or_else(|err| {
953                panic!(
954                    "Deserialization error at {} ({}:{}): {}",
955                    err.path(),
956                    err.inner().line(),
957                    err.inner().column(),
958                    err
959                );
960            })
961        };
962
963        let assistant_message3: Message = {
964            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
965                &mut serde_json::Deserializer::from_str(assistant_message_json3);
966            deserialize(jd).unwrap_or_else(|err| {
967                panic!(
968                    "Deserialization error at {} ({}:{}): {}",
969                    err.path(),
970                    err.inner().line(),
971                    err.inner().column(),
972                    err
973                );
974            })
975        };
976
977        let user_message: Message = {
978            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
979            deserialize(jd).unwrap_or_else(|err| {
980                panic!(
981                    "Deserialization error at {} ({}:{}): {}",
982                    err.path(),
983                    err.inner().line(),
984                    err.inner().column(),
985                    err
986                );
987            })
988        };
989
990        match assistant_message {
991            Message::Assistant { content, .. } => {
992                assert_eq!(
993                    content[0],
994                    AssistantContent::Text {
995                        text: "\n\nHello there, how may I assist you today?".to_string()
996                    }
997                );
998            }
999            _ => panic!("Expected assistant message"),
1000        }
1001
1002        match assistant_message2 {
1003            Message::Assistant {
1004                content,
1005                tool_calls,
1006                ..
1007            } => {
1008                assert_eq!(
1009                    content[0],
1010                    AssistantContent::Text {
1011                        text: "\n\nHello there, how may I assist you today?".to_string()
1012                    }
1013                );
1014
1015                assert_eq!(tool_calls, vec![]);
1016            }
1017            _ => panic!("Expected assistant message"),
1018        }
1019
1020        match assistant_message3 {
1021            Message::Assistant {
1022                content,
1023                tool_calls,
1024                ..
1025            } => {
1026                assert!(content.is_empty());
1027                assert_eq!(
1028                    tool_calls[0],
1029                    ToolCall {
1030                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
1031                        r#type: ToolType::Function,
1032                        function: Function {
1033                            name: "subtract".to_string(),
1034                            arguments: serde_json::json!({"x": 2, "y": 5}),
1035                        },
1036                    }
1037                );
1038            }
1039            _ => panic!("Expected assistant message"),
1040        }
1041
1042        match user_message {
1043            Message::User { content, .. } => {
1044                let (first, second) = {
1045                    let mut iter = content.into_iter();
1046                    (iter.next().unwrap(), iter.next().unwrap())
1047                };
1048                assert_eq!(
1049                    first,
1050                    UserContent::Text {
1051                        text: "What's in this image?".to_string()
1052                    }
1053                );
1054                assert_eq!(second, UserContent::ImageUrl { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string() } });
1055            }
1056            _ => panic!("Expected user message"),
1057        }
1058    }
1059
1060    #[test]
1061    fn test_message_to_message_conversion() {
1062        let user_message = message::Message::User {
1063            content: OneOrMany::one(message::UserContent::text("Hello")),
1064        };
1065
1066        let assistant_message = message::Message::Assistant {
1067            id: None,
1068            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
1069        };
1070
1071        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
1072        let converted_assistant_message: Vec<Message> =
1073            assistant_message.clone().try_into().unwrap();
1074
1075        match converted_user_message[0].clone() {
1076            Message::User { content, .. } => {
1077                assert_eq!(
1078                    content.first(),
1079                    UserContent::Text {
1080                        text: "Hello".to_string()
1081                    }
1082                );
1083            }
1084            _ => panic!("Expected user message"),
1085        }
1086
1087        match converted_assistant_message[0].clone() {
1088            Message::Assistant { content, .. } => {
1089                assert_eq!(
1090                    content[0],
1091                    AssistantContent::Text {
1092                        text: "Hi there!".to_string()
1093                    }
1094                );
1095            }
1096            _ => panic!("Expected assistant message"),
1097        }
1098
1099        let original_user_message: message::Message =
1100            converted_user_message[0].clone().try_into().unwrap();
1101        let original_assistant_message: message::Message =
1102            converted_assistant_message[0].clone().try_into().unwrap();
1103
1104        assert_eq!(original_user_message, user_message);
1105        assert_eq!(original_assistant_message, assistant_message);
1106    }
1107
1108    #[test]
1109    fn test_message_from_message_conversion() {
1110        let user_message = Message::User {
1111            content: OneOrMany::one(UserContent::Text {
1112                text: "Hello".to_string(),
1113            }),
1114        };
1115
1116        let assistant_message = Message::Assistant {
1117            content: vec![AssistantContent::Text {
1118                text: "Hi there!".to_string(),
1119            }],
1120            tool_calls: vec![],
1121        };
1122
1123        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1124        let converted_assistant_message: message::Message =
1125            assistant_message.clone().try_into().unwrap();
1126
1127        match converted_user_message.clone() {
1128            message::Message::User { content } => {
1129                assert_eq!(content.first(), message::UserContent::text("Hello"));
1130            }
1131            _ => panic!("Expected user message"),
1132        }
1133
1134        match converted_assistant_message.clone() {
1135            message::Message::Assistant { content, .. } => {
1136                assert_eq!(
1137                    content.first(),
1138                    message::AssistantContent::text("Hi there!")
1139                );
1140            }
1141            _ => panic!("Expected assistant message"),
1142        }
1143
1144        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
1145        let original_assistant_message: Vec<Message> =
1146            converted_assistant_message.try_into().unwrap();
1147
1148        assert_eq!(original_user_message[0], user_message);
1149        assert_eq!(original_assistant_message[0], assistant_message);
1150    }
1151
1152    #[test]
1153    fn test_responses() {
1154        let fireworks_response_json = r#"
1155        {
1156            "choices": [
1157                {
1158                    "finish_reason": "tool_calls",
1159                    "index": 0,
1160                    "message": {
1161                        "role": "assistant",
1162                        "tool_calls": [
1163                            {
1164                                "function": {
1165                                "arguments": "{\"x\": 2, \"y\": 5}",
1166                                "name": "subtract"
1167                                },
1168                                "id": "call_1BspL6mQqjKgvsQbH1TIYkHf",
1169                                "index": 0,
1170                                "type": "function"
1171                            }
1172                        ]
1173                    }
1174                }
1175            ],
1176            "created": 1740704000,
1177            "id": "2a81f6a1-4866-42fb-9902-2655a2b5b1ff",
1178            "model": "accounts/fireworks/models/deepseek-v3",
1179            "object": "chat.completion",
1180            "usage": {
1181                "completion_tokens": 26,
1182                "prompt_tokens": 248,
1183                "total_tokens": 274
1184            }
1185        }
1186        "#;
1187
1188        let novita_response_json = r#"
1189        {
1190            "choices": [
1191                {
1192                    "finish_reason": "tool_calls",
1193                    "index": 0,
1194                    "logprobs": null,
1195                    "message": {
1196                        "audio": null,
1197                        "content": null,
1198                        "function_call": null,
1199                        "reasoning_content": null,
1200                        "refusal": null,
1201                        "role": "assistant",
1202                        "tool_calls": [
1203                            {
1204                                "function": {
1205                                    "arguments": "{\"x\": \"2\", \"y\": \"5\"}",
1206                                    "name": "subtract"
1207                                },
1208                                "id": "chatcmpl-tool-f6d2af7c8dc041058f95e2c2eede45c5",
1209                                "type": "function"
1210                            }
1211                        ]
1212                    },
1213                    "stop_reason": 128008
1214                }
1215            ],
1216            "created": 1740704592,
1217            "id": "chatcmpl-a92c60ae125c47c998ecdcb53387fed4",
1218            "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-fast",
1219            "object": "chat.completion",
1220            "prompt_logprobs": null,
1221            "service_tier": null,
1222            "system_fingerprint": null,
1223            "usage": {
1224                "completion_tokens": 28,
1225                "completion_tokens_details": null,
1226                "prompt_tokens": 335,
1227                "prompt_tokens_details": null,
1228                "total_tokens": 363
1229            }
1230        }
1231        "#;
1232
1233        let _firework_response: CompletionResponse = {
1234            let jd = &mut serde_json::Deserializer::from_str(fireworks_response_json);
1235            deserialize(jd).unwrap_or_else(|err| {
1236                panic!(
1237                    "Deserialization error at {} ({}:{}): {}",
1238                    err.path(),
1239                    err.inner().line(),
1240                    err.inner().column(),
1241                    err
1242                );
1243            })
1244        };
1245
1246        let _novita_response: CompletionResponse = {
1247            let jd = &mut serde_json::Deserializer::from_str(novita_response_json);
1248            deserialize(jd).unwrap_or_else(|err| {
1249                panic!(
1250                    "Deserialization error at {} ({}:{}): {}",
1251                    err.path(),
1252                    err.inner().line(),
1253                    err.inner().column(),
1254                    err
1255                );
1256            })
1257        };
1258    }
1259
1260    #[test]
1261    fn test_assistant_reasoning_is_silently_skipped() {
1262        let assistant = message::Message::Assistant {
1263            id: None,
1264            content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1265        };
1266
1267        let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
1268        assert!(converted.is_empty());
1269    }
1270
1271    #[test]
1272    fn test_assistant_text_and_tool_call_are_preserved_when_reasoning_present() {
1273        let assistant = message::Message::Assistant {
1274            id: None,
1275            content: OneOrMany::many(vec![
1276                message::AssistantContent::reasoning("hidden"),
1277                message::AssistantContent::text("visible"),
1278                message::AssistantContent::tool_call(
1279                    "call_1",
1280                    "subtract",
1281                    serde_json::json!({"x": 2, "y": 1}),
1282                ),
1283            ])
1284            .expect("non-empty assistant content"),
1285        };
1286
1287        let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
1288        assert_eq!(converted.len(), 1);
1289
1290        match &converted[0] {
1291            Message::Assistant {
1292                content,
1293                tool_calls,
1294                ..
1295            } => {
1296                assert_eq!(
1297                    content,
1298                    &vec![AssistantContent::Text {
1299                        text: "visible".to_string()
1300                    }]
1301                );
1302                assert_eq!(tool_calls.len(), 1);
1303                assert_eq!(tool_calls[0].id, "call_1");
1304                assert_eq!(tool_calls[0].function.name, "subtract");
1305                assert_eq!(
1306                    tool_calls[0].function.arguments,
1307                    serde_json::json!({"x": 2, "y": 1})
1308                );
1309            }
1310            _ => panic!("expected assistant message"),
1311        }
1312    }
1313
1314    #[test]
1315    fn test_request_conversion_errors_when_all_messages_are_filtered() {
1316        let request = completion::CompletionRequest {
1317            preamble: None,
1318            chat_history: OneOrMany::one(message::Message::Assistant {
1319                id: None,
1320                content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1321            }),
1322            documents: vec![],
1323            tools: vec![],
1324            temperature: None,
1325            max_tokens: None,
1326            tool_choice: None,
1327            additional_params: None,
1328            model: None,
1329            output_schema: None,
1330        };
1331
1332        let result = HuggingfaceCompletionRequest::try_from(("meta/test-model", request));
1333        assert!(matches!(result, Err(CompletionError::RequestError(_))));
1334    }
1335}