rig/providers/huggingface/
completion.rs

1use super::client::Client;
2use crate::completion::GetTokenUsage;
3use crate::providers::openai::StreamingCompletionResponse;
4use crate::telemetry::SpanCombinator;
5use crate::{
6    OneOrMany,
7    completion::{self, CompletionError, CompletionRequest},
8    json_utils,
9    message::{self},
10    one_or_many::string_or_one_or_many,
11};
12use serde::{Deserialize, Deserializer, Serialize};
13use serde_json::{Value, json};
14use std::{convert::Infallible, str::FromStr};
15use tracing::info_span;
16
17#[derive(Debug, Deserialize)]
18#[serde(untagged)]
19pub enum ApiResponse<T> {
20    Ok(T),
21    Err(Value),
22}
23
24// ================================================================
25// Huggingface Completion API
26// ================================================================
27
28// Conversational LLMs
29
30/// `google/gemma-2-2b-it` completion model
31pub const GEMMA_2: &str = "google/gemma-2-2b-it";
32/// `meta-llama/Meta-Llama-3.1-8B-Instruct` completion model
33pub const META_LLAMA_3_1: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
34/// `microsoft/phi-4` completion model
35pub const PHI_4: &str = "microsoft/phi-4";
36/// `PowerInfer/SmallThinker-3B-Preview` completion model
37pub const SMALLTHINKER_PREVIEW: &str = "PowerInfer/SmallThinker-3B-Preview";
38/// `Qwen/Qwen2.5-7B-Instruct` completion model
39pub const QWEN2_5: &str = "Qwen/Qwen2.5-7B-Instruct";
40/// `Qwen/Qwen2.5-Coder-32B-Instruct` completion model
41pub const QWEN2_5_CODER: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
42
43// Conversational VLMs
44
45/// `Qwen/Qwen2-VL-7B-Instruct` visual-language completion model
46pub const QWEN2_VL: &str = "Qwen/Qwen2-VL-7B-Instruct";
47/// `Qwen/QVQ-72B-Preview` visual-language completion model
48pub const QWEN_QVQ_PREVIEW: &str = "Qwen/QVQ-72B-Preview";
49
50#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
51pub struct Function {
52    name: String,
53    #[serde(deserialize_with = "deserialize_arguments")]
54    pub arguments: serde_json::Value,
55}
56
57fn deserialize_arguments<'de, D>(deserializer: D) -> Result<Value, D::Error>
58where
59    D: Deserializer<'de>,
60{
61    let value = Value::deserialize(deserializer)?;
62
63    match value {
64        Value::String(s) => serde_json::from_str(&s).map_err(serde::de::Error::custom),
65        other => Ok(other),
66    }
67}
68
69impl From<Function> for message::ToolFunction {
70    fn from(value: Function) -> Self {
71        message::ToolFunction {
72            name: value.name,
73            arguments: value.arguments,
74        }
75    }
76}
77
78#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
79#[serde(rename_all = "lowercase")]
80pub enum ToolType {
81    #[default]
82    Function,
83}
84
85#[derive(Debug, Deserialize, Serialize, Clone)]
86pub struct ToolDefinition {
87    pub r#type: String,
88    pub function: completion::ToolDefinition,
89}
90
91impl From<completion::ToolDefinition> for ToolDefinition {
92    fn from(tool: completion::ToolDefinition) -> Self {
93        Self {
94            r#type: "function".into(),
95            function: tool,
96        }
97    }
98}
99
100#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
101pub struct ToolCall {
102    pub id: String,
103    pub r#type: ToolType,
104    pub function: Function,
105}
106
107impl From<ToolCall> for message::ToolCall {
108    fn from(value: ToolCall) -> Self {
109        message::ToolCall {
110            id: value.id,
111            call_id: None,
112            function: value.function.into(),
113        }
114    }
115}
116
117impl From<message::ToolCall> for ToolCall {
118    fn from(value: message::ToolCall) -> Self {
119        ToolCall {
120            id: value.id,
121            r#type: ToolType::Function,
122            function: Function {
123                name: value.function.name,
124                arguments: value.function.arguments,
125            },
126        }
127    }
128}
129
130#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
131pub struct ImageUrl {
132    url: String,
133}
134
135#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
136#[serde(tag = "type", rename_all = "lowercase")]
137pub enum UserContent {
138    Text {
139        text: String,
140    },
141    #[serde(rename = "image_url")]
142    ImageUrl {
143        image_url: ImageUrl,
144    },
145}
146
147impl FromStr for UserContent {
148    type Err = Infallible;
149
150    fn from_str(s: &str) -> Result<Self, Self::Err> {
151        Ok(UserContent::Text {
152            text: s.to_string(),
153        })
154    }
155}
156
157#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
158#[serde(tag = "type", rename_all = "lowercase")]
159pub enum AssistantContent {
160    Text { text: String },
161}
162
163impl FromStr for AssistantContent {
164    type Err = Infallible;
165
166    fn from_str(s: &str) -> Result<Self, Self::Err> {
167        Ok(AssistantContent::Text {
168            text: s.to_string(),
169        })
170    }
171}
172
173#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
174#[serde(tag = "type", rename_all = "lowercase")]
175pub enum SystemContent {
176    Text { text: String },
177}
178
179impl FromStr for SystemContent {
180    type Err = Infallible;
181
182    fn from_str(s: &str) -> Result<Self, Self::Err> {
183        Ok(SystemContent::Text {
184            text: s.to_string(),
185        })
186    }
187}
188
189impl From<UserContent> for message::UserContent {
190    fn from(value: UserContent) -> Self {
191        match value {
192            UserContent::Text { text } => message::UserContent::text(text),
193            UserContent::ImageUrl { image_url } => {
194                message::UserContent::image_url(image_url.url, None, None)
195            }
196        }
197    }
198}
199
200impl TryFrom<message::UserContent> for UserContent {
201    type Error = message::MessageError;
202
203    fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
204        match content {
205            message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }),
206            message::UserContent::Document(message::Document {
207                data: message::DocumentSourceKind::Raw(raw),
208                ..
209            }) => {
210                let text = String::from_utf8_lossy(raw.as_slice()).into();
211                Ok(UserContent::Text { text })
212            }
213            message::UserContent::Document(message::Document {
214                data:
215                    message::DocumentSourceKind::Base64(text)
216                    | message::DocumentSourceKind::String(text),
217                ..
218            }) => Ok(UserContent::Text { text }),
219            message::UserContent::Image(message::Image { data, .. }) => match data {
220                message::DocumentSourceKind::Url(url) => Ok(UserContent::ImageUrl {
221                    image_url: ImageUrl { url },
222                }),
223                _ => Err(message::MessageError::ConversionError(
224                    "Huggingface only supports images as urls".into(),
225                )),
226            },
227            _ => Err(message::MessageError::ConversionError(
228                "Huggingface only supports text and images".into(),
229            )),
230        }
231    }
232}
233
234#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
235#[serde(tag = "role", rename_all = "lowercase")]
236pub enum Message {
237    System {
238        #[serde(deserialize_with = "string_or_one_or_many")]
239        content: OneOrMany<SystemContent>,
240    },
241    User {
242        #[serde(deserialize_with = "string_or_one_or_many")]
243        content: OneOrMany<UserContent>,
244    },
245    Assistant {
246        #[serde(default, deserialize_with = "json_utils::string_or_vec")]
247        content: Vec<AssistantContent>,
248        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
249        tool_calls: Vec<ToolCall>,
250    },
251    #[serde(rename = "Tool")]
252    ToolResult {
253        name: String,
254        #[serde(skip_serializing_if = "Option::is_none")]
255        arguments: Option<serde_json::Value>,
256        #[serde(deserialize_with = "string_or_one_or_many")]
257        content: OneOrMany<String>,
258    },
259}
260
261impl Message {
262    pub fn system(content: &str) -> Self {
263        Message::System {
264            content: OneOrMany::one(SystemContent::Text {
265                text: content.to_string(),
266            }),
267        }
268    }
269}
270
271impl TryFrom<message::Message> for Vec<Message> {
272    type Error = message::MessageError;
273
274    fn try_from(message: message::Message) -> Result<Vec<Message>, Self::Error> {
275        match message {
276            message::Message::User { content } => {
277                let (tool_results, other_content): (Vec<_>, Vec<_>) = content
278                    .into_iter()
279                    .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
280
281                if !tool_results.is_empty() {
282                    tool_results
283                        .into_iter()
284                        .map(|content| match content {
285                            message::UserContent::ToolResult(message::ToolResult {
286                                id,
287                                content,
288                                ..
289                            }) => Ok::<_, message::MessageError>(Message::ToolResult {
290                                name: id,
291                                arguments: None,
292                                content: content.try_map(|content| match content {
293                                    message::ToolResultContent::Text(message::Text { text }) => {
294                                        Ok(text)
295                                    }
296                                    _ => Err(message::MessageError::ConversionError(
297                                        "Tool result content does not support non-text".into(),
298                                    )),
299                                })?,
300                            }),
301                            _ => unreachable!(),
302                        })
303                        .collect::<Result<Vec<_>, _>>()
304                } else {
305                    let other_content = OneOrMany::many(other_content).expect(
306                        "There must be other content here if there were no tool result content",
307                    );
308
309                    Ok(vec![Message::User {
310                        content: other_content.try_map(|content| match content {
311                            message::UserContent::Text(text) => {
312                                Ok(UserContent::Text { text: text.text })
313                            }
314                            message::UserContent::Image(image) => {
315                                let url = image.try_into_url()?;
316
317                                Ok(UserContent::ImageUrl {
318                                    image_url: ImageUrl { url },
319                                })
320                            }
321                            message::UserContent::Document(message::Document {
322                                data: message::DocumentSourceKind::Raw(raw), ..
323                            }) => {
324                                let text = String::from_utf8_lossy(raw.as_slice()).into();
325                                Ok(UserContent::Text { text })
326                            }
327                            message::UserContent::Document(message::Document {
328                                data: message::DocumentSourceKind::Base64(text) | message::DocumentSourceKind::String(text), ..
329                            }) => {
330                                Ok(UserContent::Text { text })
331                            }
332                            _ => Err(message::MessageError::ConversionError(
333                                "Huggingface inputs only support text and image URLs (both base64-encoded images and regular URLs)".into(),
334                            )),
335                        })?,
336                    }])
337                }
338            }
339            message::Message::Assistant { content, .. } => {
340                let (text_content, tool_calls) = content.into_iter().fold(
341                    (Vec::new(), Vec::new()),
342                    |(mut texts, mut tools), content| {
343                        match content {
344                            message::AssistantContent::Text(text) => texts.push(text),
345                            message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
346                            message::AssistantContent::Reasoning(_) => {
347                                unimplemented!("Reasoning is not supported on HuggingFace via Rig");
348                            }
349                        }
350                        (texts, tools)
351                    },
352                );
353
354                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
355                //  so either `content` or `tool_calls` will have some content.
356                Ok(vec![Message::Assistant {
357                    content: text_content
358                        .into_iter()
359                        .map(|content| AssistantContent::Text { text: content.text })
360                        .collect::<Vec<_>>(),
361                    tool_calls: tool_calls
362                        .into_iter()
363                        .map(|tool_call| tool_call.into())
364                        .collect::<Vec<_>>(),
365                }])
366            }
367        }
368    }
369}
370
371impl TryFrom<Message> for message::Message {
372    type Error = message::MessageError;
373
374    fn try_from(message: Message) -> Result<Self, Self::Error> {
375        Ok(match message {
376            Message::User { content, .. } => message::Message::User {
377                content: content.map(|content| content.into()),
378            },
379            Message::Assistant {
380                content,
381                tool_calls,
382                ..
383            } => {
384                let mut content = content
385                    .into_iter()
386                    .map(|content| match content {
387                        AssistantContent::Text { text } => message::AssistantContent::text(text),
388                    })
389                    .collect::<Vec<_>>();
390
391                content.extend(
392                    tool_calls
393                        .into_iter()
394                        .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
395                        .collect::<Result<Vec<_>, _>>()?,
396                );
397
398                message::Message::Assistant {
399                    id: None,
400                    content: OneOrMany::many(content).map_err(|_| {
401                        message::MessageError::ConversionError(
402                            "Neither `content` nor `tool_calls` was provided to the Message"
403                                .to_owned(),
404                        )
405                    })?,
406                }
407            }
408
409            Message::ToolResult { name, content, .. } => message::Message::User {
410                content: OneOrMany::one(message::UserContent::tool_result(
411                    name,
412                    content.map(message::ToolResultContent::text),
413                )),
414            },
415
416            // System messages should get stripped out when converting message's, this is just a
417            // stop gap to avoid obnoxious error handling or panic occurring.
418            Message::System { content, .. } => message::Message::User {
419                content: content.map(|c| match c {
420                    SystemContent::Text { text } => message::UserContent::text(text),
421                }),
422            },
423        })
424    }
425}
426
427#[derive(Clone, Debug, Deserialize, Serialize)]
428pub struct Choice {
429    pub finish_reason: String,
430    pub index: usize,
431    #[serde(default)]
432    pub logprobs: serde_json::Value,
433    pub message: Message,
434}
435
436#[derive(Debug, Deserialize, Clone, Serialize)]
437pub struct Usage {
438    pub completion_tokens: i32,
439    pub prompt_tokens: i32,
440    pub total_tokens: i32,
441}
442
443impl GetTokenUsage for Usage {
444    fn token_usage(&self) -> Option<crate::completion::Usage> {
445        let mut usage = crate::completion::Usage::new();
446        usage.input_tokens = self.prompt_tokens as u64;
447        usage.output_tokens = self.completion_tokens as u64;
448        usage.total_tokens = self.total_tokens as u64;
449
450        Some(usage)
451    }
452}
453
454#[derive(Clone, Debug, Deserialize, Serialize)]
455pub struct CompletionResponse {
456    pub created: i32,
457    pub id: String,
458    pub model: String,
459    pub choices: Vec<Choice>,
460    #[serde(default, deserialize_with = "default_string_on_null")]
461    pub system_fingerprint: String,
462    pub usage: Usage,
463}
464
465impl crate::telemetry::ProviderResponseExt for CompletionResponse {
466    type OutputMessage = Choice;
467    type Usage = Usage;
468
469    fn get_response_id(&self) -> Option<String> {
470        Some(self.id.clone())
471    }
472
473    fn get_response_model_name(&self) -> Option<String> {
474        Some(self.model.clone())
475    }
476
477    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
478        self.choices.clone()
479    }
480
481    fn get_text_response(&self) -> Option<String> {
482        let text_response = self
483            .choices
484            .iter()
485            .filter_map(|x| {
486                let Message::User { ref content } = x.message else {
487                    return None;
488                };
489
490                let text = content
491                    .iter()
492                    .filter_map(|x| {
493                        if let UserContent::Text { text } = x {
494                            Some(text.clone())
495                        } else {
496                            None
497                        }
498                    })
499                    .collect::<Vec<String>>();
500
501                if text.is_empty() {
502                    None
503                } else {
504                    Some(text.join("\n"))
505                }
506            })
507            .collect::<Vec<String>>()
508            .join("\n");
509
510        if text_response.is_empty() {
511            None
512        } else {
513            Some(text_response)
514        }
515    }
516
517    fn get_usage(&self) -> Option<Self::Usage> {
518        Some(self.usage.clone())
519    }
520}
521
522fn default_string_on_null<'de, D>(deserializer: D) -> Result<String, D::Error>
523where
524    D: Deserializer<'de>,
525{
526    match Option::<String>::deserialize(deserializer)? {
527        Some(value) => Ok(value),      // Use provided value
528        None => Ok(String::default()), // Use `Default` implementation
529    }
530}
531
532impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
533    type Error = CompletionError;
534
535    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
536        let choice = response.choices.first().ok_or_else(|| {
537            CompletionError::ResponseError("Response contained no choices".to_owned())
538        })?;
539
540        let content = match &choice.message {
541            Message::Assistant {
542                content,
543                tool_calls,
544                ..
545            } => {
546                let mut content = content
547                    .iter()
548                    .map(|c| match c {
549                        AssistantContent::Text { text } => message::AssistantContent::text(text),
550                    })
551                    .collect::<Vec<_>>();
552
553                content.extend(
554                    tool_calls
555                        .iter()
556                        .map(|call| {
557                            completion::AssistantContent::tool_call(
558                                &call.id,
559                                &call.function.name,
560                                call.function.arguments.clone(),
561                            )
562                        })
563                        .collect::<Vec<_>>(),
564                );
565                Ok(content)
566            }
567            _ => Err(CompletionError::ResponseError(
568                "Response did not contain a valid message or tool call".into(),
569            )),
570        }?;
571
572        let choice = OneOrMany::many(content).map_err(|_| {
573            CompletionError::ResponseError(
574                "Response contained no message or tool call (empty)".to_owned(),
575            )
576        })?;
577
578        let usage = completion::Usage {
579            input_tokens: response.usage.prompt_tokens as u64,
580            output_tokens: response.usage.completion_tokens as u64,
581            total_tokens: response.usage.total_tokens as u64,
582        };
583
584        Ok(completion::CompletionResponse {
585            choice,
586            usage,
587            raw_response: response,
588        })
589    }
590}
591
592#[derive(Clone)]
593pub struct CompletionModel<T = reqwest::Client> {
594    pub(crate) client: Client<T>,
595    /// Name of the model (e.g: google/gemma-2-2b-it)
596    pub model: String,
597}
598
599impl<T> CompletionModel<T> {
600    pub fn new(client: Client<T>, model: &str) -> Self {
601        Self {
602            client,
603            model: model.to_string(),
604        }
605    }
606
607    pub(crate) fn create_request_body(
608        &self,
609        completion_request: &CompletionRequest,
610    ) -> Result<serde_json::Value, CompletionError> {
611        let mut full_history: Vec<Message> = match &completion_request.preamble {
612            Some(preamble) => vec![Message::system(preamble)],
613            None => vec![],
614        };
615        if let Some(docs) = completion_request.normalized_documents() {
616            let docs: Vec<Message> = docs.try_into()?;
617            full_history.extend(docs);
618        }
619
620        let chat_history: Vec<Message> = completion_request
621            .chat_history
622            .clone()
623            .into_iter()
624            .map(|message| message.try_into())
625            .collect::<Result<Vec<Vec<Message>>, _>>()?
626            .into_iter()
627            .flatten()
628            .collect();
629
630        full_history.extend(chat_history);
631
632        let model = self.client.sub_provider.model_identifier(&self.model);
633
634        let tool_choice = completion_request
635            .tool_choice
636            .clone()
637            .map(crate::providers::openai::completion::ToolChoice::try_from)
638            .transpose()?;
639
640        let request = if completion_request.tools.is_empty() {
641            json!({
642                "model": model,
643                "messages": full_history,
644                "temperature": completion_request.temperature,
645            })
646        } else {
647            json!({
648                "model": model,
649                "messages": full_history,
650                "temperature": completion_request.temperature,
651                "tools": completion_request.tools.clone().into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
652                "tool_choice": tool_choice,
653            })
654        };
655        Ok(request)
656    }
657}
658
659impl completion::CompletionModel for CompletionModel<reqwest::Client> {
660    type Response = CompletionResponse;
661    type StreamingResponse = StreamingCompletionResponse;
662
663    #[cfg_attr(feature = "worker", worker::send)]
664    async fn completion(
665        &self,
666        completion_request: CompletionRequest,
667    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
668        let span = if tracing::Span::current().is_disabled() {
669            info_span!(
670                target: "rig::completions",
671                "chat",
672                gen_ai.operation.name = "chat",
673                gen_ai.provider.name = "huggingface",
674                gen_ai.request.model = self.model,
675                gen_ai.system_instructions = &completion_request.preamble,
676                gen_ai.response.id = tracing::field::Empty,
677                gen_ai.response.model = tracing::field::Empty,
678                gen_ai.usage.output_tokens = tracing::field::Empty,
679                gen_ai.usage.input_tokens = tracing::field::Empty,
680                gen_ai.input.messages = tracing::field::Empty,
681                gen_ai.output.messages = tracing::field::Empty,
682            )
683        } else {
684            tracing::Span::current()
685        };
686        let request = self.create_request_body(&completion_request)?;
687        span.record_model_input(&request.get("messages"));
688
689        let path = self.client.sub_provider.completion_endpoint(&self.model);
690
691        let request = if let Some(ref params) = completion_request.additional_params {
692            json_utils::merge(request, params.clone())
693        } else {
694            request
695        };
696
697        let request = serde_json::to_vec(&request)?;
698
699        let request = self
700            .client
701            .post(&path)?
702            .header("Content-Type", "application/json")
703            .body(request)
704            .map_err(|e| CompletionError::HttpError(e.into()))?;
705
706        let response = self.client.send(request).await?;
707
708        if response.status().is_success() {
709            let bytes: Vec<u8> = response.into_body().await?;
710            let text = String::from_utf8_lossy(&bytes);
711
712            tracing::debug!(target: "rig", "Huggingface completion error: {}", text);
713
714            match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&bytes)? {
715                ApiResponse::Ok(response) => {
716                    let span = tracing::Span::current();
717                    span.record_token_usage(&response.usage);
718                    span.record_model_output(&response.choices);
719                    span.record_response_metadata(&response);
720
721                    response.try_into()
722                }
723                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())),
724            }
725        } else {
726            let status = response.status();
727            let text: Vec<u8> = response.into_body().await?;
728            let text: String = String::from_utf8_lossy(&text).into();
729
730            Err(CompletionError::ProviderError(format!(
731                "{}: {}",
732                status, text
733            )))
734        }
735    }
736
737    #[cfg_attr(feature = "worker", worker::send)]
738    async fn stream(
739        &self,
740        request: CompletionRequest,
741    ) -> Result<
742        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
743        CompletionError,
744    > {
745        CompletionModel::stream(self, request).await
746    }
747}
748
749#[cfg(test)]
750mod tests {
751    use super::*;
752    use serde_path_to_error::deserialize;
753
754    #[test]
755    fn test_deserialize_message() {
756        let assistant_message_json = r#"
757        {
758            "role": "assistant",
759            "content": "\n\nHello there, how may I assist you today?"
760        }
761        "#;
762
763        let assistant_message_json2 = r#"
764        {
765            "role": "assistant",
766            "content": [
767                {
768                    "type": "text",
769                    "text": "\n\nHello there, how may I assist you today?"
770                }
771            ],
772            "tool_calls": null
773        }
774        "#;
775
776        let assistant_message_json3 = r#"
777        {
778            "role": "assistant",
779            "tool_calls": [
780                {
781                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
782                    "type": "function",
783                    "function": {
784                        "name": "subtract",
785                        "arguments": {"x": 2, "y": 5}
786                    }
787                }
788            ],
789            "content": null,
790            "refusal": null
791        }
792        "#;
793
794        let user_message_json = r#"
795        {
796            "role": "user",
797            "content": [
798                {
799                    "type": "text",
800                    "text": "What's in this image?"
801                },
802                {
803                    "type": "image_url",
804                    "image_url": {
805                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
806                    }
807                }
808            ]
809        }
810        "#;
811
812        let assistant_message: Message = {
813            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
814            deserialize(jd).unwrap_or_else(|err| {
815                panic!(
816                    "Deserialization error at {} ({}:{}): {}",
817                    err.path(),
818                    err.inner().line(),
819                    err.inner().column(),
820                    err
821                );
822            })
823        };
824
825        let assistant_message2: Message = {
826            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
827            deserialize(jd).unwrap_or_else(|err| {
828                panic!(
829                    "Deserialization error at {} ({}:{}): {}",
830                    err.path(),
831                    err.inner().line(),
832                    err.inner().column(),
833                    err
834                );
835            })
836        };
837
838        let assistant_message3: Message = {
839            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
840                &mut serde_json::Deserializer::from_str(assistant_message_json3);
841            deserialize(jd).unwrap_or_else(|err| {
842                panic!(
843                    "Deserialization error at {} ({}:{}): {}",
844                    err.path(),
845                    err.inner().line(),
846                    err.inner().column(),
847                    err
848                );
849            })
850        };
851
852        let user_message: Message = {
853            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
854            deserialize(jd).unwrap_or_else(|err| {
855                panic!(
856                    "Deserialization error at {} ({}:{}): {}",
857                    err.path(),
858                    err.inner().line(),
859                    err.inner().column(),
860                    err
861                );
862            })
863        };
864
865        match assistant_message {
866            Message::Assistant { content, .. } => {
867                assert_eq!(
868                    content[0],
869                    AssistantContent::Text {
870                        text: "\n\nHello there, how may I assist you today?".to_string()
871                    }
872                );
873            }
874            _ => panic!("Expected assistant message"),
875        }
876
877        match assistant_message2 {
878            Message::Assistant {
879                content,
880                tool_calls,
881                ..
882            } => {
883                assert_eq!(
884                    content[0],
885                    AssistantContent::Text {
886                        text: "\n\nHello there, how may I assist you today?".to_string()
887                    }
888                );
889
890                assert_eq!(tool_calls, vec![]);
891            }
892            _ => panic!("Expected assistant message"),
893        }
894
895        match assistant_message3 {
896            Message::Assistant {
897                content,
898                tool_calls,
899                ..
900            } => {
901                assert!(content.is_empty());
902                assert_eq!(
903                    tool_calls[0],
904                    ToolCall {
905                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
906                        r#type: ToolType::Function,
907                        function: Function {
908                            name: "subtract".to_string(),
909                            arguments: serde_json::json!({"x": 2, "y": 5}),
910                        },
911                    }
912                );
913            }
914            _ => panic!("Expected assistant message"),
915        }
916
917        match user_message {
918            Message::User { content, .. } => {
919                let (first, second) = {
920                    let mut iter = content.into_iter();
921                    (iter.next().unwrap(), iter.next().unwrap())
922                };
923                assert_eq!(
924                    first,
925                    UserContent::Text {
926                        text: "What's in this image?".to_string()
927                    }
928                );
929                assert_eq!(second, UserContent::ImageUrl { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string() } });
930            }
931            _ => panic!("Expected user message"),
932        }
933    }
934
935    #[test]
936    fn test_message_to_message_conversion() {
937        let user_message = message::Message::User {
938            content: OneOrMany::one(message::UserContent::text("Hello")),
939        };
940
941        let assistant_message = message::Message::Assistant {
942            id: None,
943            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
944        };
945
946        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
947        let converted_assistant_message: Vec<Message> =
948            assistant_message.clone().try_into().unwrap();
949
950        match converted_user_message[0].clone() {
951            Message::User { content, .. } => {
952                assert_eq!(
953                    content.first(),
954                    UserContent::Text {
955                        text: "Hello".to_string()
956                    }
957                );
958            }
959            _ => panic!("Expected user message"),
960        }
961
962        match converted_assistant_message[0].clone() {
963            Message::Assistant { content, .. } => {
964                assert_eq!(
965                    content[0],
966                    AssistantContent::Text {
967                        text: "Hi there!".to_string()
968                    }
969                );
970            }
971            _ => panic!("Expected assistant message"),
972        }
973
974        let original_user_message: message::Message =
975            converted_user_message[0].clone().try_into().unwrap();
976        let original_assistant_message: message::Message =
977            converted_assistant_message[0].clone().try_into().unwrap();
978
979        assert_eq!(original_user_message, user_message);
980        assert_eq!(original_assistant_message, assistant_message);
981    }
982
983    #[test]
984    fn test_message_from_message_conversion() {
985        let user_message = Message::User {
986            content: OneOrMany::one(UserContent::Text {
987                text: "Hello".to_string(),
988            }),
989        };
990
991        let assistant_message = Message::Assistant {
992            content: vec![AssistantContent::Text {
993                text: "Hi there!".to_string(),
994            }],
995            tool_calls: vec![],
996        };
997
998        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
999        let converted_assistant_message: message::Message =
1000            assistant_message.clone().try_into().unwrap();
1001
1002        match converted_user_message.clone() {
1003            message::Message::User { content } => {
1004                assert_eq!(content.first(), message::UserContent::text("Hello"));
1005            }
1006            _ => panic!("Expected user message"),
1007        }
1008
1009        match converted_assistant_message.clone() {
1010            message::Message::Assistant { content, .. } => {
1011                assert_eq!(
1012                    content.first(),
1013                    message::AssistantContent::text("Hi there!")
1014                );
1015            }
1016            _ => panic!("Expected assistant message"),
1017        }
1018
1019        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
1020        let original_assistant_message: Vec<Message> =
1021            converted_assistant_message.try_into().unwrap();
1022
1023        assert_eq!(original_user_message[0], user_message);
1024        assert_eq!(original_assistant_message[0], assistant_message);
1025    }
1026
1027    #[test]
1028    fn test_responses() {
1029        let fireworks_response_json = r#"
1030        {
1031            "choices": [
1032                {
1033                    "finish_reason": "tool_calls",
1034                    "index": 0,
1035                    "message": {
1036                        "role": "assistant",
1037                        "tool_calls": [
1038                            {
1039                                "function": {
1040                                "arguments": "{\"x\": 2, \"y\": 5}",
1041                                "name": "subtract"
1042                                },
1043                                "id": "call_1BspL6mQqjKgvsQbH1TIYkHf",
1044                                "index": 0,
1045                                "type": "function"
1046                            }
1047                        ]
1048                    }
1049                }
1050            ],
1051            "created": 1740704000,
1052            "id": "2a81f6a1-4866-42fb-9902-2655a2b5b1ff",
1053            "model": "accounts/fireworks/models/deepseek-v3",
1054            "object": "chat.completion",
1055            "usage": {
1056                "completion_tokens": 26,
1057                "prompt_tokens": 248,
1058                "total_tokens": 274
1059            }
1060        }
1061        "#;
1062
1063        let novita_response_json = r#"
1064        {
1065            "choices": [
1066                {
1067                    "finish_reason": "tool_calls",
1068                    "index": 0,
1069                    "logprobs": null,
1070                    "message": {
1071                        "audio": null,
1072                        "content": null,
1073                        "function_call": null,
1074                        "reasoning_content": null,
1075                        "refusal": null,
1076                        "role": "assistant",
1077                        "tool_calls": [
1078                            {
1079                                "function": {
1080                                    "arguments": "{\"x\": \"2\", \"y\": \"5\"}",
1081                                    "name": "subtract"
1082                                },
1083                                "id": "chatcmpl-tool-f6d2af7c8dc041058f95e2c2eede45c5",
1084                                "type": "function"
1085                            }
1086                        ]
1087                    },
1088                    "stop_reason": 128008
1089                }
1090            ],
1091            "created": 1740704592,
1092            "id": "chatcmpl-a92c60ae125c47c998ecdcb53387fed4",
1093            "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-fast",
1094            "object": "chat.completion",
1095            "prompt_logprobs": null,
1096            "service_tier": null,
1097            "system_fingerprint": null,
1098            "usage": {
1099                "completion_tokens": 28,
1100                "completion_tokens_details": null,
1101                "prompt_tokens": 335,
1102                "prompt_tokens_details": null,
1103                "total_tokens": 363
1104            }
1105        }
1106        "#;
1107
1108        let _firework_response: CompletionResponse = {
1109            let jd = &mut serde_json::Deserializer::from_str(fireworks_response_json);
1110            deserialize(jd).unwrap_or_else(|err| {
1111                panic!(
1112                    "Deserialization error at {} ({}:{}): {}",
1113                    err.path(),
1114                    err.inner().line(),
1115                    err.inner().column(),
1116                    err
1117                );
1118            })
1119        };
1120
1121        let _novita_response: CompletionResponse = {
1122            let jd = &mut serde_json::Deserializer::from_str(novita_response_json);
1123            deserialize(jd).unwrap_or_else(|err| {
1124                panic!(
1125                    "Deserialization error at {} ({}:{}): {}",
1126                    err.path(),
1127                    err.inner().line(),
1128                    err.inner().column(),
1129                    err
1130                );
1131            })
1132        };
1133    }
1134}