1use serde::{Deserialize, Deserializer, Serialize};
2use serde_json::{Value, json};
3use std::{convert::Infallible, str::FromStr};
4
5use super::client::Client;
6use crate::providers::openai::StreamingCompletionResponse;
7use crate::{
8 OneOrMany,
9 completion::{self, CompletionError, CompletionRequest},
10 json_utils,
11 message::{self},
12 one_or_many::string_or_one_or_many,
13};
14
15#[derive(Debug, Deserialize)]
16#[serde(untagged)]
17pub enum ApiResponse<T> {
18 Ok(T),
19 Err(Value),
20}
21
22pub const GEMMA_2: &str = "google/gemma-2-2b-it";
30pub const META_LLAMA_3_1: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
32pub const PHI_4: &str = "microsoft/phi-4";
34pub const SMALLTHINKER_PREVIEW: &str = "PowerInfer/SmallThinker-3B-Preview";
36pub const QWEN2_5: &str = "Qwen/Qwen2.5-7B-Instruct";
38pub const QWEN2_5_CODER: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
40
41pub const QWEN2_VL: &str = "Qwen/Qwen2-VL-7B-Instruct";
45pub const QWEN_QVQ_PREVIEW: &str = "Qwen/QVQ-72B-Preview";
47
48#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
49pub struct Function {
50 name: String,
51 #[serde(deserialize_with = "deserialize_arguments")]
52 pub arguments: serde_json::Value,
53}
54
55fn deserialize_arguments<'de, D>(deserializer: D) -> Result<Value, D::Error>
56where
57 D: Deserializer<'de>,
58{
59 let value = Value::deserialize(deserializer)?;
60
61 match value {
62 Value::String(s) => serde_json::from_str(&s).map_err(serde::de::Error::custom),
63 other => Ok(other),
64 }
65}
66
67impl From<Function> for message::ToolFunction {
68 fn from(value: Function) -> Self {
69 message::ToolFunction {
70 name: value.name,
71 arguments: value.arguments,
72 }
73 }
74}
75
76#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
77#[serde(rename_all = "lowercase")]
78pub enum ToolType {
79 #[default]
80 Function,
81}
82
83#[derive(Debug, Deserialize, Serialize, Clone)]
84pub struct ToolDefinition {
85 pub r#type: String,
86 pub function: completion::ToolDefinition,
87}
88
89impl From<completion::ToolDefinition> for ToolDefinition {
90 fn from(tool: completion::ToolDefinition) -> Self {
91 Self {
92 r#type: "function".into(),
93 function: tool,
94 }
95 }
96}
97
98#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
99pub struct ToolCall {
100 pub id: String,
101 pub r#type: ToolType,
102 pub function: Function,
103}
104
105impl From<ToolCall> for message::ToolCall {
106 fn from(value: ToolCall) -> Self {
107 message::ToolCall {
108 id: value.id,
109 call_id: None,
110 function: value.function.into(),
111 }
112 }
113}
114
115impl From<message::ToolCall> for ToolCall {
116 fn from(value: message::ToolCall) -> Self {
117 ToolCall {
118 id: value.id,
119 r#type: ToolType::Function,
120 function: Function {
121 name: value.function.name,
122 arguments: value.function.arguments,
123 },
124 }
125 }
126}
127
128#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
129pub struct ImageUrl {
130 url: String,
131}
132
133#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
134#[serde(tag = "type", rename_all = "lowercase")]
135pub enum UserContent {
136 Text {
137 text: String,
138 },
139 #[serde(rename = "image_url")]
140 ImageUrl {
141 image_url: ImageUrl,
142 },
143}
144
145impl FromStr for UserContent {
146 type Err = Infallible;
147
148 fn from_str(s: &str) -> Result<Self, Self::Err> {
149 Ok(UserContent::Text {
150 text: s.to_string(),
151 })
152 }
153}
154
155#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
156#[serde(tag = "type", rename_all = "lowercase")]
157pub enum AssistantContent {
158 Text { text: String },
159}
160
161impl FromStr for AssistantContent {
162 type Err = Infallible;
163
164 fn from_str(s: &str) -> Result<Self, Self::Err> {
165 Ok(AssistantContent::Text {
166 text: s.to_string(),
167 })
168 }
169}
170
171#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
172#[serde(tag = "type", rename_all = "lowercase")]
173pub enum SystemContent {
174 Text { text: String },
175}
176
177impl FromStr for SystemContent {
178 type Err = Infallible;
179
180 fn from_str(s: &str) -> Result<Self, Self::Err> {
181 Ok(SystemContent::Text {
182 text: s.to_string(),
183 })
184 }
185}
186
187impl From<UserContent> for message::UserContent {
188 fn from(value: UserContent) -> Self {
189 match value {
190 UserContent::Text { text } => message::UserContent::text(text),
191 UserContent::ImageUrl { image_url } => message::UserContent::image(
192 image_url.url,
193 Some(message::ContentFormat::String),
194 None,
195 None,
196 ),
197 }
198 }
199}
200
201impl TryFrom<message::UserContent> for UserContent {
202 type Error = message::MessageError;
203
204 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
205 match content {
206 message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }),
207 message::UserContent::Image(message::Image { data, format, .. }) => match format {
208 Some(message::ContentFormat::String) => Ok(UserContent::ImageUrl {
209 image_url: ImageUrl { url: data },
210 }),
211 _ => Err(message::MessageError::ConversionError(
212 "Huggingface only supports images as urls".into(),
213 )),
214 },
215 _ => Err(message::MessageError::ConversionError(
216 "Huggingface only supports text and images".into(),
217 )),
218 }
219 }
220}
221
222#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
223#[serde(tag = "role", rename_all = "lowercase")]
224pub enum Message {
225 System {
226 #[serde(deserialize_with = "string_or_one_or_many")]
227 content: OneOrMany<SystemContent>,
228 },
229 User {
230 #[serde(deserialize_with = "string_or_one_or_many")]
231 content: OneOrMany<UserContent>,
232 },
233 Assistant {
234 #[serde(default, deserialize_with = "json_utils::string_or_vec")]
235 content: Vec<AssistantContent>,
236 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
237 tool_calls: Vec<ToolCall>,
238 },
239 #[serde(rename = "Tool")]
240 ToolResult {
241 name: String,
242 #[serde(skip_serializing_if = "Option::is_none")]
243 arguments: Option<serde_json::Value>,
244 #[serde(deserialize_with = "string_or_one_or_many")]
245 content: OneOrMany<String>,
246 },
247}
248
249impl Message {
250 pub fn system(content: &str) -> Self {
251 Message::System {
252 content: OneOrMany::one(SystemContent::Text {
253 text: content.to_string(),
254 }),
255 }
256 }
257}
258
259impl TryFrom<message::Message> for Vec<Message> {
260 type Error = message::MessageError;
261
262 fn try_from(message: message::Message) -> Result<Vec<Message>, Self::Error> {
263 match message {
264 message::Message::User { content } => {
265 let (tool_results, other_content): (Vec<_>, Vec<_>) = content
266 .into_iter()
267 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
268
269 if !tool_results.is_empty() {
270 tool_results
271 .into_iter()
272 .map(|content| match content {
273 message::UserContent::ToolResult(message::ToolResult {
274 id,
275 content,
276 ..
277 }) => Ok::<_, message::MessageError>(Message::ToolResult {
278 name: id,
279 arguments: None,
280 content: content.try_map(|content| match content {
281 message::ToolResultContent::Text(message::Text { text }) => {
282 Ok(text)
283 }
284 _ => Err(message::MessageError::ConversionError(
285 "Tool result content does not support non-text".into(),
286 )),
287 })?,
288 }),
289 _ => unreachable!(),
290 })
291 .collect::<Result<Vec<_>, _>>()
292 } else {
293 let other_content = OneOrMany::many(other_content).expect(
294 "There must be other content here if there were no tool result content",
295 );
296
297 Ok(vec![Message::User {
298 content: other_content.try_map(|content| match content {
299 message::UserContent::Text(text) => {
300 Ok(UserContent::Text { text: text.text })
301 }
302 _ => Err(message::MessageError::ConversionError(
303 "Huggingface does not support non-text".into(),
304 )),
305 })?,
306 }])
307 }
308 }
309 message::Message::Assistant { content, .. } => {
310 let (text_content, tool_calls) = content.into_iter().fold(
311 (Vec::new(), Vec::new()),
312 |(mut texts, mut tools), content| {
313 match content {
314 message::AssistantContent::Text(text) => texts.push(text),
315 message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
316 message::AssistantContent::Reasoning(_) => {
317 unimplemented!("Reasoning is not supported on HuggingFace via Rig");
318 }
319 }
320 (texts, tools)
321 },
322 );
323
324 Ok(vec![Message::Assistant {
327 content: text_content
328 .into_iter()
329 .map(|content| AssistantContent::Text { text: content.text })
330 .collect::<Vec<_>>(),
331 tool_calls: tool_calls
332 .into_iter()
333 .map(|tool_call| tool_call.into())
334 .collect::<Vec<_>>(),
335 }])
336 }
337 }
338 }
339}
340
341impl TryFrom<Message> for message::Message {
342 type Error = message::MessageError;
343
344 fn try_from(message: Message) -> Result<Self, Self::Error> {
345 Ok(match message {
346 Message::User { content, .. } => message::Message::User {
347 content: content.map(|content| content.into()),
348 },
349 Message::Assistant {
350 content,
351 tool_calls,
352 ..
353 } => {
354 let mut content = content
355 .into_iter()
356 .map(|content| match content {
357 AssistantContent::Text { text } => message::AssistantContent::text(text),
358 })
359 .collect::<Vec<_>>();
360
361 content.extend(
362 tool_calls
363 .into_iter()
364 .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
365 .collect::<Result<Vec<_>, _>>()?,
366 );
367
368 message::Message::Assistant {
369 id: None,
370 content: OneOrMany::many(content).map_err(|_| {
371 message::MessageError::ConversionError(
372 "Neither `content` nor `tool_calls` was provided to the Message"
373 .to_owned(),
374 )
375 })?,
376 }
377 }
378
379 Message::ToolResult { name, content, .. } => message::Message::User {
380 content: OneOrMany::one(message::UserContent::tool_result(
381 name,
382 content.map(message::ToolResultContent::text),
383 )),
384 },
385
386 Message::System { content, .. } => message::Message::User {
389 content: content.map(|c| match c {
390 SystemContent::Text { text } => message::UserContent::text(text),
391 }),
392 },
393 })
394 }
395}
396
397#[derive(Debug, Deserialize, Serialize)]
398pub struct Choice {
399 pub finish_reason: String,
400 pub index: usize,
401 #[serde(default)]
402 pub logprobs: serde_json::Value,
403 pub message: Message,
404}
405
406#[derive(Debug, Deserialize, Clone, Serialize)]
407pub struct Usage {
408 pub completion_tokens: i32,
409 pub prompt_tokens: i32,
410 pub total_tokens: i32,
411}
412
413#[derive(Debug, Deserialize, Serialize)]
414pub struct CompletionResponse {
415 pub created: i32,
416 pub id: String,
417 pub model: String,
418 pub choices: Vec<Choice>,
419 #[serde(default, deserialize_with = "default_string_on_null")]
420 pub system_fingerprint: String,
421 pub usage: Usage,
422}
423
424fn default_string_on_null<'de, D>(deserializer: D) -> Result<String, D::Error>
425where
426 D: Deserializer<'de>,
427{
428 match Option::<String>::deserialize(deserializer)? {
429 Some(value) => Ok(value), None => Ok(String::default()), }
432}
433
434impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
435 type Error = CompletionError;
436
437 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
438 let choice = response.choices.first().ok_or_else(|| {
439 CompletionError::ResponseError("Response contained no choices".to_owned())
440 })?;
441
442 let content = match &choice.message {
443 Message::Assistant {
444 content,
445 tool_calls,
446 ..
447 } => {
448 let mut content = content
449 .iter()
450 .map(|c| match c {
451 AssistantContent::Text { text } => message::AssistantContent::text(text),
452 })
453 .collect::<Vec<_>>();
454
455 content.extend(
456 tool_calls
457 .iter()
458 .map(|call| {
459 completion::AssistantContent::tool_call(
460 &call.id,
461 &call.function.name,
462 call.function.arguments.clone(),
463 )
464 })
465 .collect::<Vec<_>>(),
466 );
467 Ok(content)
468 }
469 _ => Err(CompletionError::ResponseError(
470 "Response did not contain a valid message or tool call".into(),
471 )),
472 }?;
473
474 let choice = OneOrMany::many(content).map_err(|_| {
475 CompletionError::ResponseError(
476 "Response contained no message or tool call (empty)".to_owned(),
477 )
478 })?;
479
480 let usage = completion::Usage {
481 input_tokens: response.usage.prompt_tokens as u64,
482 output_tokens: response.usage.completion_tokens as u64,
483 total_tokens: response.usage.total_tokens as u64,
484 };
485
486 Ok(completion::CompletionResponse {
487 choice,
488 usage,
489 raw_response: response,
490 })
491 }
492}
493
494#[derive(Clone)]
495pub struct CompletionModel {
496 pub(crate) client: Client,
497 pub model: String,
499}
500
501impl CompletionModel {
502 pub fn new(client: Client, model: &str) -> Self {
503 Self {
504 client,
505 model: model.to_string(),
506 }
507 }
508
509 pub(crate) fn create_request_body(
510 &self,
511 completion_request: &CompletionRequest,
512 ) -> Result<serde_json::Value, CompletionError> {
513 let mut full_history: Vec<Message> = match &completion_request.preamble {
514 Some(preamble) => vec![Message::system(preamble)],
515 None => vec![],
516 };
517 if let Some(docs) = completion_request.normalized_documents() {
518 let docs: Vec<Message> = docs.try_into()?;
519 full_history.extend(docs);
520 }
521
522 let chat_history: Vec<Message> = completion_request
523 .chat_history
524 .clone()
525 .into_iter()
526 .map(|message| message.try_into())
527 .collect::<Result<Vec<Vec<Message>>, _>>()?
528 .into_iter()
529 .flatten()
530 .collect();
531
532 full_history.extend(chat_history);
533
534 let model = self.client.sub_provider.model_identifier(&self.model);
535
536 let request = if completion_request.tools.is_empty() {
537 json!({
538 "model": model,
539 "messages": full_history,
540 "temperature": completion_request.temperature,
541 })
542 } else {
543 json!({
544 "model": model,
545 "messages": full_history,
546 "temperature": completion_request.temperature,
547 "tools": completion_request.tools.clone().into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
548 "tool_choice": "auto",
549 })
550 };
551 Ok(request)
552 }
553}
554
555impl completion::CompletionModel for CompletionModel {
556 type Response = CompletionResponse;
557 type StreamingResponse = StreamingCompletionResponse;
558
559 #[cfg_attr(feature = "worker", worker::send)]
560 async fn completion(
561 &self,
562 completion_request: CompletionRequest,
563 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
564 let request = self.create_request_body(&completion_request)?;
565
566 let path = self.client.sub_provider.completion_endpoint(&self.model);
567
568 let request = if let Some(ref params) = completion_request.additional_params {
569 json_utils::merge(request, params.clone())
570 } else {
571 request
572 };
573
574 let response = self.client.post(&path).json(&request).send().await?;
575
576 if response.status().is_success() {
577 let t = response.text().await?;
578 tracing::debug!(target: "rig", "Huggingface completion error: {}", t);
579
580 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
581 ApiResponse::Ok(response) => {
582 tracing::info!(target: "rig",
583 "Huggingface completion token usage: {:?}",
584 format!("{:?}", response.usage)
585 );
586 response.try_into()
587 }
588 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())),
589 }
590 } else {
591 Err(CompletionError::ProviderError(format!(
592 "{}: {}",
593 response.status(),
594 response.text().await?
595 )))
596 }
597 }
598
599 #[cfg_attr(feature = "worker", worker::send)]
600 async fn stream(
601 &self,
602 request: CompletionRequest,
603 ) -> Result<
604 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
605 CompletionError,
606 > {
607 CompletionModel::stream(self, request).await
608 }
609}
610
611#[cfg(test)]
612mod tests {
613 use super::*;
614 use serde_path_to_error::deserialize;
615
616 #[test]
617 fn test_deserialize_message() {
618 let assistant_message_json = r#"
619 {
620 "role": "assistant",
621 "content": "\n\nHello there, how may I assist you today?"
622 }
623 "#;
624
625 let assistant_message_json2 = r#"
626 {
627 "role": "assistant",
628 "content": [
629 {
630 "type": "text",
631 "text": "\n\nHello there, how may I assist you today?"
632 }
633 ],
634 "tool_calls": null
635 }
636 "#;
637
638 let assistant_message_json3 = r#"
639 {
640 "role": "assistant",
641 "tool_calls": [
642 {
643 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
644 "type": "function",
645 "function": {
646 "name": "subtract",
647 "arguments": {"x": 2, "y": 5}
648 }
649 }
650 ],
651 "content": null,
652 "refusal": null
653 }
654 "#;
655
656 let user_message_json = r#"
657 {
658 "role": "user",
659 "content": [
660 {
661 "type": "text",
662 "text": "What's in this image?"
663 },
664 {
665 "type": "image_url",
666 "image_url": {
667 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
668 }
669 }
670 ]
671 }
672 "#;
673
674 let assistant_message: Message = {
675 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
676 deserialize(jd).unwrap_or_else(|err| {
677 panic!(
678 "Deserialization error at {} ({}:{}): {}",
679 err.path(),
680 err.inner().line(),
681 err.inner().column(),
682 err
683 );
684 })
685 };
686
687 let assistant_message2: Message = {
688 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
689 deserialize(jd).unwrap_or_else(|err| {
690 panic!(
691 "Deserialization error at {} ({}:{}): {}",
692 err.path(),
693 err.inner().line(),
694 err.inner().column(),
695 err
696 );
697 })
698 };
699
700 let assistant_message3: Message = {
701 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
702 &mut serde_json::Deserializer::from_str(assistant_message_json3);
703 deserialize(jd).unwrap_or_else(|err| {
704 panic!(
705 "Deserialization error at {} ({}:{}): {}",
706 err.path(),
707 err.inner().line(),
708 err.inner().column(),
709 err
710 );
711 })
712 };
713
714 let user_message: Message = {
715 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
716 deserialize(jd).unwrap_or_else(|err| {
717 panic!(
718 "Deserialization error at {} ({}:{}): {}",
719 err.path(),
720 err.inner().line(),
721 err.inner().column(),
722 err
723 );
724 })
725 };
726
727 match assistant_message {
728 Message::Assistant { content, .. } => {
729 assert_eq!(
730 content[0],
731 AssistantContent::Text {
732 text: "\n\nHello there, how may I assist you today?".to_string()
733 }
734 );
735 }
736 _ => panic!("Expected assistant message"),
737 }
738
739 match assistant_message2 {
740 Message::Assistant {
741 content,
742 tool_calls,
743 ..
744 } => {
745 assert_eq!(
746 content[0],
747 AssistantContent::Text {
748 text: "\n\nHello there, how may I assist you today?".to_string()
749 }
750 );
751
752 assert_eq!(tool_calls, vec![]);
753 }
754 _ => panic!("Expected assistant message"),
755 }
756
757 match assistant_message3 {
758 Message::Assistant {
759 content,
760 tool_calls,
761 ..
762 } => {
763 assert!(content.is_empty());
764 assert_eq!(
765 tool_calls[0],
766 ToolCall {
767 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
768 r#type: ToolType::Function,
769 function: Function {
770 name: "subtract".to_string(),
771 arguments: serde_json::json!({"x": 2, "y": 5}),
772 },
773 }
774 );
775 }
776 _ => panic!("Expected assistant message"),
777 }
778
779 match user_message {
780 Message::User { content, .. } => {
781 let (first, second) = {
782 let mut iter = content.into_iter();
783 (iter.next().unwrap(), iter.next().unwrap())
784 };
785 assert_eq!(
786 first,
787 UserContent::Text {
788 text: "What's in this image?".to_string()
789 }
790 );
791 assert_eq!(second, UserContent::ImageUrl { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string() } });
792 }
793 _ => panic!("Expected user message"),
794 }
795 }
796
797 #[test]
798 fn test_message_to_message_conversion() {
799 let user_message = message::Message::User {
800 content: OneOrMany::one(message::UserContent::text("Hello")),
801 };
802
803 let assistant_message = message::Message::Assistant {
804 id: None,
805 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
806 };
807
808 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
809 let converted_assistant_message: Vec<Message> =
810 assistant_message.clone().try_into().unwrap();
811
812 match converted_user_message[0].clone() {
813 Message::User { content, .. } => {
814 assert_eq!(
815 content.first(),
816 UserContent::Text {
817 text: "Hello".to_string()
818 }
819 );
820 }
821 _ => panic!("Expected user message"),
822 }
823
824 match converted_assistant_message[0].clone() {
825 Message::Assistant { content, .. } => {
826 assert_eq!(
827 content[0],
828 AssistantContent::Text {
829 text: "Hi there!".to_string()
830 }
831 );
832 }
833 _ => panic!("Expected assistant message"),
834 }
835
836 let original_user_message: message::Message =
837 converted_user_message[0].clone().try_into().unwrap();
838 let original_assistant_message: message::Message =
839 converted_assistant_message[0].clone().try_into().unwrap();
840
841 assert_eq!(original_user_message, user_message);
842 assert_eq!(original_assistant_message, assistant_message);
843 }
844
845 #[test]
846 fn test_message_from_message_conversion() {
847 let user_message = Message::User {
848 content: OneOrMany::one(UserContent::Text {
849 text: "Hello".to_string(),
850 }),
851 };
852
853 let assistant_message = Message::Assistant {
854 content: vec![AssistantContent::Text {
855 text: "Hi there!".to_string(),
856 }],
857 tool_calls: vec![],
858 };
859
860 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
861 let converted_assistant_message: message::Message =
862 assistant_message.clone().try_into().unwrap();
863
864 match converted_user_message.clone() {
865 message::Message::User { content } => {
866 assert_eq!(content.first(), message::UserContent::text("Hello"));
867 }
868 _ => panic!("Expected user message"),
869 }
870
871 match converted_assistant_message.clone() {
872 message::Message::Assistant { content, .. } => {
873 assert_eq!(
874 content.first(),
875 message::AssistantContent::text("Hi there!")
876 );
877 }
878 _ => panic!("Expected assistant message"),
879 }
880
881 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
882 let original_assistant_message: Vec<Message> =
883 converted_assistant_message.try_into().unwrap();
884
885 assert_eq!(original_user_message[0], user_message);
886 assert_eq!(original_assistant_message[0], assistant_message);
887 }
888
889 #[test]
890 fn test_responses() {
891 let fireworks_response_json = r#"
892 {
893 "choices": [
894 {
895 "finish_reason": "tool_calls",
896 "index": 0,
897 "message": {
898 "role": "assistant",
899 "tool_calls": [
900 {
901 "function": {
902 "arguments": "{\"x\": 2, \"y\": 5}",
903 "name": "subtract"
904 },
905 "id": "call_1BspL6mQqjKgvsQbH1TIYkHf",
906 "index": 0,
907 "type": "function"
908 }
909 ]
910 }
911 }
912 ],
913 "created": 1740704000,
914 "id": "2a81f6a1-4866-42fb-9902-2655a2b5b1ff",
915 "model": "accounts/fireworks/models/deepseek-v3",
916 "object": "chat.completion",
917 "usage": {
918 "completion_tokens": 26,
919 "prompt_tokens": 248,
920 "total_tokens": 274
921 }
922 }
923 "#;
924
925 let novita_response_json = r#"
926 {
927 "choices": [
928 {
929 "finish_reason": "tool_calls",
930 "index": 0,
931 "logprobs": null,
932 "message": {
933 "audio": null,
934 "content": null,
935 "function_call": null,
936 "reasoning_content": null,
937 "refusal": null,
938 "role": "assistant",
939 "tool_calls": [
940 {
941 "function": {
942 "arguments": "{\"x\": \"2\", \"y\": \"5\"}",
943 "name": "subtract"
944 },
945 "id": "chatcmpl-tool-f6d2af7c8dc041058f95e2c2eede45c5",
946 "type": "function"
947 }
948 ]
949 },
950 "stop_reason": 128008
951 }
952 ],
953 "created": 1740704592,
954 "id": "chatcmpl-a92c60ae125c47c998ecdcb53387fed4",
955 "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-fast",
956 "object": "chat.completion",
957 "prompt_logprobs": null,
958 "service_tier": null,
959 "system_fingerprint": null,
960 "usage": {
961 "completion_tokens": 28,
962 "completion_tokens_details": null,
963 "prompt_tokens": 335,
964 "prompt_tokens_details": null,
965 "total_tokens": 363
966 }
967 }
968 "#;
969
970 let _firework_response: CompletionResponse = {
971 let jd = &mut serde_json::Deserializer::from_str(fireworks_response_json);
972 deserialize(jd).unwrap_or_else(|err| {
973 panic!(
974 "Deserialization error at {} ({}:{}): {}",
975 err.path(),
976 err.inner().line(),
977 err.inner().column(),
978 err
979 );
980 })
981 };
982
983 let _novita_response: CompletionResponse = {
984 let jd = &mut serde_json::Deserializer::from_str(novita_response_json);
985 deserialize(jd).unwrap_or_else(|err| {
986 panic!(
987 "Deserialization error at {} ({}:{}): {}",
988 err.path(),
989 err.inner().line(),
990 err.inner().column(),
991 err
992 );
993 })
994 };
995 }
996}