rig/providers/huggingface/
completion.rs

1use serde::{Deserialize, Deserializer, Serialize};
2use serde_json::{json, Value};
3use std::{convert::Infallible, str::FromStr};
4
5use super::client::Client;
6use crate::providers::openai::StreamingCompletionResponse;
7use crate::{
8    completion::{self, CompletionError, CompletionRequest},
9    json_utils,
10    message::{self},
11    one_or_many::string_or_one_or_many,
12    OneOrMany,
13};
14
15#[derive(Debug, Deserialize)]
16#[serde(untagged)]
17pub enum ApiResponse<T> {
18    Ok(T),
19    Err(Value),
20}
21
22// ================================================================
23// Huggingface Completion API
24// ================================================================
25
26// Conversational LLMs
27
28/// `google/gemma-2-2b-it` completion model
29pub const GEMMA_2: &str = "google/gemma-2-2b-it";
30/// `meta-llama/Meta-Llama-3.1-8B-Instruct` completion model
31pub const META_LLAMA_3_1: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
32/// `microsoft/phi-4` completion model
33pub const PHI_4: &str = "microsoft/phi-4";
34/// `PowerInfer/SmallThinker-3B-Preview` completion model
35pub const SMALLTHINKER_PREVIEW: &str = "PowerInfer/SmallThinker-3B-Preview";
36/// `Qwen/Qwen2.5-7B-Instruct` completion model
37pub const QWEN2_5: &str = "Qwen/Qwen2.5-7B-Instruct";
38/// `Qwen/Qwen2.5-Coder-32B-Instruct` completion model
39pub const QWEN2_5_CODER: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
40
41// Conversational VLMs
42
43/// `Qwen/Qwen2-VL-7B-Instruct` visual-language completion model
44pub const QWEN2_VL: &str = "Qwen/Qwen2-VL-7B-Instruct";
45/// `Qwen/QVQ-72B-Preview` visual-language completion model
46pub const QWEN_QVQ_PREVIEW: &str = "Qwen/QVQ-72B-Preview";
47
48#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
49pub struct Function {
50    name: String,
51    #[serde(deserialize_with = "deserialize_arguments")]
52    pub arguments: serde_json::Value,
53}
54
55fn deserialize_arguments<'de, D>(deserializer: D) -> Result<Value, D::Error>
56where
57    D: Deserializer<'de>,
58{
59    let value = Value::deserialize(deserializer)?;
60
61    match value {
62        Value::String(s) => serde_json::from_str(&s).map_err(serde::de::Error::custom),
63        other => Ok(other),
64    }
65}
66
67impl From<Function> for message::ToolFunction {
68    fn from(value: Function) -> Self {
69        message::ToolFunction {
70            name: value.name,
71            arguments: value.arguments,
72        }
73    }
74}
75
76#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
77#[serde(rename_all = "lowercase")]
78pub enum ToolType {
79    #[default]
80    Function,
81}
82
83#[derive(Debug, Deserialize, Serialize, Clone)]
84pub struct ToolDefinition {
85    pub r#type: String,
86    pub function: completion::ToolDefinition,
87}
88
89impl From<completion::ToolDefinition> for ToolDefinition {
90    fn from(tool: completion::ToolDefinition) -> Self {
91        Self {
92            r#type: "function".into(),
93            function: tool,
94        }
95    }
96}
97
98#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
99pub struct ToolCall {
100    pub id: String,
101    pub r#type: ToolType,
102    pub function: Function,
103}
104
105impl From<ToolCall> for message::ToolCall {
106    fn from(value: ToolCall) -> Self {
107        message::ToolCall {
108            id: value.id,
109            function: value.function.into(),
110        }
111    }
112}
113
114impl From<message::ToolCall> for ToolCall {
115    fn from(value: message::ToolCall) -> Self {
116        ToolCall {
117            id: value.id,
118            r#type: ToolType::Function,
119            function: Function {
120                name: value.function.name,
121                arguments: value.function.arguments,
122            },
123        }
124    }
125}
126
127#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
128pub struct ImageUrl {
129    url: String,
130}
131
132#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
133#[serde(tag = "type", rename_all = "lowercase")]
134pub enum UserContent {
135    Text {
136        text: String,
137    },
138    #[serde(rename = "image_url")]
139    ImageUrl {
140        image_url: ImageUrl,
141    },
142}
143
144impl FromStr for UserContent {
145    type Err = Infallible;
146
147    fn from_str(s: &str) -> Result<Self, Self::Err> {
148        Ok(UserContent::Text {
149            text: s.to_string(),
150        })
151    }
152}
153
154#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
155#[serde(tag = "type", rename_all = "lowercase")]
156pub enum AssistantContent {
157    Text { text: String },
158}
159
160impl FromStr for AssistantContent {
161    type Err = Infallible;
162
163    fn from_str(s: &str) -> Result<Self, Self::Err> {
164        Ok(AssistantContent::Text {
165            text: s.to_string(),
166        })
167    }
168}
169
170#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
171#[serde(tag = "type", rename_all = "lowercase")]
172pub enum SystemContent {
173    Text { text: String },
174}
175
176impl FromStr for SystemContent {
177    type Err = Infallible;
178
179    fn from_str(s: &str) -> Result<Self, Self::Err> {
180        Ok(SystemContent::Text {
181            text: s.to_string(),
182        })
183    }
184}
185
186impl From<UserContent> for message::UserContent {
187    fn from(value: UserContent) -> Self {
188        match value {
189            UserContent::Text { text } => message::UserContent::text(text),
190            UserContent::ImageUrl { image_url } => message::UserContent::image(
191                image_url.url,
192                Some(message::ContentFormat::String),
193                None,
194                None,
195            ),
196        }
197    }
198}
199
200impl TryFrom<message::UserContent> for UserContent {
201    type Error = message::MessageError;
202
203    fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
204        match content {
205            message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }),
206            message::UserContent::Image(message::Image { data, format, .. }) => match format {
207                Some(message::ContentFormat::String) => Ok(UserContent::ImageUrl {
208                    image_url: ImageUrl { url: data },
209                }),
210                _ => Err(message::MessageError::ConversionError(
211                    "Huggingface only supports images as urls".into(),
212                )),
213            },
214            _ => Err(message::MessageError::ConversionError(
215                "Huggingface only supports text and images".into(),
216            )),
217        }
218    }
219}
220
221#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
222#[serde(tag = "role", rename_all = "lowercase")]
223pub enum Message {
224    System {
225        #[serde(deserialize_with = "string_or_one_or_many")]
226        content: OneOrMany<SystemContent>,
227    },
228    User {
229        #[serde(deserialize_with = "string_or_one_or_many")]
230        content: OneOrMany<UserContent>,
231    },
232    Assistant {
233        #[serde(default, deserialize_with = "json_utils::string_or_vec")]
234        content: Vec<AssistantContent>,
235        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
236        tool_calls: Vec<ToolCall>,
237    },
238    #[serde(rename = "Tool")]
239    ToolResult {
240        name: String,
241        #[serde(skip_serializing_if = "Option::is_none")]
242        arguments: Option<serde_json::Value>,
243        #[serde(deserialize_with = "string_or_one_or_many")]
244        content: OneOrMany<String>,
245    },
246}
247
248impl Message {
249    pub fn system(content: &str) -> Self {
250        Message::System {
251            content: OneOrMany::one(SystemContent::Text {
252                text: content.to_string(),
253            }),
254        }
255    }
256}
257
258impl TryFrom<message::Message> for Vec<Message> {
259    type Error = message::MessageError;
260
261    fn try_from(message: message::Message) -> Result<Vec<Message>, Self::Error> {
262        match message {
263            message::Message::User { content } => {
264                let (tool_results, other_content): (Vec<_>, Vec<_>) = content
265                    .into_iter()
266                    .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
267
268                if !tool_results.is_empty() {
269                    tool_results
270                        .into_iter()
271                        .map(|content| match content {
272                            message::UserContent::ToolResult(message::ToolResult {
273                                id,
274                                content,
275                            }) => Ok::<_, message::MessageError>(Message::ToolResult {
276                                name: id,
277                                arguments: None,
278                                content: content.try_map(|content| match content {
279                                    message::ToolResultContent::Text(message::Text { text }) => {
280                                        Ok(text)
281                                    }
282                                    _ => Err(message::MessageError::ConversionError(
283                                        "Tool result content does not support non-text".into(),
284                                    )),
285                                })?,
286                            }),
287                            _ => unreachable!(),
288                        })
289                        .collect::<Result<Vec<_>, _>>()
290                } else {
291                    let other_content = OneOrMany::many(other_content).expect(
292                        "There must be other content here if there were no tool result content",
293                    );
294
295                    Ok(vec![Message::User {
296                        content: other_content.try_map(|content| match content {
297                            message::UserContent::Text(text) => {
298                                Ok(UserContent::Text { text: text.text })
299                            }
300                            _ => Err(message::MessageError::ConversionError(
301                                "Huggingface does not support non-text".into(),
302                            )),
303                        })?,
304                    }])
305                }
306            }
307            message::Message::Assistant { content } => {
308                let (text_content, tool_calls) = content.into_iter().fold(
309                    (Vec::new(), Vec::new()),
310                    |(mut texts, mut tools), content| {
311                        match content {
312                            message::AssistantContent::Text(text) => texts.push(text),
313                            message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
314                        }
315                        (texts, tools)
316                    },
317                );
318
319                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
320                //  so either `content` or `tool_calls` will have some content.
321                Ok(vec![Message::Assistant {
322                    content: text_content
323                        .into_iter()
324                        .map(|content| AssistantContent::Text { text: content.text })
325                        .collect::<Vec<_>>(),
326                    tool_calls: tool_calls
327                        .into_iter()
328                        .map(|tool_call| tool_call.into())
329                        .collect::<Vec<_>>(),
330                }])
331            }
332        }
333    }
334}
335
336impl TryFrom<Message> for message::Message {
337    type Error = message::MessageError;
338
339    fn try_from(message: Message) -> Result<Self, Self::Error> {
340        Ok(match message {
341            Message::User { content, .. } => message::Message::User {
342                content: content.map(|content| content.into()),
343            },
344            Message::Assistant {
345                content,
346                tool_calls,
347                ..
348            } => {
349                let mut content = content
350                    .into_iter()
351                    .map(|content| match content {
352                        AssistantContent::Text { text } => message::AssistantContent::text(text),
353                    })
354                    .collect::<Vec<_>>();
355
356                content.extend(
357                    tool_calls
358                        .into_iter()
359                        .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
360                        .collect::<Result<Vec<_>, _>>()?,
361                );
362
363                message::Message::Assistant {
364                    content: OneOrMany::many(content).map_err(|_| {
365                        message::MessageError::ConversionError(
366                            "Neither `content` nor `tool_calls` was provided to the Message"
367                                .to_owned(),
368                        )
369                    })?,
370                }
371            }
372
373            Message::ToolResult { name, content, .. } => message::Message::User {
374                content: OneOrMany::one(message::UserContent::tool_result(
375                    name,
376                    content.map(message::ToolResultContent::text),
377                )),
378            },
379
380            // System messages should get stripped out when converting message's, this is just a
381            // stop gap to avoid obnoxious error handling or panic occurring.
382            Message::System { content, .. } => message::Message::User {
383                content: content.map(|c| match c {
384                    SystemContent::Text { text } => message::UserContent::text(text),
385                }),
386            },
387        })
388    }
389}
390
391#[derive(Debug, Deserialize)]
392pub struct Choice {
393    pub finish_reason: String,
394    pub index: usize,
395    #[serde(default)]
396    pub logprobs: serde_json::Value,
397    pub message: Message,
398}
399
400#[derive(Debug, Deserialize, Clone)]
401pub struct Usage {
402    pub completion_tokens: i32,
403    pub prompt_tokens: i32,
404    pub total_tokens: i32,
405}
406
407#[derive(Debug, Deserialize)]
408pub struct CompletionResponse {
409    pub created: i32,
410    pub id: String,
411    pub model: String,
412    pub choices: Vec<Choice>,
413    #[serde(default, deserialize_with = "default_string_on_null")]
414    pub system_fingerprint: String,
415    pub usage: Usage,
416}
417
418fn default_string_on_null<'de, D>(deserializer: D) -> Result<String, D::Error>
419where
420    D: Deserializer<'de>,
421{
422    match Option::<String>::deserialize(deserializer)? {
423        Some(value) => Ok(value),      // Use provided value
424        None => Ok(String::default()), // Use `Default` implementation
425    }
426}
427
428impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
429    type Error = CompletionError;
430
431    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
432        let choice = response.choices.first().ok_or_else(|| {
433            CompletionError::ResponseError("Response contained no choices".to_owned())
434        })?;
435
436        let content = match &choice.message {
437            Message::Assistant {
438                content,
439                tool_calls,
440                ..
441            } => {
442                let mut content = content
443                    .iter()
444                    .map(|c| match c {
445                        AssistantContent::Text { text } => message::AssistantContent::text(text),
446                    })
447                    .collect::<Vec<_>>();
448
449                content.extend(
450                    tool_calls
451                        .iter()
452                        .map(|call| {
453                            completion::AssistantContent::tool_call(
454                                &call.id,
455                                &call.function.name,
456                                call.function.arguments.clone(),
457                            )
458                        })
459                        .collect::<Vec<_>>(),
460                );
461                Ok(content)
462            }
463            _ => Err(CompletionError::ResponseError(
464                "Response did not contain a valid message or tool call".into(),
465            )),
466        }?;
467
468        let choice = OneOrMany::many(content).map_err(|_| {
469            CompletionError::ResponseError(
470                "Response contained no message or tool call (empty)".to_owned(),
471            )
472        })?;
473
474        Ok(completion::CompletionResponse {
475            choice,
476            raw_response: response,
477        })
478    }
479}
480
481#[derive(Clone)]
482pub struct CompletionModel {
483    pub(crate) client: Client,
484    /// Name of the model (e.g: google/gemma-2-2b-it)
485    pub model: String,
486}
487
488impl CompletionModel {
489    pub fn new(client: Client, model: &str) -> Self {
490        Self {
491            client,
492            model: model.to_string(),
493        }
494    }
495
496    pub(crate) fn create_request_body(
497        &self,
498        completion_request: &CompletionRequest,
499    ) -> Result<serde_json::Value, CompletionError> {
500        let mut full_history: Vec<Message> = match &completion_request.preamble {
501            Some(preamble) => vec![Message::system(preamble)],
502            None => vec![],
503        };
504        if let Some(docs) = completion_request.normalized_documents() {
505            let docs: Vec<Message> = docs.try_into()?;
506            full_history.extend(docs);
507        }
508
509        let chat_history: Vec<Message> = completion_request
510            .chat_history
511            .clone()
512            .into_iter()
513            .map(|message| message.try_into())
514            .collect::<Result<Vec<Vec<Message>>, _>>()?
515            .into_iter()
516            .flatten()
517            .collect();
518
519        full_history.extend(chat_history);
520
521        let model = self.client.sub_provider.model_identifier(&self.model);
522
523        let request = if completion_request.tools.is_empty() {
524            json!({
525                "model": model,
526                "messages": full_history,
527                "temperature": completion_request.temperature,
528            })
529        } else {
530            json!({
531                "model": model,
532                "messages": full_history,
533                "temperature": completion_request.temperature,
534                "tools": completion_request.tools.clone().into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
535                "tool_choice": "auto",
536            })
537        };
538        Ok(request)
539    }
540}
541
542impl completion::CompletionModel for CompletionModel {
543    type Response = CompletionResponse;
544    type StreamingResponse = StreamingCompletionResponse;
545
546    #[cfg_attr(feature = "worker", worker::send)]
547    async fn completion(
548        &self,
549        completion_request: CompletionRequest,
550    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
551        let request = self.create_request_body(&completion_request)?;
552
553        let path = self.client.sub_provider.completion_endpoint(&self.model);
554
555        let request = if let Some(ref params) = completion_request.additional_params {
556            json_utils::merge(request, params.clone())
557        } else {
558            request
559        };
560
561        let response = self.client.post(&path).json(&request).send().await?;
562
563        if response.status().is_success() {
564            let t = response.text().await?;
565            tracing::debug!(target: "rig", "Huggingface completion error: {}", t);
566
567            match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
568                ApiResponse::Ok(response) => {
569                    tracing::info!(target: "rig",
570                        "Huggingface completion token usage: {:?}",
571                        format!("{:?}", response.usage)
572                    );
573                    response.try_into()
574                }
575                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())),
576            }
577        } else {
578            Err(CompletionError::ProviderError(format!(
579                "{}: {}",
580                response.status(),
581                response.text().await?
582            )))
583        }
584    }
585
586    #[cfg_attr(feature = "worker", worker::send)]
587    async fn stream(
588        &self,
589        request: CompletionRequest,
590    ) -> Result<
591        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
592        CompletionError,
593    > {
594        CompletionModel::stream(self, request).await
595    }
596}
597
598#[cfg(test)]
599mod tests {
600    use super::*;
601    use serde_path_to_error::deserialize;
602
603    #[test]
604    fn test_deserialize_message() {
605        let assistant_message_json = r#"
606        {
607            "role": "assistant",
608            "content": "\n\nHello there, how may I assist you today?"
609        }
610        "#;
611
612        let assistant_message_json2 = r#"
613        {
614            "role": "assistant",
615            "content": [
616                {
617                    "type": "text",
618                    "text": "\n\nHello there, how may I assist you today?"
619                }
620            ],
621            "tool_calls": null
622        }
623        "#;
624
625        let assistant_message_json3 = r#"
626        {
627            "role": "assistant",
628            "tool_calls": [
629                {
630                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
631                    "type": "function",
632                    "function": {
633                        "name": "subtract",
634                        "arguments": {"x": 2, "y": 5}
635                    }
636                }
637            ],
638            "content": null,
639            "refusal": null
640        }
641        "#;
642
643        let user_message_json = r#"
644        {
645            "role": "user",
646            "content": [
647                {
648                    "type": "text",
649                    "text": "What's in this image?"
650                },
651                {
652                    "type": "image_url",
653                    "image_url": {
654                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
655                    }
656                }
657            ]
658        }
659        "#;
660
661        let assistant_message: Message = {
662            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
663            deserialize(jd).unwrap_or_else(|err| {
664                panic!(
665                    "Deserialization error at {} ({}:{}): {}",
666                    err.path(),
667                    err.inner().line(),
668                    err.inner().column(),
669                    err
670                );
671            })
672        };
673
674        let assistant_message2: Message = {
675            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
676            deserialize(jd).unwrap_or_else(|err| {
677                panic!(
678                    "Deserialization error at {} ({}:{}): {}",
679                    err.path(),
680                    err.inner().line(),
681                    err.inner().column(),
682                    err
683                );
684            })
685        };
686
687        let assistant_message3: Message = {
688            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
689                &mut serde_json::Deserializer::from_str(assistant_message_json3);
690            deserialize(jd).unwrap_or_else(|err| {
691                panic!(
692                    "Deserialization error at {} ({}:{}): {}",
693                    err.path(),
694                    err.inner().line(),
695                    err.inner().column(),
696                    err
697                );
698            })
699        };
700
701        let user_message: Message = {
702            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
703            deserialize(jd).unwrap_or_else(|err| {
704                panic!(
705                    "Deserialization error at {} ({}:{}): {}",
706                    err.path(),
707                    err.inner().line(),
708                    err.inner().column(),
709                    err
710                );
711            })
712        };
713
714        match assistant_message {
715            Message::Assistant { content, .. } => {
716                assert_eq!(
717                    content[0],
718                    AssistantContent::Text {
719                        text: "\n\nHello there, how may I assist you today?".to_string()
720                    }
721                );
722            }
723            _ => panic!("Expected assistant message"),
724        }
725
726        match assistant_message2 {
727            Message::Assistant {
728                content,
729                tool_calls,
730                ..
731            } => {
732                assert_eq!(
733                    content[0],
734                    AssistantContent::Text {
735                        text: "\n\nHello there, how may I assist you today?".to_string()
736                    }
737                );
738
739                assert_eq!(tool_calls, vec![]);
740            }
741            _ => panic!("Expected assistant message"),
742        }
743
744        match assistant_message3 {
745            Message::Assistant {
746                content,
747                tool_calls,
748                ..
749            } => {
750                assert!(content.is_empty());
751                assert_eq!(
752                    tool_calls[0],
753                    ToolCall {
754                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
755                        r#type: ToolType::Function,
756                        function: Function {
757                            name: "subtract".to_string(),
758                            arguments: serde_json::json!({"x": 2, "y": 5}),
759                        },
760                    }
761                );
762            }
763            _ => panic!("Expected assistant message"),
764        }
765
766        match user_message {
767            Message::User { content, .. } => {
768                let (first, second) = {
769                    let mut iter = content.into_iter();
770                    (iter.next().unwrap(), iter.next().unwrap())
771                };
772                assert_eq!(
773                    first,
774                    UserContent::Text {
775                        text: "What's in this image?".to_string()
776                    }
777                );
778                assert_eq!(second, UserContent::ImageUrl { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string() } });
779            }
780            _ => panic!("Expected user message"),
781        }
782    }
783
784    #[test]
785    fn test_message_to_message_conversion() {
786        let user_message = message::Message::User {
787            content: OneOrMany::one(message::UserContent::text("Hello")),
788        };
789
790        let assistant_message = message::Message::Assistant {
791            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
792        };
793
794        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
795        let converted_assistant_message: Vec<Message> =
796            assistant_message.clone().try_into().unwrap();
797
798        match converted_user_message[0].clone() {
799            Message::User { content, .. } => {
800                assert_eq!(
801                    content.first(),
802                    UserContent::Text {
803                        text: "Hello".to_string()
804                    }
805                );
806            }
807            _ => panic!("Expected user message"),
808        }
809
810        match converted_assistant_message[0].clone() {
811            Message::Assistant { content, .. } => {
812                assert_eq!(
813                    content[0],
814                    AssistantContent::Text {
815                        text: "Hi there!".to_string()
816                    }
817                );
818            }
819            _ => panic!("Expected assistant message"),
820        }
821
822        let original_user_message: message::Message =
823            converted_user_message[0].clone().try_into().unwrap();
824        let original_assistant_message: message::Message =
825            converted_assistant_message[0].clone().try_into().unwrap();
826
827        assert_eq!(original_user_message, user_message);
828        assert_eq!(original_assistant_message, assistant_message);
829    }
830
831    #[test]
832    fn test_message_from_message_conversion() {
833        let user_message = Message::User {
834            content: OneOrMany::one(UserContent::Text {
835                text: "Hello".to_string(),
836            }),
837        };
838
839        let assistant_message = Message::Assistant {
840            content: vec![AssistantContent::Text {
841                text: "Hi there!".to_string(),
842            }],
843            tool_calls: vec![],
844        };
845
846        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
847        let converted_assistant_message: message::Message =
848            assistant_message.clone().try_into().unwrap();
849
850        match converted_user_message.clone() {
851            message::Message::User { content } => {
852                assert_eq!(content.first(), message::UserContent::text("Hello"));
853            }
854            _ => panic!("Expected user message"),
855        }
856
857        match converted_assistant_message.clone() {
858            message::Message::Assistant { content } => {
859                assert_eq!(
860                    content.first(),
861                    message::AssistantContent::text("Hi there!")
862                );
863            }
864            _ => panic!("Expected assistant message"),
865        }
866
867        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
868        let original_assistant_message: Vec<Message> =
869            converted_assistant_message.try_into().unwrap();
870
871        assert_eq!(original_user_message[0], user_message);
872        assert_eq!(original_assistant_message[0], assistant_message);
873    }
874
875    #[test]
876    fn test_responses() {
877        let fireworks_response_json = r#"
878        {
879            "choices": [
880                {
881                    "finish_reason": "tool_calls",
882                    "index": 0,
883                    "message": {
884                        "role": "assistant",
885                        "tool_calls": [
886                            {
887                                "function": {
888                                "arguments": "{\"x\": 2, \"y\": 5}",
889                                "name": "subtract"
890                                },
891                                "id": "call_1BspL6mQqjKgvsQbH1TIYkHf",
892                                "index": 0,
893                                "type": "function"
894                            }
895                        ]
896                    }
897                }
898            ],
899            "created": 1740704000,
900            "id": "2a81f6a1-4866-42fb-9902-2655a2b5b1ff",
901            "model": "accounts/fireworks/models/deepseek-v3",
902            "object": "chat.completion",
903            "usage": {
904                "completion_tokens": 26,
905                "prompt_tokens": 248,
906                "total_tokens": 274
907            }
908        }
909        "#;
910
911        let novita_response_json = r#"
912        {
913            "choices": [
914                {
915                    "finish_reason": "tool_calls",
916                    "index": 0,
917                    "logprobs": null,
918                    "message": {
919                        "audio": null,
920                        "content": null,
921                        "function_call": null,
922                        "reasoning_content": null,
923                        "refusal": null,
924                        "role": "assistant",
925                        "tool_calls": [
926                            {
927                                "function": {
928                                    "arguments": "{\"x\": \"2\", \"y\": \"5\"}",
929                                    "name": "subtract"
930                                },
931                                "id": "chatcmpl-tool-f6d2af7c8dc041058f95e2c2eede45c5",
932                                "type": "function"
933                            }
934                        ]
935                    },
936                    "stop_reason": 128008
937                }
938            ],
939            "created": 1740704592,
940            "id": "chatcmpl-a92c60ae125c47c998ecdcb53387fed4",
941            "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-fast",
942            "object": "chat.completion",
943            "prompt_logprobs": null,
944            "service_tier": null,
945            "system_fingerprint": null,
946            "usage": {
947                "completion_tokens": 28,
948                "completion_tokens_details": null,
949                "prompt_tokens": 335,
950                "prompt_tokens_details": null,
951                "total_tokens": 363
952            }
953        }
954        "#;
955
956        let _firework_response: CompletionResponse = {
957            let jd = &mut serde_json::Deserializer::from_str(fireworks_response_json);
958            deserialize(jd).unwrap_or_else(|err| {
959                panic!(
960                    "Deserialization error at {} ({}:{}): {}",
961                    err.path(),
962                    err.inner().line(),
963                    err.inner().column(),
964                    err
965                );
966            })
967        };
968
969        let _novita_response: CompletionResponse = {
970            let jd = &mut serde_json::Deserializer::from_str(novita_response_json);
971            deserialize(jd).unwrap_or_else(|err| {
972                panic!(
973                    "Deserialization error at {} ({}:{}): {}",
974                    err.path(),
975                    err.inner().line(),
976                    err.inner().column(),
977                    err
978                );
979            })
980        };
981    }
982}