1use super::client::Client;
2use crate::completion::GetTokenUsage;
3use crate::http_client::HttpClientExt;
4use crate::providers::openai::StreamingCompletionResponse;
5use crate::telemetry::SpanCombinator;
6use crate::{
7 OneOrMany,
8 completion::{self, CompletionError, CompletionRequest},
9 json_utils,
10 message::{self},
11 one_or_many::string_or_one_or_many,
12};
13use serde::{Deserialize, Deserializer, Serialize, Serializer};
14use serde_json::Value;
15use std::{convert::Infallible, str::FromStr};
16use tracing::{Level, enabled, info_span};
17use tracing_futures::Instrument;
18
19#[derive(Debug, Deserialize)]
20#[serde(untagged)]
21pub enum ApiResponse<T> {
22 Ok(T),
23 Err(Value),
24}
25
26pub const GEMMA_2: &str = "google/gemma-2-2b-it";
33pub const META_LLAMA_3_1: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
35pub const SMALLTHINKER_PREVIEW: &str = "PowerInfer/SmallThinker-3B-Preview";
37pub const QWEN2_5: &str = "Qwen/Qwen2.5-7B-Instruct";
39pub const QWEN2_5_CODER: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
41
42pub const QWEN2_VL: &str = "Qwen/Qwen2-VL-7B-Instruct";
46pub const QWEN_QVQ_PREVIEW: &str = "Qwen/QVQ-72B-Preview";
48
49#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
50pub struct Function {
51 name: String,
52 #[serde(
53 serialize_with = "json_utils::stringified_json::serialize",
54 deserialize_with = "json_utils::stringified_json::deserialize_maybe_stringified"
55 )]
56 pub arguments: serde_json::Value,
57}
58
59impl From<Function> for message::ToolFunction {
60 fn from(value: Function) -> Self {
61 message::ToolFunction {
62 name: value.name,
63 arguments: value.arguments,
64 }
65 }
66}
67
68#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
69#[serde(rename_all = "lowercase")]
70pub enum ToolType {
71 #[default]
72 Function,
73}
74
75#[derive(Debug, Deserialize, Serialize, Clone)]
76pub struct ToolDefinition {
77 pub r#type: String,
78 pub function: completion::ToolDefinition,
79}
80
81impl From<completion::ToolDefinition> for ToolDefinition {
82 fn from(tool: completion::ToolDefinition) -> Self {
83 Self {
84 r#type: "function".into(),
85 function: tool,
86 }
87 }
88}
89
90#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
91pub struct ToolCall {
92 pub id: String,
93 pub r#type: ToolType,
94 pub function: Function,
95}
96
97impl From<ToolCall> for message::ToolCall {
98 fn from(value: ToolCall) -> Self {
99 message::ToolCall {
100 id: value.id,
101 call_id: None,
102 function: value.function.into(),
103 signature: None,
104 additional_params: None,
105 }
106 }
107}
108
109impl From<message::ToolCall> for ToolCall {
110 fn from(value: message::ToolCall) -> Self {
111 ToolCall {
112 id: value.id,
113 r#type: ToolType::Function,
114 function: Function {
115 name: value.function.name,
116 arguments: value.function.arguments,
117 },
118 }
119 }
120}
121
122#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
123pub struct ImageUrl {
124 url: String,
125}
126
127#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
128#[serde(tag = "type", rename_all = "lowercase")]
129pub enum UserContent {
130 Text {
131 text: String,
132 },
133 #[serde(rename = "image_url")]
134 ImageUrl {
135 image_url: ImageUrl,
136 },
137}
138
139impl FromStr for UserContent {
140 type Err = Infallible;
141
142 fn from_str(s: &str) -> Result<Self, Self::Err> {
143 Ok(UserContent::Text {
144 text: s.to_string(),
145 })
146 }
147}
148
149#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
150#[serde(tag = "type", rename_all = "lowercase")]
151pub enum AssistantContent {
152 Text { text: String },
153}
154
155impl FromStr for AssistantContent {
156 type Err = Infallible;
157
158 fn from_str(s: &str) -> Result<Self, Self::Err> {
159 Ok(AssistantContent::Text {
160 text: s.to_string(),
161 })
162 }
163}
164
165#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
166#[serde(tag = "type", rename_all = "lowercase")]
167pub enum SystemContent {
168 Text { text: String },
169}
170
171impl FromStr for SystemContent {
172 type Err = Infallible;
173
174 fn from_str(s: &str) -> Result<Self, Self::Err> {
175 Ok(SystemContent::Text {
176 text: s.to_string(),
177 })
178 }
179}
180
181impl From<UserContent> for message::UserContent {
182 fn from(value: UserContent) -> Self {
183 match value {
184 UserContent::Text { text } => message::UserContent::text(text),
185 UserContent::ImageUrl { image_url } => {
186 message::UserContent::image_url(image_url.url, None, None)
187 }
188 }
189 }
190}
191
192impl TryFrom<message::UserContent> for UserContent {
193 type Error = message::MessageError;
194
195 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
196 match content {
197 message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }),
198 message::UserContent::Document(message::Document {
199 data: message::DocumentSourceKind::Raw(raw),
200 ..
201 }) => {
202 let text = String::from_utf8_lossy(raw.as_slice()).into();
203 Ok(UserContent::Text { text })
204 }
205 message::UserContent::Document(message::Document {
206 data:
207 message::DocumentSourceKind::Base64(text)
208 | message::DocumentSourceKind::String(text),
209 ..
210 }) => Ok(UserContent::Text { text }),
211 message::UserContent::Image(message::Image { data, .. }) => match data {
212 message::DocumentSourceKind::Url(url) => Ok(UserContent::ImageUrl {
213 image_url: ImageUrl { url },
214 }),
215 _ => Err(message::MessageError::ConversionError(
216 "Huggingface only supports images as urls".into(),
217 )),
218 },
219 _ => Err(message::MessageError::ConversionError(
220 "Huggingface only supports text and images".into(),
221 )),
222 }
223 }
224}
225
226#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
227#[serde(tag = "role", rename_all = "lowercase")]
228pub enum Message {
229 System {
230 #[serde(deserialize_with = "string_or_one_or_many")]
231 content: OneOrMany<SystemContent>,
232 },
233 User {
234 #[serde(deserialize_with = "string_or_one_or_many")]
235 content: OneOrMany<UserContent>,
236 },
237 Assistant {
238 #[serde(default, deserialize_with = "json_utils::string_or_vec")]
239 content: Vec<AssistantContent>,
240 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
241 tool_calls: Vec<ToolCall>,
242 },
243 #[serde(rename = "tool", alias = "Tool")]
244 ToolResult {
245 name: String,
246 #[serde(skip_serializing_if = "Option::is_none")]
247 arguments: Option<serde_json::Value>,
248 #[serde(
249 deserialize_with = "string_or_one_or_many",
250 serialize_with = "serialize_tool_content"
251 )]
252 content: OneOrMany<String>,
253 },
254}
255
256fn serialize_tool_content<S>(content: &OneOrMany<String>, serializer: S) -> Result<S::Ok, S::Error>
257where
258 S: Serializer,
259{
260 let joined = content
262 .iter()
263 .map(String::as_str)
264 .collect::<Vec<_>>()
265 .join("\n");
266 serializer.serialize_str(&joined)
267}
268
269impl Message {
270 pub fn system(content: &str) -> Self {
271 Message::System {
272 content: OneOrMany::one(SystemContent::Text {
273 text: content.to_string(),
274 }),
275 }
276 }
277}
278
279impl TryFrom<message::Message> for Vec<Message> {
280 type Error = message::MessageError;
281
282 fn try_from(message: message::Message) -> Result<Vec<Message>, Self::Error> {
283 match message {
284 message::Message::System { content } => Ok(vec![Message::system(&content)]),
285 message::Message::User { content } => {
286 let (tool_results, other_content): (Vec<_>, Vec<_>) = content
287 .into_iter()
288 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
289
290 if !tool_results.is_empty() {
291 tool_results
292 .into_iter()
293 .map(|content| match content {
294 message::UserContent::ToolResult(message::ToolResult {
295 id,
296 content,
297 ..
298 }) => Ok::<_, message::MessageError>(Message::ToolResult {
299 name: id,
300 arguments: None,
301 content: content.try_map(|content| match content {
302 message::ToolResultContent::Text(message::Text { text }) => {
303 Ok(text)
304 }
305 _ => Err(message::MessageError::ConversionError(
306 "Tool result content does not support non-text".into(),
307 )),
308 })?,
309 }),
310 _ => unreachable!(),
311 })
312 .collect::<Result<Vec<_>, _>>()
313 } else {
314 let other_content = OneOrMany::many(other_content).expect(
315 "There must be other content here if there were no tool result content",
316 );
317
318 Ok(vec![Message::User {
319 content: other_content.try_map(|content| match content {
320 message::UserContent::Text(text) => {
321 Ok(UserContent::Text { text: text.text })
322 }
323 message::UserContent::Image(image) => {
324 let url = image.try_into_url()?;
325
326 Ok(UserContent::ImageUrl {
327 image_url: ImageUrl { url },
328 })
329 }
330 message::UserContent::Document(message::Document {
331 data: message::DocumentSourceKind::Raw(raw), ..
332 }) => {
333 let text = String::from_utf8_lossy(raw.as_slice()).into();
334 Ok(UserContent::Text { text })
335 }
336 message::UserContent::Document(message::Document {
337 data: message::DocumentSourceKind::Base64(text) | message::DocumentSourceKind::String(text), ..
338 }) => {
339 Ok(UserContent::Text { text })
340 }
341 _ => Err(message::MessageError::ConversionError(
342 "Huggingface inputs only support text and image URLs (both base64-encoded images and regular URLs)".into(),
343 )),
344 })?,
345 }])
346 }
347 }
348 message::Message::Assistant { content, .. } => {
349 let mut text_content = Vec::new();
350 let mut tool_calls = Vec::new();
351
352 for content in content {
353 match content {
354 message::AssistantContent::Text(text) => text_content.push(text),
355 message::AssistantContent::ToolCall(tool_call) => {
356 tool_calls.push(tool_call)
357 }
358 message::AssistantContent::Reasoning(_) => {
359 }
362 message::AssistantContent::Image(_) => {
363 panic!("Image content is not supported on HuggingFace via Rig");
364 }
365 }
366 }
367
368 if text_content.is_empty() && tool_calls.is_empty() {
369 return Ok(vec![]);
370 }
371
372 Ok(vec![Message::Assistant {
373 content: text_content
374 .into_iter()
375 .map(|content| AssistantContent::Text { text: content.text })
376 .collect::<Vec<_>>(),
377 tool_calls: tool_calls
378 .into_iter()
379 .map(|tool_call| tool_call.into())
380 .collect::<Vec<_>>(),
381 }])
382 }
383 }
384 }
385}
386
387impl TryFrom<Message> for message::Message {
388 type Error = message::MessageError;
389
390 fn try_from(message: Message) -> Result<Self, Self::Error> {
391 Ok(match message {
392 Message::User { content, .. } => message::Message::User {
393 content: content.map(|content| content.into()),
394 },
395 Message::Assistant {
396 content,
397 tool_calls,
398 ..
399 } => {
400 let mut content = content
401 .into_iter()
402 .map(|content| match content {
403 AssistantContent::Text { text } => message::AssistantContent::text(text),
404 })
405 .collect::<Vec<_>>();
406
407 content.extend(
408 tool_calls
409 .into_iter()
410 .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
411 .collect::<Result<Vec<_>, _>>()?,
412 );
413
414 message::Message::Assistant {
415 id: None,
416 content: OneOrMany::many(content).map_err(|_| {
417 message::MessageError::ConversionError(
418 "Neither `content` nor `tool_calls` was provided to the Message"
419 .to_owned(),
420 )
421 })?,
422 }
423 }
424
425 Message::ToolResult { name, content, .. } => message::Message::User {
426 content: OneOrMany::one(message::UserContent::tool_result(
427 name,
428 content.map(message::ToolResultContent::text),
429 )),
430 },
431
432 Message::System { content, .. } => message::Message::User {
435 content: content.map(|c| match c {
436 SystemContent::Text { text } => message::UserContent::text(text),
437 }),
438 },
439 })
440 }
441}
442
443#[derive(Clone, Debug, Deserialize, Serialize)]
444pub struct Choice {
445 pub finish_reason: String,
446 pub index: usize,
447 #[serde(default)]
448 pub logprobs: serde_json::Value,
449 pub message: Message,
450}
451
452#[derive(Debug, Deserialize, Clone, Serialize)]
453pub struct Usage {
454 pub completion_tokens: i32,
455 pub prompt_tokens: i32,
456 pub total_tokens: i32,
457}
458
459impl GetTokenUsage for Usage {
460 fn token_usage(&self) -> Option<crate::completion::Usage> {
461 let mut usage = crate::completion::Usage::new();
462 usage.input_tokens = self.prompt_tokens as u64;
463 usage.output_tokens = self.completion_tokens as u64;
464 usage.total_tokens = self.total_tokens as u64;
465
466 Some(usage)
467 }
468}
469
470#[derive(Clone, Debug, Deserialize, Serialize)]
471pub struct CompletionResponse {
472 pub created: i32,
473 pub id: String,
474 pub model: String,
475 pub choices: Vec<Choice>,
476 #[serde(default, deserialize_with = "default_string_on_null")]
477 pub system_fingerprint: String,
478 pub usage: Usage,
479}
480
481impl crate::telemetry::ProviderResponseExt for CompletionResponse {
482 type OutputMessage = Choice;
483 type Usage = Usage;
484
485 fn get_response_id(&self) -> Option<String> {
486 Some(self.id.clone())
487 }
488
489 fn get_response_model_name(&self) -> Option<String> {
490 Some(self.model.clone())
491 }
492
493 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
494 self.choices.clone()
495 }
496
497 fn get_text_response(&self) -> Option<String> {
498 let text_response = self
499 .choices
500 .iter()
501 .filter_map(|x| {
502 let Message::User { ref content } = x.message else {
503 return None;
504 };
505
506 let text = content
507 .iter()
508 .filter_map(|x| {
509 if let UserContent::Text { text } = x {
510 Some(text.clone())
511 } else {
512 None
513 }
514 })
515 .collect::<Vec<String>>();
516
517 if text.is_empty() {
518 None
519 } else {
520 Some(text.join("\n"))
521 }
522 })
523 .collect::<Vec<String>>()
524 .join("\n");
525
526 if text_response.is_empty() {
527 None
528 } else {
529 Some(text_response)
530 }
531 }
532
533 fn get_usage(&self) -> Option<Self::Usage> {
534 Some(self.usage.clone())
535 }
536}
537
538fn default_string_on_null<'de, D>(deserializer: D) -> Result<String, D::Error>
539where
540 D: Deserializer<'de>,
541{
542 match Option::<String>::deserialize(deserializer)? {
543 Some(value) => Ok(value), None => Ok(String::default()), }
546}
547
548impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
549 type Error = CompletionError;
550
551 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
552 let choice = response.choices.first().ok_or_else(|| {
553 CompletionError::ResponseError("Response contained no choices".to_owned())
554 })?;
555
556 let content = match &choice.message {
557 Message::Assistant {
558 content,
559 tool_calls,
560 ..
561 } => {
562 let mut content = content
563 .iter()
564 .map(|c| match c {
565 AssistantContent::Text { text } => message::AssistantContent::text(text),
566 })
567 .collect::<Vec<_>>();
568
569 content.extend(
570 tool_calls
571 .iter()
572 .map(|call| {
573 completion::AssistantContent::tool_call(
574 &call.id,
575 &call.function.name,
576 call.function.arguments.clone(),
577 )
578 })
579 .collect::<Vec<_>>(),
580 );
581 Ok(content)
582 }
583 _ => Err(CompletionError::ResponseError(
584 "Response did not contain a valid message or tool call".into(),
585 )),
586 }?;
587
588 let choice = OneOrMany::many(content).map_err(|_| {
589 CompletionError::ResponseError(
590 "Response contained no message or tool call (empty)".to_owned(),
591 )
592 })?;
593
594 let usage = completion::Usage {
595 input_tokens: response.usage.prompt_tokens as u64,
596 output_tokens: response.usage.completion_tokens as u64,
597 total_tokens: response.usage.total_tokens as u64,
598 cached_input_tokens: 0,
599 cache_creation_input_tokens: 0,
600 };
601
602 Ok(completion::CompletionResponse {
603 choice,
604 usage,
605 raw_response: response,
606 message_id: None,
607 })
608 }
609}
610
611#[derive(Debug, Serialize, Deserialize)]
612pub(super) struct HuggingfaceCompletionRequest {
613 model: String,
614 pub messages: Vec<Message>,
615 #[serde(skip_serializing_if = "Option::is_none")]
616 temperature: Option<f64>,
617 #[serde(skip_serializing_if = "Vec::is_empty")]
618 tools: Vec<ToolDefinition>,
619 #[serde(skip_serializing_if = "Option::is_none")]
620 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
621 #[serde(flatten, skip_serializing_if = "Option::is_none")]
622 pub additional_params: Option<serde_json::Value>,
623}
624
625impl TryFrom<(&str, CompletionRequest)> for HuggingfaceCompletionRequest {
626 type Error = CompletionError;
627
628 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
629 if req.output_schema.is_some() {
630 tracing::warn!("Structured outputs currently not supported for Huggingface");
631 }
632 let model = req.model.clone().unwrap_or_else(|| model.to_string());
633 let mut full_history: Vec<Message> = match &req.preamble {
634 Some(preamble) => vec![Message::system(preamble)],
635 None => vec![],
636 };
637 if let Some(docs) = req.normalized_documents() {
638 let docs: Vec<Message> = docs.try_into()?;
639 full_history.extend(docs);
640 }
641
642 let chat_history: Vec<Message> = req
643 .chat_history
644 .clone()
645 .into_iter()
646 .map(|message| message.try_into())
647 .collect::<Result<Vec<Vec<Message>>, _>>()?
648 .into_iter()
649 .flatten()
650 .collect();
651
652 full_history.extend(chat_history);
653
654 if full_history.is_empty() {
655 return Err(CompletionError::RequestError(
656 std::io::Error::new(
657 std::io::ErrorKind::InvalidInput,
658 "HuggingFace request has no provider-compatible messages after conversion",
659 )
660 .into(),
661 ));
662 }
663
664 let tool_choice = req
665 .tool_choice
666 .clone()
667 .map(crate::providers::openai::completion::ToolChoice::try_from)
668 .transpose()?;
669
670 Ok(Self {
671 model: model.to_string(),
672 messages: full_history,
673 temperature: req.temperature,
674 tools: req
675 .tools
676 .clone()
677 .into_iter()
678 .map(ToolDefinition::from)
679 .collect::<Vec<_>>(),
680 tool_choice,
681 additional_params: req.additional_params,
682 })
683 }
684}
685
686#[derive(Clone)]
687pub struct CompletionModel<T = reqwest::Client> {
688 pub(crate) client: Client<T>,
689 pub model: String,
691}
692
693impl<T> CompletionModel<T> {
694 pub fn new(client: Client<T>, model: &str) -> Self {
695 Self {
696 client,
697 model: model.to_string(),
698 }
699 }
700}
701
702impl<T> completion::CompletionModel for CompletionModel<T>
703where
704 T: HttpClientExt + Clone + 'static,
705{
706 type Response = CompletionResponse;
707 type StreamingResponse = StreamingCompletionResponse;
708
709 type Client = Client<T>;
710
711 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
712 Self::new(client.clone(), &model.into())
713 }
714
715 async fn completion(
716 &self,
717 completion_request: CompletionRequest,
718 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
719 let request_model = completion_request
720 .model
721 .clone()
722 .unwrap_or_else(|| self.model.clone());
723 let span = if tracing::Span::current().is_disabled() {
724 info_span!(
725 target: "rig::completions",
726 "chat",
727 gen_ai.operation.name = "chat",
728 gen_ai.provider.name = "huggingface",
729 gen_ai.request.model = &request_model,
730 gen_ai.system_instructions = &completion_request.preamble,
731 gen_ai.response.id = tracing::field::Empty,
732 gen_ai.response.model = tracing::field::Empty,
733 gen_ai.usage.output_tokens = tracing::field::Empty,
734 gen_ai.usage.input_tokens = tracing::field::Empty,
735 gen_ai.usage.cached_tokens = tracing::field::Empty,
736 )
737 } else {
738 tracing::Span::current()
739 };
740
741 let model = self.client.subprovider().model_identifier(&request_model);
742 let request = HuggingfaceCompletionRequest::try_from((model.as_ref(), completion_request))?;
743
744 if enabled!(Level::TRACE) {
745 tracing::trace!(
746 target: "rig::completions",
747 "Huggingface completion request: {}",
748 serde_json::to_string_pretty(&request)?
749 );
750 }
751
752 let request = serde_json::to_vec(&request)?;
753
754 let path = self
755 .client
756 .subprovider()
757 .completion_endpoint(&request_model);
758 let request = self
759 .client
760 .post(&path)?
761 .header("Content-Type", "application/json")
762 .body(request)
763 .map_err(|e| CompletionError::HttpError(e.into()))?;
764
765 async move {
766 let response = self.client.send(request).await?;
767
768 if response.status().is_success() {
769 let bytes: Vec<u8> = response.into_body().await?;
770 let text = String::from_utf8_lossy(&bytes);
771
772 tracing::debug!(target: "rig", "Huggingface completion error: {}", text);
773
774 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&bytes)? {
775 ApiResponse::Ok(response) => {
776 if enabled!(Level::TRACE) {
777 tracing::trace!(
778 target: "rig::completions",
779 "Huggingface completion response: {}",
780 serde_json::to_string_pretty(&response)?
781 );
782 }
783
784 let span = tracing::Span::current();
785 span.record_token_usage(&response.usage);
786 span.record_response_metadata(&response);
787
788 response.try_into()
789 }
790 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())),
791 }
792 } else {
793 let status = response.status();
794 let text: Vec<u8> = response.into_body().await?;
795 let text: String = String::from_utf8_lossy(&text).into();
796
797 Err(CompletionError::ProviderError(format!(
798 "{}: {}",
799 status, text
800 )))
801 }
802 }
803 .instrument(span)
804 .await
805 }
806
807 async fn stream(
808 &self,
809 request: CompletionRequest,
810 ) -> Result<
811 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
812 CompletionError,
813 > {
814 CompletionModel::stream(self, request).await
815 }
816}
817
818#[cfg(test)]
819mod tests {
820 use super::*;
821 use serde_path_to_error::deserialize;
822
823 #[test]
824 fn test_huggingface_request_uses_request_model_override() {
825 let request = CompletionRequest {
826 model: Some("meta-llama/Meta-Llama-3.1-8B-Instruct".to_string()),
827 preamble: None,
828 chat_history: crate::OneOrMany::one("Hello".into()),
829 documents: vec![],
830 tools: vec![],
831 temperature: None,
832 max_tokens: None,
833 tool_choice: None,
834 additional_params: None,
835 output_schema: None,
836 };
837
838 let hf_request = HuggingfaceCompletionRequest::try_from(("mistralai/Mistral-7B", request))
839 .expect("request conversion should succeed");
840 let serialized = serde_json::to_value(hf_request).expect("serialization should succeed");
841
842 assert_eq!(serialized["model"], "meta-llama/Meta-Llama-3.1-8B-Instruct");
843 }
844
845 #[test]
846 fn test_huggingface_request_uses_default_model_when_override_unset() {
847 let request = CompletionRequest {
848 model: None,
849 preamble: None,
850 chat_history: crate::OneOrMany::one("Hello".into()),
851 documents: vec![],
852 tools: vec![],
853 temperature: None,
854 max_tokens: None,
855 tool_choice: None,
856 additional_params: None,
857 output_schema: None,
858 };
859
860 let hf_request = HuggingfaceCompletionRequest::try_from(("mistralai/Mistral-7B", request))
861 .expect("request conversion should succeed");
862 let serialized = serde_json::to_value(hf_request).expect("serialization should succeed");
863
864 assert_eq!(serialized["model"], "mistralai/Mistral-7B");
865 }
866
867 #[test]
868 fn test_deserialize_message() {
869 let assistant_message_json = r#"
870 {
871 "role": "assistant",
872 "content": "\n\nHello there, how may I assist you today?"
873 }
874 "#;
875
876 let assistant_message_json2 = r#"
877 {
878 "role": "assistant",
879 "content": [
880 {
881 "type": "text",
882 "text": "\n\nHello there, how may I assist you today?"
883 }
884 ],
885 "tool_calls": null
886 }
887 "#;
888
889 let assistant_message_json3 = r#"
890 {
891 "role": "assistant",
892 "tool_calls": [
893 {
894 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
895 "type": "function",
896 "function": {
897 "name": "subtract",
898 "arguments": {"x": 2, "y": 5}
899 }
900 }
901 ],
902 "content": null,
903 "refusal": null
904 }
905 "#;
906
907 let user_message_json = r#"
908 {
909 "role": "user",
910 "content": [
911 {
912 "type": "text",
913 "text": "What's in this image?"
914 },
915 {
916 "type": "image_url",
917 "image_url": {
918 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
919 }
920 }
921 ]
922 }
923 "#;
924
925 let assistant_message: Message = {
926 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
927 deserialize(jd).unwrap_or_else(|err| {
928 panic!(
929 "Deserialization error at {} ({}:{}): {}",
930 err.path(),
931 err.inner().line(),
932 err.inner().column(),
933 err
934 );
935 })
936 };
937
938 let assistant_message2: Message = {
939 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
940 deserialize(jd).unwrap_or_else(|err| {
941 panic!(
942 "Deserialization error at {} ({}:{}): {}",
943 err.path(),
944 err.inner().line(),
945 err.inner().column(),
946 err
947 );
948 })
949 };
950
951 let assistant_message3: Message = {
952 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
953 &mut serde_json::Deserializer::from_str(assistant_message_json3);
954 deserialize(jd).unwrap_or_else(|err| {
955 panic!(
956 "Deserialization error at {} ({}:{}): {}",
957 err.path(),
958 err.inner().line(),
959 err.inner().column(),
960 err
961 );
962 })
963 };
964
965 let user_message: Message = {
966 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
967 deserialize(jd).unwrap_or_else(|err| {
968 panic!(
969 "Deserialization error at {} ({}:{}): {}",
970 err.path(),
971 err.inner().line(),
972 err.inner().column(),
973 err
974 );
975 })
976 };
977
978 match assistant_message {
979 Message::Assistant { content, .. } => {
980 assert_eq!(
981 content[0],
982 AssistantContent::Text {
983 text: "\n\nHello there, how may I assist you today?".to_string()
984 }
985 );
986 }
987 _ => panic!("Expected assistant message"),
988 }
989
990 match assistant_message2 {
991 Message::Assistant {
992 content,
993 tool_calls,
994 ..
995 } => {
996 assert_eq!(
997 content[0],
998 AssistantContent::Text {
999 text: "\n\nHello there, how may I assist you today?".to_string()
1000 }
1001 );
1002
1003 assert_eq!(tool_calls, vec![]);
1004 }
1005 _ => panic!("Expected assistant message"),
1006 }
1007
1008 match assistant_message3 {
1009 Message::Assistant {
1010 content,
1011 tool_calls,
1012 ..
1013 } => {
1014 assert!(content.is_empty());
1015 assert_eq!(
1016 tool_calls[0],
1017 ToolCall {
1018 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
1019 r#type: ToolType::Function,
1020 function: Function {
1021 name: "subtract".to_string(),
1022 arguments: serde_json::json!({"x": 2, "y": 5}),
1023 },
1024 }
1025 );
1026 }
1027 _ => panic!("Expected assistant message"),
1028 }
1029
1030 match user_message {
1031 Message::User { content, .. } => {
1032 let (first, second) = {
1033 let mut iter = content.into_iter();
1034 (iter.next().unwrap(), iter.next().unwrap())
1035 };
1036 assert_eq!(
1037 first,
1038 UserContent::Text {
1039 text: "What's in this image?".to_string()
1040 }
1041 );
1042 assert_eq!(second, UserContent::ImageUrl { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string() } });
1043 }
1044 _ => panic!("Expected user message"),
1045 }
1046 }
1047
1048 #[test]
1049 fn test_message_to_message_conversion() {
1050 let user_message = message::Message::User {
1051 content: OneOrMany::one(message::UserContent::text("Hello")),
1052 };
1053
1054 let assistant_message = message::Message::Assistant {
1055 id: None,
1056 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
1057 };
1058
1059 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
1060 let converted_assistant_message: Vec<Message> =
1061 assistant_message.clone().try_into().unwrap();
1062
1063 match converted_user_message[0].clone() {
1064 Message::User { content, .. } => {
1065 assert_eq!(
1066 content.first(),
1067 UserContent::Text {
1068 text: "Hello".to_string()
1069 }
1070 );
1071 }
1072 _ => panic!("Expected user message"),
1073 }
1074
1075 match converted_assistant_message[0].clone() {
1076 Message::Assistant { content, .. } => {
1077 assert_eq!(
1078 content[0],
1079 AssistantContent::Text {
1080 text: "Hi there!".to_string()
1081 }
1082 );
1083 }
1084 _ => panic!("Expected assistant message"),
1085 }
1086
1087 let original_user_message: message::Message =
1088 converted_user_message[0].clone().try_into().unwrap();
1089 let original_assistant_message: message::Message =
1090 converted_assistant_message[0].clone().try_into().unwrap();
1091
1092 assert_eq!(original_user_message, user_message);
1093 assert_eq!(original_assistant_message, assistant_message);
1094 }
1095
1096 #[test]
1097 fn test_message_from_message_conversion() {
1098 let user_message = Message::User {
1099 content: OneOrMany::one(UserContent::Text {
1100 text: "Hello".to_string(),
1101 }),
1102 };
1103
1104 let assistant_message = Message::Assistant {
1105 content: vec![AssistantContent::Text {
1106 text: "Hi there!".to_string(),
1107 }],
1108 tool_calls: vec![],
1109 };
1110
1111 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1112 let converted_assistant_message: message::Message =
1113 assistant_message.clone().try_into().unwrap();
1114
1115 match converted_user_message.clone() {
1116 message::Message::User { content } => {
1117 assert_eq!(content.first(), message::UserContent::text("Hello"));
1118 }
1119 _ => panic!("Expected user message"),
1120 }
1121
1122 match converted_assistant_message.clone() {
1123 message::Message::Assistant { content, .. } => {
1124 assert_eq!(
1125 content.first(),
1126 message::AssistantContent::text("Hi there!")
1127 );
1128 }
1129 _ => panic!("Expected assistant message"),
1130 }
1131
1132 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
1133 let original_assistant_message: Vec<Message> =
1134 converted_assistant_message.try_into().unwrap();
1135
1136 assert_eq!(original_user_message[0], user_message);
1137 assert_eq!(original_assistant_message[0], assistant_message);
1138 }
1139
1140 #[test]
1141 fn test_responses() {
1142 let fireworks_response_json = r#"
1143 {
1144 "choices": [
1145 {
1146 "finish_reason": "tool_calls",
1147 "index": 0,
1148 "message": {
1149 "role": "assistant",
1150 "tool_calls": [
1151 {
1152 "function": {
1153 "arguments": "{\"x\": 2, \"y\": 5}",
1154 "name": "subtract"
1155 },
1156 "id": "call_1BspL6mQqjKgvsQbH1TIYkHf",
1157 "index": 0,
1158 "type": "function"
1159 }
1160 ]
1161 }
1162 }
1163 ],
1164 "created": 1740704000,
1165 "id": "2a81f6a1-4866-42fb-9902-2655a2b5b1ff",
1166 "model": "accounts/fireworks/models/deepseek-v3",
1167 "object": "chat.completion",
1168 "usage": {
1169 "completion_tokens": 26,
1170 "prompt_tokens": 248,
1171 "total_tokens": 274
1172 }
1173 }
1174 "#;
1175
1176 let novita_response_json = r#"
1177 {
1178 "choices": [
1179 {
1180 "finish_reason": "tool_calls",
1181 "index": 0,
1182 "logprobs": null,
1183 "message": {
1184 "audio": null,
1185 "content": null,
1186 "function_call": null,
1187 "reasoning_content": null,
1188 "refusal": null,
1189 "role": "assistant",
1190 "tool_calls": [
1191 {
1192 "function": {
1193 "arguments": "{\"x\": \"2\", \"y\": \"5\"}",
1194 "name": "subtract"
1195 },
1196 "id": "chatcmpl-tool-f6d2af7c8dc041058f95e2c2eede45c5",
1197 "type": "function"
1198 }
1199 ]
1200 },
1201 "stop_reason": 128008
1202 }
1203 ],
1204 "created": 1740704592,
1205 "id": "chatcmpl-a92c60ae125c47c998ecdcb53387fed4",
1206 "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-fast",
1207 "object": "chat.completion",
1208 "prompt_logprobs": null,
1209 "service_tier": null,
1210 "system_fingerprint": null,
1211 "usage": {
1212 "completion_tokens": 28,
1213 "completion_tokens_details": null,
1214 "prompt_tokens": 335,
1215 "prompt_tokens_details": null,
1216 "total_tokens": 363
1217 }
1218 }
1219 "#;
1220
1221 let _firework_response: CompletionResponse = {
1222 let jd = &mut serde_json::Deserializer::from_str(fireworks_response_json);
1223 deserialize(jd).unwrap_or_else(|err| {
1224 panic!(
1225 "Deserialization error at {} ({}:{}): {}",
1226 err.path(),
1227 err.inner().line(),
1228 err.inner().column(),
1229 err
1230 );
1231 })
1232 };
1233
1234 let _novita_response: CompletionResponse = {
1235 let jd = &mut serde_json::Deserializer::from_str(novita_response_json);
1236 deserialize(jd).unwrap_or_else(|err| {
1237 panic!(
1238 "Deserialization error at {} ({}:{}): {}",
1239 err.path(),
1240 err.inner().line(),
1241 err.inner().column(),
1242 err
1243 );
1244 })
1245 };
1246 }
1247
1248 #[test]
1249 fn test_assistant_reasoning_is_silently_skipped() {
1250 let assistant = message::Message::Assistant {
1251 id: None,
1252 content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1253 };
1254
1255 let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
1256 assert!(converted.is_empty());
1257 }
1258
1259 #[test]
1260 fn test_assistant_text_and_tool_call_are_preserved_when_reasoning_present() {
1261 let assistant = message::Message::Assistant {
1262 id: None,
1263 content: OneOrMany::many(vec![
1264 message::AssistantContent::reasoning("hidden"),
1265 message::AssistantContent::text("visible"),
1266 message::AssistantContent::tool_call(
1267 "call_1",
1268 "subtract",
1269 serde_json::json!({"x": 2, "y": 1}),
1270 ),
1271 ])
1272 .expect("non-empty assistant content"),
1273 };
1274
1275 let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
1276 assert_eq!(converted.len(), 1);
1277
1278 match &converted[0] {
1279 Message::Assistant {
1280 content,
1281 tool_calls,
1282 ..
1283 } => {
1284 assert_eq!(
1285 content,
1286 &vec![AssistantContent::Text {
1287 text: "visible".to_string()
1288 }]
1289 );
1290 assert_eq!(tool_calls.len(), 1);
1291 assert_eq!(tool_calls[0].id, "call_1");
1292 assert_eq!(tool_calls[0].function.name, "subtract");
1293 assert_eq!(
1294 tool_calls[0].function.arguments,
1295 serde_json::json!({"x": 2, "y": 1})
1296 );
1297 }
1298 _ => panic!("expected assistant message"),
1299 }
1300 }
1301
1302 #[test]
1303 fn test_request_conversion_errors_when_all_messages_are_filtered() {
1304 let request = completion::CompletionRequest {
1305 preamble: None,
1306 chat_history: OneOrMany::one(message::Message::Assistant {
1307 id: None,
1308 content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1309 }),
1310 documents: vec![],
1311 tools: vec![],
1312 temperature: None,
1313 max_tokens: None,
1314 tool_choice: None,
1315 additional_params: None,
1316 model: None,
1317 output_schema: None,
1318 };
1319
1320 let result = HuggingfaceCompletionRequest::try_from(("meta/test-model", request));
1321 assert!(matches!(result, Err(CompletionError::RequestError(_))));
1322 }
1323}