rig/providers/huggingface/
completion.rs

1use super::client::Client;
2use crate::completion::GetTokenUsage;
3use crate::http_client::HttpClientExt;
4use crate::providers::openai::StreamingCompletionResponse;
5use crate::telemetry::SpanCombinator;
6use crate::{
7    OneOrMany,
8    completion::{self, CompletionError, CompletionRequest},
9    json_utils,
10    message::{self},
11    one_or_many::string_or_one_or_many,
12};
13use serde::{Deserialize, Deserializer, Serialize, Serializer};
14use serde_json::Value;
15use std::{convert::Infallible, str::FromStr};
16use tracing::info_span;
17
18#[derive(Debug, Deserialize)]
19#[serde(untagged)]
20pub enum ApiResponse<T> {
21    Ok(T),
22    Err(Value),
23}
24
25// ================================================================
26// Huggingface Completion API
27// ================================================================
28
29// Conversational LLMs
30/// `google/gemma-2-2b-it` completion model
31pub const GEMMA_2: &str = "google/gemma-2-2b-it";
32/// `meta-llama/Meta-Llama-3.1-8B-Instruct` completion model
33pub const META_LLAMA_3_1: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
34/// `PowerInfer/SmallThinker-3B-Preview` completion model
35pub const SMALLTHINKER_PREVIEW: &str = "PowerInfer/SmallThinker-3B-Preview";
36/// `Qwen/Qwen2.5-7B-Instruct` completion model
37pub const QWEN2_5: &str = "Qwen/Qwen2.5-7B-Instruct";
38/// `Qwen/Qwen2.5-Coder-32B-Instruct` completion model
39pub const QWEN2_5_CODER: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
40
41// Conversational VLMs
42
43/// `Qwen/Qwen2-VL-7B-Instruct` visual-language completion model
44pub const QWEN2_VL: &str = "Qwen/Qwen2-VL-7B-Instruct";
45/// `Qwen/QVQ-72B-Preview` visual-language completion model
46pub const QWEN_QVQ_PREVIEW: &str = "Qwen/QVQ-72B-Preview";
47
48#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
49pub struct Function {
50    name: String,
51    #[serde(
52        serialize_with = "json_utils::stringified_json::serialize",
53        deserialize_with = "deserialize_arguments"
54    )]
55    pub arguments: serde_json::Value,
56}
57
58fn deserialize_arguments<'de, D>(deserializer: D) -> Result<Value, D::Error>
59where
60    D: Deserializer<'de>,
61{
62    let value = Value::deserialize(deserializer)?;
63
64    match value {
65        Value::String(s) => serde_json::from_str(&s).map_err(serde::de::Error::custom),
66        other => Ok(other),
67    }
68}
69
70impl From<Function> for message::ToolFunction {
71    fn from(value: Function) -> Self {
72        message::ToolFunction {
73            name: value.name,
74            arguments: value.arguments,
75        }
76    }
77}
78
79#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
80#[serde(rename_all = "lowercase")]
81pub enum ToolType {
82    #[default]
83    Function,
84}
85
86#[derive(Debug, Deserialize, Serialize, Clone)]
87pub struct ToolDefinition {
88    pub r#type: String,
89    pub function: completion::ToolDefinition,
90}
91
92impl From<completion::ToolDefinition> for ToolDefinition {
93    fn from(tool: completion::ToolDefinition) -> Self {
94        Self {
95            r#type: "function".into(),
96            function: tool,
97        }
98    }
99}
100
101#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
102pub struct ToolCall {
103    pub id: String,
104    pub r#type: ToolType,
105    pub function: Function,
106}
107
108impl From<ToolCall> for message::ToolCall {
109    fn from(value: ToolCall) -> Self {
110        message::ToolCall {
111            id: value.id,
112            call_id: None,
113            function: value.function.into(),
114        }
115    }
116}
117
118impl From<message::ToolCall> for ToolCall {
119    fn from(value: message::ToolCall) -> Self {
120        ToolCall {
121            id: value.id,
122            r#type: ToolType::Function,
123            function: Function {
124                name: value.function.name,
125                arguments: value.function.arguments,
126            },
127        }
128    }
129}
130
131#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
132pub struct ImageUrl {
133    url: String,
134}
135
136#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
137#[serde(tag = "type", rename_all = "lowercase")]
138pub enum UserContent {
139    Text {
140        text: String,
141    },
142    #[serde(rename = "image_url")]
143    ImageUrl {
144        image_url: ImageUrl,
145    },
146}
147
148impl FromStr for UserContent {
149    type Err = Infallible;
150
151    fn from_str(s: &str) -> Result<Self, Self::Err> {
152        Ok(UserContent::Text {
153            text: s.to_string(),
154        })
155    }
156}
157
158#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
159#[serde(tag = "type", rename_all = "lowercase")]
160pub enum AssistantContent {
161    Text { text: String },
162}
163
164impl FromStr for AssistantContent {
165    type Err = Infallible;
166
167    fn from_str(s: &str) -> Result<Self, Self::Err> {
168        Ok(AssistantContent::Text {
169            text: s.to_string(),
170        })
171    }
172}
173
174#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
175#[serde(tag = "type", rename_all = "lowercase")]
176pub enum SystemContent {
177    Text { text: String },
178}
179
180impl FromStr for SystemContent {
181    type Err = Infallible;
182
183    fn from_str(s: &str) -> Result<Self, Self::Err> {
184        Ok(SystemContent::Text {
185            text: s.to_string(),
186        })
187    }
188}
189
190impl From<UserContent> for message::UserContent {
191    fn from(value: UserContent) -> Self {
192        match value {
193            UserContent::Text { text } => message::UserContent::text(text),
194            UserContent::ImageUrl { image_url } => {
195                message::UserContent::image_url(image_url.url, None, None)
196            }
197        }
198    }
199}
200
201impl TryFrom<message::UserContent> for UserContent {
202    type Error = message::MessageError;
203
204    fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
205        match content {
206            message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }),
207            message::UserContent::Document(message::Document {
208                data: message::DocumentSourceKind::Raw(raw),
209                ..
210            }) => {
211                let text = String::from_utf8_lossy(raw.as_slice()).into();
212                Ok(UserContent::Text { text })
213            }
214            message::UserContent::Document(message::Document {
215                data:
216                    message::DocumentSourceKind::Base64(text)
217                    | message::DocumentSourceKind::String(text),
218                ..
219            }) => Ok(UserContent::Text { text }),
220            message::UserContent::Image(message::Image { data, .. }) => match data {
221                message::DocumentSourceKind::Url(url) => Ok(UserContent::ImageUrl {
222                    image_url: ImageUrl { url },
223                }),
224                _ => Err(message::MessageError::ConversionError(
225                    "Huggingface only supports images as urls".into(),
226                )),
227            },
228            _ => Err(message::MessageError::ConversionError(
229                "Huggingface only supports text and images".into(),
230            )),
231        }
232    }
233}
234
235#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
236#[serde(tag = "role", rename_all = "lowercase")]
237pub enum Message {
238    System {
239        #[serde(deserialize_with = "string_or_one_or_many")]
240        content: OneOrMany<SystemContent>,
241    },
242    User {
243        #[serde(deserialize_with = "string_or_one_or_many")]
244        content: OneOrMany<UserContent>,
245    },
246    Assistant {
247        #[serde(default, deserialize_with = "json_utils::string_or_vec")]
248        content: Vec<AssistantContent>,
249        #[serde(default, deserialize_with = "json_utils::null_or_vec")]
250        tool_calls: Vec<ToolCall>,
251    },
252    #[serde(rename = "tool", alias = "Tool")]
253    ToolResult {
254        name: String,
255        #[serde(skip_serializing_if = "Option::is_none")]
256        arguments: Option<serde_json::Value>,
257        #[serde(
258            deserialize_with = "string_or_one_or_many",
259            serialize_with = "serialize_tool_content"
260        )]
261        content: OneOrMany<String>,
262    },
263}
264
265fn serialize_tool_content<S>(content: &OneOrMany<String>, serializer: S) -> Result<S::Ok, S::Error>
266where
267    S: Serializer,
268{
269    // OpenAI-compatible APIs expect tool content as a string, not an array
270    let joined = content
271        .iter()
272        .map(String::as_str)
273        .collect::<Vec<_>>()
274        .join("\n");
275    serializer.serialize_str(&joined)
276}
277
278impl Message {
279    pub fn system(content: &str) -> Self {
280        Message::System {
281            content: OneOrMany::one(SystemContent::Text {
282                text: content.to_string(),
283            }),
284        }
285    }
286}
287
288impl TryFrom<message::Message> for Vec<Message> {
289    type Error = message::MessageError;
290
291    fn try_from(message: message::Message) -> Result<Vec<Message>, Self::Error> {
292        match message {
293            message::Message::User { content } => {
294                let (tool_results, other_content): (Vec<_>, Vec<_>) = content
295                    .into_iter()
296                    .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
297
298                if !tool_results.is_empty() {
299                    tool_results
300                        .into_iter()
301                        .map(|content| match content {
302                            message::UserContent::ToolResult(message::ToolResult {
303                                id,
304                                content,
305                                ..
306                            }) => Ok::<_, message::MessageError>(Message::ToolResult {
307                                name: id,
308                                arguments: None,
309                                content: content.try_map(|content| match content {
310                                    message::ToolResultContent::Text(message::Text { text }) => {
311                                        Ok(text)
312                                    }
313                                    _ => Err(message::MessageError::ConversionError(
314                                        "Tool result content does not support non-text".into(),
315                                    )),
316                                })?,
317                            }),
318                            _ => unreachable!(),
319                        })
320                        .collect::<Result<Vec<_>, _>>()
321                } else {
322                    let other_content = OneOrMany::many(other_content).expect(
323                        "There must be other content here if there were no tool result content",
324                    );
325
326                    Ok(vec![Message::User {
327                        content: other_content.try_map(|content| match content {
328                            message::UserContent::Text(text) => {
329                                Ok(UserContent::Text { text: text.text })
330                            }
331                            message::UserContent::Image(image) => {
332                                let url = image.try_into_url()?;
333
334                                Ok(UserContent::ImageUrl {
335                                    image_url: ImageUrl { url },
336                                })
337                            }
338                            message::UserContent::Document(message::Document {
339                                data: message::DocumentSourceKind::Raw(raw), ..
340                            }) => {
341                                let text = String::from_utf8_lossy(raw.as_slice()).into();
342                                Ok(UserContent::Text { text })
343                            }
344                            message::UserContent::Document(message::Document {
345                                data: message::DocumentSourceKind::Base64(text) | message::DocumentSourceKind::String(text), ..
346                            }) => {
347                                Ok(UserContent::Text { text })
348                            }
349                            _ => Err(message::MessageError::ConversionError(
350                                "Huggingface inputs only support text and image URLs (both base64-encoded images and regular URLs)".into(),
351                            )),
352                        })?,
353                    }])
354                }
355            }
356            message::Message::Assistant { content, .. } => {
357                let (text_content, tool_calls) = content.into_iter().fold(
358                    (Vec::new(), Vec::new()),
359                    |(mut texts, mut tools), content| {
360                        match content {
361                            message::AssistantContent::Text(text) => texts.push(text),
362                            message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
363                            message::AssistantContent::Reasoning(_) => {
364                                unimplemented!("Reasoning is not supported on HuggingFace via Rig");
365                            }
366                            message::AssistantContent::Image(_) => {
367                                unimplemented!(
368                                    "Image content is not supported on HuggingFace via Rig"
369                                );
370                            }
371                        }
372                        (texts, tools)
373                    },
374                );
375
376                // `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
377                //  so either `content` or `tool_calls` will have some content.
378                Ok(vec![Message::Assistant {
379                    content: text_content
380                        .into_iter()
381                        .map(|content| AssistantContent::Text { text: content.text })
382                        .collect::<Vec<_>>(),
383                    tool_calls: tool_calls
384                        .into_iter()
385                        .map(|tool_call| tool_call.into())
386                        .collect::<Vec<_>>(),
387                }])
388            }
389        }
390    }
391}
392
393impl TryFrom<Message> for message::Message {
394    type Error = message::MessageError;
395
396    fn try_from(message: Message) -> Result<Self, Self::Error> {
397        Ok(match message {
398            Message::User { content, .. } => message::Message::User {
399                content: content.map(|content| content.into()),
400            },
401            Message::Assistant {
402                content,
403                tool_calls,
404                ..
405            } => {
406                let mut content = content
407                    .into_iter()
408                    .map(|content| match content {
409                        AssistantContent::Text { text } => message::AssistantContent::text(text),
410                    })
411                    .collect::<Vec<_>>();
412
413                content.extend(
414                    tool_calls
415                        .into_iter()
416                        .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
417                        .collect::<Result<Vec<_>, _>>()?,
418                );
419
420                message::Message::Assistant {
421                    id: None,
422                    content: OneOrMany::many(content).map_err(|_| {
423                        message::MessageError::ConversionError(
424                            "Neither `content` nor `tool_calls` was provided to the Message"
425                                .to_owned(),
426                        )
427                    })?,
428                }
429            }
430
431            Message::ToolResult { name, content, .. } => message::Message::User {
432                content: OneOrMany::one(message::UserContent::tool_result(
433                    name,
434                    content.map(message::ToolResultContent::text),
435                )),
436            },
437
438            // System messages should get stripped out when converting message's, this is just a
439            // stop gap to avoid obnoxious error handling or panic occurring.
440            Message::System { content, .. } => message::Message::User {
441                content: content.map(|c| match c {
442                    SystemContent::Text { text } => message::UserContent::text(text),
443                }),
444            },
445        })
446    }
447}
448
449#[derive(Clone, Debug, Deserialize, Serialize)]
450pub struct Choice {
451    pub finish_reason: String,
452    pub index: usize,
453    #[serde(default)]
454    pub logprobs: serde_json::Value,
455    pub message: Message,
456}
457
458#[derive(Debug, Deserialize, Clone, Serialize)]
459pub struct Usage {
460    pub completion_tokens: i32,
461    pub prompt_tokens: i32,
462    pub total_tokens: i32,
463}
464
465impl GetTokenUsage for Usage {
466    fn token_usage(&self) -> Option<crate::completion::Usage> {
467        let mut usage = crate::completion::Usage::new();
468        usage.input_tokens = self.prompt_tokens as u64;
469        usage.output_tokens = self.completion_tokens as u64;
470        usage.total_tokens = self.total_tokens as u64;
471
472        Some(usage)
473    }
474}
475
476#[derive(Clone, Debug, Deserialize, Serialize)]
477pub struct CompletionResponse {
478    pub created: i32,
479    pub id: String,
480    pub model: String,
481    pub choices: Vec<Choice>,
482    #[serde(default, deserialize_with = "default_string_on_null")]
483    pub system_fingerprint: String,
484    pub usage: Usage,
485}
486
487impl crate::telemetry::ProviderResponseExt for CompletionResponse {
488    type OutputMessage = Choice;
489    type Usage = Usage;
490
491    fn get_response_id(&self) -> Option<String> {
492        Some(self.id.clone())
493    }
494
495    fn get_response_model_name(&self) -> Option<String> {
496        Some(self.model.clone())
497    }
498
499    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
500        self.choices.clone()
501    }
502
503    fn get_text_response(&self) -> Option<String> {
504        let text_response = self
505            .choices
506            .iter()
507            .filter_map(|x| {
508                let Message::User { ref content } = x.message else {
509                    return None;
510                };
511
512                let text = content
513                    .iter()
514                    .filter_map(|x| {
515                        if let UserContent::Text { text } = x {
516                            Some(text.clone())
517                        } else {
518                            None
519                        }
520                    })
521                    .collect::<Vec<String>>();
522
523                if text.is_empty() {
524                    None
525                } else {
526                    Some(text.join("\n"))
527                }
528            })
529            .collect::<Vec<String>>()
530            .join("\n");
531
532        if text_response.is_empty() {
533            None
534        } else {
535            Some(text_response)
536        }
537    }
538
539    fn get_usage(&self) -> Option<Self::Usage> {
540        Some(self.usage.clone())
541    }
542}
543
544fn default_string_on_null<'de, D>(deserializer: D) -> Result<String, D::Error>
545where
546    D: Deserializer<'de>,
547{
548    match Option::<String>::deserialize(deserializer)? {
549        Some(value) => Ok(value),      // Use provided value
550        None => Ok(String::default()), // Use `Default` implementation
551    }
552}
553
554impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
555    type Error = CompletionError;
556
557    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
558        let choice = response.choices.first().ok_or_else(|| {
559            CompletionError::ResponseError("Response contained no choices".to_owned())
560        })?;
561
562        let content = match &choice.message {
563            Message::Assistant {
564                content,
565                tool_calls,
566                ..
567            } => {
568                let mut content = content
569                    .iter()
570                    .map(|c| match c {
571                        AssistantContent::Text { text } => message::AssistantContent::text(text),
572                    })
573                    .collect::<Vec<_>>();
574
575                content.extend(
576                    tool_calls
577                        .iter()
578                        .map(|call| {
579                            completion::AssistantContent::tool_call(
580                                &call.id,
581                                &call.function.name,
582                                call.function.arguments.clone(),
583                            )
584                        })
585                        .collect::<Vec<_>>(),
586                );
587                Ok(content)
588            }
589            _ => Err(CompletionError::ResponseError(
590                "Response did not contain a valid message or tool call".into(),
591            )),
592        }?;
593
594        let choice = OneOrMany::many(content).map_err(|_| {
595            CompletionError::ResponseError(
596                "Response contained no message or tool call (empty)".to_owned(),
597            )
598        })?;
599
600        let usage = completion::Usage {
601            input_tokens: response.usage.prompt_tokens as u64,
602            output_tokens: response.usage.completion_tokens as u64,
603            total_tokens: response.usage.total_tokens as u64,
604        };
605
606        Ok(completion::CompletionResponse {
607            choice,
608            usage,
609            raw_response: response,
610        })
611    }
612}
613
614#[derive(Debug, Serialize, Deserialize)]
615pub(super) struct HuggingfaceCompletionRequest {
616    model: String,
617    pub messages: Vec<Message>,
618    #[serde(flatten, skip_serializing_if = "Option::is_none")]
619    temperature: Option<f64>,
620    #[serde(skip_serializing_if = "Vec::is_empty")]
621    tools: Vec<ToolDefinition>,
622    #[serde(flatten, skip_serializing_if = "Option::is_none")]
623    tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
624    #[serde(flatten, skip_serializing_if = "Option::is_none")]
625    pub additional_params: Option<serde_json::Value>,
626}
627
628impl TryFrom<(&str, CompletionRequest)> for HuggingfaceCompletionRequest {
629    type Error = CompletionError;
630
631    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
632        let mut full_history: Vec<Message> = match &req.preamble {
633            Some(preamble) => vec![Message::system(preamble)],
634            None => vec![],
635        };
636        if let Some(docs) = req.normalized_documents() {
637            let docs: Vec<Message> = docs.try_into()?;
638            full_history.extend(docs);
639        }
640
641        let chat_history: Vec<Message> = req
642            .chat_history
643            .clone()
644            .into_iter()
645            .map(|message| message.try_into())
646            .collect::<Result<Vec<Vec<Message>>, _>>()?
647            .into_iter()
648            .flatten()
649            .collect();
650
651        full_history.extend(chat_history);
652
653        let tool_choice = req
654            .tool_choice
655            .clone()
656            .map(crate::providers::openai::completion::ToolChoice::try_from)
657            .transpose()?;
658
659        Ok(Self {
660            model: model.to_string(),
661            messages: full_history,
662            temperature: req.temperature,
663            tools: req
664                .tools
665                .clone()
666                .into_iter()
667                .map(ToolDefinition::from)
668                .collect::<Vec<_>>(),
669            tool_choice,
670            additional_params: req.additional_params,
671        })
672    }
673}
674
675#[derive(Clone)]
676pub struct CompletionModel<T = reqwest::Client> {
677    pub(crate) client: Client<T>,
678    /// Name of the model (e.g: google/gemma-2-2b-it)
679    pub model: String,
680}
681
682impl<T> CompletionModel<T> {
683    pub fn new(client: Client<T>, model: &str) -> Self {
684        Self {
685            client,
686            model: model.to_string(),
687        }
688    }
689}
690
691impl<T> completion::CompletionModel for CompletionModel<T>
692where
693    T: HttpClientExt + Clone + 'static,
694{
695    type Response = CompletionResponse;
696    type StreamingResponse = StreamingCompletionResponse;
697
698    type Client = Client<T>;
699
700    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
701        Self::new(client.clone(), &model.into())
702    }
703
704    #[cfg_attr(feature = "worker", worker::send)]
705    async fn completion(
706        &self,
707        completion_request: CompletionRequest,
708    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
709        let span = if tracing::Span::current().is_disabled() {
710            info_span!(
711                target: "rig::completions",
712                "chat",
713                gen_ai.operation.name = "chat",
714                gen_ai.provider.name = "huggingface",
715                gen_ai.request.model = self.model,
716                gen_ai.system_instructions = &completion_request.preamble,
717                gen_ai.response.id = tracing::field::Empty,
718                gen_ai.response.model = tracing::field::Empty,
719                gen_ai.usage.output_tokens = tracing::field::Empty,
720                gen_ai.usage.input_tokens = tracing::field::Empty,
721                gen_ai.input.messages = tracing::field::Empty,
722                gen_ai.output.messages = tracing::field::Empty,
723            )
724        } else {
725            tracing::Span::current()
726        };
727
728        let model = self.client.subprovider().model_identifier(&self.model);
729        let request = HuggingfaceCompletionRequest::try_from((model.as_ref(), completion_request))?;
730
731        span.record_model_input(&request.messages);
732
733        let request = serde_json::to_vec(&request)?;
734
735        let path = self.client.subprovider().completion_endpoint(&self.model);
736        let request = self
737            .client
738            .post(&path)?
739            .header("Content-Type", "application/json")
740            .body(request)
741            .map_err(|e| CompletionError::HttpError(e.into()))?;
742
743        let response = self.client.send(request).await?;
744
745        if response.status().is_success() {
746            let bytes: Vec<u8> = response.into_body().await?;
747            let text = String::from_utf8_lossy(&bytes);
748
749            tracing::debug!(target: "rig", "Huggingface completion error: {}", text);
750
751            match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&bytes)? {
752                ApiResponse::Ok(response) => {
753                    let span = tracing::Span::current();
754                    span.record_token_usage(&response.usage);
755                    span.record_model_output(&response.choices);
756                    span.record_response_metadata(&response);
757
758                    response.try_into()
759                }
760                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())),
761            }
762        } else {
763            let status = response.status();
764            let text: Vec<u8> = response.into_body().await?;
765            let text: String = String::from_utf8_lossy(&text).into();
766
767            Err(CompletionError::ProviderError(format!(
768                "{}: {}",
769                status, text
770            )))
771        }
772    }
773
774    #[cfg_attr(feature = "worker", worker::send)]
775    async fn stream(
776        &self,
777        request: CompletionRequest,
778    ) -> Result<
779        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
780        CompletionError,
781    > {
782        CompletionModel::stream(self, request).await
783    }
784}
785
786#[cfg(test)]
787mod tests {
788    use super::*;
789    use serde_path_to_error::deserialize;
790
791    #[test]
792    fn test_deserialize_message() {
793        let assistant_message_json = r#"
794        {
795            "role": "assistant",
796            "content": "\n\nHello there, how may I assist you today?"
797        }
798        "#;
799
800        let assistant_message_json2 = r#"
801        {
802            "role": "assistant",
803            "content": [
804                {
805                    "type": "text",
806                    "text": "\n\nHello there, how may I assist you today?"
807                }
808            ],
809            "tool_calls": null
810        }
811        "#;
812
813        let assistant_message_json3 = r#"
814        {
815            "role": "assistant",
816            "tool_calls": [
817                {
818                    "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
819                    "type": "function",
820                    "function": {
821                        "name": "subtract",
822                        "arguments": {"x": 2, "y": 5}
823                    }
824                }
825            ],
826            "content": null,
827            "refusal": null
828        }
829        "#;
830
831        let user_message_json = r#"
832        {
833            "role": "user",
834            "content": [
835                {
836                    "type": "text",
837                    "text": "What's in this image?"
838                },
839                {
840                    "type": "image_url",
841                    "image_url": {
842                        "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
843                    }
844                }
845            ]
846        }
847        "#;
848
849        let assistant_message: Message = {
850            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
851            deserialize(jd).unwrap_or_else(|err| {
852                panic!(
853                    "Deserialization error at {} ({}:{}): {}",
854                    err.path(),
855                    err.inner().line(),
856                    err.inner().column(),
857                    err
858                );
859            })
860        };
861
862        let assistant_message2: Message = {
863            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
864            deserialize(jd).unwrap_or_else(|err| {
865                panic!(
866                    "Deserialization error at {} ({}:{}): {}",
867                    err.path(),
868                    err.inner().line(),
869                    err.inner().column(),
870                    err
871                );
872            })
873        };
874
875        let assistant_message3: Message = {
876            let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
877                &mut serde_json::Deserializer::from_str(assistant_message_json3);
878            deserialize(jd).unwrap_or_else(|err| {
879                panic!(
880                    "Deserialization error at {} ({}:{}): {}",
881                    err.path(),
882                    err.inner().line(),
883                    err.inner().column(),
884                    err
885                );
886            })
887        };
888
889        let user_message: Message = {
890            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
891            deserialize(jd).unwrap_or_else(|err| {
892                panic!(
893                    "Deserialization error at {} ({}:{}): {}",
894                    err.path(),
895                    err.inner().line(),
896                    err.inner().column(),
897                    err
898                );
899            })
900        };
901
902        match assistant_message {
903            Message::Assistant { content, .. } => {
904                assert_eq!(
905                    content[0],
906                    AssistantContent::Text {
907                        text: "\n\nHello there, how may I assist you today?".to_string()
908                    }
909                );
910            }
911            _ => panic!("Expected assistant message"),
912        }
913
914        match assistant_message2 {
915            Message::Assistant {
916                content,
917                tool_calls,
918                ..
919            } => {
920                assert_eq!(
921                    content[0],
922                    AssistantContent::Text {
923                        text: "\n\nHello there, how may I assist you today?".to_string()
924                    }
925                );
926
927                assert_eq!(tool_calls, vec![]);
928            }
929            _ => panic!("Expected assistant message"),
930        }
931
932        match assistant_message3 {
933            Message::Assistant {
934                content,
935                tool_calls,
936                ..
937            } => {
938                assert!(content.is_empty());
939                assert_eq!(
940                    tool_calls[0],
941                    ToolCall {
942                        id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
943                        r#type: ToolType::Function,
944                        function: Function {
945                            name: "subtract".to_string(),
946                            arguments: serde_json::json!({"x": 2, "y": 5}),
947                        },
948                    }
949                );
950            }
951            _ => panic!("Expected assistant message"),
952        }
953
954        match user_message {
955            Message::User { content, .. } => {
956                let (first, second) = {
957                    let mut iter = content.into_iter();
958                    (iter.next().unwrap(), iter.next().unwrap())
959                };
960                assert_eq!(
961                    first,
962                    UserContent::Text {
963                        text: "What's in this image?".to_string()
964                    }
965                );
966                assert_eq!(second, UserContent::ImageUrl { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string() } });
967            }
968            _ => panic!("Expected user message"),
969        }
970    }
971
972    #[test]
973    fn test_message_to_message_conversion() {
974        let user_message = message::Message::User {
975            content: OneOrMany::one(message::UserContent::text("Hello")),
976        };
977
978        let assistant_message = message::Message::Assistant {
979            id: None,
980            content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
981        };
982
983        let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
984        let converted_assistant_message: Vec<Message> =
985            assistant_message.clone().try_into().unwrap();
986
987        match converted_user_message[0].clone() {
988            Message::User { content, .. } => {
989                assert_eq!(
990                    content.first(),
991                    UserContent::Text {
992                        text: "Hello".to_string()
993                    }
994                );
995            }
996            _ => panic!("Expected user message"),
997        }
998
999        match converted_assistant_message[0].clone() {
1000            Message::Assistant { content, .. } => {
1001                assert_eq!(
1002                    content[0],
1003                    AssistantContent::Text {
1004                        text: "Hi there!".to_string()
1005                    }
1006                );
1007            }
1008            _ => panic!("Expected assistant message"),
1009        }
1010
1011        let original_user_message: message::Message =
1012            converted_user_message[0].clone().try_into().unwrap();
1013        let original_assistant_message: message::Message =
1014            converted_assistant_message[0].clone().try_into().unwrap();
1015
1016        assert_eq!(original_user_message, user_message);
1017        assert_eq!(original_assistant_message, assistant_message);
1018    }
1019
1020    #[test]
1021    fn test_message_from_message_conversion() {
1022        let user_message = Message::User {
1023            content: OneOrMany::one(UserContent::Text {
1024                text: "Hello".to_string(),
1025            }),
1026        };
1027
1028        let assistant_message = Message::Assistant {
1029            content: vec![AssistantContent::Text {
1030                text: "Hi there!".to_string(),
1031            }],
1032            tool_calls: vec![],
1033        };
1034
1035        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1036        let converted_assistant_message: message::Message =
1037            assistant_message.clone().try_into().unwrap();
1038
1039        match converted_user_message.clone() {
1040            message::Message::User { content } => {
1041                assert_eq!(content.first(), message::UserContent::text("Hello"));
1042            }
1043            _ => panic!("Expected user message"),
1044        }
1045
1046        match converted_assistant_message.clone() {
1047            message::Message::Assistant { content, .. } => {
1048                assert_eq!(
1049                    content.first(),
1050                    message::AssistantContent::text("Hi there!")
1051                );
1052            }
1053            _ => panic!("Expected assistant message"),
1054        }
1055
1056        let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
1057        let original_assistant_message: Vec<Message> =
1058            converted_assistant_message.try_into().unwrap();
1059
1060        assert_eq!(original_user_message[0], user_message);
1061        assert_eq!(original_assistant_message[0], assistant_message);
1062    }
1063
1064    #[test]
1065    fn test_responses() {
1066        let fireworks_response_json = r#"
1067        {
1068            "choices": [
1069                {
1070                    "finish_reason": "tool_calls",
1071                    "index": 0,
1072                    "message": {
1073                        "role": "assistant",
1074                        "tool_calls": [
1075                            {
1076                                "function": {
1077                                "arguments": "{\"x\": 2, \"y\": 5}",
1078                                "name": "subtract"
1079                                },
1080                                "id": "call_1BspL6mQqjKgvsQbH1TIYkHf",
1081                                "index": 0,
1082                                "type": "function"
1083                            }
1084                        ]
1085                    }
1086                }
1087            ],
1088            "created": 1740704000,
1089            "id": "2a81f6a1-4866-42fb-9902-2655a2b5b1ff",
1090            "model": "accounts/fireworks/models/deepseek-v3",
1091            "object": "chat.completion",
1092            "usage": {
1093                "completion_tokens": 26,
1094                "prompt_tokens": 248,
1095                "total_tokens": 274
1096            }
1097        }
1098        "#;
1099
1100        let novita_response_json = r#"
1101        {
1102            "choices": [
1103                {
1104                    "finish_reason": "tool_calls",
1105                    "index": 0,
1106                    "logprobs": null,
1107                    "message": {
1108                        "audio": null,
1109                        "content": null,
1110                        "function_call": null,
1111                        "reasoning_content": null,
1112                        "refusal": null,
1113                        "role": "assistant",
1114                        "tool_calls": [
1115                            {
1116                                "function": {
1117                                    "arguments": "{\"x\": \"2\", \"y\": \"5\"}",
1118                                    "name": "subtract"
1119                                },
1120                                "id": "chatcmpl-tool-f6d2af7c8dc041058f95e2c2eede45c5",
1121                                "type": "function"
1122                            }
1123                        ]
1124                    },
1125                    "stop_reason": 128008
1126                }
1127            ],
1128            "created": 1740704592,
1129            "id": "chatcmpl-a92c60ae125c47c998ecdcb53387fed4",
1130            "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-fast",
1131            "object": "chat.completion",
1132            "prompt_logprobs": null,
1133            "service_tier": null,
1134            "system_fingerprint": null,
1135            "usage": {
1136                "completion_tokens": 28,
1137                "completion_tokens_details": null,
1138                "prompt_tokens": 335,
1139                "prompt_tokens_details": null,
1140                "total_tokens": 363
1141            }
1142        }
1143        "#;
1144
1145        let _firework_response: CompletionResponse = {
1146            let jd = &mut serde_json::Deserializer::from_str(fireworks_response_json);
1147            deserialize(jd).unwrap_or_else(|err| {
1148                panic!(
1149                    "Deserialization error at {} ({}:{}): {}",
1150                    err.path(),
1151                    err.inner().line(),
1152                    err.inner().column(),
1153                    err
1154                );
1155            })
1156        };
1157
1158        let _novita_response: CompletionResponse = {
1159            let jd = &mut serde_json::Deserializer::from_str(novita_response_json);
1160            deserialize(jd).unwrap_or_else(|err| {
1161                panic!(
1162                    "Deserialization error at {} ({}:{}): {}",
1163                    err.path(),
1164                    err.inner().line(),
1165                    err.inner().column(),
1166                    err
1167                );
1168            })
1169        };
1170    }
1171}