Skip to main content

rig/providers/huggingface/
completion.rs

1use super::client::Client;
2use crate::completion::GetTokenUsage;
3use crate::http_client::HttpClientExt;
4use crate::providers::openai::StreamingCompletionResponse;
5use crate::telemetry::SpanCombinator;
6use crate::{
7    OneOrMany,
8    completion::{self, CompletionError, CompletionRequest},
9    json_utils,
10    message::{self},
11    one_or_many::string_or_one_or_many,
12};
13use serde::{Deserialize, Deserializer, Serialize, Serializer};
14use serde_json::Value;
15use std::{convert::Infallible, str::FromStr};
16use tracing::{Level, enabled, info_span};
17use tracing_futures::Instrument;
18
19#[derive(Debug, Deserialize)]
20#[serde(untagged)]
21pub enum ApiResponse<T> {
22    Ok(T),
23    Err(Value),
24}
25
26// ================================================================
27// Huggingface Completion API
28// ================================================================
29
30// Conversational LLMs
31/// `google/gemma-2-2b-it` completion model
32pub const GEMMA_2: &str = "google/gemma-2-2b-it";
33/// `meta-llama/Meta-Llama-3.1-8B-Instruct` completion model
34pub const META_LLAMA_3_1: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
35/// `PowerInfer/SmallThinker-3B-Preview` completion model
36pub const SMALLTHINKER_PREVIEW: &str = "PowerInfer/SmallThinker-3B-Preview";
37/// `Qwen/Qwen2.5-7B-Instruct` completion model
38pub const QWEN2_5: &str = "Qwen/Qwen2.5-7B-Instruct";
39/// `Qwen/Qwen2.5-Coder-32B-Instruct` completion model
40pub const QWEN2_5_CODER: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
41
42// Conversational VLMs
43
44/// `Qwen/Qwen2-VL-7B-Instruct` visual-language completion model
45pub const QWEN2_VL: &str = "Qwen/Qwen2-VL-7B-Instruct";
46/// `Qwen/QVQ-72B-Preview` visual-language completion model
47pub const QWEN_QVQ_PREVIEW: &str = "Qwen/QVQ-72B-Preview";
48
49#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
50pub struct Function {
51    name: String,
52    #[serde(
53        serialize_with = "json_utils::stringified_json::serialize",
54        deserialize_with = "deserialize_arguments"
55    )]
56    pub arguments: serde_json::Value,
57}
58
59fn deserialize_arguments<'de, D>(deserializer: D) -> Result<Value, D::Error>
60where
61    D: Deserializer<'de>,
62{
63    let value = Value::deserialize(deserializer)?;
64
65    match value {
66        Value::String(s) => serde_json::from_str(&s).map_err(serde::de::Error::custom),
67        other => Ok(other),
68    }
69}
70
71impl From<Function> for message::ToolFunction {
72    fn from(value: Function) -> Self {
73        message::ToolFunction {
74            name: value.name,
75            arguments: value.arguments,
76        }
77    }
78}
79
80#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
81#[serde(rename_all = "lowercase")]
82pub enum ToolType {
83    #[default]
84    Function,
85}
86
87#[derive(Debug, Deserialize, Serialize, Clone)]
88pub struct ToolDefinition {
89    pub r#type: String,
90    pub function: completion::ToolDefinition,
91}
92
93impl From<completion::ToolDefinition> for ToolDefinition {
94    fn from(tool: completion::ToolDefinition) -> Self {
95        Self {
96            r#type: "function".into(),
97            function: tool,
98        }
99    }
100}
101
102#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
103pub struct ToolCall {
104    pub id: String,
105    pub r#type: ToolType,
106    pub function: Function,
107}
108
109impl From<ToolCall> for message::ToolCall {
110    fn from(value: ToolCall) -> Self {
111        message::ToolCall {
112            id: value.id,
113            call_id: None,
114            function: value.function.into(),
115            signature: None,
116            additional_params: None,
117        }
118    }
119}
120
121impl From<message::ToolCall> for ToolCall {
122    fn from(value: message::ToolCall) -> Self {
123        ToolCall {
124            id: value.id,
125            r#type: ToolType::Function,
126            function: Function {
127                name: value.function.name,
128                arguments: value.function.arguments,
129            },
130        }
131    }
132}
133
134#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
135pub struct ImageUrl {
136    url: String,
137}
138
139#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
140#[serde(tag = "type", rename_all = "lowercase")]
141pub enum UserContent {
142    Text {
143        text: String,
144    },
145    #[serde(rename = "image_url")]
146    ImageUrl {
147        image_url: ImageUrl,
148    },
149}
150
151impl FromStr for UserContent {
152    type Err = Infallible;
153
154    fn from_str(s: &str) -> Result<Self, Self::Err> {
155        Ok(UserContent::Text {
156            text: s.to_string(),
157        })
158    }
159}
160
161#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
162#[serde(tag = "type", rename_all = "lowercase")]
163pub enum AssistantContent {
164    Text { text: String },
165}
166
167impl FromStr for AssistantContent {
168    type Err = Infallible;
169
170    fn from_str(s: &str) -> Result<Self, Self::Err> {
171        Ok(AssistantContent::Text {
172            text: s.to_string(),
173        })
174    }
175}
176
177#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
178#[serde(tag = "type", rename_all = "lowercase")]
179pub enum SystemContent {
180    Text { text: String },
181}
182
183impl FromStr for SystemContent {
184    type Err = Infallible;
185
186    fn from_str(s: &str) -> Result<Self, Self::Err> {
187        Ok(SystemContent::Text {
188            text: s.to_string(),
189        })
190    }
191}
192
193impl From<UserContent> for message::UserContent {
194    fn from(value: UserContent) -> Self {
195        match value {
196            UserContent::Text { text } => message::UserContent::text(text),
197            UserContent::ImageUrl { image_url } => {
198                message::UserContent::image_url(image_url.url, None, None)
199            }
200        }
201    }
202}
203
204impl TryFrom<message::UserContent> for UserContent {
205    type Error = message::MessageError;
206
207    fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
208        match content {
209            message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }),
210            message::UserContent::Document(message::Document {
211                data: message::DocumentSourceKind::Raw(raw),
212                ..
213            }) => {
214                let text = String::from_utf8_lossy(raw.as_slice()).into();
215                Ok(UserContent::Text { text })
216            }
217            message::UserContent::Document(message::Document {
218                data:
219                    message::DocumentSourceKind::Base64(text)
220                    | message::DocumentSourceKind::String(text),
221                ..
222            }) => Ok(UserContent::Text { text }),
223            message::UserContent::Image(message::Image { data, .. }) => match data {
224                message::DocumentSourceKind::Url(url) => Ok(UserContent::ImageUrl {
225                    image_url: ImageUrl { url },
226                }),
227                _ => Err(message::MessageError::ConversionError(
228                    "Huggingface only supports images as urls".into(),
229                )),
230            },
231            _ => Err(message::MessageError::ConversionError(
232                "Huggingface only supports text and images".into(),
233            )),
234        }
235    }
236}
237
238#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
239#[serde(tag = "role", rename_all = "lowercase")]
240pub enum Message {
241    System {
242        #[serde(deserialize_with = "string_or_one_or_many")]
243        content: OneOrMany<SystemContent>,
244    },
245    User {
246        #[serde(deserialize_with = "string_or_one_or_many")]
247        content: OneOrMany<UserContent>,
248    },
249    Assistant {
250        #[serde(default, deserialize_with = "json_utils::string_or_vec")]
251        content: Vec<AssistantContent>,
252        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
253        tool_calls: Vec<ToolCall>,
254    },
255    #[serde(rename = "tool", alias = "Tool")]
256    ToolResult {
257        name: String,
258        #[serde(skip_serializing_if = "Option::is_none")]
259        arguments: Option<serde_json::Value>,
260        #[serde(
261            deserialize_with = "string_or_one_or_many",
262            serialize_with = "serialize_tool_content"
263        )]
264        content: OneOrMany<String>,
265    },
266}
267
268fn serialize_tool_content<S>(content: &OneOrMany<String>, serializer: S) -> Result<S::Ok, S::Error>
269where
270    S: Serializer,
271{
272    // OpenAI-compatible APIs expect tool content as a string, not an array
273    let joined = content
274        .iter()
275        .map(String::as_str)
276        .collect::<Vec<_>>()
277        .join("\n");
278    serializer.serialize_str(&joined)
279}
280
281impl Message {
282    pub fn system(content: &str) -> Self {
283        Message::System {
284            content: OneOrMany::one(SystemContent::Text {
285                text: content.to_string(),
286            }),
287        }
288    }
289}
290
291impl TryFrom<message::Message> for Vec<Message> {
292    type Error = message::MessageError;
293
294    fn try_from(message: message::Message) -> Result<Vec<Message>, Self::Error> {
295        match message {
296            message::Message::System { content } => Ok(vec![Message::system(&content)]),
297            message::Message::User { content } => {
298                let (tool_results, other_content): (Vec<_>, Vec<_>) = content
299                    .into_iter()
300                    .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
301
302                if !tool_results.is_empty() {
303                    tool_results
304                        .into_iter()
305                        .map(|content| match content {
306                            message::UserContent::ToolResult(message::ToolResult {
307                                id,
308                                content,
309                                ..
310                            }) => Ok::<_, message::MessageError>(Message::ToolResult {
311                                name: id,
312                                arguments: None,
313                                content: content.try_map(|content| match content {
314                                    message::ToolResultContent::Text(message::Text { text }) => {
315                                        Ok(text)
316                                    }
317                                    _ => Err(message::MessageError::ConversionError(
318                                        "Tool result content does not support non-text".into(),
319                                    )),
320                                })?,
321                            }),
322                            _ => unreachable!(),
323                        })
324                        .collect::<Result<Vec<_>, _>>()
325                } else {
326                    let other_content = OneOrMany::many(other_content).expect(
327                        "There must be other content here if there were no tool result content",
328                    );
329
330                    Ok(vec![Message::User {
331                        content: other_content.try_map(|content| match content {
332                            message::UserContent::Text(text) => {
333                                Ok(UserContent::Text { text: text.text })
334                            }
335                            message::UserContent::Image(image) => {
336                                let url = image.try_into_url()?;
337
338                                Ok(UserContent::ImageUrl {
339                                    image_url: ImageUrl { url },
340                                })
341                            }
342                            message::UserContent::Document(message::Document {
343                                data: message::DocumentSourceKind::Raw(raw), ..
344                            }) => {
345                                let text = String::from_utf8_lossy(raw.as_slice()).into();
346                                Ok(UserContent::Text { text })
347                            }
348                            message::UserContent::Document(message::Document {
349                                data: message::DocumentSourceKind::Base64(text) | message::DocumentSourceKind::String(text), ..
350                            }) => {
351                                Ok(UserContent::Text { text })
352                            }
353                            _ => Err(message::MessageError::ConversionError(
354                                "Huggingface inputs only support text and image URLs (both base64-encoded images and regular URLs)".into(),
355                            )),
356                        })?,
357                    }])
358                }
359            }
360            message::Message::Assistant { content, .. } => {
361                let mut text_content = Vec::new();
362                let mut tool_calls = Vec::new();
363
364                for content in content {
365                    match content {
366                        message::AssistantContent::Text(text) => text_content.push(text),
367                        message::AssistantContent::ToolCall(tool_call) => {
368                            tool_calls.push(tool_call)
369                        }
370                        message::AssistantContent::Reasoning(_) => {
371                            // HuggingFace does not support assistant-history reasoning items.
372                            // Silently skip unsupported reasoning content.
373                        }
374                        message::AssistantContent::Image(_) => {
375                            panic!("Image content is not supported on HuggingFace via Rig");
376                        }
377                    }
378                }
379
380                if text_content.is_empty() && tool_calls.is_empty() {
381                    return Ok(vec![]);
382                }
383
384                Ok(vec![Message::Assistant {
385                    content: text_content
386                        .into_iter()
387                        .map(|content| AssistantContent::Text { text: content.text })
388                        .collect::<Vec<_>>(),
389                    tool_calls: tool_calls
390                        .into_iter()
391                        .map(|tool_call| tool_call.into())
392                        .collect::<Vec<_>>(),
393                }])
394            }
395        }
396    }
397}
398
399impl TryFrom<Message> for message::Message {
400    type Error = message::MessageError;
401
402    fn try_from(message: Message) -> Result<Self, Self::Error> {
403        Ok(match message {
404            Message::User { content, .. } => message::Message::User {
405                content: content.map(|content| content.into()),
406            },
407            Message::Assistant {
408                content,
409                tool_calls,
410                ..
411            } => {
412                let mut content = content
413                    .into_iter()
414                    .map(|content| match content {
415                        AssistantContent::Text { text } => message::AssistantContent::text(text),
416                    })
417                    .collect::<Vec<_>>();
418
419                content.extend(
420                    tool_calls
421                        .into_iter()
422                        .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
423                        .collect::<Result<Vec<_>, _>>()?,
424                );
425
426                message::Message::Assistant {
427                    id: None,
428                    content: OneOrMany::many(content).map_err(|_| {
429                        message::MessageError::ConversionError(
430                            "Neither `content` nor `tool_calls` was provided to the Message"
431                                .to_owned(),
432                        )
433                    })?,
434                }
435            }
436
437            Message::ToolResult { name, content, .. } => message::Message::User {
438                content: OneOrMany::one(message::UserContent::tool_result(
439                    name,
440                    content.map(message::ToolResultContent::text),
441                )),
442            },
443
444            // System messages should get stripped out when converting message's, this is just a
445            // stop gap to avoid obnoxious error handling or panic occurring.
446            Message::System { content, .. } => message::Message::User {
447                content: content.map(|c| match c {
448                    SystemContent::Text { text } => message::UserContent::text(text),
449                }),
450            },
451        })
452    }
453}
454
455#[derive(Clone, Debug, Deserialize, Serialize)]
456pub struct Choice {
457    pub finish_reason: String,
458    pub index: usize,
459    #[serde(default)]
460    pub logprobs: serde_json::Value,
461    pub message: Message,
462}
463
464#[derive(Debug, Deserialize, Clone, Serialize)]
465pub struct Usage {
466    pub completion_tokens: i32,
467    pub prompt_tokens: i32,
468    pub total_tokens: i32,
469}
470
471impl GetTokenUsage for Usage {
472    fn token_usage(&self) -> Option<crate::completion::Usage> {
473        let mut usage = crate::completion::Usage::new();
474        usage.input_tokens = self.prompt_tokens as u64;
475        usage.output_tokens = self.completion_tokens as u64;
476        usage.total_tokens = self.total_tokens as u64;
477
478        Some(usage)
479    }
480}
481
482#[derive(Clone, Debug, Deserialize, Serialize)]
483pub struct CompletionResponse {
484    pub created: i32,
485    pub id: String,
486    pub model: String,
487    pub choices: Vec<Choice>,
488    #[serde(default, deserialize_with = "default_string_on_null")]
489    pub system_fingerprint: String,
490    pub usage: Usage,
491}
492
493impl crate::telemetry::ProviderResponseExt for CompletionResponse {
494    type OutputMessage = Choice;
495    type Usage = Usage;
496
497    fn get_response_id(&self) -> Option<String> {
498        Some(self.id.clone())
499    }
500
501    fn get_response_model_name(&self) -> Option<String> {
502        Some(self.model.clone())
503    }
504
505    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
506        self.choices.clone()
507    }
508
509    fn get_text_response(&self) -> Option<String> {
510        let text_response = self
511            .choices
512            .iter()
513            .filter_map(|x| {
514                let Message::User { ref content } = x.message else {
515                    return None;
516                };
517
518                let text = content
519                    .iter()
520                    .filter_map(|x| {
521                        if let UserContent::Text { text } = x {
522                            Some(text.clone())
523                        } else {
524                            None
525                        }
526                    })
527                    .collect::<Vec<String>>();
528
529                if text.is_empty() {
530                    None
531                } else {
532                    Some(text.join("\n"))
533                }
534            })
535            .collect::<Vec<String>>()
536            .join("\n");
537
538        if text_response.is_empty() {
539            None
540        } else {
541            Some(text_response)
542        }
543    }
544
545    fn get_usage(&self) -> Option<Self::Usage> {
546        Some(self.usage.clone())
547    }
548}
549
550fn default_string_on_null<'de, D>(deserializer: D) -> Result<String, D::Error>
551where
552    D: Deserializer<'de>,
553{
554    match Option::<String>::deserialize(deserializer)? {
555        Some(value) => Ok(value),      // Use provided value
556        None => Ok(String::default()), // Use `Default` implementation
557    }
558}
559
560impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
561    type Error = CompletionError;
562
563    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
564        let choice = response.choices.first().ok_or_else(|| {
565            CompletionError::ResponseError("Response contained no choices".to_owned())
566        })?;
567
568        let content = match &choice.message {
569            Message::Assistant {
570                content,
571                tool_calls,
572                ..
573            } => {
574                let mut content = content
575                    .iter()
576                    .map(|c| match c {
577                        AssistantContent::Text { text } => message::AssistantContent::text(text),
578                    })
579                    .collect::<Vec<_>>();
580
581                content.extend(
582                    tool_calls
583                        .iter()
584                        .map(|call| {
585                            completion::AssistantContent::tool_call(
586                                &call.id,
587                                &call.function.name,
588                                call.function.arguments.clone(),
589                            )
590                        })
591                        .collect::<Vec<_>>(),
592                );
593                Ok(content)
594            }
595            _ => Err(CompletionError::ResponseError(
596                "Response did not contain a valid message or tool call".into(),
597            )),
598        }?;
599
600        let choice = OneOrMany::many(content).map_err(|_| {
601            CompletionError::ResponseError(
602                "Response contained no message or tool call (empty)".to_owned(),
603            )
604        })?;
605
606        let usage = completion::Usage {
607            input_tokens: response.usage.prompt_tokens as u64,
608            output_tokens: response.usage.completion_tokens as u64,
609            total_tokens: response.usage.total_tokens as u64,
610            cached_input_tokens: 0,
611        };
612
613        Ok(completion::CompletionResponse {
614            choice,
615            usage,
616            raw_response: response,
617            message_id: None,
618        })
619    }
620}
621
622#[derive(Debug, Serialize, Deserialize)]
623pub(super) struct HuggingfaceCompletionRequest {
624    model: String,
625    pub messages: Vec<Message>,
626    #[serde(skip_serializing_if = "Option::is_none")]
627    temperature: Option<f64>,
628    #[serde(skip_serializing_if = "Vec::is_empty")]
629    tools: Vec<ToolDefinition>,
630    #[serde(skip_serializing_if = "Option::is_none")]
631    tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
632    #[serde(flatten, skip_serializing_if = "Option::is_none")]
633    pub additional_params: Option<serde_json::Value>,
634}
635
636impl TryFrom<(&str, CompletionRequest)> for HuggingfaceCompletionRequest {
637    type Error = CompletionError;
638
639    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
640        if req.output_schema.is_some() {
641            tracing::warn!("Structured outputs currently not supported for Huggingface");
642        }
643        let model = req.model.clone().unwrap_or_else(|| model.to_string());
644        let mut full_history: Vec<Message> = match &req.preamble {
645            Some(preamble) => vec![Message::system(preamble)],
646            None => vec![],
647        };
648        if let Some(docs) = req.normalized_documents() {
649            let docs: Vec<Message> = docs.try_into()?;
650            full_history.extend(docs);
651        }
652
653        let chat_history: Vec<Message> = req
654            .chat_history
655            .clone()
656            .into_iter()
657            .map(|message| message.try_into())
658            .collect::<Result<Vec<Vec<Message>>, _>>()?
659            .into_iter()
660            .flatten()
661            .collect();
662
663        full_history.extend(chat_history);
664
665        if full_history.is_empty() {
666            return Err(CompletionError::RequestError(
667                std::io::Error::new(
668                    std::io::ErrorKind::InvalidInput,
669                    "HuggingFace request has no provider-compatible messages after conversion",
670                )
671                .into(),
672            ));
673        }
674
675        let tool_choice = req
676            .tool_choice
677            .clone()
678            .map(crate::providers::openai::completion::ToolChoice::try_from)
679            .transpose()?;
680
681        Ok(Self {
682            model: model.to_string(),
683            messages: full_history,
684            temperature: req.temperature,
685            tools: req
686                .tools
687                .clone()
688                .into_iter()
689                .map(ToolDefinition::from)
690                .collect::<Vec<_>>(),
691            tool_choice,
692            additional_params: req.additional_params,
693        })
694    }
695}
696
697#[derive(Clone)]
698pub struct CompletionModel<T = reqwest::Client> {
699    pub(crate) client: Client<T>,
700    /// Name of the model (e.g: google/gemma-2-2b-it)
701    pub model: String,
702}
703
704impl<T> CompletionModel<T> {
705    pub fn new(client: Client<T>, model: &str) -> Self {
706        Self {
707            client,
708            model: model.to_string(),
709        }
710    }
711}
712
713impl<T> completion::CompletionModel for CompletionModel<T>
714where
715    T: HttpClientExt + Clone + 'static,
716{
717    type Response = CompletionResponse;
718    type StreamingResponse = StreamingCompletionResponse;
719
720    type Client = Client<T>;
721
722    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
723        Self::new(client.clone(), &model.into())
724    }
725
726    async fn completion(
727        &self,
728        completion_request: CompletionRequest,
729    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
730        let request_model = completion_request
731            .model
732            .clone()
733            .unwrap_or_else(|| self.model.clone());
734        let span = if tracing::Span::current().is_disabled() {
735            info_span!(
736                target: "rig::completions",
737                "chat",
738                gen_ai.operation.name = "chat",
739                gen_ai.provider.name = "huggingface",
740                gen_ai.request.model = &request_model,
741                gen_ai.system_instructions = &completion_request.preamble,
742                gen_ai.response.id = tracing::field::Empty,
743                gen_ai.response.model = tracing::field::Empty,
744                gen_ai.usage.output_tokens = tracing::field::Empty,
745                gen_ai.usage.input_tokens = tracing::field::Empty,
746                gen_ai.usage.cached_tokens = tracing::field::Empty,
747            )
748        } else {
749            tracing::Span::current()
750        };
751
752        let model = self.client.subprovider().model_identifier(&request_model);
753        let request = HuggingfaceCompletionRequest::try_from((model.as_ref(), completion_request))?;
754
755        if enabled!(Level::TRACE) {
756            tracing::trace!(
757                target: "rig::completions",
758                "Huggingface completion request: {}",
759                serde_json::to_string_pretty(&request)?
760            );
761        }
762
763        let request = serde_json::to_vec(&request)?;
764
765        let path = self
766            .client
767            .subprovider()
768            .completion_endpoint(&request_model);
769        let request = self
770            .client
771            .post(&path)?
772            .header("Content-Type", "application/json")
773            .body(request)
774            .map_err(|e| CompletionError::HttpError(e.into()))?;
775
776        async move {
777            let response = self.client.send(request).await?;
778
779            if response.status().is_success() {
780                let bytes: Vec<u8> = response.into_body().await?;
781                let text = String::from_utf8_lossy(&bytes);
782
783                tracing::debug!(target: "rig", "Huggingface completion error: {}", text);
784
785                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&bytes)? {
786                    ApiResponse::Ok(response) => {
787                        if enabled!(Level::TRACE) {
788                            tracing::trace!(
789                                target: "rig::completions",
790                                "Huggingface completion response: {}",
791                                serde_json::to_string_pretty(&response)?
792                            );
793                        }
794
795                        let span = tracing::Span::current();
796                        span.record_token_usage(&response.usage);
797                        span.record_response_metadata(&response);
798
799                        response.try_into()
800                    }
801                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())),
802                }
803            } else {
804                let status = response.status();
805                let text: Vec<u8> = response.into_body().await?;
806                let text: String = String::from_utf8_lossy(&text).into();
807
808                Err(CompletionError::ProviderError(format!(
809                    "{}: {}",
810                    status, text
811                )))
812            }
813        }
814        .instrument(span)
815        .await
816    }
817
818    async fn stream(
819        &self,
820        request: CompletionRequest,
821    ) -> Result<
822        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
823        CompletionError,
824    > {
825        CompletionModel::stream(self, request).await
826    }
827}
828
829#[cfg(test)]
830mod tests {
831    use super::*;
832    use serde_path_to_error::deserialize;
833
834    #[test]
835    fn test_huggingface_request_uses_request_model_override() {
836        let request = CompletionRequest {
837            model: Some("meta-llama/Meta-Llama-3.1-8B-Instruct".to_string()),
838            preamble: None,
839            chat_history: crate::OneOrMany::one("Hello".into()),
840            documents: vec![],
841            tools: vec![],
842            temperature: None,
843            max_tokens: None,
844            tool_choice: None,
845            additional_params: None,
846            output_schema: None,
847        };
848
849        let hf_request = HuggingfaceCompletionRequest::try_from(("mistralai/Mistral-7B", request))
850            .expect("request conversion should succeed");
851        let serialized = serde_json::to_value(hf_request).expect("serialization should succeed");
852
853        assert_eq!(serialized["model"], "meta-llama/Meta-Llama-3.1-8B-Instruct");
854    }
855
856    #[test]
857    fn test_huggingface_request_uses_default_model_when_override_unset() {
858        let request = CompletionRequest {
859            model: None,
860            preamble: None,
861            chat_history: crate::OneOrMany::one("Hello".into()),
862            documents: vec![],
863            tools: vec![],
864            temperature: None,
865            max_tokens: None,
866            tool_choice: None,
867            additional_params: None,
868            output_schema: None,
869        };
870
871        let hf_request = HuggingfaceCompletionRequest::try_from(("mistralai/Mistral-7B", request))
872            .expect("request conversion should succeed");
873        let serialized = serde_json::to_value(hf_request).expect("serialization should succeed");
874
875        assert_eq!(serialized["model"], "mistralai/Mistral-7B");
876    }
877
878    #[test]
879    fn test_deserialize_message() {
880        let assistant_message_json = r#"
881        {
882            "role": "assistant",
883            "content": "\n\nHello there, how may I assist you today?"
884        }
885        "#;
886
887        let assistant_message_json2 = r#"
888        {
889            "role": "assistant",
890            "content": [
891                {
892                    "type": "text",
893                    "text": "\n\nHello there, how may I assist you today?"
894                }
895            ],
896            "tool_calls": null
897        }
898        "#;
899
900        let assistant_message_json3 = r#"
901        {
902            "role": "assistant",
903            "tool_calls": [
904                {
905                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
906                    "type": "function",
907                    "function": {
908                        "name": "subtract",
909                        "arguments": {"x": 2, "y": 5}
910                    }
911                }
912            ],
913            "content": null,
914            "refusal": null
915        }
916        "#;
917
918        let user_message_json = r#"
919        {
920            "role": "user",
921            "content": [
922                {
923                    "type": "text",
924                    "text": "What's in this image?"
925                },
926                {
927                    "type": "image_url",
928                    "image_url": {
929                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
930                    }
931                }
932            ]
933        }
934        "#;
935
936        let assistant_message: Message = {
937            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
938            deserialize(jd).unwrap_or_else(|err| {
939                panic!(
940                    "Deserialization error at {} ({}:{}): {}",
941                    err.path(),
942                    err.inner().line(),
943                    err.inner().column(),
944                    err
945                );
946            })
947        };
948
949        let assistant_message2: Message = {
950            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
951            deserialize(jd).unwrap_or_else(|err| {
952                panic!(
953                    "Deserialization error at {} ({}:{}): {}",
954                    err.path(),
955                    err.inner().line(),
956                    err.inner().column(),
957                    err
958                );
959            })
960        };
961
962        let assistant_message3: Message = {
963            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
964                &mut serde_json::Deserializer::from_str(assistant_message_json3);
965            deserialize(jd).unwrap_or_else(|err| {
966                panic!(
967                    "Deserialization error at {} ({}:{}): {}",
968                    err.path(),
969                    err.inner().line(),
970                    err.inner().column(),
971                    err
972                );
973            })
974        };
975
976        let user_message: Message = {
977            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
978            deserialize(jd).unwrap_or_else(|err| {
979                panic!(
980                    "Deserialization error at {} ({}:{}): {}",
981                    err.path(),
982                    err.inner().line(),
983                    err.inner().column(),
984                    err
985                );
986            })
987        };
988
989        match assistant_message {
990            Message::Assistant { content, .. } => {
991                assert_eq!(
992                    content[0],
993                    AssistantContent::Text {
994                        text: "\n\nHello there, how may I assist you today?".to_string()
995                    }
996                );
997            }
998            _ => panic!("Expected assistant message"),
999        }
1000
1001        match assistant_message2 {
1002            Message::Assistant {
1003                content,
1004                tool_calls,
1005                ..
1006            } => {
1007                assert_eq!(
1008                    content[0],
1009                    AssistantContent::Text {
1010                        text: "\n\nHello there, how may I assist you today?".to_string()
1011                    }
1012                );
1013
1014                assert_eq!(tool_calls, vec![]);
1015            }
1016            _ => panic!("Expected assistant message"),
1017        }
1018
1019        match assistant_message3 {
1020            Message::Assistant {
1021                content,
1022                tool_calls,
1023                ..
1024            } => {
1025                assert!(content.is_empty());
1026                assert_eq!(
1027                    tool_calls[0],
1028                    ToolCall {
1029                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
1030                        r#type: ToolType::Function,
1031                        function: Function {
1032                            name: "subtract".to_string(),
1033                            arguments: serde_json::json!({"x": 2, "y": 5}),
1034                        },
1035                    }
1036                );
1037            }
1038            _ => panic!("Expected assistant message"),
1039        }
1040
1041        match user_message {
1042            Message::User { content, .. } => {
1043                let (first, second) = {
1044                    let mut iter = content.into_iter();
1045                    (iter.next().unwrap(), iter.next().unwrap())
1046                };
1047                assert_eq!(
1048                    first,
1049                    UserContent::Text {
1050                        text: "What's in this image?".to_string()
1051                    }
1052                );
1053                assert_eq!(second, UserContent::ImageUrl { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string() } });
1054            }
1055            _ => panic!("Expected user message"),
1056        }
1057    }
1058
1059    #[test]
1060    fn test_message_to_message_conversion() {
1061        let user_message = message::Message::User {
1062            content: OneOrMany::one(message::UserContent::text("Hello")),
1063        };
1064
1065        let assistant_message = message::Message::Assistant {
1066            id: None,
1067            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
1068        };
1069
1070        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
1071        let converted_assistant_message: Vec<Message> =
1072            assistant_message.clone().try_into().unwrap();
1073
1074        match converted_user_message[0].clone() {
1075            Message::User { content, .. } => {
1076                assert_eq!(
1077                    content.first(),
1078                    UserContent::Text {
1079                        text: "Hello".to_string()
1080                    }
1081                );
1082            }
1083            _ => panic!("Expected user message"),
1084        }
1085
1086        match converted_assistant_message[0].clone() {
1087            Message::Assistant { content, .. } => {
1088                assert_eq!(
1089                    content[0],
1090                    AssistantContent::Text {
1091                        text: "Hi there!".to_string()
1092                    }
1093                );
1094            }
1095            _ => panic!("Expected assistant message"),
1096        }
1097
1098        let original_user_message: message::Message =
1099            converted_user_message[0].clone().try_into().unwrap();
1100        let original_assistant_message: message::Message =
1101            converted_assistant_message[0].clone().try_into().unwrap();
1102
1103        assert_eq!(original_user_message, user_message);
1104        assert_eq!(original_assistant_message, assistant_message);
1105    }
1106
1107    #[test]
1108    fn test_message_from_message_conversion() {
1109        let user_message = Message::User {
1110            content: OneOrMany::one(UserContent::Text {
1111                text: "Hello".to_string(),
1112            }),
1113        };
1114
1115        let assistant_message = Message::Assistant {
1116            content: vec![AssistantContent::Text {
1117                text: "Hi there!".to_string(),
1118            }],
1119            tool_calls: vec![],
1120        };
1121
1122        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1123        let converted_assistant_message: message::Message =
1124            assistant_message.clone().try_into().unwrap();
1125
1126        match converted_user_message.clone() {
1127            message::Message::User { content } => {
1128                assert_eq!(content.first(), message::UserContent::text("Hello"));
1129            }
1130            _ => panic!("Expected user message"),
1131        }
1132
1133        match converted_assistant_message.clone() {
1134            message::Message::Assistant { content, .. } => {
1135                assert_eq!(
1136                    content.first(),
1137                    message::AssistantContent::text("Hi there!")
1138                );
1139            }
1140            _ => panic!("Expected assistant message"),
1141        }
1142
1143        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
1144        let original_assistant_message: Vec<Message> =
1145            converted_assistant_message.try_into().unwrap();
1146
1147        assert_eq!(original_user_message[0], user_message);
1148        assert_eq!(original_assistant_message[0], assistant_message);
1149    }
1150
1151    #[test]
1152    fn test_responses() {
1153        let fireworks_response_json = r#"
1154        {
1155            "choices": [
1156                {
1157                    "finish_reason": "tool_calls",
1158                    "index": 0,
1159                    "message": {
1160                        "role": "assistant",
1161                        "tool_calls": [
1162                            {
1163                                "function": {
1164                                "arguments": "{\"x\": 2, \"y\": 5}",
1165                                "name": "subtract"
1166                                },
1167                                "id": "call_1BspL6mQqjKgvsQbH1TIYkHf",
1168                                "index": 0,
1169                                "type": "function"
1170                            }
1171                        ]
1172                    }
1173                }
1174            ],
1175            "created": 1740704000,
1176            "id": "2a81f6a1-4866-42fb-9902-2655a2b5b1ff",
1177            "model": "accounts/fireworks/models/deepseek-v3",
1178            "object": "chat.completion",
1179            "usage": {
1180                "completion_tokens": 26,
1181                "prompt_tokens": 248,
1182                "total_tokens": 274
1183            }
1184        }
1185        "#;
1186
1187        let novita_response_json = r#"
1188        {
1189            "choices": [
1190                {
1191                    "finish_reason": "tool_calls",
1192                    "index": 0,
1193                    "logprobs": null,
1194                    "message": {
1195                        "audio": null,
1196                        "content": null,
1197                        "function_call": null,
1198                        "reasoning_content": null,
1199                        "refusal": null,
1200                        "role": "assistant",
1201                        "tool_calls": [
1202                            {
1203                                "function": {
1204                                    "arguments": "{\"x\": \"2\", \"y\": \"5\"}",
1205                                    "name": "subtract"
1206                                },
1207                                "id": "chatcmpl-tool-f6d2af7c8dc041058f95e2c2eede45c5",
1208                                "type": "function"
1209                            }
1210                        ]
1211                    },
1212                    "stop_reason": 128008
1213                }
1214            ],
1215            "created": 1740704592,
1216            "id": "chatcmpl-a92c60ae125c47c998ecdcb53387fed4",
1217            "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-fast",
1218            "object": "chat.completion",
1219            "prompt_logprobs": null,
1220            "service_tier": null,
1221            "system_fingerprint": null,
1222            "usage": {
1223                "completion_tokens": 28,
1224                "completion_tokens_details": null,
1225                "prompt_tokens": 335,
1226                "prompt_tokens_details": null,
1227                "total_tokens": 363
1228            }
1229        }
1230        "#;
1231
1232        let _firework_response: CompletionResponse = {
1233            let jd = &mut serde_json::Deserializer::from_str(fireworks_response_json);
1234            deserialize(jd).unwrap_or_else(|err| {
1235                panic!(
1236                    "Deserialization error at {} ({}:{}): {}",
1237                    err.path(),
1238                    err.inner().line(),
1239                    err.inner().column(),
1240                    err
1241                );
1242            })
1243        };
1244
1245        let _novita_response: CompletionResponse = {
1246            let jd = &mut serde_json::Deserializer::from_str(novita_response_json);
1247            deserialize(jd).unwrap_or_else(|err| {
1248                panic!(
1249                    "Deserialization error at {} ({}:{}): {}",
1250                    err.path(),
1251                    err.inner().line(),
1252                    err.inner().column(),
1253                    err
1254                );
1255            })
1256        };
1257    }
1258
1259    #[test]
1260    fn test_assistant_reasoning_is_silently_skipped() {
1261        let assistant = message::Message::Assistant {
1262            id: None,
1263            content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1264        };
1265
1266        let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
1267        assert!(converted.is_empty());
1268    }
1269
1270    #[test]
1271    fn test_assistant_text_and_tool_call_are_preserved_when_reasoning_present() {
1272        let assistant = message::Message::Assistant {
1273            id: None,
1274            content: OneOrMany::many(vec![
1275                message::AssistantContent::reasoning("hidden"),
1276                message::AssistantContent::text("visible"),
1277                message::AssistantContent::tool_call(
1278                    "call_1",
1279                    "subtract",
1280                    serde_json::json!({"x": 2, "y": 1}),
1281                ),
1282            ])
1283            .expect("non-empty assistant content"),
1284        };
1285
1286        let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
1287        assert_eq!(converted.len(), 1);
1288
1289        match &converted[0] {
1290            Message::Assistant {
1291                content,
1292                tool_calls,
1293                ..
1294            } => {
1295                assert_eq!(
1296                    content,
1297                    &vec![AssistantContent::Text {
1298                        text: "visible".to_string()
1299                    }]
1300                );
1301                assert_eq!(tool_calls.len(), 1);
1302                assert_eq!(tool_calls[0].id, "call_1");
1303                assert_eq!(tool_calls[0].function.name, "subtract");
1304                assert_eq!(
1305                    tool_calls[0].function.arguments,
1306                    serde_json::json!({"x": 2, "y": 1})
1307                );
1308            }
1309            _ => panic!("expected assistant message"),
1310        }
1311    }
1312
1313    #[test]
1314    fn test_request_conversion_errors_when_all_messages_are_filtered() {
1315        let request = completion::CompletionRequest {
1316            preamble: None,
1317            chat_history: OneOrMany::one(message::Message::Assistant {
1318                id: None,
1319                content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1320            }),
1321            documents: vec![],
1322            tools: vec![],
1323            temperature: None,
1324            max_tokens: None,
1325            tool_choice: None,
1326            additional_params: None,
1327            model: None,
1328            output_schema: None,
1329        };
1330
1331        let result = HuggingfaceCompletionRequest::try_from(("meta/test-model", request));
1332        assert!(matches!(result, Err(CompletionError::RequestError(_))));
1333    }
1334}