1use super::client::Client;
2use crate::completion::GetTokenUsage;
3use crate::http_client::HttpClientExt;
4use crate::providers::openai::StreamingCompletionResponse;
5use crate::telemetry::SpanCombinator;
6use crate::{
7 OneOrMany,
8 completion::{self, CompletionError, CompletionRequest},
9 json_utils,
10 message::{self},
11 one_or_many::string_or_one_or_many,
12};
13use serde::{Deserialize, Deserializer, Serialize, Serializer};
14use serde_json::Value;
15use std::{convert::Infallible, str::FromStr};
16use tracing::{Level, enabled, info_span};
17use tracing_futures::Instrument;
18
19#[derive(Debug, Deserialize)]
20#[serde(untagged)]
21pub enum ApiResponse<T> {
22 Ok(T),
23 Err(Value),
24}
25
26pub const GEMMA_2: &str = "google/gemma-2-2b-it";
33pub const META_LLAMA_3_1: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
35pub const SMALLTHINKER_PREVIEW: &str = "PowerInfer/SmallThinker-3B-Preview";
37pub const QWEN2_5: &str = "Qwen/Qwen2.5-7B-Instruct";
39pub const QWEN2_5_CODER: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
41
42pub const QWEN2_VL: &str = "Qwen/Qwen2-VL-7B-Instruct";
46pub const QWEN_QVQ_PREVIEW: &str = "Qwen/QVQ-72B-Preview";
48
49#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
50pub struct Function {
51 name: String,
52 #[serde(
53 serialize_with = "json_utils::stringified_json::serialize",
54 deserialize_with = "json_utils::stringified_json::deserialize_maybe_stringified"
55 )]
56 pub arguments: serde_json::Value,
57}
58
59impl From<Function> for message::ToolFunction {
60 fn from(value: Function) -> Self {
61 message::ToolFunction {
62 name: value.name,
63 arguments: value.arguments,
64 }
65 }
66}
67
68#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
69#[serde(rename_all = "lowercase")]
70pub enum ToolType {
71 #[default]
72 Function,
73}
74
75#[derive(Debug, Deserialize, Serialize, Clone)]
76pub struct ToolDefinition {
77 pub r#type: String,
78 pub function: completion::ToolDefinition,
79}
80
81impl From<completion::ToolDefinition> for ToolDefinition {
82 fn from(tool: completion::ToolDefinition) -> Self {
83 Self {
84 r#type: "function".into(),
85 function: tool,
86 }
87 }
88}
89
90#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
91pub struct ToolCall {
92 pub id: String,
93 pub r#type: ToolType,
94 pub function: Function,
95}
96
97impl From<ToolCall> for message::ToolCall {
98 fn from(value: ToolCall) -> Self {
99 message::ToolCall {
100 id: value.id,
101 call_id: None,
102 function: value.function.into(),
103 signature: None,
104 additional_params: None,
105 }
106 }
107}
108
109impl From<message::ToolCall> for ToolCall {
110 fn from(value: message::ToolCall) -> Self {
111 ToolCall {
112 id: value.id,
113 r#type: ToolType::Function,
114 function: Function {
115 name: value.function.name,
116 arguments: value.function.arguments,
117 },
118 }
119 }
120}
121
122#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
123pub struct ImageUrl {
124 url: String,
125}
126
127#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
128#[serde(tag = "type", rename_all = "lowercase")]
129pub enum UserContent {
130 Text {
131 text: String,
132 },
133 #[serde(rename = "image_url")]
134 ImageUrl {
135 image_url: ImageUrl,
136 },
137}
138
139impl FromStr for UserContent {
140 type Err = Infallible;
141
142 fn from_str(s: &str) -> Result<Self, Self::Err> {
143 Ok(UserContent::Text {
144 text: s.to_string(),
145 })
146 }
147}
148
149#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
150#[serde(tag = "type", rename_all = "lowercase")]
151pub enum AssistantContent {
152 Text { text: String },
153}
154
155impl FromStr for AssistantContent {
156 type Err = Infallible;
157
158 fn from_str(s: &str) -> Result<Self, Self::Err> {
159 Ok(AssistantContent::Text {
160 text: s.to_string(),
161 })
162 }
163}
164
165#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
166#[serde(tag = "type", rename_all = "lowercase")]
167pub enum SystemContent {
168 Text { text: String },
169}
170
171impl FromStr for SystemContent {
172 type Err = Infallible;
173
174 fn from_str(s: &str) -> Result<Self, Self::Err> {
175 Ok(SystemContent::Text {
176 text: s.to_string(),
177 })
178 }
179}
180
181impl From<UserContent> for message::UserContent {
182 fn from(value: UserContent) -> Self {
183 match value {
184 UserContent::Text { text, .. } => message::UserContent::text(text),
185 UserContent::ImageUrl { image_url } => {
186 message::UserContent::image_url(image_url.url, None, None)
187 }
188 }
189 }
190}
191
192impl TryFrom<message::UserContent> for UserContent {
193 type Error = message::MessageError;
194
195 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
196 match content {
197 message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }),
198 message::UserContent::Document(message::Document {
199 data: message::DocumentSourceKind::Raw(raw),
200 ..
201 }) => {
202 let text = String::from_utf8_lossy(raw.as_slice()).into();
203 Ok(UserContent::Text { text })
204 }
205 message::UserContent::Document(message::Document {
206 data:
207 message::DocumentSourceKind::Base64(text)
208 | message::DocumentSourceKind::String(text),
209 ..
210 }) => Ok(UserContent::Text { text }),
211 message::UserContent::Image(message::Image { data, .. }) => match data {
212 message::DocumentSourceKind::Url(url) => Ok(UserContent::ImageUrl {
213 image_url: ImageUrl { url },
214 }),
215 _ => Err(message::MessageError::ConversionError(
216 "Huggingface only supports images as urls".into(),
217 )),
218 },
219 _ => Err(message::MessageError::ConversionError(
220 "Huggingface only supports text and images".into(),
221 )),
222 }
223 }
224}
225
226#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
227#[serde(tag = "role", rename_all = "lowercase")]
228pub enum Message {
229 System {
230 #[serde(deserialize_with = "string_or_one_or_many")]
231 content: OneOrMany<SystemContent>,
232 },
233 User {
234 #[serde(deserialize_with = "string_or_one_or_many")]
235 content: OneOrMany<UserContent>,
236 },
237 Assistant {
238 #[serde(default, deserialize_with = "json_utils::string_or_vec")]
239 content: Vec<AssistantContent>,
240 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
241 tool_calls: Vec<ToolCall>,
242 },
243 #[serde(rename = "tool", alias = "Tool")]
244 ToolResult {
245 name: String,
246 #[serde(skip_serializing_if = "Option::is_none")]
247 arguments: Option<serde_json::Value>,
248 #[serde(
249 deserialize_with = "string_or_one_or_many",
250 serialize_with = "serialize_tool_content"
251 )]
252 content: OneOrMany<String>,
253 },
254}
255
256fn serialize_tool_content<S>(content: &OneOrMany<String>, serializer: S) -> Result<S::Ok, S::Error>
257where
258 S: Serializer,
259{
260 let joined = content
262 .iter()
263 .map(String::as_str)
264 .collect::<Vec<_>>()
265 .join("\n");
266 serializer.serialize_str(&joined)
267}
268
269impl Message {
270 pub fn system(content: &str) -> Self {
271 Message::System {
272 content: OneOrMany::one(SystemContent::Text {
273 text: content.to_string(),
274 }),
275 }
276 }
277}
278
279impl TryFrom<message::Message> for Vec<Message> {
280 type Error = message::MessageError;
281
282 fn try_from(message: message::Message) -> Result<Vec<Message>, Self::Error> {
283 match message {
284 message::Message::System { content } => Ok(vec![Message::system(&content)]),
285 message::Message::User { content } => {
286 let (tool_results, other_content): (Vec<_>, Vec<_>) = content
287 .into_iter()
288 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
289
290 if !tool_results.is_empty() {
291 tool_results
292 .into_iter()
293 .map(|content| match content {
294 message::UserContent::ToolResult(message::ToolResult {
295 id,
296 content,
297 ..
298 }) => Ok::<_, message::MessageError>(Message::ToolResult {
299 name: id,
300 arguments: None,
301 content: content.try_map(|content| match content {
302 message::ToolResultContent::Text(message::Text {
303 text,
304 ..
305 }) => Ok(text),
306 _ => Err(message::MessageError::ConversionError(
307 "Tool result content does not support non-text".into(),
308 )),
309 })?,
310 }),
311 _ => Err(message::MessageError::ConversionError(
312 "expected tool result content while converting HuggingFace input"
313 .into(),
314 )),
315 })
316 .collect::<Result<Vec<_>, _>>()
317 } else {
318 let other_content = OneOrMany::many(other_content).map_err(|_| {
319 message::MessageError::ConversionError(
320 "HuggingFace user message did not contain any non-tool content".into(),
321 )
322 })?;
323
324 Ok(vec![Message::User {
325 content: other_content.try_map(|content| match content {
326 message::UserContent::Text(text) => {
327 Ok(UserContent::Text { text: text.text })
328 }
329 message::UserContent::Image(image) => {
330 let url = image.try_into_url()?;
331
332 Ok(UserContent::ImageUrl {
333 image_url: ImageUrl { url },
334 })
335 }
336 message::UserContent::Document(message::Document {
337 data: message::DocumentSourceKind::Raw(raw), ..
338 }) => {
339 let text = String::from_utf8_lossy(raw.as_slice()).into();
340 Ok(UserContent::Text { text })
341 }
342 message::UserContent::Document(message::Document {
343 data: message::DocumentSourceKind::Base64(text) | message::DocumentSourceKind::String(text), ..
344 }) => {
345 Ok(UserContent::Text { text })
346 }
347 _ => Err(message::MessageError::ConversionError(
348 "Huggingface inputs only support text and image URLs (both base64-encoded images and regular URLs)".into(),
349 )),
350 })?,
351 }])
352 }
353 }
354 message::Message::Assistant { content, .. } => {
355 let mut text_content = Vec::new();
356 let mut tool_calls = Vec::new();
357
358 for content in content {
359 match content {
360 message::AssistantContent::Text(text) => text_content.push(text),
361 message::AssistantContent::ToolCall(tool_call) => {
362 tool_calls.push(tool_call)
363 }
364 message::AssistantContent::Reasoning(_) => {
365 }
368 message::AssistantContent::Image(_) => {
369 return Err(message::MessageError::ConversionError(
370 "HuggingFace assistant messages do not support image content"
371 .into(),
372 ));
373 }
374 }
375 }
376
377 if text_content.is_empty() && tool_calls.is_empty() {
378 return Ok(vec![]);
379 }
380
381 Ok(vec![Message::Assistant {
382 content: text_content
383 .into_iter()
384 .map(|content| AssistantContent::Text { text: content.text })
385 .collect::<Vec<_>>(),
386 tool_calls: tool_calls
387 .into_iter()
388 .map(|tool_call| tool_call.into())
389 .collect::<Vec<_>>(),
390 }])
391 }
392 }
393 }
394}
395
396impl TryFrom<Message> for message::Message {
397 type Error = message::MessageError;
398
399 fn try_from(message: Message) -> Result<Self, Self::Error> {
400 Ok(match message {
401 Message::User { content, .. } => message::Message::User {
402 content: content.map(|content| content.into()),
403 },
404 Message::Assistant {
405 content,
406 tool_calls,
407 ..
408 } => {
409 let mut content = content
410 .into_iter()
411 .map(|content| match content {
412 AssistantContent::Text { text, .. } => {
413 message::AssistantContent::text(text)
414 }
415 })
416 .collect::<Vec<_>>();
417
418 content.extend(
419 tool_calls
420 .into_iter()
421 .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
422 .collect::<Result<Vec<_>, _>>()?,
423 );
424
425 message::Message::Assistant {
426 id: None,
427 content: OneOrMany::many(content).map_err(|_| {
428 message::MessageError::ConversionError(
429 "Neither `content` nor `tool_calls` was provided to the Message"
430 .to_owned(),
431 )
432 })?,
433 }
434 }
435
436 Message::ToolResult { name, content, .. } => message::Message::User {
437 content: OneOrMany::one(message::UserContent::tool_result(
438 name,
439 content.map(message::ToolResultContent::text),
440 )),
441 },
442
443 Message::System { content, .. } => message::Message::User {
446 content: content.map(|c| match c {
447 SystemContent::Text { text, .. } => message::UserContent::text(text),
448 }),
449 },
450 })
451 }
452}
453
454#[derive(Clone, Debug, Deserialize, Serialize)]
455pub struct Choice {
456 pub finish_reason: String,
457 pub index: usize,
458 #[serde(default)]
459 pub logprobs: serde_json::Value,
460 pub message: Message,
461}
462
463#[derive(Debug, Deserialize, Clone, Serialize)]
464pub struct Usage {
465 pub completion_tokens: i32,
466 pub prompt_tokens: i32,
467 pub total_tokens: i32,
468}
469
470impl GetTokenUsage for Usage {
471 fn token_usage(&self) -> Option<crate::completion::Usage> {
472 let mut usage = crate::completion::Usage::new();
473 usage.input_tokens = self.prompt_tokens as u64;
474 usage.output_tokens = self.completion_tokens as u64;
475 usage.total_tokens = self.total_tokens as u64;
476
477 Some(usage)
478 }
479}
480
481#[derive(Clone, Debug, Deserialize, Serialize)]
482pub struct CompletionResponse {
483 pub created: i32,
484 pub id: String,
485 pub model: String,
486 pub choices: Vec<Choice>,
487 #[serde(default, deserialize_with = "default_string_on_null")]
488 pub system_fingerprint: String,
489 pub usage: Usage,
490}
491
492impl crate::telemetry::ProviderResponseExt for CompletionResponse {
493 type OutputMessage = Choice;
494 type Usage = Usage;
495
496 fn get_response_id(&self) -> Option<String> {
497 Some(self.id.clone())
498 }
499
500 fn get_response_model_name(&self) -> Option<String> {
501 Some(self.model.clone())
502 }
503
504 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
505 self.choices.clone()
506 }
507
508 fn get_text_response(&self) -> Option<String> {
509 let text_response = self
510 .choices
511 .iter()
512 .filter_map(|x| {
513 let Message::User { ref content } = x.message else {
514 return None;
515 };
516
517 let text = content
518 .iter()
519 .filter_map(|x| {
520 if let UserContent::Text { text, .. } = x {
521 Some(text.clone())
522 } else {
523 None
524 }
525 })
526 .collect::<Vec<String>>();
527
528 if text.is_empty() {
529 None
530 } else {
531 Some(text.join("\n"))
532 }
533 })
534 .collect::<Vec<String>>()
535 .join("\n");
536
537 if text_response.is_empty() {
538 None
539 } else {
540 Some(text_response)
541 }
542 }
543
544 fn get_usage(&self) -> Option<Self::Usage> {
545 Some(self.usage.clone())
546 }
547}
548
549fn default_string_on_null<'de, D>(deserializer: D) -> Result<String, D::Error>
550where
551 D: Deserializer<'de>,
552{
553 match Option::<String>::deserialize(deserializer)? {
554 Some(value) => Ok(value), None => Ok(String::default()), }
557}
558
559impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
560 type Error = CompletionError;
561
562 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
563 let choice = response.choices.first().ok_or_else(|| {
564 CompletionError::ResponseError("Response contained no choices".to_owned())
565 })?;
566
567 let content = match &choice.message {
568 Message::Assistant {
569 content,
570 tool_calls,
571 ..
572 } => {
573 let mut content = content
574 .iter()
575 .map(|c| match c {
576 AssistantContent::Text { text, .. } => {
577 message::AssistantContent::text(text)
578 }
579 })
580 .collect::<Vec<_>>();
581
582 content.extend(
583 tool_calls
584 .iter()
585 .map(|call| {
586 completion::AssistantContent::tool_call(
587 &call.id,
588 &call.function.name,
589 call.function.arguments.clone(),
590 )
591 })
592 .collect::<Vec<_>>(),
593 );
594 Ok(content)
595 }
596 _ => Err(CompletionError::ResponseError(
597 "Response did not contain a valid message or tool call".into(),
598 )),
599 }?;
600
601 let choice = OneOrMany::many(content).map_err(|_| {
602 CompletionError::ResponseError(
603 "Response contained no message or tool call (empty)".to_owned(),
604 )
605 })?;
606
607 let usage = completion::Usage {
608 input_tokens: response.usage.prompt_tokens as u64,
609 output_tokens: response.usage.completion_tokens as u64,
610 total_tokens: response.usage.total_tokens as u64,
611 cached_input_tokens: 0,
612 cache_creation_input_tokens: 0,
613 tool_use_prompt_tokens: 0,
614 reasoning_tokens: 0,
615 };
616
617 Ok(completion::CompletionResponse {
618 choice,
619 usage,
620 raw_response: response,
621 message_id: None,
622 })
623 }
624}
625
626#[derive(Debug, Serialize, Deserialize)]
627pub(super) struct HuggingfaceCompletionRequest {
628 model: String,
629 pub messages: Vec<Message>,
630 #[serde(skip_serializing_if = "Option::is_none")]
631 temperature: Option<f64>,
632 #[serde(skip_serializing_if = "Vec::is_empty")]
633 tools: Vec<ToolDefinition>,
634 #[serde(skip_serializing_if = "Option::is_none")]
635 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
636 #[serde(flatten, skip_serializing_if = "Option::is_none")]
637 pub additional_params: Option<serde_json::Value>,
638}
639
640impl TryFrom<(&str, CompletionRequest)> for HuggingfaceCompletionRequest {
641 type Error = CompletionError;
642
643 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
644 if req.output_schema.is_some() {
645 tracing::warn!("Structured outputs currently not supported for Huggingface");
646 }
647 let model = req.model.clone().unwrap_or_else(|| model.to_string());
648 let mut full_history: Vec<Message> = match &req.preamble {
649 Some(preamble) => vec![Message::system(preamble)],
650 None => vec![],
651 };
652 if let Some(docs) = req.normalized_documents() {
653 let docs: Vec<Message> = docs.try_into()?;
654 full_history.extend(docs);
655 }
656
657 let chat_history: Vec<Message> = req
658 .chat_history
659 .clone()
660 .into_iter()
661 .map(|message| message.try_into())
662 .collect::<Result<Vec<Vec<Message>>, _>>()?
663 .into_iter()
664 .flatten()
665 .collect();
666
667 full_history.extend(chat_history);
668
669 if full_history.is_empty() {
670 return Err(CompletionError::RequestError(
671 std::io::Error::new(
672 std::io::ErrorKind::InvalidInput,
673 "HuggingFace request has no provider-compatible messages after conversion",
674 )
675 .into(),
676 ));
677 }
678
679 let tool_choice = req
680 .tool_choice
681 .clone()
682 .map(crate::providers::openai::completion::ToolChoice::try_from)
683 .transpose()?;
684
685 Ok(Self {
686 model: model.to_string(),
687 messages: full_history,
688 temperature: req.temperature,
689 tools: req
690 .tools
691 .clone()
692 .into_iter()
693 .map(ToolDefinition::from)
694 .collect::<Vec<_>>(),
695 tool_choice,
696 additional_params: req.additional_params,
697 })
698 }
699}
700
701#[derive(Clone)]
702pub struct CompletionModel<T = reqwest::Client> {
703 pub(crate) client: Client<T>,
704 pub model: String,
706}
707
708impl<T> CompletionModel<T> {
709 pub fn new(client: Client<T>, model: &str) -> Self {
710 Self {
711 client,
712 model: model.to_string(),
713 }
714 }
715}
716
717impl<T> completion::CompletionModel for CompletionModel<T>
718where
719 T: HttpClientExt + Clone + 'static,
720{
721 type Response = CompletionResponse;
722 type StreamingResponse = StreamingCompletionResponse;
723
724 type Client = Client<T>;
725
726 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
727 Self::new(client.clone(), &model.into())
728 }
729
730 async fn completion(
731 &self,
732 completion_request: CompletionRequest,
733 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
734 let request_model = completion_request
735 .model
736 .clone()
737 .unwrap_or_else(|| self.model.clone());
738 let span = if tracing::Span::current().is_disabled() {
739 info_span!(
740 target: "rig::completions",
741 "chat",
742 gen_ai.operation.name = "chat",
743 gen_ai.provider.name = "huggingface",
744 gen_ai.request.model = &request_model,
745 gen_ai.system_instructions = &completion_request.preamble,
746 gen_ai.response.id = tracing::field::Empty,
747 gen_ai.response.model = tracing::field::Empty,
748 gen_ai.usage.output_tokens = tracing::field::Empty,
749 gen_ai.usage.input_tokens = tracing::field::Empty,
750 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
751 )
752 } else {
753 tracing::Span::current()
754 };
755
756 let model = self.client.subprovider().model_identifier(&request_model);
757 let request = HuggingfaceCompletionRequest::try_from((model.as_ref(), completion_request))?;
758
759 if enabled!(Level::TRACE) {
760 tracing::trace!(
761 target: "rig::completions",
762 "Huggingface completion request: {}",
763 serde_json::to_string_pretty(&request)?
764 );
765 }
766
767 let request = serde_json::to_vec(&request)?;
768
769 let path = self
770 .client
771 .subprovider()
772 .completion_endpoint(&request_model);
773 let request = self
774 .client
775 .post(&path)?
776 .header("Content-Type", "application/json")
777 .body(request)
778 .map_err(|e| CompletionError::HttpError(e.into()))?;
779
780 async move {
781 let response = self.client.send(request).await?;
782
783 if response.status().is_success() {
784 let bytes: Vec<u8> = response.into_body().await?;
785 let text = String::from_utf8_lossy(&bytes);
786
787 tracing::debug!(target: "rig", "Huggingface completion error: {}", text);
788
789 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&bytes)? {
790 ApiResponse::Ok(response) => {
791 if enabled!(Level::TRACE) {
792 tracing::trace!(
793 target: "rig::completions",
794 "Huggingface completion response: {}",
795 serde_json::to_string_pretty(&response)?
796 );
797 }
798
799 let span = tracing::Span::current();
800 span.record_token_usage(&response.usage);
801 span.record_response_metadata(&response);
802
803 response.try_into()
804 }
805 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())),
806 }
807 } else {
808 let status = response.status();
809 let text: Vec<u8> = response.into_body().await?;
810 let text: String = String::from_utf8_lossy(&text).into();
811
812 Err(CompletionError::ProviderError(format!(
813 "{}: {}",
814 status, text
815 )))
816 }
817 }
818 .instrument(span)
819 .await
820 }
821
822 async fn stream(
823 &self,
824 request: CompletionRequest,
825 ) -> Result<
826 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
827 CompletionError,
828 > {
829 CompletionModel::stream(self, request).await
830 }
831}
832
833#[cfg(test)]
834mod tests {
835 use super::*;
836 use serde_path_to_error::deserialize;
837
838 #[test]
839 fn test_huggingface_request_uses_request_model_override() {
840 let request = CompletionRequest {
841 model: Some("meta-llama/Meta-Llama-3.1-8B-Instruct".to_string()),
842 preamble: None,
843 chat_history: crate::OneOrMany::one("Hello".into()),
844 documents: vec![],
845 tools: vec![],
846 temperature: None,
847 max_tokens: None,
848 tool_choice: None,
849 additional_params: None,
850 output_schema: None,
851 };
852
853 let hf_request = HuggingfaceCompletionRequest::try_from(("mistralai/Mistral-7B", request))
854 .expect("request conversion should succeed");
855 let serialized = serde_json::to_value(hf_request).expect("serialization should succeed");
856
857 assert_eq!(serialized["model"], "meta-llama/Meta-Llama-3.1-8B-Instruct");
858 }
859
860 #[test]
861 fn test_huggingface_request_uses_default_model_when_override_unset() {
862 let request = CompletionRequest {
863 model: None,
864 preamble: None,
865 chat_history: crate::OneOrMany::one("Hello".into()),
866 documents: vec![],
867 tools: vec![],
868 temperature: None,
869 max_tokens: None,
870 tool_choice: None,
871 additional_params: None,
872 output_schema: None,
873 };
874
875 let hf_request = HuggingfaceCompletionRequest::try_from(("mistralai/Mistral-7B", request))
876 .expect("request conversion should succeed");
877 let serialized = serde_json::to_value(hf_request).expect("serialization should succeed");
878
879 assert_eq!(serialized["model"], "mistralai/Mistral-7B");
880 }
881
882 #[test]
883 fn test_deserialize_message() {
884 let assistant_message_json = r#"
885 {
886 "role": "assistant",
887 "content": "\n\nHello there, how may I assist you today?"
888 }
889 "#;
890
891 let assistant_message_json2 = r#"
892 {
893 "role": "assistant",
894 "content": [
895 {
896 "type": "text",
897 "text": "\n\nHello there, how may I assist you today?"
898 }
899 ],
900 "tool_calls": null
901 }
902 "#;
903
904 let assistant_message_json3 = r#"
905 {
906 "role": "assistant",
907 "tool_calls": [
908 {
909 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
910 "type": "function",
911 "function": {
912 "name": "subtract",
913 "arguments": {"x": 2, "y": 5}
914 }
915 }
916 ],
917 "content": null,
918 "refusal": null
919 }
920 "#;
921
922 let user_message_json = r#"
923 {
924 "role": "user",
925 "content": [
926 {
927 "type": "text",
928 "text": "What's in this image?"
929 },
930 {
931 "type": "image_url",
932 "image_url": {
933 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
934 }
935 }
936 ]
937 }
938 "#;
939
940 let assistant_message: Message = {
941 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
942 deserialize(jd).unwrap_or_else(|err| {
943 panic!(
944 "Deserialization error at {} ({}:{}): {}",
945 err.path(),
946 err.inner().line(),
947 err.inner().column(),
948 err
949 );
950 })
951 };
952
953 let assistant_message2: Message = {
954 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
955 deserialize(jd).unwrap_or_else(|err| {
956 panic!(
957 "Deserialization error at {} ({}:{}): {}",
958 err.path(),
959 err.inner().line(),
960 err.inner().column(),
961 err
962 );
963 })
964 };
965
966 let assistant_message3: Message = {
967 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
968 &mut serde_json::Deserializer::from_str(assistant_message_json3);
969 deserialize(jd).unwrap_or_else(|err| {
970 panic!(
971 "Deserialization error at {} ({}:{}): {}",
972 err.path(),
973 err.inner().line(),
974 err.inner().column(),
975 err
976 );
977 })
978 };
979
980 let user_message: Message = {
981 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
982 deserialize(jd).unwrap_or_else(|err| {
983 panic!(
984 "Deserialization error at {} ({}:{}): {}",
985 err.path(),
986 err.inner().line(),
987 err.inner().column(),
988 err
989 );
990 })
991 };
992
993 match assistant_message {
994 Message::Assistant { content, .. } => {
995 assert_eq!(
996 content[0],
997 AssistantContent::Text {
998 text: "\n\nHello there, how may I assist you today?".to_string()
999 }
1000 );
1001 }
1002 _ => panic!("Expected assistant message"),
1003 }
1004
1005 match assistant_message2 {
1006 Message::Assistant {
1007 content,
1008 tool_calls,
1009 ..
1010 } => {
1011 assert_eq!(
1012 content[0],
1013 AssistantContent::Text {
1014 text: "\n\nHello there, how may I assist you today?".to_string()
1015 }
1016 );
1017
1018 assert_eq!(tool_calls, vec![]);
1019 }
1020 _ => panic!("Expected assistant message"),
1021 }
1022
1023 match assistant_message3 {
1024 Message::Assistant {
1025 content,
1026 tool_calls,
1027 ..
1028 } => {
1029 assert!(content.is_empty());
1030 assert_eq!(
1031 tool_calls[0],
1032 ToolCall {
1033 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
1034 r#type: ToolType::Function,
1035 function: Function {
1036 name: "subtract".to_string(),
1037 arguments: serde_json::json!({"x": 2, "y": 5}),
1038 },
1039 }
1040 );
1041 }
1042 _ => panic!("Expected assistant message"),
1043 }
1044
1045 match user_message {
1046 Message::User { content, .. } => {
1047 let (first, second) = {
1048 let mut iter = content.into_iter();
1049 (iter.next().unwrap(), iter.next().unwrap())
1050 };
1051 assert_eq!(
1052 first,
1053 UserContent::Text {
1054 text: "What's in this image?".to_string()
1055 }
1056 );
1057 assert_eq!(second, UserContent::ImageUrl { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string() } });
1058 }
1059 _ => panic!("Expected user message"),
1060 }
1061 }
1062
1063 #[test]
1064 fn test_message_to_message_conversion() {
1065 let user_message = message::Message::User {
1066 content: OneOrMany::one(message::UserContent::text("Hello")),
1067 };
1068
1069 let assistant_message = message::Message::Assistant {
1070 id: None,
1071 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
1072 };
1073
1074 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
1075 let converted_assistant_message: Vec<Message> =
1076 assistant_message.clone().try_into().unwrap();
1077
1078 match converted_user_message[0].clone() {
1079 Message::User { content, .. } => {
1080 assert_eq!(
1081 content.first(),
1082 UserContent::Text {
1083 text: "Hello".to_string()
1084 }
1085 );
1086 }
1087 _ => panic!("Expected user message"),
1088 }
1089
1090 match converted_assistant_message[0].clone() {
1091 Message::Assistant { content, .. } => {
1092 assert_eq!(
1093 content[0],
1094 AssistantContent::Text {
1095 text: "Hi there!".to_string()
1096 }
1097 );
1098 }
1099 _ => panic!("Expected assistant message"),
1100 }
1101
1102 let original_user_message: message::Message =
1103 converted_user_message[0].clone().try_into().unwrap();
1104 let original_assistant_message: message::Message =
1105 converted_assistant_message[0].clone().try_into().unwrap();
1106
1107 assert_eq!(original_user_message, user_message);
1108 assert_eq!(original_assistant_message, assistant_message);
1109 }
1110
1111 #[test]
1112 fn test_message_from_message_conversion() {
1113 let user_message = Message::User {
1114 content: OneOrMany::one(UserContent::Text {
1115 text: "Hello".to_string(),
1116 }),
1117 };
1118
1119 let assistant_message = Message::Assistant {
1120 content: vec![AssistantContent::Text {
1121 text: "Hi there!".to_string(),
1122 }],
1123 tool_calls: vec![],
1124 };
1125
1126 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1127 let converted_assistant_message: message::Message =
1128 assistant_message.clone().try_into().unwrap();
1129
1130 match converted_user_message.clone() {
1131 message::Message::User { content } => {
1132 assert_eq!(content.first(), message::UserContent::text("Hello"));
1133 }
1134 _ => panic!("Expected user message"),
1135 }
1136
1137 match converted_assistant_message.clone() {
1138 message::Message::Assistant { content, .. } => {
1139 assert_eq!(
1140 content.first(),
1141 message::AssistantContent::text("Hi there!")
1142 );
1143 }
1144 _ => panic!("Expected assistant message"),
1145 }
1146
1147 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
1148 let original_assistant_message: Vec<Message> =
1149 converted_assistant_message.try_into().unwrap();
1150
1151 assert_eq!(original_user_message[0], user_message);
1152 assert_eq!(original_assistant_message[0], assistant_message);
1153 }
1154
1155 #[test]
1156 fn test_responses() {
1157 let fireworks_response_json = r#"
1158 {
1159 "choices": [
1160 {
1161 "finish_reason": "tool_calls",
1162 "index": 0,
1163 "message": {
1164 "role": "assistant",
1165 "tool_calls": [
1166 {
1167 "function": {
1168 "arguments": "{\"x\": 2, \"y\": 5}",
1169 "name": "subtract"
1170 },
1171 "id": "call_1BspL6mQqjKgvsQbH1TIYkHf",
1172 "index": 0,
1173 "type": "function"
1174 }
1175 ]
1176 }
1177 }
1178 ],
1179 "created": 1740704000,
1180 "id": "2a81f6a1-4866-42fb-9902-2655a2b5b1ff",
1181 "model": "accounts/fireworks/models/deepseek-v3",
1182 "object": "chat.completion",
1183 "usage": {
1184 "completion_tokens": 26,
1185 "prompt_tokens": 248,
1186 "total_tokens": 274
1187 }
1188 }
1189 "#;
1190
1191 let novita_response_json = r#"
1192 {
1193 "choices": [
1194 {
1195 "finish_reason": "tool_calls",
1196 "index": 0,
1197 "logprobs": null,
1198 "message": {
1199 "audio": null,
1200 "content": null,
1201 "function_call": null,
1202 "reasoning_content": null,
1203 "refusal": null,
1204 "role": "assistant",
1205 "tool_calls": [
1206 {
1207 "function": {
1208 "arguments": "{\"x\": \"2\", \"y\": \"5\"}",
1209 "name": "subtract"
1210 },
1211 "id": "chatcmpl-tool-f6d2af7c8dc041058f95e2c2eede45c5",
1212 "type": "function"
1213 }
1214 ]
1215 },
1216 "stop_reason": 128008
1217 }
1218 ],
1219 "created": 1740704592,
1220 "id": "chatcmpl-a92c60ae125c47c998ecdcb53387fed4",
1221 "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-fast",
1222 "object": "chat.completion",
1223 "prompt_logprobs": null,
1224 "service_tier": null,
1225 "system_fingerprint": null,
1226 "usage": {
1227 "completion_tokens": 28,
1228 "completion_tokens_details": null,
1229 "prompt_tokens": 335,
1230 "prompt_tokens_details": null,
1231 "total_tokens": 363
1232 }
1233 }
1234 "#;
1235
1236 let _firework_response: CompletionResponse = {
1237 let jd = &mut serde_json::Deserializer::from_str(fireworks_response_json);
1238 deserialize(jd).unwrap_or_else(|err| {
1239 panic!(
1240 "Deserialization error at {} ({}:{}): {}",
1241 err.path(),
1242 err.inner().line(),
1243 err.inner().column(),
1244 err
1245 );
1246 })
1247 };
1248
1249 let _novita_response: CompletionResponse = {
1250 let jd = &mut serde_json::Deserializer::from_str(novita_response_json);
1251 deserialize(jd).unwrap_or_else(|err| {
1252 panic!(
1253 "Deserialization error at {} ({}:{}): {}",
1254 err.path(),
1255 err.inner().line(),
1256 err.inner().column(),
1257 err
1258 );
1259 })
1260 };
1261 }
1262
1263 #[test]
1264 fn test_assistant_reasoning_is_silently_skipped() {
1265 let assistant = message::Message::Assistant {
1266 id: None,
1267 content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1268 };
1269
1270 let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
1271 assert!(converted.is_empty());
1272 }
1273
1274 #[test]
1275 fn test_assistant_text_and_tool_call_are_preserved_when_reasoning_present() {
1276 let assistant = message::Message::Assistant {
1277 id: None,
1278 content: OneOrMany::many(vec![
1279 message::AssistantContent::reasoning("hidden"),
1280 message::AssistantContent::text("visible"),
1281 message::AssistantContent::tool_call(
1282 "call_1",
1283 "subtract",
1284 serde_json::json!({"x": 2, "y": 1}),
1285 ),
1286 ])
1287 .expect("non-empty assistant content"),
1288 };
1289
1290 let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
1291 assert_eq!(converted.len(), 1);
1292
1293 match &converted[0] {
1294 Message::Assistant {
1295 content,
1296 tool_calls,
1297 ..
1298 } => {
1299 assert_eq!(
1300 content,
1301 &vec![AssistantContent::Text {
1302 text: "visible".to_string()
1303 }]
1304 );
1305 assert_eq!(tool_calls.len(), 1);
1306 assert_eq!(tool_calls[0].id, "call_1");
1307 assert_eq!(tool_calls[0].function.name, "subtract");
1308 assert_eq!(
1309 tool_calls[0].function.arguments,
1310 serde_json::json!({"x": 2, "y": 1})
1311 );
1312 }
1313 _ => panic!("expected assistant message"),
1314 }
1315 }
1316
1317 #[test]
1318 fn test_request_conversion_errors_when_all_messages_are_filtered() {
1319 let request = completion::CompletionRequest {
1320 preamble: None,
1321 chat_history: OneOrMany::one(message::Message::Assistant {
1322 id: None,
1323 content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1324 }),
1325 documents: vec![],
1326 tools: vec![],
1327 temperature: None,
1328 max_tokens: None,
1329 tool_choice: None,
1330 additional_params: None,
1331 model: None,
1332 output_schema: None,
1333 };
1334
1335 let result = HuggingfaceCompletionRequest::try_from(("meta/test-model", request));
1336 assert!(matches!(result, Err(CompletionError::RequestError(_))));
1337 }
1338}