Skip to main content

rig_core/providers/huggingface/
completion.rs

1use super::client::Client;
2use crate::completion::GetTokenUsage;
3use crate::http_client::HttpClientExt;
4use crate::providers::openai::StreamingCompletionResponse;
5use crate::telemetry::SpanCombinator;
6use crate::{
7    OneOrMany,
8    completion::{self, CompletionError, CompletionRequest},
9    json_utils,
10    message::{self},
11    one_or_many::string_or_one_or_many,
12};
13use serde::{Deserialize, Deserializer, Serialize, Serializer};
14use serde_json::Value;
15use std::{convert::Infallible, str::FromStr};
16use tracing::{Level, enabled, info_span};
17use tracing_futures::Instrument;
18
19#[derive(Debug, Deserialize)]
20#[serde(untagged)]
21pub enum ApiResponse<T> {
22    Ok(T),
23    Err(Value),
24}
25
26// ================================================================
27// Huggingface Completion API
28// ================================================================
29
30// Conversational LLMs
31/// `google/gemma-2-2b-it` completion model
32pub const GEMMA_2: &str = "google/gemma-2-2b-it";
33/// `meta-llama/Meta-Llama-3.1-8B-Instruct` completion model
34pub const META_LLAMA_3_1: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
35/// `PowerInfer/SmallThinker-3B-Preview` completion model
36pub const SMALLTHINKER_PREVIEW: &str = "PowerInfer/SmallThinker-3B-Preview";
37/// `Qwen/Qwen2.5-7B-Instruct` completion model
38pub const QWEN2_5: &str = "Qwen/Qwen2.5-7B-Instruct";
39/// `Qwen/Qwen2.5-Coder-32B-Instruct` completion model
40pub const QWEN2_5_CODER: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
41
42// Conversational VLMs
43
44/// `Qwen/Qwen2-VL-7B-Instruct` visual-language completion model
45pub const QWEN2_VL: &str = "Qwen/Qwen2-VL-7B-Instruct";
46/// `Qwen/QVQ-72B-Preview` visual-language completion model
47pub const QWEN_QVQ_PREVIEW: &str = "Qwen/QVQ-72B-Preview";
48
49#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
50pub struct Function {
51    name: String,
52    #[serde(
53        serialize_with = "json_utils::stringified_json::serialize",
54        deserialize_with = "json_utils::stringified_json::deserialize_maybe_stringified"
55    )]
56    pub arguments: serde_json::Value,
57}
58
59impl From<Function> for message::ToolFunction {
60    fn from(value: Function) -> Self {
61        message::ToolFunction {
62            name: value.name,
63            arguments: value.arguments,
64        }
65    }
66}
67
68#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
69#[serde(rename_all = "lowercase")]
70pub enum ToolType {
71    #[default]
72    Function,
73}
74
75#[derive(Debug, Deserialize, Serialize, Clone)]
76pub struct ToolDefinition {
77    pub r#type: String,
78    pub function: completion::ToolDefinition,
79}
80
81impl From<completion::ToolDefinition> for ToolDefinition {
82    fn from(tool: completion::ToolDefinition) -> Self {
83        Self {
84            r#type: "function".into(),
85            function: tool,
86        }
87    }
88}
89
90#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
91pub struct ToolCall {
92    pub id: String,
93    pub r#type: ToolType,
94    pub function: Function,
95}
96
97impl From<ToolCall> for message::ToolCall {
98    fn from(value: ToolCall) -> Self {
99        message::ToolCall {
100            id: value.id,
101            call_id: None,
102            function: value.function.into(),
103            signature: None,
104            additional_params: None,
105        }
106    }
107}
108
109impl From<message::ToolCall> for ToolCall {
110    fn from(value: message::ToolCall) -> Self {
111        ToolCall {
112            id: value.id,
113            r#type: ToolType::Function,
114            function: Function {
115                name: value.function.name,
116                arguments: value.function.arguments,
117            },
118        }
119    }
120}
121
122#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
123pub struct ImageUrl {
124    url: String,
125}
126
127#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
128#[serde(tag = "type", rename_all = "lowercase")]
129pub enum UserContent {
130    Text {
131        text: String,
132    },
133    #[serde(rename = "image_url")]
134    ImageUrl {
135        image_url: ImageUrl,
136    },
137}
138
139impl FromStr for UserContent {
140    type Err = Infallible;
141
142    fn from_str(s: &str) -> Result<Self, Self::Err> {
143        Ok(UserContent::Text {
144            text: s.to_string(),
145        })
146    }
147}
148
149#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
150#[serde(tag = "type", rename_all = "lowercase")]
151pub enum AssistantContent {
152    Text { text: String },
153}
154
155impl FromStr for AssistantContent {
156    type Err = Infallible;
157
158    fn from_str(s: &str) -> Result<Self, Self::Err> {
159        Ok(AssistantContent::Text {
160            text: s.to_string(),
161        })
162    }
163}
164
165#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
166#[serde(tag = "type", rename_all = "lowercase")]
167pub enum SystemContent {
168    Text { text: String },
169}
170
171impl FromStr for SystemContent {
172    type Err = Infallible;
173
174    fn from_str(s: &str) -> Result<Self, Self::Err> {
175        Ok(SystemContent::Text {
176            text: s.to_string(),
177        })
178    }
179}
180
181impl From<UserContent> for message::UserContent {
182    fn from(value: UserContent) -> Self {
183        match value {
184            UserContent::Text { text } => message::UserContent::text(text),
185            UserContent::ImageUrl { image_url } => {
186                message::UserContent::image_url(image_url.url, None, None)
187            }
188        }
189    }
190}
191
192impl TryFrom<message::UserContent> for UserContent {
193    type Error = message::MessageError;
194
195    fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
196        match content {
197            message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }),
198            message::UserContent::Document(message::Document {
199                data: message::DocumentSourceKind::Raw(raw),
200                ..
201            }) => {
202                let text = String::from_utf8_lossy(raw.as_slice()).into();
203                Ok(UserContent::Text { text })
204            }
205            message::UserContent::Document(message::Document {
206                data:
207                    message::DocumentSourceKind::Base64(text)
208                    | message::DocumentSourceKind::String(text),
209                ..
210            }) => Ok(UserContent::Text { text }),
211            message::UserContent::Image(message::Image { data, .. }) => match data {
212                message::DocumentSourceKind::Url(url) => Ok(UserContent::ImageUrl {
213                    image_url: ImageUrl { url },
214                }),
215                _ => Err(message::MessageError::ConversionError(
216                    "Huggingface only supports images as urls".into(),
217                )),
218            },
219            _ => Err(message::MessageError::ConversionError(
220                "Huggingface only supports text and images".into(),
221            )),
222        }
223    }
224}
225
226#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
227#[serde(tag = "role", rename_all = "lowercase")]
228pub enum Message {
229    System {
230        #[serde(deserialize_with = "string_or_one_or_many")]
231        content: OneOrMany<SystemContent>,
232    },
233    User {
234        #[serde(deserialize_with = "string_or_one_or_many")]
235        content: OneOrMany<UserContent>,
236    },
237    Assistant {
238        #[serde(default, deserialize_with = "json_utils::string_or_vec")]
239        content: Vec<AssistantContent>,
240        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
241        tool_calls: Vec<ToolCall>,
242    },
243    #[serde(rename = "tool", alias = "Tool")]
244    ToolResult {
245        name: String,
246        #[serde(skip_serializing_if = "Option::is_none")]
247        arguments: Option<serde_json::Value>,
248        #[serde(
249            deserialize_with = "string_or_one_or_many",
250            serialize_with = "serialize_tool_content"
251        )]
252        content: OneOrMany<String>,
253    },
254}
255
256fn serialize_tool_content<S>(content: &OneOrMany<String>, serializer: S) -> Result<S::Ok, S::Error>
257where
258    S: Serializer,
259{
260    // OpenAI-compatible APIs expect tool content as a string, not an array
261    let joined = content
262        .iter()
263        .map(String::as_str)
264        .collect::<Vec<_>>()
265        .join("\n");
266    serializer.serialize_str(&joined)
267}
268
269impl Message {
270    pub fn system(content: &str) -> Self {
271        Message::System {
272            content: OneOrMany::one(SystemContent::Text {
273                text: content.to_string(),
274            }),
275        }
276    }
277}
278
279impl TryFrom<message::Message> for Vec<Message> {
280    type Error = message::MessageError;
281
282    fn try_from(message: message::Message) -> Result<Vec<Message>, Self::Error> {
283        match message {
284            message::Message::System { content } => Ok(vec![Message::system(&content)]),
285            message::Message::User { content } => {
286                let (tool_results, other_content): (Vec<_>, Vec<_>) = content
287                    .into_iter()
288                    .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
289
290                if !tool_results.is_empty() {
291                    tool_results
292                        .into_iter()
293                        .map(|content| match content {
294                            message::UserContent::ToolResult(message::ToolResult {
295                                id,
296                                content,
297                                ..
298                            }) => Ok::<_, message::MessageError>(Message::ToolResult {
299                                name: id,
300                                arguments: None,
301                                content: content.try_map(|content| match content {
302                                    message::ToolResultContent::Text(message::Text { text }) => {
303                                        Ok(text)
304                                    }
305                                    _ => Err(message::MessageError::ConversionError(
306                                        "Tool result content does not support non-text".into(),
307                                    )),
308                                })?,
309                            }),
310                            _ => Err(message::MessageError::ConversionError(
311                                "expected tool result content while converting HuggingFace input"
312                                    .into(),
313                            )),
314                        })
315                        .collect::<Result<Vec<_>, _>>()
316                } else {
317                    let other_content = OneOrMany::many(other_content).map_err(|_| {
318                        message::MessageError::ConversionError(
319                            "HuggingFace user message did not contain any non-tool content".into(),
320                        )
321                    })?;
322
323                    Ok(vec![Message::User {
324                        content: other_content.try_map(|content| match content {
325                            message::UserContent::Text(text) => {
326                                Ok(UserContent::Text { text: text.text })
327                            }
328                            message::UserContent::Image(image) => {
329                                let url = image.try_into_url()?;
330
331                                Ok(UserContent::ImageUrl {
332                                    image_url: ImageUrl { url },
333                                })
334                            }
335                            message::UserContent::Document(message::Document {
336                                data: message::DocumentSourceKind::Raw(raw), ..
337                            }) => {
338                                let text = String::from_utf8_lossy(raw.as_slice()).into();
339                                Ok(UserContent::Text { text })
340                            }
341                            message::UserContent::Document(message::Document {
342                                data: message::DocumentSourceKind::Base64(text) | message::DocumentSourceKind::String(text), ..
343                            }) => {
344                                Ok(UserContent::Text { text })
345                            }
346                            _ => Err(message::MessageError::ConversionError(
347                                "Huggingface inputs only support text and image URLs (both base64-encoded images and regular URLs)".into(),
348                            )),
349                        })?,
350                    }])
351                }
352            }
353            message::Message::Assistant { content, .. } => {
354                let mut text_content = Vec::new();
355                let mut tool_calls = Vec::new();
356
357                for content in content {
358                    match content {
359                        message::AssistantContent::Text(text) => text_content.push(text),
360                        message::AssistantContent::ToolCall(tool_call) => {
361                            tool_calls.push(tool_call)
362                        }
363                        message::AssistantContent::Reasoning(_) => {
364                            // HuggingFace does not support assistant-history reasoning items.
365                            // Silently skip unsupported reasoning content.
366                        }
367                        message::AssistantContent::Image(_) => {
368                            return Err(message::MessageError::ConversionError(
369                                "HuggingFace assistant messages do not support image content"
370                                    .into(),
371                            ));
372                        }
373                    }
374                }
375
376                if text_content.is_empty() && tool_calls.is_empty() {
377                    return Ok(vec![]);
378                }
379
380                Ok(vec![Message::Assistant {
381                    content: text_content
382                        .into_iter()
383                        .map(|content| AssistantContent::Text { text: content.text })
384                        .collect::<Vec<_>>(),
385                    tool_calls: tool_calls
386                        .into_iter()
387                        .map(|tool_call| tool_call.into())
388                        .collect::<Vec<_>>(),
389                }])
390            }
391        }
392    }
393}
394
395impl TryFrom<Message> for message::Message {
396    type Error = message::MessageError;
397
398    fn try_from(message: Message) -> Result<Self, Self::Error> {
399        Ok(match message {
400            Message::User { content, .. } => message::Message::User {
401                content: content.map(|content| content.into()),
402            },
403            Message::Assistant {
404                content,
405                tool_calls,
406                ..
407            } => {
408                let mut content = content
409                    .into_iter()
410                    .map(|content| match content {
411                        AssistantContent::Text { text } => message::AssistantContent::text(text),
412                    })
413                    .collect::<Vec<_>>();
414
415                content.extend(
416                    tool_calls
417                        .into_iter()
418                        .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
419                        .collect::<Result<Vec<_>, _>>()?,
420                );
421
422                message::Message::Assistant {
423                    id: None,
424                    content: OneOrMany::many(content).map_err(|_| {
425                        message::MessageError::ConversionError(
426                            "Neither `content` nor `tool_calls` was provided to the Message"
427                                .to_owned(),
428                        )
429                    })?,
430                }
431            }
432
433            Message::ToolResult { name, content, .. } => message::Message::User {
434                content: OneOrMany::one(message::UserContent::tool_result(
435                    name,
436                    content.map(message::ToolResultContent::text),
437                )),
438            },
439
440            // System messages should get stripped out when converting message's, this is just a
441            // stop gap to avoid obnoxious error handling or panic occurring.
442            Message::System { content, .. } => message::Message::User {
443                content: content.map(|c| match c {
444                    SystemContent::Text { text } => message::UserContent::text(text),
445                }),
446            },
447        })
448    }
449}
450
451#[derive(Clone, Debug, Deserialize, Serialize)]
452pub struct Choice {
453    pub finish_reason: String,
454    pub index: usize,
455    #[serde(default)]
456    pub logprobs: serde_json::Value,
457    pub message: Message,
458}
459
460#[derive(Debug, Deserialize, Clone, Serialize)]
461pub struct Usage {
462    pub completion_tokens: i32,
463    pub prompt_tokens: i32,
464    pub total_tokens: i32,
465}
466
467impl GetTokenUsage for Usage {
468    fn token_usage(&self) -> Option<crate::completion::Usage> {
469        let mut usage = crate::completion::Usage::new();
470        usage.input_tokens = self.prompt_tokens as u64;
471        usage.output_tokens = self.completion_tokens as u64;
472        usage.total_tokens = self.total_tokens as u64;
473
474        Some(usage)
475    }
476}
477
478#[derive(Clone, Debug, Deserialize, Serialize)]
479pub struct CompletionResponse {
480    pub created: i32,
481    pub id: String,
482    pub model: String,
483    pub choices: Vec<Choice>,
484    #[serde(default, deserialize_with = "default_string_on_null")]
485    pub system_fingerprint: String,
486    pub usage: Usage,
487}
488
489impl crate::telemetry::ProviderResponseExt for CompletionResponse {
490    type OutputMessage = Choice;
491    type Usage = Usage;
492
493    fn get_response_id(&self) -> Option<String> {
494        Some(self.id.clone())
495    }
496
497    fn get_response_model_name(&self) -> Option<String> {
498        Some(self.model.clone())
499    }
500
501    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
502        self.choices.clone()
503    }
504
505    fn get_text_response(&self) -> Option<String> {
506        let text_response = self
507            .choices
508            .iter()
509            .filter_map(|x| {
510                let Message::User { ref content } = x.message else {
511                    return None;
512                };
513
514                let text = content
515                    .iter()
516                    .filter_map(|x| {
517                        if let UserContent::Text { text } = x {
518                            Some(text.clone())
519                        } else {
520                            None
521                        }
522                    })
523                    .collect::<Vec<String>>();
524
525                if text.is_empty() {
526                    None
527                } else {
528                    Some(text.join("\n"))
529                }
530            })
531            .collect::<Vec<String>>()
532            .join("\n");
533
534        if text_response.is_empty() {
535            None
536        } else {
537            Some(text_response)
538        }
539    }
540
541    fn get_usage(&self) -> Option<Self::Usage> {
542        Some(self.usage.clone())
543    }
544}
545
546fn default_string_on_null<'de, D>(deserializer: D) -> Result<String, D::Error>
547where
548    D: Deserializer<'de>,
549{
550    match Option::<String>::deserialize(deserializer)? {
551        Some(value) => Ok(value),      // Use provided value
552        None => Ok(String::default()), // Use `Default` implementation
553    }
554}
555
556impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
557    type Error = CompletionError;
558
559    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
560        let choice = response.choices.first().ok_or_else(|| {
561            CompletionError::ResponseError("Response contained no choices".to_owned())
562        })?;
563
564        let content = match &choice.message {
565            Message::Assistant {
566                content,
567                tool_calls,
568                ..
569            } => {
570                let mut content = content
571                    .iter()
572                    .map(|c| match c {
573                        AssistantContent::Text { text } => message::AssistantContent::text(text),
574                    })
575                    .collect::<Vec<_>>();
576
577                content.extend(
578                    tool_calls
579                        .iter()
580                        .map(|call| {
581                            completion::AssistantContent::tool_call(
582                                &call.id,
583                                &call.function.name,
584                                call.function.arguments.clone(),
585                            )
586                        })
587                        .collect::<Vec<_>>(),
588                );
589                Ok(content)
590            }
591            _ => Err(CompletionError::ResponseError(
592                "Response did not contain a valid message or tool call".into(),
593            )),
594        }?;
595
596        let choice = OneOrMany::many(content).map_err(|_| {
597            CompletionError::ResponseError(
598                "Response contained no message or tool call (empty)".to_owned(),
599            )
600        })?;
601
602        let usage = completion::Usage {
603            input_tokens: response.usage.prompt_tokens as u64,
604            output_tokens: response.usage.completion_tokens as u64,
605            total_tokens: response.usage.total_tokens as u64,
606            cached_input_tokens: 0,
607            cache_creation_input_tokens: 0,
608            reasoning_tokens: 0,
609        };
610
611        Ok(completion::CompletionResponse {
612            choice,
613            usage,
614            raw_response: response,
615            message_id: None,
616        })
617    }
618}
619
620#[derive(Debug, Serialize, Deserialize)]
621pub(super) struct HuggingfaceCompletionRequest {
622    model: String,
623    pub messages: Vec<Message>,
624    #[serde(skip_serializing_if = "Option::is_none")]
625    temperature: Option<f64>,
626    #[serde(skip_serializing_if = "Vec::is_empty")]
627    tools: Vec<ToolDefinition>,
628    #[serde(skip_serializing_if = "Option::is_none")]
629    tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
630    #[serde(flatten, skip_serializing_if = "Option::is_none")]
631    pub additional_params: Option<serde_json::Value>,
632}
633
634impl TryFrom<(&str, CompletionRequest)> for HuggingfaceCompletionRequest {
635    type Error = CompletionError;
636
637    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
638        if req.output_schema.is_some() {
639            tracing::warn!("Structured outputs currently not supported for Huggingface");
640        }
641        let model = req.model.clone().unwrap_or_else(|| model.to_string());
642        let mut full_history: Vec<Message> = match &req.preamble {
643            Some(preamble) => vec![Message::system(preamble)],
644            None => vec![],
645        };
646        if let Some(docs) = req.normalized_documents() {
647            let docs: Vec<Message> = docs.try_into()?;
648            full_history.extend(docs);
649        }
650
651        let chat_history: Vec<Message> = req
652            .chat_history
653            .clone()
654            .into_iter()
655            .map(|message| message.try_into())
656            .collect::<Result<Vec<Vec<Message>>, _>>()?
657            .into_iter()
658            .flatten()
659            .collect();
660
661        full_history.extend(chat_history);
662
663        if full_history.is_empty() {
664            return Err(CompletionError::RequestError(
665                std::io::Error::new(
666                    std::io::ErrorKind::InvalidInput,
667                    "HuggingFace request has no provider-compatible messages after conversion",
668                )
669                .into(),
670            ));
671        }
672
673        let tool_choice = req
674            .tool_choice
675            .clone()
676            .map(crate::providers::openai::completion::ToolChoice::try_from)
677            .transpose()?;
678
679        Ok(Self {
680            model: model.to_string(),
681            messages: full_history,
682            temperature: req.temperature,
683            tools: req
684                .tools
685                .clone()
686                .into_iter()
687                .map(ToolDefinition::from)
688                .collect::<Vec<_>>(),
689            tool_choice,
690            additional_params: req.additional_params,
691        })
692    }
693}
694
695#[derive(Clone)]
696pub struct CompletionModel<T = reqwest::Client> {
697    pub(crate) client: Client<T>,
698    /// Name of the model (e.g: google/gemma-2-2b-it)
699    pub model: String,
700}
701
702impl<T> CompletionModel<T> {
703    pub fn new(client: Client<T>, model: &str) -> Self {
704        Self {
705            client,
706            model: model.to_string(),
707        }
708    }
709}
710
711impl<T> completion::CompletionModel for CompletionModel<T>
712where
713    T: HttpClientExt + Clone + 'static,
714{
715    type Response = CompletionResponse;
716    type StreamingResponse = StreamingCompletionResponse;
717
718    type Client = Client<T>;
719
720    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
721        Self::new(client.clone(), &model.into())
722    }
723
724    async fn completion(
725        &self,
726        completion_request: CompletionRequest,
727    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
728        let request_model = completion_request
729            .model
730            .clone()
731            .unwrap_or_else(|| self.model.clone());
732        let span = if tracing::Span::current().is_disabled() {
733            info_span!(
734                target: "rig::completions",
735                "chat",
736                gen_ai.operation.name = "chat",
737                gen_ai.provider.name = "huggingface",
738                gen_ai.request.model = &request_model,
739                gen_ai.system_instructions = &completion_request.preamble,
740                gen_ai.response.id = tracing::field::Empty,
741                gen_ai.response.model = tracing::field::Empty,
742                gen_ai.usage.output_tokens = tracing::field::Empty,
743                gen_ai.usage.input_tokens = tracing::field::Empty,
744                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
745            )
746        } else {
747            tracing::Span::current()
748        };
749
750        let model = self.client.subprovider().model_identifier(&request_model);
751        let request = HuggingfaceCompletionRequest::try_from((model.as_ref(), completion_request))?;
752
753        if enabled!(Level::TRACE) {
754            tracing::trace!(
755                target: "rig::completions",
756                "Huggingface completion request: {}",
757                serde_json::to_string_pretty(&request)?
758            );
759        }
760
761        let request = serde_json::to_vec(&request)?;
762
763        let path = self
764            .client
765            .subprovider()
766            .completion_endpoint(&request_model);
767        let request = self
768            .client
769            .post(&path)?
770            .header("Content-Type", "application/json")
771            .body(request)
772            .map_err(|e| CompletionError::HttpError(e.into()))?;
773
774        async move {
775            let response = self.client.send(request).await?;
776
777            if response.status().is_success() {
778                let bytes: Vec<u8> = response.into_body().await?;
779                let text = String::from_utf8_lossy(&bytes);
780
781                tracing::debug!(target: "rig", "Huggingface completion error: {}", text);
782
783                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&bytes)? {
784                    ApiResponse::Ok(response) => {
785                        if enabled!(Level::TRACE) {
786                            tracing::trace!(
787                                target: "rig::completions",
788                                "Huggingface completion response: {}",
789                                serde_json::to_string_pretty(&response)?
790                            );
791                        }
792
793                        let span = tracing::Span::current();
794                        span.record_token_usage(&response.usage);
795                        span.record_response_metadata(&response);
796
797                        response.try_into()
798                    }
799                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())),
800                }
801            } else {
802                let status = response.status();
803                let text: Vec<u8> = response.into_body().await?;
804                let text: String = String::from_utf8_lossy(&text).into();
805
806                Err(CompletionError::ProviderError(format!(
807                    "{}: {}",
808                    status, text
809                )))
810            }
811        }
812        .instrument(span)
813        .await
814    }
815
816    async fn stream(
817        &self,
818        request: CompletionRequest,
819    ) -> Result<
820        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
821        CompletionError,
822    > {
823        CompletionModel::stream(self, request).await
824    }
825}
826
827#[cfg(test)]
828mod tests {
829    use super::*;
830    use serde_path_to_error::deserialize;
831
832    #[test]
833    fn test_huggingface_request_uses_request_model_override() {
834        let request = CompletionRequest {
835            model: Some("meta-llama/Meta-Llama-3.1-8B-Instruct".to_string()),
836            preamble: None,
837            chat_history: crate::OneOrMany::one("Hello".into()),
838            documents: vec![],
839            tools: vec![],
840            temperature: None,
841            max_tokens: None,
842            tool_choice: None,
843            additional_params: None,
844            output_schema: None,
845        };
846
847        let hf_request = HuggingfaceCompletionRequest::try_from(("mistralai/Mistral-7B", request))
848            .expect("request conversion should succeed");
849        let serialized = serde_json::to_value(hf_request).expect("serialization should succeed");
850
851        assert_eq!(serialized["model"], "meta-llama/Meta-Llama-3.1-8B-Instruct");
852    }
853
854    #[test]
855    fn test_huggingface_request_uses_default_model_when_override_unset() {
856        let request = CompletionRequest {
857            model: None,
858            preamble: None,
859            chat_history: crate::OneOrMany::one("Hello".into()),
860            documents: vec![],
861            tools: vec![],
862            temperature: None,
863            max_tokens: None,
864            tool_choice: None,
865            additional_params: None,
866            output_schema: None,
867        };
868
869        let hf_request = HuggingfaceCompletionRequest::try_from(("mistralai/Mistral-7B", request))
870            .expect("request conversion should succeed");
871        let serialized = serde_json::to_value(hf_request).expect("serialization should succeed");
872
873        assert_eq!(serialized["model"], "mistralai/Mistral-7B");
874    }
875
876    #[test]
877    fn test_deserialize_message() {
878        let assistant_message_json = r#"
879        {
880            "role": "assistant",
881            "content": "\n\nHello there, how may I assist you today?"
882        }
883        "#;
884
885        let assistant_message_json2 = r#"
886        {
887            "role": "assistant",
888            "content": [
889                {
890                    "type": "text",
891                    "text": "\n\nHello there, how may I assist you today?"
892                }
893            ],
894            "tool_calls": null
895        }
896        "#;
897
898        let assistant_message_json3 = r#"
899        {
900            "role": "assistant",
901            "tool_calls": [
902                {
903                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
904                    "type": "function",
905                    "function": {
906                        "name": "subtract",
907                        "arguments": {"x": 2, "y": 5}
908                    }
909                }
910            ],
911            "content": null,
912            "refusal": null
913        }
914        "#;
915
916        let user_message_json = r#"
917        {
918            "role": "user",
919            "content": [
920                {
921                    "type": "text",
922                    "text": "What's in this image?"
923                },
924                {
925                    "type": "image_url",
926                    "image_url": {
927                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
928                    }
929                }
930            ]
931        }
932        "#;
933
934        let assistant_message: Message = {
935            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
936            deserialize(jd).unwrap_or_else(|err| {
937                panic!(
938                    "Deserialization error at {} ({}:{}): {}",
939                    err.path(),
940                    err.inner().line(),
941                    err.inner().column(),
942                    err
943                );
944            })
945        };
946
947        let assistant_message2: Message = {
948            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
949            deserialize(jd).unwrap_or_else(|err| {
950                panic!(
951                    "Deserialization error at {} ({}:{}): {}",
952                    err.path(),
953                    err.inner().line(),
954                    err.inner().column(),
955                    err
956                );
957            })
958        };
959
960        let assistant_message3: Message = {
961            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
962                &mut serde_json::Deserializer::from_str(assistant_message_json3);
963            deserialize(jd).unwrap_or_else(|err| {
964                panic!(
965                    "Deserialization error at {} ({}:{}): {}",
966                    err.path(),
967                    err.inner().line(),
968                    err.inner().column(),
969                    err
970                );
971            })
972        };
973
974        let user_message: Message = {
975            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
976            deserialize(jd).unwrap_or_else(|err| {
977                panic!(
978                    "Deserialization error at {} ({}:{}): {}",
979                    err.path(),
980                    err.inner().line(),
981                    err.inner().column(),
982                    err
983                );
984            })
985        };
986
987        match assistant_message {
988            Message::Assistant { content, .. } => {
989                assert_eq!(
990                    content[0],
991                    AssistantContent::Text {
992                        text: "\n\nHello there, how may I assist you today?".to_string()
993                    }
994                );
995            }
996            _ => panic!("Expected assistant message"),
997        }
998
999        match assistant_message2 {
1000            Message::Assistant {
1001                content,
1002                tool_calls,
1003                ..
1004            } => {
1005                assert_eq!(
1006                    content[0],
1007                    AssistantContent::Text {
1008                        text: "\n\nHello there, how may I assist you today?".to_string()
1009                    }
1010                );
1011
1012                assert_eq!(tool_calls, vec![]);
1013            }
1014            _ => panic!("Expected assistant message"),
1015        }
1016
1017        match assistant_message3 {
1018            Message::Assistant {
1019                content,
1020                tool_calls,
1021                ..
1022            } => {
1023                assert!(content.is_empty());
1024                assert_eq!(
1025                    tool_calls[0],
1026                    ToolCall {
1027                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
1028                        r#type: ToolType::Function,
1029                        function: Function {
1030                            name: "subtract".to_string(),
1031                            arguments: serde_json::json!({"x": 2, "y": 5}),
1032                        },
1033                    }
1034                );
1035            }
1036            _ => panic!("Expected assistant message"),
1037        }
1038
1039        match user_message {
1040            Message::User { content, .. } => {
1041                let (first, second) = {
1042                    let mut iter = content.into_iter();
1043                    (iter.next().unwrap(), iter.next().unwrap())
1044                };
1045                assert_eq!(
1046                    first,
1047                    UserContent::Text {
1048                        text: "What's in this image?".to_string()
1049                    }
1050                );
1051                assert_eq!(second, UserContent::ImageUrl { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string() } });
1052            }
1053            _ => panic!("Expected user message"),
1054        }
1055    }
1056
1057    #[test]
1058    fn test_message_to_message_conversion() {
1059        let user_message = message::Message::User {
1060            content: OneOrMany::one(message::UserContent::text("Hello")),
1061        };
1062
1063        let assistant_message = message::Message::Assistant {
1064            id: None,
1065            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
1066        };
1067
1068        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
1069        let converted_assistant_message: Vec<Message> =
1070            assistant_message.clone().try_into().unwrap();
1071
1072        match converted_user_message[0].clone() {
1073            Message::User { content, .. } => {
1074                assert_eq!(
1075                    content.first(),
1076                    UserContent::Text {
1077                        text: "Hello".to_string()
1078                    }
1079                );
1080            }
1081            _ => panic!("Expected user message"),
1082        }
1083
1084        match converted_assistant_message[0].clone() {
1085            Message::Assistant { content, .. } => {
1086                assert_eq!(
1087                    content[0],
1088                    AssistantContent::Text {
1089                        text: "Hi there!".to_string()
1090                    }
1091                );
1092            }
1093            _ => panic!("Expected assistant message"),
1094        }
1095
1096        let original_user_message: message::Message =
1097            converted_user_message[0].clone().try_into().unwrap();
1098        let original_assistant_message: message::Message =
1099            converted_assistant_message[0].clone().try_into().unwrap();
1100
1101        assert_eq!(original_user_message, user_message);
1102        assert_eq!(original_assistant_message, assistant_message);
1103    }
1104
1105    #[test]
1106    fn test_message_from_message_conversion() {
1107        let user_message = Message::User {
1108            content: OneOrMany::one(UserContent::Text {
1109                text: "Hello".to_string(),
1110            }),
1111        };
1112
1113        let assistant_message = Message::Assistant {
1114            content: vec![AssistantContent::Text {
1115                text: "Hi there!".to_string(),
1116            }],
1117            tool_calls: vec![],
1118        };
1119
1120        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1121        let converted_assistant_message: message::Message =
1122            assistant_message.clone().try_into().unwrap();
1123
1124        match converted_user_message.clone() {
1125            message::Message::User { content } => {
1126                assert_eq!(content.first(), message::UserContent::text("Hello"));
1127            }
1128            _ => panic!("Expected user message"),
1129        }
1130
1131        match converted_assistant_message.clone() {
1132            message::Message::Assistant { content, .. } => {
1133                assert_eq!(
1134                    content.first(),
1135                    message::AssistantContent::text("Hi there!")
1136                );
1137            }
1138            _ => panic!("Expected assistant message"),
1139        }
1140
1141        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
1142        let original_assistant_message: Vec<Message> =
1143            converted_assistant_message.try_into().unwrap();
1144
1145        assert_eq!(original_user_message[0], user_message);
1146        assert_eq!(original_assistant_message[0], assistant_message);
1147    }
1148
1149    #[test]
1150    fn test_responses() {
1151        let fireworks_response_json = r#"
1152        {
1153            "choices": [
1154                {
1155                    "finish_reason": "tool_calls",
1156                    "index": 0,
1157                    "message": {
1158                        "role": "assistant",
1159                        "tool_calls": [
1160                            {
1161                                "function": {
1162                                "arguments": "{\"x\": 2, \"y\": 5}",
1163                                "name": "subtract"
1164                                },
1165                                "id": "call_1BspL6mQqjKgvsQbH1TIYkHf",
1166                                "index": 0,
1167                                "type": "function"
1168                            }
1169                        ]
1170                    }
1171                }
1172            ],
1173            "created": 1740704000,
1174            "id": "2a81f6a1-4866-42fb-9902-2655a2b5b1ff",
1175            "model": "accounts/fireworks/models/deepseek-v3",
1176            "object": "chat.completion",
1177            "usage": {
1178                "completion_tokens": 26,
1179                "prompt_tokens": 248,
1180                "total_tokens": 274
1181            }
1182        }
1183        "#;
1184
1185        let novita_response_json = r#"
1186        {
1187            "choices": [
1188                {
1189                    "finish_reason": "tool_calls",
1190                    "index": 0,
1191                    "logprobs": null,
1192                    "message": {
1193                        "audio": null,
1194                        "content": null,
1195                        "function_call": null,
1196                        "reasoning_content": null,
1197                        "refusal": null,
1198                        "role": "assistant",
1199                        "tool_calls": [
1200                            {
1201                                "function": {
1202                                    "arguments": "{\"x\": \"2\", \"y\": \"5\"}",
1203                                    "name": "subtract"
1204                                },
1205                                "id": "chatcmpl-tool-f6d2af7c8dc041058f95e2c2eede45c5",
1206                                "type": "function"
1207                            }
1208                        ]
1209                    },
1210                    "stop_reason": 128008
1211                }
1212            ],
1213            "created": 1740704592,
1214            "id": "chatcmpl-a92c60ae125c47c998ecdcb53387fed4",
1215            "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-fast",
1216            "object": "chat.completion",
1217            "prompt_logprobs": null,
1218            "service_tier": null,
1219            "system_fingerprint": null,
1220            "usage": {
1221                "completion_tokens": 28,
1222                "completion_tokens_details": null,
1223                "prompt_tokens": 335,
1224                "prompt_tokens_details": null,
1225                "total_tokens": 363
1226            }
1227        }
1228        "#;
1229
1230        let _firework_response: CompletionResponse = {
1231            let jd = &mut serde_json::Deserializer::from_str(fireworks_response_json);
1232            deserialize(jd).unwrap_or_else(|err| {
1233                panic!(
1234                    "Deserialization error at {} ({}:{}): {}",
1235                    err.path(),
1236                    err.inner().line(),
1237                    err.inner().column(),
1238                    err
1239                );
1240            })
1241        };
1242
1243        let _novita_response: CompletionResponse = {
1244            let jd = &mut serde_json::Deserializer::from_str(novita_response_json);
1245            deserialize(jd).unwrap_or_else(|err| {
1246                panic!(
1247                    "Deserialization error at {} ({}:{}): {}",
1248                    err.path(),
1249                    err.inner().line(),
1250                    err.inner().column(),
1251                    err
1252                );
1253            })
1254        };
1255    }
1256
1257    #[test]
1258    fn test_assistant_reasoning_is_silently_skipped() {
1259        let assistant = message::Message::Assistant {
1260            id: None,
1261            content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1262        };
1263
1264        let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
1265        assert!(converted.is_empty());
1266    }
1267
1268    #[test]
1269    fn test_assistant_text_and_tool_call_are_preserved_when_reasoning_present() {
1270        let assistant = message::Message::Assistant {
1271            id: None,
1272            content: OneOrMany::many(vec![
1273                message::AssistantContent::reasoning("hidden"),
1274                message::AssistantContent::text("visible"),
1275                message::AssistantContent::tool_call(
1276                    "call_1",
1277                    "subtract",
1278                    serde_json::json!({"x": 2, "y": 1}),
1279                ),
1280            ])
1281            .expect("non-empty assistant content"),
1282        };
1283
1284        let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
1285        assert_eq!(converted.len(), 1);
1286
1287        match &converted[0] {
1288            Message::Assistant {
1289                content,
1290                tool_calls,
1291                ..
1292            } => {
1293                assert_eq!(
1294                    content,
1295                    &vec![AssistantContent::Text {
1296                        text: "visible".to_string()
1297                    }]
1298                );
1299                assert_eq!(tool_calls.len(), 1);
1300                assert_eq!(tool_calls[0].id, "call_1");
1301                assert_eq!(tool_calls[0].function.name, "subtract");
1302                assert_eq!(
1303                    tool_calls[0].function.arguments,
1304                    serde_json::json!({"x": 2, "y": 1})
1305                );
1306            }
1307            _ => panic!("expected assistant message"),
1308        }
1309    }
1310
1311    #[test]
1312    fn test_request_conversion_errors_when_all_messages_are_filtered() {
1313        let request = completion::CompletionRequest {
1314            preamble: None,
1315            chat_history: OneOrMany::one(message::Message::Assistant {
1316                id: None,
1317                content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1318            }),
1319            documents: vec![],
1320            tools: vec![],
1321            temperature: None,
1322            max_tokens: None,
1323            tool_choice: None,
1324            additional_params: None,
1325            model: None,
1326            output_schema: None,
1327        };
1328
1329        let result = HuggingfaceCompletionRequest::try_from(("meta/test-model", request));
1330        assert!(matches!(result, Err(CompletionError::RequestError(_))));
1331    }
1332}