1use serde::{Deserialize, Deserializer, Serialize};
2use serde_json::{Value, json};
3use std::{convert::Infallible, str::FromStr};
4
5use super::client::Client;
6use crate::providers::openai::StreamingCompletionResponse;
7use crate::{
8 OneOrMany,
9 completion::{self, CompletionError, CompletionRequest},
10 json_utils,
11 message::{self},
12 one_or_many::string_or_one_or_many,
13};
14
15#[derive(Debug, Deserialize)]
16#[serde(untagged)]
17pub enum ApiResponse<T> {
18 Ok(T),
19 Err(Value),
20}
21
22pub const GEMMA_2: &str = "google/gemma-2-2b-it";
30pub const META_LLAMA_3_1: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
32pub const PHI_4: &str = "microsoft/phi-4";
34pub const SMALLTHINKER_PREVIEW: &str = "PowerInfer/SmallThinker-3B-Preview";
36pub const QWEN2_5: &str = "Qwen/Qwen2.5-7B-Instruct";
38pub const QWEN2_5_CODER: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
40
41pub const QWEN2_VL: &str = "Qwen/Qwen2-VL-7B-Instruct";
45pub const QWEN_QVQ_PREVIEW: &str = "Qwen/QVQ-72B-Preview";
47
48#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
49pub struct Function {
50 name: String,
51 #[serde(deserialize_with = "deserialize_arguments")]
52 pub arguments: serde_json::Value,
53}
54
55fn deserialize_arguments<'de, D>(deserializer: D) -> Result<Value, D::Error>
56where
57 D: Deserializer<'de>,
58{
59 let value = Value::deserialize(deserializer)?;
60
61 match value {
62 Value::String(s) => serde_json::from_str(&s).map_err(serde::de::Error::custom),
63 other => Ok(other),
64 }
65}
66
67impl From<Function> for message::ToolFunction {
68 fn from(value: Function) -> Self {
69 message::ToolFunction {
70 name: value.name,
71 arguments: value.arguments,
72 }
73 }
74}
75
76#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
77#[serde(rename_all = "lowercase")]
78pub enum ToolType {
79 #[default]
80 Function,
81}
82
83#[derive(Debug, Deserialize, Serialize, Clone)]
84pub struct ToolDefinition {
85 pub r#type: String,
86 pub function: completion::ToolDefinition,
87}
88
89impl From<completion::ToolDefinition> for ToolDefinition {
90 fn from(tool: completion::ToolDefinition) -> Self {
91 Self {
92 r#type: "function".into(),
93 function: tool,
94 }
95 }
96}
97
98#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
99pub struct ToolCall {
100 pub id: String,
101 pub r#type: ToolType,
102 pub function: Function,
103}
104
105impl From<ToolCall> for message::ToolCall {
106 fn from(value: ToolCall) -> Self {
107 message::ToolCall {
108 id: value.id,
109 call_id: None,
110 function: value.function.into(),
111 }
112 }
113}
114
115impl From<message::ToolCall> for ToolCall {
116 fn from(value: message::ToolCall) -> Self {
117 ToolCall {
118 id: value.id,
119 r#type: ToolType::Function,
120 function: Function {
121 name: value.function.name,
122 arguments: value.function.arguments,
123 },
124 }
125 }
126}
127
128#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
129pub struct ImageUrl {
130 url: String,
131}
132
133#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
134#[serde(tag = "type", rename_all = "lowercase")]
135pub enum UserContent {
136 Text {
137 text: String,
138 },
139 #[serde(rename = "image_url")]
140 ImageUrl {
141 image_url: ImageUrl,
142 },
143}
144
145impl FromStr for UserContent {
146 type Err = Infallible;
147
148 fn from_str(s: &str) -> Result<Self, Self::Err> {
149 Ok(UserContent::Text {
150 text: s.to_string(),
151 })
152 }
153}
154
155#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
156#[serde(tag = "type", rename_all = "lowercase")]
157pub enum AssistantContent {
158 Text { text: String },
159}
160
161impl FromStr for AssistantContent {
162 type Err = Infallible;
163
164 fn from_str(s: &str) -> Result<Self, Self::Err> {
165 Ok(AssistantContent::Text {
166 text: s.to_string(),
167 })
168 }
169}
170
171#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
172#[serde(tag = "type", rename_all = "lowercase")]
173pub enum SystemContent {
174 Text { text: String },
175}
176
177impl FromStr for SystemContent {
178 type Err = Infallible;
179
180 fn from_str(s: &str) -> Result<Self, Self::Err> {
181 Ok(SystemContent::Text {
182 text: s.to_string(),
183 })
184 }
185}
186
187impl From<UserContent> for message::UserContent {
188 fn from(value: UserContent) -> Self {
189 match value {
190 UserContent::Text { text } => message::UserContent::text(text),
191 UserContent::ImageUrl { image_url } => message::UserContent::image(
192 image_url.url,
193 Some(message::ContentFormat::String),
194 None,
195 None,
196 ),
197 }
198 }
199}
200
201impl TryFrom<message::UserContent> for UserContent {
202 type Error = message::MessageError;
203
204 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
205 match content {
206 message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }),
207 message::UserContent::Image(message::Image { data, format, .. }) => match format {
208 Some(message::ContentFormat::String) => Ok(UserContent::ImageUrl {
209 image_url: ImageUrl { url: data },
210 }),
211 _ => Err(message::MessageError::ConversionError(
212 "Huggingface only supports images as urls".into(),
213 )),
214 },
215 _ => Err(message::MessageError::ConversionError(
216 "Huggingface only supports text and images".into(),
217 )),
218 }
219 }
220}
221
222#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
223#[serde(tag = "role", rename_all = "lowercase")]
224pub enum Message {
225 System {
226 #[serde(deserialize_with = "string_or_one_or_many")]
227 content: OneOrMany<SystemContent>,
228 },
229 User {
230 #[serde(deserialize_with = "string_or_one_or_many")]
231 content: OneOrMany<UserContent>,
232 },
233 Assistant {
234 #[serde(default, deserialize_with = "json_utils::string_or_vec")]
235 content: Vec<AssistantContent>,
236 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
237 tool_calls: Vec<ToolCall>,
238 },
239 #[serde(rename = "Tool")]
240 ToolResult {
241 name: String,
242 #[serde(skip_serializing_if = "Option::is_none")]
243 arguments: Option<serde_json::Value>,
244 #[serde(deserialize_with = "string_or_one_or_many")]
245 content: OneOrMany<String>,
246 },
247}
248
249impl Message {
250 pub fn system(content: &str) -> Self {
251 Message::System {
252 content: OneOrMany::one(SystemContent::Text {
253 text: content.to_string(),
254 }),
255 }
256 }
257}
258
259impl TryFrom<message::Message> for Vec<Message> {
260 type Error = message::MessageError;
261
262 fn try_from(message: message::Message) -> Result<Vec<Message>, Self::Error> {
263 match message {
264 message::Message::User { content } => {
265 let (tool_results, other_content): (Vec<_>, Vec<_>) = content
266 .into_iter()
267 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
268
269 if !tool_results.is_empty() {
270 tool_results
271 .into_iter()
272 .map(|content| match content {
273 message::UserContent::ToolResult(message::ToolResult {
274 id,
275 content,
276 ..
277 }) => Ok::<_, message::MessageError>(Message::ToolResult {
278 name: id,
279 arguments: None,
280 content: content.try_map(|content| match content {
281 message::ToolResultContent::Text(message::Text { text }) => {
282 Ok(text)
283 }
284 _ => Err(message::MessageError::ConversionError(
285 "Tool result content does not support non-text".into(),
286 )),
287 })?,
288 }),
289 _ => unreachable!(),
290 })
291 .collect::<Result<Vec<_>, _>>()
292 } else {
293 let other_content = OneOrMany::many(other_content).expect(
294 "There must be other content here if there were no tool result content",
295 );
296
297 Ok(vec![Message::User {
298 content: other_content.try_map(|content| match content {
299 message::UserContent::Text(text) => {
300 Ok(UserContent::Text { text: text.text })
301 }
302 _ => Err(message::MessageError::ConversionError(
303 "Huggingface does not support non-text".into(),
304 )),
305 })?,
306 }])
307 }
308 }
309 message::Message::Assistant { content, .. } => {
310 let (text_content, tool_calls) = content.into_iter().fold(
311 (Vec::new(), Vec::new()),
312 |(mut texts, mut tools), content| {
313 match content {
314 message::AssistantContent::Text(text) => texts.push(text),
315 message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
316 }
317 (texts, tools)
318 },
319 );
320
321 Ok(vec![Message::Assistant {
324 content: text_content
325 .into_iter()
326 .map(|content| AssistantContent::Text { text: content.text })
327 .collect::<Vec<_>>(),
328 tool_calls: tool_calls
329 .into_iter()
330 .map(|tool_call| tool_call.into())
331 .collect::<Vec<_>>(),
332 }])
333 }
334 }
335 }
336}
337
338impl TryFrom<Message> for message::Message {
339 type Error = message::MessageError;
340
341 fn try_from(message: Message) -> Result<Self, Self::Error> {
342 Ok(match message {
343 Message::User { content, .. } => message::Message::User {
344 content: content.map(|content| content.into()),
345 },
346 Message::Assistant {
347 content,
348 tool_calls,
349 ..
350 } => {
351 let mut content = content
352 .into_iter()
353 .map(|content| match content {
354 AssistantContent::Text { text } => message::AssistantContent::text(text),
355 })
356 .collect::<Vec<_>>();
357
358 content.extend(
359 tool_calls
360 .into_iter()
361 .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
362 .collect::<Result<Vec<_>, _>>()?,
363 );
364
365 message::Message::Assistant {
366 id: None,
367 content: OneOrMany::many(content).map_err(|_| {
368 message::MessageError::ConversionError(
369 "Neither `content` nor `tool_calls` was provided to the Message"
370 .to_owned(),
371 )
372 })?,
373 }
374 }
375
376 Message::ToolResult { name, content, .. } => message::Message::User {
377 content: OneOrMany::one(message::UserContent::tool_result(
378 name,
379 content.map(message::ToolResultContent::text),
380 )),
381 },
382
383 Message::System { content, .. } => message::Message::User {
386 content: content.map(|c| match c {
387 SystemContent::Text { text } => message::UserContent::text(text),
388 }),
389 },
390 })
391 }
392}
393
394#[derive(Debug, Deserialize)]
395pub struct Choice {
396 pub finish_reason: String,
397 pub index: usize,
398 #[serde(default)]
399 pub logprobs: serde_json::Value,
400 pub message: Message,
401}
402
403#[derive(Debug, Deserialize, Clone)]
404pub struct Usage {
405 pub completion_tokens: i32,
406 pub prompt_tokens: i32,
407 pub total_tokens: i32,
408}
409
410#[derive(Debug, Deserialize)]
411pub struct CompletionResponse {
412 pub created: i32,
413 pub id: String,
414 pub model: String,
415 pub choices: Vec<Choice>,
416 #[serde(default, deserialize_with = "default_string_on_null")]
417 pub system_fingerprint: String,
418 pub usage: Usage,
419}
420
421fn default_string_on_null<'de, D>(deserializer: D) -> Result<String, D::Error>
422where
423 D: Deserializer<'de>,
424{
425 match Option::<String>::deserialize(deserializer)? {
426 Some(value) => Ok(value), None => Ok(String::default()), }
429}
430
431impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
432 type Error = CompletionError;
433
434 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
435 let choice = response.choices.first().ok_or_else(|| {
436 CompletionError::ResponseError("Response contained no choices".to_owned())
437 })?;
438
439 let content = match &choice.message {
440 Message::Assistant {
441 content,
442 tool_calls,
443 ..
444 } => {
445 let mut content = content
446 .iter()
447 .map(|c| match c {
448 AssistantContent::Text { text } => message::AssistantContent::text(text),
449 })
450 .collect::<Vec<_>>();
451
452 content.extend(
453 tool_calls
454 .iter()
455 .map(|call| {
456 completion::AssistantContent::tool_call(
457 &call.id,
458 &call.function.name,
459 call.function.arguments.clone(),
460 )
461 })
462 .collect::<Vec<_>>(),
463 );
464 Ok(content)
465 }
466 _ => Err(CompletionError::ResponseError(
467 "Response did not contain a valid message or tool call".into(),
468 )),
469 }?;
470
471 let choice = OneOrMany::many(content).map_err(|_| {
472 CompletionError::ResponseError(
473 "Response contained no message or tool call (empty)".to_owned(),
474 )
475 })?;
476
477 Ok(completion::CompletionResponse {
478 choice,
479 raw_response: response,
480 })
481 }
482}
483
484#[derive(Clone)]
485pub struct CompletionModel {
486 pub(crate) client: Client,
487 pub model: String,
489}
490
491impl CompletionModel {
492 pub fn new(client: Client, model: &str) -> Self {
493 Self {
494 client,
495 model: model.to_string(),
496 }
497 }
498
499 pub(crate) fn create_request_body(
500 &self,
501 completion_request: &CompletionRequest,
502 ) -> Result<serde_json::Value, CompletionError> {
503 let mut full_history: Vec<Message> = match &completion_request.preamble {
504 Some(preamble) => vec![Message::system(preamble)],
505 None => vec![],
506 };
507 if let Some(docs) = completion_request.normalized_documents() {
508 let docs: Vec<Message> = docs.try_into()?;
509 full_history.extend(docs);
510 }
511
512 let chat_history: Vec<Message> = completion_request
513 .chat_history
514 .clone()
515 .into_iter()
516 .map(|message| message.try_into())
517 .collect::<Result<Vec<Vec<Message>>, _>>()?
518 .into_iter()
519 .flatten()
520 .collect();
521
522 full_history.extend(chat_history);
523
524 let model = self.client.sub_provider.model_identifier(&self.model);
525
526 let request = if completion_request.tools.is_empty() {
527 json!({
528 "model": model,
529 "messages": full_history,
530 "temperature": completion_request.temperature,
531 })
532 } else {
533 json!({
534 "model": model,
535 "messages": full_history,
536 "temperature": completion_request.temperature,
537 "tools": completion_request.tools.clone().into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
538 "tool_choice": "auto",
539 })
540 };
541 Ok(request)
542 }
543}
544
545impl completion::CompletionModel for CompletionModel {
546 type Response = CompletionResponse;
547 type StreamingResponse = StreamingCompletionResponse;
548
549 #[cfg_attr(feature = "worker", worker::send)]
550 async fn completion(
551 &self,
552 completion_request: CompletionRequest,
553 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
554 let request = self.create_request_body(&completion_request)?;
555
556 let path = self.client.sub_provider.completion_endpoint(&self.model);
557
558 let request = if let Some(ref params) = completion_request.additional_params {
559 json_utils::merge(request, params.clone())
560 } else {
561 request
562 };
563
564 let response = self.client.post(&path).json(&request).send().await?;
565
566 if response.status().is_success() {
567 let t = response.text().await?;
568 tracing::debug!(target: "rig", "Huggingface completion error: {}", t);
569
570 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
571 ApiResponse::Ok(response) => {
572 tracing::info!(target: "rig",
573 "Huggingface completion token usage: {:?}",
574 format!("{:?}", response.usage)
575 );
576 response.try_into()
577 }
578 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())),
579 }
580 } else {
581 Err(CompletionError::ProviderError(format!(
582 "{}: {}",
583 response.status(),
584 response.text().await?
585 )))
586 }
587 }
588
589 #[cfg_attr(feature = "worker", worker::send)]
590 async fn stream(
591 &self,
592 request: CompletionRequest,
593 ) -> Result<
594 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
595 CompletionError,
596 > {
597 CompletionModel::stream(self, request).await
598 }
599}
600
601#[cfg(test)]
602mod tests {
603 use super::*;
604 use serde_path_to_error::deserialize;
605
606 #[test]
607 fn test_deserialize_message() {
608 let assistant_message_json = r#"
609 {
610 "role": "assistant",
611 "content": "\n\nHello there, how may I assist you today?"
612 }
613 "#;
614
615 let assistant_message_json2 = r#"
616 {
617 "role": "assistant",
618 "content": [
619 {
620 "type": "text",
621 "text": "\n\nHello there, how may I assist you today?"
622 }
623 ],
624 "tool_calls": null
625 }
626 "#;
627
628 let assistant_message_json3 = r#"
629 {
630 "role": "assistant",
631 "tool_calls": [
632 {
633 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
634 "type": "function",
635 "function": {
636 "name": "subtract",
637 "arguments": {"x": 2, "y": 5}
638 }
639 }
640 ],
641 "content": null,
642 "refusal": null
643 }
644 "#;
645
646 let user_message_json = r#"
647 {
648 "role": "user",
649 "content": [
650 {
651 "type": "text",
652 "text": "What's in this image?"
653 },
654 {
655 "type": "image_url",
656 "image_url": {
657 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
658 }
659 }
660 ]
661 }
662 "#;
663
664 let assistant_message: Message = {
665 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
666 deserialize(jd).unwrap_or_else(|err| {
667 panic!(
668 "Deserialization error at {} ({}:{}): {}",
669 err.path(),
670 err.inner().line(),
671 err.inner().column(),
672 err
673 );
674 })
675 };
676
677 let assistant_message2: Message = {
678 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
679 deserialize(jd).unwrap_or_else(|err| {
680 panic!(
681 "Deserialization error at {} ({}:{}): {}",
682 err.path(),
683 err.inner().line(),
684 err.inner().column(),
685 err
686 );
687 })
688 };
689
690 let assistant_message3: Message = {
691 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
692 &mut serde_json::Deserializer::from_str(assistant_message_json3);
693 deserialize(jd).unwrap_or_else(|err| {
694 panic!(
695 "Deserialization error at {} ({}:{}): {}",
696 err.path(),
697 err.inner().line(),
698 err.inner().column(),
699 err
700 );
701 })
702 };
703
704 let user_message: Message = {
705 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
706 deserialize(jd).unwrap_or_else(|err| {
707 panic!(
708 "Deserialization error at {} ({}:{}): {}",
709 err.path(),
710 err.inner().line(),
711 err.inner().column(),
712 err
713 );
714 })
715 };
716
717 match assistant_message {
718 Message::Assistant { content, .. } => {
719 assert_eq!(
720 content[0],
721 AssistantContent::Text {
722 text: "\n\nHello there, how may I assist you today?".to_string()
723 }
724 );
725 }
726 _ => panic!("Expected assistant message"),
727 }
728
729 match assistant_message2 {
730 Message::Assistant {
731 content,
732 tool_calls,
733 ..
734 } => {
735 assert_eq!(
736 content[0],
737 AssistantContent::Text {
738 text: "\n\nHello there, how may I assist you today?".to_string()
739 }
740 );
741
742 assert_eq!(tool_calls, vec![]);
743 }
744 _ => panic!("Expected assistant message"),
745 }
746
747 match assistant_message3 {
748 Message::Assistant {
749 content,
750 tool_calls,
751 ..
752 } => {
753 assert!(content.is_empty());
754 assert_eq!(
755 tool_calls[0],
756 ToolCall {
757 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
758 r#type: ToolType::Function,
759 function: Function {
760 name: "subtract".to_string(),
761 arguments: serde_json::json!({"x": 2, "y": 5}),
762 },
763 }
764 );
765 }
766 _ => panic!("Expected assistant message"),
767 }
768
769 match user_message {
770 Message::User { content, .. } => {
771 let (first, second) = {
772 let mut iter = content.into_iter();
773 (iter.next().unwrap(), iter.next().unwrap())
774 };
775 assert_eq!(
776 first,
777 UserContent::Text {
778 text: "What's in this image?".to_string()
779 }
780 );
781 assert_eq!(second, UserContent::ImageUrl { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string() } });
782 }
783 _ => panic!("Expected user message"),
784 }
785 }
786
787 #[test]
788 fn test_message_to_message_conversion() {
789 let user_message = message::Message::User {
790 content: OneOrMany::one(message::UserContent::text("Hello")),
791 };
792
793 let assistant_message = message::Message::Assistant {
794 id: None,
795 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
796 };
797
798 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
799 let converted_assistant_message: Vec<Message> =
800 assistant_message.clone().try_into().unwrap();
801
802 match converted_user_message[0].clone() {
803 Message::User { content, .. } => {
804 assert_eq!(
805 content.first(),
806 UserContent::Text {
807 text: "Hello".to_string()
808 }
809 );
810 }
811 _ => panic!("Expected user message"),
812 }
813
814 match converted_assistant_message[0].clone() {
815 Message::Assistant { content, .. } => {
816 assert_eq!(
817 content[0],
818 AssistantContent::Text {
819 text: "Hi there!".to_string()
820 }
821 );
822 }
823 _ => panic!("Expected assistant message"),
824 }
825
826 let original_user_message: message::Message =
827 converted_user_message[0].clone().try_into().unwrap();
828 let original_assistant_message: message::Message =
829 converted_assistant_message[0].clone().try_into().unwrap();
830
831 assert_eq!(original_user_message, user_message);
832 assert_eq!(original_assistant_message, assistant_message);
833 }
834
835 #[test]
836 fn test_message_from_message_conversion() {
837 let user_message = Message::User {
838 content: OneOrMany::one(UserContent::Text {
839 text: "Hello".to_string(),
840 }),
841 };
842
843 let assistant_message = Message::Assistant {
844 content: vec![AssistantContent::Text {
845 text: "Hi there!".to_string(),
846 }],
847 tool_calls: vec![],
848 };
849
850 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
851 let converted_assistant_message: message::Message =
852 assistant_message.clone().try_into().unwrap();
853
854 match converted_user_message.clone() {
855 message::Message::User { content } => {
856 assert_eq!(content.first(), message::UserContent::text("Hello"));
857 }
858 _ => panic!("Expected user message"),
859 }
860
861 match converted_assistant_message.clone() {
862 message::Message::Assistant { content, .. } => {
863 assert_eq!(
864 content.first(),
865 message::AssistantContent::text("Hi there!")
866 );
867 }
868 _ => panic!("Expected assistant message"),
869 }
870
871 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
872 let original_assistant_message: Vec<Message> =
873 converted_assistant_message.try_into().unwrap();
874
875 assert_eq!(original_user_message[0], user_message);
876 assert_eq!(original_assistant_message[0], assistant_message);
877 }
878
879 #[test]
880 fn test_responses() {
881 let fireworks_response_json = r#"
882 {
883 "choices": [
884 {
885 "finish_reason": "tool_calls",
886 "index": 0,
887 "message": {
888 "role": "assistant",
889 "tool_calls": [
890 {
891 "function": {
892 "arguments": "{\"x\": 2, \"y\": 5}",
893 "name": "subtract"
894 },
895 "id": "call_1BspL6mQqjKgvsQbH1TIYkHf",
896 "index": 0,
897 "type": "function"
898 }
899 ]
900 }
901 }
902 ],
903 "created": 1740704000,
904 "id": "2a81f6a1-4866-42fb-9902-2655a2b5b1ff",
905 "model": "accounts/fireworks/models/deepseek-v3",
906 "object": "chat.completion",
907 "usage": {
908 "completion_tokens": 26,
909 "prompt_tokens": 248,
910 "total_tokens": 274
911 }
912 }
913 "#;
914
915 let novita_response_json = r#"
916 {
917 "choices": [
918 {
919 "finish_reason": "tool_calls",
920 "index": 0,
921 "logprobs": null,
922 "message": {
923 "audio": null,
924 "content": null,
925 "function_call": null,
926 "reasoning_content": null,
927 "refusal": null,
928 "role": "assistant",
929 "tool_calls": [
930 {
931 "function": {
932 "arguments": "{\"x\": \"2\", \"y\": \"5\"}",
933 "name": "subtract"
934 },
935 "id": "chatcmpl-tool-f6d2af7c8dc041058f95e2c2eede45c5",
936 "type": "function"
937 }
938 ]
939 },
940 "stop_reason": 128008
941 }
942 ],
943 "created": 1740704592,
944 "id": "chatcmpl-a92c60ae125c47c998ecdcb53387fed4",
945 "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-fast",
946 "object": "chat.completion",
947 "prompt_logprobs": null,
948 "service_tier": null,
949 "system_fingerprint": null,
950 "usage": {
951 "completion_tokens": 28,
952 "completion_tokens_details": null,
953 "prompt_tokens": 335,
954 "prompt_tokens_details": null,
955 "total_tokens": 363
956 }
957 }
958 "#;
959
960 let _firework_response: CompletionResponse = {
961 let jd = &mut serde_json::Deserializer::from_str(fireworks_response_json);
962 deserialize(jd).unwrap_or_else(|err| {
963 panic!(
964 "Deserialization error at {} ({}:{}): {}",
965 err.path(),
966 err.inner().line(),
967 err.inner().column(),
968 err
969 );
970 })
971 };
972
973 let _novita_response: CompletionResponse = {
974 let jd = &mut serde_json::Deserializer::from_str(novita_response_json);
975 deserialize(jd).unwrap_or_else(|err| {
976 panic!(
977 "Deserialization error at {} ({}:{}): {}",
978 err.path(),
979 err.inner().line(),
980 err.inner().column(),
981 err
982 );
983 })
984 };
985 }
986}