1use super::client::Client;
2use crate::completion::GetTokenUsage;
3use crate::http_client::HttpClientExt;
4use crate::providers::openai::StreamingCompletionResponse;
5use crate::telemetry::SpanCombinator;
6use crate::{
7 OneOrMany,
8 completion::{self, CompletionError, CompletionRequest},
9 json_utils,
10 message::{self},
11 one_or_many::string_or_one_or_many,
12};
13use serde::{Deserialize, Deserializer, Serialize, Serializer};
14use serde_json::Value;
15use std::{convert::Infallible, str::FromStr};
16use tracing::{Level, enabled, info_span};
17use tracing_futures::Instrument;
18
19#[derive(Debug, Deserialize)]
20#[serde(untagged)]
21pub enum ApiResponse<T> {
22 Ok(T),
23 Err(Value),
24}
25
26pub const GEMMA_2: &str = "google/gemma-2-2b-it";
33pub const META_LLAMA_3_1: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
35pub const SMALLTHINKER_PREVIEW: &str = "PowerInfer/SmallThinker-3B-Preview";
37pub const QWEN2_5: &str = "Qwen/Qwen2.5-7B-Instruct";
39pub const QWEN2_5_CODER: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
41
42pub const QWEN2_VL: &str = "Qwen/Qwen2-VL-7B-Instruct";
46pub const QWEN_QVQ_PREVIEW: &str = "Qwen/QVQ-72B-Preview";
48
49#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
50pub struct Function {
51 name: String,
52 #[serde(
53 serialize_with = "json_utils::stringified_json::serialize",
54 deserialize_with = "deserialize_arguments"
55 )]
56 pub arguments: serde_json::Value,
57}
58
59fn deserialize_arguments<'de, D>(deserializer: D) -> Result<Value, D::Error>
60where
61 D: Deserializer<'de>,
62{
63 let value = Value::deserialize(deserializer)?;
64
65 match value {
66 Value::String(s) => serde_json::from_str(&s).map_err(serde::de::Error::custom),
67 other => Ok(other),
68 }
69}
70
71impl From<Function> for message::ToolFunction {
72 fn from(value: Function) -> Self {
73 message::ToolFunction {
74 name: value.name,
75 arguments: value.arguments,
76 }
77 }
78}
79
80#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
81#[serde(rename_all = "lowercase")]
82pub enum ToolType {
83 #[default]
84 Function,
85}
86
87#[derive(Debug, Deserialize, Serialize, Clone)]
88pub struct ToolDefinition {
89 pub r#type: String,
90 pub function: completion::ToolDefinition,
91}
92
93impl From<completion::ToolDefinition> for ToolDefinition {
94 fn from(tool: completion::ToolDefinition) -> Self {
95 Self {
96 r#type: "function".into(),
97 function: tool,
98 }
99 }
100}
101
102#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
103pub struct ToolCall {
104 pub id: String,
105 pub r#type: ToolType,
106 pub function: Function,
107}
108
109impl From<ToolCall> for message::ToolCall {
110 fn from(value: ToolCall) -> Self {
111 message::ToolCall {
112 id: value.id,
113 call_id: None,
114 function: value.function.into(),
115 signature: None,
116 additional_params: None,
117 }
118 }
119}
120
121impl From<message::ToolCall> for ToolCall {
122 fn from(value: message::ToolCall) -> Self {
123 ToolCall {
124 id: value.id,
125 r#type: ToolType::Function,
126 function: Function {
127 name: value.function.name,
128 arguments: value.function.arguments,
129 },
130 }
131 }
132}
133
134#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
135pub struct ImageUrl {
136 url: String,
137}
138
139#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
140#[serde(tag = "type", rename_all = "lowercase")]
141pub enum UserContent {
142 Text {
143 text: String,
144 },
145 #[serde(rename = "image_url")]
146 ImageUrl {
147 image_url: ImageUrl,
148 },
149}
150
151impl FromStr for UserContent {
152 type Err = Infallible;
153
154 fn from_str(s: &str) -> Result<Self, Self::Err> {
155 Ok(UserContent::Text {
156 text: s.to_string(),
157 })
158 }
159}
160
161#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
162#[serde(tag = "type", rename_all = "lowercase")]
163pub enum AssistantContent {
164 Text { text: String },
165}
166
167impl FromStr for AssistantContent {
168 type Err = Infallible;
169
170 fn from_str(s: &str) -> Result<Self, Self::Err> {
171 Ok(AssistantContent::Text {
172 text: s.to_string(),
173 })
174 }
175}
176
177#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
178#[serde(tag = "type", rename_all = "lowercase")]
179pub enum SystemContent {
180 Text { text: String },
181}
182
183impl FromStr for SystemContent {
184 type Err = Infallible;
185
186 fn from_str(s: &str) -> Result<Self, Self::Err> {
187 Ok(SystemContent::Text {
188 text: s.to_string(),
189 })
190 }
191}
192
193impl From<UserContent> for message::UserContent {
194 fn from(value: UserContent) -> Self {
195 match value {
196 UserContent::Text { text } => message::UserContent::text(text),
197 UserContent::ImageUrl { image_url } => {
198 message::UserContent::image_url(image_url.url, None, None)
199 }
200 }
201 }
202}
203
204impl TryFrom<message::UserContent> for UserContent {
205 type Error = message::MessageError;
206
207 fn try_from(content: message::UserContent) -> Result<Self, Self::Error> {
208 match content {
209 message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }),
210 message::UserContent::Document(message::Document {
211 data: message::DocumentSourceKind::Raw(raw),
212 ..
213 }) => {
214 let text = String::from_utf8_lossy(raw.as_slice()).into();
215 Ok(UserContent::Text { text })
216 }
217 message::UserContent::Document(message::Document {
218 data:
219 message::DocumentSourceKind::Base64(text)
220 | message::DocumentSourceKind::String(text),
221 ..
222 }) => Ok(UserContent::Text { text }),
223 message::UserContent::Image(message::Image { data, .. }) => match data {
224 message::DocumentSourceKind::Url(url) => Ok(UserContent::ImageUrl {
225 image_url: ImageUrl { url },
226 }),
227 _ => Err(message::MessageError::ConversionError(
228 "Huggingface only supports images as urls".into(),
229 )),
230 },
231 _ => Err(message::MessageError::ConversionError(
232 "Huggingface only supports text and images".into(),
233 )),
234 }
235 }
236}
237
238#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
239#[serde(tag = "role", rename_all = "lowercase")]
240pub enum Message {
241 System {
242 #[serde(deserialize_with = "string_or_one_or_many")]
243 content: OneOrMany<SystemContent>,
244 },
245 User {
246 #[serde(deserialize_with = "string_or_one_or_many")]
247 content: OneOrMany<UserContent>,
248 },
249 Assistant {
250 #[serde(default, deserialize_with = "json_utils::string_or_vec")]
251 content: Vec<AssistantContent>,
252 #[serde(default, deserialize_with = "json_utils::null_or_vec")]
253 tool_calls: Vec<ToolCall>,
254 },
255 #[serde(rename = "tool", alias = "Tool")]
256 ToolResult {
257 name: String,
258 #[serde(skip_serializing_if = "Option::is_none")]
259 arguments: Option<serde_json::Value>,
260 #[serde(
261 deserialize_with = "string_or_one_or_many",
262 serialize_with = "serialize_tool_content"
263 )]
264 content: OneOrMany<String>,
265 },
266}
267
268fn serialize_tool_content<S>(content: &OneOrMany<String>, serializer: S) -> Result<S::Ok, S::Error>
269where
270 S: Serializer,
271{
272 let joined = content
274 .iter()
275 .map(String::as_str)
276 .collect::<Vec<_>>()
277 .join("\n");
278 serializer.serialize_str(&joined)
279}
280
281impl Message {
282 pub fn system(content: &str) -> Self {
283 Message::System {
284 content: OneOrMany::one(SystemContent::Text {
285 text: content.to_string(),
286 }),
287 }
288 }
289}
290
291impl TryFrom<message::Message> for Vec<Message> {
292 type Error = message::MessageError;
293
294 fn try_from(message: message::Message) -> Result<Vec<Message>, Self::Error> {
295 match message {
296 message::Message::System { content } => Ok(vec![Message::system(&content)]),
297 message::Message::User { content } => {
298 let (tool_results, other_content): (Vec<_>, Vec<_>) = content
299 .into_iter()
300 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
301
302 if !tool_results.is_empty() {
303 tool_results
304 .into_iter()
305 .map(|content| match content {
306 message::UserContent::ToolResult(message::ToolResult {
307 id,
308 content,
309 ..
310 }) => Ok::<_, message::MessageError>(Message::ToolResult {
311 name: id,
312 arguments: None,
313 content: content.try_map(|content| match content {
314 message::ToolResultContent::Text(message::Text { text }) => {
315 Ok(text)
316 }
317 _ => Err(message::MessageError::ConversionError(
318 "Tool result content does not support non-text".into(),
319 )),
320 })?,
321 }),
322 _ => unreachable!(),
323 })
324 .collect::<Result<Vec<_>, _>>()
325 } else {
326 let other_content = OneOrMany::many(other_content).expect(
327 "There must be other content here if there were no tool result content",
328 );
329
330 Ok(vec![Message::User {
331 content: other_content.try_map(|content| match content {
332 message::UserContent::Text(text) => {
333 Ok(UserContent::Text { text: text.text })
334 }
335 message::UserContent::Image(image) => {
336 let url = image.try_into_url()?;
337
338 Ok(UserContent::ImageUrl {
339 image_url: ImageUrl { url },
340 })
341 }
342 message::UserContent::Document(message::Document {
343 data: message::DocumentSourceKind::Raw(raw), ..
344 }) => {
345 let text = String::from_utf8_lossy(raw.as_slice()).into();
346 Ok(UserContent::Text { text })
347 }
348 message::UserContent::Document(message::Document {
349 data: message::DocumentSourceKind::Base64(text) | message::DocumentSourceKind::String(text), ..
350 }) => {
351 Ok(UserContent::Text { text })
352 }
353 _ => Err(message::MessageError::ConversionError(
354 "Huggingface inputs only support text and image URLs (both base64-encoded images and regular URLs)".into(),
355 )),
356 })?,
357 }])
358 }
359 }
360 message::Message::Assistant { content, .. } => {
361 let mut text_content = Vec::new();
362 let mut tool_calls = Vec::new();
363
364 for content in content {
365 match content {
366 message::AssistantContent::Text(text) => text_content.push(text),
367 message::AssistantContent::ToolCall(tool_call) => {
368 tool_calls.push(tool_call)
369 }
370 message::AssistantContent::Reasoning(_) => {
371 }
374 message::AssistantContent::Image(_) => {
375 panic!("Image content is not supported on HuggingFace via Rig");
376 }
377 }
378 }
379
380 if text_content.is_empty() && tool_calls.is_empty() {
381 return Ok(vec![]);
382 }
383
384 Ok(vec![Message::Assistant {
385 content: text_content
386 .into_iter()
387 .map(|content| AssistantContent::Text { text: content.text })
388 .collect::<Vec<_>>(),
389 tool_calls: tool_calls
390 .into_iter()
391 .map(|tool_call| tool_call.into())
392 .collect::<Vec<_>>(),
393 }])
394 }
395 }
396 }
397}
398
399impl TryFrom<Message> for message::Message {
400 type Error = message::MessageError;
401
402 fn try_from(message: Message) -> Result<Self, Self::Error> {
403 Ok(match message {
404 Message::User { content, .. } => message::Message::User {
405 content: content.map(|content| content.into()),
406 },
407 Message::Assistant {
408 content,
409 tool_calls,
410 ..
411 } => {
412 let mut content = content
413 .into_iter()
414 .map(|content| match content {
415 AssistantContent::Text { text } => message::AssistantContent::text(text),
416 })
417 .collect::<Vec<_>>();
418
419 content.extend(
420 tool_calls
421 .into_iter()
422 .map(|tool_call| Ok(message::AssistantContent::ToolCall(tool_call.into())))
423 .collect::<Result<Vec<_>, _>>()?,
424 );
425
426 message::Message::Assistant {
427 id: None,
428 content: OneOrMany::many(content).map_err(|_| {
429 message::MessageError::ConversionError(
430 "Neither `content` nor `tool_calls` was provided to the Message"
431 .to_owned(),
432 )
433 })?,
434 }
435 }
436
437 Message::ToolResult { name, content, .. } => message::Message::User {
438 content: OneOrMany::one(message::UserContent::tool_result(
439 name,
440 content.map(message::ToolResultContent::text),
441 )),
442 },
443
444 Message::System { content, .. } => message::Message::User {
447 content: content.map(|c| match c {
448 SystemContent::Text { text } => message::UserContent::text(text),
449 }),
450 },
451 })
452 }
453}
454
455#[derive(Clone, Debug, Deserialize, Serialize)]
456pub struct Choice {
457 pub finish_reason: String,
458 pub index: usize,
459 #[serde(default)]
460 pub logprobs: serde_json::Value,
461 pub message: Message,
462}
463
464#[derive(Debug, Deserialize, Clone, Serialize)]
465pub struct Usage {
466 pub completion_tokens: i32,
467 pub prompt_tokens: i32,
468 pub total_tokens: i32,
469}
470
471impl GetTokenUsage for Usage {
472 fn token_usage(&self) -> Option<crate::completion::Usage> {
473 let mut usage = crate::completion::Usage::new();
474 usage.input_tokens = self.prompt_tokens as u64;
475 usage.output_tokens = self.completion_tokens as u64;
476 usage.total_tokens = self.total_tokens as u64;
477
478 Some(usage)
479 }
480}
481
482#[derive(Clone, Debug, Deserialize, Serialize)]
483pub struct CompletionResponse {
484 pub created: i32,
485 pub id: String,
486 pub model: String,
487 pub choices: Vec<Choice>,
488 #[serde(default, deserialize_with = "default_string_on_null")]
489 pub system_fingerprint: String,
490 pub usage: Usage,
491}
492
493impl crate::telemetry::ProviderResponseExt for CompletionResponse {
494 type OutputMessage = Choice;
495 type Usage = Usage;
496
497 fn get_response_id(&self) -> Option<String> {
498 Some(self.id.clone())
499 }
500
501 fn get_response_model_name(&self) -> Option<String> {
502 Some(self.model.clone())
503 }
504
505 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
506 self.choices.clone()
507 }
508
509 fn get_text_response(&self) -> Option<String> {
510 let text_response = self
511 .choices
512 .iter()
513 .filter_map(|x| {
514 let Message::User { ref content } = x.message else {
515 return None;
516 };
517
518 let text = content
519 .iter()
520 .filter_map(|x| {
521 if let UserContent::Text { text } = x {
522 Some(text.clone())
523 } else {
524 None
525 }
526 })
527 .collect::<Vec<String>>();
528
529 if text.is_empty() {
530 None
531 } else {
532 Some(text.join("\n"))
533 }
534 })
535 .collect::<Vec<String>>()
536 .join("\n");
537
538 if text_response.is_empty() {
539 None
540 } else {
541 Some(text_response)
542 }
543 }
544
545 fn get_usage(&self) -> Option<Self::Usage> {
546 Some(self.usage.clone())
547 }
548}
549
550fn default_string_on_null<'de, D>(deserializer: D) -> Result<String, D::Error>
551where
552 D: Deserializer<'de>,
553{
554 match Option::<String>::deserialize(deserializer)? {
555 Some(value) => Ok(value), None => Ok(String::default()), }
558}
559
560impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
561 type Error = CompletionError;
562
563 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
564 let choice = response.choices.first().ok_or_else(|| {
565 CompletionError::ResponseError("Response contained no choices".to_owned())
566 })?;
567
568 let content = match &choice.message {
569 Message::Assistant {
570 content,
571 tool_calls,
572 ..
573 } => {
574 let mut content = content
575 .iter()
576 .map(|c| match c {
577 AssistantContent::Text { text } => message::AssistantContent::text(text),
578 })
579 .collect::<Vec<_>>();
580
581 content.extend(
582 tool_calls
583 .iter()
584 .map(|call| {
585 completion::AssistantContent::tool_call(
586 &call.id,
587 &call.function.name,
588 call.function.arguments.clone(),
589 )
590 })
591 .collect::<Vec<_>>(),
592 );
593 Ok(content)
594 }
595 _ => Err(CompletionError::ResponseError(
596 "Response did not contain a valid message or tool call".into(),
597 )),
598 }?;
599
600 let choice = OneOrMany::many(content).map_err(|_| {
601 CompletionError::ResponseError(
602 "Response contained no message or tool call (empty)".to_owned(),
603 )
604 })?;
605
606 let usage = completion::Usage {
607 input_tokens: response.usage.prompt_tokens as u64,
608 output_tokens: response.usage.completion_tokens as u64,
609 total_tokens: response.usage.total_tokens as u64,
610 cached_input_tokens: 0,
611 cache_creation_input_tokens: 0,
612 };
613
614 Ok(completion::CompletionResponse {
615 choice,
616 usage,
617 raw_response: response,
618 message_id: None,
619 })
620 }
621}
622
623#[derive(Debug, Serialize, Deserialize)]
624pub(super) struct HuggingfaceCompletionRequest {
625 model: String,
626 pub messages: Vec<Message>,
627 #[serde(skip_serializing_if = "Option::is_none")]
628 temperature: Option<f64>,
629 #[serde(skip_serializing_if = "Vec::is_empty")]
630 tools: Vec<ToolDefinition>,
631 #[serde(skip_serializing_if = "Option::is_none")]
632 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
633 #[serde(flatten, skip_serializing_if = "Option::is_none")]
634 pub additional_params: Option<serde_json::Value>,
635}
636
637impl TryFrom<(&str, CompletionRequest)> for HuggingfaceCompletionRequest {
638 type Error = CompletionError;
639
640 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
641 if req.output_schema.is_some() {
642 tracing::warn!("Structured outputs currently not supported for Huggingface");
643 }
644 let model = req.model.clone().unwrap_or_else(|| model.to_string());
645 let mut full_history: Vec<Message> = match &req.preamble {
646 Some(preamble) => vec![Message::system(preamble)],
647 None => vec![],
648 };
649 if let Some(docs) = req.normalized_documents() {
650 let docs: Vec<Message> = docs.try_into()?;
651 full_history.extend(docs);
652 }
653
654 let chat_history: Vec<Message> = req
655 .chat_history
656 .clone()
657 .into_iter()
658 .map(|message| message.try_into())
659 .collect::<Result<Vec<Vec<Message>>, _>>()?
660 .into_iter()
661 .flatten()
662 .collect();
663
664 full_history.extend(chat_history);
665
666 if full_history.is_empty() {
667 return Err(CompletionError::RequestError(
668 std::io::Error::new(
669 std::io::ErrorKind::InvalidInput,
670 "HuggingFace request has no provider-compatible messages after conversion",
671 )
672 .into(),
673 ));
674 }
675
676 let tool_choice = req
677 .tool_choice
678 .clone()
679 .map(crate::providers::openai::completion::ToolChoice::try_from)
680 .transpose()?;
681
682 Ok(Self {
683 model: model.to_string(),
684 messages: full_history,
685 temperature: req.temperature,
686 tools: req
687 .tools
688 .clone()
689 .into_iter()
690 .map(ToolDefinition::from)
691 .collect::<Vec<_>>(),
692 tool_choice,
693 additional_params: req.additional_params,
694 })
695 }
696}
697
698#[derive(Clone)]
699pub struct CompletionModel<T = reqwest::Client> {
700 pub(crate) client: Client<T>,
701 pub model: String,
703}
704
705impl<T> CompletionModel<T> {
706 pub fn new(client: Client<T>, model: &str) -> Self {
707 Self {
708 client,
709 model: model.to_string(),
710 }
711 }
712}
713
714impl<T> completion::CompletionModel for CompletionModel<T>
715where
716 T: HttpClientExt + Clone + 'static,
717{
718 type Response = CompletionResponse;
719 type StreamingResponse = StreamingCompletionResponse;
720
721 type Client = Client<T>;
722
723 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
724 Self::new(client.clone(), &model.into())
725 }
726
727 async fn completion(
728 &self,
729 completion_request: CompletionRequest,
730 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
731 let request_model = completion_request
732 .model
733 .clone()
734 .unwrap_or_else(|| self.model.clone());
735 let span = if tracing::Span::current().is_disabled() {
736 info_span!(
737 target: "rig::completions",
738 "chat",
739 gen_ai.operation.name = "chat",
740 gen_ai.provider.name = "huggingface",
741 gen_ai.request.model = &request_model,
742 gen_ai.system_instructions = &completion_request.preamble,
743 gen_ai.response.id = tracing::field::Empty,
744 gen_ai.response.model = tracing::field::Empty,
745 gen_ai.usage.output_tokens = tracing::field::Empty,
746 gen_ai.usage.input_tokens = tracing::field::Empty,
747 gen_ai.usage.cached_tokens = tracing::field::Empty,
748 )
749 } else {
750 tracing::Span::current()
751 };
752
753 let model = self.client.subprovider().model_identifier(&request_model);
754 let request = HuggingfaceCompletionRequest::try_from((model.as_ref(), completion_request))?;
755
756 if enabled!(Level::TRACE) {
757 tracing::trace!(
758 target: "rig::completions",
759 "Huggingface completion request: {}",
760 serde_json::to_string_pretty(&request)?
761 );
762 }
763
764 let request = serde_json::to_vec(&request)?;
765
766 let path = self
767 .client
768 .subprovider()
769 .completion_endpoint(&request_model);
770 let request = self
771 .client
772 .post(&path)?
773 .header("Content-Type", "application/json")
774 .body(request)
775 .map_err(|e| CompletionError::HttpError(e.into()))?;
776
777 async move {
778 let response = self.client.send(request).await?;
779
780 if response.status().is_success() {
781 let bytes: Vec<u8> = response.into_body().await?;
782 let text = String::from_utf8_lossy(&bytes);
783
784 tracing::debug!(target: "rig", "Huggingface completion error: {}", text);
785
786 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&bytes)? {
787 ApiResponse::Ok(response) => {
788 if enabled!(Level::TRACE) {
789 tracing::trace!(
790 target: "rig::completions",
791 "Huggingface completion response: {}",
792 serde_json::to_string_pretty(&response)?
793 );
794 }
795
796 let span = tracing::Span::current();
797 span.record_token_usage(&response.usage);
798 span.record_response_metadata(&response);
799
800 response.try_into()
801 }
802 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())),
803 }
804 } else {
805 let status = response.status();
806 let text: Vec<u8> = response.into_body().await?;
807 let text: String = String::from_utf8_lossy(&text).into();
808
809 Err(CompletionError::ProviderError(format!(
810 "{}: {}",
811 status, text
812 )))
813 }
814 }
815 .instrument(span)
816 .await
817 }
818
819 async fn stream(
820 &self,
821 request: CompletionRequest,
822 ) -> Result<
823 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
824 CompletionError,
825 > {
826 CompletionModel::stream(self, request).await
827 }
828}
829
830#[cfg(test)]
831mod tests {
832 use super::*;
833 use serde_path_to_error::deserialize;
834
835 #[test]
836 fn test_huggingface_request_uses_request_model_override() {
837 let request = CompletionRequest {
838 model: Some("meta-llama/Meta-Llama-3.1-8B-Instruct".to_string()),
839 preamble: None,
840 chat_history: crate::OneOrMany::one("Hello".into()),
841 documents: vec![],
842 tools: vec![],
843 temperature: None,
844 max_tokens: None,
845 tool_choice: None,
846 additional_params: None,
847 output_schema: None,
848 };
849
850 let hf_request = HuggingfaceCompletionRequest::try_from(("mistralai/Mistral-7B", request))
851 .expect("request conversion should succeed");
852 let serialized = serde_json::to_value(hf_request).expect("serialization should succeed");
853
854 assert_eq!(serialized["model"], "meta-llama/Meta-Llama-3.1-8B-Instruct");
855 }
856
857 #[test]
858 fn test_huggingface_request_uses_default_model_when_override_unset() {
859 let request = CompletionRequest {
860 model: None,
861 preamble: None,
862 chat_history: crate::OneOrMany::one("Hello".into()),
863 documents: vec![],
864 tools: vec![],
865 temperature: None,
866 max_tokens: None,
867 tool_choice: None,
868 additional_params: None,
869 output_schema: None,
870 };
871
872 let hf_request = HuggingfaceCompletionRequest::try_from(("mistralai/Mistral-7B", request))
873 .expect("request conversion should succeed");
874 let serialized = serde_json::to_value(hf_request).expect("serialization should succeed");
875
876 assert_eq!(serialized["model"], "mistralai/Mistral-7B");
877 }
878
879 #[test]
880 fn test_deserialize_message() {
881 let assistant_message_json = r#"
882 {
883 "role": "assistant",
884 "content": "\n\nHello there, how may I assist you today?"
885 }
886 "#;
887
888 let assistant_message_json2 = r#"
889 {
890 "role": "assistant",
891 "content": [
892 {
893 "type": "text",
894 "text": "\n\nHello there, how may I assist you today?"
895 }
896 ],
897 "tool_calls": null
898 }
899 "#;
900
901 let assistant_message_json3 = r#"
902 {
903 "role": "assistant",
904 "tool_calls": [
905 {
906 "id": "call_h89ipqYUjEpCPI6SxspMnoUU",
907 "type": "function",
908 "function": {
909 "name": "subtract",
910 "arguments": {"x": 2, "y": 5}
911 }
912 }
913 ],
914 "content": null,
915 "refusal": null
916 }
917 "#;
918
919 let user_message_json = r#"
920 {
921 "role": "user",
922 "content": [
923 {
924 "type": "text",
925 "text": "What's in this image?"
926 },
927 {
928 "type": "image_url",
929 "image_url": {
930 "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
931 }
932 }
933 ]
934 }
935 "#;
936
937 let assistant_message: Message = {
938 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
939 deserialize(jd).unwrap_or_else(|err| {
940 panic!(
941 "Deserialization error at {} ({}:{}): {}",
942 err.path(),
943 err.inner().line(),
944 err.inner().column(),
945 err
946 );
947 })
948 };
949
950 let assistant_message2: Message = {
951 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
952 deserialize(jd).unwrap_or_else(|err| {
953 panic!(
954 "Deserialization error at {} ({}:{}): {}",
955 err.path(),
956 err.inner().line(),
957 err.inner().column(),
958 err
959 );
960 })
961 };
962
963 let assistant_message3: Message = {
964 let jd: &mut serde_json::Deserializer<serde_json::de::StrRead<'_>> =
965 &mut serde_json::Deserializer::from_str(assistant_message_json3);
966 deserialize(jd).unwrap_or_else(|err| {
967 panic!(
968 "Deserialization error at {} ({}:{}): {}",
969 err.path(),
970 err.inner().line(),
971 err.inner().column(),
972 err
973 );
974 })
975 };
976
977 let user_message: Message = {
978 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
979 deserialize(jd).unwrap_or_else(|err| {
980 panic!(
981 "Deserialization error at {} ({}:{}): {}",
982 err.path(),
983 err.inner().line(),
984 err.inner().column(),
985 err
986 );
987 })
988 };
989
990 match assistant_message {
991 Message::Assistant { content, .. } => {
992 assert_eq!(
993 content[0],
994 AssistantContent::Text {
995 text: "\n\nHello there, how may I assist you today?".to_string()
996 }
997 );
998 }
999 _ => panic!("Expected assistant message"),
1000 }
1001
1002 match assistant_message2 {
1003 Message::Assistant {
1004 content,
1005 tool_calls,
1006 ..
1007 } => {
1008 assert_eq!(
1009 content[0],
1010 AssistantContent::Text {
1011 text: "\n\nHello there, how may I assist you today?".to_string()
1012 }
1013 );
1014
1015 assert_eq!(tool_calls, vec![]);
1016 }
1017 _ => panic!("Expected assistant message"),
1018 }
1019
1020 match assistant_message3 {
1021 Message::Assistant {
1022 content,
1023 tool_calls,
1024 ..
1025 } => {
1026 assert!(content.is_empty());
1027 assert_eq!(
1028 tool_calls[0],
1029 ToolCall {
1030 id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_string(),
1031 r#type: ToolType::Function,
1032 function: Function {
1033 name: "subtract".to_string(),
1034 arguments: serde_json::json!({"x": 2, "y": 5}),
1035 },
1036 }
1037 );
1038 }
1039 _ => panic!("Expected assistant message"),
1040 }
1041
1042 match user_message {
1043 Message::User { content, .. } => {
1044 let (first, second) = {
1045 let mut iter = content.into_iter();
1046 (iter.next().unwrap(), iter.next().unwrap())
1047 };
1048 assert_eq!(
1049 first,
1050 UserContent::Text {
1051 text: "What's in this image?".to_string()
1052 }
1053 );
1054 assert_eq!(second, UserContent::ImageUrl { image_url: ImageUrl { url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string() } });
1055 }
1056 _ => panic!("Expected user message"),
1057 }
1058 }
1059
1060 #[test]
1061 fn test_message_to_message_conversion() {
1062 let user_message = message::Message::User {
1063 content: OneOrMany::one(message::UserContent::text("Hello")),
1064 };
1065
1066 let assistant_message = message::Message::Assistant {
1067 id: None,
1068 content: OneOrMany::one(message::AssistantContent::text("Hi there!")),
1069 };
1070
1071 let converted_user_message: Vec<Message> = user_message.clone().try_into().unwrap();
1072 let converted_assistant_message: Vec<Message> =
1073 assistant_message.clone().try_into().unwrap();
1074
1075 match converted_user_message[0].clone() {
1076 Message::User { content, .. } => {
1077 assert_eq!(
1078 content.first(),
1079 UserContent::Text {
1080 text: "Hello".to_string()
1081 }
1082 );
1083 }
1084 _ => panic!("Expected user message"),
1085 }
1086
1087 match converted_assistant_message[0].clone() {
1088 Message::Assistant { content, .. } => {
1089 assert_eq!(
1090 content[0],
1091 AssistantContent::Text {
1092 text: "Hi there!".to_string()
1093 }
1094 );
1095 }
1096 _ => panic!("Expected assistant message"),
1097 }
1098
1099 let original_user_message: message::Message =
1100 converted_user_message[0].clone().try_into().unwrap();
1101 let original_assistant_message: message::Message =
1102 converted_assistant_message[0].clone().try_into().unwrap();
1103
1104 assert_eq!(original_user_message, user_message);
1105 assert_eq!(original_assistant_message, assistant_message);
1106 }
1107
1108 #[test]
1109 fn test_message_from_message_conversion() {
1110 let user_message = Message::User {
1111 content: OneOrMany::one(UserContent::Text {
1112 text: "Hello".to_string(),
1113 }),
1114 };
1115
1116 let assistant_message = Message::Assistant {
1117 content: vec![AssistantContent::Text {
1118 text: "Hi there!".to_string(),
1119 }],
1120 tool_calls: vec![],
1121 };
1122
1123 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1124 let converted_assistant_message: message::Message =
1125 assistant_message.clone().try_into().unwrap();
1126
1127 match converted_user_message.clone() {
1128 message::Message::User { content } => {
1129 assert_eq!(content.first(), message::UserContent::text("Hello"));
1130 }
1131 _ => panic!("Expected user message"),
1132 }
1133
1134 match converted_assistant_message.clone() {
1135 message::Message::Assistant { content, .. } => {
1136 assert_eq!(
1137 content.first(),
1138 message::AssistantContent::text("Hi there!")
1139 );
1140 }
1141 _ => panic!("Expected assistant message"),
1142 }
1143
1144 let original_user_message: Vec<Message> = converted_user_message.try_into().unwrap();
1145 let original_assistant_message: Vec<Message> =
1146 converted_assistant_message.try_into().unwrap();
1147
1148 assert_eq!(original_user_message[0], user_message);
1149 assert_eq!(original_assistant_message[0], assistant_message);
1150 }
1151
1152 #[test]
1153 fn test_responses() {
1154 let fireworks_response_json = r#"
1155 {
1156 "choices": [
1157 {
1158 "finish_reason": "tool_calls",
1159 "index": 0,
1160 "message": {
1161 "role": "assistant",
1162 "tool_calls": [
1163 {
1164 "function": {
1165 "arguments": "{\"x\": 2, \"y\": 5}",
1166 "name": "subtract"
1167 },
1168 "id": "call_1BspL6mQqjKgvsQbH1TIYkHf",
1169 "index": 0,
1170 "type": "function"
1171 }
1172 ]
1173 }
1174 }
1175 ],
1176 "created": 1740704000,
1177 "id": "2a81f6a1-4866-42fb-9902-2655a2b5b1ff",
1178 "model": "accounts/fireworks/models/deepseek-v3",
1179 "object": "chat.completion",
1180 "usage": {
1181 "completion_tokens": 26,
1182 "prompt_tokens": 248,
1183 "total_tokens": 274
1184 }
1185 }
1186 "#;
1187
1188 let novita_response_json = r#"
1189 {
1190 "choices": [
1191 {
1192 "finish_reason": "tool_calls",
1193 "index": 0,
1194 "logprobs": null,
1195 "message": {
1196 "audio": null,
1197 "content": null,
1198 "function_call": null,
1199 "reasoning_content": null,
1200 "refusal": null,
1201 "role": "assistant",
1202 "tool_calls": [
1203 {
1204 "function": {
1205 "arguments": "{\"x\": \"2\", \"y\": \"5\"}",
1206 "name": "subtract"
1207 },
1208 "id": "chatcmpl-tool-f6d2af7c8dc041058f95e2c2eede45c5",
1209 "type": "function"
1210 }
1211 ]
1212 },
1213 "stop_reason": 128008
1214 }
1215 ],
1216 "created": 1740704592,
1217 "id": "chatcmpl-a92c60ae125c47c998ecdcb53387fed4",
1218 "model": "meta-llama/Meta-Llama-3.1-8B-Instruct-fast",
1219 "object": "chat.completion",
1220 "prompt_logprobs": null,
1221 "service_tier": null,
1222 "system_fingerprint": null,
1223 "usage": {
1224 "completion_tokens": 28,
1225 "completion_tokens_details": null,
1226 "prompt_tokens": 335,
1227 "prompt_tokens_details": null,
1228 "total_tokens": 363
1229 }
1230 }
1231 "#;
1232
1233 let _firework_response: CompletionResponse = {
1234 let jd = &mut serde_json::Deserializer::from_str(fireworks_response_json);
1235 deserialize(jd).unwrap_or_else(|err| {
1236 panic!(
1237 "Deserialization error at {} ({}:{}): {}",
1238 err.path(),
1239 err.inner().line(),
1240 err.inner().column(),
1241 err
1242 );
1243 })
1244 };
1245
1246 let _novita_response: CompletionResponse = {
1247 let jd = &mut serde_json::Deserializer::from_str(novita_response_json);
1248 deserialize(jd).unwrap_or_else(|err| {
1249 panic!(
1250 "Deserialization error at {} ({}:{}): {}",
1251 err.path(),
1252 err.inner().line(),
1253 err.inner().column(),
1254 err
1255 );
1256 })
1257 };
1258 }
1259
1260 #[test]
1261 fn test_assistant_reasoning_is_silently_skipped() {
1262 let assistant = message::Message::Assistant {
1263 id: None,
1264 content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1265 };
1266
1267 let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
1268 assert!(converted.is_empty());
1269 }
1270
1271 #[test]
1272 fn test_assistant_text_and_tool_call_are_preserved_when_reasoning_present() {
1273 let assistant = message::Message::Assistant {
1274 id: None,
1275 content: OneOrMany::many(vec![
1276 message::AssistantContent::reasoning("hidden"),
1277 message::AssistantContent::text("visible"),
1278 message::AssistantContent::tool_call(
1279 "call_1",
1280 "subtract",
1281 serde_json::json!({"x": 2, "y": 1}),
1282 ),
1283 ])
1284 .expect("non-empty assistant content"),
1285 };
1286
1287 let converted: Vec<Message> = assistant.try_into().expect("conversion should work");
1288 assert_eq!(converted.len(), 1);
1289
1290 match &converted[0] {
1291 Message::Assistant {
1292 content,
1293 tool_calls,
1294 ..
1295 } => {
1296 assert_eq!(
1297 content,
1298 &vec![AssistantContent::Text {
1299 text: "visible".to_string()
1300 }]
1301 );
1302 assert_eq!(tool_calls.len(), 1);
1303 assert_eq!(tool_calls[0].id, "call_1");
1304 assert_eq!(tool_calls[0].function.name, "subtract");
1305 assert_eq!(
1306 tool_calls[0].function.arguments,
1307 serde_json::json!({"x": 2, "y": 1})
1308 );
1309 }
1310 _ => panic!("expected assistant message"),
1311 }
1312 }
1313
1314 #[test]
1315 fn test_request_conversion_errors_when_all_messages_are_filtered() {
1316 let request = completion::CompletionRequest {
1317 preamble: None,
1318 chat_history: OneOrMany::one(message::Message::Assistant {
1319 id: None,
1320 content: OneOrMany::one(message::AssistantContent::reasoning("hidden")),
1321 }),
1322 documents: vec![],
1323 tools: vec![],
1324 temperature: None,
1325 max_tokens: None,
1326 tool_choice: None,
1327 additional_params: None,
1328 model: None,
1329 output_schema: None,
1330 };
1331
1332 let result = HuggingfaceCompletionRequest::try_from(("meta/test-model", request));
1333 assert!(matches!(result, Err(CompletionError::RequestError(_))));
1334 }
1335}