1use super::{
2 client::{ApiErrorResponse, ApiResponse, Client, Usage},
3 streaming::StreamingCompletionResponse,
4};
5use crate::message::{
6 self, AudioMediaType, DocumentMediaType, DocumentSourceKind, ImageDetail, MimeType,
7 VideoMediaType,
8};
9use crate::telemetry::SpanCombinator;
10use crate::{
11 OneOrMany,
12 completion::{self, CompletionError, CompletionRequest},
13 http_client::HttpClientExt,
14 json_utils,
15 one_or_many::string_or_one_or_many,
16 providers::openai,
17};
18use bytes::Bytes;
19use serde::{Deserialize, Serialize, Serializer};
20use std::collections::HashMap;
21use tracing::{Instrument, Level, enabled, info_span};
22
23pub const QWEN_QWQ_32B: &str = "qwen/qwq-32b";
29pub const CLAUDE_3_7_SONNET: &str = "anthropic/claude-3.7-sonnet";
31pub const PERPLEXITY_SONAR_PRO: &str = "perplexity/sonar-pro";
33pub const GEMINI_FLASH_2_0: &str = "google/gemini-2.0-flash-001";
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
45#[serde(rename_all = "lowercase")]
46pub enum DataCollection {
47 #[default]
49 Allow,
50 Deny,
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
58#[serde(rename_all = "lowercase")]
59pub enum Quantization {
60 #[serde(rename = "int4")]
62 Int4,
63 #[serde(rename = "int8")]
65 Int8,
66 #[serde(rename = "fp16")]
68 Fp16,
69 #[serde(rename = "bf16")]
71 Bf16,
72 #[serde(rename = "fp32")]
74 Fp32,
75 #[serde(rename = "fp8")]
77 Fp8,
78 #[serde(rename = "unknown")]
80 Unknown,
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
89#[serde(rename_all = "lowercase")]
90pub enum ProviderSortStrategy {
91 Price,
93 Throughput,
95 Latency,
97}
98
99#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
101#[serde(rename_all = "lowercase")]
102pub enum SortPartition {
103 Model,
105 None,
107}
108
109#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
113pub struct ProviderSortConfig {
114 pub by: ProviderSortStrategy,
116
117 #[serde(skip_serializing_if = "Option::is_none")]
119 pub partition: Option<SortPartition>,
120}
121
122impl ProviderSortConfig {
123 pub fn new(by: ProviderSortStrategy) -> Self {
125 Self {
126 by,
127 partition: None,
128 }
129 }
130
131 pub fn partition(mut self, partition: SortPartition) -> Self {
133 self.partition = Some(partition);
134 self
135 }
136}
137
138#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
143#[serde(untagged)]
144pub enum ProviderSort {
145 Simple(ProviderSortStrategy),
147 Complex(ProviderSortConfig),
149}
150
151impl From<ProviderSortStrategy> for ProviderSort {
152 fn from(strategy: ProviderSortStrategy) -> Self {
153 ProviderSort::Simple(strategy)
154 }
155}
156
157impl From<ProviderSortConfig> for ProviderSort {
158 fn from(config: ProviderSortConfig) -> Self {
159 ProviderSort::Complex(config)
160 }
161}
162
163#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
167#[serde(untagged)]
168pub enum ThroughputThreshold {
169 Simple(f64),
171 Percentile(PercentileThresholds),
173}
174
175#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
179#[serde(untagged)]
180pub enum LatencyThreshold {
181 Simple(f64),
183 Percentile(PercentileThresholds),
185}
186
187#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
189pub struct PercentileThresholds {
190 #[serde(skip_serializing_if = "Option::is_none")]
192 pub p50: Option<f64>,
193 #[serde(skip_serializing_if = "Option::is_none")]
195 pub p75: Option<f64>,
196 #[serde(skip_serializing_if = "Option::is_none")]
198 pub p90: Option<f64>,
199 #[serde(skip_serializing_if = "Option::is_none")]
201 pub p99: Option<f64>,
202}
203
204impl PercentileThresholds {
205 pub fn new() -> Self {
207 Self::default()
208 }
209
210 pub fn p50(mut self, value: f64) -> Self {
212 self.p50 = Some(value);
213 self
214 }
215
216 pub fn p75(mut self, value: f64) -> Self {
218 self.p75 = Some(value);
219 self
220 }
221
222 pub fn p90(mut self, value: f64) -> Self {
224 self.p90 = Some(value);
225 self
226 }
227
228 pub fn p99(mut self, value: f64) -> Self {
230 self.p99 = Some(value);
231 self
232 }
233}
234
235#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
240pub struct MaxPrice {
241 #[serde(skip_serializing_if = "Option::is_none")]
243 pub prompt: Option<f64>,
244 #[serde(skip_serializing_if = "Option::is_none")]
246 pub completion: Option<f64>,
247 #[serde(skip_serializing_if = "Option::is_none")]
249 pub request: Option<f64>,
250 #[serde(skip_serializing_if = "Option::is_none")]
252 pub image: Option<f64>,
253}
254
255impl MaxPrice {
256 pub fn new() -> Self {
258 Self::default()
259 }
260
261 pub fn prompt(mut self, price: f64) -> Self {
263 self.prompt = Some(price);
264 self
265 }
266
267 pub fn completion(mut self, price: f64) -> Self {
269 self.completion = Some(price);
270 self
271 }
272
273 pub fn request(mut self, price: f64) -> Self {
275 self.request = Some(price);
276 self
277 }
278
279 pub fn image(mut self, price: f64) -> Self {
281 self.image = Some(price);
282 self
283 }
284}
285
286#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
306pub struct ProviderPreferences {
307 #[serde(skip_serializing_if = "Option::is_none")]
311 pub order: Option<Vec<String>>,
312
313 #[serde(skip_serializing_if = "Option::is_none")]
315 pub only: Option<Vec<String>>,
316
317 #[serde(skip_serializing_if = "Option::is_none")]
319 pub ignore: Option<Vec<String>>,
320
321 #[serde(skip_serializing_if = "Option::is_none")]
324 pub allow_fallbacks: Option<bool>,
325
326 #[serde(skip_serializing_if = "Option::is_none")]
330 pub require_parameters: Option<bool>,
331
332 #[serde(skip_serializing_if = "Option::is_none")]
335 pub data_collection: Option<DataCollection>,
336
337 #[serde(skip_serializing_if = "Option::is_none")]
339 pub zdr: Option<bool>,
340
341 #[serde(skip_serializing_if = "Option::is_none")]
345 pub sort: Option<ProviderSort>,
346
347 #[serde(skip_serializing_if = "Option::is_none")]
349 pub preferred_min_throughput: Option<ThroughputThreshold>,
350
351 #[serde(skip_serializing_if = "Option::is_none")]
353 pub preferred_max_latency: Option<LatencyThreshold>,
354
355 #[serde(skip_serializing_if = "Option::is_none")]
357 pub max_price: Option<MaxPrice>,
358
359 #[serde(skip_serializing_if = "Option::is_none")]
362 pub quantizations: Option<Vec<Quantization>>,
363}
364
365impl ProviderPreferences {
366 pub fn new() -> Self {
368 Self::default()
369 }
370
371 pub fn order(mut self, providers: impl IntoIterator<Item = impl Into<String>>) -> Self {
387 self.order = Some(providers.into_iter().map(|p| p.into()).collect());
388 self
389 }
390
391 pub fn only(mut self, providers: impl IntoIterator<Item = impl Into<String>>) -> Self {
403 self.only = Some(providers.into_iter().map(|p| p.into()).collect());
404 self
405 }
406
407 pub fn ignore(mut self, providers: impl IntoIterator<Item = impl Into<String>>) -> Self {
418 self.ignore = Some(providers.into_iter().map(|p| p.into()).collect());
419 self
420 }
421
422 pub fn allow_fallbacks(mut self, allow: bool) -> Self {
427 self.allow_fallbacks = Some(allow);
428 self
429 }
430
431 pub fn require_parameters(mut self, require: bool) -> Self {
437 self.require_parameters = Some(require);
438 self
439 }
440
441 pub fn data_collection(mut self, policy: DataCollection) -> Self {
445 self.data_collection = Some(policy);
446 self
447 }
448
449 pub fn zdr(mut self, enable: bool) -> Self {
460 self.zdr = Some(enable);
461 self
462 }
463
464 pub fn sort(mut self, sort: impl Into<ProviderSort>) -> Self {
480 self.sort = Some(sort.into());
481 self
482 }
483
484 pub fn preferred_min_throughput(mut self, threshold: ThroughputThreshold) -> Self {
504 self.preferred_min_throughput = Some(threshold);
505 self
506 }
507
508 pub fn preferred_max_latency(mut self, threshold: LatencyThreshold) -> Self {
512 self.preferred_max_latency = Some(threshold);
513 self
514 }
515
516 pub fn max_price(mut self, price: MaxPrice) -> Self {
520 self.max_price = Some(price);
521 self
522 }
523
524 pub fn quantizations(mut self, quantizations: impl IntoIterator<Item = Quantization>) -> Self {
537 self.quantizations = Some(quantizations.into_iter().collect());
538 self
539 }
540
541 pub fn zero_data_retention(self) -> Self {
545 self.zdr(true)
546 }
547
548 pub fn fastest(self) -> Self {
550 self.sort(ProviderSortStrategy::Throughput)
551 }
552
553 pub fn cheapest(self) -> Self {
555 self.sort(ProviderSortStrategy::Price)
556 }
557
558 pub fn lowest_latency(self) -> Self {
560 self.sort(ProviderSortStrategy::Latency)
561 }
562
563 pub fn to_json(&self) -> serde_json::Value {
565 serde_json::json!({
566 "provider": self
567 })
568 }
569}
570
571#[derive(Debug, Serialize, Deserialize)]
575pub struct CompletionResponse {
576 pub id: String,
577 pub object: String,
578 pub created: u64,
579 pub model: String,
580 pub choices: Vec<Choice>,
581 pub system_fingerprint: Option<String>,
582 pub usage: Option<Usage>,
583}
584
585impl From<ApiErrorResponse> for CompletionError {
586 fn from(err: ApiErrorResponse) -> Self {
587 CompletionError::ProviderError(err.message)
588 }
589}
590
591impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
592 type Error = CompletionError;
593
594 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
595 let choice = response.choices.first().ok_or_else(|| {
596 CompletionError::ResponseError("Response contained no choices".to_owned())
597 })?;
598
599 let content = match &choice.message {
600 Message::Assistant {
601 content,
602 tool_calls,
603 reasoning,
604 reasoning_details,
605 ..
606 } => {
607 let mut content = content
608 .iter()
609 .map(|c| match c {
610 openai::AssistantContent::Text { text } => {
611 completion::AssistantContent::text(text)
612 }
613 openai::AssistantContent::Refusal { refusal } => {
614 completion::AssistantContent::text(refusal)
615 }
616 })
617 .collect::<Vec<_>>();
618
619 content.extend(tool_calls.iter().map(|call| {
620 completion::AssistantContent::tool_call(
621 &call.id,
622 &call.function.name,
623 call.function.arguments.clone(),
624 )
625 }));
626
627 let mut grouped_reasoning: HashMap<
628 Option<String>,
629 Vec<(usize, usize, message::ReasoningContent)>,
630 > = HashMap::new();
631 let mut reasoning_order: Vec<Option<String>> = Vec::new();
632 for (position, detail) in reasoning_details.iter().enumerate() {
633 let (reasoning_id, sort_index, parsed_content) = match detail {
634 ReasoningDetails::Summary {
635 id, index, summary, ..
636 } => (
637 id.clone(),
638 *index,
639 Some(message::ReasoningContent::Summary(summary.clone())),
640 ),
641 ReasoningDetails::Encrypted {
642 id, index, data, ..
643 } => (
644 id.clone(),
645 *index,
646 Some(message::ReasoningContent::Encrypted(data.clone())),
647 ),
648 ReasoningDetails::Text {
649 id,
650 index,
651 text,
652 signature,
653 ..
654 } => (
655 id.clone(),
656 *index,
657 text.as_ref().map(|text| message::ReasoningContent::Text {
658 text: text.clone(),
659 signature: signature.clone(),
660 }),
661 ),
662 };
663
664 let Some(parsed_content) = parsed_content else {
665 continue;
666 };
667 let sort_index = sort_index.unwrap_or(position);
668
669 let entry = grouped_reasoning.entry(reasoning_id.clone());
670 if matches!(entry, std::collections::hash_map::Entry::Vacant(_)) {
671 reasoning_order.push(reasoning_id);
672 }
673 entry
674 .or_default()
675 .push((sort_index, position, parsed_content));
676 }
677
678 if grouped_reasoning.is_empty() {
679 if let Some(reasoning) = reasoning {
680 content.push(completion::AssistantContent::reasoning(reasoning));
681 }
682 } else {
683 for reasoning_id in reasoning_order {
684 let Some(mut blocks) = grouped_reasoning.remove(&reasoning_id) else {
685 continue;
686 };
687 blocks.sort_by_key(|(index, position, _)| (*index, *position));
688 content.push(completion::AssistantContent::Reasoning(
689 message::Reasoning {
690 id: reasoning_id,
691 content: blocks
692 .into_iter()
693 .map(|(_, _, content)| content)
694 .collect::<Vec<_>>(),
695 },
696 ));
697 }
698 }
699
700 Ok(content)
701 }
702 _ => Err(CompletionError::ResponseError(
703 "Response did not contain a valid message or tool call".into(),
704 )),
705 }?;
706
707 let choice = OneOrMany::many(content).map_err(|_| {
708 CompletionError::ResponseError(
709 "Response contained no message or tool call (empty)".to_owned(),
710 )
711 })?;
712
713 let usage = response
714 .usage
715 .as_ref()
716 .map(|usage| completion::Usage {
717 input_tokens: usage.prompt_tokens as u64,
718 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
719 total_tokens: usage.total_tokens as u64,
720 cached_input_tokens: 0,
721 })
722 .unwrap_or_default();
723
724 Ok(completion::CompletionResponse {
725 choice,
726 usage,
727 raw_response: response,
728 message_id: None,
729 })
730 }
731}
732
733#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
773#[serde(tag = "type", rename_all = "snake_case")]
774pub enum UserContent {
775 Text { text: String },
777
778 #[serde(rename = "image_url")]
782 ImageUrl { image_url: ImageUrl },
783
784 File { file: FileContent },
789
790 InputAudio { input_audio: openai::InputAudio },
794
795 #[serde(rename = "video_url")]
800 VideoUrl { video_url: VideoUrlContent },
801}
802
803impl UserContent {
804 pub fn text(text: impl Into<String>) -> Self {
806 UserContent::Text { text: text.into() }
807 }
808
809 pub fn image_url(url: impl Into<String>) -> Self {
811 UserContent::ImageUrl {
812 image_url: ImageUrl {
813 url: url.into(),
814 detail: None,
815 },
816 }
817 }
818
819 pub fn image_url_with_detail(url: impl Into<String>, detail: ImageDetail) -> Self {
821 UserContent::ImageUrl {
822 image_url: ImageUrl {
823 url: url.into(),
824 detail: Some(detail),
825 },
826 }
827 }
828
829 pub fn image_base64(
836 data: impl Into<String>,
837 mime_type: &str,
838 detail: Option<ImageDetail>,
839 ) -> Self {
840 let data_uri = format!("data:{};base64,{}", mime_type, data.into());
841 UserContent::ImageUrl {
842 image_url: ImageUrl {
843 url: data_uri,
844 detail,
845 },
846 }
847 }
848
849 pub fn file_url(url: impl Into<String>, filename: Option<String>) -> Self {
855 UserContent::File {
856 file: FileContent {
857 filename,
858 file_data: Some(url.into()),
859 },
860 }
861 }
862
863 pub fn file_base64(data: impl Into<String>, mime_type: &str, filename: Option<String>) -> Self {
870 let data_uri = format!("data:{};base64,{}", mime_type, data.into());
871 UserContent::File {
872 file: FileContent {
873 filename,
874 file_data: Some(data_uri),
875 },
876 }
877 }
878
879 pub fn audio_base64(data: impl Into<String>, format: AudioMediaType) -> Self {
887 UserContent::InputAudio {
888 input_audio: openai::InputAudio {
889 data: data.into(),
890 format,
891 },
892 }
893 }
894
895 pub fn video_url(url: impl Into<String>) -> Self {
902 UserContent::VideoUrl {
903 video_url: VideoUrlContent { url: url.into() },
904 }
905 }
906
907 pub fn video_base64(data: impl Into<String>, media_type: VideoMediaType) -> Self {
913 let mime = media_type.to_mime_type();
914 let data_uri = format!("data:{mime};base64,{}", data.into());
915 UserContent::VideoUrl {
916 video_url: VideoUrlContent { url: data_uri },
917 }
918 }
919}
920
921impl From<String> for UserContent {
922 fn from(text: String) -> Self {
923 UserContent::Text { text }
924 }
925}
926
927impl From<&str> for UserContent {
928 fn from(text: &str) -> Self {
929 UserContent::Text {
930 text: text.to_string(),
931 }
932 }
933}
934
935impl std::str::FromStr for UserContent {
936 type Err = std::convert::Infallible;
937
938 fn from_str(s: &str) -> Result<Self, Self::Err> {
939 Ok(UserContent::Text {
940 text: s.to_string(),
941 })
942 }
943}
944
945#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
947pub struct ImageUrl {
948 pub url: String,
950 #[serde(skip_serializing_if = "Option::is_none")]
952 pub detail: Option<ImageDetail>,
953}
954
955#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
961pub struct VideoUrlContent {
962 pub url: String,
964}
965
966#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
973pub struct FileContent {
974 #[serde(skip_serializing_if = "Option::is_none")]
976 pub filename: Option<String>,
977 #[serde(skip_serializing_if = "Option::is_none")]
979 pub file_data: Option<String>,
980}
981
982fn serialize_user_content<S>(
985 content: &OneOrMany<UserContent>,
986 serializer: S,
987) -> Result<S::Ok, S::Error>
988where
989 S: Serializer,
990{
991 if content.len() == 1
992 && let UserContent::Text { text } = content.first_ref()
993 {
994 return serializer.serialize_str(text);
995 }
996 content.serialize(serializer)
997}
998
999impl TryFrom<message::UserContent> for UserContent {
1000 type Error = message::MessageError;
1001
1002 fn try_from(value: message::UserContent) -> Result<Self, Self::Error> {
1003 match value {
1004 message::UserContent::Text(message::Text { text }) => Ok(UserContent::Text { text }),
1005
1006 message::UserContent::Image(message::Image {
1007 data,
1008 detail,
1009 media_type,
1010 ..
1011 }) => {
1012 let url = match data {
1013 DocumentSourceKind::Url(url) => url,
1014 DocumentSourceKind::Base64(data) => {
1015 let mime = media_type
1016 .ok_or_else(|| {
1017 message::MessageError::ConversionError(
1018 "Image media type required for base64 encoding".into(),
1019 )
1020 })?
1021 .to_mime_type();
1022 format!("data:{mime};base64,{data}")
1023 }
1024 DocumentSourceKind::Raw(_) => {
1025 return Err(message::MessageError::ConversionError(
1026 "Raw bytes not supported, encode as base64 first".into(),
1027 ));
1028 }
1029 DocumentSourceKind::String(_) => {
1030 return Err(message::MessageError::ConversionError(
1031 "String source not supported for images".into(),
1032 ));
1033 }
1034 DocumentSourceKind::Unknown => {
1035 return Err(message::MessageError::ConversionError(
1036 "Image has no data".into(),
1037 ));
1038 }
1039 };
1040 Ok(UserContent::ImageUrl {
1041 image_url: ImageUrl { url, detail },
1042 })
1043 }
1044
1045 message::UserContent::Document(message::Document {
1046 data, media_type, ..
1047 }) => match data {
1048 DocumentSourceKind::Url(url) => {
1049 let filename = media_type.as_ref().map(|mt| match mt {
1050 DocumentMediaType::PDF => "document.pdf",
1051 DocumentMediaType::TXT => "document.txt",
1052 DocumentMediaType::HTML => "document.html",
1053 DocumentMediaType::MARKDOWN => "document.md",
1054 DocumentMediaType::CSV => "document.csv",
1055 DocumentMediaType::XML => "document.xml",
1056 _ => "document",
1057 });
1058 Ok(UserContent::File {
1059 file: FileContent {
1060 filename: filename.map(String::from),
1061 file_data: Some(url),
1062 },
1063 })
1064 }
1065 DocumentSourceKind::Base64(data) => {
1066 let mime = media_type
1067 .as_ref()
1068 .map(|m| m.to_mime_type())
1069 .unwrap_or("application/pdf");
1070 let data_uri = format!("data:{mime};base64,{data}");
1071
1072 let filename = media_type.as_ref().map(|mt| match mt {
1073 DocumentMediaType::PDF => "document.pdf",
1074 DocumentMediaType::TXT => "document.txt",
1075 DocumentMediaType::HTML => "document.html",
1076 DocumentMediaType::MARKDOWN => "document.md",
1077 DocumentMediaType::CSV => "document.csv",
1078 DocumentMediaType::XML => "document.xml",
1079 _ => "document",
1080 });
1081
1082 Ok(UserContent::File {
1083 file: FileContent {
1084 filename: filename.map(String::from),
1085 file_data: Some(data_uri),
1086 },
1087 })
1088 }
1089 DocumentSourceKind::String(text) => Ok(UserContent::Text { text }),
1090 DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
1091 "Raw bytes not supported for documents, encode as base64 first".into(),
1092 )),
1093 DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
1094 "Document has no data".into(),
1095 )),
1096 },
1097
1098 message::UserContent::Audio(message::Audio {
1099 data, media_type, ..
1100 }) => match data {
1101 DocumentSourceKind::Base64(data) => {
1102 let format = media_type.ok_or_else(|| {
1103 message::MessageError::ConversionError(
1104 "Audio media type required for base64 encoding".into(),
1105 )
1106 })?;
1107 Ok(UserContent::InputAudio {
1108 input_audio: openai::InputAudio { data, format },
1109 })
1110 }
1111 DocumentSourceKind::Url(_) => Err(message::MessageError::ConversionError(
1112 "OpenRouter does not support audio URLs, encode as base64 first".into(),
1113 )),
1114 DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
1115 "Raw bytes not supported for audio, encode as base64 first".into(),
1116 )),
1117 DocumentSourceKind::String(_) => Err(message::MessageError::ConversionError(
1118 "String source not supported for audio".into(),
1119 )),
1120 DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
1121 "Audio has no data".into(),
1122 )),
1123 },
1124
1125 message::UserContent::Video(message::Video {
1126 data, media_type, ..
1127 }) => {
1128 let url = match data {
1129 DocumentSourceKind::Url(url) => url,
1130 DocumentSourceKind::Base64(data) => {
1131 let mime = media_type
1132 .ok_or_else(|| {
1133 message::MessageError::ConversionError(
1134 "Video media type required for base64 encoding".into(),
1135 )
1136 })?
1137 .to_mime_type();
1138 format!("data:{mime};base64,{data}")
1139 }
1140 DocumentSourceKind::Raw(_) => {
1141 return Err(message::MessageError::ConversionError(
1142 "Raw bytes not supported for video, encode as base64 first".into(),
1143 ));
1144 }
1145 DocumentSourceKind::String(_) => {
1146 return Err(message::MessageError::ConversionError(
1147 "String source not supported for video".into(),
1148 ));
1149 }
1150 DocumentSourceKind::Unknown => {
1151 return Err(message::MessageError::ConversionError(
1152 "Video has no data".into(),
1153 ));
1154 }
1155 };
1156 Ok(UserContent::VideoUrl {
1157 video_url: VideoUrlContent { url },
1158 })
1159 }
1160
1161 message::UserContent::ToolResult(_) => Err(message::MessageError::ConversionError(
1162 "Tool results should be handled as separate messages".into(),
1163 )),
1164 }
1165 }
1166}
1167
1168impl TryFrom<OneOrMany<message::UserContent>> for Vec<Message> {
1169 type Error = message::MessageError;
1170
1171 fn try_from(value: OneOrMany<message::UserContent>) -> Result<Self, Self::Error> {
1172 let (tool_results, other_content): (Vec<_>, Vec<_>) = value
1173 .into_iter()
1174 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
1175
1176 if !tool_results.is_empty() {
1179 tool_results
1180 .into_iter()
1181 .map(|content| match content {
1182 message::UserContent::ToolResult(tool_result) => Ok(Message::ToolResult {
1183 tool_call_id: tool_result.id,
1184 content: tool_result
1185 .content
1186 .into_iter()
1187 .map(|c| match c {
1188 message::ToolResultContent::Text(message::Text { text }) => text,
1189 message::ToolResultContent::Image(_) => {
1190 "[Image content not supported in tool results]".to_string()
1191 }
1192 })
1193 .collect::<Vec<_>>()
1194 .join("\n"),
1195 }),
1196 _ => unreachable!(),
1197 })
1198 .collect::<Result<Vec<_>, _>>()
1199 } else {
1200 let user_content: Vec<UserContent> = other_content
1201 .into_iter()
1202 .map(|content| content.try_into())
1203 .collect::<Result<Vec<_>, _>>()?;
1204
1205 let content = OneOrMany::many(user_content)
1206 .expect("There must be content here if there were no tool result content");
1207
1208 Ok(vec![Message::User {
1209 content,
1210 name: None,
1211 }])
1212 }
1213 }
1214}
1215
1216#[derive(Debug, Deserialize, Serialize)]
1221pub struct Choice {
1222 pub index: usize,
1223 pub native_finish_reason: Option<String>,
1224 pub message: Message,
1225 pub finish_reason: Option<String>,
1226}
1227
1228#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1234#[serde(tag = "role", rename_all = "lowercase")]
1235pub enum Message {
1236 #[serde(alias = "developer")]
1237 System {
1238 #[serde(deserialize_with = "string_or_one_or_many")]
1239 content: OneOrMany<openai::SystemContent>,
1240 #[serde(skip_serializing_if = "Option::is_none")]
1241 name: Option<String>,
1242 },
1243 User {
1244 #[serde(
1245 deserialize_with = "string_or_one_or_many",
1246 serialize_with = "serialize_user_content"
1247 )]
1248 content: OneOrMany<UserContent>,
1249 #[serde(skip_serializing_if = "Option::is_none")]
1250 name: Option<String>,
1251 },
1252 Assistant {
1253 #[serde(default, deserialize_with = "json_utils::string_or_vec")]
1254 content: Vec<openai::AssistantContent>,
1255 #[serde(skip_serializing_if = "Option::is_none")]
1256 refusal: Option<String>,
1257 #[serde(skip_serializing_if = "Option::is_none")]
1258 audio: Option<openai::AudioAssistant>,
1259 #[serde(skip_serializing_if = "Option::is_none")]
1260 name: Option<String>,
1261 #[serde(
1262 default,
1263 deserialize_with = "json_utils::null_or_vec",
1264 skip_serializing_if = "Vec::is_empty"
1265 )]
1266 tool_calls: Vec<openai::ToolCall>,
1267 #[serde(skip_serializing_if = "Option::is_none")]
1268 reasoning: Option<String>,
1269 #[serde(default, skip_serializing_if = "Vec::is_empty")]
1270 reasoning_details: Vec<ReasoningDetails>,
1271 },
1272 #[serde(rename = "tool")]
1273 ToolResult {
1274 tool_call_id: String,
1275 content: String,
1276 },
1277}
1278
1279impl Message {
1280 pub fn system(content: &str) -> Self {
1281 Message::System {
1282 content: OneOrMany::one(content.to_owned().into()),
1283 name: None,
1284 }
1285 }
1286}
1287
1288#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1289#[serde(tag = "type", rename_all = "snake_case")]
1290pub enum ReasoningDetails {
1291 #[serde(rename = "reasoning.summary")]
1292 Summary {
1293 id: Option<String>,
1294 format: Option<String>,
1295 index: Option<usize>,
1296 summary: String,
1297 },
1298 #[serde(rename = "reasoning.encrypted")]
1299 Encrypted {
1300 id: Option<String>,
1301 format: Option<String>,
1302 index: Option<usize>,
1303 data: String,
1304 },
1305 #[serde(rename = "reasoning.text")]
1306 Text {
1307 id: Option<String>,
1308 format: Option<String>,
1309 index: Option<usize>,
1310 text: Option<String>,
1311 signature: Option<String>,
1312 },
1313}
1314
1315#[derive(Debug, Deserialize, PartialEq, Clone)]
1316#[serde(untagged)]
1317enum ToolCallAdditionalParams {
1318 ReasoningDetails(ReasoningDetails),
1319 Minimal {
1320 id: Option<String>,
1321 format: Option<String>,
1322 },
1323}
1324
1325impl From<openai::UserContent> for UserContent {
1327 fn from(value: openai::UserContent) -> Self {
1328 match value {
1329 openai::UserContent::Text { text } => UserContent::Text { text },
1330 openai::UserContent::Image { image_url } => UserContent::ImageUrl {
1331 image_url: ImageUrl {
1332 url: image_url.url,
1333 detail: Some(image_url.detail),
1334 },
1335 },
1336 openai::UserContent::Audio { input_audio } => UserContent::InputAudio { input_audio },
1337 }
1338 }
1339}
1340
1341impl From<openai::Message> for Message {
1342 fn from(value: openai::Message) -> Self {
1343 match value {
1344 openai::Message::System { content, name } => Self::System { content, name },
1345 openai::Message::User { content, name } => {
1346 let converted_content = content.map(UserContent::from);
1348 Self::User {
1349 content: converted_content,
1350 name,
1351 }
1352 }
1353 openai::Message::Assistant {
1354 content,
1355 refusal,
1356 audio,
1357 name,
1358 tool_calls,
1359 } => Self::Assistant {
1360 content,
1361 refusal,
1362 audio,
1363 name,
1364 tool_calls,
1365 reasoning: None,
1366 reasoning_details: Vec::new(),
1367 },
1368 openai::Message::ToolResult {
1369 tool_call_id,
1370 content,
1371 } => Self::ToolResult {
1372 tool_call_id,
1373 content: content.as_text(),
1374 },
1375 }
1376 }
1377}
1378
1379impl TryFrom<OneOrMany<message::AssistantContent>> for Vec<Message> {
1380 type Error = message::MessageError;
1381
1382 fn try_from(value: OneOrMany<message::AssistantContent>) -> Result<Self, Self::Error> {
1383 let mut text_content = Vec::new();
1384 let mut tool_calls = Vec::new();
1385 let mut reasoning = None;
1386 let mut reasoning_details = Vec::new();
1387
1388 for content in value.into_iter() {
1389 match content {
1390 message::AssistantContent::Text(text) => text_content.push(text),
1391 message::AssistantContent::ToolCall(tool_call) => {
1392 if let Some(additional_params) = &tool_call.additional_params
1398 && let Ok(additional_params) =
1399 serde_json::from_value::<ToolCallAdditionalParams>(
1400 additional_params.clone(),
1401 )
1402 {
1403 match additional_params {
1404 ToolCallAdditionalParams::ReasoningDetails(full) => {
1405 reasoning_details.push(full);
1406 }
1407 ToolCallAdditionalParams::Minimal { id, format } => {
1408 let id = id.or_else(|| tool_call.call_id.clone());
1409 if let Some(signature) = &tool_call.signature
1410 && let Some(id) = id
1411 {
1412 reasoning_details.push(ReasoningDetails::Encrypted {
1413 id: Some(id),
1414 format,
1415 index: None,
1416 data: signature.clone(),
1417 })
1418 }
1419 }
1420 }
1421 } else if let Some(signature) = &tool_call.signature {
1422 reasoning_details.push(ReasoningDetails::Encrypted {
1423 id: tool_call.call_id.clone(),
1424 format: None,
1425 index: None,
1426 data: signature.clone(),
1427 });
1428 }
1429 tool_calls.push(tool_call.into())
1430 }
1431 message::AssistantContent::Reasoning(r) => {
1432 if r.content.is_empty() {
1433 let display = r.display_text();
1434 if !display.is_empty() {
1435 reasoning = Some(display);
1436 }
1437 } else {
1438 for reasoning_block in &r.content {
1439 let index = Some(reasoning_details.len());
1440 match reasoning_block {
1441 message::ReasoningContent::Text { text, signature } => {
1442 reasoning_details.push(ReasoningDetails::Text {
1443 id: r.id.clone(),
1444 format: None,
1445 index,
1446 text: Some(text.clone()),
1447 signature: signature.clone(),
1448 });
1449 }
1450 message::ReasoningContent::Summary(summary) => {
1451 reasoning_details.push(ReasoningDetails::Summary {
1452 id: r.id.clone(),
1453 format: None,
1454 index,
1455 summary: summary.clone(),
1456 });
1457 }
1458 message::ReasoningContent::Encrypted(data)
1459 | message::ReasoningContent::Redacted { data } => {
1460 reasoning_details.push(ReasoningDetails::Encrypted {
1461 id: r.id.clone(),
1462 format: None,
1463 index,
1464 data: data.clone(),
1465 });
1466 }
1467 }
1468 }
1469 }
1470 }
1471 message::AssistantContent::Image(_) => {
1472 return Err(Self::Error::ConversionError(
1473 "OpenRouter currently doesn't support images.".into(),
1474 ));
1475 }
1476 }
1477 }
1478
1479 Ok(vec![Message::Assistant {
1482 content: text_content
1483 .into_iter()
1484 .map(|content| content.text.into())
1485 .collect::<Vec<_>>(),
1486 refusal: None,
1487 audio: None,
1488 name: None,
1489 tool_calls,
1490 reasoning,
1491 reasoning_details,
1492 }])
1493 }
1494}
1495
1496impl TryFrom<message::Message> for Vec<Message> {
1499 type Error = message::MessageError;
1500
1501 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
1502 match message {
1503 message::Message::System { content } => Ok(vec![Message::System {
1504 content: OneOrMany::one(content.into()),
1505 name: None,
1506 }]),
1507 message::Message::User { content } => {
1508 content.try_into()
1511 }
1512 message::Message::Assistant { content, .. } => content.try_into(),
1513 }
1514 }
1515}
1516
1517#[derive(Debug, Serialize, Deserialize)]
1518#[serde(untagged, rename_all = "snake_case")]
1519pub enum ToolChoice {
1520 None,
1521 Auto,
1522 Required,
1523 Function(Vec<ToolChoiceFunctionKind>),
1524}
1525
1526impl TryFrom<crate::message::ToolChoice> for ToolChoice {
1527 type Error = CompletionError;
1528
1529 fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
1530 let res = match value {
1531 crate::message::ToolChoice::None => Self::None,
1532 crate::message::ToolChoice::Auto => Self::Auto,
1533 crate::message::ToolChoice::Required => Self::Required,
1534 crate::message::ToolChoice::Specific { function_names } => {
1535 let vec: Vec<ToolChoiceFunctionKind> = function_names
1536 .into_iter()
1537 .map(|name| ToolChoiceFunctionKind::Function { name })
1538 .collect();
1539
1540 Self::Function(vec)
1541 }
1542 };
1543
1544 Ok(res)
1545 }
1546}
1547
1548#[derive(Debug, Serialize, Deserialize)]
1549#[serde(tag = "type", content = "function")]
1550pub enum ToolChoiceFunctionKind {
1551 Function { name: String },
1552}
1553
1554#[derive(Debug, Serialize, Deserialize)]
1555pub(super) struct OpenrouterCompletionRequest {
1556 model: String,
1557 pub messages: Vec<Message>,
1558 #[serde(skip_serializing_if = "Option::is_none")]
1559 temperature: Option<f64>,
1560 #[serde(skip_serializing_if = "Vec::is_empty")]
1561 tools: Vec<crate::providers::openai::completion::ToolDefinition>,
1562 #[serde(skip_serializing_if = "Option::is_none")]
1563 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
1564 #[serde(flatten, skip_serializing_if = "Option::is_none")]
1565 pub additional_params: Option<serde_json::Value>,
1566}
1567
1568pub struct OpenRouterRequestParams<'a> {
1570 pub model: &'a str,
1571 pub request: CompletionRequest,
1572 pub strict_tools: bool,
1573}
1574
1575impl TryFrom<OpenRouterRequestParams<'_>> for OpenrouterCompletionRequest {
1576 type Error = CompletionError;
1577
1578 fn try_from(params: OpenRouterRequestParams) -> Result<Self, Self::Error> {
1579 let OpenRouterRequestParams {
1580 model,
1581 request: req,
1582 strict_tools,
1583 } = params;
1584 let model = req.model.clone().unwrap_or_else(|| model.to_string());
1585
1586 if req.output_schema.is_some() {
1587 tracing::warn!("Structured outputs currently not supported for OpenRouter");
1588 }
1589
1590 let mut full_history: Vec<Message> = match &req.preamble {
1591 Some(preamble) => vec![Message::system(preamble)],
1592 None => vec![],
1593 };
1594 if let Some(docs) = req.normalized_documents() {
1595 let docs: Vec<Message> = docs.try_into()?;
1596 full_history.extend(docs);
1597 }
1598
1599 let chat_history: Vec<Message> = req
1600 .chat_history
1601 .clone()
1602 .into_iter()
1603 .map(|message| message.try_into())
1604 .collect::<Result<Vec<Vec<Message>>, _>>()?
1605 .into_iter()
1606 .flatten()
1607 .collect();
1608
1609 full_history.extend(chat_history);
1610
1611 let tool_choice = req
1612 .tool_choice
1613 .clone()
1614 .map(crate::providers::openai::completion::ToolChoice::try_from)
1615 .transpose()?;
1616
1617 let tools: Vec<crate::providers::openai::completion::ToolDefinition> = req
1618 .tools
1619 .clone()
1620 .into_iter()
1621 .map(|tool| {
1622 let def = crate::providers::openai::completion::ToolDefinition::from(tool);
1623 if strict_tools { def.with_strict() } else { def }
1624 })
1625 .collect();
1626
1627 Ok(Self {
1628 model,
1629 messages: full_history,
1630 temperature: req.temperature,
1631 tools,
1632 tool_choice,
1633 additional_params: req.additional_params,
1634 })
1635 }
1636}
1637
1638impl TryFrom<(&str, CompletionRequest)> for OpenrouterCompletionRequest {
1639 type Error = CompletionError;
1640
1641 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
1642 let model = req.model.clone().unwrap_or_else(|| model.to_string());
1643 OpenrouterCompletionRequest::try_from(OpenRouterRequestParams {
1644 model: &model,
1645 request: req,
1646 strict_tools: false,
1647 })
1648 }
1649}
1650
1651#[derive(Clone)]
1652pub struct CompletionModel<T = reqwest::Client> {
1653 pub(crate) client: Client<T>,
1654 pub model: String,
1655 pub strict_tools: bool,
1658}
1659
1660impl<T> CompletionModel<T> {
1661 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
1662 Self {
1663 client,
1664 model: model.into(),
1665 strict_tools: false,
1666 }
1667 }
1668
1669 pub fn with_strict_tools(mut self) -> Self {
1678 self.strict_tools = true;
1679 self
1680 }
1681}
1682
1683impl<T> completion::CompletionModel for CompletionModel<T>
1684where
1685 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
1686{
1687 type Response = CompletionResponse;
1688 type StreamingResponse = StreamingCompletionResponse;
1689
1690 type Client = Client<T>;
1691
1692 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1693 Self::new(client.clone(), model)
1694 }
1695
1696 async fn completion(
1697 &self,
1698 completion_request: CompletionRequest,
1699 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
1700 let request_model = completion_request
1701 .model
1702 .clone()
1703 .unwrap_or_else(|| self.model.clone());
1704 let preamble = completion_request.preamble.clone();
1705 let request = OpenrouterCompletionRequest::try_from(OpenRouterRequestParams {
1706 model: request_model.as_ref(),
1707 request: completion_request,
1708 strict_tools: self.strict_tools,
1709 })?;
1710
1711 if enabled!(Level::TRACE) {
1712 tracing::trace!(
1713 target: "rig::completions",
1714 "OpenRouter completion request: {}",
1715 serde_json::to_string_pretty(&request)?
1716 );
1717 }
1718
1719 let span = if tracing::Span::current().is_disabled() {
1720 info_span!(
1721 target: "rig::completions",
1722 "chat",
1723 gen_ai.operation.name = "chat",
1724 gen_ai.provider.name = "openrouter",
1725 gen_ai.request.model = &request_model,
1726 gen_ai.system_instructions = preamble,
1727 gen_ai.response.id = tracing::field::Empty,
1728 gen_ai.response.model = tracing::field::Empty,
1729 gen_ai.usage.output_tokens = tracing::field::Empty,
1730 gen_ai.usage.input_tokens = tracing::field::Empty,
1731 gen_ai.usage.cached_tokens = tracing::field::Empty,
1732 )
1733 } else {
1734 tracing::Span::current()
1735 };
1736
1737 let body = serde_json::to_vec(&request)?;
1738
1739 let req = self
1740 .client
1741 .post("/chat/completions")?
1742 .body(body)
1743 .map_err(|x| CompletionError::HttpError(x.into()))?;
1744
1745 async move {
1746 let response = self.client.send::<_, Bytes>(req).await?;
1747 let status = response.status();
1748 let response_body = response.into_body().into_future().await?.to_vec();
1749
1750 if status.is_success() {
1751 let parsed: ApiResponse<CompletionResponse> =
1752 serde_json::from_slice(&response_body).map_err(|e| {
1753 CompletionError::ResponseError(format!(
1754 "Failed to parse OpenRouter completion response: {}, response body: {}",
1755 e,
1756 String::from_utf8_lossy(&response_body)
1757 ))
1758 })?;
1759 match parsed {
1760 ApiResponse::Ok(response) => {
1761 let span = tracing::Span::current();
1762 span.record_token_usage(&response.usage);
1763 span.record("gen_ai.response.id", &response.id);
1764 span.record("gen_ai.response.model_name", &response.model);
1765
1766 tracing::debug!(target: "rig::completions",
1767 "OpenRouter response: {response:?}");
1768 response.try_into()
1769 }
1770 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
1771 }
1772 } else {
1773 Err(CompletionError::ProviderError(
1774 String::from_utf8_lossy(&response_body).to_string(),
1775 ))
1776 }
1777 }
1778 .instrument(span)
1779 .await
1780 }
1781
1782 async fn stream(
1783 &self,
1784 completion_request: CompletionRequest,
1785 ) -> Result<
1786 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1787 CompletionError,
1788 > {
1789 CompletionModel::stream(self, completion_request).await
1790 }
1791}
1792
1793#[cfg(test)]
1794mod tests {
1795 use super::*;
1796 use serde_json::json;
1797
1798 #[test]
1799 fn test_openrouter_request_uses_request_model_override() {
1800 let request = CompletionRequest {
1801 model: Some("google/gemini-2.5-flash".to_string()),
1802 preamble: None,
1803 chat_history: crate::OneOrMany::one("Hello".into()),
1804 documents: vec![],
1805 tools: vec![],
1806 temperature: None,
1807 max_tokens: None,
1808 tool_choice: None,
1809 additional_params: None,
1810 output_schema: None,
1811 };
1812
1813 let openrouter_request =
1814 OpenrouterCompletionRequest::try_from(("openai/gpt-4o-mini", request))
1815 .expect("request conversion should succeed");
1816 let serialized =
1817 serde_json::to_value(openrouter_request).expect("serialization should succeed");
1818
1819 assert_eq!(serialized["model"], "google/gemini-2.5-flash");
1820 }
1821
1822 #[test]
1823 fn test_openrouter_request_uses_default_model_when_override_unset() {
1824 let request = CompletionRequest {
1825 model: None,
1826 preamble: None,
1827 chat_history: crate::OneOrMany::one("Hello".into()),
1828 documents: vec![],
1829 tools: vec![],
1830 temperature: None,
1831 max_tokens: None,
1832 tool_choice: None,
1833 additional_params: None,
1834 output_schema: None,
1835 };
1836
1837 let openrouter_request =
1838 OpenrouterCompletionRequest::try_from(("openai/gpt-4o-mini", request))
1839 .expect("request conversion should succeed");
1840 let serialized =
1841 serde_json::to_value(openrouter_request).expect("serialization should succeed");
1842
1843 assert_eq!(serialized["model"], "openai/gpt-4o-mini");
1844 }
1845
1846 #[test]
1847 fn test_completion_response_deserialization_gemini_flash() {
1848 let json = json!({
1850 "id": "gen-AAAAAAAAAA-AAAAAAAAAAAAAAAAAAAA",
1851 "provider": "Google",
1852 "model": "google/gemini-2.5-flash",
1853 "object": "chat.completion",
1854 "created": 1765971703u64,
1855 "choices": [{
1856 "logprobs": null,
1857 "finish_reason": "stop",
1858 "native_finish_reason": "STOP",
1859 "index": 0,
1860 "message": {
1861 "role": "assistant",
1862 "content": "CONTENT",
1863 "refusal": null,
1864 "reasoning": null
1865 }
1866 }],
1867 "usage": {
1868 "prompt_tokens": 669,
1869 "completion_tokens": 5,
1870 "total_tokens": 674
1871 }
1872 });
1873
1874 let response: CompletionResponse = serde_json::from_value(json).unwrap();
1875 assert_eq!(response.id, "gen-AAAAAAAAAA-AAAAAAAAAAAAAAAAAAAA");
1876 assert_eq!(response.model, "google/gemini-2.5-flash");
1877 assert_eq!(response.choices.len(), 1);
1878 assert_eq!(response.choices[0].finish_reason, Some("stop".to_string()));
1879 }
1880
1881 #[test]
1882 fn test_message_assistant_without_reasoning_details() {
1883 let json = json!({
1885 "role": "assistant",
1886 "content": "Hello world",
1887 "refusal": null,
1888 "reasoning": null
1889 });
1890
1891 let message: Message = serde_json::from_value(json).unwrap();
1892 match message {
1893 Message::Assistant {
1894 content,
1895 reasoning_details,
1896 ..
1897 } => {
1898 assert_eq!(content.len(), 1);
1899 assert!(reasoning_details.is_empty());
1900 }
1901 _ => panic!("Expected Assistant message"),
1902 }
1903 }
1904
1905 #[test]
1906 fn test_data_collection_serialization() {
1907 assert_eq!(
1908 serde_json::to_string(&DataCollection::Allow).unwrap(),
1909 r#""allow""#
1910 );
1911 assert_eq!(
1912 serde_json::to_string(&DataCollection::Deny).unwrap(),
1913 r#""deny""#
1914 );
1915 }
1916
1917 #[test]
1918 fn test_data_collection_default() {
1919 assert_eq!(DataCollection::default(), DataCollection::Allow);
1920 }
1921
1922 #[test]
1923 fn test_quantization_serialization() {
1924 assert_eq!(
1925 serde_json::to_string(&Quantization::Int4).unwrap(),
1926 r#""int4""#
1927 );
1928 assert_eq!(
1929 serde_json::to_string(&Quantization::Int8).unwrap(),
1930 r#""int8""#
1931 );
1932 assert_eq!(
1933 serde_json::to_string(&Quantization::Fp16).unwrap(),
1934 r#""fp16""#
1935 );
1936 assert_eq!(
1937 serde_json::to_string(&Quantization::Bf16).unwrap(),
1938 r#""bf16""#
1939 );
1940 assert_eq!(
1941 serde_json::to_string(&Quantization::Fp32).unwrap(),
1942 r#""fp32""#
1943 );
1944 assert_eq!(
1945 serde_json::to_string(&Quantization::Fp8).unwrap(),
1946 r#""fp8""#
1947 );
1948 assert_eq!(
1949 serde_json::to_string(&Quantization::Unknown).unwrap(),
1950 r#""unknown""#
1951 );
1952 }
1953
1954 #[test]
1955 fn test_provider_sort_strategy_serialization() {
1956 assert_eq!(
1957 serde_json::to_string(&ProviderSortStrategy::Price).unwrap(),
1958 r#""price""#
1959 );
1960 assert_eq!(
1961 serde_json::to_string(&ProviderSortStrategy::Throughput).unwrap(),
1962 r#""throughput""#
1963 );
1964 assert_eq!(
1965 serde_json::to_string(&ProviderSortStrategy::Latency).unwrap(),
1966 r#""latency""#
1967 );
1968 }
1969
1970 #[test]
1971 fn test_sort_partition_serialization() {
1972 assert_eq!(
1973 serde_json::to_string(&SortPartition::Model).unwrap(),
1974 r#""model""#
1975 );
1976 assert_eq!(
1977 serde_json::to_string(&SortPartition::None).unwrap(),
1978 r#""none""#
1979 );
1980 }
1981
1982 #[test]
1983 fn test_provider_sort_simple() {
1984 let sort = ProviderSort::Simple(ProviderSortStrategy::Latency);
1985 let json = serde_json::to_value(&sort).unwrap();
1986 assert_eq!(json, "latency");
1987 }
1988
1989 #[test]
1990 fn test_provider_sort_complex() {
1991 let sort = ProviderSort::Complex(
1992 ProviderSortConfig::new(ProviderSortStrategy::Price).partition(SortPartition::None),
1993 );
1994 let json = serde_json::to_value(&sort).unwrap();
1995 assert_eq!(json["by"], "price");
1996 assert_eq!(json["partition"], "none");
1997 }
1998
1999 #[test]
2000 fn test_provider_sort_complex_without_partition() {
2001 let sort = ProviderSort::Complex(ProviderSortConfig::new(ProviderSortStrategy::Throughput));
2002 let json = serde_json::to_value(&sort).unwrap();
2003 assert_eq!(json["by"], "throughput");
2004 assert!(json.get("partition").is_none());
2005 }
2006
2007 #[test]
2008 fn test_provider_sort_from_strategy() {
2009 let sort: ProviderSort = ProviderSortStrategy::Price.into();
2010 assert_eq!(sort, ProviderSort::Simple(ProviderSortStrategy::Price));
2011 }
2012
2013 #[test]
2014 fn test_provider_sort_from_config() {
2015 let config = ProviderSortConfig::new(ProviderSortStrategy::Latency);
2016 let sort: ProviderSort = config.into();
2017 match sort {
2018 ProviderSort::Complex(c) => assert_eq!(c.by, ProviderSortStrategy::Latency),
2019 _ => panic!("Expected Complex variant"),
2020 }
2021 }
2022
2023 #[test]
2024 fn test_percentile_thresholds_builder() {
2025 let thresholds = PercentileThresholds::new()
2026 .p50(10.0)
2027 .p75(25.0)
2028 .p90(50.0)
2029 .p99(100.0);
2030
2031 assert_eq!(thresholds.p50, Some(10.0));
2032 assert_eq!(thresholds.p75, Some(25.0));
2033 assert_eq!(thresholds.p90, Some(50.0));
2034 assert_eq!(thresholds.p99, Some(100.0));
2035 }
2036
2037 #[test]
2038 fn test_percentile_thresholds_default() {
2039 let thresholds = PercentileThresholds::default();
2040 assert_eq!(thresholds.p50, None);
2041 assert_eq!(thresholds.p75, None);
2042 assert_eq!(thresholds.p90, None);
2043 assert_eq!(thresholds.p99, None);
2044 }
2045
2046 #[test]
2047 fn test_throughput_threshold_simple() {
2048 let threshold = ThroughputThreshold::Simple(50.0);
2049 let json = serde_json::to_value(&threshold).unwrap();
2050 assert_eq!(json, 50.0);
2051 }
2052
2053 #[test]
2054 fn test_throughput_threshold_percentile() {
2055 let threshold = ThroughputThreshold::Percentile(PercentileThresholds::new().p90(50.0));
2056 let json = serde_json::to_value(&threshold).unwrap();
2057 assert_eq!(json["p90"], 50.0);
2058 }
2059
2060 #[test]
2061 fn test_latency_threshold_simple() {
2062 let threshold = LatencyThreshold::Simple(0.5);
2063 let json = serde_json::to_value(&threshold).unwrap();
2064 assert_eq!(json, 0.5);
2065 }
2066
2067 #[test]
2068 fn test_latency_threshold_percentile() {
2069 let threshold = LatencyThreshold::Percentile(PercentileThresholds::new().p50(0.1).p99(1.0));
2070 let json = serde_json::to_value(&threshold).unwrap();
2071 assert_eq!(json["p50"], 0.1);
2072 assert_eq!(json["p99"], 1.0);
2073 }
2074
2075 #[test]
2076 fn test_max_price_builder() {
2077 let price = MaxPrice::new().prompt(0.001).completion(0.002);
2078
2079 assert_eq!(price.prompt, Some(0.001));
2080 assert_eq!(price.completion, Some(0.002));
2081 assert_eq!(price.request, None);
2082 assert_eq!(price.image, None);
2083 }
2084
2085 #[test]
2086 fn test_max_price_all_fields() {
2087 let price = MaxPrice::new()
2088 .prompt(0.001)
2089 .completion(0.002)
2090 .request(0.01)
2091 .image(0.05);
2092
2093 let json = serde_json::to_value(&price).unwrap();
2094 assert_eq!(json["prompt"], 0.001);
2095 assert_eq!(json["completion"], 0.002);
2096 assert_eq!(json["request"], 0.01);
2097 assert_eq!(json["image"], 0.05);
2098 }
2099
2100 #[test]
2101 fn test_max_price_default() {
2102 let price = MaxPrice::default();
2103 assert_eq!(price.prompt, None);
2104 assert_eq!(price.completion, None);
2105 assert_eq!(price.request, None);
2106 assert_eq!(price.image, None);
2107 }
2108
2109 #[test]
2110 fn test_provider_preferences_default() {
2111 let prefs = ProviderPreferences::default();
2112 assert!(prefs.order.is_none());
2113 assert!(prefs.only.is_none());
2114 assert!(prefs.ignore.is_none());
2115 assert!(prefs.allow_fallbacks.is_none());
2116 assert!(prefs.require_parameters.is_none());
2117 assert!(prefs.data_collection.is_none());
2118 assert!(prefs.zdr.is_none());
2119 assert!(prefs.sort.is_none());
2120 assert!(prefs.preferred_min_throughput.is_none());
2121 assert!(prefs.preferred_max_latency.is_none());
2122 assert!(prefs.max_price.is_none());
2123 assert!(prefs.quantizations.is_none());
2124 }
2125
2126 #[test]
2127 fn test_provider_preferences_order_with_fallbacks() {
2128 let prefs = ProviderPreferences::new()
2129 .order(["anthropic", "openai"])
2130 .allow_fallbacks(true);
2131
2132 let json = prefs.to_json();
2133 let provider = &json["provider"];
2134
2135 assert_eq!(provider["order"], json!(["anthropic", "openai"]));
2136 assert_eq!(provider["allow_fallbacks"], true);
2137 }
2138
2139 #[test]
2140 fn test_provider_preferences_only_allowlist() {
2141 let prefs = ProviderPreferences::new()
2142 .only(["azure", "together"])
2143 .allow_fallbacks(false);
2144
2145 let json = prefs.to_json();
2146 let provider = &json["provider"];
2147
2148 assert_eq!(provider["only"], json!(["azure", "together"]));
2149 assert_eq!(provider["allow_fallbacks"], false);
2150 }
2151
2152 #[test]
2153 fn test_provider_preferences_ignore() {
2154 let prefs = ProviderPreferences::new().ignore(["deepinfra"]);
2155
2156 let json = prefs.to_json();
2157 let provider = &json["provider"];
2158
2159 assert_eq!(provider["ignore"], json!(["deepinfra"]));
2160 }
2161
2162 #[test]
2163 fn test_provider_preferences_sort_latency() {
2164 let prefs = ProviderPreferences::new().sort(ProviderSortStrategy::Latency);
2165
2166 let json = prefs.to_json();
2167 let provider = &json["provider"];
2168
2169 assert_eq!(provider["sort"], "latency");
2170 }
2171
2172 #[test]
2173 fn test_provider_preferences_price_with_throughput() {
2174 let prefs = ProviderPreferences::new()
2175 .sort(ProviderSortStrategy::Price)
2176 .preferred_min_throughput(ThroughputThreshold::Percentile(
2177 PercentileThresholds::new().p90(50.0),
2178 ));
2179
2180 let json = prefs.to_json();
2181 let provider = &json["provider"];
2182
2183 assert_eq!(provider["sort"], "price");
2184 assert_eq!(provider["preferred_min_throughput"]["p90"], 50.0);
2185 }
2186
2187 #[test]
2188 fn test_provider_preferences_require_parameters() {
2189 let prefs = ProviderPreferences::new().require_parameters(true);
2190
2191 let json = prefs.to_json();
2192 let provider = &json["provider"];
2193
2194 assert_eq!(provider["require_parameters"], true);
2195 }
2196
2197 #[test]
2198 fn test_provider_preferences_data_policy_and_zdr() {
2199 let prefs = ProviderPreferences::new()
2200 .data_collection(DataCollection::Deny)
2201 .zdr(true);
2202
2203 let json = prefs.to_json();
2204 let provider = &json["provider"];
2205
2206 assert_eq!(provider["data_collection"], "deny");
2207 assert_eq!(provider["zdr"], true);
2208 }
2209
2210 #[test]
2211 fn test_provider_preferences_quantizations() {
2212 let prefs =
2213 ProviderPreferences::new().quantizations([Quantization::Int8, Quantization::Fp16]);
2214
2215 let json = prefs.to_json();
2216 let provider = &json["provider"];
2217
2218 assert_eq!(provider["quantizations"], json!(["int8", "fp16"]));
2219 }
2220
2221 #[test]
2222 fn test_provider_preferences_convenience_methods() {
2223 let prefs = ProviderPreferences::new().zero_data_retention().fastest();
2224
2225 assert_eq!(prefs.zdr, Some(true));
2226 assert_eq!(
2227 prefs.sort,
2228 Some(ProviderSort::Simple(ProviderSortStrategy::Throughput))
2229 );
2230
2231 let prefs2 = ProviderPreferences::new().cheapest();
2232 assert_eq!(
2233 prefs2.sort,
2234 Some(ProviderSort::Simple(ProviderSortStrategy::Price))
2235 );
2236
2237 let prefs3 = ProviderPreferences::new().lowest_latency();
2238 assert_eq!(
2239 prefs3.sort,
2240 Some(ProviderSort::Simple(ProviderSortStrategy::Latency))
2241 );
2242 }
2243
2244 #[test]
2245 fn test_provider_preferences_serialization_skips_none() {
2246 let prefs = ProviderPreferences::new().sort(ProviderSortStrategy::Price);
2247
2248 let json = serde_json::to_value(&prefs).unwrap();
2249
2250 assert_eq!(json["sort"], "price");
2251 assert!(json.get("order").is_none());
2252 assert!(json.get("only").is_none());
2253 assert!(json.get("ignore").is_none());
2254 assert!(json.get("zdr").is_none());
2255 }
2256
2257 #[test]
2258 fn test_provider_preferences_deserialization() {
2259 let json = json!({
2260 "order": ["anthropic", "openai"],
2261 "sort": "throughput",
2262 "data_collection": "deny",
2263 "zdr": true,
2264 "quantizations": ["int8", "fp16"]
2265 });
2266
2267 let prefs: ProviderPreferences = serde_json::from_value(json).unwrap();
2268
2269 assert_eq!(
2270 prefs.order,
2271 Some(vec!["anthropic".to_string(), "openai".to_string()])
2272 );
2273 assert_eq!(
2274 prefs.sort,
2275 Some(ProviderSort::Simple(ProviderSortStrategy::Throughput))
2276 );
2277 assert_eq!(prefs.data_collection, Some(DataCollection::Deny));
2278 assert_eq!(prefs.zdr, Some(true));
2279 assert_eq!(
2280 prefs.quantizations,
2281 Some(vec![Quantization::Int8, Quantization::Fp16])
2282 );
2283 }
2284
2285 #[test]
2286 fn test_provider_preferences_deserialization_complex_sort() {
2287 let json = json!({
2288 "sort": {
2289 "by": "latency",
2290 "partition": "model"
2291 }
2292 });
2293
2294 let prefs: ProviderPreferences = serde_json::from_value(json).unwrap();
2295
2296 match prefs.sort {
2297 Some(ProviderSort::Complex(config)) => {
2298 assert_eq!(config.by, ProviderSortStrategy::Latency);
2299 assert_eq!(config.partition, Some(SortPartition::Model));
2300 }
2301 _ => panic!("Expected Complex sort variant"),
2302 }
2303 }
2304
2305 #[test]
2306 fn test_provider_preferences_full_integration() {
2307 let prefs = ProviderPreferences::new()
2308 .order(["anthropic", "openai"])
2309 .only(["anthropic", "openai", "google"])
2310 .sort(ProviderSortStrategy::Throughput)
2311 .data_collection(DataCollection::Deny)
2312 .zdr(true)
2313 .quantizations([Quantization::Int8])
2314 .allow_fallbacks(false);
2315
2316 let json = prefs.to_json();
2317
2318 assert!(json.get("provider").is_some());
2319 let provider = &json["provider"];
2320 assert_eq!(provider["order"], json!(["anthropic", "openai"]));
2321 assert_eq!(provider["only"], json!(["anthropic", "openai", "google"]));
2322 assert_eq!(provider["sort"], "throughput");
2323 assert_eq!(provider["data_collection"], "deny");
2324 assert_eq!(provider["zdr"], true);
2325 assert_eq!(provider["quantizations"], json!(["int8"]));
2326 assert_eq!(provider["allow_fallbacks"], false);
2327 }
2328
2329 #[test]
2330 fn test_provider_preferences_max_price() {
2331 let prefs =
2332 ProviderPreferences::new().max_price(MaxPrice::new().prompt(0.001).completion(0.002));
2333
2334 let json = prefs.to_json();
2335 let provider = &json["provider"];
2336
2337 assert_eq!(provider["max_price"]["prompt"], 0.001);
2338 assert_eq!(provider["max_price"]["completion"], 0.002);
2339 }
2340
2341 #[test]
2342 fn test_provider_preferences_preferred_max_latency() {
2343 let prefs = ProviderPreferences::new().preferred_max_latency(LatencyThreshold::Simple(0.5));
2344
2345 let json = prefs.to_json();
2346 let provider = &json["provider"];
2347
2348 assert_eq!(provider["preferred_max_latency"], 0.5);
2349 }
2350
2351 #[test]
2352 fn test_provider_preferences_empty_arrays() {
2353 let prefs = ProviderPreferences::new()
2354 .order(Vec::<String>::new())
2355 .quantizations(Vec::<Quantization>::new());
2356
2357 let json = prefs.to_json();
2358 let provider = &json["provider"];
2359
2360 assert_eq!(provider["order"], json!([]));
2361 assert_eq!(provider["quantizations"], json!([]));
2362 }
2363
2364 #[test]
2369 fn test_user_content_text_serialization() {
2370 let content = UserContent::text("Hello, world!");
2371 let json = serde_json::to_value(&content).unwrap();
2372
2373 assert_eq!(json["type"], "text");
2374 assert_eq!(json["text"], "Hello, world!");
2375 }
2376
2377 #[test]
2378 fn test_user_content_image_url_serialization() {
2379 let content = UserContent::image_url("https://example.com/image.png");
2380 let json = serde_json::to_value(&content).unwrap();
2381
2382 assert_eq!(json["type"], "image_url");
2383 assert_eq!(json["image_url"]["url"], "https://example.com/image.png");
2384 assert!(json["image_url"].get("detail").is_none());
2385 }
2386
2387 #[test]
2388 fn test_user_content_image_url_with_detail_serialization() {
2389 let content =
2390 UserContent::image_url_with_detail("https://example.com/image.png", ImageDetail::High);
2391 let json = serde_json::to_value(&content).unwrap();
2392
2393 assert_eq!(json["type"], "image_url");
2394 assert_eq!(json["image_url"]["url"], "https://example.com/image.png");
2395 assert_eq!(json["image_url"]["detail"], "high");
2396 }
2397
2398 #[test]
2399 fn test_user_content_image_base64_serialization() {
2400 let content = UserContent::image_base64("SGVsbG8=", "image/png", Some(ImageDetail::Low));
2401 let json = serde_json::to_value(&content).unwrap();
2402
2403 assert_eq!(json["type"], "image_url");
2404 assert_eq!(json["image_url"]["url"], "data:image/png;base64,SGVsbG8=");
2405 assert_eq!(json["image_url"]["detail"], "low");
2406 }
2407
2408 #[test]
2409 fn test_user_content_file_url_serialization() {
2410 let content = UserContent::file_url(
2411 "https://example.com/doc.pdf",
2412 Some("document.pdf".to_string()),
2413 );
2414 let json = serde_json::to_value(&content).unwrap();
2415
2416 assert_eq!(json["type"], "file");
2417 assert_eq!(json["file"]["file_data"], "https://example.com/doc.pdf");
2418 assert_eq!(json["file"]["filename"], "document.pdf");
2419 }
2420
2421 #[test]
2422 fn test_user_content_file_base64_serialization() {
2423 let content = UserContent::file_base64(
2424 "JVBERi0xLjQ=",
2425 "application/pdf",
2426 Some("report.pdf".to_string()),
2427 );
2428 let json = serde_json::to_value(&content).unwrap();
2429
2430 assert_eq!(json["type"], "file");
2431 assert_eq!(
2432 json["file"]["file_data"],
2433 "data:application/pdf;base64,JVBERi0xLjQ="
2434 );
2435 assert_eq!(json["file"]["filename"], "report.pdf");
2436 }
2437
2438 #[test]
2439 fn test_user_content_text_deserialization() {
2440 let json = json!({
2441 "type": "text",
2442 "text": "Hello!"
2443 });
2444
2445 let content: UserContent = serde_json::from_value(json).unwrap();
2446 assert_eq!(
2447 content,
2448 UserContent::Text {
2449 text: "Hello!".to_string()
2450 }
2451 );
2452 }
2453
2454 #[test]
2455 fn test_user_content_image_url_deserialization() {
2456 let json = json!({
2457 "type": "image_url",
2458 "image_url": {
2459 "url": "https://example.com/img.jpg",
2460 "detail": "high"
2461 }
2462 });
2463
2464 let content: UserContent = serde_json::from_value(json).unwrap();
2465 match content {
2466 UserContent::ImageUrl { image_url } => {
2467 assert_eq!(image_url.url, "https://example.com/img.jpg");
2468 assert_eq!(image_url.detail, Some(ImageDetail::High));
2469 }
2470 _ => panic!("Expected ImageUrl variant"),
2471 }
2472 }
2473
2474 #[test]
2475 fn test_user_content_file_deserialization() {
2476 let json = json!({
2477 "type": "file",
2478 "file": {
2479 "filename": "doc.pdf",
2480 "file_data": "https://example.com/doc.pdf"
2481 }
2482 });
2483
2484 let content: UserContent = serde_json::from_value(json).unwrap();
2485 match content {
2486 UserContent::File { file } => {
2487 assert_eq!(file.filename, Some("doc.pdf".to_string()));
2488 assert_eq!(
2489 file.file_data,
2490 Some("https://example.com/doc.pdf".to_string())
2491 );
2492 }
2493 _ => panic!("Expected File variant"),
2494 }
2495 }
2496
2497 #[test]
2498 fn test_message_user_with_text_serialization() {
2499 let message = Message::User {
2500 content: OneOrMany::one(UserContent::text("Hello")),
2501 name: None,
2502 };
2503 let json = serde_json::to_value(&message).unwrap();
2504
2505 assert_eq!(json["role"], "user");
2507 assert_eq!(json["content"], "Hello");
2508 }
2509
2510 #[test]
2511 fn test_message_user_with_mixed_content_serialization() {
2512 let message = Message::User {
2513 content: OneOrMany::many(vec![
2514 UserContent::text("Check this image:"),
2515 UserContent::image_url("https://example.com/img.png"),
2516 ])
2517 .unwrap(),
2518 name: None,
2519 };
2520 let json = serde_json::to_value(&message).unwrap();
2521
2522 assert_eq!(json["role"], "user");
2523 let content = json["content"].as_array().unwrap();
2524 assert_eq!(content.len(), 2);
2525 assert_eq!(content[0]["type"], "text");
2526 assert_eq!(content[1]["type"], "image_url");
2527 }
2528
2529 #[test]
2530 fn test_message_user_with_file_serialization() {
2531 let message = Message::User {
2532 content: OneOrMany::many(vec![
2533 UserContent::text("Analyze this PDF:"),
2534 UserContent::file_url(
2535 "https://example.com/doc.pdf",
2536 Some("document.pdf".to_string()),
2537 ),
2538 ])
2539 .unwrap(),
2540 name: None,
2541 };
2542 let json = serde_json::to_value(&message).unwrap();
2543
2544 assert_eq!(json["role"], "user");
2545 let content = json["content"].as_array().unwrap();
2546 assert_eq!(content.len(), 2);
2547 assert_eq!(content[0]["type"], "text");
2548 assert_eq!(content[1]["type"], "file");
2549 assert_eq!(
2550 content[1]["file"]["file_data"],
2551 "https://example.com/doc.pdf"
2552 );
2553 }
2554
2555 #[test]
2556 fn test_user_content_from_rig_text() {
2557 let rig_content = message::UserContent::Text(message::Text {
2558 text: "Hello".to_string(),
2559 });
2560 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2561
2562 assert_eq!(
2563 openrouter_content,
2564 UserContent::Text {
2565 text: "Hello".to_string()
2566 }
2567 );
2568 }
2569
2570 #[test]
2571 fn test_user_content_from_rig_image_url() {
2572 let rig_content = message::UserContent::Image(message::Image {
2573 data: DocumentSourceKind::Url("https://example.com/img.png".to_string()),
2574 media_type: Some(message::ImageMediaType::PNG),
2575 detail: Some(ImageDetail::High),
2576 additional_params: None,
2577 });
2578 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2579
2580 match openrouter_content {
2581 UserContent::ImageUrl { image_url } => {
2582 assert_eq!(image_url.url, "https://example.com/img.png");
2583 assert_eq!(image_url.detail, Some(ImageDetail::High));
2584 }
2585 _ => panic!("Expected ImageUrl variant"),
2586 }
2587 }
2588
2589 #[test]
2590 fn test_user_content_from_rig_image_base64() {
2591 let rig_content = message::UserContent::Image(message::Image {
2592 data: DocumentSourceKind::Base64("SGVsbG8=".to_string()),
2593 media_type: Some(message::ImageMediaType::JPEG),
2594 detail: Some(ImageDetail::Low),
2595 additional_params: None,
2596 });
2597 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2598
2599 match openrouter_content {
2600 UserContent::ImageUrl { image_url } => {
2601 assert_eq!(image_url.url, "data:image/jpeg;base64,SGVsbG8=");
2602 assert_eq!(image_url.detail, Some(ImageDetail::Low));
2603 }
2604 _ => panic!("Expected ImageUrl variant"),
2605 }
2606 }
2607
2608 #[test]
2609 fn test_user_content_from_rig_document_url() {
2610 let rig_content = message::UserContent::Document(message::Document {
2611 data: DocumentSourceKind::Url("https://example.com/doc.pdf".to_string()),
2612 media_type: Some(DocumentMediaType::PDF),
2613 additional_params: None,
2614 });
2615 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2616
2617 match openrouter_content {
2618 UserContent::File { file } => {
2619 assert_eq!(
2620 file.file_data,
2621 Some("https://example.com/doc.pdf".to_string())
2622 );
2623 assert_eq!(file.filename, Some("document.pdf".to_string()));
2624 }
2625 _ => panic!("Expected File variant"),
2626 }
2627 }
2628
2629 #[test]
2630 fn test_user_content_from_rig_document_base64() {
2631 let rig_content = message::UserContent::Document(message::Document {
2632 data: DocumentSourceKind::Base64("JVBERi0xLjQ=".to_string()),
2633 media_type: Some(DocumentMediaType::PDF),
2634 additional_params: None,
2635 });
2636 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2637
2638 match openrouter_content {
2639 UserContent::File { file } => {
2640 assert_eq!(
2641 file.file_data,
2642 Some("data:application/pdf;base64,JVBERi0xLjQ=".to_string())
2643 );
2644 assert_eq!(file.filename, Some("document.pdf".to_string()));
2645 }
2646 _ => panic!("Expected File variant"),
2647 }
2648 }
2649
2650 #[test]
2651 fn test_user_content_from_rig_document_string_becomes_text() {
2652 let rig_content = message::UserContent::Document(message::Document {
2653 data: DocumentSourceKind::String("Plain text document content".to_string()),
2654 media_type: Some(DocumentMediaType::TXT),
2655 additional_params: None,
2656 });
2657 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2658
2659 assert_eq!(
2660 openrouter_content,
2661 UserContent::Text {
2662 text: "Plain text document content".to_string()
2663 }
2664 );
2665 }
2666
2667 #[test]
2668 fn test_completion_response_with_reasoning_details_maps_to_typed_reasoning() {
2669 let json = json!({
2670 "id": "resp_123",
2671 "object": "chat.completion",
2672 "created": 1,
2673 "model": "openrouter/test-model",
2674 "choices": [{
2675 "index": 0,
2676 "finish_reason": "stop",
2677 "message": {
2678 "role": "assistant",
2679 "content": "hello",
2680 "reasoning": null,
2681 "reasoning_details": [
2682 {"type":"reasoning.summary","id":"rs_1","summary":"s1"},
2683 {"type":"reasoning.text","id":"rs_1","text":"t1","signature":"sig_1"},
2684 {"type":"reasoning.encrypted","id":"rs_1","data":"enc_1"}
2685 ]
2686 }
2687 }]
2688 });
2689
2690 let response: CompletionResponse = serde_json::from_value(json).unwrap();
2691 let converted: completion::CompletionResponse<CompletionResponse> =
2692 response.try_into().unwrap();
2693 let items: Vec<completion::AssistantContent> = converted.choice.into_iter().collect();
2694
2695 assert!(items.iter().any(|item| matches!(
2696 item,
2697 completion::AssistantContent::Reasoning(message::Reasoning { id: Some(id), content })
2698 if id == "rs_1" && content.len() == 3
2699 )));
2700 }
2701
2702 #[test]
2703 fn test_assistant_reasoning_emits_openrouter_reasoning_details() {
2704 let reasoning = message::Reasoning {
2705 id: Some("rs_2".to_string()),
2706 content: vec![
2707 message::ReasoningContent::Text {
2708 text: "step".to_string(),
2709 signature: Some("sig_step".to_string()),
2710 },
2711 message::ReasoningContent::Summary("summary".to_string()),
2712 message::ReasoningContent::Encrypted("enc_blob".to_string()),
2713 ],
2714 };
2715
2716 let messages = Vec::<Message>::try_from(OneOrMany::one(
2717 message::AssistantContent::Reasoning(reasoning),
2718 ))
2719 .unwrap();
2720 let Message::Assistant {
2721 reasoning,
2722 reasoning_details,
2723 ..
2724 } = messages.first().expect("assistant message")
2725 else {
2726 panic!("Expected assistant message");
2727 };
2728
2729 assert!(reasoning.is_none());
2730 assert_eq!(reasoning_details.len(), 3);
2731 assert!(matches!(
2732 reasoning_details.first(),
2733 Some(ReasoningDetails::Text {
2734 id: Some(id),
2735 text: Some(text),
2736 signature: Some(signature),
2737 ..
2738 }) if id == "rs_2" && text == "step" && signature == "sig_step"
2739 ));
2740 }
2741
2742 #[test]
2743 fn test_assistant_redacted_reasoning_emits_encrypted_detail_not_text() {
2744 let reasoning = message::Reasoning {
2745 id: Some("rs_redacted".to_string()),
2746 content: vec![message::ReasoningContent::Redacted {
2747 data: "opaque-redacted-data".to_string(),
2748 }],
2749 };
2750
2751 let messages = Vec::<Message>::try_from(OneOrMany::one(
2752 message::AssistantContent::Reasoning(reasoning),
2753 ))
2754 .unwrap();
2755
2756 let Message::Assistant {
2757 reasoning_details,
2758 reasoning,
2759 ..
2760 } = messages.first().expect("assistant message")
2761 else {
2762 panic!("Expected assistant message");
2763 };
2764
2765 assert!(reasoning.is_none());
2766 assert_eq!(reasoning_details.len(), 1);
2767 assert!(matches!(
2768 reasoning_details.first(),
2769 Some(ReasoningDetails::Encrypted {
2770 id: Some(id),
2771 data,
2772 ..
2773 }) if id == "rs_redacted" && data == "opaque-redacted-data"
2774 ));
2775 }
2776
2777 #[test]
2778 fn test_completion_response_reasoning_details_respects_index_ordering() {
2779 let json = json!({
2780 "id": "resp_ordering",
2781 "object": "chat.completion",
2782 "created": 1,
2783 "model": "openrouter/test-model",
2784 "choices": [{
2785 "index": 0,
2786 "finish_reason": "stop",
2787 "message": {
2788 "role": "assistant",
2789 "content": "hello",
2790 "reasoning": null,
2791 "reasoning_details": [
2792 {"type":"reasoning.summary","id":"rs_order","index":1,"summary":"second"},
2793 {"type":"reasoning.summary","id":"rs_order","index":0,"summary":"first"}
2794 ]
2795 }
2796 }]
2797 });
2798
2799 let response: CompletionResponse = serde_json::from_value(json).unwrap();
2800 let converted: completion::CompletionResponse<CompletionResponse> =
2801 response.try_into().unwrap();
2802 let items: Vec<completion::AssistantContent> = converted.choice.into_iter().collect();
2803 let reasoning_blocks: Vec<_> = items
2804 .into_iter()
2805 .filter_map(|item| match item {
2806 completion::AssistantContent::Reasoning(reasoning) => Some(reasoning),
2807 _ => None,
2808 })
2809 .collect();
2810
2811 assert_eq!(reasoning_blocks.len(), 1);
2812 assert_eq!(reasoning_blocks[0].id.as_deref(), Some("rs_order"));
2813 assert_eq!(
2814 reasoning_blocks[0].content,
2815 vec![
2816 message::ReasoningContent::Summary("first".to_string()),
2817 message::ReasoningContent::Summary("second".to_string()),
2818 ]
2819 );
2820 }
2821
2822 #[test]
2823 fn test_user_content_from_rig_image_missing_media_type_error() {
2824 let rig_content = message::UserContent::Image(message::Image {
2825 data: DocumentSourceKind::Base64("SGVsbG8=".to_string()),
2826 media_type: None, detail: None,
2828 additional_params: None,
2829 });
2830 let result: Result<UserContent, _> = rig_content.try_into();
2831
2832 assert!(result.is_err());
2833 let err = result.unwrap_err();
2834 assert!(err.to_string().contains("media type required"));
2835 }
2836
2837 #[test]
2838 fn test_user_content_from_rig_image_raw_bytes_error() {
2839 let rig_content = message::UserContent::Image(message::Image {
2840 data: DocumentSourceKind::Raw(vec![1, 2, 3]),
2841 media_type: Some(message::ImageMediaType::PNG),
2842 detail: None,
2843 additional_params: None,
2844 });
2845 let result: Result<UserContent, _> = rig_content.try_into();
2846
2847 assert!(result.is_err());
2848 let err = result.unwrap_err();
2849 assert!(err.to_string().contains("base64"));
2850 }
2851
2852 #[test]
2853 fn test_user_content_from_rig_video_url() {
2854 let rig_content = message::UserContent::Video(message::Video {
2855 data: DocumentSourceKind::Url("https://example.com/video.mp4".to_string()),
2856 media_type: Some(message::VideoMediaType::MP4),
2857 additional_params: None,
2858 });
2859 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2860
2861 match openrouter_content {
2862 UserContent::VideoUrl { video_url } => {
2863 assert_eq!(video_url.url, "https://example.com/video.mp4");
2864 }
2865 _ => panic!("Expected VideoUrl variant"),
2866 }
2867 }
2868
2869 #[test]
2870 fn test_user_content_from_rig_video_base64() {
2871 let rig_content = message::UserContent::Video(message::Video {
2872 data: DocumentSourceKind::Base64("SGVsbG8=".to_string()),
2873 media_type: Some(message::VideoMediaType::MP4),
2874 additional_params: None,
2875 });
2876 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2877
2878 match openrouter_content {
2879 UserContent::VideoUrl { video_url } => {
2880 assert_eq!(video_url.url, "data:video/mp4;base64,SGVsbG8=");
2881 }
2882 _ => panic!("Expected VideoUrl variant"),
2883 }
2884 }
2885
2886 #[test]
2887 fn test_user_content_from_rig_video_base64_missing_media_type_error() {
2888 let rig_content = message::UserContent::Video(message::Video {
2889 data: DocumentSourceKind::Base64("SGVsbG8=".to_string()),
2890 media_type: None,
2891 additional_params: None,
2892 });
2893 let result: Result<UserContent, _> = rig_content.try_into();
2894
2895 assert!(result.is_err());
2896 let err = result.unwrap_err();
2897 assert!(err.to_string().contains("media type"));
2898 }
2899
2900 #[test]
2901 fn test_user_content_from_rig_video_raw_bytes_error() {
2902 let rig_content = message::UserContent::Video(message::Video {
2903 data: DocumentSourceKind::Raw(vec![1, 2, 3]),
2904 media_type: Some(message::VideoMediaType::MP4),
2905 additional_params: None,
2906 });
2907 let result: Result<UserContent, _> = rig_content.try_into();
2908
2909 assert!(result.is_err());
2910 let err = result.unwrap_err();
2911 assert!(err.to_string().contains("base64"));
2912 }
2913
2914 #[test]
2915 fn test_user_content_from_rig_audio_base64() {
2916 let rig_content = message::UserContent::Audio(message::Audio {
2917 data: DocumentSourceKind::Base64("audiodata".to_string()),
2918 media_type: Some(message::AudioMediaType::MP3),
2919 additional_params: None,
2920 });
2921 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2922
2923 match openrouter_content {
2924 UserContent::InputAudio { input_audio } => {
2925 assert_eq!(input_audio.data, "audiodata");
2926 assert_eq!(input_audio.format, message::AudioMediaType::MP3);
2927 }
2928 _ => panic!("Expected InputAudio variant"),
2929 }
2930 }
2931
2932 #[test]
2933 fn test_user_content_from_rig_audio_missing_media_type_error() {
2934 let rig_content = message::UserContent::Audio(message::Audio {
2935 data: DocumentSourceKind::Base64("audiodata".to_string()),
2936 media_type: None, additional_params: None,
2938 });
2939 let result: Result<UserContent, _> = rig_content.try_into();
2940
2941 assert!(result.is_err());
2942 let err = result.unwrap_err();
2943 assert!(err.to_string().contains("media type required"));
2944 }
2945
2946 #[test]
2947 fn test_user_content_from_rig_audio_url_error() {
2948 let rig_content = message::UserContent::Audio(message::Audio {
2949 data: DocumentSourceKind::Url("https://example.com/audio.wav".to_string()),
2950 media_type: Some(message::AudioMediaType::WAV),
2951 additional_params: None,
2952 });
2953 let result: Result<UserContent, _> = rig_content.try_into();
2954
2955 assert!(result.is_err());
2956 let err = result.unwrap_err();
2957 assert!(err.to_string().contains("base64"));
2958 }
2959
2960 #[test]
2961 fn test_user_content_from_rig_audio_raw_bytes_error() {
2962 let rig_content = message::UserContent::Audio(message::Audio {
2963 data: DocumentSourceKind::Raw(vec![1, 2, 3]),
2964 media_type: Some(message::AudioMediaType::WAV),
2965 additional_params: None,
2966 });
2967 let result: Result<UserContent, _> = rig_content.try_into();
2968
2969 assert!(result.is_err());
2970 let err = result.unwrap_err();
2971 assert!(err.to_string().contains("base64"));
2972 }
2973
2974 #[test]
2975 fn test_message_conversion_with_pdf() {
2976 let rig_message = message::Message::User {
2977 content: OneOrMany::many(vec![
2978 message::UserContent::Text(message::Text {
2979 text: "Summarize this document".to_string(),
2980 }),
2981 message::UserContent::Document(message::Document {
2982 data: DocumentSourceKind::Url("https://example.com/paper.pdf".to_string()),
2983 media_type: Some(DocumentMediaType::PDF),
2984 additional_params: None,
2985 }),
2986 ])
2987 .unwrap(),
2988 };
2989
2990 let openrouter_messages: Vec<Message> = rig_message.try_into().unwrap();
2991 assert_eq!(openrouter_messages.len(), 1);
2992
2993 match &openrouter_messages[0] {
2994 Message::User { content, .. } => {
2995 assert_eq!(content.len(), 2);
2996
2997 match content.first_ref() {
2999 UserContent::Text { text } => assert_eq!(text, "Summarize this document"),
3000 _ => panic!("Expected Text"),
3001 }
3002 }
3003 _ => panic!("Expected User message"),
3004 }
3005 }
3006
3007 #[test]
3008 fn test_user_content_from_string() {
3009 let content: UserContent = "Hello".into();
3010 assert_eq!(
3011 content,
3012 UserContent::Text {
3013 text: "Hello".to_string()
3014 }
3015 );
3016
3017 let content: UserContent = String::from("World").into();
3018 assert_eq!(
3019 content,
3020 UserContent::Text {
3021 text: "World".to_string()
3022 }
3023 );
3024 }
3025
3026 #[test]
3027 fn test_openai_user_content_conversion() {
3028 let openai_text = openai::UserContent::Text {
3030 text: "Hello".to_string(),
3031 };
3032 let converted: UserContent = openai_text.into();
3033 assert_eq!(
3034 converted,
3035 UserContent::Text {
3036 text: "Hello".to_string()
3037 }
3038 );
3039
3040 let openai_image = openai::UserContent::Image {
3041 image_url: openai::ImageUrl {
3042 url: "https://example.com/img.png".to_string(),
3043 detail: ImageDetail::Auto,
3044 },
3045 };
3046 let converted: UserContent = openai_image.into();
3047 match converted {
3048 UserContent::ImageUrl { image_url } => {
3049 assert_eq!(image_url.url, "https://example.com/img.png");
3050 assert_eq!(image_url.detail, Some(ImageDetail::Auto));
3051 }
3052 _ => panic!("Expected ImageUrl"),
3053 }
3054
3055 let openai_audio = openai::UserContent::Audio {
3056 input_audio: openai::InputAudio {
3057 data: "audiodata".to_string(),
3058 format: AudioMediaType::FLAC,
3059 },
3060 };
3061 let converted: UserContent = openai_audio.into();
3062 match converted {
3063 UserContent::InputAudio { input_audio } => {
3064 assert_eq!(input_audio.data, "audiodata");
3065 assert_eq!(input_audio.format, AudioMediaType::FLAC);
3066 }
3067 _ => panic!("Expected InputAudio"),
3068 }
3069 }
3070
3071 #[test]
3072 fn test_completion_response_reasoning_details_with_multiple_ids_stay_separate() {
3073 let json = json!({
3074 "id": "resp_multi_id",
3075 "object": "chat.completion",
3076 "created": 1,
3077 "model": "openrouter/test-model",
3078 "choices": [{
3079 "index": 0,
3080 "finish_reason": "stop",
3081 "message": {
3082 "role": "assistant",
3083 "content": "hello",
3084 "reasoning": null,
3085 "reasoning_details": [
3086 {"type":"reasoning.summary","id":"rs_a","summary":"a1"},
3087 {"type":"reasoning.summary","id":"rs_b","summary":"b1"},
3088 {"type":"reasoning.summary","id":"rs_a","summary":"a2"}
3089 ]
3090 }
3091 }]
3092 });
3093
3094 let response: CompletionResponse = serde_json::from_value(json).unwrap();
3095 let converted: completion::CompletionResponse<CompletionResponse> =
3096 response.try_into().unwrap();
3097 let items: Vec<completion::AssistantContent> = converted.choice.into_iter().collect();
3098 let reasoning_blocks: Vec<_> = items
3099 .into_iter()
3100 .filter_map(|item| match item {
3101 completion::AssistantContent::Reasoning(reasoning) => Some(reasoning),
3102 _ => None,
3103 })
3104 .collect();
3105
3106 assert_eq!(reasoning_blocks.len(), 2);
3107 assert_eq!(reasoning_blocks[0].id.as_deref(), Some("rs_a"));
3108 assert_eq!(
3109 reasoning_blocks[0].content,
3110 vec![
3111 message::ReasoningContent::Summary("a1".to_string()),
3112 message::ReasoningContent::Summary("a2".to_string()),
3113 ]
3114 );
3115 assert_eq!(reasoning_blocks[1].id.as_deref(), Some("rs_b"));
3116 assert_eq!(
3117 reasoning_blocks[1].content,
3118 vec![message::ReasoningContent::Summary("b1".to_string())]
3119 );
3120 }
3121
3122 #[test]
3123 fn test_user_content_audio_serialization() {
3124 let content = UserContent::audio_base64("SGVsbG8=", AudioMediaType::WAV);
3125 let json = serde_json::to_value(&content).unwrap();
3126
3127 assert_eq!(json["type"], "input_audio");
3128 assert_eq!(json["input_audio"]["data"], "SGVsbG8=");
3129 assert_eq!(json["input_audio"]["format"], "wav");
3130 }
3131
3132 #[test]
3133 fn test_user_content_audio_deserialization() {
3134 let json = json!({
3135 "type": "input_audio",
3136 "input_audio": {
3137 "data": "SGVsbG8=",
3138 "format": "wav"
3139 }
3140 });
3141
3142 let content: UserContent = serde_json::from_value(json).unwrap();
3143 match content {
3144 UserContent::InputAudio { input_audio } => {
3145 assert_eq!(input_audio.data, "SGVsbG8=");
3146 assert_eq!(input_audio.format, AudioMediaType::WAV);
3147 }
3148 _ => panic!("Expected InputAudio variant"),
3149 }
3150 }
3151
3152 #[test]
3153 fn test_message_user_with_audio_serialization() {
3154 let msg = Message::User {
3155 content: OneOrMany::many(vec![
3156 UserContent::text("Transcribe this audio:"),
3157 UserContent::audio_base64("SGVsbG8=", AudioMediaType::MP3),
3158 ])
3159 .unwrap(),
3160 name: None,
3161 };
3162 let json = serde_json::to_value(&msg).unwrap();
3163
3164 assert_eq!(json["role"], "user");
3165 let content = json["content"].as_array().unwrap();
3166 assert_eq!(content.len(), 2);
3167 assert_eq!(content[0]["type"], "text");
3168 assert_eq!(content[1]["type"], "input_audio");
3169 assert_eq!(content[1]["input_audio"]["data"], "SGVsbG8=");
3170 assert_eq!(content[1]["input_audio"]["format"], "mp3");
3171 }
3172
3173 #[test]
3174 fn test_user_content_video_url_serialization() {
3175 let content = UserContent::video_url("https://example.com/video.mp4");
3176 let json = serde_json::to_value(&content).unwrap();
3177
3178 assert_eq!(json["type"], "video_url");
3179 assert_eq!(json["video_url"]["url"], "https://example.com/video.mp4");
3180 }
3181
3182 #[test]
3183 fn test_user_content_video_base64_serialization() {
3184 let content = UserContent::video_base64("SGVsbG8=", VideoMediaType::MP4);
3185 let json = serde_json::to_value(&content).unwrap();
3186
3187 assert_eq!(json["type"], "video_url");
3188 assert_eq!(json["video_url"]["url"], "data:video/mp4;base64,SGVsbG8=");
3189 }
3190
3191 #[test]
3192 fn test_user_content_video_url_deserialization() {
3193 let json = json!({
3194 "type": "video_url",
3195 "video_url": {
3196 "url": "https://example.com/video.mp4"
3197 }
3198 });
3199
3200 let content: UserContent = serde_json::from_value(json).unwrap();
3201 match content {
3202 UserContent::VideoUrl { video_url } => {
3203 assert_eq!(video_url.url, "https://example.com/video.mp4");
3204 }
3205 _ => panic!("Expected VideoUrl variant"),
3206 }
3207 }
3208
3209 #[test]
3210 fn test_message_user_with_video_serialization() {
3211 let msg = Message::User {
3212 content: OneOrMany::many(vec![
3213 UserContent::text("Describe this video:"),
3214 UserContent::video_url("https://example.com/video.mp4"),
3215 ])
3216 .unwrap(),
3217 name: None,
3218 };
3219 let json = serde_json::to_value(&msg).unwrap();
3220
3221 assert_eq!(json["role"], "user");
3222 let content = json["content"].as_array().unwrap();
3223 assert_eq!(content.len(), 2);
3224 assert_eq!(content[0]["type"], "text");
3225 assert_eq!(content[1]["type"], "video_url");
3226 assert_eq!(
3227 content[1]["video_url"]["url"],
3228 "https://example.com/video.mp4"
3229 );
3230 }
3231
3232 #[test]
3233 fn test_user_content_video_url_no_media_type_needed() {
3234 let rig_content = message::UserContent::Video(message::Video {
3235 data: DocumentSourceKind::Url("https://example.com/video.mp4".to_string()),
3236 media_type: None,
3237 additional_params: None,
3238 });
3239 let openrouter_content: UserContent = rig_content.try_into().unwrap();
3240
3241 match openrouter_content {
3242 UserContent::VideoUrl { video_url } => {
3243 assert_eq!(video_url.url, "https://example.com/video.mp4");
3244 }
3245 _ => panic!("Expected VideoUrl variant"),
3246 }
3247 }
3248}