rig/providers/huggingface/
completion.rs

1use super::client::Client;
2use crate::completion::GetTokenUsage;
3use crate::http_client::HttpClientExt;
4use crate::providers::openai::StreamingCompletionResponse;
5use crate::telemetry::SpanCombinator;
6use crate::{
7    OneOrMany,
8    completion::{self, CompletionError, CompletionRequest},
9    json_utils,
10    message::{self},
11    one_or_many::string_or_one_or_many,
12};
13use serde::{Deserialize, Deserializer, Serialize, Serializer};
14use serde_json::{Value, json};
15use std::{convert::Infallible, str::FromStr};
16use tracing::info_span;
17
18#[derive(Debug, Deserialize)]
19#[serde(untagged)]
20pub enum ApiResponse<T> {
21    Ok(T),
22    Err(Value),
23}
24
25// ================================================================
26// Huggingface Completion API
27// ================================================================
28
29// Conversational LLMs
30
31/// `google/gemma-2-2b-it` completion model
32pub const GEMMA_2: &str = "google/gemma-2-2b-it";
33/// `meta-llama/Meta-Llama-3.1-8B-Instruct` completion model
34pub const META_LLAMA_3_1: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
35/// `microsoft/phi-4` completion model
36pub const PHI_4: &str = "microsoft/phi-4";
37/// `PowerInfer/SmallThinker-3B-Preview` completion model
38pub const SMALLTHINKER_PREVIEW: &str = "PowerInfer/SmallThinker-3B-Preview";
39/// `Qwen/Qwen2.5-7B-Instruct` completion model
40pub const QWEN2_5: &str = "Qwen/Qwen2.5-7B-Instruct";
41/// `Qwen/Qwen2.5-Coder-32B-Instruct` completion model
42pub const QWEN2_5_CODER: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
43
44// Conversational VLMs
45
46/// `Qwen/Qwen2-VL-7B-Instruct` visual-language completion model
47pub const QWEN2_VL: &str = "Qwen/Qwen2-VL-7B-Instruct";
48/// `Qwen/QVQ-72B-Preview` visual-language completion model
49pub const QWEN_QVQ_PREVIEW: &str = "Qwen/QVQ-72B-Preview";
50
51#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
52pub struct Function {
53    name: String,
54    #[serde(
55        serialize_with = "json_utils::stringified_json::serialize",
56        deserialize_with = "deserialize_arguments"
57    )]
58    pub arguments: serde_json::Value,
59}
60
61fn deserialize_arguments<'de, D>(deserializer: D) -> Result<Value, D::Error>
62where
63    D: Deserializer<'de>,
64{
65    let value = Value::deserialize(deserializer)?;
66
67    match value {
68        Value::String(s) => serde_json::from_str(&s).map_err(serde::de::Error::custom),
69        other => Ok(other),
70    }
71}
72
73impl From<Function> for message::ToolFunction {
74    fn from(value: Function) -> Self {
75        message::ToolFunction {
76            name: value.name,
77            arguments: value.arguments,
78        }
79    }
80}
81
82#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
83#[serde(rename_all = "lowercase")]
84pub enum ToolType {
85    #[default]
86    Function,
87}
88
89#[derive(Debug, Deserialize, Serialize, Clone)]
90pub struct ToolDefinition {
91    pub r#type: String,
92    pub function: completion::ToolDefinition,
93}
94
95impl From<completion::ToolDefinition> for ToolDefinition {
96    fn from(tool: completion::ToolDefinition) -> Self {
97        Self {
98            r#type: "function".into(),
99            function: tool,
100        }
101    }
102}
103
104#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
105pub struct ToolCall {
106    pub id: String,
107    pub r#type: ToolType,
108    pub function: Function,
109}
110
111impl From<ToolCall> for message::ToolCall {
112    fn from(value: ToolCall) -> Self {
113        message::ToolCall {
114            id: value.id,
115            call_id: None,
116            function: value.function.into(),
117        }
118    }
119}
120
121impl From<message::ToolCall> for ToolCall {
122    fn from(value: message::ToolCall) -> Self {
123        ToolCall {
124            id: value.id,
125            r#type: ToolType::Function,
126            function: Function {
127                name: value.function.name,
128                arguments: value.function.arguments,
129            },
130        }
131    }
132}
133
134#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
135pub struct ImageUrl {
136    url: String,
137}
138
139#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
140#[serde(tag = "type", rename_all = "lowercase")]
141pub enum UserContent {
142    Text {
143        text: String,
144    },
145    #[serde(rename = "image_url")]
146    ImageUrl {
147        image_url: ImageUrl,
148    },
149}
150
151impl FromStr for UserContent {
152    type Err = Infallible;
153
154    fn from_str(s: &str) -> Result<Self, Self::Err> {
155        Ok(UserContent::Text {
156            text: s.to_string(),
157        })
158    }
159}
160
161#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
162#[serde(tag = "type", rename_all = "lowercase")]
163pub enum AssistantContent {
164    Text { text: String },
165}
166
167impl FromStr for AssistantContent {
168    type Err = Infallible;
169
170    fn from_str(s: &str) -> Result<Self, Self::Err> {
171        Ok(AssistantContent::Text {
172            text: s.to_string(),
173        })
174    }
175}
176
177#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
178#[serde(tag = "type", rename_all = "lowercase")]
179pub enum SystemContent {
180    Text { text: String },
181}
182
183impl FromStr for SystemContent {
184    type Err = Infallible;
185
186    fn from_str(s: &str) -> Result<Self, Self::Err> {
187        Ok(SystemContent::Text {
188            text: s.to_string(),
189        })
190    }
191}
192
193impl From<UserContent> for message::UserContent {
194    fn from(value: UserContent) -> Self {
195        match value {
196            UserContent::Text { text } => message::UserContent::text(text),
197            UserContent::ImageUrl { image_url } => {
198                message::UserContent::image_url(image_url.url, None, None)
199            }
200        }
201    }
202}
203
204impl TryFrom<message::UserContent> for UserContent {
205    type Error = message::MessageError;
206
207    fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
208        match content {
209            message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }),
210            message::UserContent::Document(message::Document {
211                data: message::DocumentSourceKind::Raw(raw),
212                ..
213            }) => {
214                let text = String::from_utf8_lossy(raw.as_slice()).into();
215                Ok(UserContent::Text { text })
216            }
217            message::UserContent::Document(message::Document {
218                data:
219                    message::DocumentSourceKind::Base64(text)
220                    | message::DocumentSourceKind::String(text),
221                ..
222            }) => Ok(UserContent::Text { text }),
223            message::UserContent::Image(message::Image { data, .. }) => match data {
224                message::DocumentSourceKind::Url(url) => Ok(UserContent::ImageUrl {
225                    image_url: ImageUrl { url },
226                }),
227                _ => Err(message::MessageError::ConversionError(
228                    "Huggingface only supports images as urls".into(),
229                )),
230            },
231            _ => Err(message::MessageError::ConversionError(
232                "Huggingface only supports text and images".into(),
233            )),
234        }
235    }
236}
237
238#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
239#[serde(tag = "role", rename_all = "lowercase")]
240pub enum Message {
241    System {
242        #[serde(deserialize_with = "string_or_one_or_many")]
243        content: OneOrMany<SystemContent>,
244    },
245    User {
246        #[serde(deserialize_with = "string_or_one_or_many")]
247        content: OneOrMany<UserContent>,
248    },
249    Assistant {
250        #[serde(default, deserialize_with = "json_utils::string_or_vec")]
251        content: Vec<AssistantContent>,
252        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
253        tool_calls: Vec<ToolCall>,
254    },
255    #[serde(rename = "tool", alias = "Tool")]
256    ToolResult {
257        name: String,
258        #[serde(skip_serializing_if = "Option::is_none")]
259        arguments: Option<serde_json::Value>,
260        #[serde(
261            deserialize_with = "string_or_one_or_many",
262            serialize_with = "serialize_tool_content"
263        )]
264        content: OneOrMany<String>,
265    },
266}
267
268fn serialize_tool_content<S>(content: &OneOrMany<String>, serializer: S) -> Result<S::Ok, S::Error>
269where
270    S: Serializer,
271{
272    // OpenAI-compatible APIs expect tool content as a string, not an array
273    let joined = content
274        .iter()
275        .map(String::as_str)
276        .collect::<Vec<_>>()
277        .join("\n");
278    serializer.serialize_str(&joined)
279}
280
281impl Message {
282    pub fn system(content: &str) -> Self {
283        Message::System {
284            content: OneOrMany::one(SystemContent::Text {
285                text: content.to_string(),
286            }),
287        }
288    }
289}
290
291impl TryFrom<message::Message> for Vec<Message> {
292    type Error = message::MessageError;
293
294    fn try_from(message: message::Message) -> Result<Vec<Message>, Self::Error> {
295        match message {
296            message::Message::User { content } => {
297                let (tool_results, other_content): (Vec<_>, Vec<_>) = content
298                    .into_iter()
299                    .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
300
301                if !tool_results.is_empty() {
302                    tool_results
303                        .into_iter()
304                        .map(|content| match content {
305                            message::UserContent::ToolResult(message::ToolResult {
306                                id,
307                                content,
308                                ..
309                            }) => Ok::<_, message::MessageError>(Message::ToolResult {
310                                name: id,
311                                arguments: None,
312                                content: content.try_map(|content| match content {
313                                    message::ToolResultContent::Text(message::Text { text }) => {
314                                        Ok(text)
315                                    }
316                                    _ => Err(message::MessageError::ConversionError(
317                                        "Tool result content does not support non-text".into(),
318                                    )),
319                                })?,
320                            }),
321                            _ => unreachable!(),
322                        })
323                        .collect::<Result<Vec<_>, _>>()
324                } else {
325                    let other_content = OneOrMany::many(other_content).expect(
326                        "There must be other content here if there were no tool result content",
327                    );
328
329                    Ok(vec![Message::User {
330                        content: other_content.try_map(|content| match content {
331                            message::UserContent::Text(text) => {
332                                Ok(UserContent::Text { text: text.text })
333                            }
334                            message::UserContent::Image(image) => {
335                                let url = image.try_into_url()?;
336
337                                Ok(UserContent::ImageUrl {
338                                    image_url: ImageUrl { url },
339                                })
340                            }
341                            message::UserContent::Document(message::Document {
342                                data: message::DocumentSourceKind::Raw(raw), ..
343                            }) => {
344                                let text = String::from_utf8_lossy(raw.as_slice()).into();
345                                Ok(UserContent::Text { text })
346                            }
347                            message::UserContent::Document(message::Document {
348                                data: message::DocumentSourceKind::Base64(text) | message::DocumentSourceKind::String(text), ..
349                            }) => {
350                                Ok(UserContent::Text { text })
351                            }
352                            _ => Err(message::MessageError::ConversionError(
353                                "Huggingface inputs only support text and image URLs (both base64-encoded images and regular URLs)".into(),
354                            )),
355                        })?,
356                    }])
357                }
358            }
359            message::Message::Assistant { content, .. } => {
360                let (text_content, tool_calls) = content.into_iter().fold(
361                    (Vec::new(), Vec::new()),
362                    |(mut texts, mut tools), content| {
363                        match content {
364                            message::AssistantContent::Text(text) => texts.push(text),
365                            message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
366                            message::AssistantContent::Reasoning(_) => {
367                                unimplemented!("Reasoning is not supported on HuggingFace via Rig");
368                            }
369                        }
370                        (texts, tools)
371                    },
372                );
373
374                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
375                //  so either `content` or `tool_calls` will have some content.
376                Ok(vec![Message::Assistant {
377                    content: text_content
378                        .into_iter()
379                        .map(|content| AssistantContent::Text { text: content.text })
380                        .collect::<Vec<_>>(),
381                    tool_calls: tool_calls
382                        .into_iter()
383                        .map(|tool_call| tool_call.into())
384                        .collect::<Vec<_>>(),
385                }])
386            }
387        }
388    }
389}
390
391impl TryFrom<Message> for message::Message {
392    type Error = message::MessageError;
393
394    fn try_from(message: Message) -> Result<Self, Self::Error> {
395        Ok(match message {
396            Message::User { content, .. } => message::Message::User {
397                content: content.map(|content| content.into()),
398            },
399            Message::Assistant {
400                content,
401                tool_calls,
402                ..
403            } => {
404                let mut content = content
405                    .into_iter()
406                    .map(|content| match content {
407                        AssistantContent::Text { text } => message::AssistantContent::text(text),
408                    })
409                    .collect::<Vec<_>>();
410
411                content.extend(
412                    tool_calls
413                        .into_iter()
414                        .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
415                        .collect::<Result<Vec<_>, _>>()?,
416                );
417
418                message::Message::Assistant {
419                    id: None,
420                    content: OneOrMany::many(content).map_err(|_| {
421                        message::MessageError::ConversionError(
422                            "Neither `content` nor `tool_calls` was provided to the Message"
423                                .to_owned(),
424                        )
425                    })?,
426                }
427            }
428
429            Message::ToolResult { name, content, .. } => message::Message::User {
430                content: OneOrMany::one(message::UserContent::tool_result(
431                    name,
432                    content.map(message::ToolResultContent::text),
433                )),
434            },
435
436            // System messages should get stripped out when converting message's, this is just a
437            // stop gap to avoid obnoxious error handling or panic occurring.
438            Message::System { content, .. } => message::Message::User {
439                content: content.map(|c| match c {
440                    SystemContent::Text { text } => message::UserContent::text(text),
441                }),
442            },
443        })
444    }
445}
446
447#[derive(Clone, Debug, Deserialize, Serialize)]
448pub struct Choice {
449    pub finish_reason: String,
450    pub index: usize,
451    #[serde(default)]
452    pub logprobs: serde_json::Value,
453    pub message: Message,
454}
455
456#[derive(Debug, Deserialize, Clone, Serialize)]
457pub struct Usage {
458    pub completion_tokens: i32,
459    pub prompt_tokens: i32,
460    pub total_tokens: i32,
461}
462
463impl GetTokenUsage for Usage {
464    fn token_usage(&self) -> Option<crate::completion::Usage> {
465        let mut usage = crate::completion::Usage::new();
466        usage.input_tokens = self.prompt_tokens as u64;
467        usage.output_tokens = self.completion_tokens as u64;
468        usage.total_tokens = self.total_tokens as u64;
469
470        Some(usage)
471    }
472}
473
474#[derive(Clone, Debug, Deserialize, Serialize)]
475pub struct CompletionResponse {
476    pub created: i32,
477    pub id: String,
478    pub model: String,
479    pub choices: Vec<Choice>,
480    #[serde(default, deserialize_with = "default_string_on_null")]
481    pub system_fingerprint: String,
482    pub usage: Usage,
483}
484
485impl crate::telemetry::ProviderResponseExt for CompletionResponse {
486    type OutputMessage = Choice;
487    type Usage = Usage;
488
489    fn get_response_id(&self) -> Option<String> {
490        Some(self.id.clone())
491    }
492
493    fn get_response_model_name(&self) -> Option<String> {
494        Some(self.model.clone())
495    }
496
497    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
498        self.choices.clone()
499    }
500
501    fn get_text_response(&self) -> Option<String> {
502        let text_response = self
503            .choices
504            .iter()
505            .filter_map(|x| {
506                let Message::User { ref content } = x.message else {
507                    return None;
508                };
509
510                let text = content
511                    .iter()
512                    .filter_map(|x| {
513                        if let UserContent::Text { text } = x {
514                            Some(text.clone())
515                        } else {
516                            None
517                        }
518                    })
519                    .collect::<Vec<String>>();
520
521                if text.is_empty() {
522                    None
523                } else {
524                    Some(text.join("\n"))
525                }
526            })
527            .collect::<Vec<String>>()
528            .join("\n");
529
530        if text_response.is_empty() {
531            None
532        } else {
533            Some(text_response)
534        }
535    }
536
537    fn get_usage(&self) -> Option<Self::Usage> {
538        Some(self.usage.clone())
539    }
540}
541
542fn default_string_on_null<'de, D>(deserializer: D) -> Result<String, D::Error>
543where
544    D: Deserializer<'de>,
545{
546    match Option::<String>::deserialize(deserializer)? {
547        Some(value) => Ok(value),      // Use provided value
548        None => Ok(String::default()), // Use `Default` implementation
549    }
550}
551
552impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
553    type Error = CompletionError;
554
555    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
556        let choice = response.choices.first().ok_or_else(|| {
557            CompletionError::ResponseError("Response contained no choices".to_owned())
558        })?;
559
560        let content = match &choice.message {
561            Message::Assistant {
562                content,
563                tool_calls,
564                ..
565            } => {
566                let mut content = content
567                    .iter()
568                    .map(|c| match c {
569                        AssistantContent::Text { text } => message::AssistantContent::text(text),
570                    })
571                    .collect::<Vec<_>>();
572
573                content.extend(
574                    tool_calls
575                        .iter()
576                        .map(|call| {
577                            completion::AssistantContent::tool_call(
578                                &call.id,
579                                &call.function.name,
580                                call.function.arguments.clone(),
581                            )
582                        })
583                        .collect::<Vec<_>>(),
584                );
585                Ok(content)
586            }
587            _ => Err(CompletionError::ResponseError(
588                "Response did not contain a valid message or tool call".into(),
589            )),
590        }?;
591
592        let choice = OneOrMany::many(content).map_err(|_| {
593            CompletionError::ResponseError(
594                "Response contained no message or tool call (empty)".to_owned(),
595            )
596        })?;
597
598        let usage = completion::Usage {
599            input_tokens: response.usage.prompt_tokens as u64,
600            output_tokens: response.usage.completion_tokens as u64,
601            total_tokens: response.usage.total_tokens as u64,
602        };
603
604        Ok(completion::CompletionResponse {
605            choice,
606            usage,
607            raw_response: response,
608        })
609    }
610}
611
612#[derive(Clone)]
613pub struct CompletionModel<T = reqwest::Client> {
614    pub(crate) client: Client<T>,
615    /// Name of the model (e.g: google/gemma-2-2b-it)
616    pub model: String,
617}
618
619impl<T> CompletionModel<T> {
620    pub fn new(client: Client<T>, model: &str) -> Self {
621        Self {
622            client,
623            model: model.to_string(),
624        }
625    }
626
627    pub(crate) fn create_request_body(
628        &self,
629        completion_request: &CompletionRequest,
630    ) -> Result<serde_json::Value, CompletionError> {
631        let mut full_history: Vec<Message> = match &completion_request.preamble {
632            Some(preamble) => vec![Message::system(preamble)],
633            None => vec![],
634        };
635        if let Some(docs) = completion_request.normalized_documents() {
636            let docs: Vec<Message> = docs.try_into()?;
637            full_history.extend(docs);
638        }
639
640        let chat_history: Vec<Message> = completion_request
641            .chat_history
642            .clone()
643            .into_iter()
644            .map(|message| message.try_into())
645            .collect::<Result<Vec<Vec<Message>>, _>>()?
646            .into_iter()
647            .flatten()
648            .collect();
649
650        full_history.extend(chat_history);
651
652        let model = self.client.sub_provider.model_identifier(&self.model);
653
654        let tool_choice = completion_request
655            .tool_choice
656            .clone()
657            .map(crate::providers::openai::completion::ToolChoice::try_from)
658            .transpose()?;
659
660        let request = if completion_request.tools.is_empty() {
661            json!({
662                "model": model,
663                "messages": full_history,
664                "temperature": completion_request.temperature,
665            })
666        } else {
667            json!({
668                "model": model,
669                "messages": full_history,
670                "temperature": completion_request.temperature,
671                "tools": completion_request.tools.clone().into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
672                "tool_choice": tool_choice,
673            })
674        };
675        Ok(request)
676    }
677}
678
679impl<T> completion::CompletionModel for CompletionModel<T>
680where
681    T: HttpClientExt + Clone + 'static,
682{
683    type Response = CompletionResponse;
684    type StreamingResponse = StreamingCompletionResponse;
685
686    #[cfg_attr(feature = "worker", worker::send)]
687    async fn completion(
688        &self,
689        completion_request: CompletionRequest,
690    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
691        let span = if tracing::Span::current().is_disabled() {
692            info_span!(
693                target: "rig::completions",
694                "chat",
695                gen_ai.operation.name = "chat",
696                gen_ai.provider.name = "huggingface",
697                gen_ai.request.model = self.model,
698                gen_ai.system_instructions = &completion_request.preamble,
699                gen_ai.response.id = tracing::field::Empty,
700                gen_ai.response.model = tracing::field::Empty,
701                gen_ai.usage.output_tokens = tracing::field::Empty,
702                gen_ai.usage.input_tokens = tracing::field::Empty,
703                gen_ai.input.messages = tracing::field::Empty,
704                gen_ai.output.messages = tracing::field::Empty,
705            )
706        } else {
707            tracing::Span::current()
708        };
709        let request = self.create_request_body(&completion_request)?;
710        span.record_model_input(&request.get("messages"));
711
712        let path = self.client.sub_provider.completion_endpoint(&self.model);
713
714        let request = if let Some(ref params) = completion_request.additional_params {
715            json_utils::merge(request, params.clone())
716        } else {
717            request
718        };
719
720        let request = serde_json::to_vec(&request)?;
721
722        let request = self
723            .client
724            .post(&path)?
725            .header("Content-Type", "application/json")
726            .body(request)
727            .map_err(|e| CompletionError::HttpError(e.into()))?;
728
729        let response = self.client.send(request).await?;
730
731        if response.status().is_success() {
732            let bytes: Vec<u8> = response.into_body().await?;
733            let text = String::from_utf8_lossy(&bytes);
734
735            tracing::debug!(target: "rig", "Huggingface completion error: {}", text);
736
737            match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&bytes)? {
738                ApiResponse::Ok(response) => {
739                    let span = tracing::Span::current();
740                    span.record_token_usage(&response.usage);
741                    span.record_model_output(&response.choices);
742                    span.record_response_metadata(&response);
743
744                    response.try_into()
745                }
746                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())),
747            }
748        } else {
749            let status = response.status();
750            let text: Vec<u8> = response.into_body().await?;
751            let text: String = String::from_utf8_lossy(&text).into();
752
753            Err(CompletionError::ProviderError(format!(
754                "{}: {}",
755                status, text
756            )))
757        }
758    }
759
760    #[cfg_attr(feature = "worker", worker::send)]
761    async fn stream(
762        &self,
763        request: CompletionRequest,
764    ) -> Result<
765        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
766        CompletionError,
767    > {
768        CompletionModel::stream(self, request).await
769    }
770}
771
772#[cfg(test)]
773mod tests {
774    use super::*;
775    use serde_path_to_error::deserialize;
776
777    #[test]
778    fn test_deserialize_message() {
779        let assistant_message_json = r#"
780        {
781            "role": "assistant",
782            "content": "\n\nHello there, how may I assist you today?"
783        }
784        "#;
785
786        let assistant_message_json2 = r#"
787        {
788            "role": "assistant",
789            "content": [
790                {
791                    "type": "text",
792                    "text": "\n\nHello there, how may I assist you today?"
793                }
794            ],
795            "tool_calls": null
796        }
797        "#;
798
799        let assistant_message_json3 = r#"
800        {
801            "role": "assistant",
802            "tool_calls": [
803                {
804                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
805                    "type": "function",
806                    "function": {
807                        "name": "subtract",
808                        "arguments": {"x": 2, "y": 5}
809                    }
810                }
811            ],
812            "content": null,
813            "refusal": null
814        }
815        "#;
816
817        let user_message_json = r#"
818        {
819            "role": "user",
820            "content": [
821                {
822                    "type": "text",
823                    "text": "What's in this image?"
824                },
825                {
826                    "type": "image_url",
827                    "image_url": {
828                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
829                    }
830                }
831            ]
832        }
833        "#;
834
835        let assistant_message: Message = {
836            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
837            deserialize(jd).unwrap_or_else(|err| {
838                panic!(
839                    "Deserialization error at {} ({}:{}): {}",
840                    err.path(),
841                    err.inner().line(),
842                    err.inner().column(),
843                    err
844                );
845            })
846        };
847
848        let assistant_message2: Message = {
849            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
850            deserialize(jd).unwrap_or_else(|err| {
851                panic!(
852                    "Deserialization error at {} ({}:{}): {}",
853                    err.path(),
854                    err.inner().line(),
855                    err.inner().column(),
856                    err
857                );
858            })
859        };
860
861        let assistant_message3: Message = {
862            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
863                &mut serde_json::Deserializer::from_str(assistant_message_json3);
864            deserialize(jd).unwrap_or_else(|err| {
865                panic!(
866                    "Deserialization error at {} ({}:{}): {}",
867                    err.path(),
868                    err.inner().line(),
869                    err.inner().column(),
870                    err
871                );
872            })
873        };
874
875        let user_message: Message = {
876            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
877            deserialize(jd).unwrap_or_else(|err| {
878                panic!(
879                    "Deserialization error at {} ({}:{}): {}",
880                    err.path(),
881                    err.inner().line(),
882                    err.inner().column(),
883                    err
884                );
885            })
886        };
887
888        match assistant_message {
889            Message::Assistant { content, .. } => {
890                assert_eq!(
891                    content[0],
892                    AssistantContent::Text {
893                        text: "\n\nHello there, how may I assist you today?".to_string()
894                    }
895                );
896            }
897            _ => panic!("Expected assistant message"),
898        }
899
900        match assistant_message2 {
901            Message::Assistant {
902                content,
903                tool_calls,
904                ..
905            } => {
906                assert_eq!(
907                    content[0],
908                    AssistantContent::Text {
909                        text: "\n\nHello there, how may I assist you today?".to_string()
910                    }
911                );
912
913                assert_eq!(tool_calls, vec![]);
914            }
915            _ => panic!("Expected assistant message"),
916        }
917
918        match assistant_message3 {
919            Message::Assistant {
920                content,
921                tool_calls,
922                ..
923            } => {
924                assert!(content.is_empty());
925                assert_eq!(
926                    tool_calls[0],
927                    ToolCall {
928                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
929                        r#type: ToolType::Function,
930                        function: Function {
931                            name: "subtract".to_string(),
932                            arguments: serde_json::json!({"x": 2, "y": 5}),
933                        },
934                    }
935                );
936            }
937            _ => panic!("Expected assistant message"),
938        }
939
940        match user_message {
941            Message::User { content, .. } => {
942                let (first, second) = {
943                    let mut iter = content.into_iter();
944                    (iter.next().unwrap(), iter.next().unwrap())
945                };
946                assert_eq!(
947                    first,
948                    UserContent::Text {
949                        text: "What's in this image?".to_string()
950                    }
951                );
952                assert_eq!(second, UserContent::ImageUrl { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string() } });
953            }
954            _ => panic!("Expected user message"),
955        }
956    }
957
958    #[test]
959    fn test_message_to_message_conversion() {
960        let user_message = message::Message::User {
961            content: OneOrMany::one(message::UserContent::text("Hello")),
962        };
963
964        let assistant_message = message::Message::Assistant {
965            id: None,
966            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
967        };
968
969        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
970        let converted_assistant_message: Vec<Message> =
971            assistant_message.clone().try_into().unwrap();
972
973        match converted_user_message[0].clone() {
974            Message::User { content, .. } => {
975                assert_eq!(
976                    content.first(),
977                    UserContent::Text {
978                        text: "Hello".to_string()
979                    }
980                );
981            }
982            _ => panic!("Expected user message"),
983        }
984
985        match converted_assistant_message[0].clone() {
986            Message::Assistant { content, .. } => {
987                assert_eq!(
988                    content[0],
989                    AssistantContent::Text {
990                        text: "Hi there!".to_string()
991                    }
992                );
993            }
994            _ => panic!("Expected assistant message"),
995        }
996
997        let original_user_message: message::Message =
998            converted_user_message[0].clone().try_into().unwrap();
999        let original_assistant_message: message::Message =
1000            converted_assistant_message[0].clone().try_into().unwrap();
1001
1002        assert_eq!(original_user_message, user_message);
1003        assert_eq!(original_assistant_message, assistant_message);
1004    }
1005
1006    #[test]
1007    fn test_message_from_message_conversion() {
1008        let user_message = Message::User {
1009            content: OneOrMany::one(UserContent::Text {
1010                text: "Hello".to_string(),
1011            }),
1012        };
1013
1014        let assistant_message = Message::Assistant {
1015            content: vec![AssistantContent::Text {
1016                text: "Hi there!".to_string(),
1017            }],
1018            tool_calls: vec![],
1019        };
1020
1021        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1022        let converted_assistant_message: message::Message =
1023            assistant_message.clone().try_into().unwrap();
1024
1025        match converted_user_message.clone() {
1026            message::Message::User { content } => {
1027                assert_eq!(content.first(), message::UserContent::text("Hello"));
1028            }
1029            _ => panic!("Expected user message"),
1030        }
1031
1032        match converted_assistant_message.clone() {
1033            message::Message::Assistant { content, .. } => {
1034                assert_eq!(
1035                    content.first(),
1036                    message::AssistantContent::text("Hi there!")
1037                );
1038            }
1039            _ => panic!("Expected assistant message"),
1040        }
1041
1042        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
1043        let original_assistant_message: Vec<Message> =
1044            converted_assistant_message.try_into().unwrap();
1045
1046        assert_eq!(original_user_message[0], user_message);
1047        assert_eq!(original_assistant_message[0], assistant_message);
1048    }
1049
1050    #[test]
1051    fn test_responses() {
1052        let fireworks_response_json = r#"
1053        {
1054            "choices": [
1055                {
1056                    "finish_reason": "tool_calls",
1057                    "index": 0,
1058                    "message": {
1059                        "role": "assistant",
1060                        "tool_calls": [
1061                            {
1062                                "function": {
1063                                "arguments": "{\"x\": 2, \"y\": 5}",
1064                                "name": "subtract"
1065                                },
1066                                "id": "call_1BspL6mQqjKgvsQbH1TIYkHf",
1067                                "index": 0,
1068                                "type": "function"
1069                            }
1070                        ]
1071                    }
1072                }
1073            ],
1074            "created": 1740704000,
1075            "id": "2a81f6a1-4866-42fb-9902-2655a2b5b1ff",
1076            "model": "accounts/fireworks/models/deepseek-v3",
1077            "object": "chat.completion",
1078            "usage": {
1079                "completion_tokens": 26,
1080                "prompt_tokens": 248,
1081                "total_tokens": 274
1082            }
1083        }
1084        "#;
1085
1086        let novita_response_json = r#"
1087        {
1088            "choices": [
1089                {
1090                    "finish_reason": "tool_calls",
1091                    "index": 0,
1092                    "logprobs": null,
1093                    "message": {
1094                        "audio": null,
1095                        "content": null,
1096                        "function_call": null,
1097                        "reasoning_content": null,
1098                        "refusal": null,
1099                        "role": "assistant",
1100                        "tool_calls": [
1101                            {
1102                                "function": {
1103                                    "arguments": "{\"x\": \"2\", \"y\": \"5\"}",
1104                                    "name": "subtract"
1105                                },
1106                                "id": "chatcmpl-tool-f6d2af7c8dc041058f95e2c2eede45c5",
1107                                "type": "function"
1108                            }
1109                        ]
1110                    },
1111                    "stop_reason": 128008
1112                }
1113            ],
1114            "created": 1740704592,
1115            "id": "chatcmpl-a92c60ae125c47c998ecdcb53387fed4",
1116            "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-fast",
1117            "object": "chat.completion",
1118            "prompt_logprobs": null,
1119            "service_tier": null,
1120            "system_fingerprint": null,
1121            "usage": {
1122                "completion_tokens": 28,
1123                "completion_tokens_details": null,
1124                "prompt_tokens": 335,
1125                "prompt_tokens_details": null,
1126                "total_tokens": 363
1127            }
1128        }
1129        "#;
1130
1131        let _firework_response: CompletionResponse = {
1132            let jd = &mut serde_json::Deserializer::from_str(fireworks_response_json);
1133            deserialize(jd).unwrap_or_else(|err| {
1134                panic!(
1135                    "Deserialization error at {} ({}:{}): {}",
1136                    err.path(),
1137                    err.inner().line(),
1138                    err.inner().column(),
1139                    err
1140                );
1141            })
1142        };
1143
1144        let _novita_response: CompletionResponse = {
1145            let jd = &mut serde_json::Deserializer::from_str(novita_response_json);
1146            deserialize(jd).unwrap_or_else(|err| {
1147                panic!(
1148                    "Deserialization error at {} ({}:{}): {}",
1149                    err.path(),
1150                    err.inner().line(),
1151                    err.inner().column(),
1152                    err
1153                );
1154            })
1155        };
1156    }
1157}