rig/providers/huggingface/
completion.rs

1use serde::{Deserialize, Deserializer, Serialize};
2use serde_json::{Value, json};
3use std::{convert::Infallible, str::FromStr};
4
5use super::client::Client;
6use crate::providers::openai::StreamingCompletionResponse;
7use crate::{
8    OneOrMany,
9    completion::{self, CompletionError, CompletionRequest},
10    json_utils,
11    message::{self},
12    one_or_many::string_or_one_or_many,
13};
14
15#[derive(Debug, Deserialize)]
16#[serde(untagged)]
17pub enum ApiResponse<T> {
18    Ok(T),
19    Err(Value),
20}
21
22// ================================================================
23// Huggingface Completion API
24// ================================================================
25
26// Conversational LLMs
27
28/// `google/gemma-2-2b-it` completion model
29pub const GEMMA_2: &str = "google/gemma-2-2b-it";
30/// `meta-llama/Meta-Llama-3.1-8B-Instruct` completion model
31pub const META_LLAMA_3_1: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
32/// `microsoft/phi-4` completion model
33pub const PHI_4: &str = "microsoft/phi-4";
34/// `PowerInfer/SmallThinker-3B-Preview` completion model
35pub const SMALLTHINKER_PREVIEW: &str = "PowerInfer/SmallThinker-3B-Preview";
36/// `Qwen/Qwen2.5-7B-Instruct` completion model
37pub const QWEN2_5: &str = "Qwen/Qwen2.5-7B-Instruct";
38/// `Qwen/Qwen2.5-Coder-32B-Instruct` completion model
39pub const QWEN2_5_CODER: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
40
41// Conversational VLMs
42
43/// `Qwen/Qwen2-VL-7B-Instruct` visual-language completion model
44pub const QWEN2_VL: &str = "Qwen/Qwen2-VL-7B-Instruct";
45/// `Qwen/QVQ-72B-Preview` visual-language completion model
46pub const QWEN_QVQ_PREVIEW: &str = "Qwen/QVQ-72B-Preview";
47
48#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
49pub struct Function {
50    name: String,
51    #[serde(deserialize_with = "deserialize_arguments")]
52    pub arguments: serde_json::Value,
53}
54
55fn deserialize_arguments<'de, D>(deserializer: D) -> Result<Value, D::Error>
56where
57    D: Deserializer<'de>,
58{
59    let value = Value::deserialize(deserializer)?;
60
61    match value {
62        Value::String(s) => serde_json::from_str(&s).map_err(serde::de::Error::custom),
63        other => Ok(other),
64    }
65}
66
67impl From<Function> for message::ToolFunction {
68    fn from(value: Function) -> Self {
69        message::ToolFunction {
70            name: value.name,
71            arguments: value.arguments,
72        }
73    }
74}
75
76#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
77#[serde(rename_all = "lowercase")]
78pub enum ToolType {
79    #[default]
80    Function,
81}
82
83#[derive(Debug, Deserialize, Serialize, Clone)]
84pub struct ToolDefinition {
85    pub r#type: String,
86    pub function: completion::ToolDefinition,
87}
88
89impl From<completion::ToolDefinition> for ToolDefinition {
90    fn from(tool: completion::ToolDefinition) -> Self {
91        Self {
92            r#type: "function".into(),
93            function: tool,
94        }
95    }
96}
97
98#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
99pub struct ToolCall {
100    pub id: String,
101    pub r#type: ToolType,
102    pub function: Function,
103}
104
105impl From<ToolCall> for message::ToolCall {
106    fn from(value: ToolCall) -> Self {
107        message::ToolCall {
108            id: value.id,
109            call_id: None,
110            function: value.function.into(),
111        }
112    }
113}
114
115impl From<message::ToolCall> for ToolCall {
116    fn from(value: message::ToolCall) -> Self {
117        ToolCall {
118            id: value.id,
119            r#type: ToolType::Function,
120            function: Function {
121                name: value.function.name,
122                arguments: value.function.arguments,
123            },
124        }
125    }
126}
127
128#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
129pub struct ImageUrl {
130    url: String,
131}
132
133#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
134#[serde(tag = "type", rename_all = "lowercase")]
135pub enum UserContent {
136    Text {
137        text: String,
138    },
139    #[serde(rename = "image_url")]
140    ImageUrl {
141        image_url: ImageUrl,
142    },
143}
144
145impl FromStr for UserContent {
146    type Err = Infallible;
147
148    fn from_str(s: &str) -> Result<Self, Self::Err> {
149        Ok(UserContent::Text {
150            text: s.to_string(),
151        })
152    }
153}
154
155#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
156#[serde(tag = "type", rename_all = "lowercase")]
157pub enum AssistantContent {
158    Text { text: String },
159}
160
161impl FromStr for AssistantContent {
162    type Err = Infallible;
163
164    fn from_str(s: &str) -> Result<Self, Self::Err> {
165        Ok(AssistantContent::Text {
166            text: s.to_string(),
167        })
168    }
169}
170
171#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
172#[serde(tag = "type", rename_all = "lowercase")]
173pub enum SystemContent {
174    Text { text: String },
175}
176
177impl FromStr for SystemContent {
178    type Err = Infallible;
179
180    fn from_str(s: &str) -> Result<Self, Self::Err> {
181        Ok(SystemContent::Text {
182            text: s.to_string(),
183        })
184    }
185}
186
187impl From<UserContent> for message::UserContent {
188    fn from(value: UserContent) -> Self {
189        match value {
190            UserContent::Text { text } => message::UserContent::text(text),
191            UserContent::ImageUrl { image_url } => message::UserContent::image(
192                image_url.url,
193                Some(message::ContentFormat::String),
194                None,
195                None,
196            ),
197        }
198    }
199}
200
201impl TryFrom<message::UserContent> for UserContent {
202    type Error = message::MessageError;
203
204    fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
205        match content {
206            message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }),
207            message::UserContent::Image(message::Image { data, format, .. }) => match format {
208                Some(message::ContentFormat::String) => Ok(UserContent::ImageUrl {
209                    image_url: ImageUrl { url: data },
210                }),
211                _ => Err(message::MessageError::ConversionError(
212                    "Huggingface only supports images as urls".into(),
213                )),
214            },
215            _ => Err(message::MessageError::ConversionError(
216                "Huggingface only supports text and images".into(),
217            )),
218        }
219    }
220}
221
222#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
223#[serde(tag = "role", rename_all = "lowercase")]
224pub enum Message {
225    System {
226        #[serde(deserialize_with = "string_or_one_or_many")]
227        content: OneOrMany<SystemContent>,
228    },
229    User {
230        #[serde(deserialize_with = "string_or_one_or_many")]
231        content: OneOrMany<UserContent>,
232    },
233    Assistant {
234        #[serde(default, deserialize_with = "json_utils::string_or_vec")]
235        content: Vec<AssistantContent>,
236        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
237        tool_calls: Vec<ToolCall>,
238    },
239    #[serde(rename = "Tool")]
240    ToolResult {
241        name: String,
242        #[serde(skip_serializing_if = "Option::is_none")]
243        arguments: Option<serde_json::Value>,
244        #[serde(deserialize_with = "string_or_one_or_many")]
245        content: OneOrMany<String>,
246    },
247}
248
249impl Message {
250    pub fn system(content: &str) -> Self {
251        Message::System {
252            content: OneOrMany::one(SystemContent::Text {
253                text: content.to_string(),
254            }),
255        }
256    }
257}
258
259impl TryFrom<message::Message> for Vec<Message> {
260    type Error = message::MessageError;
261
262    fn try_from(message: message::Message) -> Result<Vec<Message>, Self::Error> {
263        match message {
264            message::Message::User { content } => {
265                let (tool_results, other_content): (Vec<_>, Vec<_>) = content
266                    .into_iter()
267                    .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
268
269                if !tool_results.is_empty() {
270                    tool_results
271                        .into_iter()
272                        .map(|content| match content {
273                            message::UserContent::ToolResult(message::ToolResult {
274                                id,
275                                content,
276                                ..
277                            }) => Ok::<_, message::MessageError>(Message::ToolResult {
278                                name: id,
279                                arguments: None,
280                                content: content.try_map(|content| match content {
281                                    message::ToolResultContent::Text(message::Text { text }) => {
282                                        Ok(text)
283                                    }
284                                    _ => Err(message::MessageError::ConversionError(
285                                        "Tool result content does not support non-text".into(),
286                                    )),
287                                })?,
288                            }),
289                            _ => unreachable!(),
290                        })
291                        .collect::<Result<Vec<_>, _>>()
292                } else {
293                    let other_content = OneOrMany::many(other_content).expect(
294                        "There must be other content here if there were no tool result content",
295                    );
296
297                    Ok(vec![Message::User {
298                        content: other_content.try_map(|content| match content {
299                            message::UserContent::Text(text) => {
300                                Ok(UserContent::Text { text: text.text })
301                            }
302                            _ => Err(message::MessageError::ConversionError(
303                                "Huggingface does not support non-text".into(),
304                            )),
305                        })?,
306                    }])
307                }
308            }
309            message::Message::Assistant { content, .. } => {
310                let (text_content, tool_calls) = content.into_iter().fold(
311                    (Vec::new(), Vec::new()),
312                    |(mut texts, mut tools), content| {
313                        match content {
314                            message::AssistantContent::Text(text) => texts.push(text),
315                            message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
316                            message::AssistantContent::Reasoning(_) => {
317                                unimplemented!("Reasoning is not supported on HuggingFace via Rig");
318                            }
319                        }
320                        (texts, tools)
321                    },
322                );
323
324                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
325                //  so either `content` or `tool_calls` will have some content.
326                Ok(vec![Message::Assistant {
327                    content: text_content
328                        .into_iter()
329                        .map(|content| AssistantContent::Text { text: content.text })
330                        .collect::<Vec<_>>(),
331                    tool_calls: tool_calls
332                        .into_iter()
333                        .map(|tool_call| tool_call.into())
334                        .collect::<Vec<_>>(),
335                }])
336            }
337        }
338    }
339}
340
341impl TryFrom<Message> for message::Message {
342    type Error = message::MessageError;
343
344    fn try_from(message: Message) -> Result<Self, Self::Error> {
345        Ok(match message {
346            Message::User { content, .. } => message::Message::User {
347                content: content.map(|content| content.into()),
348            },
349            Message::Assistant {
350                content,
351                tool_calls,
352                ..
353            } => {
354                let mut content = content
355                    .into_iter()
356                    .map(|content| match content {
357                        AssistantContent::Text { text } => message::AssistantContent::text(text),
358                    })
359                    .collect::<Vec<_>>();
360
361                content.extend(
362                    tool_calls
363                        .into_iter()
364                        .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
365                        .collect::<Result<Vec<_>, _>>()?,
366                );
367
368                message::Message::Assistant {
369                    id: None,
370                    content: OneOrMany::many(content).map_err(|_| {
371                        message::MessageError::ConversionError(
372                            "Neither `content` nor `tool_calls` was provided to the Message"
373                                .to_owned(),
374                        )
375                    })?,
376                }
377            }
378
379            Message::ToolResult { name, content, .. } => message::Message::User {
380                content: OneOrMany::one(message::UserContent::tool_result(
381                    name,
382                    content.map(message::ToolResultContent::text),
383                )),
384            },
385
386            // System messages should get stripped out when converting message's, this is just a
387            // stop gap to avoid obnoxious error handling or panic occurring.
388            Message::System { content, .. } => message::Message::User {
389                content: content.map(|c| match c {
390                    SystemContent::Text { text } => message::UserContent::text(text),
391                }),
392            },
393        })
394    }
395}
396
397#[derive(Debug, Deserialize, Serialize)]
398pub struct Choice {
399    pub finish_reason: String,
400    pub index: usize,
401    #[serde(default)]
402    pub logprobs: serde_json::Value,
403    pub message: Message,
404}
405
406#[derive(Debug, Deserialize, Clone, Serialize)]
407pub struct Usage {
408    pub completion_tokens: i32,
409    pub prompt_tokens: i32,
410    pub total_tokens: i32,
411}
412
413#[derive(Debug, Deserialize, Serialize)]
414pub struct CompletionResponse {
415    pub created: i32,
416    pub id: String,
417    pub model: String,
418    pub choices: Vec<Choice>,
419    #[serde(default, deserialize_with = "default_string_on_null")]
420    pub system_fingerprint: String,
421    pub usage: Usage,
422}
423
424fn default_string_on_null<'de, D>(deserializer: D) -> Result<String, D::Error>
425where
426    D: Deserializer<'de>,
427{
428    match Option::<String>::deserialize(deserializer)? {
429        Some(value) => Ok(value),      // Use provided value
430        None => Ok(String::default()), // Use `Default` implementation
431    }
432}
433
434impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
435    type Error = CompletionError;
436
437    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
438        let choice = response.choices.first().ok_or_else(|| {
439            CompletionError::ResponseError("Response contained no choices".to_owned())
440        })?;
441
442        let content = match &choice.message {
443            Message::Assistant {
444                content,
445                tool_calls,
446                ..
447            } => {
448                let mut content = content
449                    .iter()
450                    .map(|c| match c {
451                        AssistantContent::Text { text } => message::AssistantContent::text(text),
452                    })
453                    .collect::<Vec<_>>();
454
455                content.extend(
456                    tool_calls
457                        .iter()
458                        .map(|call| {
459                            completion::AssistantContent::tool_call(
460                                &call.id,
461                                &call.function.name,
462                                call.function.arguments.clone(),
463                            )
464                        })
465                        .collect::<Vec<_>>(),
466                );
467                Ok(content)
468            }
469            _ => Err(CompletionError::ResponseError(
470                "Response did not contain a valid message or tool call".into(),
471            )),
472        }?;
473
474        let choice = OneOrMany::many(content).map_err(|_| {
475            CompletionError::ResponseError(
476                "Response contained no message or tool call (empty)".to_owned(),
477            )
478        })?;
479
480        let usage = completion::Usage {
481            input_tokens: response.usage.prompt_tokens as u64,
482            output_tokens: response.usage.completion_tokens as u64,
483            total_tokens: response.usage.total_tokens as u64,
484        };
485
486        Ok(completion::CompletionResponse {
487            choice,
488            usage,
489            raw_response: response,
490        })
491    }
492}
493
494#[derive(Clone)]
495pub struct CompletionModel {
496    pub(crate) client: Client,
497    /// Name of the model (e.g: google/gemma-2-2b-it)
498    pub model: String,
499}
500
501impl CompletionModel {
502    pub fn new(client: Client, model: &str) -> Self {
503        Self {
504            client,
505            model: model.to_string(),
506        }
507    }
508
509    pub(crate) fn create_request_body(
510        &self,
511        completion_request: &CompletionRequest,
512    ) -> Result<serde_json::Value, CompletionError> {
513        let mut full_history: Vec<Message> = match &completion_request.preamble {
514            Some(preamble) => vec![Message::system(preamble)],
515            None => vec![],
516        };
517        if let Some(docs) = completion_request.normalized_documents() {
518            let docs: Vec<Message> = docs.try_into()?;
519            full_history.extend(docs);
520        }
521
522        let chat_history: Vec<Message> = completion_request
523            .chat_history
524            .clone()
525            .into_iter()
526            .map(|message| message.try_into())
527            .collect::<Result<Vec<Vec<Message>>, _>>()?
528            .into_iter()
529            .flatten()
530            .collect();
531
532        full_history.extend(chat_history);
533
534        let model = self.client.sub_provider.model_identifier(&self.model);
535
536        let request = if completion_request.tools.is_empty() {
537            json!({
538                "model": model,
539                "messages": full_history,
540                "temperature": completion_request.temperature,
541            })
542        } else {
543            json!({
544                "model": model,
545                "messages": full_history,
546                "temperature": completion_request.temperature,
547                "tools": completion_request.tools.clone().into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
548                "tool_choice": "auto",
549            })
550        };
551        Ok(request)
552    }
553}
554
555impl completion::CompletionModel for CompletionModel {
556    type Response = CompletionResponse;
557    type StreamingResponse = StreamingCompletionResponse;
558
559    #[cfg_attr(feature = "worker", worker::send)]
560    async fn completion(
561        &self,
562        completion_request: CompletionRequest,
563    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
564        let request = self.create_request_body(&completion_request)?;
565
566        let path = self.client.sub_provider.completion_endpoint(&self.model);
567
568        let request = if let Some(ref params) = completion_request.additional_params {
569            json_utils::merge(request, params.clone())
570        } else {
571            request
572        };
573
574        let response = self.client.post(&path).json(&request).send().await?;
575
576        if response.status().is_success() {
577            let t = response.text().await?;
578            tracing::debug!(target: "rig", "Huggingface completion error: {}", t);
579
580            match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
581                ApiResponse::Ok(response) => {
582                    tracing::info!(target: "rig",
583                        "Huggingface completion token usage: {:?}",
584                        format!("{:?}", response.usage)
585                    );
586                    response.try_into()
587                }
588                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())),
589            }
590        } else {
591            Err(CompletionError::ProviderError(format!(
592                "{}: {}",
593                response.status(),
594                response.text().await?
595            )))
596        }
597    }
598
599    #[cfg_attr(feature = "worker", worker::send)]
600    async fn stream(
601        &self,
602        request: CompletionRequest,
603    ) -> Result<
604        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
605        CompletionError,
606    > {
607        CompletionModel::stream(self, request).await
608    }
609}
610
611#[cfg(test)]
612mod tests {
613    use super::*;
614    use serde_path_to_error::deserialize;
615
616    #[test]
617    fn test_deserialize_message() {
618        let assistant_message_json = r#"
619        {
620            "role": "assistant",
621            "content": "\n\nHello there, how may I assist you today?"
622        }
623        "#;
624
625        let assistant_message_json2 = r#"
626        {
627            "role": "assistant",
628            "content": [
629                {
630                    "type": "text",
631                    "text": "\n\nHello there, how may I assist you today?"
632                }
633            ],
634            "tool_calls": null
635        }
636        "#;
637
638        let assistant_message_json3 = r#"
639        {
640            "role": "assistant",
641            "tool_calls": [
642                {
643                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
644                    "type": "function",
645                    "function": {
646                        "name": "subtract",
647                        "arguments": {"x": 2, "y": 5}
648                    }
649                }
650            ],
651            "content": null,
652            "refusal": null
653        }
654        "#;
655
656        let user_message_json = r#"
657        {
658            "role": "user",
659            "content": [
660                {
661                    "type": "text",
662                    "text": "What's in this image?"
663                },
664                {
665                    "type": "image_url",
666                    "image_url": {
667                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
668                    }
669                }
670            ]
671        }
672        "#;
673
674        let assistant_message: Message = {
675            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
676            deserialize(jd).unwrap_or_else(|err| {
677                panic!(
678                    "Deserialization error at {} ({}:{}): {}",
679                    err.path(),
680                    err.inner().line(),
681                    err.inner().column(),
682                    err
683                );
684            })
685        };
686
687        let assistant_message2: Message = {
688            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
689            deserialize(jd).unwrap_or_else(|err| {
690                panic!(
691                    "Deserialization error at {} ({}:{}): {}",
692                    err.path(),
693                    err.inner().line(),
694                    err.inner().column(),
695                    err
696                );
697            })
698        };
699
700        let assistant_message3: Message = {
701            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
702                &mut serde_json::Deserializer::from_str(assistant_message_json3);
703            deserialize(jd).unwrap_or_else(|err| {
704                panic!(
705                    "Deserialization error at {} ({}:{}): {}",
706                    err.path(),
707                    err.inner().line(),
708                    err.inner().column(),
709                    err
710                );
711            })
712        };
713
714        let user_message: Message = {
715            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
716            deserialize(jd).unwrap_or_else(|err| {
717                panic!(
718                    "Deserialization error at {} ({}:{}): {}",
719                    err.path(),
720                    err.inner().line(),
721                    err.inner().column(),
722                    err
723                );
724            })
725        };
726
727        match assistant_message {
728            Message::Assistant { content, .. } => {
729                assert_eq!(
730                    content[0],
731                    AssistantContent::Text {
732                        text: "\n\nHello there, how may I assist you today?".to_string()
733                    }
734                );
735            }
736            _ => panic!("Expected assistant message"),
737        }
738
739        match assistant_message2 {
740            Message::Assistant {
741                content,
742                tool_calls,
743                ..
744            } => {
745                assert_eq!(
746                    content[0],
747                    AssistantContent::Text {
748                        text: "\n\nHello there, how may I assist you today?".to_string()
749                    }
750                );
751
752                assert_eq!(tool_calls, vec![]);
753            }
754            _ => panic!("Expected assistant message"),
755        }
756
757        match assistant_message3 {
758            Message::Assistant {
759                content,
760                tool_calls,
761                ..
762            } => {
763                assert!(content.is_empty());
764                assert_eq!(
765                    tool_calls[0],
766                    ToolCall {
767                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
768                        r#type: ToolType::Function,
769                        function: Function {
770                            name: "subtract".to_string(),
771                            arguments: serde_json::json!({"x": 2, "y": 5}),
772                        },
773                    }
774                );
775            }
776            _ => panic!("Expected assistant message"),
777        }
778
779        match user_message {
780            Message::User { content, .. } => {
781                let (first, second) = {
782                    let mut iter = content.into_iter();
783                    (iter.next().unwrap(), iter.next().unwrap())
784                };
785                assert_eq!(
786                    first,
787                    UserContent::Text {
788                        text: "What's in this image?".to_string()
789                    }
790                );
791                assert_eq!(second, UserContent::ImageUrl { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string() } });
792            }
793            _ => panic!("Expected user message"),
794        }
795    }
796
797    #[test]
798    fn test_message_to_message_conversion() {
799        let user_message = message::Message::User {
800            content: OneOrMany::one(message::UserContent::text("Hello")),
801        };
802
803        let assistant_message = message::Message::Assistant {
804            id: None,
805            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
806        };
807
808        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
809        let converted_assistant_message: Vec<Message> =
810            assistant_message.clone().try_into().unwrap();
811
812        match converted_user_message[0].clone() {
813            Message::User { content, .. } => {
814                assert_eq!(
815                    content.first(),
816                    UserContent::Text {
817                        text: "Hello".to_string()
818                    }
819                );
820            }
821            _ => panic!("Expected user message"),
822        }
823
824        match converted_assistant_message[0].clone() {
825            Message::Assistant { content, .. } => {
826                assert_eq!(
827                    content[0],
828                    AssistantContent::Text {
829                        text: "Hi there!".to_string()
830                    }
831                );
832            }
833            _ => panic!("Expected assistant message"),
834        }
835
836        let original_user_message: message::Message =
837            converted_user_message[0].clone().try_into().unwrap();
838        let original_assistant_message: message::Message =
839            converted_assistant_message[0].clone().try_into().unwrap();
840
841        assert_eq!(original_user_message, user_message);
842        assert_eq!(original_assistant_message, assistant_message);
843    }
844
845    #[test]
846    fn test_message_from_message_conversion() {
847        let user_message = Message::User {
848            content: OneOrMany::one(UserContent::Text {
849                text: "Hello".to_string(),
850            }),
851        };
852
853        let assistant_message = Message::Assistant {
854            content: vec![AssistantContent::Text {
855                text: "Hi there!".to_string(),
856            }],
857            tool_calls: vec![],
858        };
859
860        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
861        let converted_assistant_message: message::Message =
862            assistant_message.clone().try_into().unwrap();
863
864        match converted_user_message.clone() {
865            message::Message::User { content } => {
866                assert_eq!(content.first(), message::UserContent::text("Hello"));
867            }
868            _ => panic!("Expected user message"),
869        }
870
871        match converted_assistant_message.clone() {
872            message::Message::Assistant { content, .. } => {
873                assert_eq!(
874                    content.first(),
875                    message::AssistantContent::text("Hi there!")
876                );
877            }
878            _ => panic!("Expected assistant message"),
879        }
880
881        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
882        let original_assistant_message: Vec<Message> =
883            converted_assistant_message.try_into().unwrap();
884
885        assert_eq!(original_user_message[0], user_message);
886        assert_eq!(original_assistant_message[0], assistant_message);
887    }
888
889    #[test]
890    fn test_responses() {
891        let fireworks_response_json = r#"
892        {
893            "choices": [
894                {
895                    "finish_reason": "tool_calls",
896                    "index": 0,
897                    "message": {
898                        "role": "assistant",
899                        "tool_calls": [
900                            {
901                                "function": {
902                                "arguments": "{\"x\": 2, \"y\": 5}",
903                                "name": "subtract"
904                                },
905                                "id": "call_1BspL6mQqjKgvsQbH1TIYkHf",
906                                "index": 0,
907                                "type": "function"
908                            }
909                        ]
910                    }
911                }
912            ],
913            "created": 1740704000,
914            "id": "2a81f6a1-4866-42fb-9902-2655a2b5b1ff",
915            "model": "accounts/fireworks/models/deepseek-v3",
916            "object": "chat.completion",
917            "usage": {
918                "completion_tokens": 26,
919                "prompt_tokens": 248,
920                "total_tokens": 274
921            }
922        }
923        "#;
924
925        let novita_response_json = r#"
926        {
927            "choices": [
928                {
929                    "finish_reason": "tool_calls",
930                    "index": 0,
931                    "logprobs": null,
932                    "message": {
933                        "audio": null,
934                        "content": null,
935                        "function_call": null,
936                        "reasoning_content": null,
937                        "refusal": null,
938                        "role": "assistant",
939                        "tool_calls": [
940                            {
941                                "function": {
942                                    "arguments": "{\"x\": \"2\", \"y\": \"5\"}",
943                                    "name": "subtract"
944                                },
945                                "id": "chatcmpl-tool-f6d2af7c8dc041058f95e2c2eede45c5",
946                                "type": "function"
947                            }
948                        ]
949                    },
950                    "stop_reason": 128008
951                }
952            ],
953            "created": 1740704592,
954            "id": "chatcmpl-a92c60ae125c47c998ecdcb53387fed4",
955            "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-fast",
956            "object": "chat.completion",
957            "prompt_logprobs": null,
958            "service_tier": null,
959            "system_fingerprint": null,
960            "usage": {
961                "completion_tokens": 28,
962                "completion_tokens_details": null,
963                "prompt_tokens": 335,
964                "prompt_tokens_details": null,
965                "total_tokens": 363
966            }
967        }
968        "#;
969
970        let _firework_response: CompletionResponse = {
971            let jd = &mut serde_json::Deserializer::from_str(fireworks_response_json);
972            deserialize(jd).unwrap_or_else(|err| {
973                panic!(
974                    "Deserialization error at {} ({}:{}): {}",
975                    err.path(),
976                    err.inner().line(),
977                    err.inner().column(),
978                    err
979                );
980            })
981        };
982
983        let _novita_response: CompletionResponse = {
984            let jd = &mut serde_json::Deserializer::from_str(novita_response_json);
985            deserialize(jd).unwrap_or_else(|err| {
986                panic!(
987                    "Deserialization error at {} ({}:{}): {}",
988                    err.path(),
989                    err.inner().line(),
990                    err.inner().column(),
991                    err
992                );
993            })
994        };
995    }
996}