rig/providers/huggingface/
completion.rs

1use std::{convert::Infallible, str::FromStr};
2
3use serde::{Deserialize, Deserializer, Serialize};
4use serde_json::{json, Value};
5
6use crate::{
7    completion::{self, CompletionError, CompletionRequest},
8    json_utils,
9    message::{self},
10    one_or_many::string_or_one_or_many,
11    OneOrMany,
12};
13
14use super::client::Client;
15
16#[derive(Debug, Deserialize)]
17#[serde(untagged)]
18pub enum ApiResponse<T> {
19    Ok(T),
20    Err(Value),
21}
22
23// ================================================================
24// Huggingface Completion API
25// ================================================================
26
27// Conversational LLMs
28
29/// `google/gemma-2-2b-it` completion model
30pub const GEMMA_2: &str = "google/gemma-2-2b-it";
31/// `meta-llama/Meta-Llama-3.1-8B-Instruct` completion model
32pub const META_LLAMA_3_1: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
33/// `microsoft/phi-4` completion model
34pub const PHI_4: &str = "microsoft/phi-4";
35/// `PowerInfer/SmallThinker-3B-Preview` completion model
36pub const SMALLTHINKER_PREVIEW: &str = "PowerInfer/SmallThinker-3B-Preview";
37/// `Qwen/Qwen2.5-7B-Instruct` completion model
38pub const QWEN2_5: &str = "Qwen/Qwen2.5-7B-Instruct";
39/// `Qwen/Qwen2.5-Coder-32B-Instruct` completion model
40pub const QWEN2_5_CODER: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
41
42// Conversational VLMs
43
44/// `Qwen/Qwen2-VL-7B-Instruct` visual-language completion model
45pub const QWEN2_VL: &str = "Qwen/Qwen2-VL-7B-Instruct";
46/// `Qwen/QVQ-72B-Preview` visual-language completion model
47pub const QWEN_QVQ_PREVIEW: &str = "Qwen/QVQ-72B-Preview";
48
49#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
50pub struct Function {
51    name: String,
52    #[serde(deserialize_with = "deserialize_arguments")]
53    pub arguments: serde_json::Value,
54}
55
56fn deserialize_arguments<'de, D>(deserializer: D) -> Result<Value, D::Error>
57where
58    D: Deserializer<'de>,
59{
60    let value = Value::deserialize(deserializer)?;
61
62    match value {
63        Value::String(s) => serde_json::from_str(&s).map_err(serde::de::Error::custom),
64        other => Ok(other),
65    }
66}
67
68impl From<Function> for message::ToolFunction {
69    fn from(value: Function) -> Self {
70        message::ToolFunction {
71            name: value.name,
72            arguments: value.arguments,
73        }
74    }
75}
76
77#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
78#[serde(rename_all = "lowercase")]
79pub enum ToolType {
80    #[default]
81    Function,
82}
83
84#[derive(Debug, Deserialize, Serialize, Clone)]
85pub struct ToolDefinition {
86    pub r#type: String,
87    pub function: completion::ToolDefinition,
88}
89
90impl From<completion::ToolDefinition> for ToolDefinition {
91    fn from(tool: completion::ToolDefinition) -> Self {
92        Self {
93            r#type: "function".into(),
94            function: tool,
95        }
96    }
97}
98
99#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
100pub struct ToolCall {
101    pub id: String,
102    pub r#type: ToolType,
103    pub function: Function,
104}
105
106impl From<ToolCall> for message::ToolCall {
107    fn from(value: ToolCall) -> Self {
108        message::ToolCall {
109            id: value.id,
110            function: value.function.into(),
111        }
112    }
113}
114
115impl From<message::ToolCall> for ToolCall {
116    fn from(value: message::ToolCall) -> Self {
117        ToolCall {
118            id: value.id,
119            r#type: ToolType::Function,
120            function: Function {
121                name: value.function.name,
122                arguments: value.function.arguments,
123            },
124        }
125    }
126}
127
128#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
129pub struct ImageUrl {
130    url: String,
131}
132
133#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
134#[serde(tag = "type", rename_all = "lowercase")]
135pub enum UserContent {
136    Text {
137        text: String,
138    },
139    #[serde(rename = "image_url")]
140    ImageUrl {
141        image_url: ImageUrl,
142    },
143}
144
145impl FromStr for UserContent {
146    type Err = Infallible;
147
148    fn from_str(s: &str) -> Result<Self, Self::Err> {
149        Ok(UserContent::Text {
150            text: s.to_string(),
151        })
152    }
153}
154
155#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
156#[serde(tag = "type", rename_all = "lowercase")]
157pub enum AssistantContent {
158    Text { text: String },
159}
160
161impl FromStr for AssistantContent {
162    type Err = Infallible;
163
164    fn from_str(s: &str) -> Result<Self, Self::Err> {
165        Ok(AssistantContent::Text {
166            text: s.to_string(),
167        })
168    }
169}
170
171#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
172#[serde(tag = "type", rename_all = "lowercase")]
173pub enum SystemContent {
174    Text { text: String },
175}
176
177impl FromStr for SystemContent {
178    type Err = Infallible;
179
180    fn from_str(s: &str) -> Result<Self, Self::Err> {
181        Ok(SystemContent::Text {
182            text: s.to_string(),
183        })
184    }
185}
186
187impl From<UserContent> for message::UserContent {
188    fn from(value: UserContent) -> Self {
189        match value {
190            UserContent::Text { text } => message::UserContent::text(text),
191            UserContent::ImageUrl { image_url } => message::UserContent::image(
192                image_url.url,
193                Some(message::ContentFormat::String),
194                None,
195                None,
196            ),
197        }
198    }
199}
200
201impl TryFrom<message::UserContent> for UserContent {
202    type Error = message::MessageError;
203
204    fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
205        match content {
206            message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }),
207            message::UserContent::Image(message::Image { data, format, .. }) => match format {
208                Some(message::ContentFormat::String) => Ok(UserContent::ImageUrl {
209                    image_url: ImageUrl { url: data },
210                }),
211                _ => Err(message::MessageError::ConversionError(
212                    "Huggingface only supports images as urls".into(),
213                )),
214            },
215            _ => Err(message::MessageError::ConversionError(
216                "Huggingface only supports text and images".into(),
217            )),
218        }
219    }
220}
221
222#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
223#[serde(tag = "role", rename_all = "lowercase")]
224pub enum Message {
225    System {
226        #[serde(deserialize_with = "string_or_one_or_many")]
227        content: OneOrMany<SystemContent>,
228    },
229    User {
230        #[serde(deserialize_with = "string_or_one_or_many")]
231        content: OneOrMany<UserContent>,
232    },
233    Assistant {
234        #[serde(default, deserialize_with = "json_utils::string_or_vec")]
235        content: Vec<AssistantContent>,
236        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
237        tool_calls: Vec<ToolCall>,
238    },
239    #[serde(rename = "Tool")]
240    ToolResult {
241        name: String,
242        #[serde(skip_serializing_if = "Option::is_none")]
243        arguments: Option<serde_json::Value>,
244        #[serde(deserialize_with = "string_or_one_or_many")]
245        content: OneOrMany<String>,
246    },
247}
248
249impl Message {
250    pub fn system(content: &str) -> Self {
251        Message::System {
252            content: OneOrMany::one(SystemContent::Text {
253                text: content.to_string(),
254            }),
255        }
256    }
257}
258
259impl TryFrom<message::Message> for Vec<Message> {
260    type Error = message::MessageError;
261
262    fn try_from(message: message::Message) -> Result<Vec<Message>, Self::Error> {
263        match message {
264            message::Message::User { content } => {
265                let (tool_results, other_content): (Vec<_>, Vec<_>) = content
266                    .into_iter()
267                    .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
268
269                if !tool_results.is_empty() {
270                    tool_results
271                        .into_iter()
272                        .map(|content| match content {
273                            message::UserContent::ToolResult(message::ToolResult {
274                                id,
275                                content,
276                            }) => Ok::<_, message::MessageError>(Message::ToolResult {
277                                name: id,
278                                arguments: None,
279                                content: content.try_map(|content| match content {
280                                    message::ToolResultContent::Text(message::Text { text }) => {
281                                        Ok(text)
282                                    }
283                                    _ => Err(message::MessageError::ConversionError(
284                                        "Tool result content does not support non-text".into(),
285                                    )),
286                                })?,
287                            }),
288                            _ => unreachable!(),
289                        })
290                        .collect::<Result<Vec<_>, _>>()
291                } else {
292                    let other_content = OneOrMany::many(other_content).expect(
293                        "There must be other content here if there were no tool result content",
294                    );
295
296                    Ok(vec![Message::User {
297                        content: other_content.try_map(|content| match content {
298                            message::UserContent::Text(text) => {
299                                Ok(UserContent::Text { text: text.text })
300                            }
301                            _ => Err(message::MessageError::ConversionError(
302                                "Huggingface does not support non-text".into(),
303                            )),
304                        })?,
305                    }])
306                }
307            }
308            message::Message::Assistant { content } => {
309                let (text_content, tool_calls) = content.into_iter().fold(
310                    (Vec::new(), Vec::new()),
311                    |(mut texts, mut tools), content| {
312                        match content {
313                            message::AssistantContent::Text(text) => texts.push(text),
314                            message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
315                        }
316                        (texts, tools)
317                    },
318                );
319
320                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
321                //  so either `content` or `tool_calls` will have some content.
322                Ok(vec![Message::Assistant {
323                    content: text_content
324                        .into_iter()
325                        .map(|content| AssistantContent::Text { text: content.text })
326                        .collect::<Vec<_>>(),
327                    tool_calls: tool_calls
328                        .into_iter()
329                        .map(|tool_call| tool_call.into())
330                        .collect::<Vec<_>>(),
331                }])
332            }
333        }
334    }
335}
336
337impl TryFrom<Message> for message::Message {
338    type Error = message::MessageError;
339
340    fn try_from(message: Message) -> Result<Self, Self::Error> {
341        Ok(match message {
342            Message::User { content, .. } => message::Message::User {
343                content: content.map(|content| content.into()),
344            },
345            Message::Assistant {
346                content,
347                tool_calls,
348                ..
349            } => {
350                let mut content = content
351                    .into_iter()
352                    .map(|content| match content {
353                        AssistantContent::Text { text } => message::AssistantContent::text(text),
354                    })
355                    .collect::<Vec<_>>();
356
357                content.extend(
358                    tool_calls
359                        .into_iter()
360                        .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
361                        .collect::<Result<Vec<_>, _>>()?,
362                );
363
364                message::Message::Assistant {
365                    content: OneOrMany::many(content).map_err(|_| {
366                        message::MessageError::ConversionError(
367                            "Neither `content` nor `tool_calls` was provided to the Message"
368                                .to_owned(),
369                        )
370                    })?,
371                }
372            }
373
374            Message::ToolResult { name, content, .. } => message::Message::User {
375                content: OneOrMany::one(message::UserContent::tool_result(
376                    name,
377                    content.map(message::ToolResultContent::text),
378                )),
379            },
380
381            // System messages should get stripped out when converting message's, this is just a
382            // stop gap to avoid obnoxious error handling or panic occurring.
383            Message::System { content, .. } => message::Message::User {
384                content: content.map(|c| match c {
385                    SystemContent::Text { text } => message::UserContent::text(text),
386                }),
387            },
388        })
389    }
390}
391
392#[derive(Debug, Deserialize)]
393pub struct Choice {
394    pub finish_reason: String,
395    pub index: usize,
396    #[serde(default)]
397    pub logprobs: serde_json::Value,
398    pub message: Message,
399}
400
401#[derive(Debug, Deserialize, Clone)]
402pub struct Usage {
403    pub completion_tokens: i32,
404    pub prompt_tokens: i32,
405    pub total_tokens: i32,
406}
407
408#[derive(Debug, Deserialize)]
409pub struct CompletionResponse {
410    pub created: i32,
411    pub id: String,
412    pub model: String,
413    pub choices: Vec<Choice>,
414    #[serde(default, deserialize_with = "default_string_on_null")]
415    pub system_fingerprint: String,
416    pub usage: Usage,
417}
418
419fn default_string_on_null<'de, D>(deserializer: D) -> Result<String, D::Error>
420where
421    D: Deserializer<'de>,
422{
423    match Option::<String>::deserialize(deserializer)? {
424        Some(value) => Ok(value),      // Use provided value
425        None => Ok(String::default()), // Use `Default` implementation
426    }
427}
428
429impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
430    type Error = CompletionError;
431
432    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
433        let choice = response.choices.first().ok_or_else(|| {
434            CompletionError::ResponseError("Response contained no choices".to_owned())
435        })?;
436
437        let content = match &choice.message {
438            Message::Assistant {
439                content,
440                tool_calls,
441                ..
442            } => {
443                let mut content = content
444                    .iter()
445                    .map(|c| match c {
446                        AssistantContent::Text { text } => message::AssistantContent::text(text),
447                    })
448                    .collect::<Vec<_>>();
449
450                content.extend(
451                    tool_calls
452                        .iter()
453                        .map(|call| {
454                            completion::AssistantContent::tool_call(
455                                &call.function.name,
456                                &call.function.name,
457                                call.function.arguments.clone(),
458                            )
459                        })
460                        .collect::<Vec<_>>(),
461                );
462                Ok(content)
463            }
464            _ => Err(CompletionError::ResponseError(
465                "Response did not contain a valid message or tool call".into(),
466            )),
467        }?;
468
469        let choice = OneOrMany::many(content).map_err(|_| {
470            CompletionError::ResponseError(
471                "Response contained no message or tool call (empty)".to_owned(),
472            )
473        })?;
474
475        Ok(completion::CompletionResponse {
476            choice,
477            raw_response: response,
478        })
479    }
480}
481
482#[derive(Clone)]
483pub struct CompletionModel {
484    pub(crate) client: Client,
485    /// Name of the model (e.g: google/gemma-2-2b-it)
486    pub model: String,
487}
488
489impl CompletionModel {
490    pub fn new(client: Client, model: &str) -> Self {
491        Self {
492            client,
493            model: model.to_string(),
494        }
495    }
496
497    pub(crate) fn create_request_body(
498        &self,
499        completion_request: &CompletionRequest,
500    ) -> Result<serde_json::Value, CompletionError> {
501        let mut full_history: Vec<Message> = match &completion_request.preamble {
502            Some(preamble) => vec![Message::system(preamble)],
503            None => vec![],
504        };
505
506        let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
507
508        let chat_history: Vec<Message> = completion_request
509            .chat_history
510            .clone()
511            .into_iter()
512            .map(|message| message.try_into())
513            .collect::<Result<Vec<Vec<Message>>, _>>()?
514            .into_iter()
515            .flatten()
516            .collect();
517
518        full_history.extend(chat_history);
519        full_history.extend(prompt);
520
521        let model = self.client.sub_provider.model_identifier(&self.model);
522
523        let request = if completion_request.tools.is_empty() {
524            json!({
525                "model": model,
526                "messages": full_history,
527                "temperature": completion_request.temperature,
528            })
529        } else {
530            json!({
531                "model": model,
532                "messages": full_history,
533                "temperature": completion_request.temperature,
534                "tools": completion_request.tools.clone().into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
535                "tool_choice": "auto",
536            })
537        };
538        Ok(request)
539    }
540}
541
542impl completion::CompletionModel for CompletionModel {
543    type Response = CompletionResponse;
544
545    #[cfg_attr(feature = "worker", worker::send)]
546    async fn completion(
547        &self,
548        completion_request: CompletionRequest,
549    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
550        let request = self.create_request_body(&completion_request)?;
551
552        let path = self.client.sub_provider.completion_endpoint(&self.model);
553
554        let request = if let Some(ref params) = completion_request.additional_params {
555            json_utils::merge(request, params.clone())
556        } else {
557            request
558        };
559
560        let response = self.client.post(&path).json(&request).send().await?;
561
562        if response.status().is_success() {
563            let t = response.text().await?;
564            tracing::debug!(target: "rig", "Huggingface completion error: {}", t);
565
566            match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
567                ApiResponse::Ok(response) => {
568                    tracing::info!(target: "rig",
569                        "Huggingface completion token usage: {:?}",
570                        format!("{:?}", response.usage)
571                    );
572                    response.try_into()
573                }
574                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())),
575            }
576        } else {
577            Err(CompletionError::ProviderError(format!(
578                "{}: {}",
579                response.status(),
580                response.text().await?
581            )))
582        }
583    }
584}
585
586#[cfg(test)]
587mod tests {
588    use super::*;
589    use serde_path_to_error::deserialize;
590
591    #[test]
592    fn test_deserialize_message() {
593        let assistant_message_json = r#"
594        {
595            "role": "assistant",
596            "content": "\n\nHello there, how may I assist you today?"
597        }
598        "#;
599
600        let assistant_message_json2 = r#"
601        {
602            "role": "assistant",
603            "content": [
604                {
605                    "type": "text",
606                    "text": "\n\nHello there, how may I assist you today?"
607                }
608            ],
609            "tool_calls": null
610        }
611        "#;
612
613        let assistant_message_json3 = r#"
614        {
615            "role": "assistant",
616            "tool_calls": [
617                {
618                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
619                    "type": "function",
620                    "function": {
621                        "name": "subtract",
622                        "arguments": {"x": 2, "y": 5}
623                    }
624                }
625            ],
626            "content": null,
627            "refusal": null
628        }
629        "#;
630
631        let user_message_json = r#"
632        {
633            "role": "user",
634            "content": [
635                {
636                    "type": "text",
637                    "text": "What's in this image?"
638                },
639                {
640                    "type": "image_url",
641                    "image_url": {
642                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
643                    }
644                }
645            ]
646        }
647        "#;
648
649        let assistant_message: Message = {
650            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
651            deserialize(jd).unwrap_or_else(|err| {
652                panic!(
653                    "Deserialization error at {} ({}:{}): {}",
654                    err.path(),
655                    err.inner().line(),
656                    err.inner().column(),
657                    err
658                );
659            })
660        };
661
662        let assistant_message2: Message = {
663            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
664            deserialize(jd).unwrap_or_else(|err| {
665                panic!(
666                    "Deserialization error at {} ({}:{}): {}",
667                    err.path(),
668                    err.inner().line(),
669                    err.inner().column(),
670                    err
671                );
672            })
673        };
674
675        let assistant_message3: Message = {
676            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
677                &mut serde_json::Deserializer::from_str(assistant_message_json3);
678            deserialize(jd).unwrap_or_else(|err| {
679                panic!(
680                    "Deserialization error at {} ({}:{}): {}",
681                    err.path(),
682                    err.inner().line(),
683                    err.inner().column(),
684                    err
685                );
686            })
687        };
688
689        let user_message: Message = {
690            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
691            deserialize(jd).unwrap_or_else(|err| {
692                panic!(
693                    "Deserialization error at {} ({}:{}): {}",
694                    err.path(),
695                    err.inner().line(),
696                    err.inner().column(),
697                    err
698                );
699            })
700        };
701
702        match assistant_message {
703            Message::Assistant { content, .. } => {
704                assert_eq!(
705                    content[0],
706                    AssistantContent::Text {
707                        text: "\n\nHello there, how may I assist you today?".to_string()
708                    }
709                );
710            }
711            _ => panic!("Expected assistant message"),
712        }
713
714        match assistant_message2 {
715            Message::Assistant {
716                content,
717                tool_calls,
718                ..
719            } => {
720                assert_eq!(
721                    content[0],
722                    AssistantContent::Text {
723                        text: "\n\nHello there, how may I assist you today?".to_string()
724                    }
725                );
726
727                assert_eq!(tool_calls, vec![]);
728            }
729            _ => panic!("Expected assistant message"),
730        }
731
732        match assistant_message3 {
733            Message::Assistant {
734                content,
735                tool_calls,
736                ..
737            } => {
738                assert!(content.is_empty());
739                assert_eq!(
740                    tool_calls[0],
741                    ToolCall {
742                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
743                        r#type: ToolType::Function,
744                        function: Function {
745                            name: "subtract".to_string(),
746                            arguments: serde_json::json!({"x": 2, "y": 5}),
747                        },
748                    }
749                );
750            }
751            _ => panic!("Expected assistant message"),
752        }
753
754        match user_message {
755            Message::User { content, .. } => {
756                let (first, second) = {
757                    let mut iter = content.into_iter();
758                    (iter.next().unwrap(), iter.next().unwrap())
759                };
760                assert_eq!(
761                    first,
762                    UserContent::Text {
763                        text: "What's in this image?".to_string()
764                    }
765                );
766                assert_eq!(second, UserContent::ImageUrl { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string() } });
767            }
768            _ => panic!("Expected user message"),
769        }
770    }
771
772    #[test]
773    fn test_message_to_message_conversion() {
774        let user_message = message::Message::User {
775            content: OneOrMany::one(message::UserContent::text("Hello")),
776        };
777
778        let assistant_message = message::Message::Assistant {
779            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
780        };
781
782        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
783        let converted_assistant_message: Vec<Message> =
784            assistant_message.clone().try_into().unwrap();
785
786        match converted_user_message[0].clone() {
787            Message::User { content, .. } => {
788                assert_eq!(
789                    content.first(),
790                    UserContent::Text {
791                        text: "Hello".to_string()
792                    }
793                );
794            }
795            _ => panic!("Expected user message"),
796        }
797
798        match converted_assistant_message[0].clone() {
799            Message::Assistant { content, .. } => {
800                assert_eq!(
801                    content[0],
802                    AssistantContent::Text {
803                        text: "Hi there!".to_string()
804                    }
805                );
806            }
807            _ => panic!("Expected assistant message"),
808        }
809
810        let original_user_message: message::Message =
811            converted_user_message[0].clone().try_into().unwrap();
812        let original_assistant_message: message::Message =
813            converted_assistant_message[0].clone().try_into().unwrap();
814
815        assert_eq!(original_user_message, user_message);
816        assert_eq!(original_assistant_message, assistant_message);
817    }
818
819    #[test]
820    fn test_message_from_message_conversion() {
821        let user_message = Message::User {
822            content: OneOrMany::one(UserContent::Text {
823                text: "Hello".to_string(),
824            }),
825        };
826
827        let assistant_message = Message::Assistant {
828            content: vec![AssistantContent::Text {
829                text: "Hi there!".to_string(),
830            }],
831            tool_calls: vec![],
832        };
833
834        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
835        let converted_assistant_message: message::Message =
836            assistant_message.clone().try_into().unwrap();
837
838        match converted_user_message.clone() {
839            message::Message::User { content } => {
840                assert_eq!(content.first(), message::UserContent::text("Hello"));
841            }
842            _ => panic!("Expected user message"),
843        }
844
845        match converted_assistant_message.clone() {
846            message::Message::Assistant { content } => {
847                assert_eq!(
848                    content.first(),
849                    message::AssistantContent::text("Hi there!")
850                );
851            }
852            _ => panic!("Expected assistant message"),
853        }
854
855        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
856        let original_assistant_message: Vec<Message> =
857            converted_assistant_message.try_into().unwrap();
858
859        assert_eq!(original_user_message[0], user_message);
860        assert_eq!(original_assistant_message[0], assistant_message);
861    }
862
863    #[test]
864    fn test_responses() {
865        let fireworks_response_json = r#"
866        {
867            "choices": [
868                {
869                    "finish_reason": "tool_calls",
870                    "index": 0,
871                    "message": {
872                        "role": "assistant",
873                        "tool_calls": [
874                            {
875                                "function": {
876                                "arguments": "{\"x\": 2, \"y\": 5}",
877                                "name": "subtract"
878                                },
879                                "id": "call_1BspL6mQqjKgvsQbH1TIYkHf",
880                                "index": 0,
881                                "type": "function"
882                            }
883                        ]
884                    }
885                }
886            ],
887            "created": 1740704000,
888            "id": "2a81f6a1-4866-42fb-9902-2655a2b5b1ff",
889            "model": "accounts/fireworks/models/deepseek-v3",
890            "object": "chat.completion",
891            "usage": {
892                "completion_tokens": 26,
893                "prompt_tokens": 248,
894                "total_tokens": 274
895            }
896        }
897        "#;
898
899        let novita_response_json = r#"
900        {
901            "choices": [
902                {
903                    "finish_reason": "tool_calls",
904                    "index": 0,
905                    "logprobs": null,
906                    "message": {
907                        "audio": null,
908                        "content": null,
909                        "function_call": null,
910                        "reasoning_content": null,
911                        "refusal": null,
912                        "role": "assistant",
913                        "tool_calls": [
914                            {
915                                "function": {
916                                    "arguments": "{\"x\": \"2\", \"y\": \"5\"}",
917                                    "name": "subtract"
918                                },
919                                "id": "chatcmpl-tool-f6d2af7c8dc041058f95e2c2eede45c5",
920                                "type": "function"
921                            }
922                        ]
923                    },
924                    "stop_reason": 128008
925                }
926            ],
927            "created": 1740704592,
928            "id": "chatcmpl-a92c60ae125c47c998ecdcb53387fed4",
929            "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-fast",
930            "object": "chat.completion",
931            "prompt_logprobs": null,
932            "service_tier": null,
933            "system_fingerprint": null,
934            "usage": {
935                "completion_tokens": 28,
936                "completion_tokens_details": null,
937                "prompt_tokens": 335,
938                "prompt_tokens_details": null,
939                "total_tokens": 363
940            }
941        }
942        "#;
943
944        let _firework_response: CompletionResponse = {
945            let jd = &mut serde_json::Deserializer::from_str(fireworks_response_json);
946            deserialize(jd).unwrap_or_else(|err| {
947                panic!(
948                    "Deserialization error at {} ({}:{}): {}",
949                    err.path(),
950                    err.inner().line(),
951                    err.inner().column(),
952                    err
953                );
954            })
955        };
956
957        let _novita_response: CompletionResponse = {
958            let jd = &mut serde_json::Deserializer::from_str(novita_response_json);
959            deserialize(jd).unwrap_or_else(|err| {
960                panic!(
961                    "Deserialization error at {} ({}:{}): {}",
962                    err.path(),
963                    err.inner().line(),
964                    err.inner().column(),
965                    err
966                );
967            })
968        };
969    }
970}