1use super::client::Client;
2use crate::completion::GetTokenUsage;
3use crate::http_client::HttpClientExt;
4use crate::providers::openai::StreamingCompletionResponse;
5use crate::telemetry::SpanCombinator;
6use crate::{
7 OneOrMany,
8 completion::{self, CompletionError, CompletionRequest},
9 json_utils,
10 message::{self},
11 one_or_many::string_or_one_or_many,
12};
13use serde::{Deserialize, Deserializer, Serialize, Serializer};
14use serde_json::{Value, json};
15use std::{convert::Infallible, str::FromStr};
16use tracing::info_span;
17
18#[derive(Debug, Deserialize)]
19#[serde(untagged)]
20pub enum ApiResponse<T> {
21 Ok(T),
22 Err(Value),
23}
24
25pub const GEMMA_2: &str = "google/gemma-2-2b-it";
33pub const META_LLAMA_3_1: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
35pub const PHI_4: &str = "microsoft/phi-4";
37pub const SMALLTHINKER_PREVIEW: &str = "PowerInfer/SmallThinker-3B-Preview";
39pub const QWEN2_5: &str = "Qwen/Qwen2.5-7B-Instruct";
41pub const QWEN2_5_CODER: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
43
44pub const QWEN2_VL: &str = "Qwen/Qwen2-VL-7B-Instruct";
48pub const QWEN_QVQ_PREVIEW: &str = "Qwen/QVQ-72B-Preview";
50
51#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
52pub struct Function {
53 name: String,
54 #[serde(
55 serialize_with = "json_utils::stringified_json::serialize",
56 deserialize_with = "deserialize_arguments"
57 )]
58 pub arguments: serde_json::Value,
59}
60
61fn deserialize_arguments<'de, D>(deserializer: D) -> Result<Value, D::Error>
62where
63 D: Deserializer<'de>,
64{
65 let value = Value::deserialize(deserializer)?;
66
67 match value {
68 Value::String(s) => serde_json::from_str(&s).map_err(serde::de::Error::custom),
69 other => Ok(other),
70 }
71}
72
73impl From<Function> for message::ToolFunction {
74 fn from(value: Function) -> Self {
75 message::ToolFunction {
76 name: value.name,
77 arguments: value.arguments,
78 }
79 }
80}
81
82#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
83#[serde(rename_all = "lowercase")]
84pub enum ToolType {
85 #[default]
86 Function,
87}
88
89#[derive(Debug, Deserialize, Serialize, Clone)]
90pub struct ToolDefinition {
91 pub r#type: String,
92 pub function: completion::ToolDefinition,
93}
94
95impl From<completion::ToolDefinition> for ToolDefinition {
96 fn from(tool: completion::ToolDefinition) -> Self {
97 Self {
98 r#type: "function".into(),
99 function: tool,
100 }
101 }
102}
103
104#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
105pub struct ToolCall {
106 pub id: String,
107 pub r#type: ToolType,
108 pub function: Function,
109}
110
111impl From<ToolCall> for message::ToolCall {
112 fn from(value: ToolCall) -> Self {
113 message::ToolCall {
114 id: value.id,
115 call_id: None,
116 function: value.function.into(),
117 }
118 }
119}
120
121impl From<message::ToolCall> for ToolCall {
122 fn from(value: message::ToolCall) -> Self {
123 ToolCall {
124 id: value.id,
125 r#type: ToolType::Function,
126 function: Function {
127 name: value.function.name,
128 arguments: value.function.arguments,
129 },
130 }
131 }
132}
133
134#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
135pub struct ImageUrl {
136 url: String,
137}
138
139#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
140#[serde(tag = "type", rename_all = "lowercase")]
141pub enum UserContent {
142 Text {
143 text: String,
144 },
145 #[serde(rename = "image_url")]
146 ImageUrl {
147 image_url: ImageUrl,
148 },
149}
150
151impl FromStr for UserContent {
152 type Err = Infallible;
153
154 fn from_str(s: &str) -> Result<Self, Self::Err> {
155 Ok(UserContent::Text {
156 text: s.to_string(),
157 })
158 }
159}
160
161#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
162#[serde(tag = "type", rename_all = "lowercase")]
163pub enum AssistantContent {
164 Text { text: String },
165}
166
167impl FromStr for AssistantContent {
168 type Err = Infallible;
169
170 fn from_str(s: &str) -> Result<Self, Self::Err> {
171 Ok(AssistantContent::Text {
172 text: s.to_string(),
173 })
174 }
175}
176
177#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
178#[serde(tag = "type", rename_all = "lowercase")]
179pub enum SystemContent {
180 Text { text: String },
181}
182
183impl FromStr for SystemContent {
184 type Err = Infallible;
185
186 fn from_str(s: &str) -> Result<Self, Self::Err> {
187 Ok(SystemContent::Text {
188 text: s.to_string(),
189 })
190 }
191}
192
193impl From<UserContent> for message::UserContent {
194 fn from(value: UserContent) -> Self {
195 match value {
196 UserContent::Text { text } => message::UserContent::text(text),
197 UserContent::ImageUrl { image_url } => {
198 message::UserContent::image_url(image_url.url, None, None)
199 }
200 }
201 }
202}
203
204impl TryFrom<message::UserContent> for UserContent {
205 type Error = message::MessageError;
206
207 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
208 match content {
209 message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }),
210 message::UserContent::Document(message::Document {
211 data: message::DocumentSourceKind::Raw(raw),
212 ..
213 }) => {
214 let text = String::from_utf8_lossy(raw.as_slice()).into();
215 Ok(UserContent::Text { text })
216 }
217 message::UserContent::Document(message::Document {
218 data:
219 message::DocumentSourceKind::Base64(text)
220 | message::DocumentSourceKind::String(text),
221 ..
222 }) => Ok(UserContent::Text { text }),
223 message::UserContent::Image(message::Image { data, .. }) => match data {
224 message::DocumentSourceKind::Url(url) => Ok(UserContent::ImageUrl {
225 image_url: ImageUrl { url },
226 }),
227 _ => Err(message::MessageError::ConversionError(
228 "Huggingface only supports images as urls".into(),
229 )),
230 },
231 _ => Err(message::MessageError::ConversionError(
232 "Huggingface only supports text and images".into(),
233 )),
234 }
235 }
236}
237
238#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
239#[serde(tag = "role", rename_all = "lowercase")]
240pub enum Message {
241 System {
242 #[serde(deserialize_with = "string_or_one_or_many")]
243 content: OneOrMany<SystemContent>,
244 },
245 User {
246 #[serde(deserialize_with = "string_or_one_or_many")]
247 content: OneOrMany<UserContent>,
248 },
249 Assistant {
250 #[serde(default, deserialize_with = "json_utils::string_or_vec")]
251 content: Vec<AssistantContent>,
252 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
253 tool_calls: Vec<ToolCall>,
254 },
255 #[serde(rename = "tool", alias = "Tool")]
256 ToolResult {
257 name: String,
258 #[serde(skip_serializing_if = "Option::is_none")]
259 arguments: Option<serde_json::Value>,
260 #[serde(
261 deserialize_with = "string_or_one_or_many",
262 serialize_with = "serialize_tool_content"
263 )]
264 content: OneOrMany<String>,
265 },
266}
267
268fn serialize_tool_content<S>(content: &OneOrMany<String>, serializer: S) -> Result<S::Ok, S::Error>
269where
270 S: Serializer,
271{
272 let joined = content
274 .iter()
275 .map(String::as_str)
276 .collect::<Vec<_>>()
277 .join("\n");
278 serializer.serialize_str(&joined)
279}
280
281impl Message {
282 pub fn system(content: &str) -> Self {
283 Message::System {
284 content: OneOrMany::one(SystemContent::Text {
285 text: content.to_string(),
286 }),
287 }
288 }
289}
290
291impl TryFrom<message::Message> for Vec<Message> {
292 type Error = message::MessageError;
293
294 fn try_from(message: message::Message) -> Result<Vec<Message>, Self::Error> {
295 match message {
296 message::Message::User { content } => {
297 let (tool_results, other_content): (Vec<_>, Vec<_>) = content
298 .into_iter()
299 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
300
301 if !tool_results.is_empty() {
302 tool_results
303 .into_iter()
304 .map(|content| match content {
305 message::UserContent::ToolResult(message::ToolResult {
306 id,
307 content,
308 ..
309 }) => Ok::<_, message::MessageError>(Message::ToolResult {
310 name: id,
311 arguments: None,
312 content: content.try_map(|content| match content {
313 message::ToolResultContent::Text(message::Text { text }) => {
314 Ok(text)
315 }
316 _ => Err(message::MessageError::ConversionError(
317 "Tool result content does not support non-text".into(),
318 )),
319 })?,
320 }),
321 _ => unreachable!(),
322 })
323 .collect::<Result<Vec<_>, _>>()
324 } else {
325 let other_content = OneOrMany::many(other_content).expect(
326 "There must be other content here if there were no tool result content",
327 );
328
329 Ok(vec![Message::User {
330 content: other_content.try_map(|content| match content {
331 message::UserContent::Text(text) => {
332 Ok(UserContent::Text { text: text.text })
333 }
334 message::UserContent::Image(image) => {
335 let url = image.try_into_url()?;
336
337 Ok(UserContent::ImageUrl {
338 image_url: ImageUrl { url },
339 })
340 }
341 message::UserContent::Document(message::Document {
342 data: message::DocumentSourceKind::Raw(raw), ..
343 }) => {
344 let text = String::from_utf8_lossy(raw.as_slice()).into();
345 Ok(UserContent::Text { text })
346 }
347 message::UserContent::Document(message::Document {
348 data: message::DocumentSourceKind::Base64(text) | message::DocumentSourceKind::String(text), ..
349 }) => {
350 Ok(UserContent::Text { text })
351 }
352 _ => Err(message::MessageError::ConversionError(
353 "Huggingface inputs only support text and image URLs (both base64-encoded images and regular URLs)".into(),
354 )),
355 })?,
356 }])
357 }
358 }
359 message::Message::Assistant { content, .. } => {
360 let (text_content, tool_calls) = content.into_iter().fold(
361 (Vec::new(), Vec::new()),
362 |(mut texts, mut tools), content| {
363 match content {
364 message::AssistantContent::Text(text) => texts.push(text),
365 message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
366 message::AssistantContent::Reasoning(_) => {
367 unimplemented!("Reasoning is not supported on HuggingFace via Rig");
368 }
369 }
370 (texts, tools)
371 },
372 );
373
374 Ok(vec![Message::Assistant {
377 content: text_content
378 .into_iter()
379 .map(|content| AssistantContent::Text { text: content.text })
380 .collect::<Vec<_>>(),
381 tool_calls: tool_calls
382 .into_iter()
383 .map(|tool_call| tool_call.into())
384 .collect::<Vec<_>>(),
385 }])
386 }
387 }
388 }
389}
390
391impl TryFrom<Message> for message::Message {
392 type Error = message::MessageError;
393
394 fn try_from(message: Message) -> Result<Self, Self::Error> {
395 Ok(match message {
396 Message::User { content, .. } => message::Message::User {
397 content: content.map(|content| content.into()),
398 },
399 Message::Assistant {
400 content,
401 tool_calls,
402 ..
403 } => {
404 let mut content = content
405 .into_iter()
406 .map(|content| match content {
407 AssistantContent::Text { text } => message::AssistantContent::text(text),
408 })
409 .collect::<Vec<_>>();
410
411 content.extend(
412 tool_calls
413 .into_iter()
414 .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
415 .collect::<Result<Vec<_>, _>>()?,
416 );
417
418 message::Message::Assistant {
419 id: None,
420 content: OneOrMany::many(content).map_err(|_| {
421 message::MessageError::ConversionError(
422 "Neither `content` nor `tool_calls` was provided to the Message"
423 .to_owned(),
424 )
425 })?,
426 }
427 }
428
429 Message::ToolResult { name, content, .. } => message::Message::User {
430 content: OneOrMany::one(message::UserContent::tool_result(
431 name,
432 content.map(message::ToolResultContent::text),
433 )),
434 },
435
436 Message::System { content, .. } => message::Message::User {
439 content: content.map(|c| match c {
440 SystemContent::Text { text } => message::UserContent::text(text),
441 }),
442 },
443 })
444 }
445}
446
447#[derive(Clone, Debug, Deserialize, Serialize)]
448pub struct Choice {
449 pub finish_reason: String,
450 pub index: usize,
451 #[serde(default)]
452 pub logprobs: serde_json::Value,
453 pub message: Message,
454}
455
456#[derive(Debug, Deserialize, Clone, Serialize)]
457pub struct Usage {
458 pub completion_tokens: i32,
459 pub prompt_tokens: i32,
460 pub total_tokens: i32,
461}
462
463impl GetTokenUsage for Usage {
464 fn token_usage(&self) -> Option<crate::completion::Usage> {
465 let mut usage = crate::completion::Usage::new();
466 usage.input_tokens = self.prompt_tokens as u64;
467 usage.output_tokens = self.completion_tokens as u64;
468 usage.total_tokens = self.total_tokens as u64;
469
470 Some(usage)
471 }
472}
473
474#[derive(Clone, Debug, Deserialize, Serialize)]
475pub struct CompletionResponse {
476 pub created: i32,
477 pub id: String,
478 pub model: String,
479 pub choices: Vec<Choice>,
480 #[serde(default, deserialize_with = "default_string_on_null")]
481 pub system_fingerprint: String,
482 pub usage: Usage,
483}
484
485impl crate::telemetry::ProviderResponseExt for CompletionResponse {
486 type OutputMessage = Choice;
487 type Usage = Usage;
488
489 fn get_response_id(&self) -> Option<String> {
490 Some(self.id.clone())
491 }
492
493 fn get_response_model_name(&self) -> Option<String> {
494 Some(self.model.clone())
495 }
496
497 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
498 self.choices.clone()
499 }
500
501 fn get_text_response(&self) -> Option<String> {
502 let text_response = self
503 .choices
504 .iter()
505 .filter_map(|x| {
506 let Message::User { ref content } = x.message else {
507 return None;
508 };
509
510 let text = content
511 .iter()
512 .filter_map(|x| {
513 if let UserContent::Text { text } = x {
514 Some(text.clone())
515 } else {
516 None
517 }
518 })
519 .collect::<Vec<String>>();
520
521 if text.is_empty() {
522 None
523 } else {
524 Some(text.join("\n"))
525 }
526 })
527 .collect::<Vec<String>>()
528 .join("\n");
529
530 if text_response.is_empty() {
531 None
532 } else {
533 Some(text_response)
534 }
535 }
536
537 fn get_usage(&self) -> Option<Self::Usage> {
538 Some(self.usage.clone())
539 }
540}
541
542fn default_string_on_null<'de, D>(deserializer: D) -> Result<String, D::Error>
543where
544 D: Deserializer<'de>,
545{
546 match Option::<String>::deserialize(deserializer)? {
547 Some(value) => Ok(value), None => Ok(String::default()), }
550}
551
552impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
553 type Error = CompletionError;
554
555 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
556 let choice = response.choices.first().ok_or_else(|| {
557 CompletionError::ResponseError("Response contained no choices".to_owned())
558 })?;
559
560 let content = match &choice.message {
561 Message::Assistant {
562 content,
563 tool_calls,
564 ..
565 } => {
566 let mut content = content
567 .iter()
568 .map(|c| match c {
569 AssistantContent::Text { text } => message::AssistantContent::text(text),
570 })
571 .collect::<Vec<_>>();
572
573 content.extend(
574 tool_calls
575 .iter()
576 .map(|call| {
577 completion::AssistantContent::tool_call(
578 &call.id,
579 &call.function.name,
580 call.function.arguments.clone(),
581 )
582 })
583 .collect::<Vec<_>>(),
584 );
585 Ok(content)
586 }
587 _ => Err(CompletionError::ResponseError(
588 "Response did not contain a valid message or tool call".into(),
589 )),
590 }?;
591
592 let choice = OneOrMany::many(content).map_err(|_| {
593 CompletionError::ResponseError(
594 "Response contained no message or tool call (empty)".to_owned(),
595 )
596 })?;
597
598 let usage = completion::Usage {
599 input_tokens: response.usage.prompt_tokens as u64,
600 output_tokens: response.usage.completion_tokens as u64,
601 total_tokens: response.usage.total_tokens as u64,
602 };
603
604 Ok(completion::CompletionResponse {
605 choice,
606 usage,
607 raw_response: response,
608 })
609 }
610}
611
612#[derive(Clone)]
613pub struct CompletionModel<T = reqwest::Client> {
614 pub(crate) client: Client<T>,
615 pub model: String,
617}
618
619impl<T> CompletionModel<T> {
620 pub fn new(client: Client<T>, model: &str) -> Self {
621 Self {
622 client,
623 model: model.to_string(),
624 }
625 }
626
627 pub(crate) fn create_request_body(
628 &self,
629 completion_request: &CompletionRequest,
630 ) -> Result<serde_json::Value, CompletionError> {
631 let mut full_history: Vec<Message> = match &completion_request.preamble {
632 Some(preamble) => vec![Message::system(preamble)],
633 None => vec![],
634 };
635 if let Some(docs) = completion_request.normalized_documents() {
636 let docs: Vec<Message> = docs.try_into()?;
637 full_history.extend(docs);
638 }
639
640 let chat_history: Vec<Message> = completion_request
641 .chat_history
642 .clone()
643 .into_iter()
644 .map(|message| message.try_into())
645 .collect::<Result<Vec<Vec<Message>>, _>>()?
646 .into_iter()
647 .flatten()
648 .collect();
649
650 full_history.extend(chat_history);
651
652 let model = self.client.sub_provider.model_identifier(&self.model);
653
654 let tool_choice = completion_request
655 .tool_choice
656 .clone()
657 .map(crate::providers::openai::completion::ToolChoice::try_from)
658 .transpose()?;
659
660 let request = if completion_request.tools.is_empty() {
661 json!({
662 "model": model,
663 "messages": full_history,
664 "temperature": completion_request.temperature,
665 })
666 } else {
667 json!({
668 "model": model,
669 "messages": full_history,
670 "temperature": completion_request.temperature,
671 "tools": completion_request.tools.clone().into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
672 "tool_choice": tool_choice,
673 })
674 };
675 Ok(request)
676 }
677}
678
679impl<T> completion::CompletionModel for CompletionModel<T>
680where
681 T: HttpClientExt + Clone + 'static,
682{
683 type Response = CompletionResponse;
684 type StreamingResponse = StreamingCompletionResponse;
685
686 #[cfg_attr(feature = "worker", worker::send)]
687 async fn completion(
688 &self,
689 completion_request: CompletionRequest,
690 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
691 let span = if tracing::Span::current().is_disabled() {
692 info_span!(
693 target: "rig::completions",
694 "chat",
695 gen_ai.operation.name = "chat",
696 gen_ai.provider.name = "huggingface",
697 gen_ai.request.model = self.model,
698 gen_ai.system_instructions = &completion_request.preamble,
699 gen_ai.response.id = tracing::field::Empty,
700 gen_ai.response.model = tracing::field::Empty,
701 gen_ai.usage.output_tokens = tracing::field::Empty,
702 gen_ai.usage.input_tokens = tracing::field::Empty,
703 gen_ai.input.messages = tracing::field::Empty,
704 gen_ai.output.messages = tracing::field::Empty,
705 )
706 } else {
707 tracing::Span::current()
708 };
709 let request = self.create_request_body(&completion_request)?;
710 span.record_model_input(&request.get("messages"));
711
712 let path = self.client.sub_provider.completion_endpoint(&self.model);
713
714 let request = if let Some(ref params) = completion_request.additional_params {
715 json_utils::merge(request, params.clone())
716 } else {
717 request
718 };
719
720 let request = serde_json::to_vec(&request)?;
721
722 let request = self
723 .client
724 .post(&path)?
725 .header("Content-Type", "application/json")
726 .body(request)
727 .map_err(|e| CompletionError::HttpError(e.into()))?;
728
729 let response = self.client.send(request).await?;
730
731 if response.status().is_success() {
732 let bytes: Vec<u8> = response.into_body().await?;
733 let text = String::from_utf8_lossy(&bytes);
734
735 tracing::debug!(target: "rig", "Huggingface completion error: {}", text);
736
737 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&bytes)? {
738 ApiResponse::Ok(response) => {
739 let span = tracing::Span::current();
740 span.record_token_usage(&response.usage);
741 span.record_model_output(&response.choices);
742 span.record_response_metadata(&response);
743
744 response.try_into()
745 }
746 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())),
747 }
748 } else {
749 let status = response.status();
750 let text: Vec<u8> = response.into_body().await?;
751 let text: String = String::from_utf8_lossy(&text).into();
752
753 Err(CompletionError::ProviderError(format!(
754 "{}: {}",
755 status, text
756 )))
757 }
758 }
759
760 #[cfg_attr(feature = "worker", worker::send)]
761 async fn stream(
762 &self,
763 request: CompletionRequest,
764 ) -> Result<
765 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
766 CompletionError,
767 > {
768 CompletionModel::stream(self, request).await
769 }
770}
771
772#[cfg(test)]
773mod tests {
774 use super::*;
775 use serde_path_to_error::deserialize;
776
777 #[test]
778 fn test_deserialize_message() {
779 let assistant_message_json = r#"
780 {
781 "role": "assistant",
782 "content": "\n\nHello there, how may I assist you today?"
783 }
784 "#;
785
786 let assistant_message_json2 = r#"
787 {
788 "role": "assistant",
789 "content": [
790 {
791 "type": "text",
792 "text": "\n\nHello there, how may I assist you today?"
793 }
794 ],
795 "tool_calls": null
796 }
797 "#;
798
799 let assistant_message_json3 = r#"
800 {
801 "role": "assistant",
802 "tool_calls": [
803 {
804 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
805 "type": "function",
806 "function": {
807 "name": "subtract",
808 "arguments": {"x": 2, "y": 5}
809 }
810 }
811 ],
812 "content": null,
813 "refusal": null
814 }
815 "#;
816
817 let user_message_json = r#"
818 {
819 "role": "user",
820 "content": [
821 {
822 "type": "text",
823 "text": "What's in this image?"
824 },
825 {
826 "type": "image_url",
827 "image_url": {
828 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
829 }
830 }
831 ]
832 }
833 "#;
834
835 let assistant_message: Message = {
836 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
837 deserialize(jd).unwrap_or_else(|err| {
838 panic!(
839 "Deserialization error at {} ({}:{}): {}",
840 err.path(),
841 err.inner().line(),
842 err.inner().column(),
843 err
844 );
845 })
846 };
847
848 let assistant_message2: Message = {
849 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
850 deserialize(jd).unwrap_or_else(|err| {
851 panic!(
852 "Deserialization error at {} ({}:{}): {}",
853 err.path(),
854 err.inner().line(),
855 err.inner().column(),
856 err
857 );
858 })
859 };
860
861 let assistant_message3: Message = {
862 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
863 &mut serde_json::Deserializer::from_str(assistant_message_json3);
864 deserialize(jd).unwrap_or_else(|err| {
865 panic!(
866 "Deserialization error at {} ({}:{}): {}",
867 err.path(),
868 err.inner().line(),
869 err.inner().column(),
870 err
871 );
872 })
873 };
874
875 let user_message: Message = {
876 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
877 deserialize(jd).unwrap_or_else(|err| {
878 panic!(
879 "Deserialization error at {} ({}:{}): {}",
880 err.path(),
881 err.inner().line(),
882 err.inner().column(),
883 err
884 );
885 })
886 };
887
888 match assistant_message {
889 Message::Assistant { content, .. } => {
890 assert_eq!(
891 content[0],
892 AssistantContent::Text {
893 text: "\n\nHello there, how may I assist you today?".to_string()
894 }
895 );
896 }
897 _ => panic!("Expected assistant message"),
898 }
899
900 match assistant_message2 {
901 Message::Assistant {
902 content,
903 tool_calls,
904 ..
905 } => {
906 assert_eq!(
907 content[0],
908 AssistantContent::Text {
909 text: "\n\nHello there, how may I assist you today?".to_string()
910 }
911 );
912
913 assert_eq!(tool_calls, vec![]);
914 }
915 _ => panic!("Expected assistant message"),
916 }
917
918 match assistant_message3 {
919 Message::Assistant {
920 content,
921 tool_calls,
922 ..
923 } => {
924 assert!(content.is_empty());
925 assert_eq!(
926 tool_calls[0],
927 ToolCall {
928 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
929 r#type: ToolType::Function,
930 function: Function {
931 name: "subtract".to_string(),
932 arguments: serde_json::json!({"x": 2, "y": 5}),
933 },
934 }
935 );
936 }
937 _ => panic!("Expected assistant message"),
938 }
939
940 match user_message {
941 Message::User { content, .. } => {
942 let (first, second) = {
943 let mut iter = content.into_iter();
944 (iter.next().unwrap(), iter.next().unwrap())
945 };
946 assert_eq!(
947 first,
948 UserContent::Text {
949 text: "What's in this image?".to_string()
950 }
951 );
952 assert_eq!(second, UserContent::ImageUrl { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string() } });
953 }
954 _ => panic!("Expected user message"),
955 }
956 }
957
958 #[test]
959 fn test_message_to_message_conversion() {
960 let user_message = message::Message::User {
961 content: OneOrMany::one(message::UserContent::text("Hello")),
962 };
963
964 let assistant_message = message::Message::Assistant {
965 id: None,
966 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
967 };
968
969 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
970 let converted_assistant_message: Vec<Message> =
971 assistant_message.clone().try_into().unwrap();
972
973 match converted_user_message[0].clone() {
974 Message::User { content, .. } => {
975 assert_eq!(
976 content.first(),
977 UserContent::Text {
978 text: "Hello".to_string()
979 }
980 );
981 }
982 _ => panic!("Expected user message"),
983 }
984
985 match converted_assistant_message[0].clone() {
986 Message::Assistant { content, .. } => {
987 assert_eq!(
988 content[0],
989 AssistantContent::Text {
990 text: "Hi there!".to_string()
991 }
992 );
993 }
994 _ => panic!("Expected assistant message"),
995 }
996
997 let original_user_message: message::Message =
998 converted_user_message[0].clone().try_into().unwrap();
999 let original_assistant_message: message::Message =
1000 converted_assistant_message[0].clone().try_into().unwrap();
1001
1002 assert_eq!(original_user_message, user_message);
1003 assert_eq!(original_assistant_message, assistant_message);
1004 }
1005
1006 #[test]
1007 fn test_message_from_message_conversion() {
1008 let user_message = Message::User {
1009 content: OneOrMany::one(UserContent::Text {
1010 text: "Hello".to_string(),
1011 }),
1012 };
1013
1014 let assistant_message = Message::Assistant {
1015 content: vec![AssistantContent::Text {
1016 text: "Hi there!".to_string(),
1017 }],
1018 tool_calls: vec![],
1019 };
1020
1021 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1022 let converted_assistant_message: message::Message =
1023 assistant_message.clone().try_into().unwrap();
1024
1025 match converted_user_message.clone() {
1026 message::Message::User { content } => {
1027 assert_eq!(content.first(), message::UserContent::text("Hello"));
1028 }
1029 _ => panic!("Expected user message"),
1030 }
1031
1032 match converted_assistant_message.clone() {
1033 message::Message::Assistant { content, .. } => {
1034 assert_eq!(
1035 content.first(),
1036 message::AssistantContent::text("Hi there!")
1037 );
1038 }
1039 _ => panic!("Expected assistant message"),
1040 }
1041
1042 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
1043 let original_assistant_message: Vec<Message> =
1044 converted_assistant_message.try_into().unwrap();
1045
1046 assert_eq!(original_user_message[0], user_message);
1047 assert_eq!(original_assistant_message[0], assistant_message);
1048 }
1049
1050 #[test]
1051 fn test_responses() {
1052 let fireworks_response_json = r#"
1053 {
1054 "choices": [
1055 {
1056 "finish_reason": "tool_calls",
1057 "index": 0,
1058 "message": {
1059 "role": "assistant",
1060 "tool_calls": [
1061 {
1062 "function": {
1063 "arguments": "{\"x\": 2, \"y\": 5}",
1064 "name": "subtract"
1065 },
1066 "id": "call_1BspL6mQqjKgvsQbH1TIYkHf",
1067 "index": 0,
1068 "type": "function"
1069 }
1070 ]
1071 }
1072 }
1073 ],
1074 "created": 1740704000,
1075 "id": "2a81f6a1-4866-42fb-9902-2655a2b5b1ff",
1076 "model": "accounts/fireworks/models/deepseek-v3",
1077 "object": "chat.completion",
1078 "usage": {
1079 "completion_tokens": 26,
1080 "prompt_tokens": 248,
1081 "total_tokens": 274
1082 }
1083 }
1084 "#;
1085
1086 let novita_response_json = r#"
1087 {
1088 "choices": [
1089 {
1090 "finish_reason": "tool_calls",
1091 "index": 0,
1092 "logprobs": null,
1093 "message": {
1094 "audio": null,
1095 "content": null,
1096 "function_call": null,
1097 "reasoning_content": null,
1098 "refusal": null,
1099 "role": "assistant",
1100 "tool_calls": [
1101 {
1102 "function": {
1103 "arguments": "{\"x\": \"2\", \"y\": \"5\"}",
1104 "name": "subtract"
1105 },
1106 "id": "chatcmpl-tool-f6d2af7c8dc041058f95e2c2eede45c5",
1107 "type": "function"
1108 }
1109 ]
1110 },
1111 "stop_reason": 128008
1112 }
1113 ],
1114 "created": 1740704592,
1115 "id": "chatcmpl-a92c60ae125c47c998ecdcb53387fed4",
1116 "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-fast",
1117 "object": "chat.completion",
1118 "prompt_logprobs": null,
1119 "service_tier": null,
1120 "system_fingerprint": null,
1121 "usage": {
1122 "completion_tokens": 28,
1123 "completion_tokens_details": null,
1124 "prompt_tokens": 335,
1125 "prompt_tokens_details": null,
1126 "total_tokens": 363
1127 }
1128 }
1129 "#;
1130
1131 let _firework_response: CompletionResponse = {
1132 let jd = &mut serde_json::Deserializer::from_str(fireworks_response_json);
1133 deserialize(jd).unwrap_or_else(|err| {
1134 panic!(
1135 "Deserialization error at {} ({}:{}): {}",
1136 err.path(),
1137 err.inner().line(),
1138 err.inner().column(),
1139 err
1140 );
1141 })
1142 };
1143
1144 let _novita_response: CompletionResponse = {
1145 let jd = &mut serde_json::Deserializer::from_str(novita_response_json);
1146 deserialize(jd).unwrap_or_else(|err| {
1147 panic!(
1148 "Deserialization error at {} ({}:{}): {}",
1149 err.path(),
1150 err.inner().line(),
1151 err.inner().column(),
1152 err
1153 );
1154 })
1155 };
1156 }
1157}