1use serde::{Deserialize, Deserializer, Serialize};
2use serde_json::{Value, json};
3use std::{convert::Infallible, str::FromStr};
4
5use super::client::Client;
6use crate::providers::openai::StreamingCompletionResponse;
7use crate::{
8 OneOrMany,
9 completion::{self, CompletionError, CompletionRequest},
10 json_utils,
11 message::{self},
12 one_or_many::string_or_one_or_many,
13};
14
15#[derive(Debug, Deserialize)]
16#[serde(untagged)]
17pub enum ApiResponse<T> {
18 Ok(T),
19 Err(Value),
20}
21
22pub const GEMMA_2: &str = "google/gemma-2-2b-it";
30pub const META_LLAMA_3_1: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
32pub const PHI_4: &str = "microsoft/phi-4";
34pub const SMALLTHINKER_PREVIEW: &str = "PowerInfer/SmallThinker-3B-Preview";
36pub const QWEN2_5: &str = "Qwen/Qwen2.5-7B-Instruct";
38pub const QWEN2_5_CODER: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
40
41pub const QWEN2_VL: &str = "Qwen/Qwen2-VL-7B-Instruct";
45pub const QWEN_QVQ_PREVIEW: &str = "Qwen/QVQ-72B-Preview";
47
48#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
49pub struct Function {
50 name: String,
51 #[serde(deserialize_with = "deserialize_arguments")]
52 pub arguments: serde_json::Value,
53}
54
55fn deserialize_arguments<'de, D>(deserializer: D) -> Result<Value, D::Error>
56where
57 D: Deserializer<'de>,
58{
59 let value = Value::deserialize(deserializer)?;
60
61 match value {
62 Value::String(s) => serde_json::from_str(&s).map_err(serde::de::Error::custom),
63 other => Ok(other),
64 }
65}
66
67impl From<Function> for message::ToolFunction {
68 fn from(value: Function) -> Self {
69 message::ToolFunction {
70 name: value.name,
71 arguments: value.arguments,
72 }
73 }
74}
75
76#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
77#[serde(rename_all = "lowercase")]
78pub enum ToolType {
79 #[default]
80 Function,
81}
82
83#[derive(Debug, Deserialize, Serialize, Clone)]
84pub struct ToolDefinition {
85 pub r#type: String,
86 pub function: completion::ToolDefinition,
87}
88
89impl From<completion::ToolDefinition> for ToolDefinition {
90 fn from(tool: completion::ToolDefinition) -> Self {
91 Self {
92 r#type: "function".into(),
93 function: tool,
94 }
95 }
96}
97
98#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
99pub struct ToolCall {
100 pub id: String,
101 pub r#type: ToolType,
102 pub function: Function,
103}
104
105impl From<ToolCall> for message::ToolCall {
106 fn from(value: ToolCall) -> Self {
107 message::ToolCall {
108 id: value.id,
109 call_id: None,
110 function: value.function.into(),
111 }
112 }
113}
114
115impl From<message::ToolCall> for ToolCall {
116 fn from(value: message::ToolCall) -> Self {
117 ToolCall {
118 id: value.id,
119 r#type: ToolType::Function,
120 function: Function {
121 name: value.function.name,
122 arguments: value.function.arguments,
123 },
124 }
125 }
126}
127
128#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
129pub struct ImageUrl {
130 url: String,
131}
132
133#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
134#[serde(tag = "type", rename_all = "lowercase")]
135pub enum UserContent {
136 Text {
137 text: String,
138 },
139 #[serde(rename = "image_url")]
140 ImageUrl {
141 image_url: ImageUrl,
142 },
143}
144
145impl FromStr for UserContent {
146 type Err = Infallible;
147
148 fn from_str(s: &str) -> Result<Self, Self::Err> {
149 Ok(UserContent::Text {
150 text: s.to_string(),
151 })
152 }
153}
154
155#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
156#[serde(tag = "type", rename_all = "lowercase")]
157pub enum AssistantContent {
158 Text { text: String },
159}
160
161impl FromStr for AssistantContent {
162 type Err = Infallible;
163
164 fn from_str(s: &str) -> Result<Self, Self::Err> {
165 Ok(AssistantContent::Text {
166 text: s.to_string(),
167 })
168 }
169}
170
171#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
172#[serde(tag = "type", rename_all = "lowercase")]
173pub enum SystemContent {
174 Text { text: String },
175}
176
177impl FromStr for SystemContent {
178 type Err = Infallible;
179
180 fn from_str(s: &str) -> Result<Self, Self::Err> {
181 Ok(SystemContent::Text {
182 text: s.to_string(),
183 })
184 }
185}
186
187impl From<UserContent> for message::UserContent {
188 fn from(value: UserContent) -> Self {
189 match value {
190 UserContent::Text { text } => message::UserContent::text(text),
191 UserContent::ImageUrl { image_url } => {
192 message::UserContent::image_url(image_url.url, None, None)
193 }
194 }
195 }
196}
197
198impl TryFrom<message::UserContent> for UserContent {
199 type Error = message::MessageError;
200
201 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
202 match content {
203 message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }),
204 message::UserContent::Image(message::Image { data, .. }) => match data {
205 message::DocumentSourceKind::Url(url) => Ok(UserContent::ImageUrl {
206 image_url: ImageUrl { url },
207 }),
208 _ => Err(message::MessageError::ConversionError(
209 "Huggingface only supports images as urls".into(),
210 )),
211 },
212 _ => Err(message::MessageError::ConversionError(
213 "Huggingface only supports text and images".into(),
214 )),
215 }
216 }
217}
218
219#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
220#[serde(tag = "role", rename_all = "lowercase")]
221pub enum Message {
222 System {
223 #[serde(deserialize_with = "string_or_one_or_many")]
224 content: OneOrMany<SystemContent>,
225 },
226 User {
227 #[serde(deserialize_with = "string_or_one_or_many")]
228 content: OneOrMany<UserContent>,
229 },
230 Assistant {
231 #[serde(default, deserialize_with = "json_utils::string_or_vec")]
232 content: Vec<AssistantContent>,
233 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
234 tool_calls: Vec<ToolCall>,
235 },
236 #[serde(rename = "Tool")]
237 ToolResult {
238 name: String,
239 #[serde(skip_serializing_if = "Option::is_none")]
240 arguments: Option<serde_json::Value>,
241 #[serde(deserialize_with = "string_or_one_or_many")]
242 content: OneOrMany<String>,
243 },
244}
245
246impl Message {
247 pub fn system(content: &str) -> Self {
248 Message::System {
249 content: OneOrMany::one(SystemContent::Text {
250 text: content.to_string(),
251 }),
252 }
253 }
254}
255
256impl TryFrom<message::Message> for Vec<Message> {
257 type Error = message::MessageError;
258
259 fn try_from(message: message::Message) -> Result<Vec<Message>, Self::Error> {
260 match message {
261 message::Message::User { content } => {
262 let (tool_results, other_content): (Vec<_>, Vec<_>) = content
263 .into_iter()
264 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
265
266 if !tool_results.is_empty() {
267 tool_results
268 .into_iter()
269 .map(|content| match content {
270 message::UserContent::ToolResult(message::ToolResult {
271 id,
272 content,
273 ..
274 }) => Ok::<_, message::MessageError>(Message::ToolResult {
275 name: id,
276 arguments: None,
277 content: content.try_map(|content| match content {
278 message::ToolResultContent::Text(message::Text { text }) => {
279 Ok(text)
280 }
281 _ => Err(message::MessageError::ConversionError(
282 "Tool result content does not support non-text".into(),
283 )),
284 })?,
285 }),
286 _ => unreachable!(),
287 })
288 .collect::<Result<Vec<_>, _>>()
289 } else {
290 let other_content = OneOrMany::many(other_content).expect(
291 "There must be other content here if there were no tool result content",
292 );
293
294 Ok(vec![Message::User {
295 content: other_content.try_map(|content| match content {
296 message::UserContent::Text(text) => {
297 Ok(UserContent::Text { text: text.text })
298 }
299 _ => Err(message::MessageError::ConversionError(
300 "Huggingface does not support non-text".into(),
301 )),
302 })?,
303 }])
304 }
305 }
306 message::Message::Assistant { content, .. } => {
307 let (text_content, tool_calls) = content.into_iter().fold(
308 (Vec::new(), Vec::new()),
309 |(mut texts, mut tools), content| {
310 match content {
311 message::AssistantContent::Text(text) => texts.push(text),
312 message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
313 message::AssistantContent::Reasoning(_) => {
314 unimplemented!("Reasoning is not supported on HuggingFace via Rig");
315 }
316 }
317 (texts, tools)
318 },
319 );
320
321 Ok(vec![Message::Assistant {
324 content: text_content
325 .into_iter()
326 .map(|content| AssistantContent::Text { text: content.text })
327 .collect::<Vec<_>>(),
328 tool_calls: tool_calls
329 .into_iter()
330 .map(|tool_call| tool_call.into())
331 .collect::<Vec<_>>(),
332 }])
333 }
334 }
335 }
336}
337
338impl TryFrom<Message> for message::Message {
339 type Error = message::MessageError;
340
341 fn try_from(message: Message) -> Result<Self, Self::Error> {
342 Ok(match message {
343 Message::User { content, .. } => message::Message::User {
344 content: content.map(|content| content.into()),
345 },
346 Message::Assistant {
347 content,
348 tool_calls,
349 ..
350 } => {
351 let mut content = content
352 .into_iter()
353 .map(|content| match content {
354 AssistantContent::Text { text } => message::AssistantContent::text(text),
355 })
356 .collect::<Vec<_>>();
357
358 content.extend(
359 tool_calls
360 .into_iter()
361 .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
362 .collect::<Result<Vec<_>, _>>()?,
363 );
364
365 message::Message::Assistant {
366 id: None,
367 content: OneOrMany::many(content).map_err(|_| {
368 message::MessageError::ConversionError(
369 "Neither `content` nor `tool_calls` was provided to the Message"
370 .to_owned(),
371 )
372 })?,
373 }
374 }
375
376 Message::ToolResult { name, content, .. } => message::Message::User {
377 content: OneOrMany::one(message::UserContent::tool_result(
378 name,
379 content.map(message::ToolResultContent::text),
380 )),
381 },
382
383 Message::System { content, .. } => message::Message::User {
386 content: content.map(|c| match c {
387 SystemContent::Text { text } => message::UserContent::text(text),
388 }),
389 },
390 })
391 }
392}
393
394#[derive(Debug, Deserialize, Serialize)]
395pub struct Choice {
396 pub finish_reason: String,
397 pub index: usize,
398 #[serde(default)]
399 pub logprobs: serde_json::Value,
400 pub message: Message,
401}
402
403#[derive(Debug, Deserialize, Clone, Serialize)]
404pub struct Usage {
405 pub completion_tokens: i32,
406 pub prompt_tokens: i32,
407 pub total_tokens: i32,
408}
409
410#[derive(Debug, Deserialize, Serialize)]
411pub struct CompletionResponse {
412 pub created: i32,
413 pub id: String,
414 pub model: String,
415 pub choices: Vec<Choice>,
416 #[serde(default, deserialize_with = "default_string_on_null")]
417 pub system_fingerprint: String,
418 pub usage: Usage,
419}
420
421fn default_string_on_null<'de, D>(deserializer: D) -> Result<String, D::Error>
422where
423 D: Deserializer<'de>,
424{
425 match Option::<String>::deserialize(deserializer)? {
426 Some(value) => Ok(value), None => Ok(String::default()), }
429}
430
431impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
432 type Error = CompletionError;
433
434 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
435 let choice = response.choices.first().ok_or_else(|| {
436 CompletionError::ResponseError("Response contained no choices".to_owned())
437 })?;
438
439 let content = match &choice.message {
440 Message::Assistant {
441 content,
442 tool_calls,
443 ..
444 } => {
445 let mut content = content
446 .iter()
447 .map(|c| match c {
448 AssistantContent::Text { text } => message::AssistantContent::text(text),
449 })
450 .collect::<Vec<_>>();
451
452 content.extend(
453 tool_calls
454 .iter()
455 .map(|call| {
456 completion::AssistantContent::tool_call(
457 &call.id,
458 &call.function.name,
459 call.function.arguments.clone(),
460 )
461 })
462 .collect::<Vec<_>>(),
463 );
464 Ok(content)
465 }
466 _ => Err(CompletionError::ResponseError(
467 "Response did not contain a valid message or tool call".into(),
468 )),
469 }?;
470
471 let choice = OneOrMany::many(content).map_err(|_| {
472 CompletionError::ResponseError(
473 "Response contained no message or tool call (empty)".to_owned(),
474 )
475 })?;
476
477 let usage = completion::Usage {
478 input_tokens: response.usage.prompt_tokens as u64,
479 output_tokens: response.usage.completion_tokens as u64,
480 total_tokens: response.usage.total_tokens as u64,
481 };
482
483 Ok(completion::CompletionResponse {
484 choice,
485 usage,
486 raw_response: response,
487 })
488 }
489}
490
491#[derive(Clone)]
492pub struct CompletionModel {
493 pub(crate) client: Client,
494 pub model: String,
496}
497
498impl CompletionModel {
499 pub fn new(client: Client, model: &str) -> Self {
500 Self {
501 client,
502 model: model.to_string(),
503 }
504 }
505
506 pub(crate) fn create_request_body(
507 &self,
508 completion_request: &CompletionRequest,
509 ) -> Result<serde_json::Value, CompletionError> {
510 let mut full_history: Vec<Message> = match &completion_request.preamble {
511 Some(preamble) => vec![Message::system(preamble)],
512 None => vec![],
513 };
514 if let Some(docs) = completion_request.normalized_documents() {
515 let docs: Vec<Message> = docs.try_into()?;
516 full_history.extend(docs);
517 }
518
519 let chat_history: Vec<Message> = completion_request
520 .chat_history
521 .clone()
522 .into_iter()
523 .map(|message| message.try_into())
524 .collect::<Result<Vec<Vec<Message>>, _>>()?
525 .into_iter()
526 .flatten()
527 .collect();
528
529 full_history.extend(chat_history);
530
531 let model = self.client.sub_provider.model_identifier(&self.model);
532
533 let request = if completion_request.tools.is_empty() {
534 json!({
535 "model": model,
536 "messages": full_history,
537 "temperature": completion_request.temperature,
538 })
539 } else {
540 json!({
541 "model": model,
542 "messages": full_history,
543 "temperature": completion_request.temperature,
544 "tools": completion_request.tools.clone().into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
545 "tool_choice": "auto",
546 })
547 };
548 Ok(request)
549 }
550}
551
552impl completion::CompletionModel for CompletionModel {
553 type Response = CompletionResponse;
554 type StreamingResponse = StreamingCompletionResponse;
555
556 #[cfg_attr(feature = "worker", worker::send)]
557 async fn completion(
558 &self,
559 completion_request: CompletionRequest,
560 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
561 let request = self.create_request_body(&completion_request)?;
562
563 let path = self.client.sub_provider.completion_endpoint(&self.model);
564
565 let request = if let Some(ref params) = completion_request.additional_params {
566 json_utils::merge(request, params.clone())
567 } else {
568 request
569 };
570
571 let response = self.client.post(&path).json(&request).send().await?;
572
573 if response.status().is_success() {
574 let t = response.text().await?;
575 tracing::debug!(target: "rig", "Huggingface completion error: {}", t);
576
577 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
578 ApiResponse::Ok(response) => {
579 tracing::info!(target: "rig",
580 "Huggingface completion token usage: {:?}",
581 format!("{:?}", response.usage)
582 );
583 response.try_into()
584 }
585 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())),
586 }
587 } else {
588 Err(CompletionError::ProviderError(format!(
589 "{}: {}",
590 response.status(),
591 response.text().await?
592 )))
593 }
594 }
595
596 #[cfg_attr(feature = "worker", worker::send)]
597 async fn stream(
598 &self,
599 request: CompletionRequest,
600 ) -> Result<
601 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
602 CompletionError,
603 > {
604 CompletionModel::stream(self, request).await
605 }
606}
607
608#[cfg(test)]
609mod tests {
610 use super::*;
611 use serde_path_to_error::deserialize;
612
613 #[test]
614 fn test_deserialize_message() {
615 let assistant_message_json = r#"
616 {
617 "role": "assistant",
618 "content": "\n\nHello there, how may I assist you today?"
619 }
620 "#;
621
622 let assistant_message_json2 = r#"
623 {
624 "role": "assistant",
625 "content": [
626 {
627 "type": "text",
628 "text": "\n\nHello there, how may I assist you today?"
629 }
630 ],
631 "tool_calls": null
632 }
633 "#;
634
635 let assistant_message_json3 = r#"
636 {
637 "role": "assistant",
638 "tool_calls": [
639 {
640 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
641 "type": "function",
642 "function": {
643 "name": "subtract",
644 "arguments": {"x": 2, "y": 5}
645 }
646 }
647 ],
648 "content": null,
649 "refusal": null
650 }
651 "#;
652
653 let user_message_json = r#"
654 {
655 "role": "user",
656 "content": [
657 {
658 "type": "text",
659 "text": "What's in this image?"
660 },
661 {
662 "type": "image_url",
663 "image_url": {
664 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
665 }
666 }
667 ]
668 }
669 "#;
670
671 let assistant_message: Message = {
672 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
673 deserialize(jd).unwrap_or_else(|err| {
674 panic!(
675 "Deserialization error at {} ({}:{}): {}",
676 err.path(),
677 err.inner().line(),
678 err.inner().column(),
679 err
680 );
681 })
682 };
683
684 let assistant_message2: Message = {
685 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
686 deserialize(jd).unwrap_or_else(|err| {
687 panic!(
688 "Deserialization error at {} ({}:{}): {}",
689 err.path(),
690 err.inner().line(),
691 err.inner().column(),
692 err
693 );
694 })
695 };
696
697 let assistant_message3: Message = {
698 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
699 &mut serde_json::Deserializer::from_str(assistant_message_json3);
700 deserialize(jd).unwrap_or_else(|err| {
701 panic!(
702 "Deserialization error at {} ({}:{}): {}",
703 err.path(),
704 err.inner().line(),
705 err.inner().column(),
706 err
707 );
708 })
709 };
710
711 let user_message: Message = {
712 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
713 deserialize(jd).unwrap_or_else(|err| {
714 panic!(
715 "Deserialization error at {} ({}:{}): {}",
716 err.path(),
717 err.inner().line(),
718 err.inner().column(),
719 err
720 );
721 })
722 };
723
724 match assistant_message {
725 Message::Assistant { content, .. } => {
726 assert_eq!(
727 content[0],
728 AssistantContent::Text {
729 text: "\n\nHello there, how may I assist you today?".to_string()
730 }
731 );
732 }
733 _ => panic!("Expected assistant message"),
734 }
735
736 match assistant_message2 {
737 Message::Assistant {
738 content,
739 tool_calls,
740 ..
741 } => {
742 assert_eq!(
743 content[0],
744 AssistantContent::Text {
745 text: "\n\nHello there, how may I assist you today?".to_string()
746 }
747 );
748
749 assert_eq!(tool_calls, vec![]);
750 }
751 _ => panic!("Expected assistant message"),
752 }
753
754 match assistant_message3 {
755 Message::Assistant {
756 content,
757 tool_calls,
758 ..
759 } => {
760 assert!(content.is_empty());
761 assert_eq!(
762 tool_calls[0],
763 ToolCall {
764 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
765 r#type: ToolType::Function,
766 function: Function {
767 name: "subtract".to_string(),
768 arguments: serde_json::json!({"x": 2, "y": 5}),
769 },
770 }
771 );
772 }
773 _ => panic!("Expected assistant message"),
774 }
775
776 match user_message {
777 Message::User { content, .. } => {
778 let (first, second) = {
779 let mut iter = content.into_iter();
780 (iter.next().unwrap(), iter.next().unwrap())
781 };
782 assert_eq!(
783 first,
784 UserContent::Text {
785 text: "What's in this image?".to_string()
786 }
787 );
788 assert_eq!(second, UserContent::ImageUrl { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string() } });
789 }
790 _ => panic!("Expected user message"),
791 }
792 }
793
794 #[test]
795 fn test_message_to_message_conversion() {
796 let user_message = message::Message::User {
797 content: OneOrMany::one(message::UserContent::text("Hello")),
798 };
799
800 let assistant_message = message::Message::Assistant {
801 id: None,
802 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
803 };
804
805 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
806 let converted_assistant_message: Vec<Message> =
807 assistant_message.clone().try_into().unwrap();
808
809 match converted_user_message[0].clone() {
810 Message::User { content, .. } => {
811 assert_eq!(
812 content.first(),
813 UserContent::Text {
814 text: "Hello".to_string()
815 }
816 );
817 }
818 _ => panic!("Expected user message"),
819 }
820
821 match converted_assistant_message[0].clone() {
822 Message::Assistant { content, .. } => {
823 assert_eq!(
824 content[0],
825 AssistantContent::Text {
826 text: "Hi there!".to_string()
827 }
828 );
829 }
830 _ => panic!("Expected assistant message"),
831 }
832
833 let original_user_message: message::Message =
834 converted_user_message[0].clone().try_into().unwrap();
835 let original_assistant_message: message::Message =
836 converted_assistant_message[0].clone().try_into().unwrap();
837
838 assert_eq!(original_user_message, user_message);
839 assert_eq!(original_assistant_message, assistant_message);
840 }
841
842 #[test]
843 fn test_message_from_message_conversion() {
844 let user_message = Message::User {
845 content: OneOrMany::one(UserContent::Text {
846 text: "Hello".to_string(),
847 }),
848 };
849
850 let assistant_message = Message::Assistant {
851 content: vec![AssistantContent::Text {
852 text: "Hi there!".to_string(),
853 }],
854 tool_calls: vec![],
855 };
856
857 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
858 let converted_assistant_message: message::Message =
859 assistant_message.clone().try_into().unwrap();
860
861 match converted_user_message.clone() {
862 message::Message::User { content } => {
863 assert_eq!(content.first(), message::UserContent::text("Hello"));
864 }
865 _ => panic!("Expected user message"),
866 }
867
868 match converted_assistant_message.clone() {
869 message::Message::Assistant { content, .. } => {
870 assert_eq!(
871 content.first(),
872 message::AssistantContent::text("Hi there!")
873 );
874 }
875 _ => panic!("Expected assistant message"),
876 }
877
878 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
879 let original_assistant_message: Vec<Message> =
880 converted_assistant_message.try_into().unwrap();
881
882 assert_eq!(original_user_message[0], user_message);
883 assert_eq!(original_assistant_message[0], assistant_message);
884 }
885
886 #[test]
887 fn test_responses() {
888 let fireworks_response_json = r#"
889 {
890 "choices": [
891 {
892 "finish_reason": "tool_calls",
893 "index": 0,
894 "message": {
895 "role": "assistant",
896 "tool_calls": [
897 {
898 "function": {
899 "arguments": "{\"x\": 2, \"y\": 5}",
900 "name": "subtract"
901 },
902 "id": "call_1BspL6mQqjKgvsQbH1TIYkHf",
903 "index": 0,
904 "type": "function"
905 }
906 ]
907 }
908 }
909 ],
910 "created": 1740704000,
911 "id": "2a81f6a1-4866-42fb-9902-2655a2b5b1ff",
912 "model": "accounts/fireworks/models/deepseek-v3",
913 "object": "chat.completion",
914 "usage": {
915 "completion_tokens": 26,
916 "prompt_tokens": 248,
917 "total_tokens": 274
918 }
919 }
920 "#;
921
922 let novita_response_json = r#"
923 {
924 "choices": [
925 {
926 "finish_reason": "tool_calls",
927 "index": 0,
928 "logprobs": null,
929 "message": {
930 "audio": null,
931 "content": null,
932 "function_call": null,
933 "reasoning_content": null,
934 "refusal": null,
935 "role": "assistant",
936 "tool_calls": [
937 {
938 "function": {
939 "arguments": "{\"x\": \"2\", \"y\": \"5\"}",
940 "name": "subtract"
941 },
942 "id": "chatcmpl-tool-f6d2af7c8dc041058f95e2c2eede45c5",
943 "type": "function"
944 }
945 ]
946 },
947 "stop_reason": 128008
948 }
949 ],
950 "created": 1740704592,
951 "id": "chatcmpl-a92c60ae125c47c998ecdcb53387fed4",
952 "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-fast",
953 "object": "chat.completion",
954 "prompt_logprobs": null,
955 "service_tier": null,
956 "system_fingerprint": null,
957 "usage": {
958 "completion_tokens": 28,
959 "completion_tokens_details": null,
960 "prompt_tokens": 335,
961 "prompt_tokens_details": null,
962 "total_tokens": 363
963 }
964 }
965 "#;
966
967 let _firework_response: CompletionResponse = {
968 let jd = &mut serde_json::Deserializer::from_str(fireworks_response_json);
969 deserialize(jd).unwrap_or_else(|err| {
970 panic!(
971 "Deserialization error at {} ({}:{}): {}",
972 err.path(),
973 err.inner().line(),
974 err.inner().column(),
975 err
976 );
977 })
978 };
979
980 let _novita_response: CompletionResponse = {
981 let jd = &mut serde_json::Deserializer::from_str(novita_response_json);
982 deserialize(jd).unwrap_or_else(|err| {
983 panic!(
984 "Deserialization error at {} ({}:{}): {}",
985 err.path(),
986 err.inner().line(),
987 err.inner().column(),
988 err
989 );
990 })
991 };
992 }
993}