1use super::client::Client;
2use crate::completion::GetTokenUsage;
3use crate::http_client::HttpClientExt;
4use crate::providers::openai::StreamingCompletionResponse;
5use crate::telemetry::SpanCombinator;
6use crate::{
7 OneOrMany,
8 completion::{self, CompletionError, CompletionRequest},
9 json_utils,
10 message::{self},
11 one_or_many::string_or_one_or_many,
12};
13use serde::{Deserialize, Deserializer, Serialize, Serializer};
14use serde_json::Value;
15use std::{convert::Infallible, str::FromStr};
16use tracing::{Level, enabled, info_span};
17use tracing_futures::Instrument;
18
19#[derive(Debug, Deserialize)]
20#[serde(untagged)]
21pub enum ApiResponse<T> {
22 Ok(T),
23 Err(Value),
24}
25
26pub const GEMMA_2: &str = "google/gemma-2-2b-it";
33pub const META_LLAMA_3_1: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
35pub const SMALLTHINKER_PREVIEW: &str = "PowerInfer/SmallThinker-3B-Preview";
37pub const QWEN2_5: &str = "Qwen/Qwen2.5-7B-Instruct";
39pub const QWEN2_5_CODER: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
41
42pub const QWEN2_VL: &str = "Qwen/Qwen2-VL-7B-Instruct";
46pub const QWEN_QVQ_PREVIEW: &str = "Qwen/QVQ-72B-Preview";
48
49#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
50pub struct Function {
51 name: String,
52 #[serde(
53 serialize_with = "json_utils::stringified_json::serialize",
54 deserialize_with = "deserialize_arguments"
55 )]
56 pub arguments: serde_json::Value,
57}
58
59fn deserialize_arguments<'de, D>(deserializer: D) -> Result<Value, D::Error>
60where
61 D: Deserializer<'de>,
62{
63 let value = Value::deserialize(deserializer)?;
64
65 match value {
66 Value::String(s) => serde_json::from_str(&s).map_err(serde::de::Error::custom),
67 other => Ok(other),
68 }
69}
70
71impl From<Function> for message::ToolFunction {
72 fn from(value: Function) -> Self {
73 message::ToolFunction {
74 name: value.name,
75 arguments: value.arguments,
76 }
77 }
78}
79
80#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
81#[serde(rename_all = "lowercase")]
82pub enum ToolType {
83 #[default]
84 Function,
85}
86
87#[derive(Debug, Deserialize, Serialize, Clone)]
88pub struct ToolDefinition {
89 pub r#type: String,
90 pub function: completion::ToolDefinition,
91}
92
93impl From<completion::ToolDefinition> for ToolDefinition {
94 fn from(tool: completion::ToolDefinition) -> Self {
95 Self {
96 r#type: "function".into(),
97 function: tool,
98 }
99 }
100}
101
102#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
103pub struct ToolCall {
104 pub id: String,
105 pub r#type: ToolType,
106 pub function: Function,
107}
108
109impl From<ToolCall> for message::ToolCall {
110 fn from(value: ToolCall) -> Self {
111 message::ToolCall {
112 id: value.id,
113 call_id: None,
114 function: value.function.into(),
115 }
116 }
117}
118
119impl From<message::ToolCall> for ToolCall {
120 fn from(value: message::ToolCall) -> Self {
121 ToolCall {
122 id: value.id,
123 r#type: ToolType::Function,
124 function: Function {
125 name: value.function.name,
126 arguments: value.function.arguments,
127 },
128 }
129 }
130}
131
132#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
133pub struct ImageUrl {
134 url: String,
135}
136
137#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
138#[serde(tag = "type", rename_all = "lowercase")]
139pub enum UserContent {
140 Text {
141 text: String,
142 },
143 #[serde(rename = "image_url")]
144 ImageUrl {
145 image_url: ImageUrl,
146 },
147}
148
149impl FromStr for UserContent {
150 type Err = Infallible;
151
152 fn from_str(s: &str) -> Result<Self, Self::Err> {
153 Ok(UserContent::Text {
154 text: s.to_string(),
155 })
156 }
157}
158
159#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
160#[serde(tag = "type", rename_all = "lowercase")]
161pub enum AssistantContent {
162 Text { text: String },
163}
164
165impl FromStr for AssistantContent {
166 type Err = Infallible;
167
168 fn from_str(s: &str) -> Result<Self, Self::Err> {
169 Ok(AssistantContent::Text {
170 text: s.to_string(),
171 })
172 }
173}
174
175#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
176#[serde(tag = "type", rename_all = "lowercase")]
177pub enum SystemContent {
178 Text { text: String },
179}
180
181impl FromStr for SystemContent {
182 type Err = Infallible;
183
184 fn from_str(s: &str) -> Result<Self, Self::Err> {
185 Ok(SystemContent::Text {
186 text: s.to_string(),
187 })
188 }
189}
190
191impl From<UserContent> for message::UserContent {
192 fn from(value: UserContent) -> Self {
193 match value {
194 UserContent::Text { text } => message::UserContent::text(text),
195 UserContent::ImageUrl { image_url } => {
196 message::UserContent::image_url(image_url.url, None, None)
197 }
198 }
199 }
200}
201
202impl TryFrom<message::UserContent> for UserContent {
203 type Error = message::MessageError;
204
205 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
206 match content {
207 message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }),
208 message::UserContent::Document(message::Document {
209 data: message::DocumentSourceKind::Raw(raw),
210 ..
211 }) => {
212 let text = String::from_utf8_lossy(raw.as_slice()).into();
213 Ok(UserContent::Text { text })
214 }
215 message::UserContent::Document(message::Document {
216 data:
217 message::DocumentSourceKind::Base64(text)
218 | message::DocumentSourceKind::String(text),
219 ..
220 }) => Ok(UserContent::Text { text }),
221 message::UserContent::Image(message::Image { data, .. }) => match data {
222 message::DocumentSourceKind::Url(url) => Ok(UserContent::ImageUrl {
223 image_url: ImageUrl { url },
224 }),
225 _ => Err(message::MessageError::ConversionError(
226 "Huggingface only supports images as urls".into(),
227 )),
228 },
229 _ => Err(message::MessageError::ConversionError(
230 "Huggingface only supports text and images".into(),
231 )),
232 }
233 }
234}
235
236#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
237#[serde(tag = "role", rename_all = "lowercase")]
238pub enum Message {
239 System {
240 #[serde(deserialize_with = "string_or_one_or_many")]
241 content: OneOrMany<SystemContent>,
242 },
243 User {
244 #[serde(deserialize_with = "string_or_one_or_many")]
245 content: OneOrMany<UserContent>,
246 },
247 Assistant {
248 #[serde(default, deserialize_with = "json_utils::string_or_vec")]
249 content: Vec<AssistantContent>,
250 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
251 tool_calls: Vec<ToolCall>,
252 },
253 #[serde(rename = "tool", alias = "Tool")]
254 ToolResult {
255 name: String,
256 #[serde(skip_serializing_if = "Option::is_none")]
257 arguments: Option<serde_json::Value>,
258 #[serde(
259 deserialize_with = "string_or_one_or_many",
260 serialize_with = "serialize_tool_content"
261 )]
262 content: OneOrMany<String>,
263 },
264}
265
266fn serialize_tool_content<S>(content: &OneOrMany<String>, serializer: S) -> Result<S::Ok, S::Error>
267where
268 S: Serializer,
269{
270 let joined = content
272 .iter()
273 .map(String::as_str)
274 .collect::<Vec<_>>()
275 .join("\n");
276 serializer.serialize_str(&joined)
277}
278
279impl Message {
280 pub fn system(content: &str) -> Self {
281 Message::System {
282 content: OneOrMany::one(SystemContent::Text {
283 text: content.to_string(),
284 }),
285 }
286 }
287}
288
289impl TryFrom<message::Message> for Vec<Message> {
290 type Error = message::MessageError;
291
292 fn try_from(message: message::Message) -> Result<Vec<Message>, Self::Error> {
293 match message {
294 message::Message::User { content } => {
295 let (tool_results, other_content): (Vec<_>, Vec<_>) = content
296 .into_iter()
297 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
298
299 if !tool_results.is_empty() {
300 tool_results
301 .into_iter()
302 .map(|content| match content {
303 message::UserContent::ToolResult(message::ToolResult {
304 id,
305 content,
306 ..
307 }) => Ok::<_, message::MessageError>(Message::ToolResult {
308 name: id,
309 arguments: None,
310 content: content.try_map(|content| match content {
311 message::ToolResultContent::Text(message::Text { text }) => {
312 Ok(text)
313 }
314 _ => Err(message::MessageError::ConversionError(
315 "Tool result content does not support non-text".into(),
316 )),
317 })?,
318 }),
319 _ => unreachable!(),
320 })
321 .collect::<Result<Vec<_>, _>>()
322 } else {
323 let other_content = OneOrMany::many(other_content).expect(
324 "There must be other content here if there were no tool result content",
325 );
326
327 Ok(vec![Message::User {
328 content: other_content.try_map(|content| match content {
329 message::UserContent::Text(text) => {
330 Ok(UserContent::Text { text: text.text })
331 }
332 message::UserContent::Image(image) => {
333 let url = image.try_into_url()?;
334
335 Ok(UserContent::ImageUrl {
336 image_url: ImageUrl { url },
337 })
338 }
339 message::UserContent::Document(message::Document {
340 data: message::DocumentSourceKind::Raw(raw), ..
341 }) => {
342 let text = String::from_utf8_lossy(raw.as_slice()).into();
343 Ok(UserContent::Text { text })
344 }
345 message::UserContent::Document(message::Document {
346 data: message::DocumentSourceKind::Base64(text) | message::DocumentSourceKind::String(text), ..
347 }) => {
348 Ok(UserContent::Text { text })
349 }
350 _ => Err(message::MessageError::ConversionError(
351 "Huggingface inputs only support text and image URLs (both base64-encoded images and regular URLs)".into(),
352 )),
353 })?,
354 }])
355 }
356 }
357 message::Message::Assistant { content, .. } => {
358 let (text_content, tool_calls) = content.into_iter().fold(
359 (Vec::new(), Vec::new()),
360 |(mut texts, mut tools), content| {
361 match content {
362 message::AssistantContent::Text(text) => texts.push(text),
363 message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
364 message::AssistantContent::Reasoning(_) => {
365 unimplemented!("Reasoning is not supported on HuggingFace via Rig");
366 }
367 message::AssistantContent::Image(_) => {
368 unimplemented!(
369 "Image content is not supported on HuggingFace via Rig"
370 );
371 }
372 }
373 (texts, tools)
374 },
375 );
376
377 Ok(vec![Message::Assistant {
380 content: text_content
381 .into_iter()
382 .map(|content| AssistantContent::Text { text: content.text })
383 .collect::<Vec<_>>(),
384 tool_calls: tool_calls
385 .into_iter()
386 .map(|tool_call| tool_call.into())
387 .collect::<Vec<_>>(),
388 }])
389 }
390 }
391 }
392}
393
394impl TryFrom<Message> for message::Message {
395 type Error = message::MessageError;
396
397 fn try_from(message: Message) -> Result<Self, Self::Error> {
398 Ok(match message {
399 Message::User { content, .. } => message::Message::User {
400 content: content.map(|content| content.into()),
401 },
402 Message::Assistant {
403 content,
404 tool_calls,
405 ..
406 } => {
407 let mut content = content
408 .into_iter()
409 .map(|content| match content {
410 AssistantContent::Text { text } => message::AssistantContent::text(text),
411 })
412 .collect::<Vec<_>>();
413
414 content.extend(
415 tool_calls
416 .into_iter()
417 .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
418 .collect::<Result<Vec<_>, _>>()?,
419 );
420
421 message::Message::Assistant {
422 id: None,
423 content: OneOrMany::many(content).map_err(|_| {
424 message::MessageError::ConversionError(
425 "Neither `content` nor `tool_calls` was provided to the Message"
426 .to_owned(),
427 )
428 })?,
429 }
430 }
431
432 Message::ToolResult { name, content, .. } => message::Message::User {
433 content: OneOrMany::one(message::UserContent::tool_result(
434 name,
435 content.map(message::ToolResultContent::text),
436 )),
437 },
438
439 Message::System { content, .. } => message::Message::User {
442 content: content.map(|c| match c {
443 SystemContent::Text { text } => message::UserContent::text(text),
444 }),
445 },
446 })
447 }
448}
449
450#[derive(Clone, Debug, Deserialize, Serialize)]
451pub struct Choice {
452 pub finish_reason: String,
453 pub index: usize,
454 #[serde(default)]
455 pub logprobs: serde_json::Value,
456 pub message: Message,
457}
458
459#[derive(Debug, Deserialize, Clone, Serialize)]
460pub struct Usage {
461 pub completion_tokens: i32,
462 pub prompt_tokens: i32,
463 pub total_tokens: i32,
464}
465
466impl GetTokenUsage for Usage {
467 fn token_usage(&self) -> Option<crate::completion::Usage> {
468 let mut usage = crate::completion::Usage::new();
469 usage.input_tokens = self.prompt_tokens as u64;
470 usage.output_tokens = self.completion_tokens as u64;
471 usage.total_tokens = self.total_tokens as u64;
472
473 Some(usage)
474 }
475}
476
477#[derive(Clone, Debug, Deserialize, Serialize)]
478pub struct CompletionResponse {
479 pub created: i32,
480 pub id: String,
481 pub model: String,
482 pub choices: Vec<Choice>,
483 #[serde(default, deserialize_with = "default_string_on_null")]
484 pub system_fingerprint: String,
485 pub usage: Usage,
486}
487
488impl crate::telemetry::ProviderResponseExt for CompletionResponse {
489 type OutputMessage = Choice;
490 type Usage = Usage;
491
492 fn get_response_id(&self) -> Option<String> {
493 Some(self.id.clone())
494 }
495
496 fn get_response_model_name(&self) -> Option<String> {
497 Some(self.model.clone())
498 }
499
500 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
501 self.choices.clone()
502 }
503
504 fn get_text_response(&self) -> Option<String> {
505 let text_response = self
506 .choices
507 .iter()
508 .filter_map(|x| {
509 let Message::User { ref content } = x.message else {
510 return None;
511 };
512
513 let text = content
514 .iter()
515 .filter_map(|x| {
516 if let UserContent::Text { text } = x {
517 Some(text.clone())
518 } else {
519 None
520 }
521 })
522 .collect::<Vec<String>>();
523
524 if text.is_empty() {
525 None
526 } else {
527 Some(text.join("\n"))
528 }
529 })
530 .collect::<Vec<String>>()
531 .join("\n");
532
533 if text_response.is_empty() {
534 None
535 } else {
536 Some(text_response)
537 }
538 }
539
540 fn get_usage(&self) -> Option<Self::Usage> {
541 Some(self.usage.clone())
542 }
543}
544
545fn default_string_on_null<'de, D>(deserializer: D) -> Result<String, D::Error>
546where
547 D: Deserializer<'de>,
548{
549 match Option::<String>::deserialize(deserializer)? {
550 Some(value) => Ok(value), None => Ok(String::default()), }
553}
554
555impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
556 type Error = CompletionError;
557
558 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
559 let choice = response.choices.first().ok_or_else(|| {
560 CompletionError::ResponseError("Response contained no choices".to_owned())
561 })?;
562
563 let content = match &choice.message {
564 Message::Assistant {
565 content,
566 tool_calls,
567 ..
568 } => {
569 let mut content = content
570 .iter()
571 .map(|c| match c {
572 AssistantContent::Text { text } => message::AssistantContent::text(text),
573 })
574 .collect::<Vec<_>>();
575
576 content.extend(
577 tool_calls
578 .iter()
579 .map(|call| {
580 completion::AssistantContent::tool_call(
581 &call.id,
582 &call.function.name,
583 call.function.arguments.clone(),
584 )
585 })
586 .collect::<Vec<_>>(),
587 );
588 Ok(content)
589 }
590 _ => Err(CompletionError::ResponseError(
591 "Response did not contain a valid message or tool call".into(),
592 )),
593 }?;
594
595 let choice = OneOrMany::many(content).map_err(|_| {
596 CompletionError::ResponseError(
597 "Response contained no message or tool call (empty)".to_owned(),
598 )
599 })?;
600
601 let usage = completion::Usage {
602 input_tokens: response.usage.prompt_tokens as u64,
603 output_tokens: response.usage.completion_tokens as u64,
604 total_tokens: response.usage.total_tokens as u64,
605 };
606
607 Ok(completion::CompletionResponse {
608 choice,
609 usage,
610 raw_response: response,
611 })
612 }
613}
614
615#[derive(Debug, Serialize, Deserialize)]
616pub(super) struct HuggingfaceCompletionRequest {
617 model: String,
618 pub messages: Vec<Message>,
619 #[serde(flatten, skip_serializing_if = "Option::is_none")]
620 temperature: Option<f64>,
621 #[serde(skip_serializing_if = "Vec::is_empty")]
622 tools: Vec<ToolDefinition>,
623 #[serde(flatten, skip_serializing_if = "Option::is_none")]
624 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
625 #[serde(flatten, skip_serializing_if = "Option::is_none")]
626 pub additional_params: Option<serde_json::Value>,
627}
628
629impl TryFrom<(&str, CompletionRequest)> for HuggingfaceCompletionRequest {
630 type Error = CompletionError;
631
632 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
633 let mut full_history: Vec<Message> = match &req.preamble {
634 Some(preamble) => vec![Message::system(preamble)],
635 None => vec![],
636 };
637 if let Some(docs) = req.normalized_documents() {
638 let docs: Vec<Message> = docs.try_into()?;
639 full_history.extend(docs);
640 }
641
642 let chat_history: Vec<Message> = req
643 .chat_history
644 .clone()
645 .into_iter()
646 .map(|message| message.try_into())
647 .collect::<Result<Vec<Vec<Message>>, _>>()?
648 .into_iter()
649 .flatten()
650 .collect();
651
652 full_history.extend(chat_history);
653
654 let tool_choice = req
655 .tool_choice
656 .clone()
657 .map(crate::providers::openai::completion::ToolChoice::try_from)
658 .transpose()?;
659
660 Ok(Self {
661 model: model.to_string(),
662 messages: full_history,
663 temperature: req.temperature,
664 tools: req
665 .tools
666 .clone()
667 .into_iter()
668 .map(ToolDefinition::from)
669 .collect::<Vec<_>>(),
670 tool_choice,
671 additional_params: req.additional_params,
672 })
673 }
674}
675
676#[derive(Clone)]
677pub struct CompletionModel<T = reqwest::Client> {
678 pub(crate) client: Client<T>,
679 pub model: String,
681}
682
683impl<T> CompletionModel<T> {
684 pub fn new(client: Client<T>, model: &str) -> Self {
685 Self {
686 client,
687 model: model.to_string(),
688 }
689 }
690}
691
692impl<T> completion::CompletionModel for CompletionModel<T>
693where
694 T: HttpClientExt + Clone + 'static,
695{
696 type Response = CompletionResponse;
697 type StreamingResponse = StreamingCompletionResponse;
698
699 type Client = Client<T>;
700
701 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
702 Self::new(client.clone(), &model.into())
703 }
704
705 #[cfg_attr(feature = "worker", worker::send)]
706 async fn completion(
707 &self,
708 completion_request: CompletionRequest,
709 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
710 let span = if tracing::Span::current().is_disabled() {
711 info_span!(
712 target: "rig::completions",
713 "chat",
714 gen_ai.operation.name = "chat",
715 gen_ai.provider.name = "huggingface",
716 gen_ai.request.model = self.model,
717 gen_ai.system_instructions = &completion_request.preamble,
718 gen_ai.response.id = tracing::field::Empty,
719 gen_ai.response.model = tracing::field::Empty,
720 gen_ai.usage.output_tokens = tracing::field::Empty,
721 gen_ai.usage.input_tokens = tracing::field::Empty,
722 )
723 } else {
724 tracing::Span::current()
725 };
726
727 let model = self.client.subprovider().model_identifier(&self.model);
728 let request = HuggingfaceCompletionRequest::try_from((model.as_ref(), completion_request))?;
729
730 if enabled!(Level::TRACE) {
731 tracing::trace!(
732 target: "rig::completions",
733 "Huggingface completion request: {}",
734 serde_json::to_string_pretty(&request)?
735 );
736 }
737
738 let request = serde_json::to_vec(&request)?;
739
740 let path = self.client.subprovider().completion_endpoint(&self.model);
741 let request = self
742 .client
743 .post(&path)?
744 .header("Content-Type", "application/json")
745 .body(request)
746 .map_err(|e| CompletionError::HttpError(e.into()))?;
747
748 async move {
749 let response = self.client.send(request).await?;
750
751 if response.status().is_success() {
752 let bytes: Vec<u8> = response.into_body().await?;
753 let text = String::from_utf8_lossy(&bytes);
754
755 tracing::debug!(target: "rig", "Huggingface completion error: {}", text);
756
757 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&bytes)? {
758 ApiResponse::Ok(response) => {
759 if enabled!(Level::TRACE) {
760 tracing::trace!(
761 target: "rig::completions",
762 "Huggingface completion response: {}",
763 serde_json::to_string_pretty(&response)?
764 );
765 }
766
767 let span = tracing::Span::current();
768 span.record_token_usage(&response.usage);
769 span.record_response_metadata(&response);
770
771 response.try_into()
772 }
773 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())),
774 }
775 } else {
776 let status = response.status();
777 let text: Vec<u8> = response.into_body().await?;
778 let text: String = String::from_utf8_lossy(&text).into();
779
780 Err(CompletionError::ProviderError(format!(
781 "{}: {}",
782 status, text
783 )))
784 }
785 }
786 .instrument(span)
787 .await
788 }
789
790 #[cfg_attr(feature = "worker", worker::send)]
791 async fn stream(
792 &self,
793 request: CompletionRequest,
794 ) -> Result<
795 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
796 CompletionError,
797 > {
798 CompletionModel::stream(self, request).await
799 }
800}
801
802#[cfg(test)]
803mod tests {
804 use super::*;
805 use serde_path_to_error::deserialize;
806
807 #[test]
808 fn test_deserialize_message() {
809 let assistant_message_json = r#"
810 {
811 "role": "assistant",
812 "content": "\n\nHello there, how may I assist you today?"
813 }
814 "#;
815
816 let assistant_message_json2 = r#"
817 {
818 "role": "assistant",
819 "content": [
820 {
821 "type": "text",
822 "text": "\n\nHello there, how may I assist you today?"
823 }
824 ],
825 "tool_calls": null
826 }
827 "#;
828
829 let assistant_message_json3 = r#"
830 {
831 "role": "assistant",
832 "tool_calls": [
833 {
834 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
835 "type": "function",
836 "function": {
837 "name": "subtract",
838 "arguments": {"x": 2, "y": 5}
839 }
840 }
841 ],
842 "content": null,
843 "refusal": null
844 }
845 "#;
846
847 let user_message_json = r#"
848 {
849 "role": "user",
850 "content": [
851 {
852 "type": "text",
853 "text": "What's in this image?"
854 },
855 {
856 "type": "image_url",
857 "image_url": {
858 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
859 }
860 }
861 ]
862 }
863 "#;
864
865 let assistant_message: Message = {
866 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
867 deserialize(jd).unwrap_or_else(|err| {
868 panic!(
869 "Deserialization error at {} ({}:{}): {}",
870 err.path(),
871 err.inner().line(),
872 err.inner().column(),
873 err
874 );
875 })
876 };
877
878 let assistant_message2: Message = {
879 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
880 deserialize(jd).unwrap_or_else(|err| {
881 panic!(
882 "Deserialization error at {} ({}:{}): {}",
883 err.path(),
884 err.inner().line(),
885 err.inner().column(),
886 err
887 );
888 })
889 };
890
891 let assistant_message3: Message = {
892 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
893 &mut serde_json::Deserializer::from_str(assistant_message_json3);
894 deserialize(jd).unwrap_or_else(|err| {
895 panic!(
896 "Deserialization error at {} ({}:{}): {}",
897 err.path(),
898 err.inner().line(),
899 err.inner().column(),
900 err
901 );
902 })
903 };
904
905 let user_message: Message = {
906 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
907 deserialize(jd).unwrap_or_else(|err| {
908 panic!(
909 "Deserialization error at {} ({}:{}): {}",
910 err.path(),
911 err.inner().line(),
912 err.inner().column(),
913 err
914 );
915 })
916 };
917
918 match assistant_message {
919 Message::Assistant { content, .. } => {
920 assert_eq!(
921 content[0],
922 AssistantContent::Text {
923 text: "\n\nHello there, how may I assist you today?".to_string()
924 }
925 );
926 }
927 _ => panic!("Expected assistant message"),
928 }
929
930 match assistant_message2 {
931 Message::Assistant {
932 content,
933 tool_calls,
934 ..
935 } => {
936 assert_eq!(
937 content[0],
938 AssistantContent::Text {
939 text: "\n\nHello there, how may I assist you today?".to_string()
940 }
941 );
942
943 assert_eq!(tool_calls, vec![]);
944 }
945 _ => panic!("Expected assistant message"),
946 }
947
948 match assistant_message3 {
949 Message::Assistant {
950 content,
951 tool_calls,
952 ..
953 } => {
954 assert!(content.is_empty());
955 assert_eq!(
956 tool_calls[0],
957 ToolCall {
958 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
959 r#type: ToolType::Function,
960 function: Function {
961 name: "subtract".to_string(),
962 arguments: serde_json::json!({"x": 2, "y": 5}),
963 },
964 }
965 );
966 }
967 _ => panic!("Expected assistant message"),
968 }
969
970 match user_message {
971 Message::User { content, .. } => {
972 let (first, second) = {
973 let mut iter = content.into_iter();
974 (iter.next().unwrap(), iter.next().unwrap())
975 };
976 assert_eq!(
977 first,
978 UserContent::Text {
979 text: "What's in this image?".to_string()
980 }
981 );
982 assert_eq!(second, UserContent::ImageUrl { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string() } });
983 }
984 _ => panic!("Expected user message"),
985 }
986 }
987
988 #[test]
989 fn test_message_to_message_conversion() {
990 let user_message = message::Message::User {
991 content: OneOrMany::one(message::UserContent::text("Hello")),
992 };
993
994 let assistant_message = message::Message::Assistant {
995 id: None,
996 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
997 };
998
999 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
1000 let converted_assistant_message: Vec<Message> =
1001 assistant_message.clone().try_into().unwrap();
1002
1003 match converted_user_message[0].clone() {
1004 Message::User { content, .. } => {
1005 assert_eq!(
1006 content.first(),
1007 UserContent::Text {
1008 text: "Hello".to_string()
1009 }
1010 );
1011 }
1012 _ => panic!("Expected user message"),
1013 }
1014
1015 match converted_assistant_message[0].clone() {
1016 Message::Assistant { content, .. } => {
1017 assert_eq!(
1018 content[0],
1019 AssistantContent::Text {
1020 text: "Hi there!".to_string()
1021 }
1022 );
1023 }
1024 _ => panic!("Expected assistant message"),
1025 }
1026
1027 let original_user_message: message::Message =
1028 converted_user_message[0].clone().try_into().unwrap();
1029 let original_assistant_message: message::Message =
1030 converted_assistant_message[0].clone().try_into().unwrap();
1031
1032 assert_eq!(original_user_message, user_message);
1033 assert_eq!(original_assistant_message, assistant_message);
1034 }
1035
1036 #[test]
1037 fn test_message_from_message_conversion() {
1038 let user_message = Message::User {
1039 content: OneOrMany::one(UserContent::Text {
1040 text: "Hello".to_string(),
1041 }),
1042 };
1043
1044 let assistant_message = Message::Assistant {
1045 content: vec![AssistantContent::Text {
1046 text: "Hi there!".to_string(),
1047 }],
1048 tool_calls: vec![],
1049 };
1050
1051 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1052 let converted_assistant_message: message::Message =
1053 assistant_message.clone().try_into().unwrap();
1054
1055 match converted_user_message.clone() {
1056 message::Message::User { content } => {
1057 assert_eq!(content.first(), message::UserContent::text("Hello"));
1058 }
1059 _ => panic!("Expected user message"),
1060 }
1061
1062 match converted_assistant_message.clone() {
1063 message::Message::Assistant { content, .. } => {
1064 assert_eq!(
1065 content.first(),
1066 message::AssistantContent::text("Hi there!")
1067 );
1068 }
1069 _ => panic!("Expected assistant message"),
1070 }
1071
1072 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
1073 let original_assistant_message: Vec<Message> =
1074 converted_assistant_message.try_into().unwrap();
1075
1076 assert_eq!(original_user_message[0], user_message);
1077 assert_eq!(original_assistant_message[0], assistant_message);
1078 }
1079
1080 #[test]
1081 fn test_responses() {
1082 let fireworks_response_json = r#"
1083 {
1084 "choices": [
1085 {
1086 "finish_reason": "tool_calls",
1087 "index": 0,
1088 "message": {
1089 "role": "assistant",
1090 "tool_calls": [
1091 {
1092 "function": {
1093 "arguments": "{\"x\": 2, \"y\": 5}",
1094 "name": "subtract"
1095 },
1096 "id": "call_1BspL6mQqjKgvsQbH1TIYkHf",
1097 "index": 0,
1098 "type": "function"
1099 }
1100 ]
1101 }
1102 }
1103 ],
1104 "created": 1740704000,
1105 "id": "2a81f6a1-4866-42fb-9902-2655a2b5b1ff",
1106 "model": "accounts/fireworks/models/deepseek-v3",
1107 "object": "chat.completion",
1108 "usage": {
1109 "completion_tokens": 26,
1110 "prompt_tokens": 248,
1111 "total_tokens": 274
1112 }
1113 }
1114 "#;
1115
1116 let novita_response_json = r#"
1117 {
1118 "choices": [
1119 {
1120 "finish_reason": "tool_calls",
1121 "index": 0,
1122 "logprobs": null,
1123 "message": {
1124 "audio": null,
1125 "content": null,
1126 "function_call": null,
1127 "reasoning_content": null,
1128 "refusal": null,
1129 "role": "assistant",
1130 "tool_calls": [
1131 {
1132 "function": {
1133 "arguments": "{\"x\": \"2\", \"y\": \"5\"}",
1134 "name": "subtract"
1135 },
1136 "id": "chatcmpl-tool-f6d2af7c8dc041058f95e2c2eede45c5",
1137 "type": "function"
1138 }
1139 ]
1140 },
1141 "stop_reason": 128008
1142 }
1143 ],
1144 "created": 1740704592,
1145 "id": "chatcmpl-a92c60ae125c47c998ecdcb53387fed4",
1146 "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-fast",
1147 "object": "chat.completion",
1148 "prompt_logprobs": null,
1149 "service_tier": null,
1150 "system_fingerprint": null,
1151 "usage": {
1152 "completion_tokens": 28,
1153 "completion_tokens_details": null,
1154 "prompt_tokens": 335,
1155 "prompt_tokens_details": null,
1156 "total_tokens": 363
1157 }
1158 }
1159 "#;
1160
1161 let _firework_response: CompletionResponse = {
1162 let jd = &mut serde_json::Deserializer::from_str(fireworks_response_json);
1163 deserialize(jd).unwrap_or_else(|err| {
1164 panic!(
1165 "Deserialization error at {} ({}:{}): {}",
1166 err.path(),
1167 err.inner().line(),
1168 err.inner().column(),
1169 err
1170 );
1171 })
1172 };
1173
1174 let _novita_response: CompletionResponse = {
1175 let jd = &mut serde_json::Deserializer::from_str(novita_response_json);
1176 deserialize(jd).unwrap_or_else(|err| {
1177 panic!(
1178 "Deserialization error at {} ({}:{}): {}",
1179 err.path(),
1180 err.inner().line(),
1181 err.inner().column(),
1182 err
1183 );
1184 })
1185 };
1186 }
1187}