rig/providers/huggingface/
completion.rs

1use std::{convert::Infallible, str::FromStr};
2
3use serde::{Deserialize, Deserializer, Serialize};
4use serde_json::{json, Value};
5
6use crate::{
7    completion::{self, CompletionError, CompletionRequest},
8    json_utils,
9    message::{self},
10    one_or_many::string_or_one_or_many,
11    OneOrMany,
12};
13
14use super::client::Client;
15
16#[derive(Debug, Deserialize)]
17#[serde(untagged)]
18pub enum ApiResponse<T> {
19    Ok(T),
20    Err(Value),
21}
22
23// ================================================================
24// Huggingface Completion API
25// ================================================================
26
27// Conversational LLMs
28
29/// `google/gemma-2-2b-it` completion model
30pub const GEMMA_2: &str = "google/gemma-2-2b-it";
31/// `meta-llama/Meta-Llama-3.1-8B-Instruct` completion model
32pub const META_LLAMA_3_1: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
33/// `microsoft/phi-4` completion model
34pub const PHI_4: &str = "microsoft/phi-4";
35/// `PowerInfer/SmallThinker-3B-Preview` completion model
36pub const SMALLTHINKER_PREVIEW: &str = "PowerInfer/SmallThinker-3B-Preview";
37/// `Qwen/Qwen2.5-7B-Instruct` completion model
38pub const QWEN2_5: &str = "Qwen/Qwen2.5-7B-Instruct";
39/// `Qwen/Qwen2.5-Coder-32B-Instruct` completion model
40pub const QWEN2_5_CODER: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
41
42// Conversational VLMs
43
44/// `Qwen/Qwen2-VL-7B-Instruct` visual-language completion model
45pub const QWEN2_VL: &str = "Qwen/Qwen2-VL-7B-Instruct";
46/// `Qwen/QVQ-72B-Preview` visual-language completion model
47pub const QWEN_QVQ_PREVIEW: &str = "Qwen/QVQ-72B-Preview";
48
49#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
50pub struct Function {
51    name: String,
52    #[serde(deserialize_with = "deserialize_arguments")]
53    pub arguments: serde_json::Value,
54}
55
56fn deserialize_arguments<'de, D>(deserializer: D) -> Result<Value, D::Error>
57where
58    D: Deserializer<'de>,
59{
60    let value = Value::deserialize(deserializer)?;
61
62    match value {
63        Value::String(s) => serde_json::from_str(&s).map_err(serde::de::Error::custom),
64        other => Ok(other),
65    }
66}
67
68impl From<Function> for message::ToolFunction {
69    fn from(value: Function) -> Self {
70        message::ToolFunction {
71            name: value.name,
72            arguments: value.arguments,
73        }
74    }
75}
76
77#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
78#[serde(rename_all = "lowercase")]
79pub enum ToolType {
80    #[default]
81    Function,
82}
83
84#[derive(Debug, Deserialize, Serialize, Clone)]
85pub struct ToolDefinition {
86    pub r#type: String,
87    pub function: completion::ToolDefinition,
88}
89
90impl From<completion::ToolDefinition> for ToolDefinition {
91    fn from(tool: completion::ToolDefinition) -> Self {
92        Self {
93            r#type: "function".into(),
94            function: tool,
95        }
96    }
97}
98
99#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
100pub struct ToolCall {
101    pub id: String,
102    pub r#type: ToolType,
103    pub function: Function,
104}
105
106impl From<ToolCall> for message::ToolCall {
107    fn from(value: ToolCall) -> Self {
108        message::ToolCall {
109            id: value.id,
110            function: value.function.into(),
111        }
112    }
113}
114
115impl From<message::ToolCall> for ToolCall {
116    fn from(value: message::ToolCall) -> Self {
117        ToolCall {
118            id: value.id,
119            r#type: ToolType::Function,
120            function: Function {
121                name: value.function.name,
122                arguments: value.function.arguments,
123            },
124        }
125    }
126}
127
128#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
129pub struct ImageUrl {
130    url: String,
131}
132
133#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
134#[serde(tag = "type", rename_all = "lowercase")]
135pub enum UserContent {
136    Text {
137        text: String,
138    },
139    #[serde(rename = "image_url")]
140    ImageUrl {
141        image_url: ImageUrl,
142    },
143}
144
145impl FromStr for UserContent {
146    type Err = Infallible;
147
148    fn from_str(s: &str) -> Result<Self, Self::Err> {
149        Ok(UserContent::Text {
150            text: s.to_string(),
151        })
152    }
153}
154
155#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
156#[serde(tag = "type", rename_all = "lowercase")]
157pub enum AssistantContent {
158    Text { text: String },
159}
160
161impl FromStr for AssistantContent {
162    type Err = Infallible;
163
164    fn from_str(s: &str) -> Result<Self, Self::Err> {
165        Ok(AssistantContent::Text {
166            text: s.to_string(),
167        })
168    }
169}
170
171#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
172#[serde(tag = "type", rename_all = "lowercase")]
173pub enum SystemContent {
174    Text { text: String },
175}
176
177impl FromStr for SystemContent {
178    type Err = Infallible;
179
180    fn from_str(s: &str) -> Result<Self, Self::Err> {
181        Ok(SystemContent::Text {
182            text: s.to_string(),
183        })
184    }
185}
186
187impl From<UserContent> for message::UserContent {
188    fn from(value: UserContent) -> Self {
189        match value {
190            UserContent::Text { text } => message::UserContent::text(text),
191            UserContent::ImageUrl { image_url } => message::UserContent::image(
192                image_url.url,
193                Some(message::ContentFormat::String),
194                None,
195                None,
196            ),
197        }
198    }
199}
200
201impl TryFrom<message::UserContent> for UserContent {
202    type Error = message::MessageError;
203
204    fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
205        match content {
206            message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }),
207            message::UserContent::Image(message::Image { data, format, .. }) => match format {
208                Some(message::ContentFormat::String) => Ok(UserContent::ImageUrl {
209                    image_url: ImageUrl { url: data },
210                }),
211                _ => Err(message::MessageError::ConversionError(
212                    "Huggingface only supports images as urls".into(),
213                )),
214            },
215            _ => Err(message::MessageError::ConversionError(
216                "Huggingface only supports text and images".into(),
217            )),
218        }
219    }
220}
221
222#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
223#[serde(tag = "role", rename_all = "lowercase")]
224pub enum Message {
225    System {
226        #[serde(deserialize_with = "string_or_one_or_many")]
227        content: OneOrMany<SystemContent>,
228    },
229    User {
230        #[serde(deserialize_with = "string_or_one_or_many")]
231        content: OneOrMany<UserContent>,
232    },
233    Assistant {
234        #[serde(default, deserialize_with = "json_utils::string_or_vec")]
235        content: Vec<AssistantContent>,
236        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
237        tool_calls: Vec<ToolCall>,
238    },
239    #[serde(rename = "Tool")]
240    ToolResult {
241        name: String,
242        #[serde(skip_serializing_if = "Option::is_none")]
243        arguments: Option<serde_json::Value>,
244        #[serde(deserialize_with = "string_or_one_or_many")]
245        content: OneOrMany<String>,
246    },
247}
248
249impl Message {
250    pub fn system(content: &str) -> Self {
251        Message::System {
252            content: OneOrMany::one(SystemContent::Text {
253                text: content.to_string(),
254            }),
255        }
256    }
257}
258
259impl TryFrom<message::Message> for Vec<Message> {
260    type Error = message::MessageError;
261
262    fn try_from(message: message::Message) -> Result<Vec<Message>, Self::Error> {
263        match message {
264            message::Message::User { content } => {
265                let (tool_results, other_content): (Vec<_>, Vec<_>) = content
266                    .into_iter()
267                    .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
268
269                if !tool_results.is_empty() {
270                    tool_results
271                        .into_iter()
272                        .map(|content| match content {
273                            message::UserContent::ToolResult(message::ToolResult {
274                                id,
275                                content,
276                            }) => Ok::<_, message::MessageError>(Message::ToolResult {
277                                name: id,
278                                arguments: None,
279                                content: content.try_map(|content| match content {
280                                    message::ToolResultContent::Text(message::Text { text }) => {
281                                        Ok(text)
282                                    }
283                                    _ => Err(message::MessageError::ConversionError(
284                                        "Tool result content does not support non-text".into(),
285                                    )),
286                                })?,
287                            }),
288                            _ => unreachable!(),
289                        })
290                        .collect::<Result<Vec<_>, _>>()
291                } else {
292                    let other_content = OneOrMany::many(other_content).expect(
293                        "There must be other content here if there were no tool result content",
294                    );
295
296                    Ok(vec![Message::User {
297                        content: other_content.try_map(|content| match content {
298                            message::UserContent::Text(text) => {
299                                Ok(UserContent::Text { text: text.text })
300                            }
301                            _ => Err(message::MessageError::ConversionError(
302                                "Huggingface does not support non-text".into(),
303                            )),
304                        })?,
305                    }])
306                }
307            }
308            message::Message::Assistant { content } => {
309                let (text_content, tool_calls) = content.into_iter().fold(
310                    (Vec::new(), Vec::new()),
311                    |(mut texts, mut tools), content| {
312                        match content {
313                            message::AssistantContent::Text(text) => texts.push(text),
314                            message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
315                        }
316                        (texts, tools)
317                    },
318                );
319
320                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
321                //  so either `content` or `tool_calls` will have some content.
322                Ok(vec![Message::Assistant {
323                    content: text_content
324                        .into_iter()
325                        .map(|content| AssistantContent::Text { text: content.text })
326                        .collect::<Vec<_>>(),
327                    tool_calls: tool_calls
328                        .into_iter()
329                        .map(|tool_call| tool_call.into())
330                        .collect::<Vec<_>>(),
331                }])
332            }
333        }
334    }
335}
336
337impl TryFrom<Message> for message::Message {
338    type Error = message::MessageError;
339
340    fn try_from(message: Message) -> Result<Self, Self::Error> {
341        Ok(match message {
342            Message::User { content, .. } => message::Message::User {
343                content: content.map(|content| content.into()),
344            },
345            Message::Assistant {
346                content,
347                tool_calls,
348                ..
349            } => {
350                let mut content = content
351                    .into_iter()
352                    .map(|content| match content {
353                        AssistantContent::Text { text } => message::AssistantContent::text(text),
354                    })
355                    .collect::<Vec<_>>();
356
357                content.extend(
358                    tool_calls
359                        .into_iter()
360                        .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
361                        .collect::<Result<Vec<_>, _>>()?,
362                );
363
364                message::Message::Assistant {
365                    content: OneOrMany::many(content).map_err(|_| {
366                        message::MessageError::ConversionError(
367                            "Neither `content` nor `tool_calls` was provided to the Message"
368                                .to_owned(),
369                        )
370                    })?,
371                }
372            }
373
374            Message::ToolResult { name, content, .. } => message::Message::User {
375                content: OneOrMany::one(message::UserContent::tool_result(
376                    name,
377                    content.map(message::ToolResultContent::text),
378                )),
379            },
380
381            // System messages should get stripped out when converting message's, this is just a
382            // stop gap to avoid obnoxious error handling or panic occurring.
383            Message::System { content, .. } => message::Message::User {
384                content: content.map(|c| match c {
385                    SystemContent::Text { text } => message::UserContent::text(text),
386                }),
387            },
388        })
389    }
390}
391
392#[derive(Debug, Deserialize)]
393pub struct Choice {
394    pub finish_reason: String,
395    pub index: usize,
396    #[serde(default)]
397    pub logprobs: serde_json::Value,
398    pub message: Message,
399}
400
401#[derive(Debug, Deserialize, Clone)]
402pub struct Usage {
403    pub completion_tokens: i32,
404    pub prompt_tokens: i32,
405    pub total_tokens: i32,
406}
407
408#[derive(Debug, Deserialize)]
409pub struct CompletionResponse {
410    pub created: i32,
411    pub id: String,
412    pub model: String,
413    pub choices: Vec<Choice>,
414    #[serde(default, deserialize_with = "default_string_on_null")]
415    pub system_fingerprint: String,
416    pub usage: Usage,
417}
418
419fn default_string_on_null<'de, D>(deserializer: D) -> Result<String, D::Error>
420where
421    D: Deserializer<'de>,
422{
423    match Option::<String>::deserialize(deserializer)? {
424        Some(value) => Ok(value),      // Use provided value
425        None => Ok(String::default()), // Use `Default` implementation
426    }
427}
428
429impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
430    type Error = CompletionError;
431
432    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
433        let choice = response.choices.first().ok_or_else(|| {
434            CompletionError::ResponseError("Response contained no choices".to_owned())
435        })?;
436
437        let content = match &choice.message {
438            Message::Assistant {
439                content,
440                tool_calls,
441                ..
442            } => {
443                let mut content = content
444                    .iter()
445                    .map(|c| match c {
446                        AssistantContent::Text { text } => message::AssistantContent::text(text),
447                    })
448                    .collect::<Vec<_>>();
449
450                content.extend(
451                    tool_calls
452                        .iter()
453                        .map(|call| {
454                            completion::AssistantContent::tool_call(
455                                &call.id,
456                                &call.function.name,
457                                call.function.arguments.clone(),
458                            )
459                        })
460                        .collect::<Vec<_>>(),
461                );
462                Ok(content)
463            }
464            _ => Err(CompletionError::ResponseError(
465                "Response did not contain a valid message or tool call".into(),
466            )),
467        }?;
468
469        let choice = OneOrMany::many(content).map_err(|_| {
470            CompletionError::ResponseError(
471                "Response contained no message or tool call (empty)".to_owned(),
472            )
473        })?;
474
475        Ok(completion::CompletionResponse {
476            choice,
477            raw_response: response,
478        })
479    }
480}
481
482#[derive(Clone)]
483pub struct CompletionModel {
484    pub(crate) client: Client,
485    /// Name of the model (e.g: google/gemma-2-2b-it)
486    pub model: String,
487}
488
489impl CompletionModel {
490    pub fn new(client: Client, model: &str) -> Self {
491        Self {
492            client,
493            model: model.to_string(),
494        }
495    }
496
497    pub(crate) fn create_request_body(
498        &self,
499        completion_request: &CompletionRequest,
500    ) -> Result<serde_json::Value, CompletionError> {
501        let mut full_history: Vec<Message> = match &completion_request.preamble {
502            Some(preamble) => vec![Message::system(preamble)],
503            None => vec![],
504        };
505        if let Some(docs) = completion_request.normalized_documents() {
506            let docs: Vec<Message> = docs.try_into()?;
507            full_history.extend(docs);
508        }
509
510        let chat_history: Vec<Message> = completion_request
511            .chat_history
512            .clone()
513            .into_iter()
514            .map(|message| message.try_into())
515            .collect::<Result<Vec<Vec<Message>>, _>>()?
516            .into_iter()
517            .flatten()
518            .collect();
519
520        full_history.extend(chat_history);
521
522        let model = self.client.sub_provider.model_identifier(&self.model);
523
524        let request = if completion_request.tools.is_empty() {
525            json!({
526                "model": model,
527                "messages": full_history,
528                "temperature": completion_request.temperature,
529            })
530        } else {
531            json!({
532                "model": model,
533                "messages": full_history,
534                "temperature": completion_request.temperature,
535                "tools": completion_request.tools.clone().into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
536                "tool_choice": "auto",
537            })
538        };
539        Ok(request)
540    }
541}
542
543impl completion::CompletionModel for CompletionModel {
544    type Response = CompletionResponse;
545
546    #[cfg_attr(feature = "worker", worker::send)]
547    async fn completion(
548        &self,
549        completion_request: CompletionRequest,
550    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
551        let request = self.create_request_body(&completion_request)?;
552
553        let path = self.client.sub_provider.completion_endpoint(&self.model);
554
555        let request = if let Some(ref params) = completion_request.additional_params {
556            json_utils::merge(request, params.clone())
557        } else {
558            request
559        };
560
561        let response = self.client.post(&path).json(&request).send().await?;
562
563        if response.status().is_success() {
564            let t = response.text().await?;
565            tracing::debug!(target: "rig", "Huggingface completion error: {}", t);
566
567            match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
568                ApiResponse::Ok(response) => {
569                    tracing::info!(target: "rig",
570                        "Huggingface completion token usage: {:?}",
571                        format!("{:?}", response.usage)
572                    );
573                    response.try_into()
574                }
575                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())),
576            }
577        } else {
578            Err(CompletionError::ProviderError(format!(
579                "{}: {}",
580                response.status(),
581                response.text().await?
582            )))
583        }
584    }
585}
586
587#[cfg(test)]
588mod tests {
589    use super::*;
590    use serde_path_to_error::deserialize;
591
592    #[test]
593    fn test_deserialize_message() {
594        let assistant_message_json = r#"
595        {
596            "role": "assistant",
597            "content": "\n\nHello there, how may I assist you today?"
598        }
599        "#;
600
601        let assistant_message_json2 = r#"
602        {
603            "role": "assistant",
604            "content": [
605                {
606                    "type": "text",
607                    "text": "\n\nHello there, how may I assist you today?"
608                }
609            ],
610            "tool_calls": null
611        }
612        "#;
613
614        let assistant_message_json3 = r#"
615        {
616            "role": "assistant",
617            "tool_calls": [
618                {
619                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
620                    "type": "function",
621                    "function": {
622                        "name": "subtract",
623                        "arguments": {"x": 2, "y": 5}
624                    }
625                }
626            ],
627            "content": null,
628            "refusal": null
629        }
630        "#;
631
632        let user_message_json = r#"
633        {
634            "role": "user",
635            "content": [
636                {
637                    "type": "text",
638                    "text": "What's in this image?"
639                },
640                {
641                    "type": "image_url",
642                    "image_url": {
643                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
644                    }
645                }
646            ]
647        }
648        "#;
649
650        let assistant_message: Message = {
651            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
652            deserialize(jd).unwrap_or_else(|err| {
653                panic!(
654                    "Deserialization error at {} ({}:{}): {}",
655                    err.path(),
656                    err.inner().line(),
657                    err.inner().column(),
658                    err
659                );
660            })
661        };
662
663        let assistant_message2: Message = {
664            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
665            deserialize(jd).unwrap_or_else(|err| {
666                panic!(
667                    "Deserialization error at {} ({}:{}): {}",
668                    err.path(),
669                    err.inner().line(),
670                    err.inner().column(),
671                    err
672                );
673            })
674        };
675
676        let assistant_message3: Message = {
677            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
678                &mut serde_json::Deserializer::from_str(assistant_message_json3);
679            deserialize(jd).unwrap_or_else(|err| {
680                panic!(
681                    "Deserialization error at {} ({}:{}): {}",
682                    err.path(),
683                    err.inner().line(),
684                    err.inner().column(),
685                    err
686                );
687            })
688        };
689
690        let user_message: Message = {
691            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
692            deserialize(jd).unwrap_or_else(|err| {
693                panic!(
694                    "Deserialization error at {} ({}:{}): {}",
695                    err.path(),
696                    err.inner().line(),
697                    err.inner().column(),
698                    err
699                );
700            })
701        };
702
703        match assistant_message {
704            Message::Assistant { content, .. } => {
705                assert_eq!(
706                    content[0],
707                    AssistantContent::Text {
708                        text: "\n\nHello there, how may I assist you today?".to_string()
709                    }
710                );
711            }
712            _ => panic!("Expected assistant message"),
713        }
714
715        match assistant_message2 {
716            Message::Assistant {
717                content,
718                tool_calls,
719                ..
720            } => {
721                assert_eq!(
722                    content[0],
723                    AssistantContent::Text {
724                        text: "\n\nHello there, how may I assist you today?".to_string()
725                    }
726                );
727
728                assert_eq!(tool_calls, vec![]);
729            }
730            _ => panic!("Expected assistant message"),
731        }
732
733        match assistant_message3 {
734            Message::Assistant {
735                content,
736                tool_calls,
737                ..
738            } => {
739                assert!(content.is_empty());
740                assert_eq!(
741                    tool_calls[0],
742                    ToolCall {
743                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
744                        r#type: ToolType::Function,
745                        function: Function {
746                            name: "subtract".to_string(),
747                            arguments: serde_json::json!({"x": 2, "y": 5}),
748                        },
749                    }
750                );
751            }
752            _ => panic!("Expected assistant message"),
753        }
754
755        match user_message {
756            Message::User { content, .. } => {
757                let (first, second) = {
758                    let mut iter = content.into_iter();
759                    (iter.next().unwrap(), iter.next().unwrap())
760                };
761                assert_eq!(
762                    first,
763                    UserContent::Text {
764                        text: "What's in this image?".to_string()
765                    }
766                );
767                assert_eq!(second, UserContent::ImageUrl { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string() } });
768            }
769            _ => panic!("Expected user message"),
770        }
771    }
772
773    #[test]
774    fn test_message_to_message_conversion() {
775        let user_message = message::Message::User {
776            content: OneOrMany::one(message::UserContent::text("Hello")),
777        };
778
779        let assistant_message = message::Message::Assistant {
780            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
781        };
782
783        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
784        let converted_assistant_message: Vec<Message> =
785            assistant_message.clone().try_into().unwrap();
786
787        match converted_user_message[0].clone() {
788            Message::User { content, .. } => {
789                assert_eq!(
790                    content.first(),
791                    UserContent::Text {
792                        text: "Hello".to_string()
793                    }
794                );
795            }
796            _ => panic!("Expected user message"),
797        }
798
799        match converted_assistant_message[0].clone() {
800            Message::Assistant { content, .. } => {
801                assert_eq!(
802                    content[0],
803                    AssistantContent::Text {
804                        text: "Hi there!".to_string()
805                    }
806                );
807            }
808            _ => panic!("Expected assistant message"),
809        }
810
811        let original_user_message: message::Message =
812            converted_user_message[0].clone().try_into().unwrap();
813        let original_assistant_message: message::Message =
814            converted_assistant_message[0].clone().try_into().unwrap();
815
816        assert_eq!(original_user_message, user_message);
817        assert_eq!(original_assistant_message, assistant_message);
818    }
819
820    #[test]
821    fn test_message_from_message_conversion() {
822        let user_message = Message::User {
823            content: OneOrMany::one(UserContent::Text {
824                text: "Hello".to_string(),
825            }),
826        };
827
828        let assistant_message = Message::Assistant {
829            content: vec![AssistantContent::Text {
830                text: "Hi there!".to_string(),
831            }],
832            tool_calls: vec![],
833        };
834
835        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
836        let converted_assistant_message: message::Message =
837            assistant_message.clone().try_into().unwrap();
838
839        match converted_user_message.clone() {
840            message::Message::User { content } => {
841                assert_eq!(content.first(), message::UserContent::text("Hello"));
842            }
843            _ => panic!("Expected user message"),
844        }
845
846        match converted_assistant_message.clone() {
847            message::Message::Assistant { content } => {
848                assert_eq!(
849                    content.first(),
850                    message::AssistantContent::text("Hi there!")
851                );
852            }
853            _ => panic!("Expected assistant message"),
854        }
855
856        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
857        let original_assistant_message: Vec<Message> =
858            converted_assistant_message.try_into().unwrap();
859
860        assert_eq!(original_user_message[0], user_message);
861        assert_eq!(original_assistant_message[0], assistant_message);
862    }
863
864    #[test]
865    fn test_responses() {
866        let fireworks_response_json = r#"
867        {
868            "choices": [
869                {
870                    "finish_reason": "tool_calls",
871                    "index": 0,
872                    "message": {
873                        "role": "assistant",
874                        "tool_calls": [
875                            {
876                                "function": {
877                                "arguments": "{\"x\": 2, \"y\": 5}",
878                                "name": "subtract"
879                                },
880                                "id": "call_1BspL6mQqjKgvsQbH1TIYkHf",
881                                "index": 0,
882                                "type": "function"
883                            }
884                        ]
885                    }
886                }
887            ],
888            "created": 1740704000,
889            "id": "2a81f6a1-4866-42fb-9902-2655a2b5b1ff",
890            "model": "accounts/fireworks/models/deepseek-v3",
891            "object": "chat.completion",
892            "usage": {
893                "completion_tokens": 26,
894                "prompt_tokens": 248,
895                "total_tokens": 274
896            }
897        }
898        "#;
899
900        let novita_response_json = r#"
901        {
902            "choices": [
903                {
904                    "finish_reason": "tool_calls",
905                    "index": 0,
906                    "logprobs": null,
907                    "message": {
908                        "audio": null,
909                        "content": null,
910                        "function_call": null,
911                        "reasoning_content": null,
912                        "refusal": null,
913                        "role": "assistant",
914                        "tool_calls": [
915                            {
916                                "function": {
917                                    "arguments": "{\"x\": \"2\", \"y\": \"5\"}",
918                                    "name": "subtract"
919                                },
920                                "id": "chatcmpl-tool-f6d2af7c8dc041058f95e2c2eede45c5",
921                                "type": "function"
922                            }
923                        ]
924                    },
925                    "stop_reason": 128008
926                }
927            ],
928            "created": 1740704592,
929            "id": "chatcmpl-a92c60ae125c47c998ecdcb53387fed4",
930            "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-fast",
931            "object": "chat.completion",
932            "prompt_logprobs": null,
933            "service_tier": null,
934            "system_fingerprint": null,
935            "usage": {
936                "completion_tokens": 28,
937                "completion_tokens_details": null,
938                "prompt_tokens": 335,
939                "prompt_tokens_details": null,
940                "total_tokens": 363
941            }
942        }
943        "#;
944
945        let _firework_response: CompletionResponse = {
946            let jd = &mut serde_json::Deserializer::from_str(fireworks_response_json);
947            deserialize(jd).unwrap_or_else(|err| {
948                panic!(
949                    "Deserialization error at {} ({}:{}): {}",
950                    err.path(),
951                    err.inner().line(),
952                    err.inner().column(),
953                    err
954                );
955            })
956        };
957
958        let _novita_response: CompletionResponse = {
959            let jd = &mut serde_json::Deserializer::from_str(novita_response_json);
960            deserialize(jd).unwrap_or_else(|err| {
961                panic!(
962                    "Deserialization error at {} ({}:{}): {}",
963                    err.path(),
964                    err.inner().line(),
965                    err.inner().column(),
966                    err
967                );
968            })
969        };
970    }
971}