Skip to main content

rig/providers/huggingface/
completion.rs

1use super::client::Client;
2use crate::completion::GetTokenUsage;
3use crate::http_client::HttpClientExt;
4use crate::providers::openai::StreamingCompletionResponse;
5use crate::telemetry::SpanCombinator;
6use crate::{
7    OneOrMany,
8    completion::{self, CompletionError, CompletionRequest},
9    json_utils,
10    message::{self},
11    one_or_many::string_or_one_or_many,
12};
13use serde::{Deserialize, Deserializer, Serialize, Serializer};
14use serde_json::Value;
15use std::{convert::Infallible, str::FromStr};
16use tracing::{Level, enabled, info_span};
17use tracing_futures::Instrument;
18
19#[derive(Debug, Deserialize)]
20#[serde(untagged)]
21pub enum ApiResponse<T> {
22    Ok(T),
23    Err(Value),
24}
25
26// ================================================================
27// Huggingface Completion API
28// ================================================================
29
30// Conversational LLMs
31/// `google/gemma-2-2b-it` completion model
32pub const GEMMA_2: &str = "google/gemma-2-2b-it";
33/// `meta-llama/Meta-Llama-3.1-8B-Instruct` completion model
34pub const META_LLAMA_3_1: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
35/// `PowerInfer/SmallThinker-3B-Preview` completion model
36pub const SMALLTHINKER_PREVIEW: &str = "PowerInfer/SmallThinker-3B-Preview";
37/// `Qwen/Qwen2.5-7B-Instruct` completion model
38pub const QWEN2_5: &str = "Qwen/Qwen2.5-7B-Instruct";
39/// `Qwen/Qwen2.5-Coder-32B-Instruct` completion model
40pub const QWEN2_5_CODER: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
41
42// Conversational VLMs
43
44/// `Qwen/Qwen2-VL-7B-Instruct` visual-language completion model
45pub const QWEN2_VL: &str = "Qwen/Qwen2-VL-7B-Instruct";
46/// `Qwen/QVQ-72B-Preview` visual-language completion model
47pub const QWEN_QVQ_PREVIEW: &str = "Qwen/QVQ-72B-Preview";
48
49#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
50pub struct Function {
51    name: String,
52    #[serde(
53        serialize_with = "json_utils::stringified_json::serialize",
54        deserialize_with = "json_utils::stringified_json::deserialize_maybe_stringified"
55    )]
56    pub arguments: serde_json::Value,
57}
58
59impl From<Function> for message::ToolFunction {
60    fn from(value: Function) -> Self {
61        message::ToolFunction {
62            name: value.name,
63            arguments: value.arguments,
64        }
65    }
66}
67
68#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
69#[serde(rename_all = "lowercase")]
70pub enum ToolType {
71    #[default]
72    Function,
73}
74
75#[derive(Debug, Deserialize, Serialize, Clone)]
76pub struct ToolDefinition {
77    pub r#type: String,
78    pub function: completion::ToolDefinition,
79}
80
81impl From<completion::ToolDefinition> for ToolDefinition {
82    fn from(tool: completion::ToolDefinition) -> Self {
83        Self {
84            r#type: "function".into(),
85            function: tool,
86        }
87    }
88}
89
90#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
91pub struct ToolCall {
92    pub id: String,
93    pub r#type: ToolType,
94    pub function: Function,
95}
96
97impl From<ToolCall> for message::ToolCall {
98    fn from(value: ToolCall) -> Self {
99        message::ToolCall {
100            id: value.id,
101            call_id: None,
102            function: value.function.into(),
103            signature: None,
104            additional_params: None,
105        }
106    }
107}
108
109impl From<message::ToolCall> for ToolCall {
110    fn from(value: message::ToolCall) -> Self {
111        ToolCall {
112            id: value.id,
113            r#type: ToolType::Function,
114            function: Function {
115                name: value.function.name,
116                arguments: value.function.arguments,
117            },
118        }
119    }
120}
121
122#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
123pub struct ImageUrl {
124    url: String,
125}
126
127#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
128#[serde(tag = "type", rename_all = "lowercase")]
129pub enum UserContent {
130    Text {
131        text: String,
132    },
133    #[serde(rename = "image_url")]
134    ImageUrl {
135        image_url: ImageUrl,
136    },
137}
138
139impl FromStr for UserContent {
140    type Err = Infallible;
141
142    fn from_str(s: &str) -> Result<Self, Self::Err> {
143        Ok(UserContent::Text {
144            text: s.to_string(),
145        })
146    }
147}
148
149#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
150#[serde(tag = "type", rename_all = "lowercase")]
151pub enum AssistantContent {
152    Text { text: String },
153}
154
155impl FromStr for AssistantContent {
156    type Err = Infallible;
157
158    fn from_str(s: &str) -> Result<Self, Self::Err> {
159        Ok(AssistantContent::Text {
160            text: s.to_string(),
161        })
162    }
163}
164
165#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
166#[serde(tag = "type", rename_all = "lowercase")]
167pub enum SystemContent {
168    Text { text: String },
169}
170
171impl FromStr for SystemContent {
172    type Err = Infallible;
173
174    fn from_str(s: &str) -> Result<Self, Self::Err> {
175        Ok(SystemContent::Text {
176            text: s.to_string(),
177        })
178    }
179}
180
181impl From<UserContent> for message::UserContent {
182    fn from(value: UserContent) -> Self {
183        match value {
184            UserContent::Text { text } => message::UserContent::text(text),
185            UserContent::ImageUrl { image_url } => {
186                message::UserContent::image_url(image_url.url, None, None)
187            }
188        }
189    }
190}
191
192impl TryFrom<message::UserContent> for UserContent {
193    type Error = message::MessageError;
194
195    fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
196        match content {
197            message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }),
198            message::UserContent::Document(message::Document {
199                data: message::DocumentSourceKind::Raw(raw),
200                ..
201            }) => {
202                let text = String::from_utf8_lossy(raw.as_slice()).into();
203                Ok(UserContent::Text { text })
204            }
205            message::UserContent::Document(message::Document {
206                data:
207                    message::DocumentSourceKind::Base64(text)
208                    | message::DocumentSourceKind::String(text),
209                ..
210            }) => Ok(UserContent::Text { text }),
211            message::UserContent::Image(message::Image { data, .. }) => match data {
212                message::DocumentSourceKind::Url(url) => Ok(UserContent::ImageUrl {
213                    image_url: ImageUrl { url },
214                }),
215                _ => Err(message::MessageError::ConversionError(
216                    "Huggingface only supports images as urls".into(),
217                )),
218            },
219            _ => Err(message::MessageError::ConversionError(
220                "Huggingface only supports text and images".into(),
221            )),
222        }
223    }
224}
225
226#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
227#[serde(tag = "role", rename_all = "lowercase")]
228pub enum Message {
229    System {
230        #[serde(deserialize_with = "string_or_one_or_many")]
231        content: OneOrMany<SystemContent>,
232    },
233    User {
234        #[serde(deserialize_with = "string_or_one_or_many")]
235        content: OneOrMany<UserContent>,
236    },
237    Assistant {
238        #[serde(default, deserialize_with = "json_utils::string_or_vec")]
239        content: Vec<AssistantContent>,
240        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
241        tool_calls: Vec<ToolCall>,
242    },
243    #[serde(rename = "tool", alias = "Tool")]
244    ToolResult {
245        name: String,
246        #[serde(skip_serializing_if = "Option::is_none")]
247        arguments: Option<serde_json::Value>,
248        #[serde(
249            deserialize_with = "string_or_one_or_many",
250            serialize_with = "serialize_tool_content"
251        )]
252        content: OneOrMany<String>,
253    },
254}
255
256fn serialize_tool_content<S>(content: &OneOrMany<String>, serializer: S) -> Result<S::Ok, S::Error>
257where
258    S: Serializer,
259{
260    // OpenAI-compatible APIs expect tool content as a string, not an array
261    let joined = content
262        .iter()
263        .map(String::as_str)
264        .collect::<Vec<_>>()
265        .join("\n");
266    serializer.serialize_str(&joined)
267}
268
269impl Message {
270    pub fn system(content: &str) -> Self {
271        Message::System {
272            content: OneOrMany::one(SystemContent::Text {
273                text: content.to_string(),
274            }),
275        }
276    }
277}
278
279impl TryFrom<message::Message> for Vec<Message> {
280    type Error = message::MessageError;
281
282    fn try_from(message: message::Message) -> Result<Vec<Message>, Self::Error> {
283        match message {
284            message::Message::System { content } => Ok(vec![Message::system(&content)]),
285            message::Message::User { content } => {
286                let (tool_results, other_content): (Vec<_>, Vec<_>) = content
287                    .into_iter()
288                    .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
289
290                if !tool_results.is_empty() {
291                    tool_results
292                        .into_iter()
293                        .map(|content| match content {
294                            message::UserContent::ToolResult(message::ToolResult {
295                                id,
296                                content,
297                                ..
298                            }) => Ok::<_, message::MessageError>(Message::ToolResult {
299                                name: id,
300                                arguments: None,
301                                content: content.try_map(|content| match content {
302                                    message::ToolResultContent::Text(message::Text { text }) => {
303                                        Ok(text)
304                                    }
305                                    _ => Err(message::MessageError::ConversionError(
306                                        "Tool result content does not support non-text".into(),
307                                    )),
308                                })?,
309                            }),
310                            _ => unreachable!(),
311                        })
312                        .collect::<Result<Vec<_>, _>>()
313                } else {
314                    let other_content = OneOrMany::many(other_content).expect(
315                        "There must be other content here if there were no tool result content",
316                    );
317
318                    Ok(vec![Message::User {
319                        content: other_content.try_map(|content| match content {
320                            message::UserContent::Text(text) => {
321                                Ok(UserContent::Text { text: text.text })
322                            }
323                            message::UserContent::Image(image) => {
324                                let url = image.try_into_url()?;
325
326                                Ok(UserContent::ImageUrl {
327                                    image_url: ImageUrl { url },
328                                })
329                            }
330                            message::UserContent::Document(message::Document {
331                                data: message::DocumentSourceKind::Raw(raw), ..
332                            }) => {
333                                let text = String::from_utf8_lossy(raw.as_slice()).into();
334                                Ok(UserContent::Text { text })
335                            }
336                            message::UserContent::Document(message::Document {
337                                data: message::DocumentSourceKind::Base64(text) | message::DocumentSourceKind::String(text), ..
338                            }) => {
339                                Ok(UserContent::Text { text })
340                            }
341                            _ => Err(message::MessageError::ConversionError(
342                                "Huggingface inputs only support text and image URLs (both base64-encoded images and regular URLs)".into(),
343                            )),
344                        })?,
345                    }])
346                }
347            }
348            message::Message::Assistant { content, .. } => {
349                let mut text_content = Vec::new();
350                let mut tool_calls = Vec::new();
351
352                for content in content {
353                    match content {
354                        message::AssistantContent::Text(text) => text_content.push(text),
355                        message::AssistantContent::ToolCall(tool_call) => {
356                            tool_calls.push(tool_call)
357                        }
358                        message::AssistantContent::Reasoning(_) => {
359                            // HuggingFace does not support assistant-history reasoning items.
360                            // Silently skip unsupported reasoning content.
361                        }
362                        message::AssistantContent::Image(_) => {
363                            panic!("Image content is not supported on HuggingFace via Rig");
364                        }
365                    }
366                }
367
368                if text_content.is_empty() && tool_calls.is_empty() {
369                    return Ok(vec![]);
370                }
371
372                Ok(vec![Message::Assistant {
373                    content: text_content
374                        .into_iter()
375                        .map(|content| AssistantContent::Text { text: content.text })
376                        .collect::<Vec<_>>(),
377                    tool_calls: tool_calls
378                        .into_iter()
379                        .map(|tool_call| tool_call.into())
380                        .collect::<Vec<_>>(),
381                }])
382            }
383        }
384    }
385}
386
387impl TryFrom<Message> for message::Message {
388    type Error = message::MessageError;
389
390    fn try_from(message: Message) -> Result<Self, Self::Error> {
391        Ok(match message {
392            Message::User { content, .. } => message::Message::User {
393                content: content.map(|content| content.into()),
394            },
395            Message::Assistant {
396                content,
397                tool_calls,
398                ..
399            } => {
400                let mut content = content
401                    .into_iter()
402                    .map(|content| match content {
403                        AssistantContent::Text { text } => message::AssistantContent::text(text),
404                    })
405                    .collect::<Vec<_>>();
406
407                content.extend(
408                    tool_calls
409                        .into_iter()
410                        .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
411                        .collect::<Result<Vec<_>, _>>()?,
412                );
413
414                message::Message::Assistant {
415                    id: None,
416                    content: OneOrMany::many(content).map_err(|_| {
417                        message::MessageError::ConversionError(
418                            "Neither `content` nor `tool_calls` was provided to the Message"
419                                .to_owned(),
420                        )
421                    })?,
422                }
423            }
424
425            Message::ToolResult { name, content, .. } => message::Message::User {
426                content: OneOrMany::one(message::UserContent::tool_result(
427                    name,
428                    content.map(message::ToolResultContent::text),
429                )),
430            },
431
432            // System messages should get stripped out when converting message's, this is just a
433            // stop gap to avoid obnoxious error handling or panic occurring.
434            Message::System { content, .. } => message::Message::User {
435                content: content.map(|c| match c {
436                    SystemContent::Text { text } => message::UserContent::text(text),
437                }),
438            },
439        })
440    }
441}
442
443#[derive(Clone, Debug, Deserialize, Serialize)]
444pub struct Choice {
445    pub finish_reason: String,
446    pub index: usize,
447    #[serde(default)]
448    pub logprobs: serde_json::Value,
449    pub message: Message,
450}
451
452#[derive(Debug, Deserialize, Clone, Serialize)]
453pub struct Usage {
454    pub completion_tokens: i32,
455    pub prompt_tokens: i32,
456    pub total_tokens: i32,
457}
458
459impl GetTokenUsage for Usage {
460    fn token_usage(&self) -> Option<crate::completion::Usage> {
461        let mut usage = crate::completion::Usage::new();
462        usage.input_tokens = self.prompt_tokens as u64;
463        usage.output_tokens = self.completion_tokens as u64;
464        usage.total_tokens = self.total_tokens as u64;
465
466        Some(usage)
467    }
468}
469
470#[derive(Clone, Debug, Deserialize, Serialize)]
471pub struct CompletionResponse {
472    pub created: i32,
473    pub id: String,
474    pub model: String,
475    pub choices: Vec<Choice>,
476    #[serde(default, deserialize_with = "default_string_on_null")]
477    pub system_fingerprint: String,
478    pub usage: Usage,
479}
480
481impl crate::telemetry::ProviderResponseExt for CompletionResponse {
482    type OutputMessage = Choice;
483    type Usage = Usage;
484
485    fn get_response_id(&self) -> Option<String> {
486        Some(self.id.clone())
487    }
488
489    fn get_response_model_name(&self) -> Option<String> {
490        Some(self.model.clone())
491    }
492
493    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
494        self.choices.clone()
495    }
496
497    fn get_text_response(&self) -> Option<String> {
498        let text_response = self
499            .choices
500            .iter()
501            .filter_map(|x| {
502                let Message::User { ref content } = x.message else {
503                    return None;
504                };
505
506                let text = content
507                    .iter()
508                    .filter_map(|x| {
509                        if let UserContent::Text { text } = x {
510                            Some(text.clone())
511                        } else {
512                            None
513                        }
514                    })
515                    .collect::<Vec<String>>();
516
517                if text.is_empty() {
518                    None
519                } else {
520                    Some(text.join("\n"))
521                }
522            })
523            .collect::<Vec<String>>()
524            .join("\n");
525
526        if text_response.is_empty() {
527            None
528        } else {
529            Some(text_response)
530        }
531    }
532
533    fn get_usage(&self) -> Option<Self::Usage> {
534        Some(self.usage.clone())
535    }
536}
537
538fn default_string_on_null<'de, D>(deserializer: D) -> Result<String, D::Error>
539where
540    D: Deserializer<'de>,
541{
542    match Option::<String>::deserialize(deserializer)? {
543        Some(value) => Ok(value),      // Use provided value
544        None => Ok(String::default()), // Use `Default` implementation
545    }
546}
547
548impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
549    type Error = CompletionError;
550
551    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
552        let choice = response.choices.first().ok_or_else(|| {
553            CompletionError::ResponseError("Response contained no choices".to_owned())
554        })?;
555
556        let content = match &choice.message {
557            Message::Assistant {
558                content,
559                tool_calls,
560                ..
561            } => {
562                let mut content = content
563                    .iter()
564                    .map(|c| match c {
565                        AssistantContent::Text { text } => message::AssistantContent::text(text),
566                    })
567                    .collect::<Vec<_>>();
568
569                content.extend(
570                    tool_calls
571                        .iter()
572                        .map(|call| {
573                            completion::AssistantContent::tool_call(
574                                &call.id,
575                                &call.function.name,
576                                call.function.arguments.clone(),
577                            )
578                        })
579                        .collect::<Vec<_>>(),
580                );
581                Ok(content)
582            }
583            _ => Err(CompletionError::ResponseError(
584                "Response did not contain a valid message or tool call".into(),
585            )),
586        }?;
587
588        let choice = OneOrMany::many(content).map_err(|_| {
589            CompletionError::ResponseError(
590                "Response contained no message or tool call (empty)".to_owned(),
591            )
592        })?;
593
594        let usage = completion::Usage {
595            input_tokens: response.usage.prompt_tokens as u64,
596            output_tokens: response.usage.completion_tokens as u64,
597            total_tokens: response.usage.total_tokens as u64,
598            cached_input_tokens: 0,
599            cache_creation_input_tokens: 0,
600        };
601
602        Ok(completion::CompletionResponse {
603            choice,
604            usage,
605            raw_response: response,
606            message_id: None,
607        })
608    }
609}
610
611#[derive(Debug, Serialize, Deserialize)]
612pub(super) struct HuggingfaceCompletionRequest {
613    model: String,
614    pub messages: Vec<Message>,
615    #[serde(skip_serializing_if = "Option::is_none")]
616    temperature: Option<f64>,
617    #[serde(skip_serializing_if = "Vec::is_empty")]
618    tools: Vec<ToolDefinition>,
619    #[serde(skip_serializing_if = "Option::is_none")]
620    tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
621    #[serde(flatten, skip_serializing_if = "Option::is_none")]
622    pub additional_params: Option<serde_json::Value>,
623}
624
625impl TryFrom<(&str, CompletionRequest)> for HuggingfaceCompletionRequest {
626    type Error = CompletionError;
627
628    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
629        if req.output_schema.is_some() {
630            tracing::warn!("Structured outputs currently not supported for Huggingface");
631        }
632        let model = req.model.clone().unwrap_or_else(|| model.to_string());
633        let mut full_history: Vec<Message> = match &req.preamble {
634            Some(preamble) => vec![Message::system(preamble)],
635            None => vec![],
636        };
637        if let Some(docs) = req.normalized_documents() {
638            let docs: Vec<Message> = docs.try_into()?;
639            full_history.extend(docs);
640        }
641
642        let chat_history: Vec<Message> = req
643            .chat_history
644            .clone()
645            .into_iter()
646            .map(|message| message.try_into())
647            .collect::<Result<Vec<Vec<Message>>, _>>()?
648            .into_iter()
649            .flatten()
650            .collect();
651
652        full_history.extend(chat_history);
653
654        if full_history.is_empty() {
655            return Err(CompletionError::RequestError(
656                std::io::Error::new(
657                    std::io::ErrorKind::InvalidInput,
658                    "HuggingFace request has no provider-compatible messages after conversion",
659                )
660                .into(),
661            ));
662        }
663
664        let tool_choice = req
665            .tool_choice
666            .clone()
667            .map(crate::providers::openai::completion::ToolChoice::try_from)
668            .transpose()?;
669
670        Ok(Self {
671            model: model.to_string(),
672            messages: full_history,
673            temperature: req.temperature,
674            tools: req
675                .tools
676                .clone()
677                .into_iter()
678                .map(ToolDefinition::from)
679                .collect::<Vec<_>>(),
680            tool_choice,
681            additional_params: req.additional_params,
682        })
683    }
684}
685
686#[derive(Clone)]
687pub struct CompletionModel<T = reqwest::Client> {
688    pub(crate) client: Client<T>,
689    /// Name of the model (e.g: google/gemma-2-2b-it)
690    pub model: String,
691}
692
693impl<T> CompletionModel<T> {
694    pub fn new(client: Client<T>, model: &str) -> Self {
695        Self {
696            client,
697            model: model.to_string(),
698        }
699    }
700}
701
702impl<T> completion::CompletionModel for CompletionModel<T>
703where
704    T: HttpClientExt + Clone + 'static,
705{
706    type Response = CompletionResponse;
707    type StreamingResponse = StreamingCompletionResponse;
708
709    type Client = Client<T>;
710
711    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
712        Self::new(client.clone(), &model.into())
713    }
714
715    async fn completion(
716        &self,
717        completion_request: CompletionRequest,
718    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
719        let request_model = completion_request
720            .model
721            .clone()
722            .unwrap_or_else(|| self.model.clone());
723        let span = if tracing::Span::current().is_disabled() {
724            info_span!(
725                target: "rig::completions",
726                "chat",
727                gen_ai.operation.name = "chat",
728                gen_ai.provider.name = "huggingface",
729                gen_ai.request.model = &request_model,
730                gen_ai.system_instructions = &completion_request.preamble,
731                gen_ai.response.id = tracing::field::Empty,
732                gen_ai.response.model = tracing::field::Empty,
733                gen_ai.usage.output_tokens = tracing::field::Empty,
734                gen_ai.usage.input_tokens = tracing::field::Empty,
735                gen_ai.usage.cached_tokens = tracing::field::Empty,
736            )
737        } else {
738            tracing::Span::current()
739        };
740
741        let model = self.client.subprovider().model_identifier(&request_model);
742        let request = HuggingfaceCompletionRequest::try_from((model.as_ref(), completion_request))?;
743
744        if enabled!(Level::TRACE) {
745            tracing::trace!(
746                target: "rig::completions",
747                "Huggingface completion request: {}",
748                serde_json::to_string_pretty(&request)?
749            );
750        }
751
752        let request = serde_json::to_vec(&request)?;
753
754        let path = self
755            .client
756            .subprovider()
757            .completion_endpoint(&request_model);
758        let request = self
759            .client
760            .post(&path)?
761            .header("Content-Type", "application/json")
762            .body(request)
763            .map_err(|e| CompletionError::HttpError(e.into()))?;
764
765        async move {
766            let response = self.client.send(request).await?;
767
768            if response.status().is_success() {
769                let bytes: Vec<u8> = response.into_body().await?;
770                let text = String::from_utf8_lossy(&bytes);
771
772                tracing::debug!(target: "rig", "Huggingface completion error: {}", text);
773
774                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&bytes)? {
775                    ApiResponse::Ok(response) => {
776                        if enabled!(Level::TRACE) {
777                            tracing::trace!(
778                                target: "rig::completions",
779                                "Huggingface completion response: {}",
780                                serde_json::to_string_pretty(&response)?
781                            );
782                        }
783
784                        let span = tracing::Span::current();
785                        span.record_token_usage(&response.usage);
786                        span.record_response_metadata(&response);
787
788                        response.try_into()
789                    }
790                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())),
791                }
792            } else {
793                let status = response.status();
794                let text: Vec<u8> = response.into_body().await?;
795                let text: String = String::from_utf8_lossy(&text).into();
796
797                Err(CompletionError::ProviderError(format!(
798                    "{}: {}",
799                    status, text
800                )))
801            }
802        }
803        .instrument(span)
804        .await
805    }
806
807    async fn stream(
808        &self,
809        request: CompletionRequest,
810    ) -> Result<
811        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
812        CompletionError,
813    > {
814        CompletionModel::stream(self, request).await
815    }
816}
817
818#[cfg(test)]
819mod tests {
820    use super::*;
821    use serde_path_to_error::deserialize;
822
823    #[test]
824    fn test_huggingface_request_uses_request_model_override() {
825        let request = CompletionRequest {
826            model: Some("meta-llama/Meta-Llama-3.1-8B-Instruct".to_string()),
827            preamble: None,
828            chat_history: crate::OneOrMany::one("Hello".into()),
829            documents: vec![],
830            tools: vec![],
831            temperature: None,
832            max_tokens: None,
833            tool_choice: None,
834            additional_params: None,
835            output_schema: None,
836        };
837
838        let hf_request = HuggingfaceCompletionRequest::try_from(("mistralai/Mistral-7B", request))
839            .expect("request conversion should succeed");
840        let serialized = serde_json::to_value(hf_request).expect("serialization should succeed");
841
842        assert_eq!(serialized["model"], "meta-llama/Meta-Llama-3.1-8B-Instruct");
843    }
844
845    #[test]
846    fn test_huggingface_request_uses_default_model_when_override_unset() {
847        let request = CompletionRequest {
848            model: None,
849            preamble: None,
850            chat_history: crate::OneOrMany::one("Hello".into()),
851            documents: vec![],
852            tools: vec![],
853            temperature: None,
854            max_tokens: None,
855            tool_choice: None,
856            additional_params: None,
857            output_schema: None,
858        };
859
860        let hf_request = HuggingfaceCompletionRequest::try_from(("mistralai/Mistral-7B", request))
861            .expect("request conversion should succeed");
862        let serialized = serde_json::to_value(hf_request).expect("serialization should succeed");
863
864        assert_eq!(serialized["model"], "mistralai/Mistral-7B");
865    }
866
867    #[test]
868    fn test_deserialize_message() {
869        let assistant_message_json = r#"
870        {
871            "role": "assistant",
872            "content": "\n\nHello there, how may I assist you today?"
873        }
874        "#;
875
876        let assistant_message_json2 = r#"
877        {
878            "role": "assistant",
879            "content": [
880                {
881                    "type": "text",
882                    "text": "\n\nHello there, how may I assist you today?"
883                }
884            ],
885            "tool_calls": null
886        }
887        "#;
888
889        let assistant_message_json3 = r#"
890        {
891            "role": "assistant",
892            "tool_calls": [
893                {
894                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
895                    "type": "function",
896                    "function": {
897                        "name": "subtract",
898                        "arguments": {"x": 2, "y": 5}
899                    }
900                }
901            ],
902            "content": null,
903            "refusal": null
904        }
905        "#;
906
907        let user_message_json = r#"
908        {
909            "role": "user",
910            "content": [
911                {
912                    "type": "text",
913                    "text": "What's in this image?"
914                },
915                {
916                    "type": "image_url",
917                    "image_url": {
918                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
919                    }
920                }
921            ]
922        }
923        "#;
924
925        let assistant_message: Message = {
926            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
927            deserialize(jd).unwrap_or_else(|err| {
928                panic!(
929                    "Deserialization error at {} ({}:{}): {}",
930                    err.path(),
931                    err.inner().line(),
932                    err.inner().column(),
933                    err
934                );
935            })
936        };
937
938        let assistant_message2: Message = {
939            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
940            deserialize(jd).unwrap_or_else(|err| {
941                panic!(
942                    "Deserialization error at {} ({}:{}): {}",
943                    err.path(),
944                    err.inner().line(),
945                    err.inner().column(),
946                    err
947                );
948            })
949        };
950
951        let assistant_message3: Message = {
952            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
953                &mut serde_json::Deserializer::from_str(assistant_message_json3);
954            deserialize(jd).unwrap_or_else(|err| {
955                panic!(
956                    "Deserialization error at {} ({}:{}): {}",
957                    err.path(),
958                    err.inner().line(),
959                    err.inner().column(),
960                    err
961                );
962            })
963        };
964
965        let user_message: Message = {
966            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
967            deserialize(jd).unwrap_or_else(|err| {
968                panic!(
969                    "Deserialization error at {} ({}:{}): {}",
970                    err.path(),
971                    err.inner().line(),
972                    err.inner().column(),
973                    err
974                );
975            })
976        };
977
978        match assistant_message {
979            Message::Assistant { content, .. } => {
980                assert_eq!(
981                    content[0],
982                    AssistantContent::Text {
983                        text: "\n\nHello there, how may I assist you today?".to_string()
984                    }
985                );
986            }
987            _ => panic!("Expected assistant message"),
988        }
989
990        match assistant_message2 {
991            Message::Assistant {
992                content,
993                tool_calls,
994                ..
995            } => {
996                assert_eq!(
997                    content[0],
998                    AssistantContent::Text {
999                        text: "\n\nHello there, how may I assist you today?".to_string()
1000                    }
1001                );
1002
1003                assert_eq!(tool_calls, vec![]);
1004            }
1005            _ => panic!("Expected assistant message"),
1006        }
1007
1008        match assistant_message3 {
1009            Message::Assistant {
1010                content,
1011                tool_calls,
1012                ..
1013            } => {
1014                assert!(content.is_empty());
1015                assert_eq!(
1016                    tool_calls[0],
1017                    ToolCall {
1018                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
1019                        r#type: ToolType::Function,
1020                        function: Function {
1021                            name: "subtract".to_string(),
1022                            arguments: serde_json::json!({"x": 2, "y": 5}),
1023                        },
1024                    }
1025                );
1026            }
1027            _ => panic!("Expected assistant message"),
1028        }
1029
1030        match user_message {
1031            Message::User { content, .. } => {
1032                let (first, second) = {
1033                    let mut iter = content.into_iter();
1034                    (iter.next().unwrap(), iter.next().unwrap())
1035                };
1036                assert_eq!(
1037                    first,
1038                    UserContent::Text {
1039                        text: "What's in this image?".to_string()
1040                    }
1041                );
1042                assert_eq!(second, UserContent::ImageUrl { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string() } });
1043            }
1044            _ => panic!("Expected user message"),
1045        }
1046    }
1047
1048    #[test]
1049    fn test_message_to_message_conversion() {
1050        let user_message = message::Message::User {
1051            content: OneOrMany::one(message::UserContent::text("Hello")),
1052        };
1053
1054        let assistant_message = message::Message::Assistant {
1055            id: None,
1056            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
1057        };
1058
1059        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
1060        let converted_assistant_message: Vec<Message> =
1061            assistant_message.clone().try_into().unwrap();
1062
1063        match converted_user_message[0].clone() {
1064            Message::User { content, .. } => {
1065                assert_eq!(
1066                    content.first(),
1067                    UserContent::Text {
1068                        text: "Hello".to_string()
1069                    }
1070                );
1071            }
1072            _ => panic!("Expected user message"),
1073        }
1074
1075        match converted_assistant_message[0].clone() {
1076            Message::Assistant { content, .. } => {
1077                assert_eq!(
1078                    content[0],
1079                    AssistantContent::Text {
1080                        text: "Hi there!".to_string()
1081                    }
1082                );
1083            }
1084            _ => panic!("Expected assistant message"),
1085        }
1086
1087        let original_user_message: message::Message =
1088            converted_user_message[0].clone().try_into().unwrap();
1089        let original_assistant_message: message::Message =
1090            converted_assistant_message[0].clone().try_into().unwrap();
1091
1092        assert_eq!(original_user_message, user_message);
1093        assert_eq!(original_assistant_message, assistant_message);
1094    }
1095
1096    #[test]
1097    fn test_message_from_message_conversion() {
1098        let user_message = Message::User {
1099            content: OneOrMany::one(UserContent::Text {
1100                text: "Hello".to_string(),
1101            }),
1102        };
1103
1104        let assistant_message = Message::Assistant {
1105            content: vec![AssistantContent::Text {
1106                text: "Hi there!".to_string(),
1107            }],
1108            tool_calls: vec![],
1109        };
1110
1111        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1112        let converted_assistant_message: message::Message =
1113            assistant_message.clone().try_into().unwrap();
1114
1115        match converted_user_message.clone() {
1116            message::Message::User { content } => {
1117                assert_eq!(content.first(), message::UserContent::text("Hello"));
1118            }
1119            _ => panic!("Expected user message"),
1120        }
1121
1122        match converted_assistant_message.clone() {
1123            message::Message::Assistant { content, .. } => {
1124                assert_eq!(
1125                    content.first(),
1126                    message::AssistantContent::text("Hi there!")
1127                );
1128            }
1129            _ => panic!("Expected assistant message"),
1130        }
1131
1132        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
1133        let original_assistant_message: Vec<Message> =
1134            converted_assistant_message.try_into().unwrap();
1135
1136        assert_eq!(original_user_message[0], user_message);
1137        assert_eq!(original_assistant_message[0], assistant_message);
1138    }
1139
1140    #[test]
1141    fn test_responses() {
1142        let fireworks_response_json = r#"
1143        {
1144            "choices": [
1145                {
1146                    "finish_reason": "tool_calls",
1147                    "index": 0,
1148                    "message": {
1149                        "role": "assistant",
1150                        "tool_calls": [
1151                            {
1152                                "function": {
1153                                "arguments": "{\"x\": 2, \"y\": 5}",
1154                                "name": "subtract"
1155                                },
1156                                "id": "call_1BspL6mQqjKgvsQbH1TIYkHf",
1157                                "index": 0,
1158                                "type": "function"
1159                            }
1160                        ]
1161                    }
1162                }
1163            ],
1164            "created": 1740704000,
1165            "id": "2a81f6a1-4866-42fb-9902-2655a2b5b1ff",
1166            "model": "accounts/fireworks/models/deepseek-v3",
1167            "object": "chat.completion",
1168            "usage": {
1169                "completion_tokens": 26,
1170                "prompt_tokens": 248,
1171                "total_tokens": 274
1172            }
1173        }
1174        "#;
1175
1176        let novita_response_json = r#"
1177        {
1178            "choices": [
1179                {
1180                    "finish_reason": "tool_calls",
1181                    "index": 0,
1182                    "logprobs": null,
1183                    "message": {
1184                        "audio": null,
1185                        "content": null,
1186                        "function_call": null,
1187                        "reasoning_content": null,
1188                        "refusal": null,
1189                        "role": "assistant",
1190                        "tool_calls": [
1191                            {
1192                                "function": {
1193                                    "arguments": "{\"x\": \"2\", \"y\": \"5\"}",
1194                                    "name": "subtract"
1195                                },
1196                                "id": "chatcmpl-tool-f6d2af7c8dc041058f95e2c2eede45c5",
1197                                "type": "function"
1198                            }
1199                        ]
1200                    },
1201                    "stop_reason": 128008
1202                }
1203            ],
1204            "created": 1740704592,
1205            "id": "chatcmpl-a92c60ae125c47c998ecdcb53387fed4",
1206            "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-fast",
1207            "object": "chat.completion",
1208            "prompt_logprobs": null,
1209            "service_tier": null,
1210            "system_fingerprint": null,
1211            "usage": {
1212                "completion_tokens": 28,
1213                "completion_tokens_details": null,
1214                "prompt_tokens": 335,
1215                "prompt_tokens_details": null,
1216                "total_tokens": 363
1217            }
1218        }
1219        "#;
1220
1221        let _firework_response: CompletionResponse = {
1222            let jd = &mut serde_json::Deserializer::from_str(fireworks_response_json);
1223            deserialize(jd).unwrap_or_else(|err| {
1224                panic!(
1225                    "Deserialization error at {} ({}:{}): {}",
1226                    err.path(),
1227                    err.inner().line(),
1228                    err.inner().column(),
1229                    err
1230                );
1231            })
1232        };
1233
1234        let _novita_response: CompletionResponse = {
1235            let jd = &mut serde_json::Deserializer::from_str(novita_response_json);
1236            deserialize(jd).unwrap_or_else(|err| {
1237                panic!(
1238                    "Deserialization error at {} ({}:{}): {}",
1239                    err.path(),
1240                    err.inner().line(),
1241                    err.inner().column(),
1242                    err
1243                );
1244            })
1245        };
1246    }
1247
1248    #[test]
1249    fn test_assistant_reasoning_is_silently_skipped() {
1250        let assistant = message::Message::Assistant {
1251            id: None,
1252            content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1253        };
1254
1255        let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
1256        assert!(converted.is_empty());
1257    }
1258
1259    #[test]
1260    fn test_assistant_text_and_tool_call_are_preserved_when_reasoning_present() {
1261        let assistant = message::Message::Assistant {
1262            id: None,
1263            content: OneOrMany::many(vec![
1264                message::AssistantContent::reasoning("hidden"),
1265                message::AssistantContent::text("visible"),
1266                message::AssistantContent::tool_call(
1267                    "call_1",
1268                    "subtract",
1269                    serde_json::json!({"x": 2, "y": 1}),
1270                ),
1271            ])
1272            .expect("non-empty assistant content"),
1273        };
1274
1275        let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
1276        assert_eq!(converted.len(), 1);
1277
1278        match &converted[0] {
1279            Message::Assistant {
1280                content,
1281                tool_calls,
1282                ..
1283            } => {
1284                assert_eq!(
1285                    content,
1286                    &vec![AssistantContent::Text {
1287                        text: "visible".to_string()
1288                    }]
1289                );
1290                assert_eq!(tool_calls.len(), 1);
1291                assert_eq!(tool_calls[0].id, "call_1");
1292                assert_eq!(tool_calls[0].function.name, "subtract");
1293                assert_eq!(
1294                    tool_calls[0].function.arguments,
1295                    serde_json::json!({"x": 2, "y": 1})
1296                );
1297            }
1298            _ => panic!("expected assistant message"),
1299        }
1300    }
1301
1302    #[test]
1303    fn test_request_conversion_errors_when_all_messages_are_filtered() {
1304        let request = completion::CompletionRequest {
1305            preamble: None,
1306            chat_history: OneOrMany::one(message::Message::Assistant {
1307                id: None,
1308                content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1309            }),
1310            documents: vec![],
1311            tools: vec![],
1312            temperature: None,
1313            max_tokens: None,
1314            tool_choice: None,
1315            additional_params: None,
1316            model: None,
1317            output_schema: None,
1318        };
1319
1320        let result = HuggingfaceCompletionRequest::try_from(("meta/test-model", request));
1321        assert!(matches!(result, Err(CompletionError::RequestError(_))));
1322    }
1323}