rig/providers/huggingface/
completion.rs

1use serde::{Deserialize, Deserializer, Serialize};
2use serde_json::{Value, json};
3use std::{convert::Infallible, str::FromStr};
4
5use super::client::Client;
6use crate::providers::openai::StreamingCompletionResponse;
7use crate::{
8    OneOrMany,
9    completion::{self, CompletionError, CompletionRequest},
10    json_utils,
11    message::{self},
12    one_or_many::string_or_one_or_many,
13};
14
15#[derive(Debug, Deserialize)]
16#[serde(untagged)]
17pub enum ApiResponse<T> {
18    Ok(T),
19    Err(Value),
20}
21
22// ================================================================
23// Huggingface Completion API
24// ================================================================
25
26// Conversational LLMs
27
28/// `google/gemma-2-2b-it` completion model
29pub const GEMMA_2: &str = "google/gemma-2-2b-it";
30/// `meta-llama/Meta-Llama-3.1-8B-Instruct` completion model
31pub const META_LLAMA_3_1: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
32/// `microsoft/phi-4` completion model
33pub const PHI_4: &str = "microsoft/phi-4";
34/// `PowerInfer/SmallThinker-3B-Preview` completion model
35pub const SMALLTHINKER_PREVIEW: &str = "PowerInfer/SmallThinker-3B-Preview";
36/// `Qwen/Qwen2.5-7B-Instruct` completion model
37pub const QWEN2_5: &str = "Qwen/Qwen2.5-7B-Instruct";
38/// `Qwen/Qwen2.5-Coder-32B-Instruct` completion model
39pub const QWEN2_5_CODER: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
40
41// Conversational VLMs
42
43/// `Qwen/Qwen2-VL-7B-Instruct` visual-language completion model
44pub const QWEN2_VL: &str = "Qwen/Qwen2-VL-7B-Instruct";
45/// `Qwen/QVQ-72B-Preview` visual-language completion model
46pub const QWEN_QVQ_PREVIEW: &str = "Qwen/QVQ-72B-Preview";
47
48#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
49pub struct Function {
50    name: String,
51    #[serde(deserialize_with = "deserialize_arguments")]
52    pub arguments: serde_json::Value,
53}
54
55fn deserialize_arguments<'de, D>(deserializer: D) -> Result<Value, D::Error>
56where
57    D: Deserializer<'de>,
58{
59    let value = Value::deserialize(deserializer)?;
60
61    match value {
62        Value::String(s) => serde_json::from_str(&s).map_err(serde::de::Error::custom),
63        other => Ok(other),
64    }
65}
66
67impl From<Function> for message::ToolFunction {
68    fn from(value: Function) -> Self {
69        message::ToolFunction {
70            name: value.name,
71            arguments: value.arguments,
72        }
73    }
74}
75
76#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
77#[serde(rename_all = "lowercase")]
78pub enum ToolType {
79    #[default]
80    Function,
81}
82
83#[derive(Debug, Deserialize, Serialize, Clone)]
84pub struct ToolDefinition {
85    pub r#type: String,
86    pub function: completion::ToolDefinition,
87}
88
89impl From<completion::ToolDefinition> for ToolDefinition {
90    fn from(tool: completion::ToolDefinition) -> Self {
91        Self {
92            r#type: "function".into(),
93            function: tool,
94        }
95    }
96}
97
98#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
99pub struct ToolCall {
100    pub id: String,
101    pub r#type: ToolType,
102    pub function: Function,
103}
104
105impl From<ToolCall> for message::ToolCall {
106    fn from(value: ToolCall) -> Self {
107        message::ToolCall {
108            id: value.id,
109            call_id: None,
110            function: value.function.into(),
111        }
112    }
113}
114
115impl From<message::ToolCall> for ToolCall {
116    fn from(value: message::ToolCall) -> Self {
117        ToolCall {
118            id: value.id,
119            r#type: ToolType::Function,
120            function: Function {
121                name: value.function.name,
122                arguments: value.function.arguments,
123            },
124        }
125    }
126}
127
128#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
129pub struct ImageUrl {
130    url: String,
131}
132
133#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
134#[serde(tag = "type", rename_all = "lowercase")]
135pub enum UserContent {
136    Text {
137        text: String,
138    },
139    #[serde(rename = "image_url")]
140    ImageUrl {
141        image_url: ImageUrl,
142    },
143}
144
145impl FromStr for UserContent {
146    type Err = Infallible;
147
148    fn from_str(s: &str) -> Result<Self, Self::Err> {
149        Ok(UserContent::Text {
150            text: s.to_string(),
151        })
152    }
153}
154
155#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
156#[serde(tag = "type", rename_all = "lowercase")]
157pub enum AssistantContent {
158    Text { text: String },
159}
160
161impl FromStr for AssistantContent {
162    type Err = Infallible;
163
164    fn from_str(s: &str) -> Result<Self, Self::Err> {
165        Ok(AssistantContent::Text {
166            text: s.to_string(),
167        })
168    }
169}
170
171#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
172#[serde(tag = "type", rename_all = "lowercase")]
173pub enum SystemContent {
174    Text { text: String },
175}
176
177impl FromStr for SystemContent {
178    type Err = Infallible;
179
180    fn from_str(s: &str) -> Result<Self, Self::Err> {
181        Ok(SystemContent::Text {
182            text: s.to_string(),
183        })
184    }
185}
186
187impl From<UserContent> for message::UserContent {
188    fn from(value: UserContent) -> Self {
189        match value {
190            UserContent::Text { text } => message::UserContent::text(text),
191            UserContent::ImageUrl { image_url } => {
192                message::UserContent::image_url(image_url.url, None, None)
193            }
194        }
195    }
196}
197
198impl TryFrom<message::UserContent> for UserContent {
199    type Error = message::MessageError;
200
201    fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
202        match content {
203            message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }),
204            message::UserContent::Image(message::Image { data, .. }) => match data {
205                message::DocumentSourceKind::Url(url) => Ok(UserContent::ImageUrl {
206                    image_url: ImageUrl { url },
207                }),
208                _ => Err(message::MessageError::ConversionError(
209                    "Huggingface only supports images as urls".into(),
210                )),
211            },
212            _ => Err(message::MessageError::ConversionError(
213                "Huggingface only supports text and images".into(),
214            )),
215        }
216    }
217}
218
219#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
220#[serde(tag = "role", rename_all = "lowercase")]
221pub enum Message {
222    System {
223        #[serde(deserialize_with = "string_or_one_or_many")]
224        content: OneOrMany<SystemContent>,
225    },
226    User {
227        #[serde(deserialize_with = "string_or_one_or_many")]
228        content: OneOrMany<UserContent>,
229    },
230    Assistant {
231        #[serde(default, deserialize_with = "json_utils::string_or_vec")]
232        content: Vec<AssistantContent>,
233        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
234        tool_calls: Vec<ToolCall>,
235    },
236    #[serde(rename = "Tool")]
237    ToolResult {
238        name: String,
239        #[serde(skip_serializing_if = "Option::is_none")]
240        arguments: Option<serde_json::Value>,
241        #[serde(deserialize_with = "string_or_one_or_many")]
242        content: OneOrMany<String>,
243    },
244}
245
246impl Message {
247    pub fn system(content: &str) -> Self {
248        Message::System {
249            content: OneOrMany::one(SystemContent::Text {
250                text: content.to_string(),
251            }),
252        }
253    }
254}
255
256impl TryFrom<message::Message> for Vec<Message> {
257    type Error = message::MessageError;
258
259    fn try_from(message: message::Message) -> Result<Vec<Message>, Self::Error> {
260        match message {
261            message::Message::User { content } => {
262                let (tool_results, other_content): (Vec<_>, Vec<_>) = content
263                    .into_iter()
264                    .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
265
266                if !tool_results.is_empty() {
267                    tool_results
268                        .into_iter()
269                        .map(|content| match content {
270                            message::UserContent::ToolResult(message::ToolResult {
271                                id,
272                                content,
273                                ..
274                            }) => Ok::<_, message::MessageError>(Message::ToolResult {
275                                name: id,
276                                arguments: None,
277                                content: content.try_map(|content| match content {
278                                    message::ToolResultContent::Text(message::Text { text }) => {
279                                        Ok(text)
280                                    }
281                                    _ => Err(message::MessageError::ConversionError(
282                                        "Tool result content does not support non-text".into(),
283                                    )),
284                                })?,
285                            }),
286                            _ => unreachable!(),
287                        })
288                        .collect::<Result<Vec<_>, _>>()
289                } else {
290                    let other_content = OneOrMany::many(other_content).expect(
291                        "There must be other content here if there were no tool result content",
292                    );
293
294                    Ok(vec![Message::User {
295                        content: other_content.try_map(|content| match content {
296                            message::UserContent::Text(text) => {
297                                Ok(UserContent::Text { text: text.text })
298                            }
299                            _ => Err(message::MessageError::ConversionError(
300                                "Huggingface does not support non-text".into(),
301                            )),
302                        })?,
303                    }])
304                }
305            }
306            message::Message::Assistant { content, .. } => {
307                let (text_content, tool_calls) = content.into_iter().fold(
308                    (Vec::new(), Vec::new()),
309                    |(mut texts, mut tools), content| {
310                        match content {
311                            message::AssistantContent::Text(text) => texts.push(text),
312                            message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
313                            message::AssistantContent::Reasoning(_) => {
314                                unimplemented!("Reasoning is not supported on HuggingFace via Rig");
315                            }
316                        }
317                        (texts, tools)
318                    },
319                );
320
321                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
322                //  so either `content` or `tool_calls` will have some content.
323                Ok(vec![Message::Assistant {
324                    content: text_content
325                        .into_iter()
326                        .map(|content| AssistantContent::Text { text: content.text })
327                        .collect::<Vec<_>>(),
328                    tool_calls: tool_calls
329                        .into_iter()
330                        .map(|tool_call| tool_call.into())
331                        .collect::<Vec<_>>(),
332                }])
333            }
334        }
335    }
336}
337
338impl TryFrom<Message> for message::Message {
339    type Error = message::MessageError;
340
341    fn try_from(message: Message) -> Result<Self, Self::Error> {
342        Ok(match message {
343            Message::User { content, .. } => message::Message::User {
344                content: content.map(|content| content.into()),
345            },
346            Message::Assistant {
347                content,
348                tool_calls,
349                ..
350            } => {
351                let mut content = content
352                    .into_iter()
353                    .map(|content| match content {
354                        AssistantContent::Text { text } => message::AssistantContent::text(text),
355                    })
356                    .collect::<Vec<_>>();
357
358                content.extend(
359                    tool_calls
360                        .into_iter()
361                        .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
362                        .collect::<Result<Vec<_>, _>>()?,
363                );
364
365                message::Message::Assistant {
366                    id: None,
367                    content: OneOrMany::many(content).map_err(|_| {
368                        message::MessageError::ConversionError(
369                            "Neither `content` nor `tool_calls` was provided to the Message"
370                                .to_owned(),
371                        )
372                    })?,
373                }
374            }
375
376            Message::ToolResult { name, content, .. } => message::Message::User {
377                content: OneOrMany::one(message::UserContent::tool_result(
378                    name,
379                    content.map(message::ToolResultContent::text),
380                )),
381            },
382
383            // System messages should get stripped out when converting message's, this is just a
384            // stop gap to avoid obnoxious error handling or panic occurring.
385            Message::System { content, .. } => message::Message::User {
386                content: content.map(|c| match c {
387                    SystemContent::Text { text } => message::UserContent::text(text),
388                }),
389            },
390        })
391    }
392}
393
394#[derive(Debug, Deserialize, Serialize)]
395pub struct Choice {
396    pub finish_reason: String,
397    pub index: usize,
398    #[serde(default)]
399    pub logprobs: serde_json::Value,
400    pub message: Message,
401}
402
403#[derive(Debug, Deserialize, Clone, Serialize)]
404pub struct Usage {
405    pub completion_tokens: i32,
406    pub prompt_tokens: i32,
407    pub total_tokens: i32,
408}
409
410#[derive(Debug, Deserialize, Serialize)]
411pub struct CompletionResponse {
412    pub created: i32,
413    pub id: String,
414    pub model: String,
415    pub choices: Vec<Choice>,
416    #[serde(default, deserialize_with = "default_string_on_null")]
417    pub system_fingerprint: String,
418    pub usage: Usage,
419}
420
421fn default_string_on_null<'de, D>(deserializer: D) -> Result<String, D::Error>
422where
423    D: Deserializer<'de>,
424{
425    match Option::<String>::deserialize(deserializer)? {
426        Some(value) => Ok(value),      // Use provided value
427        None => Ok(String::default()), // Use `Default` implementation
428    }
429}
430
431impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
432    type Error = CompletionError;
433
434    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
435        let choice = response.choices.first().ok_or_else(|| {
436            CompletionError::ResponseError("Response contained no choices".to_owned())
437        })?;
438
439        let content = match &choice.message {
440            Message::Assistant {
441                content,
442                tool_calls,
443                ..
444            } => {
445                let mut content = content
446                    .iter()
447                    .map(|c| match c {
448                        AssistantContent::Text { text } => message::AssistantContent::text(text),
449                    })
450                    .collect::<Vec<_>>();
451
452                content.extend(
453                    tool_calls
454                        .iter()
455                        .map(|call| {
456                            completion::AssistantContent::tool_call(
457                                &call.id,
458                                &call.function.name,
459                                call.function.arguments.clone(),
460                            )
461                        })
462                        .collect::<Vec<_>>(),
463                );
464                Ok(content)
465            }
466            _ => Err(CompletionError::ResponseError(
467                "Response did not contain a valid message or tool call".into(),
468            )),
469        }?;
470
471        let choice = OneOrMany::many(content).map_err(|_| {
472            CompletionError::ResponseError(
473                "Response contained no message or tool call (empty)".to_owned(),
474            )
475        })?;
476
477        let usage = completion::Usage {
478            input_tokens: response.usage.prompt_tokens as u64,
479            output_tokens: response.usage.completion_tokens as u64,
480            total_tokens: response.usage.total_tokens as u64,
481        };
482
483        Ok(completion::CompletionResponse {
484            choice,
485            usage,
486            raw_response: response,
487        })
488    }
489}
490
491#[derive(Clone)]
492pub struct CompletionModel {
493    pub(crate) client: Client,
494    /// Name of the model (e.g: google/gemma-2-2b-it)
495    pub model: String,
496}
497
498impl CompletionModel {
499    pub fn new(client: Client, model: &str) -> Self {
500        Self {
501            client,
502            model: model.to_string(),
503        }
504    }
505
506    pub(crate) fn create_request_body(
507        &self,
508        completion_request: &CompletionRequest,
509    ) -> Result<serde_json::Value, CompletionError> {
510        let mut full_history: Vec<Message> = match &completion_request.preamble {
511            Some(preamble) => vec![Message::system(preamble)],
512            None => vec![],
513        };
514        if let Some(docs) = completion_request.normalized_documents() {
515            let docs: Vec<Message> = docs.try_into()?;
516            full_history.extend(docs);
517        }
518
519        let chat_history: Vec<Message> = completion_request
520            .chat_history
521            .clone()
522            .into_iter()
523            .map(|message| message.try_into())
524            .collect::<Result<Vec<Vec<Message>>, _>>()?
525            .into_iter()
526            .flatten()
527            .collect();
528
529        full_history.extend(chat_history);
530
531        let model = self.client.sub_provider.model_identifier(&self.model);
532
533        let request = if completion_request.tools.is_empty() {
534            json!({
535                "model": model,
536                "messages": full_history,
537                "temperature": completion_request.temperature,
538            })
539        } else {
540            json!({
541                "model": model,
542                "messages": full_history,
543                "temperature": completion_request.temperature,
544                "tools": completion_request.tools.clone().into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
545                "tool_choice": "auto",
546            })
547        };
548        Ok(request)
549    }
550}
551
552impl completion::CompletionModel for CompletionModel {
553    type Response = CompletionResponse;
554    type StreamingResponse = StreamingCompletionResponse;
555
556    #[cfg_attr(feature = "worker", worker::send)]
557    async fn completion(
558        &self,
559        completion_request: CompletionRequest,
560    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
561        let request = self.create_request_body(&completion_request)?;
562
563        let path = self.client.sub_provider.completion_endpoint(&self.model);
564
565        let request = if let Some(ref params) = completion_request.additional_params {
566            json_utils::merge(request, params.clone())
567        } else {
568            request
569        };
570
571        let response = self.client.post(&path).json(&request).send().await?;
572
573        if response.status().is_success() {
574            let t = response.text().await?;
575            tracing::debug!(target: "rig", "Huggingface completion error: {}", t);
576
577            match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
578                ApiResponse::Ok(response) => {
579                    tracing::info!(target: "rig",
580                        "Huggingface completion token usage: {:?}",
581                        format!("{:?}", response.usage)
582                    );
583                    response.try_into()
584                }
585                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())),
586            }
587        } else {
588            Err(CompletionError::ProviderError(format!(
589                "{}: {}",
590                response.status(),
591                response.text().await?
592            )))
593        }
594    }
595
596    #[cfg_attr(feature = "worker", worker::send)]
597    async fn stream(
598        &self,
599        request: CompletionRequest,
600    ) -> Result<
601        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
602        CompletionError,
603    > {
604        CompletionModel::stream(self, request).await
605    }
606}
607
608#[cfg(test)]
609mod tests {
610    use super::*;
611    use serde_path_to_error::deserialize;
612
613    #[test]
614    fn test_deserialize_message() {
615        let assistant_message_json = r#"
616        {
617            "role": "assistant",
618            "content": "\n\nHello there, how may I assist you today?"
619        }
620        "#;
621
622        let assistant_message_json2 = r#"
623        {
624            "role": "assistant",
625            "content": [
626                {
627                    "type": "text",
628                    "text": "\n\nHello there, how may I assist you today?"
629                }
630            ],
631            "tool_calls": null
632        }
633        "#;
634
635        let assistant_message_json3 = r#"
636        {
637            "role": "assistant",
638            "tool_calls": [
639                {
640                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
641                    "type": "function",
642                    "function": {
643                        "name": "subtract",
644                        "arguments": {"x": 2, "y": 5}
645                    }
646                }
647            ],
648            "content": null,
649            "refusal": null
650        }
651        "#;
652
653        let user_message_json = r#"
654        {
655            "role": "user",
656            "content": [
657                {
658                    "type": "text",
659                    "text": "What's in this image?"
660                },
661                {
662                    "type": "image_url",
663                    "image_url": {
664                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
665                    }
666                }
667            ]
668        }
669        "#;
670
671        let assistant_message: Message = {
672            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
673            deserialize(jd).unwrap_or_else(|err| {
674                panic!(
675                    "Deserialization error at {} ({}:{}): {}",
676                    err.path(),
677                    err.inner().line(),
678                    err.inner().column(),
679                    err
680                );
681            })
682        };
683
684        let assistant_message2: Message = {
685            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
686            deserialize(jd).unwrap_or_else(|err| {
687                panic!(
688                    "Deserialization error at {} ({}:{}): {}",
689                    err.path(),
690                    err.inner().line(),
691                    err.inner().column(),
692                    err
693                );
694            })
695        };
696
697        let assistant_message3: Message = {
698            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
699                &mut serde_json::Deserializer::from_str(assistant_message_json3);
700            deserialize(jd).unwrap_or_else(|err| {
701                panic!(
702                    "Deserialization error at {} ({}:{}): {}",
703                    err.path(),
704                    err.inner().line(),
705                    err.inner().column(),
706                    err
707                );
708            })
709        };
710
711        let user_message: Message = {
712            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
713            deserialize(jd).unwrap_or_else(|err| {
714                panic!(
715                    "Deserialization error at {} ({}:{}): {}",
716                    err.path(),
717                    err.inner().line(),
718                    err.inner().column(),
719                    err
720                );
721            })
722        };
723
724        match assistant_message {
725            Message::Assistant { content, .. } => {
726                assert_eq!(
727                    content[0],
728                    AssistantContent::Text {
729                        text: "\n\nHello there, how may I assist you today?".to_string()
730                    }
731                );
732            }
733            _ => panic!("Expected assistant message"),
734        }
735
736        match assistant_message2 {
737            Message::Assistant {
738                content,
739                tool_calls,
740                ..
741            } => {
742                assert_eq!(
743                    content[0],
744                    AssistantContent::Text {
745                        text: "\n\nHello there, how may I assist you today?".to_string()
746                    }
747                );
748
749                assert_eq!(tool_calls, vec![]);
750            }
751            _ => panic!("Expected assistant message"),
752        }
753
754        match assistant_message3 {
755            Message::Assistant {
756                content,
757                tool_calls,
758                ..
759            } => {
760                assert!(content.is_empty());
761                assert_eq!(
762                    tool_calls[0],
763                    ToolCall {
764                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
765                        r#type: ToolType::Function,
766                        function: Function {
767                            name: "subtract".to_string(),
768                            arguments: serde_json::json!({"x": 2, "y": 5}),
769                        },
770                    }
771                );
772            }
773            _ => panic!("Expected assistant message"),
774        }
775
776        match user_message {
777            Message::User { content, .. } => {
778                let (first, second) = {
779                    let mut iter = content.into_iter();
780                    (iter.next().unwrap(), iter.next().unwrap())
781                };
782                assert_eq!(
783                    first,
784                    UserContent::Text {
785                        text: "What's in this image?".to_string()
786                    }
787                );
788                assert_eq!(second, UserContent::ImageUrl { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string() } });
789            }
790            _ => panic!("Expected user message"),
791        }
792    }
793
794    #[test]
795    fn test_message_to_message_conversion() {
796        let user_message = message::Message::User {
797            content: OneOrMany::one(message::UserContent::text("Hello")),
798        };
799
800        let assistant_message = message::Message::Assistant {
801            id: None,
802            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
803        };
804
805        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
806        let converted_assistant_message: Vec<Message> =
807            assistant_message.clone().try_into().unwrap();
808
809        match converted_user_message[0].clone() {
810            Message::User { content, .. } => {
811                assert_eq!(
812                    content.first(),
813                    UserContent::Text {
814                        text: "Hello".to_string()
815                    }
816                );
817            }
818            _ => panic!("Expected user message"),
819        }
820
821        match converted_assistant_message[0].clone() {
822            Message::Assistant { content, .. } => {
823                assert_eq!(
824                    content[0],
825                    AssistantContent::Text {
826                        text: "Hi there!".to_string()
827                    }
828                );
829            }
830            _ => panic!("Expected assistant message"),
831        }
832
833        let original_user_message: message::Message =
834            converted_user_message[0].clone().try_into().unwrap();
835        let original_assistant_message: message::Message =
836            converted_assistant_message[0].clone().try_into().unwrap();
837
838        assert_eq!(original_user_message, user_message);
839        assert_eq!(original_assistant_message, assistant_message);
840    }
841
842    #[test]
843    fn test_message_from_message_conversion() {
844        let user_message = Message::User {
845            content: OneOrMany::one(UserContent::Text {
846                text: "Hello".to_string(),
847            }),
848        };
849
850        let assistant_message = Message::Assistant {
851            content: vec![AssistantContent::Text {
852                text: "Hi there!".to_string(),
853            }],
854            tool_calls: vec![],
855        };
856
857        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
858        let converted_assistant_message: message::Message =
859            assistant_message.clone().try_into().unwrap();
860
861        match converted_user_message.clone() {
862            message::Message::User { content } => {
863                assert_eq!(content.first(), message::UserContent::text("Hello"));
864            }
865            _ => panic!("Expected user message"),
866        }
867
868        match converted_assistant_message.clone() {
869            message::Message::Assistant { content, .. } => {
870                assert_eq!(
871                    content.first(),
872                    message::AssistantContent::text("Hi there!")
873                );
874            }
875            _ => panic!("Expected assistant message"),
876        }
877
878        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
879        let original_assistant_message: Vec<Message> =
880            converted_assistant_message.try_into().unwrap();
881
882        assert_eq!(original_user_message[0], user_message);
883        assert_eq!(original_assistant_message[0], assistant_message);
884    }
885
886    #[test]
887    fn test_responses() {
888        let fireworks_response_json = r#"
889        {
890            "choices": [
891                {
892                    "finish_reason": "tool_calls",
893                    "index": 0,
894                    "message": {
895                        "role": "assistant",
896                        "tool_calls": [
897                            {
898                                "function": {
899                                "arguments": "{\"x\": 2, \"y\": 5}",
900                                "name": "subtract"
901                                },
902                                "id": "call_1BspL6mQqjKgvsQbH1TIYkHf",
903                                "index": 0,
904                                "type": "function"
905                            }
906                        ]
907                    }
908                }
909            ],
910            "created": 1740704000,
911            "id": "2a81f6a1-4866-42fb-9902-2655a2b5b1ff",
912            "model": "accounts/fireworks/models/deepseek-v3",
913            "object": "chat.completion",
914            "usage": {
915                "completion_tokens": 26,
916                "prompt_tokens": 248,
917                "total_tokens": 274
918            }
919        }
920        "#;
921
922        let novita_response_json = r#"
923        {
924            "choices": [
925                {
926                    "finish_reason": "tool_calls",
927                    "index": 0,
928                    "logprobs": null,
929                    "message": {
930                        "audio": null,
931                        "content": null,
932                        "function_call": null,
933                        "reasoning_content": null,
934                        "refusal": null,
935                        "role": "assistant",
936                        "tool_calls": [
937                            {
938                                "function": {
939                                    "arguments": "{\"x\": \"2\", \"y\": \"5\"}",
940                                    "name": "subtract"
941                                },
942                                "id": "chatcmpl-tool-f6d2af7c8dc041058f95e2c2eede45c5",
943                                "type": "function"
944                            }
945                        ]
946                    },
947                    "stop_reason": 128008
948                }
949            ],
950            "created": 1740704592,
951            "id": "chatcmpl-a92c60ae125c47c998ecdcb53387fed4",
952            "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-fast",
953            "object": "chat.completion",
954            "prompt_logprobs": null,
955            "service_tier": null,
956            "system_fingerprint": null,
957            "usage": {
958                "completion_tokens": 28,
959                "completion_tokens_details": null,
960                "prompt_tokens": 335,
961                "prompt_tokens_details": null,
962                "total_tokens": 363
963            }
964        }
965        "#;
966
967        let _firework_response: CompletionResponse = {
968            let jd = &mut serde_json::Deserializer::from_str(fireworks_response_json);
969            deserialize(jd).unwrap_or_else(|err| {
970                panic!(
971                    "Deserialization error at {} ({}:{}): {}",
972                    err.path(),
973                    err.inner().line(),
974                    err.inner().column(),
975                    err
976                );
977            })
978        };
979
980        let _novita_response: CompletionResponse = {
981            let jd = &mut serde_json::Deserializer::from_str(novita_response_json);
982            deserialize(jd).unwrap_or_else(|err| {
983                panic!(
984                    "Deserialization error at {} ({}:{}): {}",
985                    err.path(),
986                    err.inner().line(),
987                    err.inner().column(),
988                    err
989                );
990            })
991        };
992    }
993}