1use super::client::Client;
2use crate::completion::GetTokenUsage;
3use crate::http_client::HttpClientExt;
4use crate::providers::openai::StreamingCompletionResponse;
5use crate::telemetry::SpanCombinator;
6use crate::{
7 OneOrMany,
8 completion::{self, CompletionError, CompletionRequest},
9 json_utils,
10 message::{self},
11 one_or_many::string_or_one_or_many,
12};
13use serde::{Deserialize, Deserializer, Serialize, Serializer};
14use serde_json::Value;
15use std::{convert::Infallible, str::FromStr};
16use tracing::{Level, enabled, info_span};
17use tracing_futures::Instrument;
18
19#[derive(Debug, Deserialize)]
20#[serde(untagged)]
21pub enum ApiResponse<T> {
22 Ok(T),
23 Err(Value),
24}
25
26pub const GEMMA_2: &str = "google/gemma-2-2b-it";
33pub const META_LLAMA_3_1: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
35pub const SMALLTHINKER_PREVIEW: &str = "PowerInfer/SmallThinker-3B-Preview";
37pub const QWEN2_5: &str = "Qwen/Qwen2.5-7B-Instruct";
39pub const QWEN2_5_CODER: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
41
42pub const QWEN2_VL: &str = "Qwen/Qwen2-VL-7B-Instruct";
46pub const QWEN_QVQ_PREVIEW: &str = "Qwen/QVQ-72B-Preview";
48
49#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
50pub struct Function {
51 name: String,
52 #[serde(
53 serialize_with = "json_utils::stringified_json::serialize",
54 deserialize_with = "deserialize_arguments"
55 )]
56 pub arguments: serde_json::Value,
57}
58
59fn deserialize_arguments<'de, D>(deserializer: D) -> Result<Value, D::Error>
60where
61 D: Deserializer<'de>,
62{
63 let value = Value::deserialize(deserializer)?;
64
65 match value {
66 Value::String(s) => serde_json::from_str(&s).map_err(serde::de::Error::custom),
67 other => Ok(other),
68 }
69}
70
71impl From<Function> for message::ToolFunction {
72 fn from(value: Function) -> Self {
73 message::ToolFunction {
74 name: value.name,
75 arguments: value.arguments,
76 }
77 }
78}
79
80#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
81#[serde(rename_all = "lowercase")]
82pub enum ToolType {
83 #[default]
84 Function,
85}
86
87#[derive(Debug, Deserialize, Serialize, Clone)]
88pub struct ToolDefinition {
89 pub r#type: String,
90 pub function: completion::ToolDefinition,
91}
92
93impl From<completion::ToolDefinition> for ToolDefinition {
94 fn from(tool: completion::ToolDefinition) -> Self {
95 Self {
96 r#type: "function".into(),
97 function: tool,
98 }
99 }
100}
101
102#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
103pub struct ToolCall {
104 pub id: String,
105 pub r#type: ToolType,
106 pub function: Function,
107}
108
109impl From<ToolCall> for message::ToolCall {
110 fn from(value: ToolCall) -> Self {
111 message::ToolCall {
112 id: value.id,
113 call_id: None,
114 function: value.function.into(),
115 signature: None,
116 additional_params: None,
117 }
118 }
119}
120
121impl From<message::ToolCall> for ToolCall {
122 fn from(value: message::ToolCall) -> Self {
123 ToolCall {
124 id: value.id,
125 r#type: ToolType::Function,
126 function: Function {
127 name: value.function.name,
128 arguments: value.function.arguments,
129 },
130 }
131 }
132}
133
134#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
135pub struct ImageUrl {
136 url: String,
137}
138
139#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
140#[serde(tag = "type", rename_all = "lowercase")]
141pub enum UserContent {
142 Text {
143 text: String,
144 },
145 #[serde(rename = "image_url")]
146 ImageUrl {
147 image_url: ImageUrl,
148 },
149}
150
151impl FromStr for UserContent {
152 type Err = Infallible;
153
154 fn from_str(s: &str) -> Result<Self, Self::Err> {
155 Ok(UserContent::Text {
156 text: s.to_string(),
157 })
158 }
159}
160
161#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
162#[serde(tag = "type", rename_all = "lowercase")]
163pub enum AssistantContent {
164 Text { text: String },
165}
166
167impl FromStr for AssistantContent {
168 type Err = Infallible;
169
170 fn from_str(s: &str) -> Result<Self, Self::Err> {
171 Ok(AssistantContent::Text {
172 text: s.to_string(),
173 })
174 }
175}
176
177#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
178#[serde(tag = "type", rename_all = "lowercase")]
179pub enum SystemContent {
180 Text { text: String },
181}
182
183impl FromStr for SystemContent {
184 type Err = Infallible;
185
186 fn from_str(s: &str) -> Result<Self, Self::Err> {
187 Ok(SystemContent::Text {
188 text: s.to_string(),
189 })
190 }
191}
192
193impl From<UserContent> for message::UserContent {
194 fn from(value: UserContent) -> Self {
195 match value {
196 UserContent::Text { text } => message::UserContent::text(text),
197 UserContent::ImageUrl { image_url } => {
198 message::UserContent::image_url(image_url.url, None, None)
199 }
200 }
201 }
202}
203
204impl TryFrom<message::UserContent> for UserContent {
205 type Error = message::MessageError;
206
207 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
208 match content {
209 message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }),
210 message::UserContent::Document(message::Document {
211 data: message::DocumentSourceKind::Raw(raw),
212 ..
213 }) => {
214 let text = String::from_utf8_lossy(raw.as_slice()).into();
215 Ok(UserContent::Text { text })
216 }
217 message::UserContent::Document(message::Document {
218 data:
219 message::DocumentSourceKind::Base64(text)
220 | message::DocumentSourceKind::String(text),
221 ..
222 }) => Ok(UserContent::Text { text }),
223 message::UserContent::Image(message::Image { data, .. }) => match data {
224 message::DocumentSourceKind::Url(url) => Ok(UserContent::ImageUrl {
225 image_url: ImageUrl { url },
226 }),
227 _ => Err(message::MessageError::ConversionError(
228 "Huggingface only supports images as urls".into(),
229 )),
230 },
231 _ => Err(message::MessageError::ConversionError(
232 "Huggingface only supports text and images".into(),
233 )),
234 }
235 }
236}
237
238#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
239#[serde(tag = "role", rename_all = "lowercase")]
240pub enum Message {
241 System {
242 #[serde(deserialize_with = "string_or_one_or_many")]
243 content: OneOrMany<SystemContent>,
244 },
245 User {
246 #[serde(deserialize_with = "string_or_one_or_many")]
247 content: OneOrMany<UserContent>,
248 },
249 Assistant {
250 #[serde(default, deserialize_with = "json_utils::string_or_vec")]
251 content: Vec<AssistantContent>,
252 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
253 tool_calls: Vec<ToolCall>,
254 },
255 #[serde(rename = "tool", alias = "Tool")]
256 ToolResult {
257 name: String,
258 #[serde(skip_serializing_if = "Option::is_none")]
259 arguments: Option<serde_json::Value>,
260 #[serde(
261 deserialize_with = "string_or_one_or_many",
262 serialize_with = "serialize_tool_content"
263 )]
264 content: OneOrMany<String>,
265 },
266}
267
268fn serialize_tool_content<S>(content: &OneOrMany<String>, serializer: S) -> Result<S::Ok, S::Error>
269where
270 S: Serializer,
271{
272 let joined = content
274 .iter()
275 .map(String::as_str)
276 .collect::<Vec<_>>()
277 .join("\n");
278 serializer.serialize_str(&joined)
279}
280
281impl Message {
282 pub fn system(content: &str) -> Self {
283 Message::System {
284 content: OneOrMany::one(SystemContent::Text {
285 text: content.to_string(),
286 }),
287 }
288 }
289}
290
291impl TryFrom<message::Message> for Vec<Message> {
292 type Error = message::MessageError;
293
294 fn try_from(message: message::Message) -> Result<Vec<Message>, Self::Error> {
295 match message {
296 message::Message::User { content } => {
297 let (tool_results, other_content): (Vec<_>, Vec<_>) = content
298 .into_iter()
299 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
300
301 if !tool_results.is_empty() {
302 tool_results
303 .into_iter()
304 .map(|content| match content {
305 message::UserContent::ToolResult(message::ToolResult {
306 id,
307 content,
308 ..
309 }) => Ok::<_, message::MessageError>(Message::ToolResult {
310 name: id,
311 arguments: None,
312 content: content.try_map(|content| match content {
313 message::ToolResultContent::Text(message::Text { text }) => {
314 Ok(text)
315 }
316 _ => Err(message::MessageError::ConversionError(
317 "Tool result content does not support non-text".into(),
318 )),
319 })?,
320 }),
321 _ => unreachable!(),
322 })
323 .collect::<Result<Vec<_>, _>>()
324 } else {
325 let other_content = OneOrMany::many(other_content).expect(
326 "There must be other content here if there were no tool result content",
327 );
328
329 Ok(vec![Message::User {
330 content: other_content.try_map(|content| match content {
331 message::UserContent::Text(text) => {
332 Ok(UserContent::Text { text: text.text })
333 }
334 message::UserContent::Image(image) => {
335 let url = image.try_into_url()?;
336
337 Ok(UserContent::ImageUrl {
338 image_url: ImageUrl { url },
339 })
340 }
341 message::UserContent::Document(message::Document {
342 data: message::DocumentSourceKind::Raw(raw), ..
343 }) => {
344 let text = String::from_utf8_lossy(raw.as_slice()).into();
345 Ok(UserContent::Text { text })
346 }
347 message::UserContent::Document(message::Document {
348 data: message::DocumentSourceKind::Base64(text) | message::DocumentSourceKind::String(text), ..
349 }) => {
350 Ok(UserContent::Text { text })
351 }
352 _ => Err(message::MessageError::ConversionError(
353 "Huggingface inputs only support text and image URLs (both base64-encoded images and regular URLs)".into(),
354 )),
355 })?,
356 }])
357 }
358 }
359 message::Message::Assistant { content, .. } => {
360 let mut text_content = Vec::new();
361 let mut tool_calls = Vec::new();
362
363 for content in content {
364 match content {
365 message::AssistantContent::Text(text) => text_content.push(text),
366 message::AssistantContent::ToolCall(tool_call) => {
367 tool_calls.push(tool_call)
368 }
369 message::AssistantContent::Reasoning(_) => {
370 }
373 message::AssistantContent::Image(_) => {
374 panic!("Image content is not supported on HuggingFace via Rig");
375 }
376 }
377 }
378
379 if text_content.is_empty() && tool_calls.is_empty() {
380 return Ok(vec![]);
381 }
382
383 Ok(vec![Message::Assistant {
384 content: text_content
385 .into_iter()
386 .map(|content| AssistantContent::Text { text: content.text })
387 .collect::<Vec<_>>(),
388 tool_calls: tool_calls
389 .into_iter()
390 .map(|tool_call| tool_call.into())
391 .collect::<Vec<_>>(),
392 }])
393 }
394 }
395 }
396}
397
398impl TryFrom<Message> for message::Message {
399 type Error = message::MessageError;
400
401 fn try_from(message: Message) -> Result<Self, Self::Error> {
402 Ok(match message {
403 Message::User { content, .. } => message::Message::User {
404 content: content.map(|content| content.into()),
405 },
406 Message::Assistant {
407 content,
408 tool_calls,
409 ..
410 } => {
411 let mut content = content
412 .into_iter()
413 .map(|content| match content {
414 AssistantContent::Text { text } => message::AssistantContent::text(text),
415 })
416 .collect::<Vec<_>>();
417
418 content.extend(
419 tool_calls
420 .into_iter()
421 .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
422 .collect::<Result<Vec<_>, _>>()?,
423 );
424
425 message::Message::Assistant {
426 id: None,
427 content: OneOrMany::many(content).map_err(|_| {
428 message::MessageError::ConversionError(
429 "Neither `content` nor `tool_calls` was provided to the Message"
430 .to_owned(),
431 )
432 })?,
433 }
434 }
435
436 Message::ToolResult { name, content, .. } => message::Message::User {
437 content: OneOrMany::one(message::UserContent::tool_result(
438 name,
439 content.map(message::ToolResultContent::text),
440 )),
441 },
442
443 Message::System { content, .. } => message::Message::User {
446 content: content.map(|c| match c {
447 SystemContent::Text { text } => message::UserContent::text(text),
448 }),
449 },
450 })
451 }
452}
453
454#[derive(Clone, Debug, Deserialize, Serialize)]
455pub struct Choice {
456 pub finish_reason: String,
457 pub index: usize,
458 #[serde(default)]
459 pub logprobs: serde_json::Value,
460 pub message: Message,
461}
462
463#[derive(Debug, Deserialize, Clone, Serialize)]
464pub struct Usage {
465 pub completion_tokens: i32,
466 pub prompt_tokens: i32,
467 pub total_tokens: i32,
468}
469
470impl GetTokenUsage for Usage {
471 fn token_usage(&self) -> Option<crate::completion::Usage> {
472 let mut usage = crate::completion::Usage::new();
473 usage.input_tokens = self.prompt_tokens as u64;
474 usage.output_tokens = self.completion_tokens as u64;
475 usage.total_tokens = self.total_tokens as u64;
476
477 Some(usage)
478 }
479}
480
481#[derive(Clone, Debug, Deserialize, Serialize)]
482pub struct CompletionResponse {
483 pub created: i32,
484 pub id: String,
485 pub model: String,
486 pub choices: Vec<Choice>,
487 #[serde(default, deserialize_with = "default_string_on_null")]
488 pub system_fingerprint: String,
489 pub usage: Usage,
490}
491
492impl crate::telemetry::ProviderResponseExt for CompletionResponse {
493 type OutputMessage = Choice;
494 type Usage = Usage;
495
496 fn get_response_id(&self) -> Option<String> {
497 Some(self.id.clone())
498 }
499
500 fn get_response_model_name(&self) -> Option<String> {
501 Some(self.model.clone())
502 }
503
504 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
505 self.choices.clone()
506 }
507
508 fn get_text_response(&self) -> Option<String> {
509 let text_response = self
510 .choices
511 .iter()
512 .filter_map(|x| {
513 let Message::User { ref content } = x.message else {
514 return None;
515 };
516
517 let text = content
518 .iter()
519 .filter_map(|x| {
520 if let UserContent::Text { text } = x {
521 Some(text.clone())
522 } else {
523 None
524 }
525 })
526 .collect::<Vec<String>>();
527
528 if text.is_empty() {
529 None
530 } else {
531 Some(text.join("\n"))
532 }
533 })
534 .collect::<Vec<String>>()
535 .join("\n");
536
537 if text_response.is_empty() {
538 None
539 } else {
540 Some(text_response)
541 }
542 }
543
544 fn get_usage(&self) -> Option<Self::Usage> {
545 Some(self.usage.clone())
546 }
547}
548
549fn default_string_on_null<'de, D>(deserializer: D) -> Result<String, D::Error>
550where
551 D: Deserializer<'de>,
552{
553 match Option::<String>::deserialize(deserializer)? {
554 Some(value) => Ok(value), None => Ok(String::default()), }
557}
558
559impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
560 type Error = CompletionError;
561
562 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
563 let choice = response.choices.first().ok_or_else(|| {
564 CompletionError::ResponseError("Response contained no choices".to_owned())
565 })?;
566
567 let content = match &choice.message {
568 Message::Assistant {
569 content,
570 tool_calls,
571 ..
572 } => {
573 let mut content = content
574 .iter()
575 .map(|c| match c {
576 AssistantContent::Text { text } => message::AssistantContent::text(text),
577 })
578 .collect::<Vec<_>>();
579
580 content.extend(
581 tool_calls
582 .iter()
583 .map(|call| {
584 completion::AssistantContent::tool_call(
585 &call.id,
586 &call.function.name,
587 call.function.arguments.clone(),
588 )
589 })
590 .collect::<Vec<_>>(),
591 );
592 Ok(content)
593 }
594 _ => Err(CompletionError::ResponseError(
595 "Response did not contain a valid message or tool call".into(),
596 )),
597 }?;
598
599 let choice = OneOrMany::many(content).map_err(|_| {
600 CompletionError::ResponseError(
601 "Response contained no message or tool call (empty)".to_owned(),
602 )
603 })?;
604
605 let usage = completion::Usage {
606 input_tokens: response.usage.prompt_tokens as u64,
607 output_tokens: response.usage.completion_tokens as u64,
608 total_tokens: response.usage.total_tokens as u64,
609 cached_input_tokens: 0,
610 };
611
612 Ok(completion::CompletionResponse {
613 choice,
614 usage,
615 raw_response: response,
616 message_id: None,
617 })
618 }
619}
620
621#[derive(Debug, Serialize, Deserialize)]
622pub(super) struct HuggingfaceCompletionRequest {
623 model: String,
624 pub messages: Vec<Message>,
625 #[serde(skip_serializing_if = "Option::is_none")]
626 temperature: Option<f64>,
627 #[serde(skip_serializing_if = "Vec::is_empty")]
628 tools: Vec<ToolDefinition>,
629 #[serde(skip_serializing_if = "Option::is_none")]
630 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
631 #[serde(flatten, skip_serializing_if = "Option::is_none")]
632 pub additional_params: Option<serde_json::Value>,
633}
634
635impl TryFrom<(&str, CompletionRequest)> for HuggingfaceCompletionRequest {
636 type Error = CompletionError;
637
638 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
639 if req.output_schema.is_some() {
640 tracing::warn!("Structured outputs currently not supported for Huggingface");
641 }
642 let model = req.model.clone().unwrap_or_else(|| model.to_string());
643 let mut full_history: Vec<Message> = match &req.preamble {
644 Some(preamble) => vec![Message::system(preamble)],
645 None => vec![],
646 };
647 if let Some(docs) = req.normalized_documents() {
648 let docs: Vec<Message> = docs.try_into()?;
649 full_history.extend(docs);
650 }
651
652 let chat_history: Vec<Message> = req
653 .chat_history
654 .clone()
655 .into_iter()
656 .map(|message| message.try_into())
657 .collect::<Result<Vec<Vec<Message>>, _>>()?
658 .into_iter()
659 .flatten()
660 .collect();
661
662 full_history.extend(chat_history);
663
664 if full_history.is_empty() {
665 return Err(CompletionError::RequestError(
666 std::io::Error::new(
667 std::io::ErrorKind::InvalidInput,
668 "HuggingFace request has no provider-compatible messages after conversion",
669 )
670 .into(),
671 ));
672 }
673
674 let tool_choice = req
675 .tool_choice
676 .clone()
677 .map(crate::providers::openai::completion::ToolChoice::try_from)
678 .transpose()?;
679
680 Ok(Self {
681 model: model.to_string(),
682 messages: full_history,
683 temperature: req.temperature,
684 tools: req
685 .tools
686 .clone()
687 .into_iter()
688 .map(ToolDefinition::from)
689 .collect::<Vec<_>>(),
690 tool_choice,
691 additional_params: req.additional_params,
692 })
693 }
694}
695
696#[derive(Clone)]
697pub struct CompletionModel<T = reqwest::Client> {
698 pub(crate) client: Client<T>,
699 pub model: String,
701}
702
703impl<T> CompletionModel<T> {
704 pub fn new(client: Client<T>, model: &str) -> Self {
705 Self {
706 client,
707 model: model.to_string(),
708 }
709 }
710}
711
712impl<T> completion::CompletionModel for CompletionModel<T>
713where
714 T: HttpClientExt + Clone + 'static,
715{
716 type Response = CompletionResponse;
717 type StreamingResponse = StreamingCompletionResponse;
718
719 type Client = Client<T>;
720
721 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
722 Self::new(client.clone(), &model.into())
723 }
724
725 async fn completion(
726 &self,
727 completion_request: CompletionRequest,
728 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
729 let request_model = completion_request
730 .model
731 .clone()
732 .unwrap_or_else(|| self.model.clone());
733 let span = if tracing::Span::current().is_disabled() {
734 info_span!(
735 target: "rig::completions",
736 "chat",
737 gen_ai.operation.name = "chat",
738 gen_ai.provider.name = "huggingface",
739 gen_ai.request.model = &request_model,
740 gen_ai.system_instructions = &completion_request.preamble,
741 gen_ai.response.id = tracing::field::Empty,
742 gen_ai.response.model = tracing::field::Empty,
743 gen_ai.usage.output_tokens = tracing::field::Empty,
744 gen_ai.usage.input_tokens = tracing::field::Empty,
745 )
746 } else {
747 tracing::Span::current()
748 };
749
750 let model = self.client.subprovider().model_identifier(&request_model);
751 let request = HuggingfaceCompletionRequest::try_from((model.as_ref(), completion_request))?;
752
753 if enabled!(Level::TRACE) {
754 tracing::trace!(
755 target: "rig::completions",
756 "Huggingface completion request: {}",
757 serde_json::to_string_pretty(&request)?
758 );
759 }
760
761 let request = serde_json::to_vec(&request)?;
762
763 let path = self
764 .client
765 .subprovider()
766 .completion_endpoint(&request_model);
767 let request = self
768 .client
769 .post(&path)?
770 .header("Content-Type", "application/json")
771 .body(request)
772 .map_err(|e| CompletionError::HttpError(e.into()))?;
773
774 async move {
775 let response = self.client.send(request).await?;
776
777 if response.status().is_success() {
778 let bytes: Vec<u8> = response.into_body().await?;
779 let text = String::from_utf8_lossy(&bytes);
780
781 tracing::debug!(target: "rig", "Huggingface completion error: {}", text);
782
783 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&bytes)? {
784 ApiResponse::Ok(response) => {
785 if enabled!(Level::TRACE) {
786 tracing::trace!(
787 target: "rig::completions",
788 "Huggingface completion response: {}",
789 serde_json::to_string_pretty(&response)?
790 );
791 }
792
793 let span = tracing::Span::current();
794 span.record_token_usage(&response.usage);
795 span.record_response_metadata(&response);
796
797 response.try_into()
798 }
799 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())),
800 }
801 } else {
802 let status = response.status();
803 let text: Vec<u8> = response.into_body().await?;
804 let text: String = String::from_utf8_lossy(&text).into();
805
806 Err(CompletionError::ProviderError(format!(
807 "{}: {}",
808 status, text
809 )))
810 }
811 }
812 .instrument(span)
813 .await
814 }
815
816 async fn stream(
817 &self,
818 request: CompletionRequest,
819 ) -> Result<
820 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
821 CompletionError,
822 > {
823 CompletionModel::stream(self, request).await
824 }
825}
826
827#[cfg(test)]
828mod tests {
829 use super::*;
830 use serde_path_to_error::deserialize;
831
832 #[test]
833 fn test_huggingface_request_uses_request_model_override() {
834 let request = CompletionRequest {
835 model: Some("meta-llama/Meta-Llama-3.1-8B-Instruct".to_string()),
836 preamble: None,
837 chat_history: crate::OneOrMany::one("Hello".into()),
838 documents: vec![],
839 tools: vec![],
840 temperature: None,
841 max_tokens: None,
842 tool_choice: None,
843 additional_params: None,
844 output_schema: None,
845 };
846
847 let hf_request = HuggingfaceCompletionRequest::try_from(("mistralai/Mistral-7B", request))
848 .expect("request conversion should succeed");
849 let serialized = serde_json::to_value(hf_request).expect("serialization should succeed");
850
851 assert_eq!(serialized["model"], "meta-llama/Meta-Llama-3.1-8B-Instruct");
852 }
853
854 #[test]
855 fn test_huggingface_request_uses_default_model_when_override_unset() {
856 let request = CompletionRequest {
857 model: None,
858 preamble: None,
859 chat_history: crate::OneOrMany::one("Hello".into()),
860 documents: vec![],
861 tools: vec![],
862 temperature: None,
863 max_tokens: None,
864 tool_choice: None,
865 additional_params: None,
866 output_schema: None,
867 };
868
869 let hf_request = HuggingfaceCompletionRequest::try_from(("mistralai/Mistral-7B", request))
870 .expect("request conversion should succeed");
871 let serialized = serde_json::to_value(hf_request).expect("serialization should succeed");
872
873 assert_eq!(serialized["model"], "mistralai/Mistral-7B");
874 }
875
876 #[test]
877 fn test_deserialize_message() {
878 let assistant_message_json = r#"
879 {
880 "role": "assistant",
881 "content": "\n\nHello there, how may I assist you today?"
882 }
883 "#;
884
885 let assistant_message_json2 = r#"
886 {
887 "role": "assistant",
888 "content": [
889 {
890 "type": "text",
891 "text": "\n\nHello there, how may I assist you today?"
892 }
893 ],
894 "tool_calls": null
895 }
896 "#;
897
898 let assistant_message_json3 = r#"
899 {
900 "role": "assistant",
901 "tool_calls": [
902 {
903 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
904 "type": "function",
905 "function": {
906 "name": "subtract",
907 "arguments": {"x": 2, "y": 5}
908 }
909 }
910 ],
911 "content": null,
912 "refusal": null
913 }
914 "#;
915
916 let user_message_json = r#"
917 {
918 "role": "user",
919 "content": [
920 {
921 "type": "text",
922 "text": "What's in this image?"
923 },
924 {
925 "type": "image_url",
926 "image_url": {
927 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
928 }
929 }
930 ]
931 }
932 "#;
933
934 let assistant_message: Message = {
935 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
936 deserialize(jd).unwrap_or_else(|err| {
937 panic!(
938 "Deserialization error at {} ({}:{}): {}",
939 err.path(),
940 err.inner().line(),
941 err.inner().column(),
942 err
943 );
944 })
945 };
946
947 let assistant_message2: Message = {
948 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
949 deserialize(jd).unwrap_or_else(|err| {
950 panic!(
951 "Deserialization error at {} ({}:{}): {}",
952 err.path(),
953 err.inner().line(),
954 err.inner().column(),
955 err
956 );
957 })
958 };
959
960 let assistant_message3: Message = {
961 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
962 &mut serde_json::Deserializer::from_str(assistant_message_json3);
963 deserialize(jd).unwrap_or_else(|err| {
964 panic!(
965 "Deserialization error at {} ({}:{}): {}",
966 err.path(),
967 err.inner().line(),
968 err.inner().column(),
969 err
970 );
971 })
972 };
973
974 let user_message: Message = {
975 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
976 deserialize(jd).unwrap_or_else(|err| {
977 panic!(
978 "Deserialization error at {} ({}:{}): {}",
979 err.path(),
980 err.inner().line(),
981 err.inner().column(),
982 err
983 );
984 })
985 };
986
987 match assistant_message {
988 Message::Assistant { content, .. } => {
989 assert_eq!(
990 content[0],
991 AssistantContent::Text {
992 text: "\n\nHello there, how may I assist you today?".to_string()
993 }
994 );
995 }
996 _ => panic!("Expected assistant message"),
997 }
998
999 match assistant_message2 {
1000 Message::Assistant {
1001 content,
1002 tool_calls,
1003 ..
1004 } => {
1005 assert_eq!(
1006 content[0],
1007 AssistantContent::Text {
1008 text: "\n\nHello there, how may I assist you today?".to_string()
1009 }
1010 );
1011
1012 assert_eq!(tool_calls, vec![]);
1013 }
1014 _ => panic!("Expected assistant message"),
1015 }
1016
1017 match assistant_message3 {
1018 Message::Assistant {
1019 content,
1020 tool_calls,
1021 ..
1022 } => {
1023 assert!(content.is_empty());
1024 assert_eq!(
1025 tool_calls[0],
1026 ToolCall {
1027 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
1028 r#type: ToolType::Function,
1029 function: Function {
1030 name: "subtract".to_string(),
1031 arguments: serde_json::json!({"x": 2, "y": 5}),
1032 },
1033 }
1034 );
1035 }
1036 _ => panic!("Expected assistant message"),
1037 }
1038
1039 match user_message {
1040 Message::User { content, .. } => {
1041 let (first, second) = {
1042 let mut iter = content.into_iter();
1043 (iter.next().unwrap(), iter.next().unwrap())
1044 };
1045 assert_eq!(
1046 first,
1047 UserContent::Text {
1048 text: "What's in this image?".to_string()
1049 }
1050 );
1051 assert_eq!(second, UserContent::ImageUrl { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string() } });
1052 }
1053 _ => panic!("Expected user message"),
1054 }
1055 }
1056
1057 #[test]
1058 fn test_message_to_message_conversion() {
1059 let user_message = message::Message::User {
1060 content: OneOrMany::one(message::UserContent::text("Hello")),
1061 };
1062
1063 let assistant_message = message::Message::Assistant {
1064 id: None,
1065 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
1066 };
1067
1068 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
1069 let converted_assistant_message: Vec<Message> =
1070 assistant_message.clone().try_into().unwrap();
1071
1072 match converted_user_message[0].clone() {
1073 Message::User { content, .. } => {
1074 assert_eq!(
1075 content.first(),
1076 UserContent::Text {
1077 text: "Hello".to_string()
1078 }
1079 );
1080 }
1081 _ => panic!("Expected user message"),
1082 }
1083
1084 match converted_assistant_message[0].clone() {
1085 Message::Assistant { content, .. } => {
1086 assert_eq!(
1087 content[0],
1088 AssistantContent::Text {
1089 text: "Hi there!".to_string()
1090 }
1091 );
1092 }
1093 _ => panic!("Expected assistant message"),
1094 }
1095
1096 let original_user_message: message::Message =
1097 converted_user_message[0].clone().try_into().unwrap();
1098 let original_assistant_message: message::Message =
1099 converted_assistant_message[0].clone().try_into().unwrap();
1100
1101 assert_eq!(original_user_message, user_message);
1102 assert_eq!(original_assistant_message, assistant_message);
1103 }
1104
1105 #[test]
1106 fn test_message_from_message_conversion() {
1107 let user_message = Message::User {
1108 content: OneOrMany::one(UserContent::Text {
1109 text: "Hello".to_string(),
1110 }),
1111 };
1112
1113 let assistant_message = Message::Assistant {
1114 content: vec![AssistantContent::Text {
1115 text: "Hi there!".to_string(),
1116 }],
1117 tool_calls: vec![],
1118 };
1119
1120 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1121 let converted_assistant_message: message::Message =
1122 assistant_message.clone().try_into().unwrap();
1123
1124 match converted_user_message.clone() {
1125 message::Message::User { content } => {
1126 assert_eq!(content.first(), message::UserContent::text("Hello"));
1127 }
1128 _ => panic!("Expected user message"),
1129 }
1130
1131 match converted_assistant_message.clone() {
1132 message::Message::Assistant { content, .. } => {
1133 assert_eq!(
1134 content.first(),
1135 message::AssistantContent::text("Hi there!")
1136 );
1137 }
1138 _ => panic!("Expected assistant message"),
1139 }
1140
1141 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
1142 let original_assistant_message: Vec<Message> =
1143 converted_assistant_message.try_into().unwrap();
1144
1145 assert_eq!(original_user_message[0], user_message);
1146 assert_eq!(original_assistant_message[0], assistant_message);
1147 }
1148
1149 #[test]
1150 fn test_responses() {
1151 let fireworks_response_json = r#"
1152 {
1153 "choices": [
1154 {
1155 "finish_reason": "tool_calls",
1156 "index": 0,
1157 "message": {
1158 "role": "assistant",
1159 "tool_calls": [
1160 {
1161 "function": {
1162 "arguments": "{\"x\": 2, \"y\": 5}",
1163 "name": "subtract"
1164 },
1165 "id": "call_1BspL6mQqjKgvsQbH1TIYkHf",
1166 "index": 0,
1167 "type": "function"
1168 }
1169 ]
1170 }
1171 }
1172 ],
1173 "created": 1740704000,
1174 "id": "2a81f6a1-4866-42fb-9902-2655a2b5b1ff",
1175 "model": "accounts/fireworks/models/deepseek-v3",
1176 "object": "chat.completion",
1177 "usage": {
1178 "completion_tokens": 26,
1179 "prompt_tokens": 248,
1180 "total_tokens": 274
1181 }
1182 }
1183 "#;
1184
1185 let novita_response_json = r#"
1186 {
1187 "choices": [
1188 {
1189 "finish_reason": "tool_calls",
1190 "index": 0,
1191 "logprobs": null,
1192 "message": {
1193 "audio": null,
1194 "content": null,
1195 "function_call": null,
1196 "reasoning_content": null,
1197 "refusal": null,
1198 "role": "assistant",
1199 "tool_calls": [
1200 {
1201 "function": {
1202 "arguments": "{\"x\": \"2\", \"y\": \"5\"}",
1203 "name": "subtract"
1204 },
1205 "id": "chatcmpl-tool-f6d2af7c8dc041058f95e2c2eede45c5",
1206 "type": "function"
1207 }
1208 ]
1209 },
1210 "stop_reason": 128008
1211 }
1212 ],
1213 "created": 1740704592,
1214 "id": "chatcmpl-a92c60ae125c47c998ecdcb53387fed4",
1215 "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-fast",
1216 "object": "chat.completion",
1217 "prompt_logprobs": null,
1218 "service_tier": null,
1219 "system_fingerprint": null,
1220 "usage": {
1221 "completion_tokens": 28,
1222 "completion_tokens_details": null,
1223 "prompt_tokens": 335,
1224 "prompt_tokens_details": null,
1225 "total_tokens": 363
1226 }
1227 }
1228 "#;
1229
1230 let _firework_response: CompletionResponse = {
1231 let jd = &mut serde_json::Deserializer::from_str(fireworks_response_json);
1232 deserialize(jd).unwrap_or_else(|err| {
1233 panic!(
1234 "Deserialization error at {} ({}:{}): {}",
1235 err.path(),
1236 err.inner().line(),
1237 err.inner().column(),
1238 err
1239 );
1240 })
1241 };
1242
1243 let _novita_response: CompletionResponse = {
1244 let jd = &mut serde_json::Deserializer::from_str(novita_response_json);
1245 deserialize(jd).unwrap_or_else(|err| {
1246 panic!(
1247 "Deserialization error at {} ({}:{}): {}",
1248 err.path(),
1249 err.inner().line(),
1250 err.inner().column(),
1251 err
1252 );
1253 })
1254 };
1255 }
1256
1257 #[test]
1258 fn test_assistant_reasoning_is_silently_skipped() {
1259 let assistant = message::Message::Assistant {
1260 id: None,
1261 content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1262 };
1263
1264 let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
1265 assert!(converted.is_empty());
1266 }
1267
1268 #[test]
1269 fn test_assistant_text_and_tool_call_are_preserved_when_reasoning_present() {
1270 let assistant = message::Message::Assistant {
1271 id: None,
1272 content: OneOrMany::many(vec![
1273 message::AssistantContent::reasoning("hidden"),
1274 message::AssistantContent::text("visible"),
1275 message::AssistantContent::tool_call(
1276 "call_1",
1277 "subtract",
1278 serde_json::json!({"x": 2, "y": 1}),
1279 ),
1280 ])
1281 .expect("non-empty assistant content"),
1282 };
1283
1284 let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
1285 assert_eq!(converted.len(), 1);
1286
1287 match &converted[0] {
1288 Message::Assistant {
1289 content,
1290 tool_calls,
1291 ..
1292 } => {
1293 assert_eq!(
1294 content,
1295 &vec![AssistantContent::Text {
1296 text: "visible".to_string()
1297 }]
1298 );
1299 assert_eq!(tool_calls.len(), 1);
1300 assert_eq!(tool_calls[0].id, "call_1");
1301 assert_eq!(tool_calls[0].function.name, "subtract");
1302 assert_eq!(
1303 tool_calls[0].function.arguments,
1304 serde_json::json!({"x": 2, "y": 1})
1305 );
1306 }
1307 _ => panic!("expected assistant message"),
1308 }
1309 }
1310
1311 #[test]
1312 fn test_request_conversion_errors_when_all_messages_are_filtered() {
1313 let request = completion::CompletionRequest {
1314 preamble: None,
1315 chat_history: OneOrMany::one(message::Message::Assistant {
1316 id: None,
1317 content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1318 }),
1319 documents: vec![],
1320 tools: vec![],
1321 temperature: None,
1322 max_tokens: None,
1323 tool_choice: None,
1324 additional_params: None,
1325 model: None,
1326 output_schema: None,
1327 };
1328
1329 let result = HuggingfaceCompletionRequest::try_from(("meta/test-model", request));
1330 assert!(matches!(result, Err(CompletionError::RequestError(_))));
1331 }
1332}