1use super::{
2 client::{ApiErrorResponse, ApiResponse, Client, Usage},
3 streaming::StreamingCompletionResponse,
4};
5use crate::message::{
6 self, AudioMediaType, DocumentMediaType, DocumentSourceKind, ImageDetail, MimeType,
7 VideoMediaType,
8};
9use crate::telemetry::SpanCombinator;
10use crate::{
11 OneOrMany,
12 completion::{self, CompletionError, CompletionRequest},
13 http_client::HttpClientExt,
14 json_utils,
15 one_or_many::string_or_one_or_many,
16 providers::openai,
17};
18use bytes::Bytes;
19use serde::{Deserialize, Serialize, Serializer};
20use std::collections::HashMap;
21use tracing::{Instrument, Level, enabled, info_span};
22
23pub const QWEN_QWQ_32B: &str = "qwen/qwq-32b";
29pub const CLAUDE_3_7_SONNET: &str = "anthropic/claude-3.7-sonnet";
31pub const PERPLEXITY_SONAR_PRO: &str = "perplexity/sonar-pro";
33pub const GEMINI_FLASH_2_0: &str = "google/gemini-2.0-flash-001";
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
45#[serde(rename_all = "lowercase")]
46pub enum DataCollection {
47 #[default]
49 Allow,
50 Deny,
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
58#[serde(rename_all = "lowercase")]
59pub enum Quantization {
60 #[serde(rename = "int4")]
62 Int4,
63 #[serde(rename = "int8")]
65 Int8,
66 #[serde(rename = "fp16")]
68 Fp16,
69 #[serde(rename = "bf16")]
71 Bf16,
72 #[serde(rename = "fp32")]
74 Fp32,
75 #[serde(rename = "fp8")]
77 Fp8,
78 #[serde(rename = "unknown")]
80 Unknown,
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
89#[serde(rename_all = "lowercase")]
90pub enum ProviderSortStrategy {
91 Price,
93 Throughput,
95 Latency,
97}
98
99#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
101#[serde(rename_all = "lowercase")]
102pub enum SortPartition {
103 Model,
105 None,
107}
108
109#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
113pub struct ProviderSortConfig {
114 pub by: ProviderSortStrategy,
116
117 #[serde(skip_serializing_if = "Option::is_none")]
119 pub partition: Option<SortPartition>,
120}
121
122impl ProviderSortConfig {
123 pub fn new(by: ProviderSortStrategy) -> Self {
125 Self {
126 by,
127 partition: None,
128 }
129 }
130
131 pub fn partition(mut self, partition: SortPartition) -> Self {
133 self.partition = Some(partition);
134 self
135 }
136}
137
138#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
143#[serde(untagged)]
144pub enum ProviderSort {
145 Simple(ProviderSortStrategy),
147 Complex(ProviderSortConfig),
149}
150
151impl From<ProviderSortStrategy> for ProviderSort {
152 fn from(strategy: ProviderSortStrategy) -> Self {
153 ProviderSort::Simple(strategy)
154 }
155}
156
157impl From<ProviderSortConfig> for ProviderSort {
158 fn from(config: ProviderSortConfig) -> Self {
159 ProviderSort::Complex(config)
160 }
161}
162
163#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
167#[serde(untagged)]
168pub enum ThroughputThreshold {
169 Simple(f64),
171 Percentile(PercentileThresholds),
173}
174
175#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
179#[serde(untagged)]
180pub enum LatencyThreshold {
181 Simple(f64),
183 Percentile(PercentileThresholds),
185}
186
187#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
189pub struct PercentileThresholds {
190 #[serde(skip_serializing_if = "Option::is_none")]
192 pub p50: Option<f64>,
193 #[serde(skip_serializing_if = "Option::is_none")]
195 pub p75: Option<f64>,
196 #[serde(skip_serializing_if = "Option::is_none")]
198 pub p90: Option<f64>,
199 #[serde(skip_serializing_if = "Option::is_none")]
201 pub p99: Option<f64>,
202}
203
204impl PercentileThresholds {
205 pub fn new() -> Self {
207 Self::default()
208 }
209
210 pub fn p50(mut self, value: f64) -> Self {
212 self.p50 = Some(value);
213 self
214 }
215
216 pub fn p75(mut self, value: f64) -> Self {
218 self.p75 = Some(value);
219 self
220 }
221
222 pub fn p90(mut self, value: f64) -> Self {
224 self.p90 = Some(value);
225 self
226 }
227
228 pub fn p99(mut self, value: f64) -> Self {
230 self.p99 = Some(value);
231 self
232 }
233}
234
235#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
240pub struct MaxPrice {
241 #[serde(skip_serializing_if = "Option::is_none")]
243 pub prompt: Option<f64>,
244 #[serde(skip_serializing_if = "Option::is_none")]
246 pub completion: Option<f64>,
247 #[serde(skip_serializing_if = "Option::is_none")]
249 pub request: Option<f64>,
250 #[serde(skip_serializing_if = "Option::is_none")]
252 pub image: Option<f64>,
253}
254
255impl MaxPrice {
256 pub fn new() -> Self {
258 Self::default()
259 }
260
261 pub fn prompt(mut self, price: f64) -> Self {
263 self.prompt = Some(price);
264 self
265 }
266
267 pub fn completion(mut self, price: f64) -> Self {
269 self.completion = Some(price);
270 self
271 }
272
273 pub fn request(mut self, price: f64) -> Self {
275 self.request = Some(price);
276 self
277 }
278
279 pub fn image(mut self, price: f64) -> Self {
281 self.image = Some(price);
282 self
283 }
284}
285
286#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
306pub struct ProviderPreferences {
307 #[serde(skip_serializing_if = "Option::is_none")]
311 pub order: Option<Vec<String>>,
312
313 #[serde(skip_serializing_if = "Option::is_none")]
315 pub only: Option<Vec<String>>,
316
317 #[serde(skip_serializing_if = "Option::is_none")]
319 pub ignore: Option<Vec<String>>,
320
321 #[serde(skip_serializing_if = "Option::is_none")]
324 pub allow_fallbacks: Option<bool>,
325
326 #[serde(skip_serializing_if = "Option::is_none")]
330 pub require_parameters: Option<bool>,
331
332 #[serde(skip_serializing_if = "Option::is_none")]
335 pub data_collection: Option<DataCollection>,
336
337 #[serde(skip_serializing_if = "Option::is_none")]
339 pub zdr: Option<bool>,
340
341 #[serde(skip_serializing_if = "Option::is_none")]
345 pub sort: Option<ProviderSort>,
346
347 #[serde(skip_serializing_if = "Option::is_none")]
349 pub preferred_min_throughput: Option<ThroughputThreshold>,
350
351 #[serde(skip_serializing_if = "Option::is_none")]
353 pub preferred_max_latency: Option<LatencyThreshold>,
354
355 #[serde(skip_serializing_if = "Option::is_none")]
357 pub max_price: Option<MaxPrice>,
358
359 #[serde(skip_serializing_if = "Option::is_none")]
362 pub quantizations: Option<Vec<Quantization>>,
363}
364
365impl ProviderPreferences {
366 pub fn new() -> Self {
368 Self::default()
369 }
370
371 pub fn order(mut self, providers: impl IntoIterator<Item = impl Into<String>>) -> Self {
387 self.order = Some(providers.into_iter().map(|p| p.into()).collect());
388 self
389 }
390
391 pub fn only(mut self, providers: impl IntoIterator<Item = impl Into<String>>) -> Self {
403 self.only = Some(providers.into_iter().map(|p| p.into()).collect());
404 self
405 }
406
407 pub fn ignore(mut self, providers: impl IntoIterator<Item = impl Into<String>>) -> Self {
418 self.ignore = Some(providers.into_iter().map(|p| p.into()).collect());
419 self
420 }
421
422 pub fn allow_fallbacks(mut self, allow: bool) -> Self {
427 self.allow_fallbacks = Some(allow);
428 self
429 }
430
431 pub fn require_parameters(mut self, require: bool) -> Self {
437 self.require_parameters = Some(require);
438 self
439 }
440
441 pub fn data_collection(mut self, policy: DataCollection) -> Self {
445 self.data_collection = Some(policy);
446 self
447 }
448
449 pub fn zdr(mut self, enable: bool) -> Self {
460 self.zdr = Some(enable);
461 self
462 }
463
464 pub fn sort(mut self, sort: impl Into<ProviderSort>) -> Self {
480 self.sort = Some(sort.into());
481 self
482 }
483
484 pub fn preferred_min_throughput(mut self, threshold: ThroughputThreshold) -> Self {
504 self.preferred_min_throughput = Some(threshold);
505 self
506 }
507
508 pub fn preferred_max_latency(mut self, threshold: LatencyThreshold) -> Self {
512 self.preferred_max_latency = Some(threshold);
513 self
514 }
515
516 pub fn max_price(mut self, price: MaxPrice) -> Self {
520 self.max_price = Some(price);
521 self
522 }
523
524 pub fn quantizations(mut self, quantizations: impl IntoIterator<Item = Quantization>) -> Self {
537 self.quantizations = Some(quantizations.into_iter().collect());
538 self
539 }
540
541 pub fn zero_data_retention(self) -> Self {
545 self.zdr(true)
546 }
547
548 pub fn fastest(self) -> Self {
550 self.sort(ProviderSortStrategy::Throughput)
551 }
552
553 pub fn cheapest(self) -> Self {
555 self.sort(ProviderSortStrategy::Price)
556 }
557
558 pub fn lowest_latency(self) -> Self {
560 self.sort(ProviderSortStrategy::Latency)
561 }
562
563 pub fn to_json(&self) -> serde_json::Value {
565 serde_json::json!({
566 "provider": self
567 })
568 }
569}
570
571#[derive(Debug, Serialize, Deserialize)]
575pub struct CompletionResponse {
576 pub id: String,
577 pub object: String,
578 pub created: u64,
579 pub model: String,
580 pub choices: Vec<Choice>,
581 pub system_fingerprint: Option<String>,
582 pub usage: Option<Usage>,
583}
584
585impl From<ApiErrorResponse> for CompletionError {
586 fn from(err: ApiErrorResponse) -> Self {
587 CompletionError::ProviderError(err.message)
588 }
589}
590
591impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
592 type Error = CompletionError;
593
594 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
595 let choice = response.choices.first().ok_or_else(|| {
596 CompletionError::ResponseError("Response contained no choices".to_owned())
597 })?;
598
599 let content = match &choice.message {
600 Message::Assistant {
601 content,
602 tool_calls,
603 reasoning,
604 reasoning_details,
605 ..
606 } => {
607 let mut content = content
608 .iter()
609 .map(|c| match c {
610 openai::AssistantContent::Text { text } => {
611 completion::AssistantContent::text(text)
612 }
613 openai::AssistantContent::Refusal { refusal } => {
614 completion::AssistantContent::text(refusal)
615 }
616 })
617 .collect::<Vec<_>>();
618
619 content.extend(tool_calls.iter().map(|call| {
620 completion::AssistantContent::tool_call(
621 &call.id,
622 &call.function.name,
623 call.function.arguments.clone(),
624 )
625 }));
626
627 let mut grouped_reasoning: HashMap<
628 Option<String>,
629 Vec<(usize, usize, message::ReasoningContent)>,
630 > = HashMap::new();
631 let mut reasoning_order: Vec<Option<String>> = Vec::new();
632 for (position, detail) in reasoning_details.iter().enumerate() {
633 let (reasoning_id, sort_index, parsed_content) = match detail {
634 ReasoningDetails::Summary {
635 id, index, summary, ..
636 } => (
637 id.clone(),
638 *index,
639 Some(message::ReasoningContent::Summary(summary.clone())),
640 ),
641 ReasoningDetails::Encrypted {
642 id, index, data, ..
643 } => (
644 id.clone(),
645 *index,
646 Some(message::ReasoningContent::Encrypted(data.clone())),
647 ),
648 ReasoningDetails::Text {
649 id,
650 index,
651 text,
652 signature,
653 ..
654 } => (
655 id.clone(),
656 *index,
657 text.as_ref().map(|text| message::ReasoningContent::Text {
658 text: text.clone(),
659 signature: signature.clone(),
660 }),
661 ),
662 };
663
664 let Some(parsed_content) = parsed_content else {
665 continue;
666 };
667 let sort_index = sort_index.unwrap_or(position);
668
669 let entry = grouped_reasoning.entry(reasoning_id.clone());
670 if matches!(entry, std::collections::hash_map::Entry::Vacant(_)) {
671 reasoning_order.push(reasoning_id);
672 }
673 entry
674 .or_default()
675 .push((sort_index, position, parsed_content));
676 }
677
678 if grouped_reasoning.is_empty() {
679 if let Some(reasoning) = reasoning {
680 content.push(completion::AssistantContent::reasoning(reasoning));
681 }
682 } else {
683 for reasoning_id in reasoning_order {
684 let Some(mut blocks) = grouped_reasoning.remove(&reasoning_id) else {
685 continue;
686 };
687 blocks.sort_by_key(|(index, position, _)| (*index, *position));
688 content.push(completion::AssistantContent::Reasoning(
689 message::Reasoning {
690 id: reasoning_id,
691 content: blocks
692 .into_iter()
693 .map(|(_, _, content)| content)
694 .collect::<Vec<_>>(),
695 },
696 ));
697 }
698 }
699
700 Ok(content)
701 }
702 _ => Err(CompletionError::ResponseError(
703 "Response did not contain a valid message or tool call".into(),
704 )),
705 }?;
706
707 let choice = OneOrMany::many(content).map_err(|_| {
708 CompletionError::ResponseError(
709 "Response contained no message or tool call (empty)".to_owned(),
710 )
711 })?;
712
713 let usage = response
714 .usage
715 .as_ref()
716 .map(|usage| completion::Usage {
717 input_tokens: usage.prompt_tokens as u64,
718 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
719 total_tokens: usage.total_tokens as u64,
720 cached_input_tokens: 0,
721 cache_creation_input_tokens: 0,
722 })
723 .unwrap_or_default();
724
725 Ok(completion::CompletionResponse {
726 choice,
727 usage,
728 raw_response: response,
729 message_id: None,
730 })
731 }
732}
733
734#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
774#[serde(tag = "type", rename_all = "snake_case")]
775pub enum UserContent {
776 Text { text: String },
778
779 #[serde(rename = "image_url")]
783 ImageUrl { image_url: ImageUrl },
784
785 File { file: FileContent },
790
791 InputAudio { input_audio: openai::InputAudio },
795
796 #[serde(rename = "video_url")]
801 VideoUrl { video_url: VideoUrlContent },
802}
803
804impl UserContent {
805 pub fn text(text: impl Into<String>) -> Self {
807 UserContent::Text { text: text.into() }
808 }
809
810 pub fn image_url(url: impl Into<String>) -> Self {
812 UserContent::ImageUrl {
813 image_url: ImageUrl {
814 url: url.into(),
815 detail: None,
816 },
817 }
818 }
819
820 pub fn image_url_with_detail(url: impl Into<String>, detail: ImageDetail) -> Self {
822 UserContent::ImageUrl {
823 image_url: ImageUrl {
824 url: url.into(),
825 detail: Some(detail),
826 },
827 }
828 }
829
830 pub fn image_base64(
837 data: impl Into<String>,
838 mime_type: &str,
839 detail: Option<ImageDetail>,
840 ) -> Self {
841 let data_uri = format!("data:{};base64,{}", mime_type, data.into());
842 UserContent::ImageUrl {
843 image_url: ImageUrl {
844 url: data_uri,
845 detail,
846 },
847 }
848 }
849
850 pub fn file_url(url: impl Into<String>, filename: Option<String>) -> Self {
856 UserContent::File {
857 file: FileContent {
858 filename,
859 file_data: Some(url.into()),
860 },
861 }
862 }
863
864 pub fn file_base64(data: impl Into<String>, mime_type: &str, filename: Option<String>) -> Self {
871 let data_uri = format!("data:{};base64,{}", mime_type, data.into());
872 UserContent::File {
873 file: FileContent {
874 filename,
875 file_data: Some(data_uri),
876 },
877 }
878 }
879
880 pub fn audio_base64(data: impl Into<String>, format: AudioMediaType) -> Self {
888 UserContent::InputAudio {
889 input_audio: openai::InputAudio {
890 data: data.into(),
891 format,
892 },
893 }
894 }
895
896 pub fn video_url(url: impl Into<String>) -> Self {
903 UserContent::VideoUrl {
904 video_url: VideoUrlContent { url: url.into() },
905 }
906 }
907
908 pub fn video_base64(data: impl Into<String>, media_type: VideoMediaType) -> Self {
914 let mime = media_type.to_mime_type();
915 let data_uri = format!("data:{mime};base64,{}", data.into());
916 UserContent::VideoUrl {
917 video_url: VideoUrlContent { url: data_uri },
918 }
919 }
920}
921
922impl From<String> for UserContent {
923 fn from(text: String) -> Self {
924 UserContent::Text { text }
925 }
926}
927
928impl From<&str> for UserContent {
929 fn from(text: &str) -> Self {
930 UserContent::Text {
931 text: text.to_string(),
932 }
933 }
934}
935
936impl std::str::FromStr for UserContent {
937 type Err = std::convert::Infallible;
938
939 fn from_str(s: &str) -> Result<Self, Self::Err> {
940 Ok(UserContent::Text {
941 text: s.to_string(),
942 })
943 }
944}
945
946#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
948pub struct ImageUrl {
949 pub url: String,
951 #[serde(skip_serializing_if = "Option::is_none")]
953 pub detail: Option<ImageDetail>,
954}
955
956#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
962pub struct VideoUrlContent {
963 pub url: String,
965}
966
967#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
974pub struct FileContent {
975 #[serde(skip_serializing_if = "Option::is_none")]
977 pub filename: Option<String>,
978 #[serde(skip_serializing_if = "Option::is_none")]
980 pub file_data: Option<String>,
981}
982
983fn serialize_user_content<S>(
986 content: &OneOrMany<UserContent>,
987 serializer: S,
988) -> Result<S::Ok, S::Error>
989where
990 S: Serializer,
991{
992 if content.len() == 1
993 && let UserContent::Text { text } = content.first_ref()
994 {
995 return serializer.serialize_str(text);
996 }
997 content.serialize(serializer)
998}
999
1000impl TryFrom<message::UserContent> for UserContent {
1001 type Error = message::MessageError;
1002
1003 fn try_from(value: message::UserContent) -> Result<Self, Self::Error> {
1004 match value {
1005 message::UserContent::Text(message::Text { text }) => Ok(UserContent::Text { text }),
1006
1007 message::UserContent::Image(message::Image {
1008 data,
1009 detail,
1010 media_type,
1011 ..
1012 }) => {
1013 let url = match data {
1014 DocumentSourceKind::Url(url) => url,
1015 DocumentSourceKind::Base64(data) => {
1016 let mime = media_type
1017 .ok_or_else(|| {
1018 message::MessageError::ConversionError(
1019 "Image media type required for base64 encoding".into(),
1020 )
1021 })?
1022 .to_mime_type();
1023 format!("data:{mime};base64,{data}")
1024 }
1025 DocumentSourceKind::Raw(_) => {
1026 return Err(message::MessageError::ConversionError(
1027 "Raw bytes not supported, encode as base64 first".into(),
1028 ));
1029 }
1030 DocumentSourceKind::String(_) => {
1031 return Err(message::MessageError::ConversionError(
1032 "String source not supported for images".into(),
1033 ));
1034 }
1035 DocumentSourceKind::Unknown => {
1036 return Err(message::MessageError::ConversionError(
1037 "Image has no data".into(),
1038 ));
1039 }
1040 };
1041 Ok(UserContent::ImageUrl {
1042 image_url: ImageUrl { url, detail },
1043 })
1044 }
1045
1046 message::UserContent::Document(message::Document {
1047 data, media_type, ..
1048 }) => match data {
1049 DocumentSourceKind::Url(url) => {
1050 let filename = media_type.as_ref().map(|mt| match mt {
1051 DocumentMediaType::PDF => "document.pdf",
1052 DocumentMediaType::TXT => "document.txt",
1053 DocumentMediaType::HTML => "document.html",
1054 DocumentMediaType::MARKDOWN => "document.md",
1055 DocumentMediaType::CSV => "document.csv",
1056 DocumentMediaType::XML => "document.xml",
1057 _ => "document",
1058 });
1059 Ok(UserContent::File {
1060 file: FileContent {
1061 filename: filename.map(String::from),
1062 file_data: Some(url),
1063 },
1064 })
1065 }
1066 DocumentSourceKind::Base64(data) => {
1067 let mime = media_type
1068 .as_ref()
1069 .map(|m| m.to_mime_type())
1070 .unwrap_or("application/pdf");
1071 let data_uri = format!("data:{mime};base64,{data}");
1072
1073 let filename = media_type.as_ref().map(|mt| match mt {
1074 DocumentMediaType::PDF => "document.pdf",
1075 DocumentMediaType::TXT => "document.txt",
1076 DocumentMediaType::HTML => "document.html",
1077 DocumentMediaType::MARKDOWN => "document.md",
1078 DocumentMediaType::CSV => "document.csv",
1079 DocumentMediaType::XML => "document.xml",
1080 _ => "document",
1081 });
1082
1083 Ok(UserContent::File {
1084 file: FileContent {
1085 filename: filename.map(String::from),
1086 file_data: Some(data_uri),
1087 },
1088 })
1089 }
1090 DocumentSourceKind::String(text) => Ok(UserContent::Text { text }),
1091 DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
1092 "Raw bytes not supported for documents, encode as base64 first".into(),
1093 )),
1094 DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
1095 "Document has no data".into(),
1096 )),
1097 },
1098
1099 message::UserContent::Audio(message::Audio {
1100 data, media_type, ..
1101 }) => match data {
1102 DocumentSourceKind::Base64(data) => {
1103 let format = media_type.ok_or_else(|| {
1104 message::MessageError::ConversionError(
1105 "Audio media type required for base64 encoding".into(),
1106 )
1107 })?;
1108 Ok(UserContent::InputAudio {
1109 input_audio: openai::InputAudio { data, format },
1110 })
1111 }
1112 DocumentSourceKind::Url(_) => Err(message::MessageError::ConversionError(
1113 "OpenRouter does not support audio URLs, encode as base64 first".into(),
1114 )),
1115 DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
1116 "Raw bytes not supported for audio, encode as base64 first".into(),
1117 )),
1118 DocumentSourceKind::String(_) => Err(message::MessageError::ConversionError(
1119 "String source not supported for audio".into(),
1120 )),
1121 DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
1122 "Audio has no data".into(),
1123 )),
1124 },
1125
1126 message::UserContent::Video(message::Video {
1127 data, media_type, ..
1128 }) => {
1129 let url = match data {
1130 DocumentSourceKind::Url(url) => url,
1131 DocumentSourceKind::Base64(data) => {
1132 let mime = media_type
1133 .ok_or_else(|| {
1134 message::MessageError::ConversionError(
1135 "Video media type required for base64 encoding".into(),
1136 )
1137 })?
1138 .to_mime_type();
1139 format!("data:{mime};base64,{data}")
1140 }
1141 DocumentSourceKind::Raw(_) => {
1142 return Err(message::MessageError::ConversionError(
1143 "Raw bytes not supported for video, encode as base64 first".into(),
1144 ));
1145 }
1146 DocumentSourceKind::String(_) => {
1147 return Err(message::MessageError::ConversionError(
1148 "String source not supported for video".into(),
1149 ));
1150 }
1151 DocumentSourceKind::Unknown => {
1152 return Err(message::MessageError::ConversionError(
1153 "Video has no data".into(),
1154 ));
1155 }
1156 };
1157 Ok(UserContent::VideoUrl {
1158 video_url: VideoUrlContent { url },
1159 })
1160 }
1161
1162 message::UserContent::ToolResult(_) => Err(message::MessageError::ConversionError(
1163 "Tool results should be handled as separate messages".into(),
1164 )),
1165 }
1166 }
1167}
1168
1169impl TryFrom<OneOrMany<message::UserContent>> for Vec<Message> {
1170 type Error = message::MessageError;
1171
1172 fn try_from(value: OneOrMany<message::UserContent>) -> Result<Self, Self::Error> {
1173 let (tool_results, other_content): (Vec<_>, Vec<_>) = value
1174 .into_iter()
1175 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
1176
1177 if !tool_results.is_empty() {
1180 tool_results
1181 .into_iter()
1182 .map(|content| match content {
1183 message::UserContent::ToolResult(tool_result) => Ok(Message::ToolResult {
1184 tool_call_id: tool_result.id,
1185 content: tool_result
1186 .content
1187 .into_iter()
1188 .map(|c| match c {
1189 message::ToolResultContent::Text(message::Text { text }) => text,
1190 message::ToolResultContent::Image(_) => {
1191 "[Image content not supported in tool results]".to_string()
1192 }
1193 })
1194 .collect::<Vec<_>>()
1195 .join("\n"),
1196 }),
1197 _ => Err(message::MessageError::ConversionError(
1198 "expected tool result content while converting OpenRouter input".into(),
1199 )),
1200 })
1201 .collect::<Result<Vec<_>, _>>()
1202 } else {
1203 let user_content: Vec<UserContent> = other_content
1204 .into_iter()
1205 .map(|content| content.try_into())
1206 .collect::<Result<Vec<_>, _>>()?;
1207
1208 let content = OneOrMany::many(user_content).map_err(|_| {
1209 message::MessageError::ConversionError(
1210 "OpenRouter user message did not contain any non-tool content".into(),
1211 )
1212 })?;
1213
1214 Ok(vec![Message::User {
1215 content,
1216 name: None,
1217 }])
1218 }
1219 }
1220}
1221
1222#[derive(Debug, Deserialize, Serialize)]
1227pub struct Choice {
1228 pub index: usize,
1229 pub native_finish_reason: Option<String>,
1230 pub message: Message,
1231 pub finish_reason: Option<String>,
1232}
1233
1234#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1240#[serde(tag = "role", rename_all = "lowercase")]
1241pub enum Message {
1242 #[serde(alias = "developer")]
1243 System {
1244 #[serde(deserialize_with = "string_or_one_or_many")]
1245 content: OneOrMany<openai::SystemContent>,
1246 #[serde(skip_serializing_if = "Option::is_none")]
1247 name: Option<String>,
1248 },
1249 User {
1250 #[serde(
1251 deserialize_with = "string_or_one_or_many",
1252 serialize_with = "serialize_user_content"
1253 )]
1254 content: OneOrMany<UserContent>,
1255 #[serde(skip_serializing_if = "Option::is_none")]
1256 name: Option<String>,
1257 },
1258 Assistant {
1259 #[serde(default, deserialize_with = "json_utils::string_or_vec")]
1260 content: Vec<openai::AssistantContent>,
1261 #[serde(skip_serializing_if = "Option::is_none")]
1262 refusal: Option<String>,
1263 #[serde(skip_serializing_if = "Option::is_none")]
1264 audio: Option<openai::AudioAssistant>,
1265 #[serde(skip_serializing_if = "Option::is_none")]
1266 name: Option<String>,
1267 #[serde(
1268 default,
1269 deserialize_with = "json_utils::null_or_vec",
1270 skip_serializing_if = "Vec::is_empty"
1271 )]
1272 tool_calls: Vec<openai::ToolCall>,
1273 #[serde(skip_serializing_if = "Option::is_none")]
1274 reasoning: Option<String>,
1275 #[serde(default, skip_serializing_if = "Vec::is_empty")]
1276 reasoning_details: Vec<ReasoningDetails>,
1277 },
1278 #[serde(rename = "tool")]
1279 ToolResult {
1280 tool_call_id: String,
1281 content: String,
1282 },
1283}
1284
1285impl Message {
1286 pub fn system(content: &str) -> Self {
1287 Message::System {
1288 content: OneOrMany::one(content.to_owned().into()),
1289 name: None,
1290 }
1291 }
1292}
1293
1294#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1295#[serde(tag = "type", rename_all = "snake_case")]
1296pub enum ReasoningDetails {
1297 #[serde(rename = "reasoning.summary")]
1298 Summary {
1299 id: Option<String>,
1300 format: Option<String>,
1301 index: Option<usize>,
1302 summary: String,
1303 },
1304 #[serde(rename = "reasoning.encrypted")]
1305 Encrypted {
1306 id: Option<String>,
1307 format: Option<String>,
1308 index: Option<usize>,
1309 data: String,
1310 },
1311 #[serde(rename = "reasoning.text")]
1312 Text {
1313 id: Option<String>,
1314 format: Option<String>,
1315 index: Option<usize>,
1316 text: Option<String>,
1317 signature: Option<String>,
1318 },
1319}
1320
1321#[derive(Debug, Deserialize, PartialEq, Clone)]
1322#[serde(untagged)]
1323enum ToolCallAdditionalParams {
1324 ReasoningDetails(ReasoningDetails),
1325 Minimal {
1326 id: Option<String>,
1327 format: Option<String>,
1328 },
1329}
1330
1331impl From<openai::UserContent> for UserContent {
1333 fn from(value: openai::UserContent) -> Self {
1334 match value {
1335 openai::UserContent::Text { text } => UserContent::Text { text },
1336 openai::UserContent::Image { image_url } => UserContent::ImageUrl {
1337 image_url: ImageUrl {
1338 url: image_url.url,
1339 detail: Some(image_url.detail),
1340 },
1341 },
1342 openai::UserContent::Audio { input_audio } => UserContent::InputAudio { input_audio },
1343 }
1344 }
1345}
1346
1347impl From<openai::Message> for Message {
1348 fn from(value: openai::Message) -> Self {
1349 match value {
1350 openai::Message::System { content, name } => Self::System { content, name },
1351 openai::Message::User { content, name } => {
1352 let converted_content = content.map(UserContent::from);
1354 Self::User {
1355 content: converted_content,
1356 name,
1357 }
1358 }
1359 openai::Message::Assistant {
1360 content,
1361 reasoning,
1362 refusal,
1363 audio,
1364 name,
1365 tool_calls,
1366 } => Self::Assistant {
1367 content,
1368 refusal,
1369 audio,
1370 name,
1371 tool_calls,
1372 reasoning,
1373 reasoning_details: Vec::new(),
1374 },
1375 openai::Message::ToolResult {
1376 tool_call_id,
1377 content,
1378 } => Self::ToolResult {
1379 tool_call_id,
1380 content: content.as_text(),
1381 },
1382 }
1383 }
1384}
1385
1386impl TryFrom<OneOrMany<message::AssistantContent>> for Vec<Message> {
1387 type Error = message::MessageError;
1388
1389 fn try_from(value: OneOrMany<message::AssistantContent>) -> Result<Self, Self::Error> {
1390 let mut text_content = Vec::new();
1391 let mut tool_calls = Vec::new();
1392 let mut reasoning = None;
1393 let mut reasoning_details = Vec::new();
1394
1395 for content in value.into_iter() {
1396 match content {
1397 message::AssistantContent::Text(text) => text_content.push(text),
1398 message::AssistantContent::ToolCall(tool_call) => {
1399 if let Some(additional_params) = &tool_call.additional_params
1405 && let Ok(additional_params) =
1406 serde_json::from_value::<ToolCallAdditionalParams>(
1407 additional_params.clone(),
1408 )
1409 {
1410 match additional_params {
1411 ToolCallAdditionalParams::ReasoningDetails(full) => {
1412 reasoning_details.push(full);
1413 }
1414 ToolCallAdditionalParams::Minimal { id, format } => {
1415 let id = id.or_else(|| tool_call.call_id.clone());
1416 if let Some(signature) = &tool_call.signature
1417 && let Some(id) = id
1418 {
1419 reasoning_details.push(ReasoningDetails::Encrypted {
1420 id: Some(id),
1421 format,
1422 index: None,
1423 data: signature.clone(),
1424 })
1425 }
1426 }
1427 }
1428 } else if let Some(signature) = &tool_call.signature {
1429 reasoning_details.push(ReasoningDetails::Encrypted {
1430 id: tool_call.call_id.clone(),
1431 format: None,
1432 index: None,
1433 data: signature.clone(),
1434 });
1435 }
1436 tool_calls.push(tool_call.into())
1437 }
1438 message::AssistantContent::Reasoning(r) => {
1439 if r.content.is_empty() {
1440 let display = r.display_text();
1441 if !display.is_empty() {
1442 reasoning = Some(display);
1443 }
1444 } else {
1445 for reasoning_block in &r.content {
1446 let index = Some(reasoning_details.len());
1447 match reasoning_block {
1448 message::ReasoningContent::Text { text, signature } => {
1449 reasoning_details.push(ReasoningDetails::Text {
1450 id: r.id.clone(),
1451 format: None,
1452 index,
1453 text: Some(text.clone()),
1454 signature: signature.clone(),
1455 });
1456 }
1457 message::ReasoningContent::Summary(summary) => {
1458 reasoning_details.push(ReasoningDetails::Summary {
1459 id: r.id.clone(),
1460 format: None,
1461 index,
1462 summary: summary.clone(),
1463 });
1464 }
1465 message::ReasoningContent::Encrypted(data)
1466 | message::ReasoningContent::Redacted { data } => {
1467 reasoning_details.push(ReasoningDetails::Encrypted {
1468 id: r.id.clone(),
1469 format: None,
1470 index,
1471 data: data.clone(),
1472 });
1473 }
1474 }
1475 }
1476 }
1477 }
1478 message::AssistantContent::Image(_) => {
1479 return Err(Self::Error::ConversionError(
1480 "OpenRouter currently doesn't support images.".into(),
1481 ));
1482 }
1483 }
1484 }
1485
1486 Ok(vec![Message::Assistant {
1489 content: text_content
1490 .into_iter()
1491 .map(|content| content.text.into())
1492 .collect::<Vec<_>>(),
1493 refusal: None,
1494 audio: None,
1495 name: None,
1496 tool_calls,
1497 reasoning,
1498 reasoning_details,
1499 }])
1500 }
1501}
1502
1503impl TryFrom<message::Message> for Vec<Message> {
1506 type Error = message::MessageError;
1507
1508 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
1509 match message {
1510 message::Message::System { content } => Ok(vec![Message::System {
1511 content: OneOrMany::one(content.into()),
1512 name: None,
1513 }]),
1514 message::Message::User { content } => {
1515 content.try_into()
1518 }
1519 message::Message::Assistant { content, .. } => content.try_into(),
1520 }
1521 }
1522}
1523
1524#[derive(Debug, Serialize, Deserialize)]
1525#[serde(untagged, rename_all = "snake_case")]
1526pub enum ToolChoice {
1527 None,
1528 Auto,
1529 Required,
1530 Function(Vec<ToolChoiceFunctionKind>),
1531}
1532
1533impl TryFrom<crate::message::ToolChoice> for ToolChoice {
1534 type Error = CompletionError;
1535
1536 fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
1537 let res = match value {
1538 crate::message::ToolChoice::None => Self::None,
1539 crate::message::ToolChoice::Auto => Self::Auto,
1540 crate::message::ToolChoice::Required => Self::Required,
1541 crate::message::ToolChoice::Specific { function_names } => {
1542 let vec: Vec<ToolChoiceFunctionKind> = function_names
1543 .into_iter()
1544 .map(|name| ToolChoiceFunctionKind::Function { name })
1545 .collect();
1546
1547 Self::Function(vec)
1548 }
1549 };
1550
1551 Ok(res)
1552 }
1553}
1554
1555#[derive(Debug, Serialize, Deserialize)]
1556#[serde(tag = "type", content = "function")]
1557pub enum ToolChoiceFunctionKind {
1558 Function { name: String },
1559}
1560
1561#[derive(Debug, Serialize, Deserialize)]
1562pub(super) struct OpenrouterCompletionRequest {
1563 model: String,
1564 pub messages: Vec<Message>,
1565 #[serde(skip_serializing_if = "Option::is_none")]
1566 temperature: Option<f64>,
1567 #[serde(skip_serializing_if = "Vec::is_empty")]
1568 tools: Vec<crate::providers::openai::completion::ToolDefinition>,
1569 #[serde(skip_serializing_if = "Option::is_none")]
1570 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
1571 #[serde(flatten, skip_serializing_if = "Option::is_none")]
1572 pub additional_params: Option<serde_json::Value>,
1573}
1574
1575pub struct OpenRouterRequestParams<'a> {
1577 pub model: &'a str,
1578 pub request: CompletionRequest,
1579 pub strict_tools: bool,
1580}
1581
1582impl TryFrom<OpenRouterRequestParams<'_>> for OpenrouterCompletionRequest {
1583 type Error = CompletionError;
1584
1585 fn try_from(params: OpenRouterRequestParams) -> Result<Self, Self::Error> {
1586 let OpenRouterRequestParams {
1587 model,
1588 request: req,
1589 strict_tools,
1590 } = params;
1591 let model = req.model.clone().unwrap_or_else(|| model.to_string());
1592
1593 if req.output_schema.is_some() {
1594 tracing::warn!("Structured outputs currently not supported for OpenRouter");
1595 }
1596
1597 let mut full_history: Vec<Message> = match &req.preamble {
1598 Some(preamble) => vec![Message::system(preamble)],
1599 None => vec![],
1600 };
1601 if let Some(docs) = req.normalized_documents() {
1602 let docs: Vec<Message> = docs.try_into()?;
1603 full_history.extend(docs);
1604 }
1605
1606 let chat_history: Vec<Message> = req
1607 .chat_history
1608 .clone()
1609 .into_iter()
1610 .map(|message| message.try_into())
1611 .collect::<Result<Vec<Vec<Message>>, _>>()?
1612 .into_iter()
1613 .flatten()
1614 .collect();
1615
1616 full_history.extend(chat_history);
1617
1618 let tool_choice = req
1619 .tool_choice
1620 .clone()
1621 .map(crate::providers::openai::completion::ToolChoice::try_from)
1622 .transpose()?;
1623
1624 let tools: Vec<crate::providers::openai::completion::ToolDefinition> = req
1625 .tools
1626 .clone()
1627 .into_iter()
1628 .map(|tool| {
1629 let def = crate::providers::openai::completion::ToolDefinition::from(tool);
1630 if strict_tools { def.with_strict() } else { def }
1631 })
1632 .collect();
1633
1634 Ok(Self {
1635 model,
1636 messages: full_history,
1637 temperature: req.temperature,
1638 tools,
1639 tool_choice,
1640 additional_params: req.additional_params,
1641 })
1642 }
1643}
1644
1645impl TryFrom<(&str, CompletionRequest)> for OpenrouterCompletionRequest {
1646 type Error = CompletionError;
1647
1648 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
1649 let model = req.model.clone().unwrap_or_else(|| model.to_string());
1650 OpenrouterCompletionRequest::try_from(OpenRouterRequestParams {
1651 model: &model,
1652 request: req,
1653 strict_tools: false,
1654 })
1655 }
1656}
1657
1658#[derive(Clone)]
1659pub struct CompletionModel<T = reqwest::Client> {
1660 pub(crate) client: Client<T>,
1661 pub model: String,
1662 pub strict_tools: bool,
1665}
1666
1667impl<T> CompletionModel<T> {
1668 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
1669 Self {
1670 client,
1671 model: model.into(),
1672 strict_tools: false,
1673 }
1674 }
1675
1676 pub fn with_strict_tools(mut self) -> Self {
1685 self.strict_tools = true;
1686 self
1687 }
1688}
1689
1690impl<T> completion::CompletionModel for CompletionModel<T>
1691where
1692 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
1693{
1694 type Response = CompletionResponse;
1695 type StreamingResponse = StreamingCompletionResponse;
1696
1697 type Client = Client<T>;
1698
1699 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1700 Self::new(client.clone(), model)
1701 }
1702
1703 async fn completion(
1704 &self,
1705 completion_request: CompletionRequest,
1706 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
1707 let request_model = completion_request
1708 .model
1709 .clone()
1710 .unwrap_or_else(|| self.model.clone());
1711 let preamble = completion_request.preamble.clone();
1712 let request = OpenrouterCompletionRequest::try_from(OpenRouterRequestParams {
1713 model: request_model.as_ref(),
1714 request: completion_request,
1715 strict_tools: self.strict_tools,
1716 })?;
1717
1718 if enabled!(Level::TRACE) {
1719 tracing::trace!(
1720 target: "rig::completions",
1721 "OpenRouter completion request: {}",
1722 serde_json::to_string_pretty(&request)?
1723 );
1724 }
1725
1726 let span = if tracing::Span::current().is_disabled() {
1727 info_span!(
1728 target: "rig::completions",
1729 "chat",
1730 gen_ai.operation.name = "chat",
1731 gen_ai.provider.name = "openrouter",
1732 gen_ai.request.model = &request_model,
1733 gen_ai.system_instructions = preamble,
1734 gen_ai.response.id = tracing::field::Empty,
1735 gen_ai.response.model = tracing::field::Empty,
1736 gen_ai.usage.output_tokens = tracing::field::Empty,
1737 gen_ai.usage.input_tokens = tracing::field::Empty,
1738 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
1739 )
1740 } else {
1741 tracing::Span::current()
1742 };
1743
1744 let body = serde_json::to_vec(&request)?;
1745
1746 let req = self
1747 .client
1748 .post("/chat/completions")?
1749 .body(body)
1750 .map_err(|x| CompletionError::HttpError(x.into()))?;
1751
1752 async move {
1753 let response = self.client.send::<_, Bytes>(req).await?;
1754 let status = response.status();
1755 let response_body = response.into_body().into_future().await?.to_vec();
1756
1757 if status.is_success() {
1758 let parsed: ApiResponse<CompletionResponse> =
1759 serde_json::from_slice(&response_body).map_err(|e| {
1760 CompletionError::ResponseError(format!(
1761 "Failed to parse OpenRouter completion response: {}, response body: {}",
1762 e,
1763 String::from_utf8_lossy(&response_body)
1764 ))
1765 })?;
1766 match parsed {
1767 ApiResponse::Ok(response) => {
1768 let span = tracing::Span::current();
1769 span.record_token_usage(&response.usage);
1770 span.record("gen_ai.response.id", &response.id);
1771 span.record("gen_ai.response.model", &response.model);
1772
1773 tracing::debug!(target: "rig::completions",
1774 "OpenRouter response: {response:?}");
1775 response.try_into()
1776 }
1777 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
1778 }
1779 } else {
1780 Err(CompletionError::ProviderError(
1781 String::from_utf8_lossy(&response_body).to_string(),
1782 ))
1783 }
1784 }
1785 .instrument(span)
1786 .await
1787 }
1788
1789 async fn stream(
1790 &self,
1791 completion_request: CompletionRequest,
1792 ) -> Result<
1793 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1794 CompletionError,
1795 > {
1796 CompletionModel::stream(self, completion_request).await
1797 }
1798}
1799
1800#[cfg(test)]
1801mod tests {
1802 use super::*;
1803 use serde_json::json;
1804
1805 #[test]
1806 fn test_openrouter_request_uses_request_model_override() {
1807 let request = CompletionRequest {
1808 model: Some("google/gemini-2.5-flash".to_string()),
1809 preamble: None,
1810 chat_history: crate::OneOrMany::one("Hello".into()),
1811 documents: vec![],
1812 tools: vec![],
1813 temperature: None,
1814 max_tokens: None,
1815 tool_choice: None,
1816 additional_params: None,
1817 output_schema: None,
1818 };
1819
1820 let openrouter_request =
1821 OpenrouterCompletionRequest::try_from(("openai/gpt-4o-mini", request))
1822 .expect("request conversion should succeed");
1823 let serialized =
1824 serde_json::to_value(openrouter_request).expect("serialization should succeed");
1825
1826 assert_eq!(serialized["model"], "google/gemini-2.5-flash");
1827 }
1828
1829 #[test]
1830 fn test_openrouter_request_uses_default_model_when_override_unset() {
1831 let request = CompletionRequest {
1832 model: None,
1833 preamble: None,
1834 chat_history: crate::OneOrMany::one("Hello".into()),
1835 documents: vec![],
1836 tools: vec![],
1837 temperature: None,
1838 max_tokens: None,
1839 tool_choice: None,
1840 additional_params: None,
1841 output_schema: None,
1842 };
1843
1844 let openrouter_request =
1845 OpenrouterCompletionRequest::try_from(("openai/gpt-4o-mini", request))
1846 .expect("request conversion should succeed");
1847 let serialized =
1848 serde_json::to_value(openrouter_request).expect("serialization should succeed");
1849
1850 assert_eq!(serialized["model"], "openai/gpt-4o-mini");
1851 }
1852
1853 #[test]
1854 fn test_completion_response_deserialization_gemini_flash() {
1855 let json = json!({
1857 "id": "gen-AAAAAAAAAA-AAAAAAAAAAAAAAAAAAAA",
1858 "provider": "Google",
1859 "model": "google/gemini-2.5-flash",
1860 "object": "chat.completion",
1861 "created": 1765971703u64,
1862 "choices": [{
1863 "logprobs": null,
1864 "finish_reason": "stop",
1865 "native_finish_reason": "STOP",
1866 "index": 0,
1867 "message": {
1868 "role": "assistant",
1869 "content": "CONTENT",
1870 "refusal": null,
1871 "reasoning": null
1872 }
1873 }],
1874 "usage": {
1875 "prompt_tokens": 669,
1876 "completion_tokens": 5,
1877 "total_tokens": 674
1878 }
1879 });
1880
1881 let response: CompletionResponse = serde_json::from_value(json).unwrap();
1882 assert_eq!(response.id, "gen-AAAAAAAAAA-AAAAAAAAAAAAAAAAAAAA");
1883 assert_eq!(response.model, "google/gemini-2.5-flash");
1884 assert_eq!(response.choices.len(), 1);
1885 assert_eq!(response.choices[0].finish_reason, Some("stop".to_string()));
1886 }
1887
1888 #[test]
1889 fn test_message_assistant_without_reasoning_details() {
1890 let json = json!({
1892 "role": "assistant",
1893 "content": "Hello world",
1894 "refusal": null,
1895 "reasoning": null
1896 });
1897
1898 let message: Message = serde_json::from_value(json).unwrap();
1899 match message {
1900 Message::Assistant {
1901 content,
1902 reasoning_details,
1903 ..
1904 } => {
1905 assert_eq!(content.len(), 1);
1906 assert!(reasoning_details.is_empty());
1907 }
1908 _ => panic!("Expected Assistant message"),
1909 }
1910 }
1911
1912 #[test]
1913 fn test_data_collection_serialization() {
1914 assert_eq!(
1915 serde_json::to_string(&DataCollection::Allow).unwrap(),
1916 r#""allow""#
1917 );
1918 assert_eq!(
1919 serde_json::to_string(&DataCollection::Deny).unwrap(),
1920 r#""deny""#
1921 );
1922 }
1923
1924 #[test]
1925 fn test_data_collection_default() {
1926 assert_eq!(DataCollection::default(), DataCollection::Allow);
1927 }
1928
1929 #[test]
1930 fn test_quantization_serialization() {
1931 assert_eq!(
1932 serde_json::to_string(&Quantization::Int4).unwrap(),
1933 r#""int4""#
1934 );
1935 assert_eq!(
1936 serde_json::to_string(&Quantization::Int8).unwrap(),
1937 r#""int8""#
1938 );
1939 assert_eq!(
1940 serde_json::to_string(&Quantization::Fp16).unwrap(),
1941 r#""fp16""#
1942 );
1943 assert_eq!(
1944 serde_json::to_string(&Quantization::Bf16).unwrap(),
1945 r#""bf16""#
1946 );
1947 assert_eq!(
1948 serde_json::to_string(&Quantization::Fp32).unwrap(),
1949 r#""fp32""#
1950 );
1951 assert_eq!(
1952 serde_json::to_string(&Quantization::Fp8).unwrap(),
1953 r#""fp8""#
1954 );
1955 assert_eq!(
1956 serde_json::to_string(&Quantization::Unknown).unwrap(),
1957 r#""unknown""#
1958 );
1959 }
1960
1961 #[test]
1962 fn test_provider_sort_strategy_serialization() {
1963 assert_eq!(
1964 serde_json::to_string(&ProviderSortStrategy::Price).unwrap(),
1965 r#""price""#
1966 );
1967 assert_eq!(
1968 serde_json::to_string(&ProviderSortStrategy::Throughput).unwrap(),
1969 r#""throughput""#
1970 );
1971 assert_eq!(
1972 serde_json::to_string(&ProviderSortStrategy::Latency).unwrap(),
1973 r#""latency""#
1974 );
1975 }
1976
1977 #[test]
1978 fn test_sort_partition_serialization() {
1979 assert_eq!(
1980 serde_json::to_string(&SortPartition::Model).unwrap(),
1981 r#""model""#
1982 );
1983 assert_eq!(
1984 serde_json::to_string(&SortPartition::None).unwrap(),
1985 r#""none""#
1986 );
1987 }
1988
1989 #[test]
1990 fn test_provider_sort_simple() {
1991 let sort = ProviderSort::Simple(ProviderSortStrategy::Latency);
1992 let json = serde_json::to_value(&sort).unwrap();
1993 assert_eq!(json, "latency");
1994 }
1995
1996 #[test]
1997 fn test_provider_sort_complex() {
1998 let sort = ProviderSort::Complex(
1999 ProviderSortConfig::new(ProviderSortStrategy::Price).partition(SortPartition::None),
2000 );
2001 let json = serde_json::to_value(&sort).unwrap();
2002 assert_eq!(json["by"], "price");
2003 assert_eq!(json["partition"], "none");
2004 }
2005
2006 #[test]
2007 fn test_provider_sort_complex_without_partition() {
2008 let sort = ProviderSort::Complex(ProviderSortConfig::new(ProviderSortStrategy::Throughput));
2009 let json = serde_json::to_value(&sort).unwrap();
2010 assert_eq!(json["by"], "throughput");
2011 assert!(json.get("partition").is_none());
2012 }
2013
2014 #[test]
2015 fn test_provider_sort_from_strategy() {
2016 let sort: ProviderSort = ProviderSortStrategy::Price.into();
2017 assert_eq!(sort, ProviderSort::Simple(ProviderSortStrategy::Price));
2018 }
2019
2020 #[test]
2021 fn test_provider_sort_from_config() {
2022 let config = ProviderSortConfig::new(ProviderSortStrategy::Latency);
2023 let sort: ProviderSort = config.into();
2024 match sort {
2025 ProviderSort::Complex(c) => assert_eq!(c.by, ProviderSortStrategy::Latency),
2026 _ => panic!("Expected Complex variant"),
2027 }
2028 }
2029
2030 #[test]
2031 fn test_percentile_thresholds_builder() {
2032 let thresholds = PercentileThresholds::new()
2033 .p50(10.0)
2034 .p75(25.0)
2035 .p90(50.0)
2036 .p99(100.0);
2037
2038 assert_eq!(thresholds.p50, Some(10.0));
2039 assert_eq!(thresholds.p75, Some(25.0));
2040 assert_eq!(thresholds.p90, Some(50.0));
2041 assert_eq!(thresholds.p99, Some(100.0));
2042 }
2043
2044 #[test]
2045 fn test_percentile_thresholds_default() {
2046 let thresholds = PercentileThresholds::default();
2047 assert_eq!(thresholds.p50, None);
2048 assert_eq!(thresholds.p75, None);
2049 assert_eq!(thresholds.p90, None);
2050 assert_eq!(thresholds.p99, None);
2051 }
2052
2053 #[test]
2054 fn test_throughput_threshold_simple() {
2055 let threshold = ThroughputThreshold::Simple(50.0);
2056 let json = serde_json::to_value(&threshold).unwrap();
2057 assert_eq!(json, 50.0);
2058 }
2059
2060 #[test]
2061 fn test_throughput_threshold_percentile() {
2062 let threshold = ThroughputThreshold::Percentile(PercentileThresholds::new().p90(50.0));
2063 let json = serde_json::to_value(&threshold).unwrap();
2064 assert_eq!(json["p90"], 50.0);
2065 }
2066
2067 #[test]
2068 fn test_latency_threshold_simple() {
2069 let threshold = LatencyThreshold::Simple(0.5);
2070 let json = serde_json::to_value(&threshold).unwrap();
2071 assert_eq!(json, 0.5);
2072 }
2073
2074 #[test]
2075 fn test_latency_threshold_percentile() {
2076 let threshold = LatencyThreshold::Percentile(PercentileThresholds::new().p50(0.1).p99(1.0));
2077 let json = serde_json::to_value(&threshold).unwrap();
2078 assert_eq!(json["p50"], 0.1);
2079 assert_eq!(json["p99"], 1.0);
2080 }
2081
2082 #[test]
2083 fn test_max_price_builder() {
2084 let price = MaxPrice::new().prompt(0.001).completion(0.002);
2085
2086 assert_eq!(price.prompt, Some(0.001));
2087 assert_eq!(price.completion, Some(0.002));
2088 assert_eq!(price.request, None);
2089 assert_eq!(price.image, None);
2090 }
2091
2092 #[test]
2093 fn test_max_price_all_fields() {
2094 let price = MaxPrice::new()
2095 .prompt(0.001)
2096 .completion(0.002)
2097 .request(0.01)
2098 .image(0.05);
2099
2100 let json = serde_json::to_value(&price).unwrap();
2101 assert_eq!(json["prompt"], 0.001);
2102 assert_eq!(json["completion"], 0.002);
2103 assert_eq!(json["request"], 0.01);
2104 assert_eq!(json["image"], 0.05);
2105 }
2106
2107 #[test]
2108 fn test_max_price_default() {
2109 let price = MaxPrice::default();
2110 assert_eq!(price.prompt, None);
2111 assert_eq!(price.completion, None);
2112 assert_eq!(price.request, None);
2113 assert_eq!(price.image, None);
2114 }
2115
2116 #[test]
2117 fn test_provider_preferences_default() {
2118 let prefs = ProviderPreferences::default();
2119 assert!(prefs.order.is_none());
2120 assert!(prefs.only.is_none());
2121 assert!(prefs.ignore.is_none());
2122 assert!(prefs.allow_fallbacks.is_none());
2123 assert!(prefs.require_parameters.is_none());
2124 assert!(prefs.data_collection.is_none());
2125 assert!(prefs.zdr.is_none());
2126 assert!(prefs.sort.is_none());
2127 assert!(prefs.preferred_min_throughput.is_none());
2128 assert!(prefs.preferred_max_latency.is_none());
2129 assert!(prefs.max_price.is_none());
2130 assert!(prefs.quantizations.is_none());
2131 }
2132
2133 #[test]
2134 fn test_provider_preferences_order_with_fallbacks() {
2135 let prefs = ProviderPreferences::new()
2136 .order(["anthropic", "openai"])
2137 .allow_fallbacks(true);
2138
2139 let json = prefs.to_json();
2140 let provider = &json["provider"];
2141
2142 assert_eq!(provider["order"], json!(["anthropic", "openai"]));
2143 assert_eq!(provider["allow_fallbacks"], true);
2144 }
2145
2146 #[test]
2147 fn test_provider_preferences_only_allowlist() {
2148 let prefs = ProviderPreferences::new()
2149 .only(["azure", "together"])
2150 .allow_fallbacks(false);
2151
2152 let json = prefs.to_json();
2153 let provider = &json["provider"];
2154
2155 assert_eq!(provider["only"], json!(["azure", "together"]));
2156 assert_eq!(provider["allow_fallbacks"], false);
2157 }
2158
2159 #[test]
2160 fn test_provider_preferences_ignore() {
2161 let prefs = ProviderPreferences::new().ignore(["deepinfra"]);
2162
2163 let json = prefs.to_json();
2164 let provider = &json["provider"];
2165
2166 assert_eq!(provider["ignore"], json!(["deepinfra"]));
2167 }
2168
2169 #[test]
2170 fn test_provider_preferences_sort_latency() {
2171 let prefs = ProviderPreferences::new().sort(ProviderSortStrategy::Latency);
2172
2173 let json = prefs.to_json();
2174 let provider = &json["provider"];
2175
2176 assert_eq!(provider["sort"], "latency");
2177 }
2178
2179 #[test]
2180 fn test_provider_preferences_price_with_throughput() {
2181 let prefs = ProviderPreferences::new()
2182 .sort(ProviderSortStrategy::Price)
2183 .preferred_min_throughput(ThroughputThreshold::Percentile(
2184 PercentileThresholds::new().p90(50.0),
2185 ));
2186
2187 let json = prefs.to_json();
2188 let provider = &json["provider"];
2189
2190 assert_eq!(provider["sort"], "price");
2191 assert_eq!(provider["preferred_min_throughput"]["p90"], 50.0);
2192 }
2193
2194 #[test]
2195 fn test_provider_preferences_require_parameters() {
2196 let prefs = ProviderPreferences::new().require_parameters(true);
2197
2198 let json = prefs.to_json();
2199 let provider = &json["provider"];
2200
2201 assert_eq!(provider["require_parameters"], true);
2202 }
2203
2204 #[test]
2205 fn test_provider_preferences_data_policy_and_zdr() {
2206 let prefs = ProviderPreferences::new()
2207 .data_collection(DataCollection::Deny)
2208 .zdr(true);
2209
2210 let json = prefs.to_json();
2211 let provider = &json["provider"];
2212
2213 assert_eq!(provider["data_collection"], "deny");
2214 assert_eq!(provider["zdr"], true);
2215 }
2216
2217 #[test]
2218 fn test_provider_preferences_quantizations() {
2219 let prefs =
2220 ProviderPreferences::new().quantizations([Quantization::Int8, Quantization::Fp16]);
2221
2222 let json = prefs.to_json();
2223 let provider = &json["provider"];
2224
2225 assert_eq!(provider["quantizations"], json!(["int8", "fp16"]));
2226 }
2227
2228 #[test]
2229 fn test_provider_preferences_convenience_methods() {
2230 let prefs = ProviderPreferences::new().zero_data_retention().fastest();
2231
2232 assert_eq!(prefs.zdr, Some(true));
2233 assert_eq!(
2234 prefs.sort,
2235 Some(ProviderSort::Simple(ProviderSortStrategy::Throughput))
2236 );
2237
2238 let prefs2 = ProviderPreferences::new().cheapest();
2239 assert_eq!(
2240 prefs2.sort,
2241 Some(ProviderSort::Simple(ProviderSortStrategy::Price))
2242 );
2243
2244 let prefs3 = ProviderPreferences::new().lowest_latency();
2245 assert_eq!(
2246 prefs3.sort,
2247 Some(ProviderSort::Simple(ProviderSortStrategy::Latency))
2248 );
2249 }
2250
2251 #[test]
2252 fn test_provider_preferences_serialization_skips_none() {
2253 let prefs = ProviderPreferences::new().sort(ProviderSortStrategy::Price);
2254
2255 let json = serde_json::to_value(&prefs).unwrap();
2256
2257 assert_eq!(json["sort"], "price");
2258 assert!(json.get("order").is_none());
2259 assert!(json.get("only").is_none());
2260 assert!(json.get("ignore").is_none());
2261 assert!(json.get("zdr").is_none());
2262 }
2263
2264 #[test]
2265 fn test_provider_preferences_deserialization() {
2266 let json = json!({
2267 "order": ["anthropic", "openai"],
2268 "sort": "throughput",
2269 "data_collection": "deny",
2270 "zdr": true,
2271 "quantizations": ["int8", "fp16"]
2272 });
2273
2274 let prefs: ProviderPreferences = serde_json::from_value(json).unwrap();
2275
2276 assert_eq!(
2277 prefs.order,
2278 Some(vec!["anthropic".to_string(), "openai".to_string()])
2279 );
2280 assert_eq!(
2281 prefs.sort,
2282 Some(ProviderSort::Simple(ProviderSortStrategy::Throughput))
2283 );
2284 assert_eq!(prefs.data_collection, Some(DataCollection::Deny));
2285 assert_eq!(prefs.zdr, Some(true));
2286 assert_eq!(
2287 prefs.quantizations,
2288 Some(vec![Quantization::Int8, Quantization::Fp16])
2289 );
2290 }
2291
2292 #[test]
2293 fn test_provider_preferences_deserialization_complex_sort() {
2294 let json = json!({
2295 "sort": {
2296 "by": "latency",
2297 "partition": "model"
2298 }
2299 });
2300
2301 let prefs: ProviderPreferences = serde_json::from_value(json).unwrap();
2302
2303 match prefs.sort {
2304 Some(ProviderSort::Complex(config)) => {
2305 assert_eq!(config.by, ProviderSortStrategy::Latency);
2306 assert_eq!(config.partition, Some(SortPartition::Model));
2307 }
2308 _ => panic!("Expected Complex sort variant"),
2309 }
2310 }
2311
2312 #[test]
2313 fn test_provider_preferences_full_integration() {
2314 let prefs = ProviderPreferences::new()
2315 .order(["anthropic", "openai"])
2316 .only(["anthropic", "openai", "google"])
2317 .sort(ProviderSortStrategy::Throughput)
2318 .data_collection(DataCollection::Deny)
2319 .zdr(true)
2320 .quantizations([Quantization::Int8])
2321 .allow_fallbacks(false);
2322
2323 let json = prefs.to_json();
2324
2325 assert!(json.get("provider").is_some());
2326 let provider = &json["provider"];
2327 assert_eq!(provider["order"], json!(["anthropic", "openai"]));
2328 assert_eq!(provider["only"], json!(["anthropic", "openai", "google"]));
2329 assert_eq!(provider["sort"], "throughput");
2330 assert_eq!(provider["data_collection"], "deny");
2331 assert_eq!(provider["zdr"], true);
2332 assert_eq!(provider["quantizations"], json!(["int8"]));
2333 assert_eq!(provider["allow_fallbacks"], false);
2334 }
2335
2336 #[test]
2337 fn test_provider_preferences_max_price() {
2338 let prefs =
2339 ProviderPreferences::new().max_price(MaxPrice::new().prompt(0.001).completion(0.002));
2340
2341 let json = prefs.to_json();
2342 let provider = &json["provider"];
2343
2344 assert_eq!(provider["max_price"]["prompt"], 0.001);
2345 assert_eq!(provider["max_price"]["completion"], 0.002);
2346 }
2347
2348 #[test]
2349 fn test_provider_preferences_preferred_max_latency() {
2350 let prefs = ProviderPreferences::new().preferred_max_latency(LatencyThreshold::Simple(0.5));
2351
2352 let json = prefs.to_json();
2353 let provider = &json["provider"];
2354
2355 assert_eq!(provider["preferred_max_latency"], 0.5);
2356 }
2357
2358 #[test]
2359 fn test_provider_preferences_empty_arrays() {
2360 let prefs = ProviderPreferences::new()
2361 .order(Vec::<String>::new())
2362 .quantizations(Vec::<Quantization>::new());
2363
2364 let json = prefs.to_json();
2365 let provider = &json["provider"];
2366
2367 assert_eq!(provider["order"], json!([]));
2368 assert_eq!(provider["quantizations"], json!([]));
2369 }
2370
2371 #[test]
2376 fn test_user_content_text_serialization() {
2377 let content = UserContent::text("Hello, world!");
2378 let json = serde_json::to_value(&content).unwrap();
2379
2380 assert_eq!(json["type"], "text");
2381 assert_eq!(json["text"], "Hello, world!");
2382 }
2383
2384 #[test]
2385 fn test_user_content_image_url_serialization() {
2386 let content = UserContent::image_url("https://example.com/image.png");
2387 let json = serde_json::to_value(&content).unwrap();
2388
2389 assert_eq!(json["type"], "image_url");
2390 assert_eq!(json["image_url"]["url"], "https://example.com/image.png");
2391 assert!(json["image_url"].get("detail").is_none());
2392 }
2393
2394 #[test]
2395 fn test_user_content_image_url_with_detail_serialization() {
2396 let content =
2397 UserContent::image_url_with_detail("https://example.com/image.png", ImageDetail::High);
2398 let json = serde_json::to_value(&content).unwrap();
2399
2400 assert_eq!(json["type"], "image_url");
2401 assert_eq!(json["image_url"]["url"], "https://example.com/image.png");
2402 assert_eq!(json["image_url"]["detail"], "high");
2403 }
2404
2405 #[test]
2406 fn test_user_content_image_base64_serialization() {
2407 let content = UserContent::image_base64("SGVsbG8=", "image/png", Some(ImageDetail::Low));
2408 let json = serde_json::to_value(&content).unwrap();
2409
2410 assert_eq!(json["type"], "image_url");
2411 assert_eq!(json["image_url"]["url"], "data:image/png;base64,SGVsbG8=");
2412 assert_eq!(json["image_url"]["detail"], "low");
2413 }
2414
2415 #[test]
2416 fn test_user_content_file_url_serialization() {
2417 let content = UserContent::file_url(
2418 "https://example.com/doc.pdf",
2419 Some("document.pdf".to_string()),
2420 );
2421 let json = serde_json::to_value(&content).unwrap();
2422
2423 assert_eq!(json["type"], "file");
2424 assert_eq!(json["file"]["file_data"], "https://example.com/doc.pdf");
2425 assert_eq!(json["file"]["filename"], "document.pdf");
2426 }
2427
2428 #[test]
2429 fn test_user_content_file_base64_serialization() {
2430 let content = UserContent::file_base64(
2431 "JVBERi0xLjQ=",
2432 "application/pdf",
2433 Some("report.pdf".to_string()),
2434 );
2435 let json = serde_json::to_value(&content).unwrap();
2436
2437 assert_eq!(json["type"], "file");
2438 assert_eq!(
2439 json["file"]["file_data"],
2440 "data:application/pdf;base64,JVBERi0xLjQ="
2441 );
2442 assert_eq!(json["file"]["filename"], "report.pdf");
2443 }
2444
2445 #[test]
2446 fn test_user_content_text_deserialization() {
2447 let json = json!({
2448 "type": "text",
2449 "text": "Hello!"
2450 });
2451
2452 let content: UserContent = serde_json::from_value(json).unwrap();
2453 assert_eq!(
2454 content,
2455 UserContent::Text {
2456 text: "Hello!".to_string()
2457 }
2458 );
2459 }
2460
2461 #[test]
2462 fn test_user_content_image_url_deserialization() {
2463 let json = json!({
2464 "type": "image_url",
2465 "image_url": {
2466 "url": "https://example.com/img.jpg",
2467 "detail": "high"
2468 }
2469 });
2470
2471 let content: UserContent = serde_json::from_value(json).unwrap();
2472 match content {
2473 UserContent::ImageUrl { image_url } => {
2474 assert_eq!(image_url.url, "https://example.com/img.jpg");
2475 assert_eq!(image_url.detail, Some(ImageDetail::High));
2476 }
2477 _ => panic!("Expected ImageUrl variant"),
2478 }
2479 }
2480
2481 #[test]
2482 fn test_user_content_file_deserialization() {
2483 let json = json!({
2484 "type": "file",
2485 "file": {
2486 "filename": "doc.pdf",
2487 "file_data": "https://example.com/doc.pdf"
2488 }
2489 });
2490
2491 let content: UserContent = serde_json::from_value(json).unwrap();
2492 match content {
2493 UserContent::File { file } => {
2494 assert_eq!(file.filename, Some("doc.pdf".to_string()));
2495 assert_eq!(
2496 file.file_data,
2497 Some("https://example.com/doc.pdf".to_string())
2498 );
2499 }
2500 _ => panic!("Expected File variant"),
2501 }
2502 }
2503
2504 #[test]
2505 fn test_message_user_with_text_serialization() {
2506 let message = Message::User {
2507 content: OneOrMany::one(UserContent::text("Hello")),
2508 name: None,
2509 };
2510 let json = serde_json::to_value(&message).unwrap();
2511
2512 assert_eq!(json["role"], "user");
2514 assert_eq!(json["content"], "Hello");
2515 }
2516
2517 #[test]
2518 fn test_message_user_with_mixed_content_serialization() {
2519 let message = Message::User {
2520 content: OneOrMany::many(vec![
2521 UserContent::text("Check this image:"),
2522 UserContent::image_url("https://example.com/img.png"),
2523 ])
2524 .unwrap(),
2525 name: None,
2526 };
2527 let json = serde_json::to_value(&message).unwrap();
2528
2529 assert_eq!(json["role"], "user");
2530 let content = json["content"].as_array().unwrap();
2531 assert_eq!(content.len(), 2);
2532 assert_eq!(content[0]["type"], "text");
2533 assert_eq!(content[1]["type"], "image_url");
2534 }
2535
2536 #[test]
2537 fn test_message_user_with_file_serialization() {
2538 let message = Message::User {
2539 content: OneOrMany::many(vec![
2540 UserContent::text("Analyze this PDF:"),
2541 UserContent::file_url(
2542 "https://example.com/doc.pdf",
2543 Some("document.pdf".to_string()),
2544 ),
2545 ])
2546 .unwrap(),
2547 name: None,
2548 };
2549 let json = serde_json::to_value(&message).unwrap();
2550
2551 assert_eq!(json["role"], "user");
2552 let content = json["content"].as_array().unwrap();
2553 assert_eq!(content.len(), 2);
2554 assert_eq!(content[0]["type"], "text");
2555 assert_eq!(content[1]["type"], "file");
2556 assert_eq!(
2557 content[1]["file"]["file_data"],
2558 "https://example.com/doc.pdf"
2559 );
2560 }
2561
2562 #[test]
2563 fn test_user_content_from_rig_text() {
2564 let rig_content = message::UserContent::Text(message::Text {
2565 text: "Hello".to_string(),
2566 });
2567 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2568
2569 assert_eq!(
2570 openrouter_content,
2571 UserContent::Text {
2572 text: "Hello".to_string()
2573 }
2574 );
2575 }
2576
2577 #[test]
2578 fn test_user_content_from_rig_image_url() {
2579 let rig_content = message::UserContent::Image(message::Image {
2580 data: DocumentSourceKind::Url("https://example.com/img.png".to_string()),
2581 media_type: Some(message::ImageMediaType::PNG),
2582 detail: Some(ImageDetail::High),
2583 additional_params: None,
2584 });
2585 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2586
2587 match openrouter_content {
2588 UserContent::ImageUrl { image_url } => {
2589 assert_eq!(image_url.url, "https://example.com/img.png");
2590 assert_eq!(image_url.detail, Some(ImageDetail::High));
2591 }
2592 _ => panic!("Expected ImageUrl variant"),
2593 }
2594 }
2595
2596 #[test]
2597 fn test_user_content_from_rig_image_base64() {
2598 let rig_content = message::UserContent::Image(message::Image {
2599 data: DocumentSourceKind::Base64("SGVsbG8=".to_string()),
2600 media_type: Some(message::ImageMediaType::JPEG),
2601 detail: Some(ImageDetail::Low),
2602 additional_params: None,
2603 });
2604 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2605
2606 match openrouter_content {
2607 UserContent::ImageUrl { image_url } => {
2608 assert_eq!(image_url.url, "data:image/jpeg;base64,SGVsbG8=");
2609 assert_eq!(image_url.detail, Some(ImageDetail::Low));
2610 }
2611 _ => panic!("Expected ImageUrl variant"),
2612 }
2613 }
2614
2615 #[test]
2616 fn test_user_content_from_rig_document_url() {
2617 let rig_content = message::UserContent::Document(message::Document {
2618 data: DocumentSourceKind::Url("https://example.com/doc.pdf".to_string()),
2619 media_type: Some(DocumentMediaType::PDF),
2620 additional_params: None,
2621 });
2622 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2623
2624 match openrouter_content {
2625 UserContent::File { file } => {
2626 assert_eq!(
2627 file.file_data,
2628 Some("https://example.com/doc.pdf".to_string())
2629 );
2630 assert_eq!(file.filename, Some("document.pdf".to_string()));
2631 }
2632 _ => panic!("Expected File variant"),
2633 }
2634 }
2635
2636 #[test]
2637 fn test_user_content_from_rig_document_base64() {
2638 let rig_content = message::UserContent::Document(message::Document {
2639 data: DocumentSourceKind::Base64("JVBERi0xLjQ=".to_string()),
2640 media_type: Some(DocumentMediaType::PDF),
2641 additional_params: None,
2642 });
2643 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2644
2645 match openrouter_content {
2646 UserContent::File { file } => {
2647 assert_eq!(
2648 file.file_data,
2649 Some("data:application/pdf;base64,JVBERi0xLjQ=".to_string())
2650 );
2651 assert_eq!(file.filename, Some("document.pdf".to_string()));
2652 }
2653 _ => panic!("Expected File variant"),
2654 }
2655 }
2656
2657 #[test]
2658 fn test_user_content_from_rig_document_string_becomes_text() {
2659 let rig_content = message::UserContent::Document(message::Document {
2660 data: DocumentSourceKind::String("Plain text document content".to_string()),
2661 media_type: Some(DocumentMediaType::TXT),
2662 additional_params: None,
2663 });
2664 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2665
2666 assert_eq!(
2667 openrouter_content,
2668 UserContent::Text {
2669 text: "Plain text document content".to_string()
2670 }
2671 );
2672 }
2673
2674 #[test]
2675 fn test_completion_response_with_reasoning_details_maps_to_typed_reasoning() {
2676 let json = json!({
2677 "id": "resp_123",
2678 "object": "chat.completion",
2679 "created": 1,
2680 "model": "openrouter/test-model",
2681 "choices": [{
2682 "index": 0,
2683 "finish_reason": "stop",
2684 "message": {
2685 "role": "assistant",
2686 "content": "hello",
2687 "reasoning": null,
2688 "reasoning_details": [
2689 {"type":"reasoning.summary","id":"rs_1","summary":"s1"},
2690 {"type":"reasoning.text","id":"rs_1","text":"t1","signature":"sig_1"},
2691 {"type":"reasoning.encrypted","id":"rs_1","data":"enc_1"}
2692 ]
2693 }
2694 }]
2695 });
2696
2697 let response: CompletionResponse = serde_json::from_value(json).unwrap();
2698 let converted: completion::CompletionResponse<CompletionResponse> =
2699 response.try_into().unwrap();
2700 let items: Vec<completion::AssistantContent> = converted.choice.into_iter().collect();
2701
2702 assert!(items.iter().any(|item| matches!(
2703 item,
2704 completion::AssistantContent::Reasoning(message::Reasoning { id: Some(id), content })
2705 if id == "rs_1" && content.len() == 3
2706 )));
2707 }
2708
2709 #[test]
2710 fn test_assistant_reasoning_emits_openrouter_reasoning_details() {
2711 let reasoning = message::Reasoning {
2712 id: Some("rs_2".to_string()),
2713 content: vec![
2714 message::ReasoningContent::Text {
2715 text: "step".to_string(),
2716 signature: Some("sig_step".to_string()),
2717 },
2718 message::ReasoningContent::Summary("summary".to_string()),
2719 message::ReasoningContent::Encrypted("enc_blob".to_string()),
2720 ],
2721 };
2722
2723 let messages = Vec::<Message>::try_from(OneOrMany::one(
2724 message::AssistantContent::Reasoning(reasoning),
2725 ))
2726 .unwrap();
2727 let Message::Assistant {
2728 reasoning,
2729 reasoning_details,
2730 ..
2731 } = messages.first().expect("assistant message")
2732 else {
2733 panic!("Expected assistant message");
2734 };
2735
2736 assert!(reasoning.is_none());
2737 assert_eq!(reasoning_details.len(), 3);
2738 assert!(matches!(
2739 reasoning_details.first(),
2740 Some(ReasoningDetails::Text {
2741 id: Some(id),
2742 text: Some(text),
2743 signature: Some(signature),
2744 ..
2745 }) if id == "rs_2" && text == "step" && signature == "sig_step"
2746 ));
2747 }
2748
2749 #[test]
2750 fn test_assistant_redacted_reasoning_emits_encrypted_detail_not_text() {
2751 let reasoning = message::Reasoning {
2752 id: Some("rs_redacted".to_string()),
2753 content: vec![message::ReasoningContent::Redacted {
2754 data: "opaque-redacted-data".to_string(),
2755 }],
2756 };
2757
2758 let messages = Vec::<Message>::try_from(OneOrMany::one(
2759 message::AssistantContent::Reasoning(reasoning),
2760 ))
2761 .unwrap();
2762
2763 let Message::Assistant {
2764 reasoning_details,
2765 reasoning,
2766 ..
2767 } = messages.first().expect("assistant message")
2768 else {
2769 panic!("Expected assistant message");
2770 };
2771
2772 assert!(reasoning.is_none());
2773 assert_eq!(reasoning_details.len(), 1);
2774 assert!(matches!(
2775 reasoning_details.first(),
2776 Some(ReasoningDetails::Encrypted {
2777 id: Some(id),
2778 data,
2779 ..
2780 }) if id == "rs_redacted" && data == "opaque-redacted-data"
2781 ));
2782 }
2783
2784 #[test]
2785 fn test_completion_response_reasoning_details_respects_index_ordering() {
2786 let json = json!({
2787 "id": "resp_ordering",
2788 "object": "chat.completion",
2789 "created": 1,
2790 "model": "openrouter/test-model",
2791 "choices": [{
2792 "index": 0,
2793 "finish_reason": "stop",
2794 "message": {
2795 "role": "assistant",
2796 "content": "hello",
2797 "reasoning": null,
2798 "reasoning_details": [
2799 {"type":"reasoning.summary","id":"rs_order","index":1,"summary":"second"},
2800 {"type":"reasoning.summary","id":"rs_order","index":0,"summary":"first"}
2801 ]
2802 }
2803 }]
2804 });
2805
2806 let response: CompletionResponse = serde_json::from_value(json).unwrap();
2807 let converted: completion::CompletionResponse<CompletionResponse> =
2808 response.try_into().unwrap();
2809 let items: Vec<completion::AssistantContent> = converted.choice.into_iter().collect();
2810 let reasoning_blocks: Vec<_> = items
2811 .into_iter()
2812 .filter_map(|item| match item {
2813 completion::AssistantContent::Reasoning(reasoning) => Some(reasoning),
2814 _ => None,
2815 })
2816 .collect();
2817
2818 assert_eq!(reasoning_blocks.len(), 1);
2819 assert_eq!(reasoning_blocks[0].id.as_deref(), Some("rs_order"));
2820 assert_eq!(
2821 reasoning_blocks[0].content,
2822 vec![
2823 message::ReasoningContent::Summary("first".to_string()),
2824 message::ReasoningContent::Summary("second".to_string()),
2825 ]
2826 );
2827 }
2828
2829 #[test]
2830 fn test_user_content_from_rig_image_missing_media_type_error() {
2831 let rig_content = message::UserContent::Image(message::Image {
2832 data: DocumentSourceKind::Base64("SGVsbG8=".to_string()),
2833 media_type: None, detail: None,
2835 additional_params: None,
2836 });
2837 let result: Result<UserContent, _> = rig_content.try_into();
2838
2839 assert!(result.is_err());
2840 let err = result.unwrap_err();
2841 assert!(err.to_string().contains("media type required"));
2842 }
2843
2844 #[test]
2845 fn test_user_content_from_rig_image_raw_bytes_error() {
2846 let rig_content = message::UserContent::Image(message::Image {
2847 data: DocumentSourceKind::Raw(vec![1, 2, 3]),
2848 media_type: Some(message::ImageMediaType::PNG),
2849 detail: None,
2850 additional_params: None,
2851 });
2852 let result: Result<UserContent, _> = rig_content.try_into();
2853
2854 assert!(result.is_err());
2855 let err = result.unwrap_err();
2856 assert!(err.to_string().contains("base64"));
2857 }
2858
2859 #[test]
2860 fn test_user_content_from_rig_video_url() {
2861 let rig_content = message::UserContent::Video(message::Video {
2862 data: DocumentSourceKind::Url("https://example.com/video.mp4".to_string()),
2863 media_type: Some(message::VideoMediaType::MP4),
2864 additional_params: None,
2865 });
2866 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2867
2868 match openrouter_content {
2869 UserContent::VideoUrl { video_url } => {
2870 assert_eq!(video_url.url, "https://example.com/video.mp4");
2871 }
2872 _ => panic!("Expected VideoUrl variant"),
2873 }
2874 }
2875
2876 #[test]
2877 fn test_user_content_from_rig_video_base64() {
2878 let rig_content = message::UserContent::Video(message::Video {
2879 data: DocumentSourceKind::Base64("SGVsbG8=".to_string()),
2880 media_type: Some(message::VideoMediaType::MP4),
2881 additional_params: None,
2882 });
2883 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2884
2885 match openrouter_content {
2886 UserContent::VideoUrl { video_url } => {
2887 assert_eq!(video_url.url, "data:video/mp4;base64,SGVsbG8=");
2888 }
2889 _ => panic!("Expected VideoUrl variant"),
2890 }
2891 }
2892
2893 #[test]
2894 fn test_user_content_from_rig_video_base64_missing_media_type_error() {
2895 let rig_content = message::UserContent::Video(message::Video {
2896 data: DocumentSourceKind::Base64("SGVsbG8=".to_string()),
2897 media_type: None,
2898 additional_params: None,
2899 });
2900 let result: Result<UserContent, _> = rig_content.try_into();
2901
2902 assert!(result.is_err());
2903 let err = result.unwrap_err();
2904 assert!(err.to_string().contains("media type"));
2905 }
2906
2907 #[test]
2908 fn test_user_content_from_rig_video_raw_bytes_error() {
2909 let rig_content = message::UserContent::Video(message::Video {
2910 data: DocumentSourceKind::Raw(vec![1, 2, 3]),
2911 media_type: Some(message::VideoMediaType::MP4),
2912 additional_params: None,
2913 });
2914 let result: Result<UserContent, _> = rig_content.try_into();
2915
2916 assert!(result.is_err());
2917 let err = result.unwrap_err();
2918 assert!(err.to_string().contains("base64"));
2919 }
2920
2921 #[test]
2922 fn test_user_content_from_rig_audio_base64() {
2923 let rig_content = message::UserContent::Audio(message::Audio {
2924 data: DocumentSourceKind::Base64("audiodata".to_string()),
2925 media_type: Some(message::AudioMediaType::MP3),
2926 additional_params: None,
2927 });
2928 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2929
2930 match openrouter_content {
2931 UserContent::InputAudio { input_audio } => {
2932 assert_eq!(input_audio.data, "audiodata");
2933 assert_eq!(input_audio.format, message::AudioMediaType::MP3);
2934 }
2935 _ => panic!("Expected InputAudio variant"),
2936 }
2937 }
2938
2939 #[test]
2940 fn test_user_content_from_rig_audio_missing_media_type_error() {
2941 let rig_content = message::UserContent::Audio(message::Audio {
2942 data: DocumentSourceKind::Base64("audiodata".to_string()),
2943 media_type: None, additional_params: None,
2945 });
2946 let result: Result<UserContent, _> = rig_content.try_into();
2947
2948 assert!(result.is_err());
2949 let err = result.unwrap_err();
2950 assert!(err.to_string().contains("media type required"));
2951 }
2952
2953 #[test]
2954 fn test_user_content_from_rig_audio_url_error() {
2955 let rig_content = message::UserContent::Audio(message::Audio {
2956 data: DocumentSourceKind::Url("https://example.com/audio.wav".to_string()),
2957 media_type: Some(message::AudioMediaType::WAV),
2958 additional_params: None,
2959 });
2960 let result: Result<UserContent, _> = rig_content.try_into();
2961
2962 assert!(result.is_err());
2963 let err = result.unwrap_err();
2964 assert!(err.to_string().contains("base64"));
2965 }
2966
2967 #[test]
2968 fn test_user_content_from_rig_audio_raw_bytes_error() {
2969 let rig_content = message::UserContent::Audio(message::Audio {
2970 data: DocumentSourceKind::Raw(vec![1, 2, 3]),
2971 media_type: Some(message::AudioMediaType::WAV),
2972 additional_params: None,
2973 });
2974 let result: Result<UserContent, _> = rig_content.try_into();
2975
2976 assert!(result.is_err());
2977 let err = result.unwrap_err();
2978 assert!(err.to_string().contains("base64"));
2979 }
2980
2981 #[test]
2982 fn test_message_conversion_with_pdf() {
2983 let rig_message = message::Message::User {
2984 content: OneOrMany::many(vec![
2985 message::UserContent::Text(message::Text {
2986 text: "Summarize this document".to_string(),
2987 }),
2988 message::UserContent::Document(message::Document {
2989 data: DocumentSourceKind::Url("https://example.com/paper.pdf".to_string()),
2990 media_type: Some(DocumentMediaType::PDF),
2991 additional_params: None,
2992 }),
2993 ])
2994 .unwrap(),
2995 };
2996
2997 let openrouter_messages: Vec<Message> = rig_message.try_into().unwrap();
2998 assert_eq!(openrouter_messages.len(), 1);
2999
3000 match &openrouter_messages[0] {
3001 Message::User { content, .. } => {
3002 assert_eq!(content.len(), 2);
3003
3004 match content.first_ref() {
3006 UserContent::Text { text } => assert_eq!(text, "Summarize this document"),
3007 _ => panic!("Expected Text"),
3008 }
3009 }
3010 _ => panic!("Expected User message"),
3011 }
3012 }
3013
3014 #[test]
3015 fn test_user_content_from_string() {
3016 let content: UserContent = "Hello".into();
3017 assert_eq!(
3018 content,
3019 UserContent::Text {
3020 text: "Hello".to_string()
3021 }
3022 );
3023
3024 let content: UserContent = String::from("World").into();
3025 assert_eq!(
3026 content,
3027 UserContent::Text {
3028 text: "World".to_string()
3029 }
3030 );
3031 }
3032
3033 #[test]
3034 fn test_openai_user_content_conversion() {
3035 let openai_text = openai::UserContent::Text {
3037 text: "Hello".to_string(),
3038 };
3039 let converted: UserContent = openai_text.into();
3040 assert_eq!(
3041 converted,
3042 UserContent::Text {
3043 text: "Hello".to_string()
3044 }
3045 );
3046
3047 let openai_image = openai::UserContent::Image {
3048 image_url: openai::ImageUrl {
3049 url: "https://example.com/img.png".to_string(),
3050 detail: ImageDetail::Auto,
3051 },
3052 };
3053 let converted: UserContent = openai_image.into();
3054 match converted {
3055 UserContent::ImageUrl { image_url } => {
3056 assert_eq!(image_url.url, "https://example.com/img.png");
3057 assert_eq!(image_url.detail, Some(ImageDetail::Auto));
3058 }
3059 _ => panic!("Expected ImageUrl"),
3060 }
3061
3062 let openai_audio = openai::UserContent::Audio {
3063 input_audio: openai::InputAudio {
3064 data: "audiodata".to_string(),
3065 format: AudioMediaType::FLAC,
3066 },
3067 };
3068 let converted: UserContent = openai_audio.into();
3069 match converted {
3070 UserContent::InputAudio { input_audio } => {
3071 assert_eq!(input_audio.data, "audiodata");
3072 assert_eq!(input_audio.format, AudioMediaType::FLAC);
3073 }
3074 _ => panic!("Expected InputAudio"),
3075 }
3076 }
3077
3078 #[test]
3079 fn test_completion_response_reasoning_details_with_multiple_ids_stay_separate() {
3080 let json = json!({
3081 "id": "resp_multi_id",
3082 "object": "chat.completion",
3083 "created": 1,
3084 "model": "openrouter/test-model",
3085 "choices": [{
3086 "index": 0,
3087 "finish_reason": "stop",
3088 "message": {
3089 "role": "assistant",
3090 "content": "hello",
3091 "reasoning": null,
3092 "reasoning_details": [
3093 {"type":"reasoning.summary","id":"rs_a","summary":"a1"},
3094 {"type":"reasoning.summary","id":"rs_b","summary":"b1"},
3095 {"type":"reasoning.summary","id":"rs_a","summary":"a2"}
3096 ]
3097 }
3098 }]
3099 });
3100
3101 let response: CompletionResponse = serde_json::from_value(json).unwrap();
3102 let converted: completion::CompletionResponse<CompletionResponse> =
3103 response.try_into().unwrap();
3104 let items: Vec<completion::AssistantContent> = converted.choice.into_iter().collect();
3105 let reasoning_blocks: Vec<_> = items
3106 .into_iter()
3107 .filter_map(|item| match item {
3108 completion::AssistantContent::Reasoning(reasoning) => Some(reasoning),
3109 _ => None,
3110 })
3111 .collect();
3112
3113 assert_eq!(reasoning_blocks.len(), 2);
3114 assert_eq!(reasoning_blocks[0].id.as_deref(), Some("rs_a"));
3115 assert_eq!(
3116 reasoning_blocks[0].content,
3117 vec![
3118 message::ReasoningContent::Summary("a1".to_string()),
3119 message::ReasoningContent::Summary("a2".to_string()),
3120 ]
3121 );
3122 assert_eq!(reasoning_blocks[1].id.as_deref(), Some("rs_b"));
3123 assert_eq!(
3124 reasoning_blocks[1].content,
3125 vec![message::ReasoningContent::Summary("b1".to_string())]
3126 );
3127 }
3128
3129 #[test]
3130 fn test_user_content_audio_serialization() {
3131 let content = UserContent::audio_base64("SGVsbG8=", AudioMediaType::WAV);
3132 let json = serde_json::to_value(&content).unwrap();
3133
3134 assert_eq!(json["type"], "input_audio");
3135 assert_eq!(json["input_audio"]["data"], "SGVsbG8=");
3136 assert_eq!(json["input_audio"]["format"], "wav");
3137 }
3138
3139 #[test]
3140 fn test_user_content_audio_deserialization() {
3141 let json = json!({
3142 "type": "input_audio",
3143 "input_audio": {
3144 "data": "SGVsbG8=",
3145 "format": "wav"
3146 }
3147 });
3148
3149 let content: UserContent = serde_json::from_value(json).unwrap();
3150 match content {
3151 UserContent::InputAudio { input_audio } => {
3152 assert_eq!(input_audio.data, "SGVsbG8=");
3153 assert_eq!(input_audio.format, AudioMediaType::WAV);
3154 }
3155 _ => panic!("Expected InputAudio variant"),
3156 }
3157 }
3158
3159 #[test]
3160 fn test_message_user_with_audio_serialization() {
3161 let msg = Message::User {
3162 content: OneOrMany::many(vec![
3163 UserContent::text("Transcribe this audio:"),
3164 UserContent::audio_base64("SGVsbG8=", AudioMediaType::MP3),
3165 ])
3166 .unwrap(),
3167 name: None,
3168 };
3169 let json = serde_json::to_value(&msg).unwrap();
3170
3171 assert_eq!(json["role"], "user");
3172 let content = json["content"].as_array().unwrap();
3173 assert_eq!(content.len(), 2);
3174 assert_eq!(content[0]["type"], "text");
3175 assert_eq!(content[1]["type"], "input_audio");
3176 assert_eq!(content[1]["input_audio"]["data"], "SGVsbG8=");
3177 assert_eq!(content[1]["input_audio"]["format"], "mp3");
3178 }
3179
3180 #[test]
3181 fn test_user_content_video_url_serialization() {
3182 let content = UserContent::video_url("https://example.com/video.mp4");
3183 let json = serde_json::to_value(&content).unwrap();
3184
3185 assert_eq!(json["type"], "video_url");
3186 assert_eq!(json["video_url"]["url"], "https://example.com/video.mp4");
3187 }
3188
3189 #[test]
3190 fn test_user_content_video_base64_serialization() {
3191 let content = UserContent::video_base64("SGVsbG8=", VideoMediaType::MP4);
3192 let json = serde_json::to_value(&content).unwrap();
3193
3194 assert_eq!(json["type"], "video_url");
3195 assert_eq!(json["video_url"]["url"], "data:video/mp4;base64,SGVsbG8=");
3196 }
3197
3198 #[test]
3199 fn test_user_content_video_url_deserialization() {
3200 let json = json!({
3201 "type": "video_url",
3202 "video_url": {
3203 "url": "https://example.com/video.mp4"
3204 }
3205 });
3206
3207 let content: UserContent = serde_json::from_value(json).unwrap();
3208 match content {
3209 UserContent::VideoUrl { video_url } => {
3210 assert_eq!(video_url.url, "https://example.com/video.mp4");
3211 }
3212 _ => panic!("Expected VideoUrl variant"),
3213 }
3214 }
3215
3216 #[test]
3217 fn test_message_user_with_video_serialization() {
3218 let msg = Message::User {
3219 content: OneOrMany::many(vec![
3220 UserContent::text("Describe this video:"),
3221 UserContent::video_url("https://example.com/video.mp4"),
3222 ])
3223 .unwrap(),
3224 name: None,
3225 };
3226 let json = serde_json::to_value(&msg).unwrap();
3227
3228 assert_eq!(json["role"], "user");
3229 let content = json["content"].as_array().unwrap();
3230 assert_eq!(content.len(), 2);
3231 assert_eq!(content[0]["type"], "text");
3232 assert_eq!(content[1]["type"], "video_url");
3233 assert_eq!(
3234 content[1]["video_url"]["url"],
3235 "https://example.com/video.mp4"
3236 );
3237 }
3238
3239 #[test]
3240 fn test_user_content_video_url_no_media_type_needed() {
3241 let rig_content = message::UserContent::Video(message::Video {
3242 data: DocumentSourceKind::Url("https://example.com/video.mp4".to_string()),
3243 media_type: None,
3244 additional_params: None,
3245 });
3246 let openrouter_content: UserContent = rig_content.try_into().unwrap();
3247
3248 match openrouter_content {
3249 UserContent::VideoUrl { video_url } => {
3250 assert_eq!(video_url.url, "https://example.com/video.mp4");
3251 }
3252 _ => panic!("Expected VideoUrl variant"),
3253 }
3254 }
3255}