1use super::client::Client;
2use crate::completion::GetTokenUsage;
3use crate::http_client::HttpClientExt;
4use crate::providers::openai::StreamingCompletionResponse;
5use crate::telemetry::SpanCombinator;
6use crate::{
7 OneOrMany,
8 completion::{self, CompletionError, CompletionRequest},
9 json_utils,
10 message::{self},
11 one_or_many::string_or_one_or_many,
12};
13use serde::{Deserialize, Deserializer, Serialize, Serializer};
14use serde_json::Value;
15use std::{convert::Infallible, str::FromStr};
16use tracing::info_span;
17
18#[derive(Debug, Deserialize)]
19#[serde(untagged)]
20pub enum ApiResponse<T> {
21 Ok(T),
22 Err(Value),
23}
24
25pub const GEMMA_2: &str = "google/gemma-2-2b-it";
32pub const META_LLAMA_3_1: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
34pub const SMALLTHINKER_PREVIEW: &str = "PowerInfer/SmallThinker-3B-Preview";
36pub const QWEN2_5: &str = "Qwen/Qwen2.5-7B-Instruct";
38pub const QWEN2_5_CODER: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
40
41pub const QWEN2_VL: &str = "Qwen/Qwen2-VL-7B-Instruct";
45pub const QWEN_QVQ_PREVIEW: &str = "Qwen/QVQ-72B-Preview";
47
48#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
49pub struct Function {
50 name: String,
51 #[serde(
52 serialize_with = "json_utils::stringified_json::serialize",
53 deserialize_with = "deserialize_arguments"
54 )]
55 pub arguments: serde_json::Value,
56}
57
58fn deserialize_arguments<'de, D>(deserializer: D) -> Result<Value, D::Error>
59where
60 D: Deserializer<'de>,
61{
62 let value = Value::deserialize(deserializer)?;
63
64 match value {
65 Value::String(s) => serde_json::from_str(&s).map_err(serde::de::Error::custom),
66 other => Ok(other),
67 }
68}
69
70impl From<Function> for message::ToolFunction {
71 fn from(value: Function) -> Self {
72 message::ToolFunction {
73 name: value.name,
74 arguments: value.arguments,
75 }
76 }
77}
78
79#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
80#[serde(rename_all = "lowercase")]
81pub enum ToolType {
82 #[default]
83 Function,
84}
85
86#[derive(Debug, Deserialize, Serialize, Clone)]
87pub struct ToolDefinition {
88 pub r#type: String,
89 pub function: completion::ToolDefinition,
90}
91
92impl From<completion::ToolDefinition> for ToolDefinition {
93 fn from(tool: completion::ToolDefinition) -> Self {
94 Self {
95 r#type: "function".into(),
96 function: tool,
97 }
98 }
99}
100
101#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
102pub struct ToolCall {
103 pub id: String,
104 pub r#type: ToolType,
105 pub function: Function,
106}
107
108impl From<ToolCall> for message::ToolCall {
109 fn from(value: ToolCall) -> Self {
110 message::ToolCall {
111 id: value.id,
112 call_id: None,
113 function: value.function.into(),
114 }
115 }
116}
117
118impl From<message::ToolCall> for ToolCall {
119 fn from(value: message::ToolCall) -> Self {
120 ToolCall {
121 id: value.id,
122 r#type: ToolType::Function,
123 function: Function {
124 name: value.function.name,
125 arguments: value.function.arguments,
126 },
127 }
128 }
129}
130
131#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
132pub struct ImageUrl {
133 url: String,
134}
135
136#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
137#[serde(tag = "type", rename_all = "lowercase")]
138pub enum UserContent {
139 Text {
140 text: String,
141 },
142 #[serde(rename = "image_url")]
143 ImageUrl {
144 image_url: ImageUrl,
145 },
146}
147
148impl FromStr for UserContent {
149 type Err = Infallible;
150
151 fn from_str(s: &str) -> Result<Self, Self::Err> {
152 Ok(UserContent::Text {
153 text: s.to_string(),
154 })
155 }
156}
157
158#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
159#[serde(tag = "type", rename_all = "lowercase")]
160pub enum AssistantContent {
161 Text { text: String },
162}
163
164impl FromStr for AssistantContent {
165 type Err = Infallible;
166
167 fn from_str(s: &str) -> Result<Self, Self::Err> {
168 Ok(AssistantContent::Text {
169 text: s.to_string(),
170 })
171 }
172}
173
174#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
175#[serde(tag = "type", rename_all = "lowercase")]
176pub enum SystemContent {
177 Text { text: String },
178}
179
180impl FromStr for SystemContent {
181 type Err = Infallible;
182
183 fn from_str(s: &str) -> Result<Self, Self::Err> {
184 Ok(SystemContent::Text {
185 text: s.to_string(),
186 })
187 }
188}
189
190impl From<UserContent> for message::UserContent {
191 fn from(value: UserContent) -> Self {
192 match value {
193 UserContent::Text { text } => message::UserContent::text(text),
194 UserContent::ImageUrl { image_url } => {
195 message::UserContent::image_url(image_url.url, None, None)
196 }
197 }
198 }
199}
200
201impl TryFrom<message::UserContent> for UserContent {
202 type Error = message::MessageError;
203
204 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
205 match content {
206 message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }),
207 message::UserContent::Document(message::Document {
208 data: message::DocumentSourceKind::Raw(raw),
209 ..
210 }) => {
211 let text = String::from_utf8_lossy(raw.as_slice()).into();
212 Ok(UserContent::Text { text })
213 }
214 message::UserContent::Document(message::Document {
215 data:
216 message::DocumentSourceKind::Base64(text)
217 | message::DocumentSourceKind::String(text),
218 ..
219 }) => Ok(UserContent::Text { text }),
220 message::UserContent::Image(message::Image { data, .. }) => match data {
221 message::DocumentSourceKind::Url(url) => Ok(UserContent::ImageUrl {
222 image_url: ImageUrl { url },
223 }),
224 _ => Err(message::MessageError::ConversionError(
225 "Huggingface only supports images as urls".into(),
226 )),
227 },
228 _ => Err(message::MessageError::ConversionError(
229 "Huggingface only supports text and images".into(),
230 )),
231 }
232 }
233}
234
235#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
236#[serde(tag = "role", rename_all = "lowercase")]
237pub enum Message {
238 System {
239 #[serde(deserialize_with = "string_or_one_or_many")]
240 content: OneOrMany<SystemContent>,
241 },
242 User {
243 #[serde(deserialize_with = "string_or_one_or_many")]
244 content: OneOrMany<UserContent>,
245 },
246 Assistant {
247 #[serde(default, deserialize_with = "json_utils::string_or_vec")]
248 content: Vec<AssistantContent>,
249 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
250 tool_calls: Vec<ToolCall>,
251 },
252 #[serde(rename = "tool", alias = "Tool")]
253 ToolResult {
254 name: String,
255 #[serde(skip_serializing_if = "Option::is_none")]
256 arguments: Option<serde_json::Value>,
257 #[serde(
258 deserialize_with = "string_or_one_or_many",
259 serialize_with = "serialize_tool_content"
260 )]
261 content: OneOrMany<String>,
262 },
263}
264
265fn serialize_tool_content<S>(content: &OneOrMany<String>, serializer: S) -> Result<S::Ok, S::Error>
266where
267 S: Serializer,
268{
269 let joined = content
271 .iter()
272 .map(String::as_str)
273 .collect::<Vec<_>>()
274 .join("\n");
275 serializer.serialize_str(&joined)
276}
277
278impl Message {
279 pub fn system(content: &str) -> Self {
280 Message::System {
281 content: OneOrMany::one(SystemContent::Text {
282 text: content.to_string(),
283 }),
284 }
285 }
286}
287
288impl TryFrom<message::Message> for Vec<Message> {
289 type Error = message::MessageError;
290
291 fn try_from(message: message::Message) -> Result<Vec<Message>, Self::Error> {
292 match message {
293 message::Message::User { content } => {
294 let (tool_results, other_content): (Vec<_>, Vec<_>) = content
295 .into_iter()
296 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
297
298 if !tool_results.is_empty() {
299 tool_results
300 .into_iter()
301 .map(|content| match content {
302 message::UserContent::ToolResult(message::ToolResult {
303 id,
304 content,
305 ..
306 }) => Ok::<_, message::MessageError>(Message::ToolResult {
307 name: id,
308 arguments: None,
309 content: content.try_map(|content| match content {
310 message::ToolResultContent::Text(message::Text { text }) => {
311 Ok(text)
312 }
313 _ => Err(message::MessageError::ConversionError(
314 "Tool result content does not support non-text".into(),
315 )),
316 })?,
317 }),
318 _ => unreachable!(),
319 })
320 .collect::<Result<Vec<_>, _>>()
321 } else {
322 let other_content = OneOrMany::many(other_content).expect(
323 "There must be other content here if there were no tool result content",
324 );
325
326 Ok(vec![Message::User {
327 content: other_content.try_map(|content| match content {
328 message::UserContent::Text(text) => {
329 Ok(UserContent::Text { text: text.text })
330 }
331 message::UserContent::Image(image) => {
332 let url = image.try_into_url()?;
333
334 Ok(UserContent::ImageUrl {
335 image_url: ImageUrl { url },
336 })
337 }
338 message::UserContent::Document(message::Document {
339 data: message::DocumentSourceKind::Raw(raw), ..
340 }) => {
341 let text = String::from_utf8_lossy(raw.as_slice()).into();
342 Ok(UserContent::Text { text })
343 }
344 message::UserContent::Document(message::Document {
345 data: message::DocumentSourceKind::Base64(text) | message::DocumentSourceKind::String(text), ..
346 }) => {
347 Ok(UserContent::Text { text })
348 }
349 _ => Err(message::MessageError::ConversionError(
350 "Huggingface inputs only support text and image URLs (both base64-encoded images and regular URLs)".into(),
351 )),
352 })?,
353 }])
354 }
355 }
356 message::Message::Assistant { content, .. } => {
357 let (text_content, tool_calls) = content.into_iter().fold(
358 (Vec::new(), Vec::new()),
359 |(mut texts, mut tools), content| {
360 match content {
361 message::AssistantContent::Text(text) => texts.push(text),
362 message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
363 message::AssistantContent::Reasoning(_) => {
364 unimplemented!("Reasoning is not supported on HuggingFace via Rig");
365 }
366 message::AssistantContent::Image(_) => {
367 unimplemented!(
368 "Image content is not supported on HuggingFace via Rig"
369 );
370 }
371 }
372 (texts, tools)
373 },
374 );
375
376 Ok(vec![Message::Assistant {
379 content: text_content
380 .into_iter()
381 .map(|content| AssistantContent::Text { text: content.text })
382 .collect::<Vec<_>>(),
383 tool_calls: tool_calls
384 .into_iter()
385 .map(|tool_call| tool_call.into())
386 .collect::<Vec<_>>(),
387 }])
388 }
389 }
390 }
391}
392
393impl TryFrom<Message> for message::Message {
394 type Error = message::MessageError;
395
396 fn try_from(message: Message) -> Result<Self, Self::Error> {
397 Ok(match message {
398 Message::User { content, .. } => message::Message::User {
399 content: content.map(|content| content.into()),
400 },
401 Message::Assistant {
402 content,
403 tool_calls,
404 ..
405 } => {
406 let mut content = content
407 .into_iter()
408 .map(|content| match content {
409 AssistantContent::Text { text } => message::AssistantContent::text(text),
410 })
411 .collect::<Vec<_>>();
412
413 content.extend(
414 tool_calls
415 .into_iter()
416 .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
417 .collect::<Result<Vec<_>, _>>()?,
418 );
419
420 message::Message::Assistant {
421 id: None,
422 content: OneOrMany::many(content).map_err(|_| {
423 message::MessageError::ConversionError(
424 "Neither `content` nor `tool_calls` was provided to the Message"
425 .to_owned(),
426 )
427 })?,
428 }
429 }
430
431 Message::ToolResult { name, content, .. } => message::Message::User {
432 content: OneOrMany::one(message::UserContent::tool_result(
433 name,
434 content.map(message::ToolResultContent::text),
435 )),
436 },
437
438 Message::System { content, .. } => message::Message::User {
441 content: content.map(|c| match c {
442 SystemContent::Text { text } => message::UserContent::text(text),
443 }),
444 },
445 })
446 }
447}
448
449#[derive(Clone, Debug, Deserialize, Serialize)]
450pub struct Choice {
451 pub finish_reason: String,
452 pub index: usize,
453 #[serde(default)]
454 pub logprobs: serde_json::Value,
455 pub message: Message,
456}
457
458#[derive(Debug, Deserialize, Clone, Serialize)]
459pub struct Usage {
460 pub completion_tokens: i32,
461 pub prompt_tokens: i32,
462 pub total_tokens: i32,
463}
464
465impl GetTokenUsage for Usage {
466 fn token_usage(&self) -> Option<crate::completion::Usage> {
467 let mut usage = crate::completion::Usage::new();
468 usage.input_tokens = self.prompt_tokens as u64;
469 usage.output_tokens = self.completion_tokens as u64;
470 usage.total_tokens = self.total_tokens as u64;
471
472 Some(usage)
473 }
474}
475
476#[derive(Clone, Debug, Deserialize, Serialize)]
477pub struct CompletionResponse {
478 pub created: i32,
479 pub id: String,
480 pub model: String,
481 pub choices: Vec<Choice>,
482 #[serde(default, deserialize_with = "default_string_on_null")]
483 pub system_fingerprint: String,
484 pub usage: Usage,
485}
486
487impl crate::telemetry::ProviderResponseExt for CompletionResponse {
488 type OutputMessage = Choice;
489 type Usage = Usage;
490
491 fn get_response_id(&self) -> Option<String> {
492 Some(self.id.clone())
493 }
494
495 fn get_response_model_name(&self) -> Option<String> {
496 Some(self.model.clone())
497 }
498
499 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
500 self.choices.clone()
501 }
502
503 fn get_text_response(&self) -> Option<String> {
504 let text_response = self
505 .choices
506 .iter()
507 .filter_map(|x| {
508 let Message::User { ref content } = x.message else {
509 return None;
510 };
511
512 let text = content
513 .iter()
514 .filter_map(|x| {
515 if let UserContent::Text { text } = x {
516 Some(text.clone())
517 } else {
518 None
519 }
520 })
521 .collect::<Vec<String>>();
522
523 if text.is_empty() {
524 None
525 } else {
526 Some(text.join("\n"))
527 }
528 })
529 .collect::<Vec<String>>()
530 .join("\n");
531
532 if text_response.is_empty() {
533 None
534 } else {
535 Some(text_response)
536 }
537 }
538
539 fn get_usage(&self) -> Option<Self::Usage> {
540 Some(self.usage.clone())
541 }
542}
543
544fn default_string_on_null<'de, D>(deserializer: D) -> Result<String, D::Error>
545where
546 D: Deserializer<'de>,
547{
548 match Option::<String>::deserialize(deserializer)? {
549 Some(value) => Ok(value), None => Ok(String::default()), }
552}
553
554impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
555 type Error = CompletionError;
556
557 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
558 let choice = response.choices.first().ok_or_else(|| {
559 CompletionError::ResponseError("Response contained no choices".to_owned())
560 })?;
561
562 let content = match &choice.message {
563 Message::Assistant {
564 content,
565 tool_calls,
566 ..
567 } => {
568 let mut content = content
569 .iter()
570 .map(|c| match c {
571 AssistantContent::Text { text } => message::AssistantContent::text(text),
572 })
573 .collect::<Vec<_>>();
574
575 content.extend(
576 tool_calls
577 .iter()
578 .map(|call| {
579 completion::AssistantContent::tool_call(
580 &call.id,
581 &call.function.name,
582 call.function.arguments.clone(),
583 )
584 })
585 .collect::<Vec<_>>(),
586 );
587 Ok(content)
588 }
589 _ => Err(CompletionError::ResponseError(
590 "Response did not contain a valid message or tool call".into(),
591 )),
592 }?;
593
594 let choice = OneOrMany::many(content).map_err(|_| {
595 CompletionError::ResponseError(
596 "Response contained no message or tool call (empty)".to_owned(),
597 )
598 })?;
599
600 let usage = completion::Usage {
601 input_tokens: response.usage.prompt_tokens as u64,
602 output_tokens: response.usage.completion_tokens as u64,
603 total_tokens: response.usage.total_tokens as u64,
604 };
605
606 Ok(completion::CompletionResponse {
607 choice,
608 usage,
609 raw_response: response,
610 })
611 }
612}
613
614#[derive(Debug, Serialize, Deserialize)]
615pub(super) struct HuggingfaceCompletionRequest {
616 model: String,
617 pub messages: Vec<Message>,
618 #[serde(flatten, skip_serializing_if = "Option::is_none")]
619 temperature: Option<f64>,
620 #[serde(skip_serializing_if = "Vec::is_empty")]
621 tools: Vec<ToolDefinition>,
622 #[serde(flatten, skip_serializing_if = "Option::is_none")]
623 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
624 #[serde(flatten, skip_serializing_if = "Option::is_none")]
625 pub additional_params: Option<serde_json::Value>,
626}
627
628impl TryFrom<(&str, CompletionRequest)> for HuggingfaceCompletionRequest {
629 type Error = CompletionError;
630
631 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
632 let mut full_history: Vec<Message> = match &req.preamble {
633 Some(preamble) => vec![Message::system(preamble)],
634 None => vec![],
635 };
636 if let Some(docs) = req.normalized_documents() {
637 let docs: Vec<Message> = docs.try_into()?;
638 full_history.extend(docs);
639 }
640
641 let chat_history: Vec<Message> = req
642 .chat_history
643 .clone()
644 .into_iter()
645 .map(|message| message.try_into())
646 .collect::<Result<Vec<Vec<Message>>, _>>()?
647 .into_iter()
648 .flatten()
649 .collect();
650
651 full_history.extend(chat_history);
652
653 let tool_choice = req
654 .tool_choice
655 .clone()
656 .map(crate::providers::openai::completion::ToolChoice::try_from)
657 .transpose()?;
658
659 Ok(Self {
660 model: model.to_string(),
661 messages: full_history,
662 temperature: req.temperature,
663 tools: req
664 .tools
665 .clone()
666 .into_iter()
667 .map(ToolDefinition::from)
668 .collect::<Vec<_>>(),
669 tool_choice,
670 additional_params: req.additional_params,
671 })
672 }
673}
674
675#[derive(Clone)]
676pub struct CompletionModel<T = reqwest::Client> {
677 pub(crate) client: Client<T>,
678 pub model: String,
680}
681
682impl<T> CompletionModel<T> {
683 pub fn new(client: Client<T>, model: &str) -> Self {
684 Self {
685 client,
686 model: model.to_string(),
687 }
688 }
689}
690
691impl<T> completion::CompletionModel for CompletionModel<T>
692where
693 T: HttpClientExt + Clone + 'static,
694{
695 type Response = CompletionResponse;
696 type StreamingResponse = StreamingCompletionResponse;
697
698 type Client = Client<T>;
699
700 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
701 Self::new(client.clone(), &model.into())
702 }
703
704 #[cfg_attr(feature = "worker", worker::send)]
705 async fn completion(
706 &self,
707 completion_request: CompletionRequest,
708 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
709 let span = if tracing::Span::current().is_disabled() {
710 info_span!(
711 target: "rig::completions",
712 "chat",
713 gen_ai.operation.name = "chat",
714 gen_ai.provider.name = "huggingface",
715 gen_ai.request.model = self.model,
716 gen_ai.system_instructions = &completion_request.preamble,
717 gen_ai.response.id = tracing::field::Empty,
718 gen_ai.response.model = tracing::field::Empty,
719 gen_ai.usage.output_tokens = tracing::field::Empty,
720 gen_ai.usage.input_tokens = tracing::field::Empty,
721 gen_ai.input.messages = tracing::field::Empty,
722 gen_ai.output.messages = tracing::field::Empty,
723 )
724 } else {
725 tracing::Span::current()
726 };
727
728 let model = self.client.subprovider().model_identifier(&self.model);
729 let request = HuggingfaceCompletionRequest::try_from((model.as_ref(), completion_request))?;
730
731 span.record_model_input(&request.messages);
732
733 let request = serde_json::to_vec(&request)?;
734
735 let path = self.client.subprovider().completion_endpoint(&self.model);
736 let request = self
737 .client
738 .post(&path)?
739 .header("Content-Type", "application/json")
740 .body(request)
741 .map_err(|e| CompletionError::HttpError(e.into()))?;
742
743 let response = self.client.send(request).await?;
744
745 if response.status().is_success() {
746 let bytes: Vec<u8> = response.into_body().await?;
747 let text = String::from_utf8_lossy(&bytes);
748
749 tracing::debug!(target: "rig", "Huggingface completion error: {}", text);
750
751 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&bytes)? {
752 ApiResponse::Ok(response) => {
753 let span = tracing::Span::current();
754 span.record_token_usage(&response.usage);
755 span.record_model_output(&response.choices);
756 span.record_response_metadata(&response);
757
758 response.try_into()
759 }
760 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())),
761 }
762 } else {
763 let status = response.status();
764 let text: Vec<u8> = response.into_body().await?;
765 let text: String = String::from_utf8_lossy(&text).into();
766
767 Err(CompletionError::ProviderError(format!(
768 "{}: {}",
769 status, text
770 )))
771 }
772 }
773
774 #[cfg_attr(feature = "worker", worker::send)]
775 async fn stream(
776 &self,
777 request: CompletionRequest,
778 ) -> Result<
779 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
780 CompletionError,
781 > {
782 CompletionModel::stream(self, request).await
783 }
784}
785
786#[cfg(test)]
787mod tests {
788 use super::*;
789 use serde_path_to_error::deserialize;
790
791 #[test]
792 fn test_deserialize_message() {
793 let assistant_message_json = r#"
794 {
795 "role": "assistant",
796 "content": "\n\nHello there, how may I assist you today?"
797 }
798 "#;
799
800 let assistant_message_json2 = r#"
801 {
802 "role": "assistant",
803 "content": [
804 {
805 "type": "text",
806 "text": "\n\nHello there, how may I assist you today?"
807 }
808 ],
809 "tool_calls": null
810 }
811 "#;
812
813 let assistant_message_json3 = r#"
814 {
815 "role": "assistant",
816 "tool_calls": [
817 {
818 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
819 "type": "function",
820 "function": {
821 "name": "subtract",
822 "arguments": {"x": 2, "y": 5}
823 }
824 }
825 ],
826 "content": null,
827 "refusal": null
828 }
829 "#;
830
831 let user_message_json = r#"
832 {
833 "role": "user",
834 "content": [
835 {
836 "type": "text",
837 "text": "What's in this image?"
838 },
839 {
840 "type": "image_url",
841 "image_url": {
842 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
843 }
844 }
845 ]
846 }
847 "#;
848
849 let assistant_message: Message = {
850 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
851 deserialize(jd).unwrap_or_else(|err| {
852 panic!(
853 "Deserialization error at {} ({}:{}): {}",
854 err.path(),
855 err.inner().line(),
856 err.inner().column(),
857 err
858 );
859 })
860 };
861
862 let assistant_message2: Message = {
863 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
864 deserialize(jd).unwrap_or_else(|err| {
865 panic!(
866 "Deserialization error at {} ({}:{}): {}",
867 err.path(),
868 err.inner().line(),
869 err.inner().column(),
870 err
871 );
872 })
873 };
874
875 let assistant_message3: Message = {
876 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
877 &mut serde_json::Deserializer::from_str(assistant_message_json3);
878 deserialize(jd).unwrap_or_else(|err| {
879 panic!(
880 "Deserialization error at {} ({}:{}): {}",
881 err.path(),
882 err.inner().line(),
883 err.inner().column(),
884 err
885 );
886 })
887 };
888
889 let user_message: Message = {
890 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
891 deserialize(jd).unwrap_or_else(|err| {
892 panic!(
893 "Deserialization error at {} ({}:{}): {}",
894 err.path(),
895 err.inner().line(),
896 err.inner().column(),
897 err
898 );
899 })
900 };
901
902 match assistant_message {
903 Message::Assistant { content, .. } => {
904 assert_eq!(
905 content[0],
906 AssistantContent::Text {
907 text: "\n\nHello there, how may I assist you today?".to_string()
908 }
909 );
910 }
911 _ => panic!("Expected assistant message"),
912 }
913
914 match assistant_message2 {
915 Message::Assistant {
916 content,
917 tool_calls,
918 ..
919 } => {
920 assert_eq!(
921 content[0],
922 AssistantContent::Text {
923 text: "\n\nHello there, how may I assist you today?".to_string()
924 }
925 );
926
927 assert_eq!(tool_calls, vec![]);
928 }
929 _ => panic!("Expected assistant message"),
930 }
931
932 match assistant_message3 {
933 Message::Assistant {
934 content,
935 tool_calls,
936 ..
937 } => {
938 assert!(content.is_empty());
939 assert_eq!(
940 tool_calls[0],
941 ToolCall {
942 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
943 r#type: ToolType::Function,
944 function: Function {
945 name: "subtract".to_string(),
946 arguments: serde_json::json!({"x": 2, "y": 5}),
947 },
948 }
949 );
950 }
951 _ => panic!("Expected assistant message"),
952 }
953
954 match user_message {
955 Message::User { content, .. } => {
956 let (first, second) = {
957 let mut iter = content.into_iter();
958 (iter.next().unwrap(), iter.next().unwrap())
959 };
960 assert_eq!(
961 first,
962 UserContent::Text {
963 text: "What's in this image?".to_string()
964 }
965 );
966 assert_eq!(second, UserContent::ImageUrl { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string() } });
967 }
968 _ => panic!("Expected user message"),
969 }
970 }
971
972 #[test]
973 fn test_message_to_message_conversion() {
974 let user_message = message::Message::User {
975 content: OneOrMany::one(message::UserContent::text("Hello")),
976 };
977
978 let assistant_message = message::Message::Assistant {
979 id: None,
980 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
981 };
982
983 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
984 let converted_assistant_message: Vec<Message> =
985 assistant_message.clone().try_into().unwrap();
986
987 match converted_user_message[0].clone() {
988 Message::User { content, .. } => {
989 assert_eq!(
990 content.first(),
991 UserContent::Text {
992 text: "Hello".to_string()
993 }
994 );
995 }
996 _ => panic!("Expected user message"),
997 }
998
999 match converted_assistant_message[0].clone() {
1000 Message::Assistant { content, .. } => {
1001 assert_eq!(
1002 content[0],
1003 AssistantContent::Text {
1004 text: "Hi there!".to_string()
1005 }
1006 );
1007 }
1008 _ => panic!("Expected assistant message"),
1009 }
1010
1011 let original_user_message: message::Message =
1012 converted_user_message[0].clone().try_into().unwrap();
1013 let original_assistant_message: message::Message =
1014 converted_assistant_message[0].clone().try_into().unwrap();
1015
1016 assert_eq!(original_user_message, user_message);
1017 assert_eq!(original_assistant_message, assistant_message);
1018 }
1019
1020 #[test]
1021 fn test_message_from_message_conversion() {
1022 let user_message = Message::User {
1023 content: OneOrMany::one(UserContent::Text {
1024 text: "Hello".to_string(),
1025 }),
1026 };
1027
1028 let assistant_message = Message::Assistant {
1029 content: vec![AssistantContent::Text {
1030 text: "Hi there!".to_string(),
1031 }],
1032 tool_calls: vec![],
1033 };
1034
1035 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1036 let converted_assistant_message: message::Message =
1037 assistant_message.clone().try_into().unwrap();
1038
1039 match converted_user_message.clone() {
1040 message::Message::User { content } => {
1041 assert_eq!(content.first(), message::UserContent::text("Hello"));
1042 }
1043 _ => panic!("Expected user message"),
1044 }
1045
1046 match converted_assistant_message.clone() {
1047 message::Message::Assistant { content, .. } => {
1048 assert_eq!(
1049 content.first(),
1050 message::AssistantContent::text("Hi there!")
1051 );
1052 }
1053 _ => panic!("Expected assistant message"),
1054 }
1055
1056 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
1057 let original_assistant_message: Vec<Message> =
1058 converted_assistant_message.try_into().unwrap();
1059
1060 assert_eq!(original_user_message[0], user_message);
1061 assert_eq!(original_assistant_message[0], assistant_message);
1062 }
1063
1064 #[test]
1065 fn test_responses() {
1066 let fireworks_response_json = r#"
1067 {
1068 "choices": [
1069 {
1070 "finish_reason": "tool_calls",
1071 "index": 0,
1072 "message": {
1073 "role": "assistant",
1074 "tool_calls": [
1075 {
1076 "function": {
1077 "arguments": "{\"x\": 2, \"y\": 5}",
1078 "name": "subtract"
1079 },
1080 "id": "call_1BspL6mQqjKgvsQbH1TIYkHf",
1081 "index": 0,
1082 "type": "function"
1083 }
1084 ]
1085 }
1086 }
1087 ],
1088 "created": 1740704000,
1089 "id": "2a81f6a1-4866-42fb-9902-2655a2b5b1ff",
1090 "model": "accounts/fireworks/models/deepseek-v3",
1091 "object": "chat.completion",
1092 "usage": {
1093 "completion_tokens": 26,
1094 "prompt_tokens": 248,
1095 "total_tokens": 274
1096 }
1097 }
1098 "#;
1099
1100 let novita_response_json = r#"
1101 {
1102 "choices": [
1103 {
1104 "finish_reason": "tool_calls",
1105 "index": 0,
1106 "logprobs": null,
1107 "message": {
1108 "audio": null,
1109 "content": null,
1110 "function_call": null,
1111 "reasoning_content": null,
1112 "refusal": null,
1113 "role": "assistant",
1114 "tool_calls": [
1115 {
1116 "function": {
1117 "arguments": "{\"x\": \"2\", \"y\": \"5\"}",
1118 "name": "subtract"
1119 },
1120 "id": "chatcmpl-tool-f6d2af7c8dc041058f95e2c2eede45c5",
1121 "type": "function"
1122 }
1123 ]
1124 },
1125 "stop_reason": 128008
1126 }
1127 ],
1128 "created": 1740704592,
1129 "id": "chatcmpl-a92c60ae125c47c998ecdcb53387fed4",
1130 "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-fast",
1131 "object": "chat.completion",
1132 "prompt_logprobs": null,
1133 "service_tier": null,
1134 "system_fingerprint": null,
1135 "usage": {
1136 "completion_tokens": 28,
1137 "completion_tokens_details": null,
1138 "prompt_tokens": 335,
1139 "prompt_tokens_details": null,
1140 "total_tokens": 363
1141 }
1142 }
1143 "#;
1144
1145 let _firework_response: CompletionResponse = {
1146 let jd = &mut serde_json::Deserializer::from_str(fireworks_response_json);
1147 deserialize(jd).unwrap_or_else(|err| {
1148 panic!(
1149 "Deserialization error at {} ({}:{}): {}",
1150 err.path(),
1151 err.inner().line(),
1152 err.inner().column(),
1153 err
1154 );
1155 })
1156 };
1157
1158 let _novita_response: CompletionResponse = {
1159 let jd = &mut serde_json::Deserializer::from_str(novita_response_json);
1160 deserialize(jd).unwrap_or_else(|err| {
1161 panic!(
1162 "Deserialization error at {} ({}:{}): {}",
1163 err.path(),
1164 err.inner().line(),
1165 err.inner().column(),
1166 err
1167 );
1168 })
1169 };
1170 }
1171}