1use serde::{Deserialize, Deserializer, Serialize};
2use serde_json::{json, Value};
3use std::{convert::Infallible, str::FromStr};
4
5use super::client::Client;
6use crate::providers::openai::StreamingCompletionResponse;
7use crate::{
8 completion::{self, CompletionError, CompletionRequest},
9 json_utils,
10 message::{self},
11 one_or_many::string_or_one_or_many,
12 OneOrMany,
13};
14
15#[derive(Debug, Deserialize)]
16#[serde(untagged)]
17pub enum ApiResponse<T> {
18 Ok(T),
19 Err(Value),
20}
21
22pub const GEMMA_2: &str = "google/gemma-2-2b-it";
30pub const META_LLAMA_3_1: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
32pub const PHI_4: &str = "microsoft/phi-4";
34pub const SMALLTHINKER_PREVIEW: &str = "PowerInfer/SmallThinker-3B-Preview";
36pub const QWEN2_5: &str = "Qwen/Qwen2.5-7B-Instruct";
38pub const QWEN2_5_CODER: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
40
41pub const QWEN2_VL: &str = "Qwen/Qwen2-VL-7B-Instruct";
45pub const QWEN_QVQ_PREVIEW: &str = "Qwen/QVQ-72B-Preview";
47
48#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
49pub struct Function {
50 name: String,
51 #[serde(deserialize_with = "deserialize_arguments")]
52 pub arguments: serde_json::Value,
53}
54
55fn deserialize_arguments<'de, D>(deserializer: D) -> Result<Value, D::Error>
56where
57 D: Deserializer<'de>,
58{
59 let value = Value::deserialize(deserializer)?;
60
61 match value {
62 Value::String(s) => serde_json::from_str(&s).map_err(serde::de::Error::custom),
63 other => Ok(other),
64 }
65}
66
67impl From<Function> for message::ToolFunction {
68 fn from(value: Function) -> Self {
69 message::ToolFunction {
70 name: value.name,
71 arguments: value.arguments,
72 }
73 }
74}
75
76#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
77#[serde(rename_all = "lowercase")]
78pub enum ToolType {
79 #[default]
80 Function,
81}
82
83#[derive(Debug, Deserialize, Serialize, Clone)]
84pub struct ToolDefinition {
85 pub r#type: String,
86 pub function: completion::ToolDefinition,
87}
88
89impl From<completion::ToolDefinition> for ToolDefinition {
90 fn from(tool: completion::ToolDefinition) -> Self {
91 Self {
92 r#type: "function".into(),
93 function: tool,
94 }
95 }
96}
97
98#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
99pub struct ToolCall {
100 pub id: String,
101 pub r#type: ToolType,
102 pub function: Function,
103}
104
105impl From<ToolCall> for message::ToolCall {
106 fn from(value: ToolCall) -> Self {
107 message::ToolCall {
108 id: value.id,
109 function: value.function.into(),
110 }
111 }
112}
113
114impl From<message::ToolCall> for ToolCall {
115 fn from(value: message::ToolCall) -> Self {
116 ToolCall {
117 id: value.id,
118 r#type: ToolType::Function,
119 function: Function {
120 name: value.function.name,
121 arguments: value.function.arguments,
122 },
123 }
124 }
125}
126
127#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
128pub struct ImageUrl {
129 url: String,
130}
131
132#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
133#[serde(tag = "type", rename_all = "lowercase")]
134pub enum UserContent {
135 Text {
136 text: String,
137 },
138 #[serde(rename = "image_url")]
139 ImageUrl {
140 image_url: ImageUrl,
141 },
142}
143
144impl FromStr for UserContent {
145 type Err = Infallible;
146
147 fn from_str(s: &str) -> Result<Self, Self::Err> {
148 Ok(UserContent::Text {
149 text: s.to_string(),
150 })
151 }
152}
153
154#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
155#[serde(tag = "type", rename_all = "lowercase")]
156pub enum AssistantContent {
157 Text { text: String },
158}
159
160impl FromStr for AssistantContent {
161 type Err = Infallible;
162
163 fn from_str(s: &str) -> Result<Self, Self::Err> {
164 Ok(AssistantContent::Text {
165 text: s.to_string(),
166 })
167 }
168}
169
170#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
171#[serde(tag = "type", rename_all = "lowercase")]
172pub enum SystemContent {
173 Text { text: String },
174}
175
176impl FromStr for SystemContent {
177 type Err = Infallible;
178
179 fn from_str(s: &str) -> Result<Self, Self::Err> {
180 Ok(SystemContent::Text {
181 text: s.to_string(),
182 })
183 }
184}
185
186impl From<UserContent> for message::UserContent {
187 fn from(value: UserContent) -> Self {
188 match value {
189 UserContent::Text { text } => message::UserContent::text(text),
190 UserContent::ImageUrl { image_url } => message::UserContent::image(
191 image_url.url,
192 Some(message::ContentFormat::String),
193 None,
194 None,
195 ),
196 }
197 }
198}
199
200impl TryFrom<message::UserContent> for UserContent {
201 type Error = message::MessageError;
202
203 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
204 match content {
205 message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }),
206 message::UserContent::Image(message::Image { data, format, .. }) => match format {
207 Some(message::ContentFormat::String) => Ok(UserContent::ImageUrl {
208 image_url: ImageUrl { url: data },
209 }),
210 _ => Err(message::MessageError::ConversionError(
211 "Huggingface only supports images as urls".into(),
212 )),
213 },
214 _ => Err(message::MessageError::ConversionError(
215 "Huggingface only supports text and images".into(),
216 )),
217 }
218 }
219}
220
221#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
222#[serde(tag = "role", rename_all = "lowercase")]
223pub enum Message {
224 System {
225 #[serde(deserialize_with = "string_or_one_or_many")]
226 content: OneOrMany<SystemContent>,
227 },
228 User {
229 #[serde(deserialize_with = "string_or_one_or_many")]
230 content: OneOrMany<UserContent>,
231 },
232 Assistant {
233 #[serde(default, deserialize_with = "json_utils::string_or_vec")]
234 content: Vec<AssistantContent>,
235 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
236 tool_calls: Vec<ToolCall>,
237 },
238 #[serde(rename = "Tool")]
239 ToolResult {
240 name: String,
241 #[serde(skip_serializing_if = "Option::is_none")]
242 arguments: Option<serde_json::Value>,
243 #[serde(deserialize_with = "string_or_one_or_many")]
244 content: OneOrMany<String>,
245 },
246}
247
248impl Message {
249 pub fn system(content: &str) -> Self {
250 Message::System {
251 content: OneOrMany::one(SystemContent::Text {
252 text: content.to_string(),
253 }),
254 }
255 }
256}
257
258impl TryFrom<message::Message> for Vec<Message> {
259 type Error = message::MessageError;
260
261 fn try_from(message: message::Message) -> Result<Vec<Message>, Self::Error> {
262 match message {
263 message::Message::User { content } => {
264 let (tool_results, other_content): (Vec<_>, Vec<_>) = content
265 .into_iter()
266 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
267
268 if !tool_results.is_empty() {
269 tool_results
270 .into_iter()
271 .map(|content| match content {
272 message::UserContent::ToolResult(message::ToolResult {
273 id,
274 content,
275 }) => Ok::<_, message::MessageError>(Message::ToolResult {
276 name: id,
277 arguments: None,
278 content: content.try_map(|content| match content {
279 message::ToolResultContent::Text(message::Text { text }) => {
280 Ok(text)
281 }
282 _ => Err(message::MessageError::ConversionError(
283 "Tool result content does not support non-text".into(),
284 )),
285 })?,
286 }),
287 _ => unreachable!(),
288 })
289 .collect::<Result<Vec<_>, _>>()
290 } else {
291 let other_content = OneOrMany::many(other_content).expect(
292 "There must be other content here if there were no tool result content",
293 );
294
295 Ok(vec![Message::User {
296 content: other_content.try_map(|content| match content {
297 message::UserContent::Text(text) => {
298 Ok(UserContent::Text { text: text.text })
299 }
300 _ => Err(message::MessageError::ConversionError(
301 "Huggingface does not support non-text".into(),
302 )),
303 })?,
304 }])
305 }
306 }
307 message::Message::Assistant { content } => {
308 let (text_content, tool_calls) = content.into_iter().fold(
309 (Vec::new(), Vec::new()),
310 |(mut texts, mut tools), content| {
311 match content {
312 message::AssistantContent::Text(text) => texts.push(text),
313 message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
314 }
315 (texts, tools)
316 },
317 );
318
319 Ok(vec![Message::Assistant {
322 content: text_content
323 .into_iter()
324 .map(|content| AssistantContent::Text { text: content.text })
325 .collect::<Vec<_>>(),
326 tool_calls: tool_calls
327 .into_iter()
328 .map(|tool_call| tool_call.into())
329 .collect::<Vec<_>>(),
330 }])
331 }
332 }
333 }
334}
335
336impl TryFrom<Message> for message::Message {
337 type Error = message::MessageError;
338
339 fn try_from(message: Message) -> Result<Self, Self::Error> {
340 Ok(match message {
341 Message::User { content, .. } => message::Message::User {
342 content: content.map(|content| content.into()),
343 },
344 Message::Assistant {
345 content,
346 tool_calls,
347 ..
348 } => {
349 let mut content = content
350 .into_iter()
351 .map(|content| match content {
352 AssistantContent::Text { text } => message::AssistantContent::text(text),
353 })
354 .collect::<Vec<_>>();
355
356 content.extend(
357 tool_calls
358 .into_iter()
359 .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
360 .collect::<Result<Vec<_>, _>>()?,
361 );
362
363 message::Message::Assistant {
364 content: OneOrMany::many(content).map_err(|_| {
365 message::MessageError::ConversionError(
366 "Neither `content` nor `tool_calls` was provided to the Message"
367 .to_owned(),
368 )
369 })?,
370 }
371 }
372
373 Message::ToolResult { name, content, .. } => message::Message::User {
374 content: OneOrMany::one(message::UserContent::tool_result(
375 name,
376 content.map(message::ToolResultContent::text),
377 )),
378 },
379
380 Message::System { content, .. } => message::Message::User {
383 content: content.map(|c| match c {
384 SystemContent::Text { text } => message::UserContent::text(text),
385 }),
386 },
387 })
388 }
389}
390
391#[derive(Debug, Deserialize)]
392pub struct Choice {
393 pub finish_reason: String,
394 pub index: usize,
395 #[serde(default)]
396 pub logprobs: serde_json::Value,
397 pub message: Message,
398}
399
400#[derive(Debug, Deserialize, Clone)]
401pub struct Usage {
402 pub completion_tokens: i32,
403 pub prompt_tokens: i32,
404 pub total_tokens: i32,
405}
406
407#[derive(Debug, Deserialize)]
408pub struct CompletionResponse {
409 pub created: i32,
410 pub id: String,
411 pub model: String,
412 pub choices: Vec<Choice>,
413 #[serde(default, deserialize_with = "default_string_on_null")]
414 pub system_fingerprint: String,
415 pub usage: Usage,
416}
417
418fn default_string_on_null<'de, D>(deserializer: D) -> Result<String, D::Error>
419where
420 D: Deserializer<'de>,
421{
422 match Option::<String>::deserialize(deserializer)? {
423 Some(value) => Ok(value), None => Ok(String::default()), }
426}
427
428impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
429 type Error = CompletionError;
430
431 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
432 let choice = response.choices.first().ok_or_else(|| {
433 CompletionError::ResponseError("Response contained no choices".to_owned())
434 })?;
435
436 let content = match &choice.message {
437 Message::Assistant {
438 content,
439 tool_calls,
440 ..
441 } => {
442 let mut content = content
443 .iter()
444 .map(|c| match c {
445 AssistantContent::Text { text } => message::AssistantContent::text(text),
446 })
447 .collect::<Vec<_>>();
448
449 content.extend(
450 tool_calls
451 .iter()
452 .map(|call| {
453 completion::AssistantContent::tool_call(
454 &call.id,
455 &call.function.name,
456 call.function.arguments.clone(),
457 )
458 })
459 .collect::<Vec<_>>(),
460 );
461 Ok(content)
462 }
463 _ => Err(CompletionError::ResponseError(
464 "Response did not contain a valid message or tool call".into(),
465 )),
466 }?;
467
468 let choice = OneOrMany::many(content).map_err(|_| {
469 CompletionError::ResponseError(
470 "Response contained no message or tool call (empty)".to_owned(),
471 )
472 })?;
473
474 Ok(completion::CompletionResponse {
475 choice,
476 raw_response: response,
477 })
478 }
479}
480
481#[derive(Clone)]
482pub struct CompletionModel {
483 pub(crate) client: Client,
484 pub model: String,
486}
487
488impl CompletionModel {
489 pub fn new(client: Client, model: &str) -> Self {
490 Self {
491 client,
492 model: model.to_string(),
493 }
494 }
495
496 pub(crate) fn create_request_body(
497 &self,
498 completion_request: &CompletionRequest,
499 ) -> Result<serde_json::Value, CompletionError> {
500 let mut full_history: Vec<Message> = match &completion_request.preamble {
501 Some(preamble) => vec![Message::system(preamble)],
502 None => vec![],
503 };
504 if let Some(docs) = completion_request.normalized_documents() {
505 let docs: Vec<Message> = docs.try_into()?;
506 full_history.extend(docs);
507 }
508
509 let chat_history: Vec<Message> = completion_request
510 .chat_history
511 .clone()
512 .into_iter()
513 .map(|message| message.try_into())
514 .collect::<Result<Vec<Vec<Message>>, _>>()?
515 .into_iter()
516 .flatten()
517 .collect();
518
519 full_history.extend(chat_history);
520
521 let model = self.client.sub_provider.model_identifier(&self.model);
522
523 let request = if completion_request.tools.is_empty() {
524 json!({
525 "model": model,
526 "messages": full_history,
527 "temperature": completion_request.temperature,
528 })
529 } else {
530 json!({
531 "model": model,
532 "messages": full_history,
533 "temperature": completion_request.temperature,
534 "tools": completion_request.tools.clone().into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
535 "tool_choice": "auto",
536 })
537 };
538 Ok(request)
539 }
540}
541
542impl completion::CompletionModel for CompletionModel {
543 type Response = CompletionResponse;
544 type StreamingResponse = StreamingCompletionResponse;
545
546 #[cfg_attr(feature = "worker", worker::send)]
547 async fn completion(
548 &self,
549 completion_request: CompletionRequest,
550 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
551 let request = self.create_request_body(&completion_request)?;
552
553 let path = self.client.sub_provider.completion_endpoint(&self.model);
554
555 let request = if let Some(ref params) = completion_request.additional_params {
556 json_utils::merge(request, params.clone())
557 } else {
558 request
559 };
560
561 let response = self.client.post(&path).json(&request).send().await?;
562
563 if response.status().is_success() {
564 let t = response.text().await?;
565 tracing::debug!(target: "rig", "Huggingface completion error: {}", t);
566
567 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
568 ApiResponse::Ok(response) => {
569 tracing::info!(target: "rig",
570 "Huggingface completion token usage: {:?}",
571 format!("{:?}", response.usage)
572 );
573 response.try_into()
574 }
575 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())),
576 }
577 } else {
578 Err(CompletionError::ProviderError(format!(
579 "{}: {}",
580 response.status(),
581 response.text().await?
582 )))
583 }
584 }
585
586 #[cfg_attr(feature = "worker", worker::send)]
587 async fn stream(
588 &self,
589 request: CompletionRequest,
590 ) -> Result<
591 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
592 CompletionError,
593 > {
594 CompletionModel::stream(self, request).await
595 }
596}
597
598#[cfg(test)]
599mod tests {
600 use super::*;
601 use serde_path_to_error::deserialize;
602
603 #[test]
604 fn test_deserialize_message() {
605 let assistant_message_json = r#"
606 {
607 "role": "assistant",
608 "content": "\n\nHello there, how may I assist you today?"
609 }
610 "#;
611
612 let assistant_message_json2 = r#"
613 {
614 "role": "assistant",
615 "content": [
616 {
617 "type": "text",
618 "text": "\n\nHello there, how may I assist you today?"
619 }
620 ],
621 "tool_calls": null
622 }
623 "#;
624
625 let assistant_message_json3 = r#"
626 {
627 "role": "assistant",
628 "tool_calls": [
629 {
630 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
631 "type": "function",
632 "function": {
633 "name": "subtract",
634 "arguments": {"x": 2, "y": 5}
635 }
636 }
637 ],
638 "content": null,
639 "refusal": null
640 }
641 "#;
642
643 let user_message_json = r#"
644 {
645 "role": "user",
646 "content": [
647 {
648 "type": "text",
649 "text": "What's in this image?"
650 },
651 {
652 "type": "image_url",
653 "image_url": {
654 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
655 }
656 }
657 ]
658 }
659 "#;
660
661 let assistant_message: Message = {
662 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
663 deserialize(jd).unwrap_or_else(|err| {
664 panic!(
665 "Deserialization error at {} ({}:{}): {}",
666 err.path(),
667 err.inner().line(),
668 err.inner().column(),
669 err
670 );
671 })
672 };
673
674 let assistant_message2: Message = {
675 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
676 deserialize(jd).unwrap_or_else(|err| {
677 panic!(
678 "Deserialization error at {} ({}:{}): {}",
679 err.path(),
680 err.inner().line(),
681 err.inner().column(),
682 err
683 );
684 })
685 };
686
687 let assistant_message3: Message = {
688 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
689 &mut serde_json::Deserializer::from_str(assistant_message_json3);
690 deserialize(jd).unwrap_or_else(|err| {
691 panic!(
692 "Deserialization error at {} ({}:{}): {}",
693 err.path(),
694 err.inner().line(),
695 err.inner().column(),
696 err
697 );
698 })
699 };
700
701 let user_message: Message = {
702 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
703 deserialize(jd).unwrap_or_else(|err| {
704 panic!(
705 "Deserialization error at {} ({}:{}): {}",
706 err.path(),
707 err.inner().line(),
708 err.inner().column(),
709 err
710 );
711 })
712 };
713
714 match assistant_message {
715 Message::Assistant { content, .. } => {
716 assert_eq!(
717 content[0],
718 AssistantContent::Text {
719 text: "\n\nHello there, how may I assist you today?".to_string()
720 }
721 );
722 }
723 _ => panic!("Expected assistant message"),
724 }
725
726 match assistant_message2 {
727 Message::Assistant {
728 content,
729 tool_calls,
730 ..
731 } => {
732 assert_eq!(
733 content[0],
734 AssistantContent::Text {
735 text: "\n\nHello there, how may I assist you today?".to_string()
736 }
737 );
738
739 assert_eq!(tool_calls, vec![]);
740 }
741 _ => panic!("Expected assistant message"),
742 }
743
744 match assistant_message3 {
745 Message::Assistant {
746 content,
747 tool_calls,
748 ..
749 } => {
750 assert!(content.is_empty());
751 assert_eq!(
752 tool_calls[0],
753 ToolCall {
754 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
755 r#type: ToolType::Function,
756 function: Function {
757 name: "subtract".to_string(),
758 arguments: serde_json::json!({"x": 2, "y": 5}),
759 },
760 }
761 );
762 }
763 _ => panic!("Expected assistant message"),
764 }
765
766 match user_message {
767 Message::User { content, .. } => {
768 let (first, second) = {
769 let mut iter = content.into_iter();
770 (iter.next().unwrap(), iter.next().unwrap())
771 };
772 assert_eq!(
773 first,
774 UserContent::Text {
775 text: "What's in this image?".to_string()
776 }
777 );
778 assert_eq!(second, UserContent::ImageUrl { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string() } });
779 }
780 _ => panic!("Expected user message"),
781 }
782 }
783
784 #[test]
785 fn test_message_to_message_conversion() {
786 let user_message = message::Message::User {
787 content: OneOrMany::one(message::UserContent::text("Hello")),
788 };
789
790 let assistant_message = message::Message::Assistant {
791 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
792 };
793
794 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
795 let converted_assistant_message: Vec<Message> =
796 assistant_message.clone().try_into().unwrap();
797
798 match converted_user_message[0].clone() {
799 Message::User { content, .. } => {
800 assert_eq!(
801 content.first(),
802 UserContent::Text {
803 text: "Hello".to_string()
804 }
805 );
806 }
807 _ => panic!("Expected user message"),
808 }
809
810 match converted_assistant_message[0].clone() {
811 Message::Assistant { content, .. } => {
812 assert_eq!(
813 content[0],
814 AssistantContent::Text {
815 text: "Hi there!".to_string()
816 }
817 );
818 }
819 _ => panic!("Expected assistant message"),
820 }
821
822 let original_user_message: message::Message =
823 converted_user_message[0].clone().try_into().unwrap();
824 let original_assistant_message: message::Message =
825 converted_assistant_message[0].clone().try_into().unwrap();
826
827 assert_eq!(original_user_message, user_message);
828 assert_eq!(original_assistant_message, assistant_message);
829 }
830
831 #[test]
832 fn test_message_from_message_conversion() {
833 let user_message = Message::User {
834 content: OneOrMany::one(UserContent::Text {
835 text: "Hello".to_string(),
836 }),
837 };
838
839 let assistant_message = Message::Assistant {
840 content: vec![AssistantContent::Text {
841 text: "Hi there!".to_string(),
842 }],
843 tool_calls: vec![],
844 };
845
846 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
847 let converted_assistant_message: message::Message =
848 assistant_message.clone().try_into().unwrap();
849
850 match converted_user_message.clone() {
851 message::Message::User { content } => {
852 assert_eq!(content.first(), message::UserContent::text("Hello"));
853 }
854 _ => panic!("Expected user message"),
855 }
856
857 match converted_assistant_message.clone() {
858 message::Message::Assistant { content } => {
859 assert_eq!(
860 content.first(),
861 message::AssistantContent::text("Hi there!")
862 );
863 }
864 _ => panic!("Expected assistant message"),
865 }
866
867 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
868 let original_assistant_message: Vec<Message> =
869 converted_assistant_message.try_into().unwrap();
870
871 assert_eq!(original_user_message[0], user_message);
872 assert_eq!(original_assistant_message[0], assistant_message);
873 }
874
875 #[test]
876 fn test_responses() {
877 let fireworks_response_json = r#"
878 {
879 "choices": [
880 {
881 "finish_reason": "tool_calls",
882 "index": 0,
883 "message": {
884 "role": "assistant",
885 "tool_calls": [
886 {
887 "function": {
888 "arguments": "{\"x\": 2, \"y\": 5}",
889 "name": "subtract"
890 },
891 "id": "call_1BspL6mQqjKgvsQbH1TIYkHf",
892 "index": 0,
893 "type": "function"
894 }
895 ]
896 }
897 }
898 ],
899 "created": 1740704000,
900 "id": "2a81f6a1-4866-42fb-9902-2655a2b5b1ff",
901 "model": "accounts/fireworks/models/deepseek-v3",
902 "object": "chat.completion",
903 "usage": {
904 "completion_tokens": 26,
905 "prompt_tokens": 248,
906 "total_tokens": 274
907 }
908 }
909 "#;
910
911 let novita_response_json = r#"
912 {
913 "choices": [
914 {
915 "finish_reason": "tool_calls",
916 "index": 0,
917 "logprobs": null,
918 "message": {
919 "audio": null,
920 "content": null,
921 "function_call": null,
922 "reasoning_content": null,
923 "refusal": null,
924 "role": "assistant",
925 "tool_calls": [
926 {
927 "function": {
928 "arguments": "{\"x\": \"2\", \"y\": \"5\"}",
929 "name": "subtract"
930 },
931 "id": "chatcmpl-tool-f6d2af7c8dc041058f95e2c2eede45c5",
932 "type": "function"
933 }
934 ]
935 },
936 "stop_reason": 128008
937 }
938 ],
939 "created": 1740704592,
940 "id": "chatcmpl-a92c60ae125c47c998ecdcb53387fed4",
941 "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-fast",
942 "object": "chat.completion",
943 "prompt_logprobs": null,
944 "service_tier": null,
945 "system_fingerprint": null,
946 "usage": {
947 "completion_tokens": 28,
948 "completion_tokens_details": null,
949 "prompt_tokens": 335,
950 "prompt_tokens_details": null,
951 "total_tokens": 363
952 }
953 }
954 "#;
955
956 let _firework_response: CompletionResponse = {
957 let jd = &mut serde_json::Deserializer::from_str(fireworks_response_json);
958 deserialize(jd).unwrap_or_else(|err| {
959 panic!(
960 "Deserialization error at {} ({}:{}): {}",
961 err.path(),
962 err.inner().line(),
963 err.inner().column(),
964 err
965 );
966 })
967 };
968
969 let _novita_response: CompletionResponse = {
970 let jd = &mut serde_json::Deserializer::from_str(novita_response_json);
971 deserialize(jd).unwrap_or_else(|err| {
972 panic!(
973 "Deserialization error at {} ({}:{}): {}",
974 err.path(),
975 err.inner().line(),
976 err.inner().column(),
977 err
978 );
979 })
980 };
981 }
982}