1use super::client::Client;
2use crate::completion::GetTokenUsage;
3use crate::http_client::HttpClientExt;
4use crate::providers::openai::StreamingCompletionResponse;
5use crate::telemetry::SpanCombinator;
6use crate::{
7 OneOrMany,
8 completion::{self, CompletionError, CompletionRequest},
9 json_utils,
10 message::{self},
11 one_or_many::string_or_one_or_many,
12};
13use serde::{Deserialize, Deserializer, Serialize, Serializer};
14use serde_json::Value;
15use std::{convert::Infallible, str::FromStr};
16use tracing::{Level, enabled, info_span};
17use tracing_futures::Instrument;
18
19#[derive(Debug, Deserialize)]
20#[serde(untagged)]
21pub enum ApiResponse<T> {
22 Ok(T),
23 Err(Value),
24}
25
26pub const GEMMA_2: &str = "google/gemma-2-2b-it";
33pub const META_LLAMA_3_1: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
35pub const SMALLTHINKER_PREVIEW: &str = "PowerInfer/SmallThinker-3B-Preview";
37pub const QWEN2_5: &str = "Qwen/Qwen2.5-7B-Instruct";
39pub const QWEN2_5_CODER: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
41
42pub const QWEN2_VL: &str = "Qwen/Qwen2-VL-7B-Instruct";
46pub const QWEN_QVQ_PREVIEW: &str = "Qwen/QVQ-72B-Preview";
48
49#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
50pub struct Function {
51 name: String,
52 #[serde(
53 serialize_with = "json_utils::stringified_json::serialize",
54 deserialize_with = "deserialize_arguments"
55 )]
56 pub arguments: serde_json::Value,
57}
58
59fn deserialize_arguments<'de, D>(deserializer: D) -> Result<Value, D::Error>
60where
61 D: Deserializer<'de>,
62{
63 let value = Value::deserialize(deserializer)?;
64
65 match value {
66 Value::String(s) => serde_json::from_str(&s).map_err(serde::de::Error::custom),
67 other => Ok(other),
68 }
69}
70
71impl From<Function> for message::ToolFunction {
72 fn from(value: Function) -> Self {
73 message::ToolFunction {
74 name: value.name,
75 arguments: value.arguments,
76 }
77 }
78}
79
80#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
81#[serde(rename_all = "lowercase")]
82pub enum ToolType {
83 #[default]
84 Function,
85}
86
87#[derive(Debug, Deserialize, Serialize, Clone)]
88pub struct ToolDefinition {
89 pub r#type: String,
90 pub function: completion::ToolDefinition,
91}
92
93impl From<completion::ToolDefinition> for ToolDefinition {
94 fn from(tool: completion::ToolDefinition) -> Self {
95 Self {
96 r#type: "function".into(),
97 function: tool,
98 }
99 }
100}
101
102#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
103pub struct ToolCall {
104 pub id: String,
105 pub r#type: ToolType,
106 pub function: Function,
107}
108
109impl From<ToolCall> for message::ToolCall {
110 fn from(value: ToolCall) -> Self {
111 message::ToolCall {
112 id: value.id,
113 call_id: None,
114 function: value.function.into(),
115 signature: None,
116 additional_params: None,
117 }
118 }
119}
120
121impl From<message::ToolCall> for ToolCall {
122 fn from(value: message::ToolCall) -> Self {
123 ToolCall {
124 id: value.id,
125 r#type: ToolType::Function,
126 function: Function {
127 name: value.function.name,
128 arguments: value.function.arguments,
129 },
130 }
131 }
132}
133
134#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
135pub struct ImageUrl {
136 url: String,
137}
138
139#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
140#[serde(tag = "type", rename_all = "lowercase")]
141pub enum UserContent {
142 Text {
143 text: String,
144 },
145 #[serde(rename = "image_url")]
146 ImageUrl {
147 image_url: ImageUrl,
148 },
149}
150
151impl FromStr for UserContent {
152 type Err = Infallible;
153
154 fn from_str(s: &str) -> Result<Self, Self::Err> {
155 Ok(UserContent::Text {
156 text: s.to_string(),
157 })
158 }
159}
160
161#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
162#[serde(tag = "type", rename_all = "lowercase")]
163pub enum AssistantContent {
164 Text { text: String },
165}
166
167impl FromStr for AssistantContent {
168 type Err = Infallible;
169
170 fn from_str(s: &str) -> Result<Self, Self::Err> {
171 Ok(AssistantContent::Text {
172 text: s.to_string(),
173 })
174 }
175}
176
177#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
178#[serde(tag = "type", rename_all = "lowercase")]
179pub enum SystemContent {
180 Text { text: String },
181}
182
183impl FromStr for SystemContent {
184 type Err = Infallible;
185
186 fn from_str(s: &str) -> Result<Self, Self::Err> {
187 Ok(SystemContent::Text {
188 text: s.to_string(),
189 })
190 }
191}
192
193impl From<UserContent> for message::UserContent {
194 fn from(value: UserContent) -> Self {
195 match value {
196 UserContent::Text { text } => message::UserContent::text(text),
197 UserContent::ImageUrl { image_url } => {
198 message::UserContent::image_url(image_url.url, None, None)
199 }
200 }
201 }
202}
203
204impl TryFrom<message::UserContent> for UserContent {
205 type Error = message::MessageError;
206
207 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
208 match content {
209 message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }),
210 message::UserContent::Document(message::Document {
211 data: message::DocumentSourceKind::Raw(raw),
212 ..
213 }) => {
214 let text = String::from_utf8_lossy(raw.as_slice()).into();
215 Ok(UserContent::Text { text })
216 }
217 message::UserContent::Document(message::Document {
218 data:
219 message::DocumentSourceKind::Base64(text)
220 | message::DocumentSourceKind::String(text),
221 ..
222 }) => Ok(UserContent::Text { text }),
223 message::UserContent::Image(message::Image { data, .. }) => match data {
224 message::DocumentSourceKind::Url(url) => Ok(UserContent::ImageUrl {
225 image_url: ImageUrl { url },
226 }),
227 _ => Err(message::MessageError::ConversionError(
228 "Huggingface only supports images as urls".into(),
229 )),
230 },
231 _ => Err(message::MessageError::ConversionError(
232 "Huggingface only supports text and images".into(),
233 )),
234 }
235 }
236}
237
238#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
239#[serde(tag = "role", rename_all = "lowercase")]
240pub enum Message {
241 System {
242 #[serde(deserialize_with = "string_or_one_or_many")]
243 content: OneOrMany<SystemContent>,
244 },
245 User {
246 #[serde(deserialize_with = "string_or_one_or_many")]
247 content: OneOrMany<UserContent>,
248 },
249 Assistant {
250 #[serde(default, deserialize_with = "json_utils::string_or_vec")]
251 content: Vec<AssistantContent>,
252 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
253 tool_calls: Vec<ToolCall>,
254 },
255 #[serde(rename = "tool", alias = "Tool")]
256 ToolResult {
257 name: String,
258 #[serde(skip_serializing_if = "Option::is_none")]
259 arguments: Option<serde_json::Value>,
260 #[serde(
261 deserialize_with = "string_or_one_or_many",
262 serialize_with = "serialize_tool_content"
263 )]
264 content: OneOrMany<String>,
265 },
266}
267
268fn serialize_tool_content<S>(content: &OneOrMany<String>, serializer: S) -> Result<S::Ok, S::Error>
269where
270 S: Serializer,
271{
272 let joined = content
274 .iter()
275 .map(String::as_str)
276 .collect::<Vec<_>>()
277 .join("\n");
278 serializer.serialize_str(&joined)
279}
280
281impl Message {
282 pub fn system(content: &str) -> Self {
283 Message::System {
284 content: OneOrMany::one(SystemContent::Text {
285 text: content.to_string(),
286 }),
287 }
288 }
289}
290
291impl TryFrom<message::Message> for Vec<Message> {
292 type Error = message::MessageError;
293
294 fn try_from(message: message::Message) -> Result<Vec<Message>, Self::Error> {
295 match message {
296 message::Message::System { content } => Ok(vec![Message::system(&content)]),
297 message::Message::User { content } => {
298 let (tool_results, other_content): (Vec<_>, Vec<_>) = content
299 .into_iter()
300 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
301
302 if !tool_results.is_empty() {
303 tool_results
304 .into_iter()
305 .map(|content| match content {
306 message::UserContent::ToolResult(message::ToolResult {
307 id,
308 content,
309 ..
310 }) => Ok::<_, message::MessageError>(Message::ToolResult {
311 name: id,
312 arguments: None,
313 content: content.try_map(|content| match content {
314 message::ToolResultContent::Text(message::Text { text }) => {
315 Ok(text)
316 }
317 _ => Err(message::MessageError::ConversionError(
318 "Tool result content does not support non-text".into(),
319 )),
320 })?,
321 }),
322 _ => unreachable!(),
323 })
324 .collect::<Result<Vec<_>, _>>()
325 } else {
326 let other_content = OneOrMany::many(other_content).expect(
327 "There must be other content here if there were no tool result content",
328 );
329
330 Ok(vec![Message::User {
331 content: other_content.try_map(|content| match content {
332 message::UserContent::Text(text) => {
333 Ok(UserContent::Text { text: text.text })
334 }
335 message::UserContent::Image(image) => {
336 let url = image.try_into_url()?;
337
338 Ok(UserContent::ImageUrl {
339 image_url: ImageUrl { url },
340 })
341 }
342 message::UserContent::Document(message::Document {
343 data: message::DocumentSourceKind::Raw(raw), ..
344 }) => {
345 let text = String::from_utf8_lossy(raw.as_slice()).into();
346 Ok(UserContent::Text { text })
347 }
348 message::UserContent::Document(message::Document {
349 data: message::DocumentSourceKind::Base64(text) | message::DocumentSourceKind::String(text), ..
350 }) => {
351 Ok(UserContent::Text { text })
352 }
353 _ => Err(message::MessageError::ConversionError(
354 "Huggingface inputs only support text and image URLs (both base64-encoded images and regular URLs)".into(),
355 )),
356 })?,
357 }])
358 }
359 }
360 message::Message::Assistant { content, .. } => {
361 let mut text_content = Vec::new();
362 let mut tool_calls = Vec::new();
363
364 for content in content {
365 match content {
366 message::AssistantContent::Text(text) => text_content.push(text),
367 message::AssistantContent::ToolCall(tool_call) => {
368 tool_calls.push(tool_call)
369 }
370 message::AssistantContent::Reasoning(_) => {
371 }
374 message::AssistantContent::Image(_) => {
375 panic!("Image content is not supported on HuggingFace via Rig");
376 }
377 }
378 }
379
380 if text_content.is_empty() && tool_calls.is_empty() {
381 return Ok(vec![]);
382 }
383
384 Ok(vec![Message::Assistant {
385 content: text_content
386 .into_iter()
387 .map(|content| AssistantContent::Text { text: content.text })
388 .collect::<Vec<_>>(),
389 tool_calls: tool_calls
390 .into_iter()
391 .map(|tool_call| tool_call.into())
392 .collect::<Vec<_>>(),
393 }])
394 }
395 }
396 }
397}
398
399impl TryFrom<Message> for message::Message {
400 type Error = message::MessageError;
401
402 fn try_from(message: Message) -> Result<Self, Self::Error> {
403 Ok(match message {
404 Message::User { content, .. } => message::Message::User {
405 content: content.map(|content| content.into()),
406 },
407 Message::Assistant {
408 content,
409 tool_calls,
410 ..
411 } => {
412 let mut content = content
413 .into_iter()
414 .map(|content| match content {
415 AssistantContent::Text { text } => message::AssistantContent::text(text),
416 })
417 .collect::<Vec<_>>();
418
419 content.extend(
420 tool_calls
421 .into_iter()
422 .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
423 .collect::<Result<Vec<_>, _>>()?,
424 );
425
426 message::Message::Assistant {
427 id: None,
428 content: OneOrMany::many(content).map_err(|_| {
429 message::MessageError::ConversionError(
430 "Neither `content` nor `tool_calls` was provided to the Message"
431 .to_owned(),
432 )
433 })?,
434 }
435 }
436
437 Message::ToolResult { name, content, .. } => message::Message::User {
438 content: OneOrMany::one(message::UserContent::tool_result(
439 name,
440 content.map(message::ToolResultContent::text),
441 )),
442 },
443
444 Message::System { content, .. } => message::Message::User {
447 content: content.map(|c| match c {
448 SystemContent::Text { text } => message::UserContent::text(text),
449 }),
450 },
451 })
452 }
453}
454
455#[derive(Clone, Debug, Deserialize, Serialize)]
456pub struct Choice {
457 pub finish_reason: String,
458 pub index: usize,
459 #[serde(default)]
460 pub logprobs: serde_json::Value,
461 pub message: Message,
462}
463
464#[derive(Debug, Deserialize, Clone, Serialize)]
465pub struct Usage {
466 pub completion_tokens: i32,
467 pub prompt_tokens: i32,
468 pub total_tokens: i32,
469}
470
471impl GetTokenUsage for Usage {
472 fn token_usage(&self) -> Option<crate::completion::Usage> {
473 let mut usage = crate::completion::Usage::new();
474 usage.input_tokens = self.prompt_tokens as u64;
475 usage.output_tokens = self.completion_tokens as u64;
476 usage.total_tokens = self.total_tokens as u64;
477
478 Some(usage)
479 }
480}
481
482#[derive(Clone, Debug, Deserialize, Serialize)]
483pub struct CompletionResponse {
484 pub created: i32,
485 pub id: String,
486 pub model: String,
487 pub choices: Vec<Choice>,
488 #[serde(default, deserialize_with = "default_string_on_null")]
489 pub system_fingerprint: String,
490 pub usage: Usage,
491}
492
493impl crate::telemetry::ProviderResponseExt for CompletionResponse {
494 type OutputMessage = Choice;
495 type Usage = Usage;
496
497 fn get_response_id(&self) -> Option<String> {
498 Some(self.id.clone())
499 }
500
501 fn get_response_model_name(&self) -> Option<String> {
502 Some(self.model.clone())
503 }
504
505 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
506 self.choices.clone()
507 }
508
509 fn get_text_response(&self) -> Option<String> {
510 let text_response = self
511 .choices
512 .iter()
513 .filter_map(|x| {
514 let Message::User { ref content } = x.message else {
515 return None;
516 };
517
518 let text = content
519 .iter()
520 .filter_map(|x| {
521 if let UserContent::Text { text } = x {
522 Some(text.clone())
523 } else {
524 None
525 }
526 })
527 .collect::<Vec<String>>();
528
529 if text.is_empty() {
530 None
531 } else {
532 Some(text.join("\n"))
533 }
534 })
535 .collect::<Vec<String>>()
536 .join("\n");
537
538 if text_response.is_empty() {
539 None
540 } else {
541 Some(text_response)
542 }
543 }
544
545 fn get_usage(&self) -> Option<Self::Usage> {
546 Some(self.usage.clone())
547 }
548}
549
550fn default_string_on_null<'de, D>(deserializer: D) -> Result<String, D::Error>
551where
552 D: Deserializer<'de>,
553{
554 match Option::<String>::deserialize(deserializer)? {
555 Some(value) => Ok(value), None => Ok(String::default()), }
558}
559
560impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
561 type Error = CompletionError;
562
563 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
564 let choice = response.choices.first().ok_or_else(|| {
565 CompletionError::ResponseError("Response contained no choices".to_owned())
566 })?;
567
568 let content = match &choice.message {
569 Message::Assistant {
570 content,
571 tool_calls,
572 ..
573 } => {
574 let mut content = content
575 .iter()
576 .map(|c| match c {
577 AssistantContent::Text { text } => message::AssistantContent::text(text),
578 })
579 .collect::<Vec<_>>();
580
581 content.extend(
582 tool_calls
583 .iter()
584 .map(|call| {
585 completion::AssistantContent::tool_call(
586 &call.id,
587 &call.function.name,
588 call.function.arguments.clone(),
589 )
590 })
591 .collect::<Vec<_>>(),
592 );
593 Ok(content)
594 }
595 _ => Err(CompletionError::ResponseError(
596 "Response did not contain a valid message or tool call".into(),
597 )),
598 }?;
599
600 let choice = OneOrMany::many(content).map_err(|_| {
601 CompletionError::ResponseError(
602 "Response contained no message or tool call (empty)".to_owned(),
603 )
604 })?;
605
606 let usage = completion::Usage {
607 input_tokens: response.usage.prompt_tokens as u64,
608 output_tokens: response.usage.completion_tokens as u64,
609 total_tokens: response.usage.total_tokens as u64,
610 cached_input_tokens: 0,
611 };
612
613 Ok(completion::CompletionResponse {
614 choice,
615 usage,
616 raw_response: response,
617 message_id: None,
618 })
619 }
620}
621
622#[derive(Debug, Serialize, Deserialize)]
623pub(super) struct HuggingfaceCompletionRequest {
624 model: String,
625 pub messages: Vec<Message>,
626 #[serde(skip_serializing_if = "Option::is_none")]
627 temperature: Option<f64>,
628 #[serde(skip_serializing_if = "Vec::is_empty")]
629 tools: Vec<ToolDefinition>,
630 #[serde(skip_serializing_if = "Option::is_none")]
631 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
632 #[serde(flatten, skip_serializing_if = "Option::is_none")]
633 pub additional_params: Option<serde_json::Value>,
634}
635
636impl TryFrom<(&str, CompletionRequest)> for HuggingfaceCompletionRequest {
637 type Error = CompletionError;
638
639 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
640 if req.output_schema.is_some() {
641 tracing::warn!("Structured outputs currently not supported for Huggingface");
642 }
643 let model = req.model.clone().unwrap_or_else(|| model.to_string());
644 let mut full_history: Vec<Message> = match &req.preamble {
645 Some(preamble) => vec![Message::system(preamble)],
646 None => vec![],
647 };
648 if let Some(docs) = req.normalized_documents() {
649 let docs: Vec<Message> = docs.try_into()?;
650 full_history.extend(docs);
651 }
652
653 let chat_history: Vec<Message> = req
654 .chat_history
655 .clone()
656 .into_iter()
657 .map(|message| message.try_into())
658 .collect::<Result<Vec<Vec<Message>>, _>>()?
659 .into_iter()
660 .flatten()
661 .collect();
662
663 full_history.extend(chat_history);
664
665 if full_history.is_empty() {
666 return Err(CompletionError::RequestError(
667 std::io::Error::new(
668 std::io::ErrorKind::InvalidInput,
669 "HuggingFace request has no provider-compatible messages after conversion",
670 )
671 .into(),
672 ));
673 }
674
675 let tool_choice = req
676 .tool_choice
677 .clone()
678 .map(crate::providers::openai::completion::ToolChoice::try_from)
679 .transpose()?;
680
681 Ok(Self {
682 model: model.to_string(),
683 messages: full_history,
684 temperature: req.temperature,
685 tools: req
686 .tools
687 .clone()
688 .into_iter()
689 .map(ToolDefinition::from)
690 .collect::<Vec<_>>(),
691 tool_choice,
692 additional_params: req.additional_params,
693 })
694 }
695}
696
697#[derive(Clone)]
698pub struct CompletionModel<T = reqwest::Client> {
699 pub(crate) client: Client<T>,
700 pub model: String,
702}
703
704impl<T> CompletionModel<T> {
705 pub fn new(client: Client<T>, model: &str) -> Self {
706 Self {
707 client,
708 model: model.to_string(),
709 }
710 }
711}
712
713impl<T> completion::CompletionModel for CompletionModel<T>
714where
715 T: HttpClientExt + Clone + 'static,
716{
717 type Response = CompletionResponse;
718 type StreamingResponse = StreamingCompletionResponse;
719
720 type Client = Client<T>;
721
722 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
723 Self::new(client.clone(), &model.into())
724 }
725
726 async fn completion(
727 &self,
728 completion_request: CompletionRequest,
729 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
730 let request_model = completion_request
731 .model
732 .clone()
733 .unwrap_or_else(|| self.model.clone());
734 let span = if tracing::Span::current().is_disabled() {
735 info_span!(
736 target: "rig::completions",
737 "chat",
738 gen_ai.operation.name = "chat",
739 gen_ai.provider.name = "huggingface",
740 gen_ai.request.model = &request_model,
741 gen_ai.system_instructions = &completion_request.preamble,
742 gen_ai.response.id = tracing::field::Empty,
743 gen_ai.response.model = tracing::field::Empty,
744 gen_ai.usage.output_tokens = tracing::field::Empty,
745 gen_ai.usage.input_tokens = tracing::field::Empty,
746 gen_ai.usage.cached_tokens = tracing::field::Empty,
747 )
748 } else {
749 tracing::Span::current()
750 };
751
752 let model = self.client.subprovider().model_identifier(&request_model);
753 let request = HuggingfaceCompletionRequest::try_from((model.as_ref(), completion_request))?;
754
755 if enabled!(Level::TRACE) {
756 tracing::trace!(
757 target: "rig::completions",
758 "Huggingface completion request: {}",
759 serde_json::to_string_pretty(&request)?
760 );
761 }
762
763 let request = serde_json::to_vec(&request)?;
764
765 let path = self
766 .client
767 .subprovider()
768 .completion_endpoint(&request_model);
769 let request = self
770 .client
771 .post(&path)?
772 .header("Content-Type", "application/json")
773 .body(request)
774 .map_err(|e| CompletionError::HttpError(e.into()))?;
775
776 async move {
777 let response = self.client.send(request).await?;
778
779 if response.status().is_success() {
780 let bytes: Vec<u8> = response.into_body().await?;
781 let text = String::from_utf8_lossy(&bytes);
782
783 tracing::debug!(target: "rig", "Huggingface completion error: {}", text);
784
785 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&bytes)? {
786 ApiResponse::Ok(response) => {
787 if enabled!(Level::TRACE) {
788 tracing::trace!(
789 target: "rig::completions",
790 "Huggingface completion response: {}",
791 serde_json::to_string_pretty(&response)?
792 );
793 }
794
795 let span = tracing::Span::current();
796 span.record_token_usage(&response.usage);
797 span.record_response_metadata(&response);
798
799 response.try_into()
800 }
801 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())),
802 }
803 } else {
804 let status = response.status();
805 let text: Vec<u8> = response.into_body().await?;
806 let text: String = String::from_utf8_lossy(&text).into();
807
808 Err(CompletionError::ProviderError(format!(
809 "{}: {}",
810 status, text
811 )))
812 }
813 }
814 .instrument(span)
815 .await
816 }
817
818 async fn stream(
819 &self,
820 request: CompletionRequest,
821 ) -> Result<
822 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
823 CompletionError,
824 > {
825 CompletionModel::stream(self, request).await
826 }
827}
828
829#[cfg(test)]
830mod tests {
831 use super::*;
832 use serde_path_to_error::deserialize;
833
834 #[test]
835 fn test_huggingface_request_uses_request_model_override() {
836 let request = CompletionRequest {
837 model: Some("meta-llama/Meta-Llama-3.1-8B-Instruct".to_string()),
838 preamble: None,
839 chat_history: crate::OneOrMany::one("Hello".into()),
840 documents: vec![],
841 tools: vec![],
842 temperature: None,
843 max_tokens: None,
844 tool_choice: None,
845 additional_params: None,
846 output_schema: None,
847 };
848
849 let hf_request = HuggingfaceCompletionRequest::try_from(("mistralai/Mistral-7B", request))
850 .expect("request conversion should succeed");
851 let serialized = serde_json::to_value(hf_request).expect("serialization should succeed");
852
853 assert_eq!(serialized["model"], "meta-llama/Meta-Llama-3.1-8B-Instruct");
854 }
855
856 #[test]
857 fn test_huggingface_request_uses_default_model_when_override_unset() {
858 let request = CompletionRequest {
859 model: None,
860 preamble: None,
861 chat_history: crate::OneOrMany::one("Hello".into()),
862 documents: vec![],
863 tools: vec![],
864 temperature: None,
865 max_tokens: None,
866 tool_choice: None,
867 additional_params: None,
868 output_schema: None,
869 };
870
871 let hf_request = HuggingfaceCompletionRequest::try_from(("mistralai/Mistral-7B", request))
872 .expect("request conversion should succeed");
873 let serialized = serde_json::to_value(hf_request).expect("serialization should succeed");
874
875 assert_eq!(serialized["model"], "mistralai/Mistral-7B");
876 }
877
878 #[test]
879 fn test_deserialize_message() {
880 let assistant_message_json = r#"
881 {
882 "role": "assistant",
883 "content": "\n\nHello there, how may I assist you today?"
884 }
885 "#;
886
887 let assistant_message_json2 = r#"
888 {
889 "role": "assistant",
890 "content": [
891 {
892 "type": "text",
893 "text": "\n\nHello there, how may I assist you today?"
894 }
895 ],
896 "tool_calls": null
897 }
898 "#;
899
900 let assistant_message_json3 = r#"
901 {
902 "role": "assistant",
903 "tool_calls": [
904 {
905 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
906 "type": "function",
907 "function": {
908 "name": "subtract",
909 "arguments": {"x": 2, "y": 5}
910 }
911 }
912 ],
913 "content": null,
914 "refusal": null
915 }
916 "#;
917
918 let user_message_json = r#"
919 {
920 "role": "user",
921 "content": [
922 {
923 "type": "text",
924 "text": "What's in this image?"
925 },
926 {
927 "type": "image_url",
928 "image_url": {
929 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
930 }
931 }
932 ]
933 }
934 "#;
935
936 let assistant_message: Message = {
937 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
938 deserialize(jd).unwrap_or_else(|err| {
939 panic!(
940 "Deserialization error at {} ({}:{}): {}",
941 err.path(),
942 err.inner().line(),
943 err.inner().column(),
944 err
945 );
946 })
947 };
948
949 let assistant_message2: Message = {
950 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
951 deserialize(jd).unwrap_or_else(|err| {
952 panic!(
953 "Deserialization error at {} ({}:{}): {}",
954 err.path(),
955 err.inner().line(),
956 err.inner().column(),
957 err
958 );
959 })
960 };
961
962 let assistant_message3: Message = {
963 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
964 &mut serde_json::Deserializer::from_str(assistant_message_json3);
965 deserialize(jd).unwrap_or_else(|err| {
966 panic!(
967 "Deserialization error at {} ({}:{}): {}",
968 err.path(),
969 err.inner().line(),
970 err.inner().column(),
971 err
972 );
973 })
974 };
975
976 let user_message: Message = {
977 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
978 deserialize(jd).unwrap_or_else(|err| {
979 panic!(
980 "Deserialization error at {} ({}:{}): {}",
981 err.path(),
982 err.inner().line(),
983 err.inner().column(),
984 err
985 );
986 })
987 };
988
989 match assistant_message {
990 Message::Assistant { content, .. } => {
991 assert_eq!(
992 content[0],
993 AssistantContent::Text {
994 text: "\n\nHello there, how may I assist you today?".to_string()
995 }
996 );
997 }
998 _ => panic!("Expected assistant message"),
999 }
1000
1001 match assistant_message2 {
1002 Message::Assistant {
1003 content,
1004 tool_calls,
1005 ..
1006 } => {
1007 assert_eq!(
1008 content[0],
1009 AssistantContent::Text {
1010 text: "\n\nHello there, how may I assist you today?".to_string()
1011 }
1012 );
1013
1014 assert_eq!(tool_calls, vec![]);
1015 }
1016 _ => panic!("Expected assistant message"),
1017 }
1018
1019 match assistant_message3 {
1020 Message::Assistant {
1021 content,
1022 tool_calls,
1023 ..
1024 } => {
1025 assert!(content.is_empty());
1026 assert_eq!(
1027 tool_calls[0],
1028 ToolCall {
1029 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
1030 r#type: ToolType::Function,
1031 function: Function {
1032 name: "subtract".to_string(),
1033 arguments: serde_json::json!({"x": 2, "y": 5}),
1034 },
1035 }
1036 );
1037 }
1038 _ => panic!("Expected assistant message"),
1039 }
1040
1041 match user_message {
1042 Message::User { content, .. } => {
1043 let (first, second) = {
1044 let mut iter = content.into_iter();
1045 (iter.next().unwrap(), iter.next().unwrap())
1046 };
1047 assert_eq!(
1048 first,
1049 UserContent::Text {
1050 text: "What's in this image?".to_string()
1051 }
1052 );
1053 assert_eq!(second, UserContent::ImageUrl { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string() } });
1054 }
1055 _ => panic!("Expected user message"),
1056 }
1057 }
1058
1059 #[test]
1060 fn test_message_to_message_conversion() {
1061 let user_message = message::Message::User {
1062 content: OneOrMany::one(message::UserContent::text("Hello")),
1063 };
1064
1065 let assistant_message = message::Message::Assistant {
1066 id: None,
1067 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
1068 };
1069
1070 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
1071 let converted_assistant_message: Vec<Message> =
1072 assistant_message.clone().try_into().unwrap();
1073
1074 match converted_user_message[0].clone() {
1075 Message::User { content, .. } => {
1076 assert_eq!(
1077 content.first(),
1078 UserContent::Text {
1079 text: "Hello".to_string()
1080 }
1081 );
1082 }
1083 _ => panic!("Expected user message"),
1084 }
1085
1086 match converted_assistant_message[0].clone() {
1087 Message::Assistant { content, .. } => {
1088 assert_eq!(
1089 content[0],
1090 AssistantContent::Text {
1091 text: "Hi there!".to_string()
1092 }
1093 );
1094 }
1095 _ => panic!("Expected assistant message"),
1096 }
1097
1098 let original_user_message: message::Message =
1099 converted_user_message[0].clone().try_into().unwrap();
1100 let original_assistant_message: message::Message =
1101 converted_assistant_message[0].clone().try_into().unwrap();
1102
1103 assert_eq!(original_user_message, user_message);
1104 assert_eq!(original_assistant_message, assistant_message);
1105 }
1106
1107 #[test]
1108 fn test_message_from_message_conversion() {
1109 let user_message = Message::User {
1110 content: OneOrMany::one(UserContent::Text {
1111 text: "Hello".to_string(),
1112 }),
1113 };
1114
1115 let assistant_message = Message::Assistant {
1116 content: vec![AssistantContent::Text {
1117 text: "Hi there!".to_string(),
1118 }],
1119 tool_calls: vec![],
1120 };
1121
1122 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1123 let converted_assistant_message: message::Message =
1124 assistant_message.clone().try_into().unwrap();
1125
1126 match converted_user_message.clone() {
1127 message::Message::User { content } => {
1128 assert_eq!(content.first(), message::UserContent::text("Hello"));
1129 }
1130 _ => panic!("Expected user message"),
1131 }
1132
1133 match converted_assistant_message.clone() {
1134 message::Message::Assistant { content, .. } => {
1135 assert_eq!(
1136 content.first(),
1137 message::AssistantContent::text("Hi there!")
1138 );
1139 }
1140 _ => panic!("Expected assistant message"),
1141 }
1142
1143 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
1144 let original_assistant_message: Vec<Message> =
1145 converted_assistant_message.try_into().unwrap();
1146
1147 assert_eq!(original_user_message[0], user_message);
1148 assert_eq!(original_assistant_message[0], assistant_message);
1149 }
1150
1151 #[test]
1152 fn test_responses() {
1153 let fireworks_response_json = r#"
1154 {
1155 "choices": [
1156 {
1157 "finish_reason": "tool_calls",
1158 "index": 0,
1159 "message": {
1160 "role": "assistant",
1161 "tool_calls": [
1162 {
1163 "function": {
1164 "arguments": "{\"x\": 2, \"y\": 5}",
1165 "name": "subtract"
1166 },
1167 "id": "call_1BspL6mQqjKgvsQbH1TIYkHf",
1168 "index": 0,
1169 "type": "function"
1170 }
1171 ]
1172 }
1173 }
1174 ],
1175 "created": 1740704000,
1176 "id": "2a81f6a1-4866-42fb-9902-2655a2b5b1ff",
1177 "model": "accounts/fireworks/models/deepseek-v3",
1178 "object": "chat.completion",
1179 "usage": {
1180 "completion_tokens": 26,
1181 "prompt_tokens": 248,
1182 "total_tokens": 274
1183 }
1184 }
1185 "#;
1186
1187 let novita_response_json = r#"
1188 {
1189 "choices": [
1190 {
1191 "finish_reason": "tool_calls",
1192 "index": 0,
1193 "logprobs": null,
1194 "message": {
1195 "audio": null,
1196 "content": null,
1197 "function_call": null,
1198 "reasoning_content": null,
1199 "refusal": null,
1200 "role": "assistant",
1201 "tool_calls": [
1202 {
1203 "function": {
1204 "arguments": "{\"x\": \"2\", \"y\": \"5\"}",
1205 "name": "subtract"
1206 },
1207 "id": "chatcmpl-tool-f6d2af7c8dc041058f95e2c2eede45c5",
1208 "type": "function"
1209 }
1210 ]
1211 },
1212 "stop_reason": 128008
1213 }
1214 ],
1215 "created": 1740704592,
1216 "id": "chatcmpl-a92c60ae125c47c998ecdcb53387fed4",
1217 "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-fast",
1218 "object": "chat.completion",
1219 "prompt_logprobs": null,
1220 "service_tier": null,
1221 "system_fingerprint": null,
1222 "usage": {
1223 "completion_tokens": 28,
1224 "completion_tokens_details": null,
1225 "prompt_tokens": 335,
1226 "prompt_tokens_details": null,
1227 "total_tokens": 363
1228 }
1229 }
1230 "#;
1231
1232 let _firework_response: CompletionResponse = {
1233 let jd = &mut serde_json::Deserializer::from_str(fireworks_response_json);
1234 deserialize(jd).unwrap_or_else(|err| {
1235 panic!(
1236 "Deserialization error at {} ({}:{}): {}",
1237 err.path(),
1238 err.inner().line(),
1239 err.inner().column(),
1240 err
1241 );
1242 })
1243 };
1244
1245 let _novita_response: CompletionResponse = {
1246 let jd = &mut serde_json::Deserializer::from_str(novita_response_json);
1247 deserialize(jd).unwrap_or_else(|err| {
1248 panic!(
1249 "Deserialization error at {} ({}:{}): {}",
1250 err.path(),
1251 err.inner().line(),
1252 err.inner().column(),
1253 err
1254 );
1255 })
1256 };
1257 }
1258
1259 #[test]
1260 fn test_assistant_reasoning_is_silently_skipped() {
1261 let assistant = message::Message::Assistant {
1262 id: None,
1263 content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1264 };
1265
1266 let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
1267 assert!(converted.is_empty());
1268 }
1269
1270 #[test]
1271 fn test_assistant_text_and_tool_call_are_preserved_when_reasoning_present() {
1272 let assistant = message::Message::Assistant {
1273 id: None,
1274 content: OneOrMany::many(vec![
1275 message::AssistantContent::reasoning("hidden"),
1276 message::AssistantContent::text("visible"),
1277 message::AssistantContent::tool_call(
1278 "call_1",
1279 "subtract",
1280 serde_json::json!({"x": 2, "y": 1}),
1281 ),
1282 ])
1283 .expect("non-empty assistant content"),
1284 };
1285
1286 let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
1287 assert_eq!(converted.len(), 1);
1288
1289 match &converted[0] {
1290 Message::Assistant {
1291 content,
1292 tool_calls,
1293 ..
1294 } => {
1295 assert_eq!(
1296 content,
1297 &vec![AssistantContent::Text {
1298 text: "visible".to_string()
1299 }]
1300 );
1301 assert_eq!(tool_calls.len(), 1);
1302 assert_eq!(tool_calls[0].id, "call_1");
1303 assert_eq!(tool_calls[0].function.name, "subtract");
1304 assert_eq!(
1305 tool_calls[0].function.arguments,
1306 serde_json::json!({"x": 2, "y": 1})
1307 );
1308 }
1309 _ => panic!("expected assistant message"),
1310 }
1311 }
1312
1313 #[test]
1314 fn test_request_conversion_errors_when_all_messages_are_filtered() {
1315 let request = completion::CompletionRequest {
1316 preamble: None,
1317 chat_history: OneOrMany::one(message::Message::Assistant {
1318 id: None,
1319 content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1320 }),
1321 documents: vec![],
1322 tools: vec![],
1323 temperature: None,
1324 max_tokens: None,
1325 tool_choice: None,
1326 additional_params: None,
1327 model: None,
1328 output_schema: None,
1329 };
1330
1331 let result = HuggingfaceCompletionRequest::try_from(("meta/test-model", request));
1332 assert!(matches!(result, Err(CompletionError::RequestError(_))));
1333 }
1334}