1use super::{
2 client::{ApiErrorResponse, ApiResponse, Client, Usage},
3 streaming::StreamingCompletionResponse,
4};
5use crate::message::{
6 self, AudioMediaType, DocumentMediaType, DocumentSourceKind, ImageDetail, MimeType,
7 VideoMediaType,
8};
9use crate::telemetry::SpanCombinator;
10use crate::{
11 OneOrMany,
12 completion::{self, CompletionError, CompletionRequest},
13 http_client::HttpClientExt,
14 json_utils,
15 one_or_many::string_or_one_or_many,
16 providers::openai,
17};
18use bytes::Bytes;
19use serde::{Deserialize, Serialize, Serializer};
20use std::collections::HashMap;
21use tracing::{Instrument, Level, enabled, info_span};
22
23pub const QWEN_QWQ_32B: &str = "qwen/qwq-32b";
29pub const CLAUDE_3_7_SONNET: &str = "anthropic/claude-3.7-sonnet";
31pub const PERPLEXITY_SONAR_PRO: &str = "perplexity/sonar-pro";
33pub const GEMINI_FLASH_2_0: &str = "google/gemini-2.0-flash-001";
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
45#[serde(rename_all = "lowercase")]
46pub enum DataCollection {
47 #[default]
49 Allow,
50 Deny,
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
58#[serde(rename_all = "lowercase")]
59pub enum Quantization {
60 #[serde(rename = "int4")]
62 Int4,
63 #[serde(rename = "int8")]
65 Int8,
66 #[serde(rename = "fp16")]
68 Fp16,
69 #[serde(rename = "bf16")]
71 Bf16,
72 #[serde(rename = "fp32")]
74 Fp32,
75 #[serde(rename = "fp8")]
77 Fp8,
78 #[serde(rename = "unknown")]
80 Unknown,
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
89#[serde(rename_all = "lowercase")]
90pub enum ProviderSortStrategy {
91 Price,
93 Throughput,
95 Latency,
97}
98
99#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
101#[serde(rename_all = "lowercase")]
102pub enum SortPartition {
103 Model,
105 None,
107}
108
109#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
113pub struct ProviderSortConfig {
114 pub by: ProviderSortStrategy,
116
117 #[serde(skip_serializing_if = "Option::is_none")]
119 pub partition: Option<SortPartition>,
120}
121
122impl ProviderSortConfig {
123 pub fn new(by: ProviderSortStrategy) -> Self {
125 Self {
126 by,
127 partition: None,
128 }
129 }
130
131 pub fn partition(mut self, partition: SortPartition) -> Self {
133 self.partition = Some(partition);
134 self
135 }
136}
137
138#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
143#[serde(untagged)]
144pub enum ProviderSort {
145 Simple(ProviderSortStrategy),
147 Complex(ProviderSortConfig),
149}
150
151impl From<ProviderSortStrategy> for ProviderSort {
152 fn from(strategy: ProviderSortStrategy) -> Self {
153 ProviderSort::Simple(strategy)
154 }
155}
156
157impl From<ProviderSortConfig> for ProviderSort {
158 fn from(config: ProviderSortConfig) -> Self {
159 ProviderSort::Complex(config)
160 }
161}
162
163#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
167#[serde(untagged)]
168pub enum ThroughputThreshold {
169 Simple(f64),
171 Percentile(PercentileThresholds),
173}
174
175#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
179#[serde(untagged)]
180pub enum LatencyThreshold {
181 Simple(f64),
183 Percentile(PercentileThresholds),
185}
186
187#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
189pub struct PercentileThresholds {
190 #[serde(skip_serializing_if = "Option::is_none")]
192 pub p50: Option<f64>,
193 #[serde(skip_serializing_if = "Option::is_none")]
195 pub p75: Option<f64>,
196 #[serde(skip_serializing_if = "Option::is_none")]
198 pub p90: Option<f64>,
199 #[serde(skip_serializing_if = "Option::is_none")]
201 pub p99: Option<f64>,
202}
203
204impl PercentileThresholds {
205 pub fn new() -> Self {
207 Self::default()
208 }
209
210 pub fn p50(mut self, value: f64) -> Self {
212 self.p50 = Some(value);
213 self
214 }
215
216 pub fn p75(mut self, value: f64) -> Self {
218 self.p75 = Some(value);
219 self
220 }
221
222 pub fn p90(mut self, value: f64) -> Self {
224 self.p90 = Some(value);
225 self
226 }
227
228 pub fn p99(mut self, value: f64) -> Self {
230 self.p99 = Some(value);
231 self
232 }
233}
234
235#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
240pub struct MaxPrice {
241 #[serde(skip_serializing_if = "Option::is_none")]
243 pub prompt: Option<f64>,
244 #[serde(skip_serializing_if = "Option::is_none")]
246 pub completion: Option<f64>,
247 #[serde(skip_serializing_if = "Option::is_none")]
249 pub request: Option<f64>,
250 #[serde(skip_serializing_if = "Option::is_none")]
252 pub image: Option<f64>,
253}
254
255impl MaxPrice {
256 pub fn new() -> Self {
258 Self::default()
259 }
260
261 pub fn prompt(mut self, price: f64) -> Self {
263 self.prompt = Some(price);
264 self
265 }
266
267 pub fn completion(mut self, price: f64) -> Self {
269 self.completion = Some(price);
270 self
271 }
272
273 pub fn request(mut self, price: f64) -> Self {
275 self.request = Some(price);
276 self
277 }
278
279 pub fn image(mut self, price: f64) -> Self {
281 self.image = Some(price);
282 self
283 }
284}
285
286#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
306pub struct ProviderPreferences {
307 #[serde(skip_serializing_if = "Option::is_none")]
311 pub order: Option<Vec<String>>,
312
313 #[serde(skip_serializing_if = "Option::is_none")]
315 pub only: Option<Vec<String>>,
316
317 #[serde(skip_serializing_if = "Option::is_none")]
319 pub ignore: Option<Vec<String>>,
320
321 #[serde(skip_serializing_if = "Option::is_none")]
324 pub allow_fallbacks: Option<bool>,
325
326 #[serde(skip_serializing_if = "Option::is_none")]
330 pub require_parameters: Option<bool>,
331
332 #[serde(skip_serializing_if = "Option::is_none")]
335 pub data_collection: Option<DataCollection>,
336
337 #[serde(skip_serializing_if = "Option::is_none")]
339 pub zdr: Option<bool>,
340
341 #[serde(skip_serializing_if = "Option::is_none")]
345 pub sort: Option<ProviderSort>,
346
347 #[serde(skip_serializing_if = "Option::is_none")]
349 pub preferred_min_throughput: Option<ThroughputThreshold>,
350
351 #[serde(skip_serializing_if = "Option::is_none")]
353 pub preferred_max_latency: Option<LatencyThreshold>,
354
355 #[serde(skip_serializing_if = "Option::is_none")]
357 pub max_price: Option<MaxPrice>,
358
359 #[serde(skip_serializing_if = "Option::is_none")]
362 pub quantizations: Option<Vec<Quantization>>,
363}
364
365impl ProviderPreferences {
366 pub fn new() -> Self {
368 Self::default()
369 }
370
371 pub fn order(mut self, providers: impl IntoIterator<Item = impl Into<String>>) -> Self {
387 self.order = Some(providers.into_iter().map(|p| p.into()).collect());
388 self
389 }
390
391 pub fn only(mut self, providers: impl IntoIterator<Item = impl Into<String>>) -> Self {
403 self.only = Some(providers.into_iter().map(|p| p.into()).collect());
404 self
405 }
406
407 pub fn ignore(mut self, providers: impl IntoIterator<Item = impl Into<String>>) -> Self {
418 self.ignore = Some(providers.into_iter().map(|p| p.into()).collect());
419 self
420 }
421
422 pub fn allow_fallbacks(mut self, allow: bool) -> Self {
427 self.allow_fallbacks = Some(allow);
428 self
429 }
430
431 pub fn require_parameters(mut self, require: bool) -> Self {
437 self.require_parameters = Some(require);
438 self
439 }
440
441 pub fn data_collection(mut self, policy: DataCollection) -> Self {
445 self.data_collection = Some(policy);
446 self
447 }
448
449 pub fn zdr(mut self, enable: bool) -> Self {
460 self.zdr = Some(enable);
461 self
462 }
463
464 pub fn sort(mut self, sort: impl Into<ProviderSort>) -> Self {
480 self.sort = Some(sort.into());
481 self
482 }
483
484 pub fn preferred_min_throughput(mut self, threshold: ThroughputThreshold) -> Self {
504 self.preferred_min_throughput = Some(threshold);
505 self
506 }
507
508 pub fn preferred_max_latency(mut self, threshold: LatencyThreshold) -> Self {
512 self.preferred_max_latency = Some(threshold);
513 self
514 }
515
516 pub fn max_price(mut self, price: MaxPrice) -> Self {
520 self.max_price = Some(price);
521 self
522 }
523
524 pub fn quantizations(mut self, quantizations: impl IntoIterator<Item = Quantization>) -> Self {
537 self.quantizations = Some(quantizations.into_iter().collect());
538 self
539 }
540
541 pub fn zero_data_retention(self) -> Self {
545 self.zdr(true)
546 }
547
548 pub fn fastest(self) -> Self {
550 self.sort(ProviderSortStrategy::Throughput)
551 }
552
553 pub fn cheapest(self) -> Self {
555 self.sort(ProviderSortStrategy::Price)
556 }
557
558 pub fn lowest_latency(self) -> Self {
560 self.sort(ProviderSortStrategy::Latency)
561 }
562
563 pub fn to_json(&self) -> serde_json::Value {
565 serde_json::json!({
566 "provider": self
567 })
568 }
569}
570
571#[derive(Debug, Serialize, Deserialize)]
575pub struct CompletionResponse {
576 pub id: String,
577 pub object: String,
578 pub created: u64,
579 pub model: String,
580 pub choices: Vec<Choice>,
581 pub system_fingerprint: Option<String>,
582 pub usage: Option<Usage>,
583}
584
585impl From<ApiErrorResponse> for CompletionError {
586 fn from(err: ApiErrorResponse) -> Self {
587 CompletionError::ProviderError(err.message)
588 }
589}
590
591impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
592 type Error = CompletionError;
593
594 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
595 let choice = response.choices.first().ok_or_else(|| {
596 CompletionError::ResponseError("Response contained no choices".to_owned())
597 })?;
598
599 let content = match &choice.message {
600 Message::Assistant {
601 content,
602 tool_calls,
603 reasoning,
604 reasoning_details,
605 ..
606 } => {
607 let mut content = content
608 .iter()
609 .map(|c| match c {
610 openai::AssistantContent::Text { text } => {
611 completion::AssistantContent::text(text)
612 }
613 openai::AssistantContent::Refusal { refusal } => {
614 completion::AssistantContent::text(refusal)
615 }
616 })
617 .collect::<Vec<_>>();
618
619 content.extend(tool_calls.iter().map(|call| {
620 completion::AssistantContent::tool_call(
621 &call.id,
622 &call.function.name,
623 call.function.arguments.clone(),
624 )
625 }));
626
627 let mut grouped_reasoning: HashMap<
628 Option<String>,
629 Vec<(usize, usize, message::ReasoningContent)>,
630 > = HashMap::new();
631 let mut reasoning_order: Vec<Option<String>> = Vec::new();
632 for (position, detail) in reasoning_details.iter().enumerate() {
633 let (reasoning_id, sort_index, parsed_content) = match detail {
634 ReasoningDetails::Summary {
635 id, index, summary, ..
636 } => (
637 id.clone(),
638 *index,
639 Some(message::ReasoningContent::Summary(summary.clone())),
640 ),
641 ReasoningDetails::Encrypted {
642 id, index, data, ..
643 } => (
644 id.clone(),
645 *index,
646 Some(message::ReasoningContent::Encrypted(data.clone())),
647 ),
648 ReasoningDetails::Text {
649 id,
650 index,
651 text,
652 signature,
653 ..
654 } => (
655 id.clone(),
656 *index,
657 text.as_ref().map(|text| message::ReasoningContent::Text {
658 text: text.clone(),
659 signature: signature.clone(),
660 }),
661 ),
662 };
663
664 let Some(parsed_content) = parsed_content else {
665 continue;
666 };
667 let sort_index = sort_index.unwrap_or(position);
668
669 let entry = grouped_reasoning.entry(reasoning_id.clone());
670 if matches!(entry, std::collections::hash_map::Entry::Vacant(_)) {
671 reasoning_order.push(reasoning_id);
672 }
673 entry
674 .or_default()
675 .push((sort_index, position, parsed_content));
676 }
677
678 if grouped_reasoning.is_empty() {
679 if let Some(reasoning) = reasoning {
680 content.push(completion::AssistantContent::reasoning(reasoning));
681 }
682 } else {
683 for reasoning_id in reasoning_order {
684 let Some(mut blocks) = grouped_reasoning.remove(&reasoning_id) else {
685 continue;
686 };
687 blocks.sort_by_key(|(index, position, _)| (*index, *position));
688 content.push(completion::AssistantContent::Reasoning(
689 message::Reasoning {
690 id: reasoning_id,
691 content: blocks
692 .into_iter()
693 .map(|(_, _, content)| content)
694 .collect::<Vec<_>>(),
695 },
696 ));
697 }
698 }
699
700 Ok(content)
701 }
702 _ => Err(CompletionError::ResponseError(
703 "Response did not contain a valid message or tool call".into(),
704 )),
705 }?;
706
707 let choice = OneOrMany::many(content).map_err(|_| {
708 CompletionError::ResponseError(
709 "Response contained no message or tool call (empty)".to_owned(),
710 )
711 })?;
712
713 let usage = response
714 .usage
715 .as_ref()
716 .map(|usage| completion::Usage {
717 input_tokens: usage.prompt_tokens as u64,
718 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
719 total_tokens: usage.total_tokens as u64,
720 cached_input_tokens: 0,
721 cache_creation_input_tokens: 0,
722 })
723 .unwrap_or_default();
724
725 Ok(completion::CompletionResponse {
726 choice,
727 usage,
728 raw_response: response,
729 message_id: None,
730 })
731 }
732}
733
734#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
774#[serde(tag = "type", rename_all = "snake_case")]
775pub enum UserContent {
776 Text { text: String },
778
779 #[serde(rename = "image_url")]
783 ImageUrl { image_url: ImageUrl },
784
785 File { file: FileContent },
790
791 InputAudio { input_audio: openai::InputAudio },
795
796 #[serde(rename = "video_url")]
801 VideoUrl { video_url: VideoUrlContent },
802}
803
804impl UserContent {
805 pub fn text(text: impl Into<String>) -> Self {
807 UserContent::Text { text: text.into() }
808 }
809
810 pub fn image_url(url: impl Into<String>) -> Self {
812 UserContent::ImageUrl {
813 image_url: ImageUrl {
814 url: url.into(),
815 detail: None,
816 },
817 }
818 }
819
820 pub fn image_url_with_detail(url: impl Into<String>, detail: ImageDetail) -> Self {
822 UserContent::ImageUrl {
823 image_url: ImageUrl {
824 url: url.into(),
825 detail: Some(detail),
826 },
827 }
828 }
829
830 pub fn image_base64(
837 data: impl Into<String>,
838 mime_type: &str,
839 detail: Option<ImageDetail>,
840 ) -> Self {
841 let data_uri = format!("data:{};base64,{}", mime_type, data.into());
842 UserContent::ImageUrl {
843 image_url: ImageUrl {
844 url: data_uri,
845 detail,
846 },
847 }
848 }
849
850 pub fn file_url(url: impl Into<String>, filename: Option<String>) -> Self {
856 UserContent::File {
857 file: FileContent {
858 filename,
859 file_data: Some(url.into()),
860 },
861 }
862 }
863
864 pub fn file_base64(data: impl Into<String>, mime_type: &str, filename: Option<String>) -> Self {
871 let data_uri = format!("data:{};base64,{}", mime_type, data.into());
872 UserContent::File {
873 file: FileContent {
874 filename,
875 file_data: Some(data_uri),
876 },
877 }
878 }
879
880 pub fn audio_base64(data: impl Into<String>, format: AudioMediaType) -> Self {
888 UserContent::InputAudio {
889 input_audio: openai::InputAudio {
890 data: data.into(),
891 format,
892 },
893 }
894 }
895
896 pub fn video_url(url: impl Into<String>) -> Self {
903 UserContent::VideoUrl {
904 video_url: VideoUrlContent { url: url.into() },
905 }
906 }
907
908 pub fn video_base64(data: impl Into<String>, media_type: VideoMediaType) -> Self {
914 let mime = media_type.to_mime_type();
915 let data_uri = format!("data:{mime};base64,{}", data.into());
916 UserContent::VideoUrl {
917 video_url: VideoUrlContent { url: data_uri },
918 }
919 }
920}
921
922impl From<String> for UserContent {
923 fn from(text: String) -> Self {
924 UserContent::Text { text }
925 }
926}
927
928impl From<&str> for UserContent {
929 fn from(text: &str) -> Self {
930 UserContent::Text {
931 text: text.to_string(),
932 }
933 }
934}
935
936impl std::str::FromStr for UserContent {
937 type Err = std::convert::Infallible;
938
939 fn from_str(s: &str) -> Result<Self, Self::Err> {
940 Ok(UserContent::Text {
941 text: s.to_string(),
942 })
943 }
944}
945
946#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
948pub struct ImageUrl {
949 pub url: String,
951 #[serde(skip_serializing_if = "Option::is_none")]
953 pub detail: Option<ImageDetail>,
954}
955
956#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
962pub struct VideoUrlContent {
963 pub url: String,
965}
966
967#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
974pub struct FileContent {
975 #[serde(skip_serializing_if = "Option::is_none")]
977 pub filename: Option<String>,
978 #[serde(skip_serializing_if = "Option::is_none")]
980 pub file_data: Option<String>,
981}
982
983fn serialize_user_content<S>(
986 content: &OneOrMany<UserContent>,
987 serializer: S,
988) -> Result<S::Ok, S::Error>
989where
990 S: Serializer,
991{
992 if content.len() == 1
993 && let UserContent::Text { text } = content.first_ref()
994 {
995 return serializer.serialize_str(text);
996 }
997 content.serialize(serializer)
998}
999
1000impl TryFrom<message::UserContent> for UserContent {
1001 type Error = message::MessageError;
1002
1003 fn try_from(value: message::UserContent) -> Result<Self, Self::Error> {
1004 match value {
1005 message::UserContent::Text(message::Text { text }) => Ok(UserContent::Text { text }),
1006
1007 message::UserContent::Image(message::Image {
1008 data,
1009 detail,
1010 media_type,
1011 ..
1012 }) => {
1013 let url = match data {
1014 DocumentSourceKind::Url(url) => url,
1015 DocumentSourceKind::Base64(data) => {
1016 let mime = media_type
1017 .ok_or_else(|| {
1018 message::MessageError::ConversionError(
1019 "Image media type required for base64 encoding".into(),
1020 )
1021 })?
1022 .to_mime_type();
1023 format!("data:{mime};base64,{data}")
1024 }
1025 DocumentSourceKind::Raw(_) => {
1026 return Err(message::MessageError::ConversionError(
1027 "Raw bytes not supported, encode as base64 first".into(),
1028 ));
1029 }
1030 DocumentSourceKind::String(_) => {
1031 return Err(message::MessageError::ConversionError(
1032 "String source not supported for images".into(),
1033 ));
1034 }
1035 DocumentSourceKind::Unknown => {
1036 return Err(message::MessageError::ConversionError(
1037 "Image has no data".into(),
1038 ));
1039 }
1040 };
1041 Ok(UserContent::ImageUrl {
1042 image_url: ImageUrl { url, detail },
1043 })
1044 }
1045
1046 message::UserContent::Document(message::Document {
1047 data, media_type, ..
1048 }) => match data {
1049 DocumentSourceKind::Url(url) => {
1050 let filename = media_type.as_ref().map(|mt| match mt {
1051 DocumentMediaType::PDF => "document.pdf",
1052 DocumentMediaType::TXT => "document.txt",
1053 DocumentMediaType::HTML => "document.html",
1054 DocumentMediaType::MARKDOWN => "document.md",
1055 DocumentMediaType::CSV => "document.csv",
1056 DocumentMediaType::XML => "document.xml",
1057 _ => "document",
1058 });
1059 Ok(UserContent::File {
1060 file: FileContent {
1061 filename: filename.map(String::from),
1062 file_data: Some(url),
1063 },
1064 })
1065 }
1066 DocumentSourceKind::Base64(data) => {
1067 let mime = media_type
1068 .as_ref()
1069 .map(|m| m.to_mime_type())
1070 .unwrap_or("application/pdf");
1071 let data_uri = format!("data:{mime};base64,{data}");
1072
1073 let filename = media_type.as_ref().map(|mt| match mt {
1074 DocumentMediaType::PDF => "document.pdf",
1075 DocumentMediaType::TXT => "document.txt",
1076 DocumentMediaType::HTML => "document.html",
1077 DocumentMediaType::MARKDOWN => "document.md",
1078 DocumentMediaType::CSV => "document.csv",
1079 DocumentMediaType::XML => "document.xml",
1080 _ => "document",
1081 });
1082
1083 Ok(UserContent::File {
1084 file: FileContent {
1085 filename: filename.map(String::from),
1086 file_data: Some(data_uri),
1087 },
1088 })
1089 }
1090 DocumentSourceKind::String(text) => Ok(UserContent::Text { text }),
1091 DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
1092 "Raw bytes not supported for documents, encode as base64 first".into(),
1093 )),
1094 DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
1095 "Document has no data".into(),
1096 )),
1097 },
1098
1099 message::UserContent::Audio(message::Audio {
1100 data, media_type, ..
1101 }) => match data {
1102 DocumentSourceKind::Base64(data) => {
1103 let format = media_type.ok_or_else(|| {
1104 message::MessageError::ConversionError(
1105 "Audio media type required for base64 encoding".into(),
1106 )
1107 })?;
1108 Ok(UserContent::InputAudio {
1109 input_audio: openai::InputAudio { data, format },
1110 })
1111 }
1112 DocumentSourceKind::Url(_) => Err(message::MessageError::ConversionError(
1113 "OpenRouter does not support audio URLs, encode as base64 first".into(),
1114 )),
1115 DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
1116 "Raw bytes not supported for audio, encode as base64 first".into(),
1117 )),
1118 DocumentSourceKind::String(_) => Err(message::MessageError::ConversionError(
1119 "String source not supported for audio".into(),
1120 )),
1121 DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
1122 "Audio has no data".into(),
1123 )),
1124 },
1125
1126 message::UserContent::Video(message::Video {
1127 data, media_type, ..
1128 }) => {
1129 let url = match data {
1130 DocumentSourceKind::Url(url) => url,
1131 DocumentSourceKind::Base64(data) => {
1132 let mime = media_type
1133 .ok_or_else(|| {
1134 message::MessageError::ConversionError(
1135 "Video media type required for base64 encoding".into(),
1136 )
1137 })?
1138 .to_mime_type();
1139 format!("data:{mime};base64,{data}")
1140 }
1141 DocumentSourceKind::Raw(_) => {
1142 return Err(message::MessageError::ConversionError(
1143 "Raw bytes not supported for video, encode as base64 first".into(),
1144 ));
1145 }
1146 DocumentSourceKind::String(_) => {
1147 return Err(message::MessageError::ConversionError(
1148 "String source not supported for video".into(),
1149 ));
1150 }
1151 DocumentSourceKind::Unknown => {
1152 return Err(message::MessageError::ConversionError(
1153 "Video has no data".into(),
1154 ));
1155 }
1156 };
1157 Ok(UserContent::VideoUrl {
1158 video_url: VideoUrlContent { url },
1159 })
1160 }
1161
1162 message::UserContent::ToolResult(_) => Err(message::MessageError::ConversionError(
1163 "Tool results should be handled as separate messages".into(),
1164 )),
1165 }
1166 }
1167}
1168
1169impl TryFrom<OneOrMany<message::UserContent>> for Vec<Message> {
1170 type Error = message::MessageError;
1171
1172 fn try_from(value: OneOrMany<message::UserContent>) -> Result<Self, Self::Error> {
1173 let (tool_results, other_content): (Vec<_>, Vec<_>) = value
1174 .into_iter()
1175 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
1176
1177 if !tool_results.is_empty() {
1180 tool_results
1181 .into_iter()
1182 .map(|content| match content {
1183 message::UserContent::ToolResult(tool_result) => Ok(Message::ToolResult {
1184 tool_call_id: tool_result.id,
1185 content: tool_result
1186 .content
1187 .into_iter()
1188 .map(|c| match c {
1189 message::ToolResultContent::Text(message::Text { text }) => text,
1190 message::ToolResultContent::Image(_) => {
1191 "[Image content not supported in tool results]".to_string()
1192 }
1193 })
1194 .collect::<Vec<_>>()
1195 .join("\n"),
1196 }),
1197 _ => unreachable!(),
1198 })
1199 .collect::<Result<Vec<_>, _>>()
1200 } else {
1201 let user_content: Vec<UserContent> = other_content
1202 .into_iter()
1203 .map(|content| content.try_into())
1204 .collect::<Result<Vec<_>, _>>()?;
1205
1206 let content = OneOrMany::many(user_content)
1207 .expect("There must be content here if there were no tool result content");
1208
1209 Ok(vec![Message::User {
1210 content,
1211 name: None,
1212 }])
1213 }
1214 }
1215}
1216
1217#[derive(Debug, Deserialize, Serialize)]
1222pub struct Choice {
1223 pub index: usize,
1224 pub native_finish_reason: Option<String>,
1225 pub message: Message,
1226 pub finish_reason: Option<String>,
1227}
1228
1229#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1235#[serde(tag = "role", rename_all = "lowercase")]
1236pub enum Message {
1237 #[serde(alias = "developer")]
1238 System {
1239 #[serde(deserialize_with = "string_or_one_or_many")]
1240 content: OneOrMany<openai::SystemContent>,
1241 #[serde(skip_serializing_if = "Option::is_none")]
1242 name: Option<String>,
1243 },
1244 User {
1245 #[serde(
1246 deserialize_with = "string_or_one_or_many",
1247 serialize_with = "serialize_user_content"
1248 )]
1249 content: OneOrMany<UserContent>,
1250 #[serde(skip_serializing_if = "Option::is_none")]
1251 name: Option<String>,
1252 },
1253 Assistant {
1254 #[serde(default, deserialize_with = "json_utils::string_or_vec")]
1255 content: Vec<openai::AssistantContent>,
1256 #[serde(skip_serializing_if = "Option::is_none")]
1257 refusal: Option<String>,
1258 #[serde(skip_serializing_if = "Option::is_none")]
1259 audio: Option<openai::AudioAssistant>,
1260 #[serde(skip_serializing_if = "Option::is_none")]
1261 name: Option<String>,
1262 #[serde(
1263 default,
1264 deserialize_with = "json_utils::null_or_vec",
1265 skip_serializing_if = "Vec::is_empty"
1266 )]
1267 tool_calls: Vec<openai::ToolCall>,
1268 #[serde(skip_serializing_if = "Option::is_none")]
1269 reasoning: Option<String>,
1270 #[serde(default, skip_serializing_if = "Vec::is_empty")]
1271 reasoning_details: Vec<ReasoningDetails>,
1272 },
1273 #[serde(rename = "tool")]
1274 ToolResult {
1275 tool_call_id: String,
1276 content: String,
1277 },
1278}
1279
1280impl Message {
1281 pub fn system(content: &str) -> Self {
1282 Message::System {
1283 content: OneOrMany::one(content.to_owned().into()),
1284 name: None,
1285 }
1286 }
1287}
1288
1289#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1290#[serde(tag = "type", rename_all = "snake_case")]
1291pub enum ReasoningDetails {
1292 #[serde(rename = "reasoning.summary")]
1293 Summary {
1294 id: Option<String>,
1295 format: Option<String>,
1296 index: Option<usize>,
1297 summary: String,
1298 },
1299 #[serde(rename = "reasoning.encrypted")]
1300 Encrypted {
1301 id: Option<String>,
1302 format: Option<String>,
1303 index: Option<usize>,
1304 data: String,
1305 },
1306 #[serde(rename = "reasoning.text")]
1307 Text {
1308 id: Option<String>,
1309 format: Option<String>,
1310 index: Option<usize>,
1311 text: Option<String>,
1312 signature: Option<String>,
1313 },
1314}
1315
1316#[derive(Debug, Deserialize, PartialEq, Clone)]
1317#[serde(untagged)]
1318enum ToolCallAdditionalParams {
1319 ReasoningDetails(ReasoningDetails),
1320 Minimal {
1321 id: Option<String>,
1322 format: Option<String>,
1323 },
1324}
1325
1326impl From<openai::UserContent> for UserContent {
1328 fn from(value: openai::UserContent) -> Self {
1329 match value {
1330 openai::UserContent::Text { text } => UserContent::Text { text },
1331 openai::UserContent::Image { image_url } => UserContent::ImageUrl {
1332 image_url: ImageUrl {
1333 url: image_url.url,
1334 detail: Some(image_url.detail),
1335 },
1336 },
1337 openai::UserContent::Audio { input_audio } => UserContent::InputAudio { input_audio },
1338 }
1339 }
1340}
1341
1342impl From<openai::Message> for Message {
1343 fn from(value: openai::Message) -> Self {
1344 match value {
1345 openai::Message::System { content, name } => Self::System { content, name },
1346 openai::Message::User { content, name } => {
1347 let converted_content = content.map(UserContent::from);
1349 Self::User {
1350 content: converted_content,
1351 name,
1352 }
1353 }
1354 openai::Message::Assistant {
1355 content,
1356 refusal,
1357 audio,
1358 name,
1359 tool_calls,
1360 } => Self::Assistant {
1361 content,
1362 refusal,
1363 audio,
1364 name,
1365 tool_calls,
1366 reasoning: None,
1367 reasoning_details: Vec::new(),
1368 },
1369 openai::Message::ToolResult {
1370 tool_call_id,
1371 content,
1372 } => Self::ToolResult {
1373 tool_call_id,
1374 content: content.as_text(),
1375 },
1376 }
1377 }
1378}
1379
1380impl TryFrom<OneOrMany<message::AssistantContent>> for Vec<Message> {
1381 type Error = message::MessageError;
1382
1383 fn try_from(value: OneOrMany<message::AssistantContent>) -> Result<Self, Self::Error> {
1384 let mut text_content = Vec::new();
1385 let mut tool_calls = Vec::new();
1386 let mut reasoning = None;
1387 let mut reasoning_details = Vec::new();
1388
1389 for content in value.into_iter() {
1390 match content {
1391 message::AssistantContent::Text(text) => text_content.push(text),
1392 message::AssistantContent::ToolCall(tool_call) => {
1393 if let Some(additional_params) = &tool_call.additional_params
1399 && let Ok(additional_params) =
1400 serde_json::from_value::<ToolCallAdditionalParams>(
1401 additional_params.clone(),
1402 )
1403 {
1404 match additional_params {
1405 ToolCallAdditionalParams::ReasoningDetails(full) => {
1406 reasoning_details.push(full);
1407 }
1408 ToolCallAdditionalParams::Minimal { id, format } => {
1409 let id = id.or_else(|| tool_call.call_id.clone());
1410 if let Some(signature) = &tool_call.signature
1411 && let Some(id) = id
1412 {
1413 reasoning_details.push(ReasoningDetails::Encrypted {
1414 id: Some(id),
1415 format,
1416 index: None,
1417 data: signature.clone(),
1418 })
1419 }
1420 }
1421 }
1422 } else if let Some(signature) = &tool_call.signature {
1423 reasoning_details.push(ReasoningDetails::Encrypted {
1424 id: tool_call.call_id.clone(),
1425 format: None,
1426 index: None,
1427 data: signature.clone(),
1428 });
1429 }
1430 tool_calls.push(tool_call.into())
1431 }
1432 message::AssistantContent::Reasoning(r) => {
1433 if r.content.is_empty() {
1434 let display = r.display_text();
1435 if !display.is_empty() {
1436 reasoning = Some(display);
1437 }
1438 } else {
1439 for reasoning_block in &r.content {
1440 let index = Some(reasoning_details.len());
1441 match reasoning_block {
1442 message::ReasoningContent::Text { text, signature } => {
1443 reasoning_details.push(ReasoningDetails::Text {
1444 id: r.id.clone(),
1445 format: None,
1446 index,
1447 text: Some(text.clone()),
1448 signature: signature.clone(),
1449 });
1450 }
1451 message::ReasoningContent::Summary(summary) => {
1452 reasoning_details.push(ReasoningDetails::Summary {
1453 id: r.id.clone(),
1454 format: None,
1455 index,
1456 summary: summary.clone(),
1457 });
1458 }
1459 message::ReasoningContent::Encrypted(data)
1460 | message::ReasoningContent::Redacted { data } => {
1461 reasoning_details.push(ReasoningDetails::Encrypted {
1462 id: r.id.clone(),
1463 format: None,
1464 index,
1465 data: data.clone(),
1466 });
1467 }
1468 }
1469 }
1470 }
1471 }
1472 message::AssistantContent::Image(_) => {
1473 return Err(Self::Error::ConversionError(
1474 "OpenRouter currently doesn't support images.".into(),
1475 ));
1476 }
1477 }
1478 }
1479
1480 Ok(vec![Message::Assistant {
1483 content: text_content
1484 .into_iter()
1485 .map(|content| content.text.into())
1486 .collect::<Vec<_>>(),
1487 refusal: None,
1488 audio: None,
1489 name: None,
1490 tool_calls,
1491 reasoning,
1492 reasoning_details,
1493 }])
1494 }
1495}
1496
1497impl TryFrom<message::Message> for Vec<Message> {
1500 type Error = message::MessageError;
1501
1502 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
1503 match message {
1504 message::Message::System { content } => Ok(vec![Message::System {
1505 content: OneOrMany::one(content.into()),
1506 name: None,
1507 }]),
1508 message::Message::User { content } => {
1509 content.try_into()
1512 }
1513 message::Message::Assistant { content, .. } => content.try_into(),
1514 }
1515 }
1516}
1517
1518#[derive(Debug, Serialize, Deserialize)]
1519#[serde(untagged, rename_all = "snake_case")]
1520pub enum ToolChoice {
1521 None,
1522 Auto,
1523 Required,
1524 Function(Vec<ToolChoiceFunctionKind>),
1525}
1526
1527impl TryFrom<crate::message::ToolChoice> for ToolChoice {
1528 type Error = CompletionError;
1529
1530 fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
1531 let res = match value {
1532 crate::message::ToolChoice::None => Self::None,
1533 crate::message::ToolChoice::Auto => Self::Auto,
1534 crate::message::ToolChoice::Required => Self::Required,
1535 crate::message::ToolChoice::Specific { function_names } => {
1536 let vec: Vec<ToolChoiceFunctionKind> = function_names
1537 .into_iter()
1538 .map(|name| ToolChoiceFunctionKind::Function { name })
1539 .collect();
1540
1541 Self::Function(vec)
1542 }
1543 };
1544
1545 Ok(res)
1546 }
1547}
1548
1549#[derive(Debug, Serialize, Deserialize)]
1550#[serde(tag = "type", content = "function")]
1551pub enum ToolChoiceFunctionKind {
1552 Function { name: String },
1553}
1554
1555#[derive(Debug, Serialize, Deserialize)]
1556pub(super) struct OpenrouterCompletionRequest {
1557 model: String,
1558 pub messages: Vec<Message>,
1559 #[serde(skip_serializing_if = "Option::is_none")]
1560 temperature: Option<f64>,
1561 #[serde(skip_serializing_if = "Vec::is_empty")]
1562 tools: Vec<crate::providers::openai::completion::ToolDefinition>,
1563 #[serde(skip_serializing_if = "Option::is_none")]
1564 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
1565 #[serde(flatten, skip_serializing_if = "Option::is_none")]
1566 pub additional_params: Option<serde_json::Value>,
1567}
1568
1569pub struct OpenRouterRequestParams<'a> {
1571 pub model: &'a str,
1572 pub request: CompletionRequest,
1573 pub strict_tools: bool,
1574}
1575
1576impl TryFrom<OpenRouterRequestParams<'_>> for OpenrouterCompletionRequest {
1577 type Error = CompletionError;
1578
1579 fn try_from(params: OpenRouterRequestParams) -> Result<Self, Self::Error> {
1580 let OpenRouterRequestParams {
1581 model,
1582 request: req,
1583 strict_tools,
1584 } = params;
1585 let model = req.model.clone().unwrap_or_else(|| model.to_string());
1586
1587 if req.output_schema.is_some() {
1588 tracing::warn!("Structured outputs currently not supported for OpenRouter");
1589 }
1590
1591 let mut full_history: Vec<Message> = match &req.preamble {
1592 Some(preamble) => vec![Message::system(preamble)],
1593 None => vec![],
1594 };
1595 if let Some(docs) = req.normalized_documents() {
1596 let docs: Vec<Message> = docs.try_into()?;
1597 full_history.extend(docs);
1598 }
1599
1600 let chat_history: Vec<Message> = req
1601 .chat_history
1602 .clone()
1603 .into_iter()
1604 .map(|message| message.try_into())
1605 .collect::<Result<Vec<Vec<Message>>, _>>()?
1606 .into_iter()
1607 .flatten()
1608 .collect();
1609
1610 full_history.extend(chat_history);
1611
1612 let tool_choice = req
1613 .tool_choice
1614 .clone()
1615 .map(crate::providers::openai::completion::ToolChoice::try_from)
1616 .transpose()?;
1617
1618 let tools: Vec<crate::providers::openai::completion::ToolDefinition> = req
1619 .tools
1620 .clone()
1621 .into_iter()
1622 .map(|tool| {
1623 let def = crate::providers::openai::completion::ToolDefinition::from(tool);
1624 if strict_tools { def.with_strict() } else { def }
1625 })
1626 .collect();
1627
1628 Ok(Self {
1629 model,
1630 messages: full_history,
1631 temperature: req.temperature,
1632 tools,
1633 tool_choice,
1634 additional_params: req.additional_params,
1635 })
1636 }
1637}
1638
1639impl TryFrom<(&str, CompletionRequest)> for OpenrouterCompletionRequest {
1640 type Error = CompletionError;
1641
1642 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
1643 let model = req.model.clone().unwrap_or_else(|| model.to_string());
1644 OpenrouterCompletionRequest::try_from(OpenRouterRequestParams {
1645 model: &model,
1646 request: req,
1647 strict_tools: false,
1648 })
1649 }
1650}
1651
1652#[derive(Clone)]
1653pub struct CompletionModel<T = reqwest::Client> {
1654 pub(crate) client: Client<T>,
1655 pub model: String,
1656 pub strict_tools: bool,
1659}
1660
1661impl<T> CompletionModel<T> {
1662 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
1663 Self {
1664 client,
1665 model: model.into(),
1666 strict_tools: false,
1667 }
1668 }
1669
1670 pub fn with_strict_tools(mut self) -> Self {
1679 self.strict_tools = true;
1680 self
1681 }
1682}
1683
1684impl<T> completion::CompletionModel for CompletionModel<T>
1685where
1686 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
1687{
1688 type Response = CompletionResponse;
1689 type StreamingResponse = StreamingCompletionResponse;
1690
1691 type Client = Client<T>;
1692
1693 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1694 Self::new(client.clone(), model)
1695 }
1696
1697 async fn completion(
1698 &self,
1699 completion_request: CompletionRequest,
1700 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
1701 let request_model = completion_request
1702 .model
1703 .clone()
1704 .unwrap_or_else(|| self.model.clone());
1705 let preamble = completion_request.preamble.clone();
1706 let request = OpenrouterCompletionRequest::try_from(OpenRouterRequestParams {
1707 model: request_model.as_ref(),
1708 request: completion_request,
1709 strict_tools: self.strict_tools,
1710 })?;
1711
1712 if enabled!(Level::TRACE) {
1713 tracing::trace!(
1714 target: "rig::completions",
1715 "OpenRouter completion request: {}",
1716 serde_json::to_string_pretty(&request)?
1717 );
1718 }
1719
1720 let span = if tracing::Span::current().is_disabled() {
1721 info_span!(
1722 target: "rig::completions",
1723 "chat",
1724 gen_ai.operation.name = "chat",
1725 gen_ai.provider.name = "openrouter",
1726 gen_ai.request.model = &request_model,
1727 gen_ai.system_instructions = preamble,
1728 gen_ai.response.id = tracing::field::Empty,
1729 gen_ai.response.model = tracing::field::Empty,
1730 gen_ai.usage.output_tokens = tracing::field::Empty,
1731 gen_ai.usage.input_tokens = tracing::field::Empty,
1732 gen_ai.usage.cached_tokens = tracing::field::Empty,
1733 )
1734 } else {
1735 tracing::Span::current()
1736 };
1737
1738 let body = serde_json::to_vec(&request)?;
1739
1740 let req = self
1741 .client
1742 .post("/chat/completions")?
1743 .body(body)
1744 .map_err(|x| CompletionError::HttpError(x.into()))?;
1745
1746 async move {
1747 let response = self.client.send::<_, Bytes>(req).await?;
1748 let status = response.status();
1749 let response_body = response.into_body().into_future().await?.to_vec();
1750
1751 if status.is_success() {
1752 let parsed: ApiResponse<CompletionResponse> =
1753 serde_json::from_slice(&response_body).map_err(|e| {
1754 CompletionError::ResponseError(format!(
1755 "Failed to parse OpenRouter completion response: {}, response body: {}",
1756 e,
1757 String::from_utf8_lossy(&response_body)
1758 ))
1759 })?;
1760 match parsed {
1761 ApiResponse::Ok(response) => {
1762 let span = tracing::Span::current();
1763 span.record_token_usage(&response.usage);
1764 span.record("gen_ai.response.id", &response.id);
1765 span.record("gen_ai.response.model_name", &response.model);
1766
1767 tracing::debug!(target: "rig::completions",
1768 "OpenRouter response: {response:?}");
1769 response.try_into()
1770 }
1771 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
1772 }
1773 } else {
1774 Err(CompletionError::ProviderError(
1775 String::from_utf8_lossy(&response_body).to_string(),
1776 ))
1777 }
1778 }
1779 .instrument(span)
1780 .await
1781 }
1782
1783 async fn stream(
1784 &self,
1785 completion_request: CompletionRequest,
1786 ) -> Result<
1787 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1788 CompletionError,
1789 > {
1790 CompletionModel::stream(self, completion_request).await
1791 }
1792}
1793
1794#[cfg(test)]
1795mod tests {
1796 use super::*;
1797 use serde_json::json;
1798
1799 #[test]
1800 fn test_openrouter_request_uses_request_model_override() {
1801 let request = CompletionRequest {
1802 model: Some("google/gemini-2.5-flash".to_string()),
1803 preamble: None,
1804 chat_history: crate::OneOrMany::one("Hello".into()),
1805 documents: vec![],
1806 tools: vec![],
1807 temperature: None,
1808 max_tokens: None,
1809 tool_choice: None,
1810 additional_params: None,
1811 output_schema: None,
1812 };
1813
1814 let openrouter_request =
1815 OpenrouterCompletionRequest::try_from(("openai/gpt-4o-mini", request))
1816 .expect("request conversion should succeed");
1817 let serialized =
1818 serde_json::to_value(openrouter_request).expect("serialization should succeed");
1819
1820 assert_eq!(serialized["model"], "google/gemini-2.5-flash");
1821 }
1822
1823 #[test]
1824 fn test_openrouter_request_uses_default_model_when_override_unset() {
1825 let request = CompletionRequest {
1826 model: None,
1827 preamble: None,
1828 chat_history: crate::OneOrMany::one("Hello".into()),
1829 documents: vec![],
1830 tools: vec![],
1831 temperature: None,
1832 max_tokens: None,
1833 tool_choice: None,
1834 additional_params: None,
1835 output_schema: None,
1836 };
1837
1838 let openrouter_request =
1839 OpenrouterCompletionRequest::try_from(("openai/gpt-4o-mini", request))
1840 .expect("request conversion should succeed");
1841 let serialized =
1842 serde_json::to_value(openrouter_request).expect("serialization should succeed");
1843
1844 assert_eq!(serialized["model"], "openai/gpt-4o-mini");
1845 }
1846
1847 #[test]
1848 fn test_completion_response_deserialization_gemini_flash() {
1849 let json = json!({
1851 "id": "gen-AAAAAAAAAA-AAAAAAAAAAAAAAAAAAAA",
1852 "provider": "Google",
1853 "model": "google/gemini-2.5-flash",
1854 "object": "chat.completion",
1855 "created": 1765971703u64,
1856 "choices": [{
1857 "logprobs": null,
1858 "finish_reason": "stop",
1859 "native_finish_reason": "STOP",
1860 "index": 0,
1861 "message": {
1862 "role": "assistant",
1863 "content": "CONTENT",
1864 "refusal": null,
1865 "reasoning": null
1866 }
1867 }],
1868 "usage": {
1869 "prompt_tokens": 669,
1870 "completion_tokens": 5,
1871 "total_tokens": 674
1872 }
1873 });
1874
1875 let response: CompletionResponse = serde_json::from_value(json).unwrap();
1876 assert_eq!(response.id, "gen-AAAAAAAAAA-AAAAAAAAAAAAAAAAAAAA");
1877 assert_eq!(response.model, "google/gemini-2.5-flash");
1878 assert_eq!(response.choices.len(), 1);
1879 assert_eq!(response.choices[0].finish_reason, Some("stop".to_string()));
1880 }
1881
1882 #[test]
1883 fn test_message_assistant_without_reasoning_details() {
1884 let json = json!({
1886 "role": "assistant",
1887 "content": "Hello world",
1888 "refusal": null,
1889 "reasoning": null
1890 });
1891
1892 let message: Message = serde_json::from_value(json).unwrap();
1893 match message {
1894 Message::Assistant {
1895 content,
1896 reasoning_details,
1897 ..
1898 } => {
1899 assert_eq!(content.len(), 1);
1900 assert!(reasoning_details.is_empty());
1901 }
1902 _ => panic!("Expected Assistant message"),
1903 }
1904 }
1905
1906 #[test]
1907 fn test_data_collection_serialization() {
1908 assert_eq!(
1909 serde_json::to_string(&DataCollection::Allow).unwrap(),
1910 r#""allow""#
1911 );
1912 assert_eq!(
1913 serde_json::to_string(&DataCollection::Deny).unwrap(),
1914 r#""deny""#
1915 );
1916 }
1917
1918 #[test]
1919 fn test_data_collection_default() {
1920 assert_eq!(DataCollection::default(), DataCollection::Allow);
1921 }
1922
1923 #[test]
1924 fn test_quantization_serialization() {
1925 assert_eq!(
1926 serde_json::to_string(&Quantization::Int4).unwrap(),
1927 r#""int4""#
1928 );
1929 assert_eq!(
1930 serde_json::to_string(&Quantization::Int8).unwrap(),
1931 r#""int8""#
1932 );
1933 assert_eq!(
1934 serde_json::to_string(&Quantization::Fp16).unwrap(),
1935 r#""fp16""#
1936 );
1937 assert_eq!(
1938 serde_json::to_string(&Quantization::Bf16).unwrap(),
1939 r#""bf16""#
1940 );
1941 assert_eq!(
1942 serde_json::to_string(&Quantization::Fp32).unwrap(),
1943 r#""fp32""#
1944 );
1945 assert_eq!(
1946 serde_json::to_string(&Quantization::Fp8).unwrap(),
1947 r#""fp8""#
1948 );
1949 assert_eq!(
1950 serde_json::to_string(&Quantization::Unknown).unwrap(),
1951 r#""unknown""#
1952 );
1953 }
1954
1955 #[test]
1956 fn test_provider_sort_strategy_serialization() {
1957 assert_eq!(
1958 serde_json::to_string(&ProviderSortStrategy::Price).unwrap(),
1959 r#""price""#
1960 );
1961 assert_eq!(
1962 serde_json::to_string(&ProviderSortStrategy::Throughput).unwrap(),
1963 r#""throughput""#
1964 );
1965 assert_eq!(
1966 serde_json::to_string(&ProviderSortStrategy::Latency).unwrap(),
1967 r#""latency""#
1968 );
1969 }
1970
1971 #[test]
1972 fn test_sort_partition_serialization() {
1973 assert_eq!(
1974 serde_json::to_string(&SortPartition::Model).unwrap(),
1975 r#""model""#
1976 );
1977 assert_eq!(
1978 serde_json::to_string(&SortPartition::None).unwrap(),
1979 r#""none""#
1980 );
1981 }
1982
1983 #[test]
1984 fn test_provider_sort_simple() {
1985 let sort = ProviderSort::Simple(ProviderSortStrategy::Latency);
1986 let json = serde_json::to_value(&sort).unwrap();
1987 assert_eq!(json, "latency");
1988 }
1989
1990 #[test]
1991 fn test_provider_sort_complex() {
1992 let sort = ProviderSort::Complex(
1993 ProviderSortConfig::new(ProviderSortStrategy::Price).partition(SortPartition::None),
1994 );
1995 let json = serde_json::to_value(&sort).unwrap();
1996 assert_eq!(json["by"], "price");
1997 assert_eq!(json["partition"], "none");
1998 }
1999
2000 #[test]
2001 fn test_provider_sort_complex_without_partition() {
2002 let sort = ProviderSort::Complex(ProviderSortConfig::new(ProviderSortStrategy::Throughput));
2003 let json = serde_json::to_value(&sort).unwrap();
2004 assert_eq!(json["by"], "throughput");
2005 assert!(json.get("partition").is_none());
2006 }
2007
2008 #[test]
2009 fn test_provider_sort_from_strategy() {
2010 let sort: ProviderSort = ProviderSortStrategy::Price.into();
2011 assert_eq!(sort, ProviderSort::Simple(ProviderSortStrategy::Price));
2012 }
2013
2014 #[test]
2015 fn test_provider_sort_from_config() {
2016 let config = ProviderSortConfig::new(ProviderSortStrategy::Latency);
2017 let sort: ProviderSort = config.into();
2018 match sort {
2019 ProviderSort::Complex(c) => assert_eq!(c.by, ProviderSortStrategy::Latency),
2020 _ => panic!("Expected Complex variant"),
2021 }
2022 }
2023
2024 #[test]
2025 fn test_percentile_thresholds_builder() {
2026 let thresholds = PercentileThresholds::new()
2027 .p50(10.0)
2028 .p75(25.0)
2029 .p90(50.0)
2030 .p99(100.0);
2031
2032 assert_eq!(thresholds.p50, Some(10.0));
2033 assert_eq!(thresholds.p75, Some(25.0));
2034 assert_eq!(thresholds.p90, Some(50.0));
2035 assert_eq!(thresholds.p99, Some(100.0));
2036 }
2037
2038 #[test]
2039 fn test_percentile_thresholds_default() {
2040 let thresholds = PercentileThresholds::default();
2041 assert_eq!(thresholds.p50, None);
2042 assert_eq!(thresholds.p75, None);
2043 assert_eq!(thresholds.p90, None);
2044 assert_eq!(thresholds.p99, None);
2045 }
2046
2047 #[test]
2048 fn test_throughput_threshold_simple() {
2049 let threshold = ThroughputThreshold::Simple(50.0);
2050 let json = serde_json::to_value(&threshold).unwrap();
2051 assert_eq!(json, 50.0);
2052 }
2053
2054 #[test]
2055 fn test_throughput_threshold_percentile() {
2056 let threshold = ThroughputThreshold::Percentile(PercentileThresholds::new().p90(50.0));
2057 let json = serde_json::to_value(&threshold).unwrap();
2058 assert_eq!(json["p90"], 50.0);
2059 }
2060
2061 #[test]
2062 fn test_latency_threshold_simple() {
2063 let threshold = LatencyThreshold::Simple(0.5);
2064 let json = serde_json::to_value(&threshold).unwrap();
2065 assert_eq!(json, 0.5);
2066 }
2067
2068 #[test]
2069 fn test_latency_threshold_percentile() {
2070 let threshold = LatencyThreshold::Percentile(PercentileThresholds::new().p50(0.1).p99(1.0));
2071 let json = serde_json::to_value(&threshold).unwrap();
2072 assert_eq!(json["p50"], 0.1);
2073 assert_eq!(json["p99"], 1.0);
2074 }
2075
2076 #[test]
2077 fn test_max_price_builder() {
2078 let price = MaxPrice::new().prompt(0.001).completion(0.002);
2079
2080 assert_eq!(price.prompt, Some(0.001));
2081 assert_eq!(price.completion, Some(0.002));
2082 assert_eq!(price.request, None);
2083 assert_eq!(price.image, None);
2084 }
2085
2086 #[test]
2087 fn test_max_price_all_fields() {
2088 let price = MaxPrice::new()
2089 .prompt(0.001)
2090 .completion(0.002)
2091 .request(0.01)
2092 .image(0.05);
2093
2094 let json = serde_json::to_value(&price).unwrap();
2095 assert_eq!(json["prompt"], 0.001);
2096 assert_eq!(json["completion"], 0.002);
2097 assert_eq!(json["request"], 0.01);
2098 assert_eq!(json["image"], 0.05);
2099 }
2100
2101 #[test]
2102 fn test_max_price_default() {
2103 let price = MaxPrice::default();
2104 assert_eq!(price.prompt, None);
2105 assert_eq!(price.completion, None);
2106 assert_eq!(price.request, None);
2107 assert_eq!(price.image, None);
2108 }
2109
2110 #[test]
2111 fn test_provider_preferences_default() {
2112 let prefs = ProviderPreferences::default();
2113 assert!(prefs.order.is_none());
2114 assert!(prefs.only.is_none());
2115 assert!(prefs.ignore.is_none());
2116 assert!(prefs.allow_fallbacks.is_none());
2117 assert!(prefs.require_parameters.is_none());
2118 assert!(prefs.data_collection.is_none());
2119 assert!(prefs.zdr.is_none());
2120 assert!(prefs.sort.is_none());
2121 assert!(prefs.preferred_min_throughput.is_none());
2122 assert!(prefs.preferred_max_latency.is_none());
2123 assert!(prefs.max_price.is_none());
2124 assert!(prefs.quantizations.is_none());
2125 }
2126
2127 #[test]
2128 fn test_provider_preferences_order_with_fallbacks() {
2129 let prefs = ProviderPreferences::new()
2130 .order(["anthropic", "openai"])
2131 .allow_fallbacks(true);
2132
2133 let json = prefs.to_json();
2134 let provider = &json["provider"];
2135
2136 assert_eq!(provider["order"], json!(["anthropic", "openai"]));
2137 assert_eq!(provider["allow_fallbacks"], true);
2138 }
2139
2140 #[test]
2141 fn test_provider_preferences_only_allowlist() {
2142 let prefs = ProviderPreferences::new()
2143 .only(["azure", "together"])
2144 .allow_fallbacks(false);
2145
2146 let json = prefs.to_json();
2147 let provider = &json["provider"];
2148
2149 assert_eq!(provider["only"], json!(["azure", "together"]));
2150 assert_eq!(provider["allow_fallbacks"], false);
2151 }
2152
2153 #[test]
2154 fn test_provider_preferences_ignore() {
2155 let prefs = ProviderPreferences::new().ignore(["deepinfra"]);
2156
2157 let json = prefs.to_json();
2158 let provider = &json["provider"];
2159
2160 assert_eq!(provider["ignore"], json!(["deepinfra"]));
2161 }
2162
2163 #[test]
2164 fn test_provider_preferences_sort_latency() {
2165 let prefs = ProviderPreferences::new().sort(ProviderSortStrategy::Latency);
2166
2167 let json = prefs.to_json();
2168 let provider = &json["provider"];
2169
2170 assert_eq!(provider["sort"], "latency");
2171 }
2172
2173 #[test]
2174 fn test_provider_preferences_price_with_throughput() {
2175 let prefs = ProviderPreferences::new()
2176 .sort(ProviderSortStrategy::Price)
2177 .preferred_min_throughput(ThroughputThreshold::Percentile(
2178 PercentileThresholds::new().p90(50.0),
2179 ));
2180
2181 let json = prefs.to_json();
2182 let provider = &json["provider"];
2183
2184 assert_eq!(provider["sort"], "price");
2185 assert_eq!(provider["preferred_min_throughput"]["p90"], 50.0);
2186 }
2187
2188 #[test]
2189 fn test_provider_preferences_require_parameters() {
2190 let prefs = ProviderPreferences::new().require_parameters(true);
2191
2192 let json = prefs.to_json();
2193 let provider = &json["provider"];
2194
2195 assert_eq!(provider["require_parameters"], true);
2196 }
2197
2198 #[test]
2199 fn test_provider_preferences_data_policy_and_zdr() {
2200 let prefs = ProviderPreferences::new()
2201 .data_collection(DataCollection::Deny)
2202 .zdr(true);
2203
2204 let json = prefs.to_json();
2205 let provider = &json["provider"];
2206
2207 assert_eq!(provider["data_collection"], "deny");
2208 assert_eq!(provider["zdr"], true);
2209 }
2210
2211 #[test]
2212 fn test_provider_preferences_quantizations() {
2213 let prefs =
2214 ProviderPreferences::new().quantizations([Quantization::Int8, Quantization::Fp16]);
2215
2216 let json = prefs.to_json();
2217 let provider = &json["provider"];
2218
2219 assert_eq!(provider["quantizations"], json!(["int8", "fp16"]));
2220 }
2221
2222 #[test]
2223 fn test_provider_preferences_convenience_methods() {
2224 let prefs = ProviderPreferences::new().zero_data_retention().fastest();
2225
2226 assert_eq!(prefs.zdr, Some(true));
2227 assert_eq!(
2228 prefs.sort,
2229 Some(ProviderSort::Simple(ProviderSortStrategy::Throughput))
2230 );
2231
2232 let prefs2 = ProviderPreferences::new().cheapest();
2233 assert_eq!(
2234 prefs2.sort,
2235 Some(ProviderSort::Simple(ProviderSortStrategy::Price))
2236 );
2237
2238 let prefs3 = ProviderPreferences::new().lowest_latency();
2239 assert_eq!(
2240 prefs3.sort,
2241 Some(ProviderSort::Simple(ProviderSortStrategy::Latency))
2242 );
2243 }
2244
2245 #[test]
2246 fn test_provider_preferences_serialization_skips_none() {
2247 let prefs = ProviderPreferences::new().sort(ProviderSortStrategy::Price);
2248
2249 let json = serde_json::to_value(&prefs).unwrap();
2250
2251 assert_eq!(json["sort"], "price");
2252 assert!(json.get("order").is_none());
2253 assert!(json.get("only").is_none());
2254 assert!(json.get("ignore").is_none());
2255 assert!(json.get("zdr").is_none());
2256 }
2257
2258 #[test]
2259 fn test_provider_preferences_deserialization() {
2260 let json = json!({
2261 "order": ["anthropic", "openai"],
2262 "sort": "throughput",
2263 "data_collection": "deny",
2264 "zdr": true,
2265 "quantizations": ["int8", "fp16"]
2266 });
2267
2268 let prefs: ProviderPreferences = serde_json::from_value(json).unwrap();
2269
2270 assert_eq!(
2271 prefs.order,
2272 Some(vec!["anthropic".to_string(), "openai".to_string()])
2273 );
2274 assert_eq!(
2275 prefs.sort,
2276 Some(ProviderSort::Simple(ProviderSortStrategy::Throughput))
2277 );
2278 assert_eq!(prefs.data_collection, Some(DataCollection::Deny));
2279 assert_eq!(prefs.zdr, Some(true));
2280 assert_eq!(
2281 prefs.quantizations,
2282 Some(vec![Quantization::Int8, Quantization::Fp16])
2283 );
2284 }
2285
2286 #[test]
2287 fn test_provider_preferences_deserialization_complex_sort() {
2288 let json = json!({
2289 "sort": {
2290 "by": "latency",
2291 "partition": "model"
2292 }
2293 });
2294
2295 let prefs: ProviderPreferences = serde_json::from_value(json).unwrap();
2296
2297 match prefs.sort {
2298 Some(ProviderSort::Complex(config)) => {
2299 assert_eq!(config.by, ProviderSortStrategy::Latency);
2300 assert_eq!(config.partition, Some(SortPartition::Model));
2301 }
2302 _ => panic!("Expected Complex sort variant"),
2303 }
2304 }
2305
2306 #[test]
2307 fn test_provider_preferences_full_integration() {
2308 let prefs = ProviderPreferences::new()
2309 .order(["anthropic", "openai"])
2310 .only(["anthropic", "openai", "google"])
2311 .sort(ProviderSortStrategy::Throughput)
2312 .data_collection(DataCollection::Deny)
2313 .zdr(true)
2314 .quantizations([Quantization::Int8])
2315 .allow_fallbacks(false);
2316
2317 let json = prefs.to_json();
2318
2319 assert!(json.get("provider").is_some());
2320 let provider = &json["provider"];
2321 assert_eq!(provider["order"], json!(["anthropic", "openai"]));
2322 assert_eq!(provider["only"], json!(["anthropic", "openai", "google"]));
2323 assert_eq!(provider["sort"], "throughput");
2324 assert_eq!(provider["data_collection"], "deny");
2325 assert_eq!(provider["zdr"], true);
2326 assert_eq!(provider["quantizations"], json!(["int8"]));
2327 assert_eq!(provider["allow_fallbacks"], false);
2328 }
2329
2330 #[test]
2331 fn test_provider_preferences_max_price() {
2332 let prefs =
2333 ProviderPreferences::new().max_price(MaxPrice::new().prompt(0.001).completion(0.002));
2334
2335 let json = prefs.to_json();
2336 let provider = &json["provider"];
2337
2338 assert_eq!(provider["max_price"]["prompt"], 0.001);
2339 assert_eq!(provider["max_price"]["completion"], 0.002);
2340 }
2341
2342 #[test]
2343 fn test_provider_preferences_preferred_max_latency() {
2344 let prefs = ProviderPreferences::new().preferred_max_latency(LatencyThreshold::Simple(0.5));
2345
2346 let json = prefs.to_json();
2347 let provider = &json["provider"];
2348
2349 assert_eq!(provider["preferred_max_latency"], 0.5);
2350 }
2351
2352 #[test]
2353 fn test_provider_preferences_empty_arrays() {
2354 let prefs = ProviderPreferences::new()
2355 .order(Vec::<String>::new())
2356 .quantizations(Vec::<Quantization>::new());
2357
2358 let json = prefs.to_json();
2359 let provider = &json["provider"];
2360
2361 assert_eq!(provider["order"], json!([]));
2362 assert_eq!(provider["quantizations"], json!([]));
2363 }
2364
2365 #[test]
2370 fn test_user_content_text_serialization() {
2371 let content = UserContent::text("Hello, world!");
2372 let json = serde_json::to_value(&content).unwrap();
2373
2374 assert_eq!(json["type"], "text");
2375 assert_eq!(json["text"], "Hello, world!");
2376 }
2377
2378 #[test]
2379 fn test_user_content_image_url_serialization() {
2380 let content = UserContent::image_url("https://example.com/image.png");
2381 let json = serde_json::to_value(&content).unwrap();
2382
2383 assert_eq!(json["type"], "image_url");
2384 assert_eq!(json["image_url"]["url"], "https://example.com/image.png");
2385 assert!(json["image_url"].get("detail").is_none());
2386 }
2387
2388 #[test]
2389 fn test_user_content_image_url_with_detail_serialization() {
2390 let content =
2391 UserContent::image_url_with_detail("https://example.com/image.png", ImageDetail::High);
2392 let json = serde_json::to_value(&content).unwrap();
2393
2394 assert_eq!(json["type"], "image_url");
2395 assert_eq!(json["image_url"]["url"], "https://example.com/image.png");
2396 assert_eq!(json["image_url"]["detail"], "high");
2397 }
2398
2399 #[test]
2400 fn test_user_content_image_base64_serialization() {
2401 let content = UserContent::image_base64("SGVsbG8=", "image/png", Some(ImageDetail::Low));
2402 let json = serde_json::to_value(&content).unwrap();
2403
2404 assert_eq!(json["type"], "image_url");
2405 assert_eq!(json["image_url"]["url"], "data:image/png;base64,SGVsbG8=");
2406 assert_eq!(json["image_url"]["detail"], "low");
2407 }
2408
2409 #[test]
2410 fn test_user_content_file_url_serialization() {
2411 let content = UserContent::file_url(
2412 "https://example.com/doc.pdf",
2413 Some("document.pdf".to_string()),
2414 );
2415 let json = serde_json::to_value(&content).unwrap();
2416
2417 assert_eq!(json["type"], "file");
2418 assert_eq!(json["file"]["file_data"], "https://example.com/doc.pdf");
2419 assert_eq!(json["file"]["filename"], "document.pdf");
2420 }
2421
2422 #[test]
2423 fn test_user_content_file_base64_serialization() {
2424 let content = UserContent::file_base64(
2425 "JVBERi0xLjQ=",
2426 "application/pdf",
2427 Some("report.pdf".to_string()),
2428 );
2429 let json = serde_json::to_value(&content).unwrap();
2430
2431 assert_eq!(json["type"], "file");
2432 assert_eq!(
2433 json["file"]["file_data"],
2434 "data:application/pdf;base64,JVBERi0xLjQ="
2435 );
2436 assert_eq!(json["file"]["filename"], "report.pdf");
2437 }
2438
2439 #[test]
2440 fn test_user_content_text_deserialization() {
2441 let json = json!({
2442 "type": "text",
2443 "text": "Hello!"
2444 });
2445
2446 let content: UserContent = serde_json::from_value(json).unwrap();
2447 assert_eq!(
2448 content,
2449 UserContent::Text {
2450 text: "Hello!".to_string()
2451 }
2452 );
2453 }
2454
2455 #[test]
2456 fn test_user_content_image_url_deserialization() {
2457 let json = json!({
2458 "type": "image_url",
2459 "image_url": {
2460 "url": "https://example.com/img.jpg",
2461 "detail": "high"
2462 }
2463 });
2464
2465 let content: UserContent = serde_json::from_value(json).unwrap();
2466 match content {
2467 UserContent::ImageUrl { image_url } => {
2468 assert_eq!(image_url.url, "https://example.com/img.jpg");
2469 assert_eq!(image_url.detail, Some(ImageDetail::High));
2470 }
2471 _ => panic!("Expected ImageUrl variant"),
2472 }
2473 }
2474
2475 #[test]
2476 fn test_user_content_file_deserialization() {
2477 let json = json!({
2478 "type": "file",
2479 "file": {
2480 "filename": "doc.pdf",
2481 "file_data": "https://example.com/doc.pdf"
2482 }
2483 });
2484
2485 let content: UserContent = serde_json::from_value(json).unwrap();
2486 match content {
2487 UserContent::File { file } => {
2488 assert_eq!(file.filename, Some("doc.pdf".to_string()));
2489 assert_eq!(
2490 file.file_data,
2491 Some("https://example.com/doc.pdf".to_string())
2492 );
2493 }
2494 _ => panic!("Expected File variant"),
2495 }
2496 }
2497
2498 #[test]
2499 fn test_message_user_with_text_serialization() {
2500 let message = Message::User {
2501 content: OneOrMany::one(UserContent::text("Hello")),
2502 name: None,
2503 };
2504 let json = serde_json::to_value(&message).unwrap();
2505
2506 assert_eq!(json["role"], "user");
2508 assert_eq!(json["content"], "Hello");
2509 }
2510
2511 #[test]
2512 fn test_message_user_with_mixed_content_serialization() {
2513 let message = Message::User {
2514 content: OneOrMany::many(vec![
2515 UserContent::text("Check this image:"),
2516 UserContent::image_url("https://example.com/img.png"),
2517 ])
2518 .unwrap(),
2519 name: None,
2520 };
2521 let json = serde_json::to_value(&message).unwrap();
2522
2523 assert_eq!(json["role"], "user");
2524 let content = json["content"].as_array().unwrap();
2525 assert_eq!(content.len(), 2);
2526 assert_eq!(content[0]["type"], "text");
2527 assert_eq!(content[1]["type"], "image_url");
2528 }
2529
2530 #[test]
2531 fn test_message_user_with_file_serialization() {
2532 let message = Message::User {
2533 content: OneOrMany::many(vec![
2534 UserContent::text("Analyze this PDF:"),
2535 UserContent::file_url(
2536 "https://example.com/doc.pdf",
2537 Some("document.pdf".to_string()),
2538 ),
2539 ])
2540 .unwrap(),
2541 name: None,
2542 };
2543 let json = serde_json::to_value(&message).unwrap();
2544
2545 assert_eq!(json["role"], "user");
2546 let content = json["content"].as_array().unwrap();
2547 assert_eq!(content.len(), 2);
2548 assert_eq!(content[0]["type"], "text");
2549 assert_eq!(content[1]["type"], "file");
2550 assert_eq!(
2551 content[1]["file"]["file_data"],
2552 "https://example.com/doc.pdf"
2553 );
2554 }
2555
2556 #[test]
2557 fn test_user_content_from_rig_text() {
2558 let rig_content = message::UserContent::Text(message::Text {
2559 text: "Hello".to_string(),
2560 });
2561 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2562
2563 assert_eq!(
2564 openrouter_content,
2565 UserContent::Text {
2566 text: "Hello".to_string()
2567 }
2568 );
2569 }
2570
2571 #[test]
2572 fn test_user_content_from_rig_image_url() {
2573 let rig_content = message::UserContent::Image(message::Image {
2574 data: DocumentSourceKind::Url("https://example.com/img.png".to_string()),
2575 media_type: Some(message::ImageMediaType::PNG),
2576 detail: Some(ImageDetail::High),
2577 additional_params: None,
2578 });
2579 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2580
2581 match openrouter_content {
2582 UserContent::ImageUrl { image_url } => {
2583 assert_eq!(image_url.url, "https://example.com/img.png");
2584 assert_eq!(image_url.detail, Some(ImageDetail::High));
2585 }
2586 _ => panic!("Expected ImageUrl variant"),
2587 }
2588 }
2589
2590 #[test]
2591 fn test_user_content_from_rig_image_base64() {
2592 let rig_content = message::UserContent::Image(message::Image {
2593 data: DocumentSourceKind::Base64("SGVsbG8=".to_string()),
2594 media_type: Some(message::ImageMediaType::JPEG),
2595 detail: Some(ImageDetail::Low),
2596 additional_params: None,
2597 });
2598 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2599
2600 match openrouter_content {
2601 UserContent::ImageUrl { image_url } => {
2602 assert_eq!(image_url.url, "data:image/jpeg;base64,SGVsbG8=");
2603 assert_eq!(image_url.detail, Some(ImageDetail::Low));
2604 }
2605 _ => panic!("Expected ImageUrl variant"),
2606 }
2607 }
2608
2609 #[test]
2610 fn test_user_content_from_rig_document_url() {
2611 let rig_content = message::UserContent::Document(message::Document {
2612 data: DocumentSourceKind::Url("https://example.com/doc.pdf".to_string()),
2613 media_type: Some(DocumentMediaType::PDF),
2614 additional_params: None,
2615 });
2616 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2617
2618 match openrouter_content {
2619 UserContent::File { file } => {
2620 assert_eq!(
2621 file.file_data,
2622 Some("https://example.com/doc.pdf".to_string())
2623 );
2624 assert_eq!(file.filename, Some("document.pdf".to_string()));
2625 }
2626 _ => panic!("Expected File variant"),
2627 }
2628 }
2629
2630 #[test]
2631 fn test_user_content_from_rig_document_base64() {
2632 let rig_content = message::UserContent::Document(message::Document {
2633 data: DocumentSourceKind::Base64("JVBERi0xLjQ=".to_string()),
2634 media_type: Some(DocumentMediaType::PDF),
2635 additional_params: None,
2636 });
2637 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2638
2639 match openrouter_content {
2640 UserContent::File { file } => {
2641 assert_eq!(
2642 file.file_data,
2643 Some("data:application/pdf;base64,JVBERi0xLjQ=".to_string())
2644 );
2645 assert_eq!(file.filename, Some("document.pdf".to_string()));
2646 }
2647 _ => panic!("Expected File variant"),
2648 }
2649 }
2650
2651 #[test]
2652 fn test_user_content_from_rig_document_string_becomes_text() {
2653 let rig_content = message::UserContent::Document(message::Document {
2654 data: DocumentSourceKind::String("Plain text document content".to_string()),
2655 media_type: Some(DocumentMediaType::TXT),
2656 additional_params: None,
2657 });
2658 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2659
2660 assert_eq!(
2661 openrouter_content,
2662 UserContent::Text {
2663 text: "Plain text document content".to_string()
2664 }
2665 );
2666 }
2667
2668 #[test]
2669 fn test_completion_response_with_reasoning_details_maps_to_typed_reasoning() {
2670 let json = json!({
2671 "id": "resp_123",
2672 "object": "chat.completion",
2673 "created": 1,
2674 "model": "openrouter/test-model",
2675 "choices": [{
2676 "index": 0,
2677 "finish_reason": "stop",
2678 "message": {
2679 "role": "assistant",
2680 "content": "hello",
2681 "reasoning": null,
2682 "reasoning_details": [
2683 {"type":"reasoning.summary","id":"rs_1","summary":"s1"},
2684 {"type":"reasoning.text","id":"rs_1","text":"t1","signature":"sig_1"},
2685 {"type":"reasoning.encrypted","id":"rs_1","data":"enc_1"}
2686 ]
2687 }
2688 }]
2689 });
2690
2691 let response: CompletionResponse = serde_json::from_value(json).unwrap();
2692 let converted: completion::CompletionResponse<CompletionResponse> =
2693 response.try_into().unwrap();
2694 let items: Vec<completion::AssistantContent> = converted.choice.into_iter().collect();
2695
2696 assert!(items.iter().any(|item| matches!(
2697 item,
2698 completion::AssistantContent::Reasoning(message::Reasoning { id: Some(id), content })
2699 if id == "rs_1" && content.len() == 3
2700 )));
2701 }
2702
2703 #[test]
2704 fn test_assistant_reasoning_emits_openrouter_reasoning_details() {
2705 let reasoning = message::Reasoning {
2706 id: Some("rs_2".to_string()),
2707 content: vec![
2708 message::ReasoningContent::Text {
2709 text: "step".to_string(),
2710 signature: Some("sig_step".to_string()),
2711 },
2712 message::ReasoningContent::Summary("summary".to_string()),
2713 message::ReasoningContent::Encrypted("enc_blob".to_string()),
2714 ],
2715 };
2716
2717 let messages = Vec::<Message>::try_from(OneOrMany::one(
2718 message::AssistantContent::Reasoning(reasoning),
2719 ))
2720 .unwrap();
2721 let Message::Assistant {
2722 reasoning,
2723 reasoning_details,
2724 ..
2725 } = messages.first().expect("assistant message")
2726 else {
2727 panic!("Expected assistant message");
2728 };
2729
2730 assert!(reasoning.is_none());
2731 assert_eq!(reasoning_details.len(), 3);
2732 assert!(matches!(
2733 reasoning_details.first(),
2734 Some(ReasoningDetails::Text {
2735 id: Some(id),
2736 text: Some(text),
2737 signature: Some(signature),
2738 ..
2739 }) if id == "rs_2" && text == "step" && signature == "sig_step"
2740 ));
2741 }
2742
2743 #[test]
2744 fn test_assistant_redacted_reasoning_emits_encrypted_detail_not_text() {
2745 let reasoning = message::Reasoning {
2746 id: Some("rs_redacted".to_string()),
2747 content: vec![message::ReasoningContent::Redacted {
2748 data: "opaque-redacted-data".to_string(),
2749 }],
2750 };
2751
2752 let messages = Vec::<Message>::try_from(OneOrMany::one(
2753 message::AssistantContent::Reasoning(reasoning),
2754 ))
2755 .unwrap();
2756
2757 let Message::Assistant {
2758 reasoning_details,
2759 reasoning,
2760 ..
2761 } = messages.first().expect("assistant message")
2762 else {
2763 panic!("Expected assistant message");
2764 };
2765
2766 assert!(reasoning.is_none());
2767 assert_eq!(reasoning_details.len(), 1);
2768 assert!(matches!(
2769 reasoning_details.first(),
2770 Some(ReasoningDetails::Encrypted {
2771 id: Some(id),
2772 data,
2773 ..
2774 }) if id == "rs_redacted" && data == "opaque-redacted-data"
2775 ));
2776 }
2777
2778 #[test]
2779 fn test_completion_response_reasoning_details_respects_index_ordering() {
2780 let json = json!({
2781 "id": "resp_ordering",
2782 "object": "chat.completion",
2783 "created": 1,
2784 "model": "openrouter/test-model",
2785 "choices": [{
2786 "index": 0,
2787 "finish_reason": "stop",
2788 "message": {
2789 "role": "assistant",
2790 "content": "hello",
2791 "reasoning": null,
2792 "reasoning_details": [
2793 {"type":"reasoning.summary","id":"rs_order","index":1,"summary":"second"},
2794 {"type":"reasoning.summary","id":"rs_order","index":0,"summary":"first"}
2795 ]
2796 }
2797 }]
2798 });
2799
2800 let response: CompletionResponse = serde_json::from_value(json).unwrap();
2801 let converted: completion::CompletionResponse<CompletionResponse> =
2802 response.try_into().unwrap();
2803 let items: Vec<completion::AssistantContent> = converted.choice.into_iter().collect();
2804 let reasoning_blocks: Vec<_> = items
2805 .into_iter()
2806 .filter_map(|item| match item {
2807 completion::AssistantContent::Reasoning(reasoning) => Some(reasoning),
2808 _ => None,
2809 })
2810 .collect();
2811
2812 assert_eq!(reasoning_blocks.len(), 1);
2813 assert_eq!(reasoning_blocks[0].id.as_deref(), Some("rs_order"));
2814 assert_eq!(
2815 reasoning_blocks[0].content,
2816 vec![
2817 message::ReasoningContent::Summary("first".to_string()),
2818 message::ReasoningContent::Summary("second".to_string()),
2819 ]
2820 );
2821 }
2822
2823 #[test]
2824 fn test_user_content_from_rig_image_missing_media_type_error() {
2825 let rig_content = message::UserContent::Image(message::Image {
2826 data: DocumentSourceKind::Base64("SGVsbG8=".to_string()),
2827 media_type: None, detail: None,
2829 additional_params: None,
2830 });
2831 let result: Result<UserContent, _> = rig_content.try_into();
2832
2833 assert!(result.is_err());
2834 let err = result.unwrap_err();
2835 assert!(err.to_string().contains("media type required"));
2836 }
2837
2838 #[test]
2839 fn test_user_content_from_rig_image_raw_bytes_error() {
2840 let rig_content = message::UserContent::Image(message::Image {
2841 data: DocumentSourceKind::Raw(vec![1, 2, 3]),
2842 media_type: Some(message::ImageMediaType::PNG),
2843 detail: None,
2844 additional_params: None,
2845 });
2846 let result: Result<UserContent, _> = rig_content.try_into();
2847
2848 assert!(result.is_err());
2849 let err = result.unwrap_err();
2850 assert!(err.to_string().contains("base64"));
2851 }
2852
2853 #[test]
2854 fn test_user_content_from_rig_video_url() {
2855 let rig_content = message::UserContent::Video(message::Video {
2856 data: DocumentSourceKind::Url("https://example.com/video.mp4".to_string()),
2857 media_type: Some(message::VideoMediaType::MP4),
2858 additional_params: None,
2859 });
2860 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2861
2862 match openrouter_content {
2863 UserContent::VideoUrl { video_url } => {
2864 assert_eq!(video_url.url, "https://example.com/video.mp4");
2865 }
2866 _ => panic!("Expected VideoUrl variant"),
2867 }
2868 }
2869
2870 #[test]
2871 fn test_user_content_from_rig_video_base64() {
2872 let rig_content = message::UserContent::Video(message::Video {
2873 data: DocumentSourceKind::Base64("SGVsbG8=".to_string()),
2874 media_type: Some(message::VideoMediaType::MP4),
2875 additional_params: None,
2876 });
2877 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2878
2879 match openrouter_content {
2880 UserContent::VideoUrl { video_url } => {
2881 assert_eq!(video_url.url, "data:video/mp4;base64,SGVsbG8=");
2882 }
2883 _ => panic!("Expected VideoUrl variant"),
2884 }
2885 }
2886
2887 #[test]
2888 fn test_user_content_from_rig_video_base64_missing_media_type_error() {
2889 let rig_content = message::UserContent::Video(message::Video {
2890 data: DocumentSourceKind::Base64("SGVsbG8=".to_string()),
2891 media_type: None,
2892 additional_params: None,
2893 });
2894 let result: Result<UserContent, _> = rig_content.try_into();
2895
2896 assert!(result.is_err());
2897 let err = result.unwrap_err();
2898 assert!(err.to_string().contains("media type"));
2899 }
2900
2901 #[test]
2902 fn test_user_content_from_rig_video_raw_bytes_error() {
2903 let rig_content = message::UserContent::Video(message::Video {
2904 data: DocumentSourceKind::Raw(vec![1, 2, 3]),
2905 media_type: Some(message::VideoMediaType::MP4),
2906 additional_params: None,
2907 });
2908 let result: Result<UserContent, _> = rig_content.try_into();
2909
2910 assert!(result.is_err());
2911 let err = result.unwrap_err();
2912 assert!(err.to_string().contains("base64"));
2913 }
2914
2915 #[test]
2916 fn test_user_content_from_rig_audio_base64() {
2917 let rig_content = message::UserContent::Audio(message::Audio {
2918 data: DocumentSourceKind::Base64("audiodata".to_string()),
2919 media_type: Some(message::AudioMediaType::MP3),
2920 additional_params: None,
2921 });
2922 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2923
2924 match openrouter_content {
2925 UserContent::InputAudio { input_audio } => {
2926 assert_eq!(input_audio.data, "audiodata");
2927 assert_eq!(input_audio.format, message::AudioMediaType::MP3);
2928 }
2929 _ => panic!("Expected InputAudio variant"),
2930 }
2931 }
2932
2933 #[test]
2934 fn test_user_content_from_rig_audio_missing_media_type_error() {
2935 let rig_content = message::UserContent::Audio(message::Audio {
2936 data: DocumentSourceKind::Base64("audiodata".to_string()),
2937 media_type: None, additional_params: None,
2939 });
2940 let result: Result<UserContent, _> = rig_content.try_into();
2941
2942 assert!(result.is_err());
2943 let err = result.unwrap_err();
2944 assert!(err.to_string().contains("media type required"));
2945 }
2946
2947 #[test]
2948 fn test_user_content_from_rig_audio_url_error() {
2949 let rig_content = message::UserContent::Audio(message::Audio {
2950 data: DocumentSourceKind::Url("https://example.com/audio.wav".to_string()),
2951 media_type: Some(message::AudioMediaType::WAV),
2952 additional_params: None,
2953 });
2954 let result: Result<UserContent, _> = rig_content.try_into();
2955
2956 assert!(result.is_err());
2957 let err = result.unwrap_err();
2958 assert!(err.to_string().contains("base64"));
2959 }
2960
2961 #[test]
2962 fn test_user_content_from_rig_audio_raw_bytes_error() {
2963 let rig_content = message::UserContent::Audio(message::Audio {
2964 data: DocumentSourceKind::Raw(vec![1, 2, 3]),
2965 media_type: Some(message::AudioMediaType::WAV),
2966 additional_params: None,
2967 });
2968 let result: Result<UserContent, _> = rig_content.try_into();
2969
2970 assert!(result.is_err());
2971 let err = result.unwrap_err();
2972 assert!(err.to_string().contains("base64"));
2973 }
2974
2975 #[test]
2976 fn test_message_conversion_with_pdf() {
2977 let rig_message = message::Message::User {
2978 content: OneOrMany::many(vec![
2979 message::UserContent::Text(message::Text {
2980 text: "Summarize this document".to_string(),
2981 }),
2982 message::UserContent::Document(message::Document {
2983 data: DocumentSourceKind::Url("https://example.com/paper.pdf".to_string()),
2984 media_type: Some(DocumentMediaType::PDF),
2985 additional_params: None,
2986 }),
2987 ])
2988 .unwrap(),
2989 };
2990
2991 let openrouter_messages: Vec<Message> = rig_message.try_into().unwrap();
2992 assert_eq!(openrouter_messages.len(), 1);
2993
2994 match &openrouter_messages[0] {
2995 Message::User { content, .. } => {
2996 assert_eq!(content.len(), 2);
2997
2998 match content.first_ref() {
3000 UserContent::Text { text } => assert_eq!(text, "Summarize this document"),
3001 _ => panic!("Expected Text"),
3002 }
3003 }
3004 _ => panic!("Expected User message"),
3005 }
3006 }
3007
3008 #[test]
3009 fn test_user_content_from_string() {
3010 let content: UserContent = "Hello".into();
3011 assert_eq!(
3012 content,
3013 UserContent::Text {
3014 text: "Hello".to_string()
3015 }
3016 );
3017
3018 let content: UserContent = String::from("World").into();
3019 assert_eq!(
3020 content,
3021 UserContent::Text {
3022 text: "World".to_string()
3023 }
3024 );
3025 }
3026
3027 #[test]
3028 fn test_openai_user_content_conversion() {
3029 let openai_text = openai::UserContent::Text {
3031 text: "Hello".to_string(),
3032 };
3033 let converted: UserContent = openai_text.into();
3034 assert_eq!(
3035 converted,
3036 UserContent::Text {
3037 text: "Hello".to_string()
3038 }
3039 );
3040
3041 let openai_image = openai::UserContent::Image {
3042 image_url: openai::ImageUrl {
3043 url: "https://example.com/img.png".to_string(),
3044 detail: ImageDetail::Auto,
3045 },
3046 };
3047 let converted: UserContent = openai_image.into();
3048 match converted {
3049 UserContent::ImageUrl { image_url } => {
3050 assert_eq!(image_url.url, "https://example.com/img.png");
3051 assert_eq!(image_url.detail, Some(ImageDetail::Auto));
3052 }
3053 _ => panic!("Expected ImageUrl"),
3054 }
3055
3056 let openai_audio = openai::UserContent::Audio {
3057 input_audio: openai::InputAudio {
3058 data: "audiodata".to_string(),
3059 format: AudioMediaType::FLAC,
3060 },
3061 };
3062 let converted: UserContent = openai_audio.into();
3063 match converted {
3064 UserContent::InputAudio { input_audio } => {
3065 assert_eq!(input_audio.data, "audiodata");
3066 assert_eq!(input_audio.format, AudioMediaType::FLAC);
3067 }
3068 _ => panic!("Expected InputAudio"),
3069 }
3070 }
3071
3072 #[test]
3073 fn test_completion_response_reasoning_details_with_multiple_ids_stay_separate() {
3074 let json = json!({
3075 "id": "resp_multi_id",
3076 "object": "chat.completion",
3077 "created": 1,
3078 "model": "openrouter/test-model",
3079 "choices": [{
3080 "index": 0,
3081 "finish_reason": "stop",
3082 "message": {
3083 "role": "assistant",
3084 "content": "hello",
3085 "reasoning": null,
3086 "reasoning_details": [
3087 {"type":"reasoning.summary","id":"rs_a","summary":"a1"},
3088 {"type":"reasoning.summary","id":"rs_b","summary":"b1"},
3089 {"type":"reasoning.summary","id":"rs_a","summary":"a2"}
3090 ]
3091 }
3092 }]
3093 });
3094
3095 let response: CompletionResponse = serde_json::from_value(json).unwrap();
3096 let converted: completion::CompletionResponse<CompletionResponse> =
3097 response.try_into().unwrap();
3098 let items: Vec<completion::AssistantContent> = converted.choice.into_iter().collect();
3099 let reasoning_blocks: Vec<_> = items
3100 .into_iter()
3101 .filter_map(|item| match item {
3102 completion::AssistantContent::Reasoning(reasoning) => Some(reasoning),
3103 _ => None,
3104 })
3105 .collect();
3106
3107 assert_eq!(reasoning_blocks.len(), 2);
3108 assert_eq!(reasoning_blocks[0].id.as_deref(), Some("rs_a"));
3109 assert_eq!(
3110 reasoning_blocks[0].content,
3111 vec![
3112 message::ReasoningContent::Summary("a1".to_string()),
3113 message::ReasoningContent::Summary("a2".to_string()),
3114 ]
3115 );
3116 assert_eq!(reasoning_blocks[1].id.as_deref(), Some("rs_b"));
3117 assert_eq!(
3118 reasoning_blocks[1].content,
3119 vec![message::ReasoningContent::Summary("b1".to_string())]
3120 );
3121 }
3122
3123 #[test]
3124 fn test_user_content_audio_serialization() {
3125 let content = UserContent::audio_base64("SGVsbG8=", AudioMediaType::WAV);
3126 let json = serde_json::to_value(&content).unwrap();
3127
3128 assert_eq!(json["type"], "input_audio");
3129 assert_eq!(json["input_audio"]["data"], "SGVsbG8=");
3130 assert_eq!(json["input_audio"]["format"], "wav");
3131 }
3132
3133 #[test]
3134 fn test_user_content_audio_deserialization() {
3135 let json = json!({
3136 "type": "input_audio",
3137 "input_audio": {
3138 "data": "SGVsbG8=",
3139 "format": "wav"
3140 }
3141 });
3142
3143 let content: UserContent = serde_json::from_value(json).unwrap();
3144 match content {
3145 UserContent::InputAudio { input_audio } => {
3146 assert_eq!(input_audio.data, "SGVsbG8=");
3147 assert_eq!(input_audio.format, AudioMediaType::WAV);
3148 }
3149 _ => panic!("Expected InputAudio variant"),
3150 }
3151 }
3152
3153 #[test]
3154 fn test_message_user_with_audio_serialization() {
3155 let msg = Message::User {
3156 content: OneOrMany::many(vec![
3157 UserContent::text("Transcribe this audio:"),
3158 UserContent::audio_base64("SGVsbG8=", AudioMediaType::MP3),
3159 ])
3160 .unwrap(),
3161 name: None,
3162 };
3163 let json = serde_json::to_value(&msg).unwrap();
3164
3165 assert_eq!(json["role"], "user");
3166 let content = json["content"].as_array().unwrap();
3167 assert_eq!(content.len(), 2);
3168 assert_eq!(content[0]["type"], "text");
3169 assert_eq!(content[1]["type"], "input_audio");
3170 assert_eq!(content[1]["input_audio"]["data"], "SGVsbG8=");
3171 assert_eq!(content[1]["input_audio"]["format"], "mp3");
3172 }
3173
3174 #[test]
3175 fn test_user_content_video_url_serialization() {
3176 let content = UserContent::video_url("https://example.com/video.mp4");
3177 let json = serde_json::to_value(&content).unwrap();
3178
3179 assert_eq!(json["type"], "video_url");
3180 assert_eq!(json["video_url"]["url"], "https://example.com/video.mp4");
3181 }
3182
3183 #[test]
3184 fn test_user_content_video_base64_serialization() {
3185 let content = UserContent::video_base64("SGVsbG8=", VideoMediaType::MP4);
3186 let json = serde_json::to_value(&content).unwrap();
3187
3188 assert_eq!(json["type"], "video_url");
3189 assert_eq!(json["video_url"]["url"], "data:video/mp4;base64,SGVsbG8=");
3190 }
3191
3192 #[test]
3193 fn test_user_content_video_url_deserialization() {
3194 let json = json!({
3195 "type": "video_url",
3196 "video_url": {
3197 "url": "https://example.com/video.mp4"
3198 }
3199 });
3200
3201 let content: UserContent = serde_json::from_value(json).unwrap();
3202 match content {
3203 UserContent::VideoUrl { video_url } => {
3204 assert_eq!(video_url.url, "https://example.com/video.mp4");
3205 }
3206 _ => panic!("Expected VideoUrl variant"),
3207 }
3208 }
3209
3210 #[test]
3211 fn test_message_user_with_video_serialization() {
3212 let msg = Message::User {
3213 content: OneOrMany::many(vec![
3214 UserContent::text("Describe this video:"),
3215 UserContent::video_url("https://example.com/video.mp4"),
3216 ])
3217 .unwrap(),
3218 name: None,
3219 };
3220 let json = serde_json::to_value(&msg).unwrap();
3221
3222 assert_eq!(json["role"], "user");
3223 let content = json["content"].as_array().unwrap();
3224 assert_eq!(content.len(), 2);
3225 assert_eq!(content[0]["type"], "text");
3226 assert_eq!(content[1]["type"], "video_url");
3227 assert_eq!(
3228 content[1]["video_url"]["url"],
3229 "https://example.com/video.mp4"
3230 );
3231 }
3232
3233 #[test]
3234 fn test_user_content_video_url_no_media_type_needed() {
3235 let rig_content = message::UserContent::Video(message::Video {
3236 data: DocumentSourceKind::Url("https://example.com/video.mp4".to_string()),
3237 media_type: None,
3238 additional_params: None,
3239 });
3240 let openrouter_content: UserContent = rig_content.try_into().unwrap();
3241
3242 match openrouter_content {
3243 UserContent::VideoUrl { video_url } => {
3244 assert_eq!(video_url.url, "https://example.com/video.mp4");
3245 }
3246 _ => panic!("Expected VideoUrl variant"),
3247 }
3248 }
3249}