Skip to main content

rig/providers/huggingface/
completion.rs

1use super::client::Client;
2use crate::completion::GetTokenUsage;
3use crate::http_client::HttpClientExt;
4use crate::providers::openai::StreamingCompletionResponse;
5use crate::telemetry::SpanCombinator;
6use crate::{
7    OneOrMany,
8    completion::{self, CompletionError, CompletionRequest},
9    json_utils,
10    message::{self},
11    one_or_many::string_or_one_or_many,
12};
13use serde::{Deserialize, Deserializer, Serialize, Serializer};
14use serde_json::Value;
15use std::{convert::Infallible, str::FromStr};
16use tracing::{Level, enabled, info_span};
17use tracing_futures::Instrument;
18
19#[derive(Debug, Deserialize)]
20#[serde(untagged)]
21pub enum ApiResponse<T> {
22    Ok(T),
23    Err(Value),
24}
25
26// ================================================================
27// Huggingface Completion API
28// ================================================================
29
30// Conversational LLMs
31/// `google/gemma-2-2b-it` completion model
32pub const GEMMA_2: &str = "google/gemma-2-2b-it";
33/// `meta-llama/Meta-Llama-3.1-8B-Instruct` completion model
34pub const META_LLAMA_3_1: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
35/// `PowerInfer/SmallThinker-3B-Preview` completion model
36pub const SMALLTHINKER_PREVIEW: &str = "PowerInfer/SmallThinker-3B-Preview";
37/// `Qwen/Qwen2.5-7B-Instruct` completion model
38pub const QWEN2_5: &str = "Qwen/Qwen2.5-7B-Instruct";
39/// `Qwen/Qwen2.5-Coder-32B-Instruct` completion model
40pub const QWEN2_5_CODER: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
41
42// Conversational VLMs
43
44/// `Qwen/Qwen2-VL-7B-Instruct` visual-language completion model
45pub const QWEN2_VL: &str = "Qwen/Qwen2-VL-7B-Instruct";
46/// `Qwen/QVQ-72B-Preview` visual-language completion model
47pub const QWEN_QVQ_PREVIEW: &str = "Qwen/QVQ-72B-Preview";
48
49#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
50pub struct Function {
51    name: String,
52    #[serde(
53        serialize_with = "json_utils::stringified_json::serialize",
54        deserialize_with = "json_utils::stringified_json::deserialize_maybe_stringified"
55    )]
56    pub arguments: serde_json::Value,
57}
58
59impl From<Function> for message::ToolFunction {
60    fn from(value: Function) -> Self {
61        message::ToolFunction {
62            name: value.name,
63            arguments: value.arguments,
64        }
65    }
66}
67
68#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
69#[serde(rename_all = "lowercase")]
70pub enum ToolType {
71    #[default]
72    Function,
73}
74
75#[derive(Debug, Deserialize, Serialize, Clone)]
76pub struct ToolDefinition {
77    pub r#type: String,
78    pub function: completion::ToolDefinition,
79}
80
81impl From<completion::ToolDefinition> for ToolDefinition {
82    fn from(tool: completion::ToolDefinition) -> Self {
83        Self {
84            r#type: "function".into(),
85            function: tool,
86        }
87    }
88}
89
90#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
91pub struct ToolCall {
92    pub id: String,
93    pub r#type: ToolType,
94    pub function: Function,
95}
96
97impl From<ToolCall> for message::ToolCall {
98    fn from(value: ToolCall) -> Self {
99        message::ToolCall {
100            id: value.id,
101            call_id: None,
102            function: value.function.into(),
103            signature: None,
104            additional_params: None,
105        }
106    }
107}
108
109impl From<message::ToolCall> for ToolCall {
110    fn from(value: message::ToolCall) -> Self {
111        ToolCall {
112            id: value.id,
113            r#type: ToolType::Function,
114            function: Function {
115                name: value.function.name,
116                arguments: value.function.arguments,
117            },
118        }
119    }
120}
121
122#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
123pub struct ImageUrl {
124    url: String,
125}
126
127#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
128#[serde(tag = "type", rename_all = "lowercase")]
129pub enum UserContent {
130    Text {
131        text: String,
132    },
133    #[serde(rename = "image_url")]
134    ImageUrl {
135        image_url: ImageUrl,
136    },
137}
138
139impl FromStr for UserContent {
140    type Err = Infallible;
141
142    fn from_str(s: &str) -> Result<Self, Self::Err> {
143        Ok(UserContent::Text {
144            text: s.to_string(),
145        })
146    }
147}
148
149#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
150#[serde(tag = "type", rename_all = "lowercase")]
151pub enum AssistantContent {
152    Text { text: String },
153}
154
155impl FromStr for AssistantContent {
156    type Err = Infallible;
157
158    fn from_str(s: &str) -> Result<Self, Self::Err> {
159        Ok(AssistantContent::Text {
160            text: s.to_string(),
161        })
162    }
163}
164
165#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
166#[serde(tag = "type", rename_all = "lowercase")]
167pub enum SystemContent {
168    Text { text: String },
169}
170
171impl FromStr for SystemContent {
172    type Err = Infallible;
173
174    fn from_str(s: &str) -> Result<Self, Self::Err> {
175        Ok(SystemContent::Text {
176            text: s.to_string(),
177        })
178    }
179}
180
181impl From<UserContent> for message::UserContent {
182    fn from(value: UserContent) -> Self {
183        match value {
184            UserContent::Text { text } => message::UserContent::text(text),
185            UserContent::ImageUrl { image_url } => {
186                message::UserContent::image_url(image_url.url, None, None)
187            }
188        }
189    }
190}
191
192impl TryFrom<message::UserContent> for UserContent {
193    type Error = message::MessageError;
194
195    fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
196        match content {
197            message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }),
198            message::UserContent::Document(message::Document {
199                data: message::DocumentSourceKind::Raw(raw),
200                ..
201            }) => {
202                let text = String::from_utf8_lossy(raw.as_slice()).into();
203                Ok(UserContent::Text { text })
204            }
205            message::UserContent::Document(message::Document {
206                data:
207                    message::DocumentSourceKind::Base64(text)
208                    | message::DocumentSourceKind::String(text),
209                ..
210            }) => Ok(UserContent::Text { text }),
211            message::UserContent::Image(message::Image { data, .. }) => match data {
212                message::DocumentSourceKind::Url(url) => Ok(UserContent::ImageUrl {
213                    image_url: ImageUrl { url },
214                }),
215                _ => Err(message::MessageError::ConversionError(
216                    "Huggingface only supports images as urls".into(),
217                )),
218            },
219            _ => Err(message::MessageError::ConversionError(
220                "Huggingface only supports text and images".into(),
221            )),
222        }
223    }
224}
225
226#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
227#[serde(tag = "role", rename_all = "lowercase")]
228pub enum Message {
229    System {
230        #[serde(deserialize_with = "string_or_one_or_many")]
231        content: OneOrMany<SystemContent>,
232    },
233    User {
234        #[serde(deserialize_with = "string_or_one_or_many")]
235        content: OneOrMany<UserContent>,
236    },
237    Assistant {
238        #[serde(default, deserialize_with = "json_utils::string_or_vec")]
239        content: Vec<AssistantContent>,
240        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
241        tool_calls: Vec<ToolCall>,
242    },
243    #[serde(rename = "tool", alias = "Tool")]
244    ToolResult {
245        name: String,
246        #[serde(skip_serializing_if = "Option::is_none")]
247        arguments: Option<serde_json::Value>,
248        #[serde(
249            deserialize_with = "string_or_one_or_many",
250            serialize_with = "serialize_tool_content"
251        )]
252        content: OneOrMany<String>,
253    },
254}
255
256fn serialize_tool_content<S>(content: &OneOrMany<String>, serializer: S) -> Result<S::Ok, S::Error>
257where
258    S: Serializer,
259{
260    // OpenAI-compatible APIs expect tool content as a string, not an array
261    let joined = content
262        .iter()
263        .map(String::as_str)
264        .collect::<Vec<_>>()
265        .join("\n");
266    serializer.serialize_str(&joined)
267}
268
269impl Message {
270    pub fn system(content: &str) -> Self {
271        Message::System {
272            content: OneOrMany::one(SystemContent::Text {
273                text: content.to_string(),
274            }),
275        }
276    }
277}
278
279impl TryFrom<message::Message> for Vec<Message> {
280    type Error = message::MessageError;
281
282    fn try_from(message: message::Message) -> Result<Vec<Message>, Self::Error> {
283        match message {
284            message::Message::System { content } => Ok(vec![Message::system(&content)]),
285            message::Message::User { content } => {
286                let (tool_results, other_content): (Vec<_>, Vec<_>) = content
287                    .into_iter()
288                    .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
289
290                if !tool_results.is_empty() {
291                    tool_results
292                        .into_iter()
293                        .map(|content| match content {
294                            message::UserContent::ToolResult(message::ToolResult {
295                                id,
296                                content,
297                                ..
298                            }) => Ok::<_, message::MessageError>(Message::ToolResult {
299                                name: id,
300                                arguments: None,
301                                content: content.try_map(|content| match content {
302                                    message::ToolResultContent::Text(message::Text { text }) => {
303                                        Ok(text)
304                                    }
305                                    _ => Err(message::MessageError::ConversionError(
306                                        "Tool result content does not support non-text".into(),
307                                    )),
308                                })?,
309                            }),
310                            _ => Err(message::MessageError::ConversionError(
311                                "expected tool result content while converting HuggingFace input"
312                                    .into(),
313                            )),
314                        })
315                        .collect::<Result<Vec<_>, _>>()
316                } else {
317                    let other_content = OneOrMany::many(other_content).map_err(|_| {
318                        message::MessageError::ConversionError(
319                            "HuggingFace user message did not contain any non-tool content".into(),
320                        )
321                    })?;
322
323                    Ok(vec![Message::User {
324                        content: other_content.try_map(|content| match content {
325                            message::UserContent::Text(text) => {
326                                Ok(UserContent::Text { text: text.text })
327                            }
328                            message::UserContent::Image(image) => {
329                                let url = image.try_into_url()?;
330
331                                Ok(UserContent::ImageUrl {
332                                    image_url: ImageUrl { url },
333                                })
334                            }
335                            message::UserContent::Document(message::Document {
336                                data: message::DocumentSourceKind::Raw(raw), ..
337                            }) => {
338                                let text = String::from_utf8_lossy(raw.as_slice()).into();
339                                Ok(UserContent::Text { text })
340                            }
341                            message::UserContent::Document(message::Document {
342                                data: message::DocumentSourceKind::Base64(text) | message::DocumentSourceKind::String(text), ..
343                            }) => {
344                                Ok(UserContent::Text { text })
345                            }
346                            _ => Err(message::MessageError::ConversionError(
347                                "Huggingface inputs only support text and image URLs (both base64-encoded images and regular URLs)".into(),
348                            )),
349                        })?,
350                    }])
351                }
352            }
353            message::Message::Assistant { content, .. } => {
354                let mut text_content = Vec::new();
355                let mut tool_calls = Vec::new();
356
357                for content in content {
358                    match content {
359                        message::AssistantContent::Text(text) => text_content.push(text),
360                        message::AssistantContent::ToolCall(tool_call) => {
361                            tool_calls.push(tool_call)
362                        }
363                        message::AssistantContent::Reasoning(_) => {
364                            // HuggingFace does not support assistant-history reasoning items.
365                            // Silently skip unsupported reasoning content.
366                        }
367                        message::AssistantContent::Image(_) => {
368                            return Err(message::MessageError::ConversionError(
369                                "HuggingFace assistant messages do not support image content"
370                                    .into(),
371                            ));
372                        }
373                    }
374                }
375
376                if text_content.is_empty() && tool_calls.is_empty() {
377                    return Ok(vec![]);
378                }
379
380                Ok(vec![Message::Assistant {
381                    content: text_content
382                        .into_iter()
383                        .map(|content| AssistantContent::Text { text: content.text })
384                        .collect::<Vec<_>>(),
385                    tool_calls: tool_calls
386                        .into_iter()
387                        .map(|tool_call| tool_call.into())
388                        .collect::<Vec<_>>(),
389                }])
390            }
391        }
392    }
393}
394
395impl TryFrom<Message> for message::Message {
396    type Error = message::MessageError;
397
398    fn try_from(message: Message) -> Result<Self, Self::Error> {
399        Ok(match message {
400            Message::User { content, .. } => message::Message::User {
401                content: content.map(|content| content.into()),
402            },
403            Message::Assistant {
404                content,
405                tool_calls,
406                ..
407            } => {
408                let mut content = content
409                    .into_iter()
410                    .map(|content| match content {
411                        AssistantContent::Text { text } => message::AssistantContent::text(text),
412                    })
413                    .collect::<Vec<_>>();
414
415                content.extend(
416                    tool_calls
417                        .into_iter()
418                        .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
419                        .collect::<Result<Vec<_>, _>>()?,
420                );
421
422                message::Message::Assistant {
423                    id: None,
424                    content: OneOrMany::many(content).map_err(|_| {
425                        message::MessageError::ConversionError(
426                            "Neither `content` nor `tool_calls` was provided to the Message"
427                                .to_owned(),
428                        )
429                    })?,
430                }
431            }
432
433            Message::ToolResult { name, content, .. } => message::Message::User {
434                content: OneOrMany::one(message::UserContent::tool_result(
435                    name,
436                    content.map(message::ToolResultContent::text),
437                )),
438            },
439
440            // System messages should get stripped out when converting message's, this is just a
441            // stop gap to avoid obnoxious error handling or panic occurring.
442            Message::System { content, .. } => message::Message::User {
443                content: content.map(|c| match c {
444                    SystemContent::Text { text } => message::UserContent::text(text),
445                }),
446            },
447        })
448    }
449}
450
451#[derive(Clone, Debug, Deserialize, Serialize)]
452pub struct Choice {
453    pub finish_reason: String,
454    pub index: usize,
455    #[serde(default)]
456    pub logprobs: serde_json::Value,
457    pub message: Message,
458}
459
460#[derive(Debug, Deserialize, Clone, Serialize)]
461pub struct Usage {
462    pub completion_tokens: i32,
463    pub prompt_tokens: i32,
464    pub total_tokens: i32,
465}
466
467impl GetTokenUsage for Usage {
468    fn token_usage(&self) -> Option<crate::completion::Usage> {
469        let mut usage = crate::completion::Usage::new();
470        usage.input_tokens = self.prompt_tokens as u64;
471        usage.output_tokens = self.completion_tokens as u64;
472        usage.total_tokens = self.total_tokens as u64;
473
474        Some(usage)
475    }
476}
477
478#[derive(Clone, Debug, Deserialize, Serialize)]
479pub struct CompletionResponse {
480    pub created: i32,
481    pub id: String,
482    pub model: String,
483    pub choices: Vec<Choice>,
484    #[serde(default, deserialize_with = "default_string_on_null")]
485    pub system_fingerprint: String,
486    pub usage: Usage,
487}
488
489impl crate::telemetry::ProviderResponseExt for CompletionResponse {
490    type OutputMessage = Choice;
491    type Usage = Usage;
492
493    fn get_response_id(&self) -> Option<String> {
494        Some(self.id.clone())
495    }
496
497    fn get_response_model_name(&self) -> Option<String> {
498        Some(self.model.clone())
499    }
500
501    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
502        self.choices.clone()
503    }
504
505    fn get_text_response(&self) -> Option<String> {
506        let text_response = self
507            .choices
508            .iter()
509            .filter_map(|x| {
510                let Message::User { ref content } = x.message else {
511                    return None;
512                };
513
514                let text = content
515                    .iter()
516                    .filter_map(|x| {
517                        if let UserContent::Text { text } = x {
518                            Some(text.clone())
519                        } else {
520                            None
521                        }
522                    })
523                    .collect::<Vec<String>>();
524
525                if text.is_empty() {
526                    None
527                } else {
528                    Some(text.join("\n"))
529                }
530            })
531            .collect::<Vec<String>>()
532            .join("\n");
533
534        if text_response.is_empty() {
535            None
536        } else {
537            Some(text_response)
538        }
539    }
540
541    fn get_usage(&self) -> Option<Self::Usage> {
542        Some(self.usage.clone())
543    }
544}
545
546fn default_string_on_null<'de, D>(deserializer: D) -> Result<String, D::Error>
547where
548    D: Deserializer<'de>,
549{
550    match Option::<String>::deserialize(deserializer)? {
551        Some(value) => Ok(value),      // Use provided value
552        None => Ok(String::default()), // Use `Default` implementation
553    }
554}
555
556impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
557    type Error = CompletionError;
558
559    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
560        let choice = response.choices.first().ok_or_else(|| {
561            CompletionError::ResponseError("Response contained no choices".to_owned())
562        })?;
563
564        let content = match &choice.message {
565            Message::Assistant {
566                content,
567                tool_calls,
568                ..
569            } => {
570                let mut content = content
571                    .iter()
572                    .map(|c| match c {
573                        AssistantContent::Text { text } => message::AssistantContent::text(text),
574                    })
575                    .collect::<Vec<_>>();
576
577                content.extend(
578                    tool_calls
579                        .iter()
580                        .map(|call| {
581                            completion::AssistantContent::tool_call(
582                                &call.id,
583                                &call.function.name,
584                                call.function.arguments.clone(),
585                            )
586                        })
587                        .collect::<Vec<_>>(),
588                );
589                Ok(content)
590            }
591            _ => Err(CompletionError::ResponseError(
592                "Response did not contain a valid message or tool call".into(),
593            )),
594        }?;
595
596        let choice = OneOrMany::many(content).map_err(|_| {
597            CompletionError::ResponseError(
598                "Response contained no message or tool call (empty)".to_owned(),
599            )
600        })?;
601
602        let usage = completion::Usage {
603            input_tokens: response.usage.prompt_tokens as u64,
604            output_tokens: response.usage.completion_tokens as u64,
605            total_tokens: response.usage.total_tokens as u64,
606            cached_input_tokens: 0,
607            cache_creation_input_tokens: 0,
608        };
609
610        Ok(completion::CompletionResponse {
611            choice,
612            usage,
613            raw_response: response,
614            message_id: None,
615        })
616    }
617}
618
619#[derive(Debug, Serialize, Deserialize)]
620pub(super) struct HuggingfaceCompletionRequest {
621    model: String,
622    pub messages: Vec<Message>,
623    #[serde(skip_serializing_if = "Option::is_none")]
624    temperature: Option<f64>,
625    #[serde(skip_serializing_if = "Vec::is_empty")]
626    tools: Vec<ToolDefinition>,
627    #[serde(skip_serializing_if = "Option::is_none")]
628    tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
629    #[serde(flatten, skip_serializing_if = "Option::is_none")]
630    pub additional_params: Option<serde_json::Value>,
631}
632
633impl TryFrom<(&str, CompletionRequest)> for HuggingfaceCompletionRequest {
634    type Error = CompletionError;
635
636    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
637        if req.output_schema.is_some() {
638            tracing::warn!("Structured outputs currently not supported for Huggingface");
639        }
640        let model = req.model.clone().unwrap_or_else(|| model.to_string());
641        let mut full_history: Vec<Message> = match &req.preamble {
642            Some(preamble) => vec![Message::system(preamble)],
643            None => vec![],
644        };
645        if let Some(docs) = req.normalized_documents() {
646            let docs: Vec<Message> = docs.try_into()?;
647            full_history.extend(docs);
648        }
649
650        let chat_history: Vec<Message> = req
651            .chat_history
652            .clone()
653            .into_iter()
654            .map(|message| message.try_into())
655            .collect::<Result<Vec<Vec<Message>>, _>>()?
656            .into_iter()
657            .flatten()
658            .collect();
659
660        full_history.extend(chat_history);
661
662        if full_history.is_empty() {
663            return Err(CompletionError::RequestError(
664                std::io::Error::new(
665                    std::io::ErrorKind::InvalidInput,
666                    "HuggingFace request has no provider-compatible messages after conversion",
667                )
668                .into(),
669            ));
670        }
671
672        let tool_choice = req
673            .tool_choice
674            .clone()
675            .map(crate::providers::openai::completion::ToolChoice::try_from)
676            .transpose()?;
677
678        Ok(Self {
679            model: model.to_string(),
680            messages: full_history,
681            temperature: req.temperature,
682            tools: req
683                .tools
684                .clone()
685                .into_iter()
686                .map(ToolDefinition::from)
687                .collect::<Vec<_>>(),
688            tool_choice,
689            additional_params: req.additional_params,
690        })
691    }
692}
693
694#[derive(Clone)]
695pub struct CompletionModel<T = reqwest::Client> {
696    pub(crate) client: Client<T>,
697    /// Name of the model (e.g: google/gemma-2-2b-it)
698    pub model: String,
699}
700
701impl<T> CompletionModel<T> {
702    pub fn new(client: Client<T>, model: &str) -> Self {
703        Self {
704            client,
705            model: model.to_string(),
706        }
707    }
708}
709
710impl<T> completion::CompletionModel for CompletionModel<T>
711where
712    T: HttpClientExt + Clone + 'static,
713{
714    type Response = CompletionResponse;
715    type StreamingResponse = StreamingCompletionResponse;
716
717    type Client = Client<T>;
718
719    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
720        Self::new(client.clone(), &model.into())
721    }
722
723    async fn completion(
724        &self,
725        completion_request: CompletionRequest,
726    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
727        let request_model = completion_request
728            .model
729            .clone()
730            .unwrap_or_else(|| self.model.clone());
731        let span = if tracing::Span::current().is_disabled() {
732            info_span!(
733                target: "rig::completions",
734                "chat",
735                gen_ai.operation.name = "chat",
736                gen_ai.provider.name = "huggingface",
737                gen_ai.request.model = &request_model,
738                gen_ai.system_instructions = &completion_request.preamble,
739                gen_ai.response.id = tracing::field::Empty,
740                gen_ai.response.model = tracing::field::Empty,
741                gen_ai.usage.output_tokens = tracing::field::Empty,
742                gen_ai.usage.input_tokens = tracing::field::Empty,
743                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
744            )
745        } else {
746            tracing::Span::current()
747        };
748
749        let model = self.client.subprovider().model_identifier(&request_model);
750        let request = HuggingfaceCompletionRequest::try_from((model.as_ref(), completion_request))?;
751
752        if enabled!(Level::TRACE) {
753            tracing::trace!(
754                target: "rig::completions",
755                "Huggingface completion request: {}",
756                serde_json::to_string_pretty(&request)?
757            );
758        }
759
760        let request = serde_json::to_vec(&request)?;
761
762        let path = self
763            .client
764            .subprovider()
765            .completion_endpoint(&request_model);
766        let request = self
767            .client
768            .post(&path)?
769            .header("Content-Type", "application/json")
770            .body(request)
771            .map_err(|e| CompletionError::HttpError(e.into()))?;
772
773        async move {
774            let response = self.client.send(request).await?;
775
776            if response.status().is_success() {
777                let bytes: Vec<u8> = response.into_body().await?;
778                let text = String::from_utf8_lossy(&bytes);
779
780                tracing::debug!(target: "rig", "Huggingface completion error: {}", text);
781
782                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&bytes)? {
783                    ApiResponse::Ok(response) => {
784                        if enabled!(Level::TRACE) {
785                            tracing::trace!(
786                                target: "rig::completions",
787                                "Huggingface completion response: {}",
788                                serde_json::to_string_pretty(&response)?
789                            );
790                        }
791
792                        let span = tracing::Span::current();
793                        span.record_token_usage(&response.usage);
794                        span.record_response_metadata(&response);
795
796                        response.try_into()
797                    }
798                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())),
799                }
800            } else {
801                let status = response.status();
802                let text: Vec<u8> = response.into_body().await?;
803                let text: String = String::from_utf8_lossy(&text).into();
804
805                Err(CompletionError::ProviderError(format!(
806                    "{}: {}",
807                    status, text
808                )))
809            }
810        }
811        .instrument(span)
812        .await
813    }
814
815    async fn stream(
816        &self,
817        request: CompletionRequest,
818    ) -> Result<
819        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
820        CompletionError,
821    > {
822        CompletionModel::stream(self, request).await
823    }
824}
825
826#[cfg(test)]
827mod tests {
828    use super::*;
829    use serde_path_to_error::deserialize;
830
831    #[test]
832    fn test_huggingface_request_uses_request_model_override() {
833        let request = CompletionRequest {
834            model: Some("meta-llama/Meta-Llama-3.1-8B-Instruct".to_string()),
835            preamble: None,
836            chat_history: crate::OneOrMany::one("Hello".into()),
837            documents: vec![],
838            tools: vec![],
839            temperature: None,
840            max_tokens: None,
841            tool_choice: None,
842            additional_params: None,
843            output_schema: None,
844        };
845
846        let hf_request = HuggingfaceCompletionRequest::try_from(("mistralai/Mistral-7B", request))
847            .expect("request conversion should succeed");
848        let serialized = serde_json::to_value(hf_request).expect("serialization should succeed");
849
850        assert_eq!(serialized["model"], "meta-llama/Meta-Llama-3.1-8B-Instruct");
851    }
852
853    #[test]
854    fn test_huggingface_request_uses_default_model_when_override_unset() {
855        let request = CompletionRequest {
856            model: None,
857            preamble: None,
858            chat_history: crate::OneOrMany::one("Hello".into()),
859            documents: vec![],
860            tools: vec![],
861            temperature: None,
862            max_tokens: None,
863            tool_choice: None,
864            additional_params: None,
865            output_schema: None,
866        };
867
868        let hf_request = HuggingfaceCompletionRequest::try_from(("mistralai/Mistral-7B", request))
869            .expect("request conversion should succeed");
870        let serialized = serde_json::to_value(hf_request).expect("serialization should succeed");
871
872        assert_eq!(serialized["model"], "mistralai/Mistral-7B");
873    }
874
875    #[test]
876    fn test_deserialize_message() {
877        let assistant_message_json = r#"
878        {
879            "role": "assistant",
880            "content": "\n\nHello there, how may I assist you today?"
881        }
882        "#;
883
884        let assistant_message_json2 = r#"
885        {
886            "role": "assistant",
887            "content": [
888                {
889                    "type": "text",
890                    "text": "\n\nHello there, how may I assist you today?"
891                }
892            ],
893            "tool_calls": null
894        }
895        "#;
896
897        let assistant_message_json3 = r#"
898        {
899            "role": "assistant",
900            "tool_calls": [
901                {
902                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
903                    "type": "function",
904                    "function": {
905                        "name": "subtract",
906                        "arguments": {"x": 2, "y": 5}
907                    }
908                }
909            ],
910            "content": null,
911            "refusal": null
912        }
913        "#;
914
915        let user_message_json = r#"
916        {
917            "role": "user",
918            "content": [
919                {
920                    "type": "text",
921                    "text": "What's in this image?"
922                },
923                {
924                    "type": "image_url",
925                    "image_url": {
926                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
927                    }
928                }
929            ]
930        }
931        "#;
932
933        let assistant_message: Message = {
934            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
935            deserialize(jd).unwrap_or_else(|err| {
936                panic!(
937                    "Deserialization error at {} ({}:{}): {}",
938                    err.path(),
939                    err.inner().line(),
940                    err.inner().column(),
941                    err
942                );
943            })
944        };
945
946        let assistant_message2: Message = {
947            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
948            deserialize(jd).unwrap_or_else(|err| {
949                panic!(
950                    "Deserialization error at {} ({}:{}): {}",
951                    err.path(),
952                    err.inner().line(),
953                    err.inner().column(),
954                    err
955                );
956            })
957        };
958
959        let assistant_message3: Message = {
960            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
961                &mut serde_json::Deserializer::from_str(assistant_message_json3);
962            deserialize(jd).unwrap_or_else(|err| {
963                panic!(
964                    "Deserialization error at {} ({}:{}): {}",
965                    err.path(),
966                    err.inner().line(),
967                    err.inner().column(),
968                    err
969                );
970            })
971        };
972
973        let user_message: Message = {
974            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
975            deserialize(jd).unwrap_or_else(|err| {
976                panic!(
977                    "Deserialization error at {} ({}:{}): {}",
978                    err.path(),
979                    err.inner().line(),
980                    err.inner().column(),
981                    err
982                );
983            })
984        };
985
986        match assistant_message {
987            Message::Assistant { content, .. } => {
988                assert_eq!(
989                    content[0],
990                    AssistantContent::Text {
991                        text: "\n\nHello there, how may I assist you today?".to_string()
992                    }
993                );
994            }
995            _ => panic!("Expected assistant message"),
996        }
997
998        match assistant_message2 {
999            Message::Assistant {
1000                content,
1001                tool_calls,
1002                ..
1003            } => {
1004                assert_eq!(
1005                    content[0],
1006                    AssistantContent::Text {
1007                        text: "\n\nHello there, how may I assist you today?".to_string()
1008                    }
1009                );
1010
1011                assert_eq!(tool_calls, vec![]);
1012            }
1013            _ => panic!("Expected assistant message"),
1014        }
1015
1016        match assistant_message3 {
1017            Message::Assistant {
1018                content,
1019                tool_calls,
1020                ..
1021            } => {
1022                assert!(content.is_empty());
1023                assert_eq!(
1024                    tool_calls[0],
1025                    ToolCall {
1026                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
1027                        r#type: ToolType::Function,
1028                        function: Function {
1029                            name: "subtract".to_string(),
1030                            arguments: serde_json::json!({"x": 2, "y": 5}),
1031                        },
1032                    }
1033                );
1034            }
1035            _ => panic!("Expected assistant message"),
1036        }
1037
1038        match user_message {
1039            Message::User { content, .. } => {
1040                let (first, second) = {
1041                    let mut iter = content.into_iter();
1042                    (iter.next().unwrap(), iter.next().unwrap())
1043                };
1044                assert_eq!(
1045                    first,
1046                    UserContent::Text {
1047                        text: "What's in this image?".to_string()
1048                    }
1049                );
1050                assert_eq!(second, UserContent::ImageUrl { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string() } });
1051            }
1052            _ => panic!("Expected user message"),
1053        }
1054    }
1055
1056    #[test]
1057    fn test_message_to_message_conversion() {
1058        let user_message = message::Message::User {
1059            content: OneOrMany::one(message::UserContent::text("Hello")),
1060        };
1061
1062        let assistant_message = message::Message::Assistant {
1063            id: None,
1064            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
1065        };
1066
1067        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
1068        let converted_assistant_message: Vec<Message> =
1069            assistant_message.clone().try_into().unwrap();
1070
1071        match converted_user_message[0].clone() {
1072            Message::User { content, .. } => {
1073                assert_eq!(
1074                    content.first(),
1075                    UserContent::Text {
1076                        text: "Hello".to_string()
1077                    }
1078                );
1079            }
1080            _ => panic!("Expected user message"),
1081        }
1082
1083        match converted_assistant_message[0].clone() {
1084            Message::Assistant { content, .. } => {
1085                assert_eq!(
1086                    content[0],
1087                    AssistantContent::Text {
1088                        text: "Hi there!".to_string()
1089                    }
1090                );
1091            }
1092            _ => panic!("Expected assistant message"),
1093        }
1094
1095        let original_user_message: message::Message =
1096            converted_user_message[0].clone().try_into().unwrap();
1097        let original_assistant_message: message::Message =
1098            converted_assistant_message[0].clone().try_into().unwrap();
1099
1100        assert_eq!(original_user_message, user_message);
1101        assert_eq!(original_assistant_message, assistant_message);
1102    }
1103
1104    #[test]
1105    fn test_message_from_message_conversion() {
1106        let user_message = Message::User {
1107            content: OneOrMany::one(UserContent::Text {
1108                text: "Hello".to_string(),
1109            }),
1110        };
1111
1112        let assistant_message = Message::Assistant {
1113            content: vec![AssistantContent::Text {
1114                text: "Hi there!".to_string(),
1115            }],
1116            tool_calls: vec![],
1117        };
1118
1119        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1120        let converted_assistant_message: message::Message =
1121            assistant_message.clone().try_into().unwrap();
1122
1123        match converted_user_message.clone() {
1124            message::Message::User { content } => {
1125                assert_eq!(content.first(), message::UserContent::text("Hello"));
1126            }
1127            _ => panic!("Expected user message"),
1128        }
1129
1130        match converted_assistant_message.clone() {
1131            message::Message::Assistant { content, .. } => {
1132                assert_eq!(
1133                    content.first(),
1134                    message::AssistantContent::text("Hi there!")
1135                );
1136            }
1137            _ => panic!("Expected assistant message"),
1138        }
1139
1140        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
1141        let original_assistant_message: Vec<Message> =
1142            converted_assistant_message.try_into().unwrap();
1143
1144        assert_eq!(original_user_message[0], user_message);
1145        assert_eq!(original_assistant_message[0], assistant_message);
1146    }
1147
1148    #[test]
1149    fn test_responses() {
1150        let fireworks_response_json = r#"
1151        {
1152            "choices": [
1153                {
1154                    "finish_reason": "tool_calls",
1155                    "index": 0,
1156                    "message": {
1157                        "role": "assistant",
1158                        "tool_calls": [
1159                            {
1160                                "function": {
1161                                "arguments": "{\"x\": 2, \"y\": 5}",
1162                                "name": "subtract"
1163                                },
1164                                "id": "call_1BspL6mQqjKgvsQbH1TIYkHf",
1165                                "index": 0,
1166                                "type": "function"
1167                            }
1168                        ]
1169                    }
1170                }
1171            ],
1172            "created": 1740704000,
1173            "id": "2a81f6a1-4866-42fb-9902-2655a2b5b1ff",
1174            "model": "accounts/fireworks/models/deepseek-v3",
1175            "object": "chat.completion",
1176            "usage": {
1177                "completion_tokens": 26,
1178                "prompt_tokens": 248,
1179                "total_tokens": 274
1180            }
1181        }
1182        "#;
1183
1184        let novita_response_json = r#"
1185        {
1186            "choices": [
1187                {
1188                    "finish_reason": "tool_calls",
1189                    "index": 0,
1190                    "logprobs": null,
1191                    "message": {
1192                        "audio": null,
1193                        "content": null,
1194                        "function_call": null,
1195                        "reasoning_content": null,
1196                        "refusal": null,
1197                        "role": "assistant",
1198                        "tool_calls": [
1199                            {
1200                                "function": {
1201                                    "arguments": "{\"x\": \"2\", \"y\": \"5\"}",
1202                                    "name": "subtract"
1203                                },
1204                                "id": "chatcmpl-tool-f6d2af7c8dc041058f95e2c2eede45c5",
1205                                "type": "function"
1206                            }
1207                        ]
1208                    },
1209                    "stop_reason": 128008
1210                }
1211            ],
1212            "created": 1740704592,
1213            "id": "chatcmpl-a92c60ae125c47c998ecdcb53387fed4",
1214            "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-fast",
1215            "object": "chat.completion",
1216            "prompt_logprobs": null,
1217            "service_tier": null,
1218            "system_fingerprint": null,
1219            "usage": {
1220                "completion_tokens": 28,
1221                "completion_tokens_details": null,
1222                "prompt_tokens": 335,
1223                "prompt_tokens_details": null,
1224                "total_tokens": 363
1225            }
1226        }
1227        "#;
1228
1229        let _firework_response: CompletionResponse = {
1230            let jd = &mut serde_json::Deserializer::from_str(fireworks_response_json);
1231            deserialize(jd).unwrap_or_else(|err| {
1232                panic!(
1233                    "Deserialization error at {} ({}:{}): {}",
1234                    err.path(),
1235                    err.inner().line(),
1236                    err.inner().column(),
1237                    err
1238                );
1239            })
1240        };
1241
1242        let _novita_response: CompletionResponse = {
1243            let jd = &mut serde_json::Deserializer::from_str(novita_response_json);
1244            deserialize(jd).unwrap_or_else(|err| {
1245                panic!(
1246                    "Deserialization error at {} ({}:{}): {}",
1247                    err.path(),
1248                    err.inner().line(),
1249                    err.inner().column(),
1250                    err
1251                );
1252            })
1253        };
1254    }
1255
1256    #[test]
1257    fn test_assistant_reasoning_is_silently_skipped() {
1258        let assistant = message::Message::Assistant {
1259            id: None,
1260            content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1261        };
1262
1263        let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
1264        assert!(converted.is_empty());
1265    }
1266
1267    #[test]
1268    fn test_assistant_text_and_tool_call_are_preserved_when_reasoning_present() {
1269        let assistant = message::Message::Assistant {
1270            id: None,
1271            content: OneOrMany::many(vec![
1272                message::AssistantContent::reasoning("hidden"),
1273                message::AssistantContent::text("visible"),
1274                message::AssistantContent::tool_call(
1275                    "call_1",
1276                    "subtract",
1277                    serde_json::json!({"x": 2, "y": 1}),
1278                ),
1279            ])
1280            .expect("non-empty assistant content"),
1281        };
1282
1283        let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
1284        assert_eq!(converted.len(), 1);
1285
1286        match &converted[0] {
1287            Message::Assistant {
1288                content,
1289                tool_calls,
1290                ..
1291            } => {
1292                assert_eq!(
1293                    content,
1294                    &vec![AssistantContent::Text {
1295                        text: "visible".to_string()
1296                    }]
1297                );
1298                assert_eq!(tool_calls.len(), 1);
1299                assert_eq!(tool_calls[0].id, "call_1");
1300                assert_eq!(tool_calls[0].function.name, "subtract");
1301                assert_eq!(
1302                    tool_calls[0].function.arguments,
1303                    serde_json::json!({"x": 2, "y": 1})
1304                );
1305            }
1306            _ => panic!("expected assistant message"),
1307        }
1308    }
1309
1310    #[test]
1311    fn test_request_conversion_errors_when_all_messages_are_filtered() {
1312        let request = completion::CompletionRequest {
1313            preamble: None,
1314            chat_history: OneOrMany::one(message::Message::Assistant {
1315                id: None,
1316                content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1317            }),
1318            documents: vec![],
1319            tools: vec![],
1320            temperature: None,
1321            max_tokens: None,
1322            tool_choice: None,
1323            additional_params: None,
1324            model: None,
1325            output_schema: None,
1326        };
1327
1328        let result = HuggingfaceCompletionRequest::try_from(("meta/test-model", request));
1329        assert!(matches!(result, Err(CompletionError::RequestError(_))));
1330    }
1331}