rig/providers/huggingface/
completion.rs

1use serde::{Deserialize, Deserializer, Serialize};
2use serde_json::{Value, json};
3use std::{convert::Infallible, str::FromStr};
4
5use super::client::Client;
6use crate::providers::openai::StreamingCompletionResponse;
7use crate::{
8    OneOrMany,
9    completion::{self, CompletionError, CompletionRequest},
10    json_utils,
11    message::{self},
12    one_or_many::string_or_one_or_many,
13};
14
15#[derive(Debug, Deserialize)]
16#[serde(untagged)]
17pub enum ApiResponse<T> {
18    Ok(T),
19    Err(Value),
20}
21
22// ================================================================
23// Huggingface Completion API
24// ================================================================
25
26// Conversational LLMs
27
28/// `google/gemma-2-2b-it` completion model
29pub const GEMMA_2: &str = "google/gemma-2-2b-it";
30/// `meta-llama/Meta-Llama-3.1-8B-Instruct` completion model
31pub const META_LLAMA_3_1: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
32/// `microsoft/phi-4` completion model
33pub const PHI_4: &str = "microsoft/phi-4";
34/// `PowerInfer/SmallThinker-3B-Preview` completion model
35pub const SMALLTHINKER_PREVIEW: &str = "PowerInfer/SmallThinker-3B-Preview";
36/// `Qwen/Qwen2.5-7B-Instruct` completion model
37pub const QWEN2_5: &str = "Qwen/Qwen2.5-7B-Instruct";
38/// `Qwen/Qwen2.5-Coder-32B-Instruct` completion model
39pub const QWEN2_5_CODER: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
40
41// Conversational VLMs
42
43/// `Qwen/Qwen2-VL-7B-Instruct` visual-language completion model
44pub const QWEN2_VL: &str = "Qwen/Qwen2-VL-7B-Instruct";
45/// `Qwen/QVQ-72B-Preview` visual-language completion model
46pub const QWEN_QVQ_PREVIEW: &str = "Qwen/QVQ-72B-Preview";
47
48#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
49pub struct Function {
50    name: String,
51    #[serde(deserialize_with = "deserialize_arguments")]
52    pub arguments: serde_json::Value,
53}
54
55fn deserialize_arguments<'de, D>(deserializer: D) -> Result<Value, D::Error>
56where
57    D: Deserializer<'de>,
58{
59    let value = Value::deserialize(deserializer)?;
60
61    match value {
62        Value::String(s) => serde_json::from_str(&s).map_err(serde::de::Error::custom),
63        other => Ok(other),
64    }
65}
66
67impl From<Function> for message::ToolFunction {
68    fn from(value: Function) -> Self {
69        message::ToolFunction {
70            name: value.name,
71            arguments: value.arguments,
72        }
73    }
74}
75
76#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
77#[serde(rename_all = "lowercase")]
78pub enum ToolType {
79    #[default]
80    Function,
81}
82
83#[derive(Debug, Deserialize, Serialize, Clone)]
84pub struct ToolDefinition {
85    pub r#type: String,
86    pub function: completion::ToolDefinition,
87}
88
89impl From<completion::ToolDefinition> for ToolDefinition {
90    fn from(tool: completion::ToolDefinition) -> Self {
91        Self {
92            r#type: "function".into(),
93            function: tool,
94        }
95    }
96}
97
98#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
99pub struct ToolCall {
100    pub id: String,
101    pub r#type: ToolType,
102    pub function: Function,
103}
104
105impl From<ToolCall> for message::ToolCall {
106    fn from(value: ToolCall) -> Self {
107        message::ToolCall {
108            id: value.id,
109            call_id: None,
110            function: value.function.into(),
111        }
112    }
113}
114
115impl From<message::ToolCall> for ToolCall {
116    fn from(value: message::ToolCall) -> Self {
117        ToolCall {
118            id: value.id,
119            r#type: ToolType::Function,
120            function: Function {
121                name: value.function.name,
122                arguments: value.function.arguments,
123            },
124        }
125    }
126}
127
128#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
129pub struct ImageUrl {
130    url: String,
131}
132
133#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
134#[serde(tag = "type", rename_all = "lowercase")]
135pub enum UserContent {
136    Text {
137        text: String,
138    },
139    #[serde(rename = "image_url")]
140    ImageUrl {
141        image_url: ImageUrl,
142    },
143}
144
145impl FromStr for UserContent {
146    type Err = Infallible;
147
148    fn from_str(s: &str) -> Result<Self, Self::Err> {
149        Ok(UserContent::Text {
150            text: s.to_string(),
151        })
152    }
153}
154
155#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
156#[serde(tag = "type", rename_all = "lowercase")]
157pub enum AssistantContent {
158    Text { text: String },
159}
160
161impl FromStr for AssistantContent {
162    type Err = Infallible;
163
164    fn from_str(s: &str) -> Result<Self, Self::Err> {
165        Ok(AssistantContent::Text {
166            text: s.to_string(),
167        })
168    }
169}
170
171#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
172#[serde(tag = "type", rename_all = "lowercase")]
173pub enum SystemContent {
174    Text { text: String },
175}
176
177impl FromStr for SystemContent {
178    type Err = Infallible;
179
180    fn from_str(s: &str) -> Result<Self, Self::Err> {
181        Ok(SystemContent::Text {
182            text: s.to_string(),
183        })
184    }
185}
186
187impl From<UserContent> for message::UserContent {
188    fn from(value: UserContent) -> Self {
189        match value {
190            UserContent::Text { text } => message::UserContent::text(text),
191            UserContent::ImageUrl { image_url } => message::UserContent::image(
192                image_url.url,
193                Some(message::ContentFormat::String),
194                None,
195                None,
196            ),
197        }
198    }
199}
200
201impl TryFrom<message::UserContent> for UserContent {
202    type Error = message::MessageError;
203
204    fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
205        match content {
206            message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }),
207            message::UserContent::Image(message::Image { data, format, .. }) => match format {
208                Some(message::ContentFormat::String) => Ok(UserContent::ImageUrl {
209                    image_url: ImageUrl { url: data },
210                }),
211                _ => Err(message::MessageError::ConversionError(
212                    "Huggingface only supports images as urls".into(),
213                )),
214            },
215            _ => Err(message::MessageError::ConversionError(
216                "Huggingface only supports text and images".into(),
217            )),
218        }
219    }
220}
221
222#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
223#[serde(tag = "role", rename_all = "lowercase")]
224pub enum Message {
225    System {
226        #[serde(deserialize_with = "string_or_one_or_many")]
227        content: OneOrMany<SystemContent>,
228    },
229    User {
230        #[serde(deserialize_with = "string_or_one_or_many")]
231        content: OneOrMany<UserContent>,
232    },
233    Assistant {
234        #[serde(default, deserialize_with = "json_utils::string_or_vec")]
235        content: Vec<AssistantContent>,
236        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
237        tool_calls: Vec<ToolCall>,
238    },
239    #[serde(rename = "Tool")]
240    ToolResult {
241        name: String,
242        #[serde(skip_serializing_if = "Option::is_none")]
243        arguments: Option<serde_json::Value>,
244        #[serde(deserialize_with = "string_or_one_or_many")]
245        content: OneOrMany<String>,
246    },
247}
248
249impl Message {
250    pub fn system(content: &str) -> Self {
251        Message::System {
252            content: OneOrMany::one(SystemContent::Text {
253                text: content.to_string(),
254            }),
255        }
256    }
257}
258
259impl TryFrom<message::Message> for Vec<Message> {
260    type Error = message::MessageError;
261
262    fn try_from(message: message::Message) -> Result<Vec<Message>, Self::Error> {
263        match message {
264            message::Message::User { content } => {
265                let (tool_results, other_content): (Vec<_>, Vec<_>) = content
266                    .into_iter()
267                    .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
268
269                if !tool_results.is_empty() {
270                    tool_results
271                        .into_iter()
272                        .map(|content| match content {
273                            message::UserContent::ToolResult(message::ToolResult {
274                                id,
275                                content,
276                                ..
277                            }) => Ok::<_, message::MessageError>(Message::ToolResult {
278                                name: id,
279                                arguments: None,
280                                content: content.try_map(|content| match content {
281                                    message::ToolResultContent::Text(message::Text { text }) => {
282                                        Ok(text)
283                                    }
284                                    _ => Err(message::MessageError::ConversionError(
285                                        "Tool result content does not support non-text".into(),
286                                    )),
287                                })?,
288                            }),
289                            _ => unreachable!(),
290                        })
291                        .collect::<Result<Vec<_>, _>>()
292                } else {
293                    let other_content = OneOrMany::many(other_content).expect(
294                        "There must be other content here if there were no tool result content",
295                    );
296
297                    Ok(vec![Message::User {
298                        content: other_content.try_map(|content| match content {
299                            message::UserContent::Text(text) => {
300                                Ok(UserContent::Text { text: text.text })
301                            }
302                            _ => Err(message::MessageError::ConversionError(
303                                "Huggingface does not support non-text".into(),
304                            )),
305                        })?,
306                    }])
307                }
308            }
309            message::Message::Assistant { content, .. } => {
310                let (text_content, tool_calls) = content.into_iter().fold(
311                    (Vec::new(), Vec::new()),
312                    |(mut texts, mut tools), content| {
313                        match content {
314                            message::AssistantContent::Text(text) => texts.push(text),
315                            message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
316                        }
317                        (texts, tools)
318                    },
319                );
320
321                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
322                //  so either `content` or `tool_calls` will have some content.
323                Ok(vec![Message::Assistant {
324                    content: text_content
325                        .into_iter()
326                        .map(|content| AssistantContent::Text { text: content.text })
327                        .collect::<Vec<_>>(),
328                    tool_calls: tool_calls
329                        .into_iter()
330                        .map(|tool_call| tool_call.into())
331                        .collect::<Vec<_>>(),
332                }])
333            }
334        }
335    }
336}
337
338impl TryFrom<Message> for message::Message {
339    type Error = message::MessageError;
340
341    fn try_from(message: Message) -> Result<Self, Self::Error> {
342        Ok(match message {
343            Message::User { content, .. } => message::Message::User {
344                content: content.map(|content| content.into()),
345            },
346            Message::Assistant {
347                content,
348                tool_calls,
349                ..
350            } => {
351                let mut content = content
352                    .into_iter()
353                    .map(|content| match content {
354                        AssistantContent::Text { text } => message::AssistantContent::text(text),
355                    })
356                    .collect::<Vec<_>>();
357
358                content.extend(
359                    tool_calls
360                        .into_iter()
361                        .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
362                        .collect::<Result<Vec<_>, _>>()?,
363                );
364
365                message::Message::Assistant {
366                    id: None,
367                    content: OneOrMany::many(content).map_err(|_| {
368                        message::MessageError::ConversionError(
369                            "Neither `content` nor `tool_calls` was provided to the Message"
370                                .to_owned(),
371                        )
372                    })?,
373                }
374            }
375
376            Message::ToolResult { name, content, .. } => message::Message::User {
377                content: OneOrMany::one(message::UserContent::tool_result(
378                    name,
379                    content.map(message::ToolResultContent::text),
380                )),
381            },
382
383            // System messages should get stripped out when converting message's, this is just a
384            // stop gap to avoid obnoxious error handling or panic occurring.
385            Message::System { content, .. } => message::Message::User {
386                content: content.map(|c| match c {
387                    SystemContent::Text { text } => message::UserContent::text(text),
388                }),
389            },
390        })
391    }
392}
393
394#[derive(Debug, Deserialize)]
395pub struct Choice {
396    pub finish_reason: String,
397    pub index: usize,
398    #[serde(default)]
399    pub logprobs: serde_json::Value,
400    pub message: Message,
401}
402
403#[derive(Debug, Deserialize, Clone)]
404pub struct Usage {
405    pub completion_tokens: i32,
406    pub prompt_tokens: i32,
407    pub total_tokens: i32,
408}
409
410#[derive(Debug, Deserialize)]
411pub struct CompletionResponse {
412    pub created: i32,
413    pub id: String,
414    pub model: String,
415    pub choices: Vec<Choice>,
416    #[serde(default, deserialize_with = "default_string_on_null")]
417    pub system_fingerprint: String,
418    pub usage: Usage,
419}
420
421fn default_string_on_null<'de, D>(deserializer: D) -> Result<String, D::Error>
422where
423    D: Deserializer<'de>,
424{
425    match Option::<String>::deserialize(deserializer)? {
426        Some(value) => Ok(value),      // Use provided value
427        None => Ok(String::default()), // Use `Default` implementation
428    }
429}
430
431impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
432    type Error = CompletionError;
433
434    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
435        let choice = response.choices.first().ok_or_else(|| {
436            CompletionError::ResponseError("Response contained no choices".to_owned())
437        })?;
438
439        let content = match &choice.message {
440            Message::Assistant {
441                content,
442                tool_calls,
443                ..
444            } => {
445                let mut content = content
446                    .iter()
447                    .map(|c| match c {
448                        AssistantContent::Text { text } => message::AssistantContent::text(text),
449                    })
450                    .collect::<Vec<_>>();
451
452                content.extend(
453                    tool_calls
454                        .iter()
455                        .map(|call| {
456                            completion::AssistantContent::tool_call(
457                                &call.id,
458                                &call.function.name,
459                                call.function.arguments.clone(),
460                            )
461                        })
462                        .collect::<Vec<_>>(),
463                );
464                Ok(content)
465            }
466            _ => Err(CompletionError::ResponseError(
467                "Response did not contain a valid message or tool call".into(),
468            )),
469        }?;
470
471        let choice = OneOrMany::many(content).map_err(|_| {
472            CompletionError::ResponseError(
473                "Response contained no message or tool call (empty)".to_owned(),
474            )
475        })?;
476
477        Ok(completion::CompletionResponse {
478            choice,
479            raw_response: response,
480        })
481    }
482}
483
484#[derive(Clone)]
485pub struct CompletionModel {
486    pub(crate) client: Client,
487    /// Name of the model (e.g: google/gemma-2-2b-it)
488    pub model: String,
489}
490
491impl CompletionModel {
492    pub fn new(client: Client, model: &str) -> Self {
493        Self {
494            client,
495            model: model.to_string(),
496        }
497    }
498
499    pub(crate) fn create_request_body(
500        &self,
501        completion_request: &CompletionRequest,
502    ) -> Result<serde_json::Value, CompletionError> {
503        let mut full_history: Vec<Message> = match &completion_request.preamble {
504            Some(preamble) => vec![Message::system(preamble)],
505            None => vec![],
506        };
507        if let Some(docs) = completion_request.normalized_documents() {
508            let docs: Vec<Message> = docs.try_into()?;
509            full_history.extend(docs);
510        }
511
512        let chat_history: Vec<Message> = completion_request
513            .chat_history
514            .clone()
515            .into_iter()
516            .map(|message| message.try_into())
517            .collect::<Result<Vec<Vec<Message>>, _>>()?
518            .into_iter()
519            .flatten()
520            .collect();
521
522        full_history.extend(chat_history);
523
524        let model = self.client.sub_provider.model_identifier(&self.model);
525
526        let request = if completion_request.tools.is_empty() {
527            json!({
528                "model": model,
529                "messages": full_history,
530                "temperature": completion_request.temperature,
531            })
532        } else {
533            json!({
534                "model": model,
535                "messages": full_history,
536                "temperature": completion_request.temperature,
537                "tools": completion_request.tools.clone().into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
538                "tool_choice": "auto",
539            })
540        };
541        Ok(request)
542    }
543}
544
545impl completion::CompletionModel for CompletionModel {
546    type Response = CompletionResponse;
547    type StreamingResponse = StreamingCompletionResponse;
548
549    #[cfg_attr(feature = "worker", worker::send)]
550    async fn completion(
551        &self,
552        completion_request: CompletionRequest,
553    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
554        let request = self.create_request_body(&completion_request)?;
555
556        let path = self.client.sub_provider.completion_endpoint(&self.model);
557
558        let request = if let Some(ref params) = completion_request.additional_params {
559            json_utils::merge(request, params.clone())
560        } else {
561            request
562        };
563
564        let response = self.client.post(&path).json(&request).send().await?;
565
566        if response.status().is_success() {
567            let t = response.text().await?;
568            tracing::debug!(target: "rig", "Huggingface completion error: {}", t);
569
570            match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
571                ApiResponse::Ok(response) => {
572                    tracing::info!(target: "rig",
573                        "Huggingface completion token usage: {:?}",
574                        format!("{:?}", response.usage)
575                    );
576                    response.try_into()
577                }
578                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())),
579            }
580        } else {
581            Err(CompletionError::ProviderError(format!(
582                "{}: {}",
583                response.status(),
584                response.text().await?
585            )))
586        }
587    }
588
589    #[cfg_attr(feature = "worker", worker::send)]
590    async fn stream(
591        &self,
592        request: CompletionRequest,
593    ) -> Result<
594        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
595        CompletionError,
596    > {
597        CompletionModel::stream(self, request).await
598    }
599}
600
601#[cfg(test)]
602mod tests {
603    use super::*;
604    use serde_path_to_error::deserialize;
605
606    #[test]
607    fn test_deserialize_message() {
608        let assistant_message_json = r#"
609        {
610            "role": "assistant",
611            "content": "\n\nHello there, how may I assist you today?"
612        }
613        "#;
614
615        let assistant_message_json2 = r#"
616        {
617            "role": "assistant",
618            "content": [
619                {
620                    "type": "text",
621                    "text": "\n\nHello there, how may I assist you today?"
622                }
623            ],
624            "tool_calls": null
625        }
626        "#;
627
628        let assistant_message_json3 = r#"
629        {
630            "role": "assistant",
631            "tool_calls": [
632                {
633                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
634                    "type": "function",
635                    "function": {
636                        "name": "subtract",
637                        "arguments": {"x": 2, "y": 5}
638                    }
639                }
640            ],
641            "content": null,
642            "refusal": null
643        }
644        "#;
645
646        let user_message_json = r#"
647        {
648            "role": "user",
649            "content": [
650                {
651                    "type": "text",
652                    "text": "What's in this image?"
653                },
654                {
655                    "type": "image_url",
656                    "image_url": {
657                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
658                    }
659                }
660            ]
661        }
662        "#;
663
664        let assistant_message: Message = {
665            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
666            deserialize(jd).unwrap_or_else(|err| {
667                panic!(
668                    "Deserialization error at {} ({}:{}): {}",
669                    err.path(),
670                    err.inner().line(),
671                    err.inner().column(),
672                    err
673                );
674            })
675        };
676
677        let assistant_message2: Message = {
678            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
679            deserialize(jd).unwrap_or_else(|err| {
680                panic!(
681                    "Deserialization error at {} ({}:{}): {}",
682                    err.path(),
683                    err.inner().line(),
684                    err.inner().column(),
685                    err
686                );
687            })
688        };
689
690        let assistant_message3: Message = {
691            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
692                &mut serde_json::Deserializer::from_str(assistant_message_json3);
693            deserialize(jd).unwrap_or_else(|err| {
694                panic!(
695                    "Deserialization error at {} ({}:{}): {}",
696                    err.path(),
697                    err.inner().line(),
698                    err.inner().column(),
699                    err
700                );
701            })
702        };
703
704        let user_message: Message = {
705            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
706            deserialize(jd).unwrap_or_else(|err| {
707                panic!(
708                    "Deserialization error at {} ({}:{}): {}",
709                    err.path(),
710                    err.inner().line(),
711                    err.inner().column(),
712                    err
713                );
714            })
715        };
716
717        match assistant_message {
718            Message::Assistant { content, .. } => {
719                assert_eq!(
720                    content[0],
721                    AssistantContent::Text {
722                        text: "\n\nHello there, how may I assist you today?".to_string()
723                    }
724                );
725            }
726            _ => panic!("Expected assistant message"),
727        }
728
729        match assistant_message2 {
730            Message::Assistant {
731                content,
732                tool_calls,
733                ..
734            } => {
735                assert_eq!(
736                    content[0],
737                    AssistantContent::Text {
738                        text: "\n\nHello there, how may I assist you today?".to_string()
739                    }
740                );
741
742                assert_eq!(tool_calls, vec![]);
743            }
744            _ => panic!("Expected assistant message"),
745        }
746
747        match assistant_message3 {
748            Message::Assistant {
749                content,
750                tool_calls,
751                ..
752            } => {
753                assert!(content.is_empty());
754                assert_eq!(
755                    tool_calls[0],
756                    ToolCall {
757                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
758                        r#type: ToolType::Function,
759                        function: Function {
760                            name: "subtract".to_string(),
761                            arguments: serde_json::json!({"x": 2, "y": 5}),
762                        },
763                    }
764                );
765            }
766            _ => panic!("Expected assistant message"),
767        }
768
769        match user_message {
770            Message::User { content, .. } => {
771                let (first, second) = {
772                    let mut iter = content.into_iter();
773                    (iter.next().unwrap(), iter.next().unwrap())
774                };
775                assert_eq!(
776                    first,
777                    UserContent::Text {
778                        text: "What's in this image?".to_string()
779                    }
780                );
781                assert_eq!(second, UserContent::ImageUrl { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string() } });
782            }
783            _ => panic!("Expected user message"),
784        }
785    }
786
787    #[test]
788    fn test_message_to_message_conversion() {
789        let user_message = message::Message::User {
790            content: OneOrMany::one(message::UserContent::text("Hello")),
791        };
792
793        let assistant_message = message::Message::Assistant {
794            id: None,
795            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
796        };
797
798        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
799        let converted_assistant_message: Vec<Message> =
800            assistant_message.clone().try_into().unwrap();
801
802        match converted_user_message[0].clone() {
803            Message::User { content, .. } => {
804                assert_eq!(
805                    content.first(),
806                    UserContent::Text {
807                        text: "Hello".to_string()
808                    }
809                );
810            }
811            _ => panic!("Expected user message"),
812        }
813
814        match converted_assistant_message[0].clone() {
815            Message::Assistant { content, .. } => {
816                assert_eq!(
817                    content[0],
818                    AssistantContent::Text {
819                        text: "Hi there!".to_string()
820                    }
821                );
822            }
823            _ => panic!("Expected assistant message"),
824        }
825
826        let original_user_message: message::Message =
827            converted_user_message[0].clone().try_into().unwrap();
828        let original_assistant_message: message::Message =
829            converted_assistant_message[0].clone().try_into().unwrap();
830
831        assert_eq!(original_user_message, user_message);
832        assert_eq!(original_assistant_message, assistant_message);
833    }
834
835    #[test]
836    fn test_message_from_message_conversion() {
837        let user_message = Message::User {
838            content: OneOrMany::one(UserContent::Text {
839                text: "Hello".to_string(),
840            }),
841        };
842
843        let assistant_message = Message::Assistant {
844            content: vec![AssistantContent::Text {
845                text: "Hi there!".to_string(),
846            }],
847            tool_calls: vec![],
848        };
849
850        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
851        let converted_assistant_message: message::Message =
852            assistant_message.clone().try_into().unwrap();
853
854        match converted_user_message.clone() {
855            message::Message::User { content } => {
856                assert_eq!(content.first(), message::UserContent::text("Hello"));
857            }
858            _ => panic!("Expected user message"),
859        }
860
861        match converted_assistant_message.clone() {
862            message::Message::Assistant { content, .. } => {
863                assert_eq!(
864                    content.first(),
865                    message::AssistantContent::text("Hi there!")
866                );
867            }
868            _ => panic!("Expected assistant message"),
869        }
870
871        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
872        let original_assistant_message: Vec<Message> =
873            converted_assistant_message.try_into().unwrap();
874
875        assert_eq!(original_user_message[0], user_message);
876        assert_eq!(original_assistant_message[0], assistant_message);
877    }
878
879    #[test]
880    fn test_responses() {
881        let fireworks_response_json = r#"
882        {
883            "choices": [
884                {
885                    "finish_reason": "tool_calls",
886                    "index": 0,
887                    "message": {
888                        "role": "assistant",
889                        "tool_calls": [
890                            {
891                                "function": {
892                                "arguments": "{\"x\": 2, \"y\": 5}",
893                                "name": "subtract"
894                                },
895                                "id": "call_1BspL6mQqjKgvsQbH1TIYkHf",
896                                "index": 0,
897                                "type": "function"
898                            }
899                        ]
900                    }
901                }
902            ],
903            "created": 1740704000,
904            "id": "2a81f6a1-4866-42fb-9902-2655a2b5b1ff",
905            "model": "accounts/fireworks/models/deepseek-v3",
906            "object": "chat.completion",
907            "usage": {
908                "completion_tokens": 26,
909                "prompt_tokens": 248,
910                "total_tokens": 274
911            }
912        }
913        "#;
914
915        let novita_response_json = r#"
916        {
917            "choices": [
918                {
919                    "finish_reason": "tool_calls",
920                    "index": 0,
921                    "logprobs": null,
922                    "message": {
923                        "audio": null,
924                        "content": null,
925                        "function_call": null,
926                        "reasoning_content": null,
927                        "refusal": null,
928                        "role": "assistant",
929                        "tool_calls": [
930                            {
931                                "function": {
932                                    "arguments": "{\"x\": \"2\", \"y\": \"5\"}",
933                                    "name": "subtract"
934                                },
935                                "id": "chatcmpl-tool-f6d2af7c8dc041058f95e2c2eede45c5",
936                                "type": "function"
937                            }
938                        ]
939                    },
940                    "stop_reason": 128008
941                }
942            ],
943            "created": 1740704592,
944            "id": "chatcmpl-a92c60ae125c47c998ecdcb53387fed4",
945            "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-fast",
946            "object": "chat.completion",
947            "prompt_logprobs": null,
948            "service_tier": null,
949            "system_fingerprint": null,
950            "usage": {
951                "completion_tokens": 28,
952                "completion_tokens_details": null,
953                "prompt_tokens": 335,
954                "prompt_tokens_details": null,
955                "total_tokens": 363
956            }
957        }
958        "#;
959
960        let _firework_response: CompletionResponse = {
961            let jd = &mut serde_json::Deserializer::from_str(fireworks_response_json);
962            deserialize(jd).unwrap_or_else(|err| {
963                panic!(
964                    "Deserialization error at {} ({}:{}): {}",
965                    err.path(),
966                    err.inner().line(),
967                    err.inner().column(),
968                    err
969                );
970            })
971        };
972
973        let _novita_response: CompletionResponse = {
974            let jd = &mut serde_json::Deserializer::from_str(novita_response_json);
975            deserialize(jd).unwrap_or_else(|err| {
976                panic!(
977                    "Deserialization error at {} ({}:{}): {}",
978                    err.path(),
979                    err.inner().line(),
980                    err.inner().column(),
981                    err
982                );
983            })
984        };
985    }
986}