1use super::client::Client;
2use crate::completion::GetTokenUsage;
3use crate::providers::openai::StreamingCompletionResponse;
4use crate::telemetry::SpanCombinator;
5use crate::{
6 OneOrMany,
7 completion::{self, CompletionError, CompletionRequest},
8 json_utils,
9 message::{self},
10 one_or_many::string_or_one_or_many,
11};
12use serde::{Deserialize, Deserializer, Serialize};
13use serde_json::{Value, json};
14use std::{convert::Infallible, str::FromStr};
15use tracing::info_span;
16
17#[derive(Debug, Deserialize)]
18#[serde(untagged)]
19pub enum ApiResponse<T> {
20 Ok(T),
21 Err(Value),
22}
23
24pub const GEMMA_2: &str = "google/gemma-2-2b-it";
32pub const META_LLAMA_3_1: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
34pub const PHI_4: &str = "microsoft/phi-4";
36pub const SMALLTHINKER_PREVIEW: &str = "PowerInfer/SmallThinker-3B-Preview";
38pub const QWEN2_5: &str = "Qwen/Qwen2.5-7B-Instruct";
40pub const QWEN2_5_CODER: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
42
43pub const QWEN2_VL: &str = "Qwen/Qwen2-VL-7B-Instruct";
47pub const QWEN_QVQ_PREVIEW: &str = "Qwen/QVQ-72B-Preview";
49
50#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
51pub struct Function {
52 name: String,
53 #[serde(deserialize_with = "deserialize_arguments")]
54 pub arguments: serde_json::Value,
55}
56
57fn deserialize_arguments<'de, D>(deserializer: D) -> Result<Value, D::Error>
58where
59 D: Deserializer<'de>,
60{
61 let value = Value::deserialize(deserializer)?;
62
63 match value {
64 Value::String(s) => serde_json::from_str(&s).map_err(serde::de::Error::custom),
65 other => Ok(other),
66 }
67}
68
69impl From<Function> for message::ToolFunction {
70 fn from(value: Function) -> Self {
71 message::ToolFunction {
72 name: value.name,
73 arguments: value.arguments,
74 }
75 }
76}
77
78#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
79#[serde(rename_all = "lowercase")]
80pub enum ToolType {
81 #[default]
82 Function,
83}
84
85#[derive(Debug, Deserialize, Serialize, Clone)]
86pub struct ToolDefinition {
87 pub r#type: String,
88 pub function: completion::ToolDefinition,
89}
90
91impl From<completion::ToolDefinition> for ToolDefinition {
92 fn from(tool: completion::ToolDefinition) -> Self {
93 Self {
94 r#type: "function".into(),
95 function: tool,
96 }
97 }
98}
99
100#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
101pub struct ToolCall {
102 pub id: String,
103 pub r#type: ToolType,
104 pub function: Function,
105}
106
107impl From<ToolCall> for message::ToolCall {
108 fn from(value: ToolCall) -> Self {
109 message::ToolCall {
110 id: value.id,
111 call_id: None,
112 function: value.function.into(),
113 }
114 }
115}
116
117impl From<message::ToolCall> for ToolCall {
118 fn from(value: message::ToolCall) -> Self {
119 ToolCall {
120 id: value.id,
121 r#type: ToolType::Function,
122 function: Function {
123 name: value.function.name,
124 arguments: value.function.arguments,
125 },
126 }
127 }
128}
129
130#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
131pub struct ImageUrl {
132 url: String,
133}
134
135#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
136#[serde(tag = "type", rename_all = "lowercase")]
137pub enum UserContent {
138 Text {
139 text: String,
140 },
141 #[serde(rename = "image_url")]
142 ImageUrl {
143 image_url: ImageUrl,
144 },
145}
146
147impl FromStr for UserContent {
148 type Err = Infallible;
149
150 fn from_str(s: &str) -> Result<Self, Self::Err> {
151 Ok(UserContent::Text {
152 text: s.to_string(),
153 })
154 }
155}
156
157#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
158#[serde(tag = "type", rename_all = "lowercase")]
159pub enum AssistantContent {
160 Text { text: String },
161}
162
163impl FromStr for AssistantContent {
164 type Err = Infallible;
165
166 fn from_str(s: &str) -> Result<Self, Self::Err> {
167 Ok(AssistantContent::Text {
168 text: s.to_string(),
169 })
170 }
171}
172
173#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
174#[serde(tag = "type", rename_all = "lowercase")]
175pub enum SystemContent {
176 Text { text: String },
177}
178
179impl FromStr for SystemContent {
180 type Err = Infallible;
181
182 fn from_str(s: &str) -> Result<Self, Self::Err> {
183 Ok(SystemContent::Text {
184 text: s.to_string(),
185 })
186 }
187}
188
189impl From<UserContent> for message::UserContent {
190 fn from(value: UserContent) -> Self {
191 match value {
192 UserContent::Text { text } => message::UserContent::text(text),
193 UserContent::ImageUrl { image_url } => {
194 message::UserContent::image_url(image_url.url, None, None)
195 }
196 }
197 }
198}
199
200impl TryFrom<message::UserContent> for UserContent {
201 type Error = message::MessageError;
202
203 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
204 match content {
205 message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }),
206 message::UserContent::Document(message::Document {
207 data: message::DocumentSourceKind::Raw(raw),
208 ..
209 }) => {
210 let text = String::from_utf8_lossy(raw.as_slice()).into();
211 Ok(UserContent::Text { text })
212 }
213 message::UserContent::Document(message::Document {
214 data:
215 message::DocumentSourceKind::Base64(text)
216 | message::DocumentSourceKind::String(text),
217 ..
218 }) => Ok(UserContent::Text { text }),
219 message::UserContent::Image(message::Image { data, .. }) => match data {
220 message::DocumentSourceKind::Url(url) => Ok(UserContent::ImageUrl {
221 image_url: ImageUrl { url },
222 }),
223 _ => Err(message::MessageError::ConversionError(
224 "Huggingface only supports images as urls".into(),
225 )),
226 },
227 _ => Err(message::MessageError::ConversionError(
228 "Huggingface only supports text and images".into(),
229 )),
230 }
231 }
232}
233
234#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
235#[serde(tag = "role", rename_all = "lowercase")]
236pub enum Message {
237 System {
238 #[serde(deserialize_with = "string_or_one_or_many")]
239 content: OneOrMany<SystemContent>,
240 },
241 User {
242 #[serde(deserialize_with = "string_or_one_or_many")]
243 content: OneOrMany<UserContent>,
244 },
245 Assistant {
246 #[serde(default, deserialize_with = "json_utils::string_or_vec")]
247 content: Vec<AssistantContent>,
248 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
249 tool_calls: Vec<ToolCall>,
250 },
251 #[serde(rename = "Tool")]
252 ToolResult {
253 name: String,
254 #[serde(skip_serializing_if = "Option::is_none")]
255 arguments: Option<serde_json::Value>,
256 #[serde(deserialize_with = "string_or_one_or_many")]
257 content: OneOrMany<String>,
258 },
259}
260
261impl Message {
262 pub fn system(content: &str) -> Self {
263 Message::System {
264 content: OneOrMany::one(SystemContent::Text {
265 text: content.to_string(),
266 }),
267 }
268 }
269}
270
271impl TryFrom<message::Message> for Vec<Message> {
272 type Error = message::MessageError;
273
274 fn try_from(message: message::Message) -> Result<Vec<Message>, Self::Error> {
275 match message {
276 message::Message::User { content } => {
277 let (tool_results, other_content): (Vec<_>, Vec<_>) = content
278 .into_iter()
279 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
280
281 if !tool_results.is_empty() {
282 tool_results
283 .into_iter()
284 .map(|content| match content {
285 message::UserContent::ToolResult(message::ToolResult {
286 id,
287 content,
288 ..
289 }) => Ok::<_, message::MessageError>(Message::ToolResult {
290 name: id,
291 arguments: None,
292 content: content.try_map(|content| match content {
293 message::ToolResultContent::Text(message::Text { text }) => {
294 Ok(text)
295 }
296 _ => Err(message::MessageError::ConversionError(
297 "Tool result content does not support non-text".into(),
298 )),
299 })?,
300 }),
301 _ => unreachable!(),
302 })
303 .collect::<Result<Vec<_>, _>>()
304 } else {
305 let other_content = OneOrMany::many(other_content).expect(
306 "There must be other content here if there were no tool result content",
307 );
308
309 Ok(vec![Message::User {
310 content: other_content.try_map(|content| match content {
311 message::UserContent::Text(text) => {
312 Ok(UserContent::Text { text: text.text })
313 }
314 message::UserContent::Image(image) => {
315 let url = image.try_into_url()?;
316
317 Ok(UserContent::ImageUrl {
318 image_url: ImageUrl { url },
319 })
320 }
321 message::UserContent::Document(message::Document {
322 data: message::DocumentSourceKind::Raw(raw), ..
323 }) => {
324 let text = String::from_utf8_lossy(raw.as_slice()).into();
325 Ok(UserContent::Text { text })
326 }
327 message::UserContent::Document(message::Document {
328 data: message::DocumentSourceKind::Base64(text) | message::DocumentSourceKind::String(text), ..
329 }) => {
330 Ok(UserContent::Text { text })
331 }
332 _ => Err(message::MessageError::ConversionError(
333 "Huggingface inputs only support text and image URLs (both base64-encoded images and regular URLs)".into(),
334 )),
335 })?,
336 }])
337 }
338 }
339 message::Message::Assistant { content, .. } => {
340 let (text_content, tool_calls) = content.into_iter().fold(
341 (Vec::new(), Vec::new()),
342 |(mut texts, mut tools), content| {
343 match content {
344 message::AssistantContent::Text(text) => texts.push(text),
345 message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
346 message::AssistantContent::Reasoning(_) => {
347 unimplemented!("Reasoning is not supported on HuggingFace via Rig");
348 }
349 }
350 (texts, tools)
351 },
352 );
353
354 Ok(vec![Message::Assistant {
357 content: text_content
358 .into_iter()
359 .map(|content| AssistantContent::Text { text: content.text })
360 .collect::<Vec<_>>(),
361 tool_calls: tool_calls
362 .into_iter()
363 .map(|tool_call| tool_call.into())
364 .collect::<Vec<_>>(),
365 }])
366 }
367 }
368 }
369}
370
371impl TryFrom<Message> for message::Message {
372 type Error = message::MessageError;
373
374 fn try_from(message: Message) -> Result<Self, Self::Error> {
375 Ok(match message {
376 Message::User { content, .. } => message::Message::User {
377 content: content.map(|content| content.into()),
378 },
379 Message::Assistant {
380 content,
381 tool_calls,
382 ..
383 } => {
384 let mut content = content
385 .into_iter()
386 .map(|content| match content {
387 AssistantContent::Text { text } => message::AssistantContent::text(text),
388 })
389 .collect::<Vec<_>>();
390
391 content.extend(
392 tool_calls
393 .into_iter()
394 .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
395 .collect::<Result<Vec<_>, _>>()?,
396 );
397
398 message::Message::Assistant {
399 id: None,
400 content: OneOrMany::many(content).map_err(|_| {
401 message::MessageError::ConversionError(
402 "Neither `content` nor `tool_calls` was provided to the Message"
403 .to_owned(),
404 )
405 })?,
406 }
407 }
408
409 Message::ToolResult { name, content, .. } => message::Message::User {
410 content: OneOrMany::one(message::UserContent::tool_result(
411 name,
412 content.map(message::ToolResultContent::text),
413 )),
414 },
415
416 Message::System { content, .. } => message::Message::User {
419 content: content.map(|c| match c {
420 SystemContent::Text { text } => message::UserContent::text(text),
421 }),
422 },
423 })
424 }
425}
426
427#[derive(Clone, Debug, Deserialize, Serialize)]
428pub struct Choice {
429 pub finish_reason: String,
430 pub index: usize,
431 #[serde(default)]
432 pub logprobs: serde_json::Value,
433 pub message: Message,
434}
435
436#[derive(Debug, Deserialize, Clone, Serialize)]
437pub struct Usage {
438 pub completion_tokens: i32,
439 pub prompt_tokens: i32,
440 pub total_tokens: i32,
441}
442
443impl GetTokenUsage for Usage {
444 fn token_usage(&self) -> Option<crate::completion::Usage> {
445 let mut usage = crate::completion::Usage::new();
446 usage.input_tokens = self.prompt_tokens as u64;
447 usage.output_tokens = self.completion_tokens as u64;
448 usage.total_tokens = self.total_tokens as u64;
449
450 Some(usage)
451 }
452}
453
454#[derive(Clone, Debug, Deserialize, Serialize)]
455pub struct CompletionResponse {
456 pub created: i32,
457 pub id: String,
458 pub model: String,
459 pub choices: Vec<Choice>,
460 #[serde(default, deserialize_with = "default_string_on_null")]
461 pub system_fingerprint: String,
462 pub usage: Usage,
463}
464
465impl crate::telemetry::ProviderResponseExt for CompletionResponse {
466 type OutputMessage = Choice;
467 type Usage = Usage;
468
469 fn get_response_id(&self) -> Option<String> {
470 Some(self.id.clone())
471 }
472
473 fn get_response_model_name(&self) -> Option<String> {
474 Some(self.model.clone())
475 }
476
477 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
478 self.choices.clone()
479 }
480
481 fn get_text_response(&self) -> Option<String> {
482 let text_response = self
483 .choices
484 .iter()
485 .filter_map(|x| {
486 let Message::User { ref content } = x.message else {
487 return None;
488 };
489
490 let text = content
491 .iter()
492 .filter_map(|x| {
493 if let UserContent::Text { text } = x {
494 Some(text.clone())
495 } else {
496 None
497 }
498 })
499 .collect::<Vec<String>>();
500
501 if text.is_empty() {
502 None
503 } else {
504 Some(text.join("\n"))
505 }
506 })
507 .collect::<Vec<String>>()
508 .join("\n");
509
510 if text_response.is_empty() {
511 None
512 } else {
513 Some(text_response)
514 }
515 }
516
517 fn get_usage(&self) -> Option<Self::Usage> {
518 Some(self.usage.clone())
519 }
520}
521
522fn default_string_on_null<'de, D>(deserializer: D) -> Result<String, D::Error>
523where
524 D: Deserializer<'de>,
525{
526 match Option::<String>::deserialize(deserializer)? {
527 Some(value) => Ok(value), None => Ok(String::default()), }
530}
531
532impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
533 type Error = CompletionError;
534
535 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
536 let choice = response.choices.first().ok_or_else(|| {
537 CompletionError::ResponseError("Response contained no choices".to_owned())
538 })?;
539
540 let content = match &choice.message {
541 Message::Assistant {
542 content,
543 tool_calls,
544 ..
545 } => {
546 let mut content = content
547 .iter()
548 .map(|c| match c {
549 AssistantContent::Text { text } => message::AssistantContent::text(text),
550 })
551 .collect::<Vec<_>>();
552
553 content.extend(
554 tool_calls
555 .iter()
556 .map(|call| {
557 completion::AssistantContent::tool_call(
558 &call.id,
559 &call.function.name,
560 call.function.arguments.clone(),
561 )
562 })
563 .collect::<Vec<_>>(),
564 );
565 Ok(content)
566 }
567 _ => Err(CompletionError::ResponseError(
568 "Response did not contain a valid message or tool call".into(),
569 )),
570 }?;
571
572 let choice = OneOrMany::many(content).map_err(|_| {
573 CompletionError::ResponseError(
574 "Response contained no message or tool call (empty)".to_owned(),
575 )
576 })?;
577
578 let usage = completion::Usage {
579 input_tokens: response.usage.prompt_tokens as u64,
580 output_tokens: response.usage.completion_tokens as u64,
581 total_tokens: response.usage.total_tokens as u64,
582 };
583
584 Ok(completion::CompletionResponse {
585 choice,
586 usage,
587 raw_response: response,
588 })
589 }
590}
591
592#[derive(Clone)]
593pub struct CompletionModel<T = reqwest::Client> {
594 pub(crate) client: Client<T>,
595 pub model: String,
597}
598
599impl<T> CompletionModel<T> {
600 pub fn new(client: Client<T>, model: &str) -> Self {
601 Self {
602 client,
603 model: model.to_string(),
604 }
605 }
606
607 pub(crate) fn create_request_body(
608 &self,
609 completion_request: &CompletionRequest,
610 ) -> Result<serde_json::Value, CompletionError> {
611 let mut full_history: Vec<Message> = match &completion_request.preamble {
612 Some(preamble) => vec![Message::system(preamble)],
613 None => vec![],
614 };
615 if let Some(docs) = completion_request.normalized_documents() {
616 let docs: Vec<Message> = docs.try_into()?;
617 full_history.extend(docs);
618 }
619
620 let chat_history: Vec<Message> = completion_request
621 .chat_history
622 .clone()
623 .into_iter()
624 .map(|message| message.try_into())
625 .collect::<Result<Vec<Vec<Message>>, _>>()?
626 .into_iter()
627 .flatten()
628 .collect();
629
630 full_history.extend(chat_history);
631
632 let model = self.client.sub_provider.model_identifier(&self.model);
633
634 let tool_choice = completion_request
635 .tool_choice
636 .clone()
637 .map(crate::providers::openai::completion::ToolChoice::try_from)
638 .transpose()?;
639
640 let request = if completion_request.tools.is_empty() {
641 json!({
642 "model": model,
643 "messages": full_history,
644 "temperature": completion_request.temperature,
645 })
646 } else {
647 json!({
648 "model": model,
649 "messages": full_history,
650 "temperature": completion_request.temperature,
651 "tools": completion_request.tools.clone().into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
652 "tool_choice": tool_choice,
653 })
654 };
655 Ok(request)
656 }
657}
658
659impl completion::CompletionModel for CompletionModel<reqwest::Client> {
660 type Response = CompletionResponse;
661 type StreamingResponse = StreamingCompletionResponse;
662
663 #[cfg_attr(feature = "worker", worker::send)]
664 async fn completion(
665 &self,
666 completion_request: CompletionRequest,
667 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
668 let span = if tracing::Span::current().is_disabled() {
669 info_span!(
670 target: "rig::completions",
671 "chat",
672 gen_ai.operation.name = "chat",
673 gen_ai.provider.name = "huggingface",
674 gen_ai.request.model = self.model,
675 gen_ai.system_instructions = &completion_request.preamble,
676 gen_ai.response.id = tracing::field::Empty,
677 gen_ai.response.model = tracing::field::Empty,
678 gen_ai.usage.output_tokens = tracing::field::Empty,
679 gen_ai.usage.input_tokens = tracing::field::Empty,
680 gen_ai.input.messages = tracing::field::Empty,
681 gen_ai.output.messages = tracing::field::Empty,
682 )
683 } else {
684 tracing::Span::current()
685 };
686 let request = self.create_request_body(&completion_request)?;
687 span.record_model_input(&request.get("messages"));
688
689 let path = self.client.sub_provider.completion_endpoint(&self.model);
690
691 let request = if let Some(ref params) = completion_request.additional_params {
692 json_utils::merge(request, params.clone())
693 } else {
694 request
695 };
696
697 let request = serde_json::to_vec(&request)?;
698
699 let request = self
700 .client
701 .post(&path)?
702 .header("Content-Type", "application/json")
703 .body(request)
704 .map_err(|e| CompletionError::HttpError(e.into()))?;
705
706 let response = self.client.send(request).await?;
707
708 if response.status().is_success() {
709 let bytes: Vec<u8> = response.into_body().await?;
710 let text = String::from_utf8_lossy(&bytes);
711
712 tracing::debug!(target: "rig", "Huggingface completion error: {}", text);
713
714 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&bytes)? {
715 ApiResponse::Ok(response) => {
716 let span = tracing::Span::current();
717 span.record_token_usage(&response.usage);
718 span.record_model_output(&response.choices);
719 span.record_response_metadata(&response);
720
721 response.try_into()
722 }
723 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())),
724 }
725 } else {
726 let status = response.status();
727 let text: Vec<u8> = response.into_body().await?;
728 let text: String = String::from_utf8_lossy(&text).into();
729
730 Err(CompletionError::ProviderError(format!(
731 "{}: {}",
732 status, text
733 )))
734 }
735 }
736
737 #[cfg_attr(feature = "worker", worker::send)]
738 async fn stream(
739 &self,
740 request: CompletionRequest,
741 ) -> Result<
742 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
743 CompletionError,
744 > {
745 CompletionModel::stream(self, request).await
746 }
747}
748
749#[cfg(test)]
750mod tests {
751 use super::*;
752 use serde_path_to_error::deserialize;
753
754 #[test]
755 fn test_deserialize_message() {
756 let assistant_message_json = r#"
757 {
758 "role": "assistant",
759 "content": "\n\nHello there, how may I assist you today?"
760 }
761 "#;
762
763 let assistant_message_json2 = r#"
764 {
765 "role": "assistant",
766 "content": [
767 {
768 "type": "text",
769 "text": "\n\nHello there, how may I assist you today?"
770 }
771 ],
772 "tool_calls": null
773 }
774 "#;
775
776 let assistant_message_json3 = r#"
777 {
778 "role": "assistant",
779 "tool_calls": [
780 {
781 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
782 "type": "function",
783 "function": {
784 "name": "subtract",
785 "arguments": {"x": 2, "y": 5}
786 }
787 }
788 ],
789 "content": null,
790 "refusal": null
791 }
792 "#;
793
794 let user_message_json = r#"
795 {
796 "role": "user",
797 "content": [
798 {
799 "type": "text",
800 "text": "What's in this image?"
801 },
802 {
803 "type": "image_url",
804 "image_url": {
805 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
806 }
807 }
808 ]
809 }
810 "#;
811
812 let assistant_message: Message = {
813 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
814 deserialize(jd).unwrap_or_else(|err| {
815 panic!(
816 "Deserialization error at {} ({}:{}): {}",
817 err.path(),
818 err.inner().line(),
819 err.inner().column(),
820 err
821 );
822 })
823 };
824
825 let assistant_message2: Message = {
826 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
827 deserialize(jd).unwrap_or_else(|err| {
828 panic!(
829 "Deserialization error at {} ({}:{}): {}",
830 err.path(),
831 err.inner().line(),
832 err.inner().column(),
833 err
834 );
835 })
836 };
837
838 let assistant_message3: Message = {
839 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
840 &mut serde_json::Deserializer::from_str(assistant_message_json3);
841 deserialize(jd).unwrap_or_else(|err| {
842 panic!(
843 "Deserialization error at {} ({}:{}): {}",
844 err.path(),
845 err.inner().line(),
846 err.inner().column(),
847 err
848 );
849 })
850 };
851
852 let user_message: Message = {
853 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
854 deserialize(jd).unwrap_or_else(|err| {
855 panic!(
856 "Deserialization error at {} ({}:{}): {}",
857 err.path(),
858 err.inner().line(),
859 err.inner().column(),
860 err
861 );
862 })
863 };
864
865 match assistant_message {
866 Message::Assistant { content, .. } => {
867 assert_eq!(
868 content[0],
869 AssistantContent::Text {
870 text: "\n\nHello there, how may I assist you today?".to_string()
871 }
872 );
873 }
874 _ => panic!("Expected assistant message"),
875 }
876
877 match assistant_message2 {
878 Message::Assistant {
879 content,
880 tool_calls,
881 ..
882 } => {
883 assert_eq!(
884 content[0],
885 AssistantContent::Text {
886 text: "\n\nHello there, how may I assist you today?".to_string()
887 }
888 );
889
890 assert_eq!(tool_calls, vec![]);
891 }
892 _ => panic!("Expected assistant message"),
893 }
894
895 match assistant_message3 {
896 Message::Assistant {
897 content,
898 tool_calls,
899 ..
900 } => {
901 assert!(content.is_empty());
902 assert_eq!(
903 tool_calls[0],
904 ToolCall {
905 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
906 r#type: ToolType::Function,
907 function: Function {
908 name: "subtract".to_string(),
909 arguments: serde_json::json!({"x": 2, "y": 5}),
910 },
911 }
912 );
913 }
914 _ => panic!("Expected assistant message"),
915 }
916
917 match user_message {
918 Message::User { content, .. } => {
919 let (first, second) = {
920 let mut iter = content.into_iter();
921 (iter.next().unwrap(), iter.next().unwrap())
922 };
923 assert_eq!(
924 first,
925 UserContent::Text {
926 text: "What's in this image?".to_string()
927 }
928 );
929 assert_eq!(second, UserContent::ImageUrl { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string() } });
930 }
931 _ => panic!("Expected user message"),
932 }
933 }
934
935 #[test]
936 fn test_message_to_message_conversion() {
937 let user_message = message::Message::User {
938 content: OneOrMany::one(message::UserContent::text("Hello")),
939 };
940
941 let assistant_message = message::Message::Assistant {
942 id: None,
943 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
944 };
945
946 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
947 let converted_assistant_message: Vec<Message> =
948 assistant_message.clone().try_into().unwrap();
949
950 match converted_user_message[0].clone() {
951 Message::User { content, .. } => {
952 assert_eq!(
953 content.first(),
954 UserContent::Text {
955 text: "Hello".to_string()
956 }
957 );
958 }
959 _ => panic!("Expected user message"),
960 }
961
962 match converted_assistant_message[0].clone() {
963 Message::Assistant { content, .. } => {
964 assert_eq!(
965 content[0],
966 AssistantContent::Text {
967 text: "Hi there!".to_string()
968 }
969 );
970 }
971 _ => panic!("Expected assistant message"),
972 }
973
974 let original_user_message: message::Message =
975 converted_user_message[0].clone().try_into().unwrap();
976 let original_assistant_message: message::Message =
977 converted_assistant_message[0].clone().try_into().unwrap();
978
979 assert_eq!(original_user_message, user_message);
980 assert_eq!(original_assistant_message, assistant_message);
981 }
982
983 #[test]
984 fn test_message_from_message_conversion() {
985 let user_message = Message::User {
986 content: OneOrMany::one(UserContent::Text {
987 text: "Hello".to_string(),
988 }),
989 };
990
991 let assistant_message = Message::Assistant {
992 content: vec![AssistantContent::Text {
993 text: "Hi there!".to_string(),
994 }],
995 tool_calls: vec![],
996 };
997
998 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
999 let converted_assistant_message: message::Message =
1000 assistant_message.clone().try_into().unwrap();
1001
1002 match converted_user_message.clone() {
1003 message::Message::User { content } => {
1004 assert_eq!(content.first(), message::UserContent::text("Hello"));
1005 }
1006 _ => panic!("Expected user message"),
1007 }
1008
1009 match converted_assistant_message.clone() {
1010 message::Message::Assistant { content, .. } => {
1011 assert_eq!(
1012 content.first(),
1013 message::AssistantContent::text("Hi there!")
1014 );
1015 }
1016 _ => panic!("Expected assistant message"),
1017 }
1018
1019 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
1020 let original_assistant_message: Vec<Message> =
1021 converted_assistant_message.try_into().unwrap();
1022
1023 assert_eq!(original_user_message[0], user_message);
1024 assert_eq!(original_assistant_message[0], assistant_message);
1025 }
1026
1027 #[test]
1028 fn test_responses() {
1029 let fireworks_response_json = r#"
1030 {
1031 "choices": [
1032 {
1033 "finish_reason": "tool_calls",
1034 "index": 0,
1035 "message": {
1036 "role": "assistant",
1037 "tool_calls": [
1038 {
1039 "function": {
1040 "arguments": "{\"x\": 2, \"y\": 5}",
1041 "name": "subtract"
1042 },
1043 "id": "call_1BspL6mQqjKgvsQbH1TIYkHf",
1044 "index": 0,
1045 "type": "function"
1046 }
1047 ]
1048 }
1049 }
1050 ],
1051 "created": 1740704000,
1052 "id": "2a81f6a1-4866-42fb-9902-2655a2b5b1ff",
1053 "model": "accounts/fireworks/models/deepseek-v3",
1054 "object": "chat.completion",
1055 "usage": {
1056 "completion_tokens": 26,
1057 "prompt_tokens": 248,
1058 "total_tokens": 274
1059 }
1060 }
1061 "#;
1062
1063 let novita_response_json = r#"
1064 {
1065 "choices": [
1066 {
1067 "finish_reason": "tool_calls",
1068 "index": 0,
1069 "logprobs": null,
1070 "message": {
1071 "audio": null,
1072 "content": null,
1073 "function_call": null,
1074 "reasoning_content": null,
1075 "refusal": null,
1076 "role": "assistant",
1077 "tool_calls": [
1078 {
1079 "function": {
1080 "arguments": "{\"x\": \"2\", \"y\": \"5\"}",
1081 "name": "subtract"
1082 },
1083 "id": "chatcmpl-tool-f6d2af7c8dc041058f95e2c2eede45c5",
1084 "type": "function"
1085 }
1086 ]
1087 },
1088 "stop_reason": 128008
1089 }
1090 ],
1091 "created": 1740704592,
1092 "id": "chatcmpl-a92c60ae125c47c998ecdcb53387fed4",
1093 "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-fast",
1094 "object": "chat.completion",
1095 "prompt_logprobs": null,
1096 "service_tier": null,
1097 "system_fingerprint": null,
1098 "usage": {
1099 "completion_tokens": 28,
1100 "completion_tokens_details": null,
1101 "prompt_tokens": 335,
1102 "prompt_tokens_details": null,
1103 "total_tokens": 363
1104 }
1105 }
1106 "#;
1107
1108 let _firework_response: CompletionResponse = {
1109 let jd = &mut serde_json::Deserializer::from_str(fireworks_response_json);
1110 deserialize(jd).unwrap_or_else(|err| {
1111 panic!(
1112 "Deserialization error at {} ({}:{}): {}",
1113 err.path(),
1114 err.inner().line(),
1115 err.inner().column(),
1116 err
1117 );
1118 })
1119 };
1120
1121 let _novita_response: CompletionResponse = {
1122 let jd = &mut serde_json::Deserializer::from_str(novita_response_json);
1123 deserialize(jd).unwrap_or_else(|err| {
1124 panic!(
1125 "Deserialization error at {} ({}:{}): {}",
1126 err.path(),
1127 err.inner().line(),
1128 err.inner().column(),
1129 err
1130 );
1131 })
1132 };
1133 }
1134}