1use super::client::Client;
2use crate::completion::GetTokenUsage;
3use crate::http_client::HttpClientExt;
4use crate::providers::openai::StreamingCompletionResponse;
5use crate::telemetry::SpanCombinator;
6use crate::{
7 OneOrMany,
8 completion::{self, CompletionError, CompletionRequest},
9 json_utils,
10 message::{self},
11 one_or_many::string_or_one_or_many,
12};
13use serde::{Deserialize, Deserializer, Serialize, Serializer};
14use serde_json::Value;
15use std::{convert::Infallible, str::FromStr};
16use tracing::{Level, enabled, info_span};
17use tracing_futures::Instrument;
18
19#[derive(Debug, Deserialize)]
20#[serde(untagged)]
21pub enum ApiResponse<T> {
22 Ok(T),
23 Err(Value),
24}
25
26pub const GEMMA_2: &str = "google/gemma-2-2b-it";
33pub const META_LLAMA_3_1: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
35pub const SMALLTHINKER_PREVIEW: &str = "PowerInfer/SmallThinker-3B-Preview";
37pub const QWEN2_5: &str = "Qwen/Qwen2.5-7B-Instruct";
39pub const QWEN2_5_CODER: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
41
42pub const QWEN2_VL: &str = "Qwen/Qwen2-VL-7B-Instruct";
46pub const QWEN_QVQ_PREVIEW: &str = "Qwen/QVQ-72B-Preview";
48
49#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
50pub struct Function {
51 name: String,
52 #[serde(
53 serialize_with = "json_utils::stringified_json::serialize",
54 deserialize_with = "json_utils::stringified_json::deserialize_maybe_stringified"
55 )]
56 pub arguments: serde_json::Value,
57}
58
59impl From<Function> for message::ToolFunction {
60 fn from(value: Function) -> Self {
61 message::ToolFunction {
62 name: value.name,
63 arguments: value.arguments,
64 }
65 }
66}
67
68#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
69#[serde(rename_all = "lowercase")]
70pub enum ToolType {
71 #[default]
72 Function,
73}
74
75#[derive(Debug, Deserialize, Serialize, Clone)]
76pub struct ToolDefinition {
77 pub r#type: String,
78 pub function: completion::ToolDefinition,
79}
80
81impl From<completion::ToolDefinition> for ToolDefinition {
82 fn from(tool: completion::ToolDefinition) -> Self {
83 Self {
84 r#type: "function".into(),
85 function: tool,
86 }
87 }
88}
89
90#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
91pub struct ToolCall {
92 pub id: String,
93 pub r#type: ToolType,
94 pub function: Function,
95}
96
97impl From<ToolCall> for message::ToolCall {
98 fn from(value: ToolCall) -> Self {
99 message::ToolCall {
100 id: value.id,
101 call_id: None,
102 function: value.function.into(),
103 signature: None,
104 additional_params: None,
105 }
106 }
107}
108
109impl From<message::ToolCall> for ToolCall {
110 fn from(value: message::ToolCall) -> Self {
111 ToolCall {
112 id: value.id,
113 r#type: ToolType::Function,
114 function: Function {
115 name: value.function.name,
116 arguments: value.function.arguments,
117 },
118 }
119 }
120}
121
122#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
123pub struct ImageUrl {
124 url: String,
125}
126
127#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
128#[serde(tag = "type", rename_all = "lowercase")]
129pub enum UserContent {
130 Text {
131 text: String,
132 },
133 #[serde(rename = "image_url")]
134 ImageUrl {
135 image_url: ImageUrl,
136 },
137}
138
139impl FromStr for UserContent {
140 type Err = Infallible;
141
142 fn from_str(s: &str) -> Result<Self, Self::Err> {
143 Ok(UserContent::Text {
144 text: s.to_string(),
145 })
146 }
147}
148
149#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
150#[serde(tag = "type", rename_all = "lowercase")]
151pub enum AssistantContent {
152 Text { text: String },
153}
154
155impl FromStr for AssistantContent {
156 type Err = Infallible;
157
158 fn from_str(s: &str) -> Result<Self, Self::Err> {
159 Ok(AssistantContent::Text {
160 text: s.to_string(),
161 })
162 }
163}
164
165#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
166#[serde(tag = "type", rename_all = "lowercase")]
167pub enum SystemContent {
168 Text { text: String },
169}
170
171impl FromStr for SystemContent {
172 type Err = Infallible;
173
174 fn from_str(s: &str) -> Result<Self, Self::Err> {
175 Ok(SystemContent::Text {
176 text: s.to_string(),
177 })
178 }
179}
180
181impl From<UserContent> for message::UserContent {
182 fn from(value: UserContent) -> Self {
183 match value {
184 UserContent::Text { text } => message::UserContent::text(text),
185 UserContent::ImageUrl { image_url } => {
186 message::UserContent::image_url(image_url.url, None, None)
187 }
188 }
189 }
190}
191
192impl TryFrom<message::UserContent> for UserContent {
193 type Error = message::MessageError;
194
195 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
196 match content {
197 message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }),
198 message::UserContent::Document(message::Document {
199 data: message::DocumentSourceKind::Raw(raw),
200 ..
201 }) => {
202 let text = String::from_utf8_lossy(raw.as_slice()).into();
203 Ok(UserContent::Text { text })
204 }
205 message::UserContent::Document(message::Document {
206 data:
207 message::DocumentSourceKind::Base64(text)
208 | message::DocumentSourceKind::String(text),
209 ..
210 }) => Ok(UserContent::Text { text }),
211 message::UserContent::Image(message::Image { data, .. }) => match data {
212 message::DocumentSourceKind::Url(url) => Ok(UserContent::ImageUrl {
213 image_url: ImageUrl { url },
214 }),
215 _ => Err(message::MessageError::ConversionError(
216 "Huggingface only supports images as urls".into(),
217 )),
218 },
219 _ => Err(message::MessageError::ConversionError(
220 "Huggingface only supports text and images".into(),
221 )),
222 }
223 }
224}
225
226#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
227#[serde(tag = "role", rename_all = "lowercase")]
228pub enum Message {
229 System {
230 #[serde(deserialize_with = "string_or_one_or_many")]
231 content: OneOrMany<SystemContent>,
232 },
233 User {
234 #[serde(deserialize_with = "string_or_one_or_many")]
235 content: OneOrMany<UserContent>,
236 },
237 Assistant {
238 #[serde(default, deserialize_with = "json_utils::string_or_vec")]
239 content: Vec<AssistantContent>,
240 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
241 tool_calls: Vec<ToolCall>,
242 },
243 #[serde(rename = "tool", alias = "Tool")]
244 ToolResult {
245 name: String,
246 #[serde(skip_serializing_if = "Option::is_none")]
247 arguments: Option<serde_json::Value>,
248 #[serde(
249 deserialize_with = "string_or_one_or_many",
250 serialize_with = "serialize_tool_content"
251 )]
252 content: OneOrMany<String>,
253 },
254}
255
256fn serialize_tool_content<S>(content: &OneOrMany<String>, serializer: S) -> Result<S::Ok, S::Error>
257where
258 S: Serializer,
259{
260 let joined = content
262 .iter()
263 .map(String::as_str)
264 .collect::<Vec<_>>()
265 .join("\n");
266 serializer.serialize_str(&joined)
267}
268
269impl Message {
270 pub fn system(content: &str) -> Self {
271 Message::System {
272 content: OneOrMany::one(SystemContent::Text {
273 text: content.to_string(),
274 }),
275 }
276 }
277}
278
279impl TryFrom<message::Message> for Vec<Message> {
280 type Error = message::MessageError;
281
282 fn try_from(message: message::Message) -> Result<Vec<Message>, Self::Error> {
283 match message {
284 message::Message::System { content } => Ok(vec![Message::system(&content)]),
285 message::Message::User { content } => {
286 let (tool_results, other_content): (Vec<_>, Vec<_>) = content
287 .into_iter()
288 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
289
290 if !tool_results.is_empty() {
291 tool_results
292 .into_iter()
293 .map(|content| match content {
294 message::UserContent::ToolResult(message::ToolResult {
295 id,
296 content,
297 ..
298 }) => Ok::<_, message::MessageError>(Message::ToolResult {
299 name: id,
300 arguments: None,
301 content: content.try_map(|content| match content {
302 message::ToolResultContent::Text(message::Text { text }) => {
303 Ok(text)
304 }
305 _ => Err(message::MessageError::ConversionError(
306 "Tool result content does not support non-text".into(),
307 )),
308 })?,
309 }),
310 _ => Err(message::MessageError::ConversionError(
311 "expected tool result content while converting HuggingFace input"
312 .into(),
313 )),
314 })
315 .collect::<Result<Vec<_>, _>>()
316 } else {
317 let other_content = OneOrMany::many(other_content).map_err(|_| {
318 message::MessageError::ConversionError(
319 "HuggingFace user message did not contain any non-tool content".into(),
320 )
321 })?;
322
323 Ok(vec![Message::User {
324 content: other_content.try_map(|content| match content {
325 message::UserContent::Text(text) => {
326 Ok(UserContent::Text { text: text.text })
327 }
328 message::UserContent::Image(image) => {
329 let url = image.try_into_url()?;
330
331 Ok(UserContent::ImageUrl {
332 image_url: ImageUrl { url },
333 })
334 }
335 message::UserContent::Document(message::Document {
336 data: message::DocumentSourceKind::Raw(raw), ..
337 }) => {
338 let text = String::from_utf8_lossy(raw.as_slice()).into();
339 Ok(UserContent::Text { text })
340 }
341 message::UserContent::Document(message::Document {
342 data: message::DocumentSourceKind::Base64(text) | message::DocumentSourceKind::String(text), ..
343 }) => {
344 Ok(UserContent::Text { text })
345 }
346 _ => Err(message::MessageError::ConversionError(
347 "Huggingface inputs only support text and image URLs (both base64-encoded images and regular URLs)".into(),
348 )),
349 })?,
350 }])
351 }
352 }
353 message::Message::Assistant { content, .. } => {
354 let mut text_content = Vec::new();
355 let mut tool_calls = Vec::new();
356
357 for content in content {
358 match content {
359 message::AssistantContent::Text(text) => text_content.push(text),
360 message::AssistantContent::ToolCall(tool_call) => {
361 tool_calls.push(tool_call)
362 }
363 message::AssistantContent::Reasoning(_) => {
364 }
367 message::AssistantContent::Image(_) => {
368 return Err(message::MessageError::ConversionError(
369 "HuggingFace assistant messages do not support image content"
370 .into(),
371 ));
372 }
373 }
374 }
375
376 if text_content.is_empty() && tool_calls.is_empty() {
377 return Ok(vec![]);
378 }
379
380 Ok(vec![Message::Assistant {
381 content: text_content
382 .into_iter()
383 .map(|content| AssistantContent::Text { text: content.text })
384 .collect::<Vec<_>>(),
385 tool_calls: tool_calls
386 .into_iter()
387 .map(|tool_call| tool_call.into())
388 .collect::<Vec<_>>(),
389 }])
390 }
391 }
392 }
393}
394
395impl TryFrom<Message> for message::Message {
396 type Error = message::MessageError;
397
398 fn try_from(message: Message) -> Result<Self, Self::Error> {
399 Ok(match message {
400 Message::User { content, .. } => message::Message::User {
401 content: content.map(|content| content.into()),
402 },
403 Message::Assistant {
404 content,
405 tool_calls,
406 ..
407 } => {
408 let mut content = content
409 .into_iter()
410 .map(|content| match content {
411 AssistantContent::Text { text } => message::AssistantContent::text(text),
412 })
413 .collect::<Vec<_>>();
414
415 content.extend(
416 tool_calls
417 .into_iter()
418 .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
419 .collect::<Result<Vec<_>, _>>()?,
420 );
421
422 message::Message::Assistant {
423 id: None,
424 content: OneOrMany::many(content).map_err(|_| {
425 message::MessageError::ConversionError(
426 "Neither `content` nor `tool_calls` was provided to the Message"
427 .to_owned(),
428 )
429 })?,
430 }
431 }
432
433 Message::ToolResult { name, content, .. } => message::Message::User {
434 content: OneOrMany::one(message::UserContent::tool_result(
435 name,
436 content.map(message::ToolResultContent::text),
437 )),
438 },
439
440 Message::System { content, .. } => message::Message::User {
443 content: content.map(|c| match c {
444 SystemContent::Text { text } => message::UserContent::text(text),
445 }),
446 },
447 })
448 }
449}
450
451#[derive(Clone, Debug, Deserialize, Serialize)]
452pub struct Choice {
453 pub finish_reason: String,
454 pub index: usize,
455 #[serde(default)]
456 pub logprobs: serde_json::Value,
457 pub message: Message,
458}
459
460#[derive(Debug, Deserialize, Clone, Serialize)]
461pub struct Usage {
462 pub completion_tokens: i32,
463 pub prompt_tokens: i32,
464 pub total_tokens: i32,
465}
466
467impl GetTokenUsage for Usage {
468 fn token_usage(&self) -> Option<crate::completion::Usage> {
469 let mut usage = crate::completion::Usage::new();
470 usage.input_tokens = self.prompt_tokens as u64;
471 usage.output_tokens = self.completion_tokens as u64;
472 usage.total_tokens = self.total_tokens as u64;
473
474 Some(usage)
475 }
476}
477
478#[derive(Clone, Debug, Deserialize, Serialize)]
479pub struct CompletionResponse {
480 pub created: i32,
481 pub id: String,
482 pub model: String,
483 pub choices: Vec<Choice>,
484 #[serde(default, deserialize_with = "default_string_on_null")]
485 pub system_fingerprint: String,
486 pub usage: Usage,
487}
488
489impl crate::telemetry::ProviderResponseExt for CompletionResponse {
490 type OutputMessage = Choice;
491 type Usage = Usage;
492
493 fn get_response_id(&self) -> Option<String> {
494 Some(self.id.clone())
495 }
496
497 fn get_response_model_name(&self) -> Option<String> {
498 Some(self.model.clone())
499 }
500
501 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
502 self.choices.clone()
503 }
504
505 fn get_text_response(&self) -> Option<String> {
506 let text_response = self
507 .choices
508 .iter()
509 .filter_map(|x| {
510 let Message::User { ref content } = x.message else {
511 return None;
512 };
513
514 let text = content
515 .iter()
516 .filter_map(|x| {
517 if let UserContent::Text { text } = x {
518 Some(text.clone())
519 } else {
520 None
521 }
522 })
523 .collect::<Vec<String>>();
524
525 if text.is_empty() {
526 None
527 } else {
528 Some(text.join("\n"))
529 }
530 })
531 .collect::<Vec<String>>()
532 .join("\n");
533
534 if text_response.is_empty() {
535 None
536 } else {
537 Some(text_response)
538 }
539 }
540
541 fn get_usage(&self) -> Option<Self::Usage> {
542 Some(self.usage.clone())
543 }
544}
545
546fn default_string_on_null<'de, D>(deserializer: D) -> Result<String, D::Error>
547where
548 D: Deserializer<'de>,
549{
550 match Option::<String>::deserialize(deserializer)? {
551 Some(value) => Ok(value), None => Ok(String::default()), }
554}
555
556impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
557 type Error = CompletionError;
558
559 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
560 let choice = response.choices.first().ok_or_else(|| {
561 CompletionError::ResponseError("Response contained no choices".to_owned())
562 })?;
563
564 let content = match &choice.message {
565 Message::Assistant {
566 content,
567 tool_calls,
568 ..
569 } => {
570 let mut content = content
571 .iter()
572 .map(|c| match c {
573 AssistantContent::Text { text } => message::AssistantContent::text(text),
574 })
575 .collect::<Vec<_>>();
576
577 content.extend(
578 tool_calls
579 .iter()
580 .map(|call| {
581 completion::AssistantContent::tool_call(
582 &call.id,
583 &call.function.name,
584 call.function.arguments.clone(),
585 )
586 })
587 .collect::<Vec<_>>(),
588 );
589 Ok(content)
590 }
591 _ => Err(CompletionError::ResponseError(
592 "Response did not contain a valid message or tool call".into(),
593 )),
594 }?;
595
596 let choice = OneOrMany::many(content).map_err(|_| {
597 CompletionError::ResponseError(
598 "Response contained no message or tool call (empty)".to_owned(),
599 )
600 })?;
601
602 let usage = completion::Usage {
603 input_tokens: response.usage.prompt_tokens as u64,
604 output_tokens: response.usage.completion_tokens as u64,
605 total_tokens: response.usage.total_tokens as u64,
606 cached_input_tokens: 0,
607 cache_creation_input_tokens: 0,
608 };
609
610 Ok(completion::CompletionResponse {
611 choice,
612 usage,
613 raw_response: response,
614 message_id: None,
615 })
616 }
617}
618
619#[derive(Debug, Serialize, Deserialize)]
620pub(super) struct HuggingfaceCompletionRequest {
621 model: String,
622 pub messages: Vec<Message>,
623 #[serde(skip_serializing_if = "Option::is_none")]
624 temperature: Option<f64>,
625 #[serde(skip_serializing_if = "Vec::is_empty")]
626 tools: Vec<ToolDefinition>,
627 #[serde(skip_serializing_if = "Option::is_none")]
628 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
629 #[serde(flatten, skip_serializing_if = "Option::is_none")]
630 pub additional_params: Option<serde_json::Value>,
631}
632
633impl TryFrom<(&str, CompletionRequest)> for HuggingfaceCompletionRequest {
634 type Error = CompletionError;
635
636 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
637 if req.output_schema.is_some() {
638 tracing::warn!("Structured outputs currently not supported for Huggingface");
639 }
640 let model = req.model.clone().unwrap_or_else(|| model.to_string());
641 let mut full_history: Vec<Message> = match &req.preamble {
642 Some(preamble) => vec![Message::system(preamble)],
643 None => vec![],
644 };
645 if let Some(docs) = req.normalized_documents() {
646 let docs: Vec<Message> = docs.try_into()?;
647 full_history.extend(docs);
648 }
649
650 let chat_history: Vec<Message> = req
651 .chat_history
652 .clone()
653 .into_iter()
654 .map(|message| message.try_into())
655 .collect::<Result<Vec<Vec<Message>>, _>>()?
656 .into_iter()
657 .flatten()
658 .collect();
659
660 full_history.extend(chat_history);
661
662 if full_history.is_empty() {
663 return Err(CompletionError::RequestError(
664 std::io::Error::new(
665 std::io::ErrorKind::InvalidInput,
666 "HuggingFace request has no provider-compatible messages after conversion",
667 )
668 .into(),
669 ));
670 }
671
672 let tool_choice = req
673 .tool_choice
674 .clone()
675 .map(crate::providers::openai::completion::ToolChoice::try_from)
676 .transpose()?;
677
678 Ok(Self {
679 model: model.to_string(),
680 messages: full_history,
681 temperature: req.temperature,
682 tools: req
683 .tools
684 .clone()
685 .into_iter()
686 .map(ToolDefinition::from)
687 .collect::<Vec<_>>(),
688 tool_choice,
689 additional_params: req.additional_params,
690 })
691 }
692}
693
694#[derive(Clone)]
695pub struct CompletionModel<T = reqwest::Client> {
696 pub(crate) client: Client<T>,
697 pub model: String,
699}
700
701impl<T> CompletionModel<T> {
702 pub fn new(client: Client<T>, model: &str) -> Self {
703 Self {
704 client,
705 model: model.to_string(),
706 }
707 }
708}
709
710impl<T> completion::CompletionModel for CompletionModel<T>
711where
712 T: HttpClientExt + Clone + 'static,
713{
714 type Response = CompletionResponse;
715 type StreamingResponse = StreamingCompletionResponse;
716
717 type Client = Client<T>;
718
719 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
720 Self::new(client.clone(), &model.into())
721 }
722
723 async fn completion(
724 &self,
725 completion_request: CompletionRequest,
726 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
727 let request_model = completion_request
728 .model
729 .clone()
730 .unwrap_or_else(|| self.model.clone());
731 let span = if tracing::Span::current().is_disabled() {
732 info_span!(
733 target: "rig::completions",
734 "chat",
735 gen_ai.operation.name = "chat",
736 gen_ai.provider.name = "huggingface",
737 gen_ai.request.model = &request_model,
738 gen_ai.system_instructions = &completion_request.preamble,
739 gen_ai.response.id = tracing::field::Empty,
740 gen_ai.response.model = tracing::field::Empty,
741 gen_ai.usage.output_tokens = tracing::field::Empty,
742 gen_ai.usage.input_tokens = tracing::field::Empty,
743 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
744 )
745 } else {
746 tracing::Span::current()
747 };
748
749 let model = self.client.subprovider().model_identifier(&request_model);
750 let request = HuggingfaceCompletionRequest::try_from((model.as_ref(), completion_request))?;
751
752 if enabled!(Level::TRACE) {
753 tracing::trace!(
754 target: "rig::completions",
755 "Huggingface completion request: {}",
756 serde_json::to_string_pretty(&request)?
757 );
758 }
759
760 let request = serde_json::to_vec(&request)?;
761
762 let path = self
763 .client
764 .subprovider()
765 .completion_endpoint(&request_model);
766 let request = self
767 .client
768 .post(&path)?
769 .header("Content-Type", "application/json")
770 .body(request)
771 .map_err(|e| CompletionError::HttpError(e.into()))?;
772
773 async move {
774 let response = self.client.send(request).await?;
775
776 if response.status().is_success() {
777 let bytes: Vec<u8> = response.into_body().await?;
778 let text = String::from_utf8_lossy(&bytes);
779
780 tracing::debug!(target: "rig", "Huggingface completion error: {}", text);
781
782 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&bytes)? {
783 ApiResponse::Ok(response) => {
784 if enabled!(Level::TRACE) {
785 tracing::trace!(
786 target: "rig::completions",
787 "Huggingface completion response: {}",
788 serde_json::to_string_pretty(&response)?
789 );
790 }
791
792 let span = tracing::Span::current();
793 span.record_token_usage(&response.usage);
794 span.record_response_metadata(&response);
795
796 response.try_into()
797 }
798 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())),
799 }
800 } else {
801 let status = response.status();
802 let text: Vec<u8> = response.into_body().await?;
803 let text: String = String::from_utf8_lossy(&text).into();
804
805 Err(CompletionError::ProviderError(format!(
806 "{}: {}",
807 status, text
808 )))
809 }
810 }
811 .instrument(span)
812 .await
813 }
814
815 async fn stream(
816 &self,
817 request: CompletionRequest,
818 ) -> Result<
819 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
820 CompletionError,
821 > {
822 CompletionModel::stream(self, request).await
823 }
824}
825
826#[cfg(test)]
827mod tests {
828 use super::*;
829 use serde_path_to_error::deserialize;
830
831 #[test]
832 fn test_huggingface_request_uses_request_model_override() {
833 let request = CompletionRequest {
834 model: Some("meta-llama/Meta-Llama-3.1-8B-Instruct".to_string()),
835 preamble: None,
836 chat_history: crate::OneOrMany::one("Hello".into()),
837 documents: vec![],
838 tools: vec![],
839 temperature: None,
840 max_tokens: None,
841 tool_choice: None,
842 additional_params: None,
843 output_schema: None,
844 };
845
846 let hf_request = HuggingfaceCompletionRequest::try_from(("mistralai/Mistral-7B", request))
847 .expect("request conversion should succeed");
848 let serialized = serde_json::to_value(hf_request).expect("serialization should succeed");
849
850 assert_eq!(serialized["model"], "meta-llama/Meta-Llama-3.1-8B-Instruct");
851 }
852
853 #[test]
854 fn test_huggingface_request_uses_default_model_when_override_unset() {
855 let request = CompletionRequest {
856 model: None,
857 preamble: None,
858 chat_history: crate::OneOrMany::one("Hello".into()),
859 documents: vec![],
860 tools: vec![],
861 temperature: None,
862 max_tokens: None,
863 tool_choice: None,
864 additional_params: None,
865 output_schema: None,
866 };
867
868 let hf_request = HuggingfaceCompletionRequest::try_from(("mistralai/Mistral-7B", request))
869 .expect("request conversion should succeed");
870 let serialized = serde_json::to_value(hf_request).expect("serialization should succeed");
871
872 assert_eq!(serialized["model"], "mistralai/Mistral-7B");
873 }
874
875 #[test]
876 fn test_deserialize_message() {
877 let assistant_message_json = r#"
878 {
879 "role": "assistant",
880 "content": "\n\nHello there, how may I assist you today?"
881 }
882 "#;
883
884 let assistant_message_json2 = r#"
885 {
886 "role": "assistant",
887 "content": [
888 {
889 "type": "text",
890 "text": "\n\nHello there, how may I assist you today?"
891 }
892 ],
893 "tool_calls": null
894 }
895 "#;
896
897 let assistant_message_json3 = r#"
898 {
899 "role": "assistant",
900 "tool_calls": [
901 {
902 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
903 "type": "function",
904 "function": {
905 "name": "subtract",
906 "arguments": {"x": 2, "y": 5}
907 }
908 }
909 ],
910 "content": null,
911 "refusal": null
912 }
913 "#;
914
915 let user_message_json = r#"
916 {
917 "role": "user",
918 "content": [
919 {
920 "type": "text",
921 "text": "What's in this image?"
922 },
923 {
924 "type": "image_url",
925 "image_url": {
926 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
927 }
928 }
929 ]
930 }
931 "#;
932
933 let assistant_message: Message = {
934 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
935 deserialize(jd).unwrap_or_else(|err| {
936 panic!(
937 "Deserialization error at {} ({}:{}): {}",
938 err.path(),
939 err.inner().line(),
940 err.inner().column(),
941 err
942 );
943 })
944 };
945
946 let assistant_message2: Message = {
947 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
948 deserialize(jd).unwrap_or_else(|err| {
949 panic!(
950 "Deserialization error at {} ({}:{}): {}",
951 err.path(),
952 err.inner().line(),
953 err.inner().column(),
954 err
955 );
956 })
957 };
958
959 let assistant_message3: Message = {
960 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
961 &mut serde_json::Deserializer::from_str(assistant_message_json3);
962 deserialize(jd).unwrap_or_else(|err| {
963 panic!(
964 "Deserialization error at {} ({}:{}): {}",
965 err.path(),
966 err.inner().line(),
967 err.inner().column(),
968 err
969 );
970 })
971 };
972
973 let user_message: Message = {
974 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
975 deserialize(jd).unwrap_or_else(|err| {
976 panic!(
977 "Deserialization error at {} ({}:{}): {}",
978 err.path(),
979 err.inner().line(),
980 err.inner().column(),
981 err
982 );
983 })
984 };
985
986 match assistant_message {
987 Message::Assistant { content, .. } => {
988 assert_eq!(
989 content[0],
990 AssistantContent::Text {
991 text: "\n\nHello there, how may I assist you today?".to_string()
992 }
993 );
994 }
995 _ => panic!("Expected assistant message"),
996 }
997
998 match assistant_message2 {
999 Message::Assistant {
1000 content,
1001 tool_calls,
1002 ..
1003 } => {
1004 assert_eq!(
1005 content[0],
1006 AssistantContent::Text {
1007 text: "\n\nHello there, how may I assist you today?".to_string()
1008 }
1009 );
1010
1011 assert_eq!(tool_calls, vec![]);
1012 }
1013 _ => panic!("Expected assistant message"),
1014 }
1015
1016 match assistant_message3 {
1017 Message::Assistant {
1018 content,
1019 tool_calls,
1020 ..
1021 } => {
1022 assert!(content.is_empty());
1023 assert_eq!(
1024 tool_calls[0],
1025 ToolCall {
1026 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
1027 r#type: ToolType::Function,
1028 function: Function {
1029 name: "subtract".to_string(),
1030 arguments: serde_json::json!({"x": 2, "y": 5}),
1031 },
1032 }
1033 );
1034 }
1035 _ => panic!("Expected assistant message"),
1036 }
1037
1038 match user_message {
1039 Message::User { content, .. } => {
1040 let (first, second) = {
1041 let mut iter = content.into_iter();
1042 (iter.next().unwrap(), iter.next().unwrap())
1043 };
1044 assert_eq!(
1045 first,
1046 UserContent::Text {
1047 text: "What's in this image?".to_string()
1048 }
1049 );
1050 assert_eq!(second, UserContent::ImageUrl { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string() } });
1051 }
1052 _ => panic!("Expected user message"),
1053 }
1054 }
1055
1056 #[test]
1057 fn test_message_to_message_conversion() {
1058 let user_message = message::Message::User {
1059 content: OneOrMany::one(message::UserContent::text("Hello")),
1060 };
1061
1062 let assistant_message = message::Message::Assistant {
1063 id: None,
1064 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
1065 };
1066
1067 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
1068 let converted_assistant_message: Vec<Message> =
1069 assistant_message.clone().try_into().unwrap();
1070
1071 match converted_user_message[0].clone() {
1072 Message::User { content, .. } => {
1073 assert_eq!(
1074 content.first(),
1075 UserContent::Text {
1076 text: "Hello".to_string()
1077 }
1078 );
1079 }
1080 _ => panic!("Expected user message"),
1081 }
1082
1083 match converted_assistant_message[0].clone() {
1084 Message::Assistant { content, .. } => {
1085 assert_eq!(
1086 content[0],
1087 AssistantContent::Text {
1088 text: "Hi there!".to_string()
1089 }
1090 );
1091 }
1092 _ => panic!("Expected assistant message"),
1093 }
1094
1095 let original_user_message: message::Message =
1096 converted_user_message[0].clone().try_into().unwrap();
1097 let original_assistant_message: message::Message =
1098 converted_assistant_message[0].clone().try_into().unwrap();
1099
1100 assert_eq!(original_user_message, user_message);
1101 assert_eq!(original_assistant_message, assistant_message);
1102 }
1103
1104 #[test]
1105 fn test_message_from_message_conversion() {
1106 let user_message = Message::User {
1107 content: OneOrMany::one(UserContent::Text {
1108 text: "Hello".to_string(),
1109 }),
1110 };
1111
1112 let assistant_message = Message::Assistant {
1113 content: vec![AssistantContent::Text {
1114 text: "Hi there!".to_string(),
1115 }],
1116 tool_calls: vec![],
1117 };
1118
1119 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1120 let converted_assistant_message: message::Message =
1121 assistant_message.clone().try_into().unwrap();
1122
1123 match converted_user_message.clone() {
1124 message::Message::User { content } => {
1125 assert_eq!(content.first(), message::UserContent::text("Hello"));
1126 }
1127 _ => panic!("Expected user message"),
1128 }
1129
1130 match converted_assistant_message.clone() {
1131 message::Message::Assistant { content, .. } => {
1132 assert_eq!(
1133 content.first(),
1134 message::AssistantContent::text("Hi there!")
1135 );
1136 }
1137 _ => panic!("Expected assistant message"),
1138 }
1139
1140 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
1141 let original_assistant_message: Vec<Message> =
1142 converted_assistant_message.try_into().unwrap();
1143
1144 assert_eq!(original_user_message[0], user_message);
1145 assert_eq!(original_assistant_message[0], assistant_message);
1146 }
1147
1148 #[test]
1149 fn test_responses() {
1150 let fireworks_response_json = r#"
1151 {
1152 "choices": [
1153 {
1154 "finish_reason": "tool_calls",
1155 "index": 0,
1156 "message": {
1157 "role": "assistant",
1158 "tool_calls": [
1159 {
1160 "function": {
1161 "arguments": "{\"x\": 2, \"y\": 5}",
1162 "name": "subtract"
1163 },
1164 "id": "call_1BspL6mQqjKgvsQbH1TIYkHf",
1165 "index": 0,
1166 "type": "function"
1167 }
1168 ]
1169 }
1170 }
1171 ],
1172 "created": 1740704000,
1173 "id": "2a81f6a1-4866-42fb-9902-2655a2b5b1ff",
1174 "model": "accounts/fireworks/models/deepseek-v3",
1175 "object": "chat.completion",
1176 "usage": {
1177 "completion_tokens": 26,
1178 "prompt_tokens": 248,
1179 "total_tokens": 274
1180 }
1181 }
1182 "#;
1183
1184 let novita_response_json = r#"
1185 {
1186 "choices": [
1187 {
1188 "finish_reason": "tool_calls",
1189 "index": 0,
1190 "logprobs": null,
1191 "message": {
1192 "audio": null,
1193 "content": null,
1194 "function_call": null,
1195 "reasoning_content": null,
1196 "refusal": null,
1197 "role": "assistant",
1198 "tool_calls": [
1199 {
1200 "function": {
1201 "arguments": "{\"x\": \"2\", \"y\": \"5\"}",
1202 "name": "subtract"
1203 },
1204 "id": "chatcmpl-tool-f6d2af7c8dc041058f95e2c2eede45c5",
1205 "type": "function"
1206 }
1207 ]
1208 },
1209 "stop_reason": 128008
1210 }
1211 ],
1212 "created": 1740704592,
1213 "id": "chatcmpl-a92c60ae125c47c998ecdcb53387fed4",
1214 "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-fast",
1215 "object": "chat.completion",
1216 "prompt_logprobs": null,
1217 "service_tier": null,
1218 "system_fingerprint": null,
1219 "usage": {
1220 "completion_tokens": 28,
1221 "completion_tokens_details": null,
1222 "prompt_tokens": 335,
1223 "prompt_tokens_details": null,
1224 "total_tokens": 363
1225 }
1226 }
1227 "#;
1228
1229 let _firework_response: CompletionResponse = {
1230 let jd = &mut serde_json::Deserializer::from_str(fireworks_response_json);
1231 deserialize(jd).unwrap_or_else(|err| {
1232 panic!(
1233 "Deserialization error at {} ({}:{}): {}",
1234 err.path(),
1235 err.inner().line(),
1236 err.inner().column(),
1237 err
1238 );
1239 })
1240 };
1241
1242 let _novita_response: CompletionResponse = {
1243 let jd = &mut serde_json::Deserializer::from_str(novita_response_json);
1244 deserialize(jd).unwrap_or_else(|err| {
1245 panic!(
1246 "Deserialization error at {} ({}:{}): {}",
1247 err.path(),
1248 err.inner().line(),
1249 err.inner().column(),
1250 err
1251 );
1252 })
1253 };
1254 }
1255
1256 #[test]
1257 fn test_assistant_reasoning_is_silently_skipped() {
1258 let assistant = message::Message::Assistant {
1259 id: None,
1260 content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1261 };
1262
1263 let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
1264 assert!(converted.is_empty());
1265 }
1266
1267 #[test]
1268 fn test_assistant_text_and_tool_call_are_preserved_when_reasoning_present() {
1269 let assistant = message::Message::Assistant {
1270 id: None,
1271 content: OneOrMany::many(vec![
1272 message::AssistantContent::reasoning("hidden"),
1273 message::AssistantContent::text("visible"),
1274 message::AssistantContent::tool_call(
1275 "call_1",
1276 "subtract",
1277 serde_json::json!({"x": 2, "y": 1}),
1278 ),
1279 ])
1280 .expect("non-empty assistant content"),
1281 };
1282
1283 let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
1284 assert_eq!(converted.len(), 1);
1285
1286 match &converted[0] {
1287 Message::Assistant {
1288 content,
1289 tool_calls,
1290 ..
1291 } => {
1292 assert_eq!(
1293 content,
1294 &vec![AssistantContent::Text {
1295 text: "visible".to_string()
1296 }]
1297 );
1298 assert_eq!(tool_calls.len(), 1);
1299 assert_eq!(tool_calls[0].id, "call_1");
1300 assert_eq!(tool_calls[0].function.name, "subtract");
1301 assert_eq!(
1302 tool_calls[0].function.arguments,
1303 serde_json::json!({"x": 2, "y": 1})
1304 );
1305 }
1306 _ => panic!("expected assistant message"),
1307 }
1308 }
1309
1310 #[test]
1311 fn test_request_conversion_errors_when_all_messages_are_filtered() {
1312 let request = completion::CompletionRequest {
1313 preamble: None,
1314 chat_history: OneOrMany::one(message::Message::Assistant {
1315 id: None,
1316 content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1317 }),
1318 documents: vec![],
1319 tools: vec![],
1320 temperature: None,
1321 max_tokens: None,
1322 tool_choice: None,
1323 additional_params: None,
1324 model: None,
1325 output_schema: None,
1326 };
1327
1328 let result = HuggingfaceCompletionRequest::try_from(("meta/test-model", request));
1329 assert!(matches!(result, Err(CompletionError::RequestError(_))));
1330 }
1331}