1use super::{
2 client::{ApiErrorResponse, ApiResponse, Client, Usage},
3 streaming::StreamingCompletionResponse,
4};
5use crate::message::{
6 self, AudioMediaType, DocumentMediaType, DocumentSourceKind, ImageDetail, MimeType,
7 VideoMediaType,
8};
9use crate::telemetry::SpanCombinator;
10use crate::{
11 OneOrMany,
12 completion::{self, CompletionError, CompletionRequest},
13 http_client::HttpClientExt,
14 json_utils,
15 one_or_many::string_or_one_or_many,
16 providers::openai,
17};
18use bytes::Bytes;
19use serde::{Deserialize, Serialize, Serializer};
20use std::collections::HashMap;
21use tracing::{Instrument, Level, enabled, info_span};
22
23pub const QWEN_QWQ_32B: &str = "qwen/qwq-32b";
29pub const CLAUDE_3_7_SONNET: &str = "anthropic/claude-3.7-sonnet";
31pub const PERPLEXITY_SONAR_PRO: &str = "perplexity/sonar-pro";
33pub const GEMINI_FLASH_2_0: &str = "google/gemini-2.0-flash-001";
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
45#[serde(rename_all = "lowercase")]
46pub enum DataCollection {
47 #[default]
49 Allow,
50 Deny,
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
58#[serde(rename_all = "lowercase")]
59pub enum Quantization {
60 #[serde(rename = "int4")]
62 Int4,
63 #[serde(rename = "int8")]
65 Int8,
66 #[serde(rename = "fp16")]
68 Fp16,
69 #[serde(rename = "bf16")]
71 Bf16,
72 #[serde(rename = "fp32")]
74 Fp32,
75 #[serde(rename = "fp8")]
77 Fp8,
78 #[serde(rename = "unknown")]
80 Unknown,
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
89#[serde(rename_all = "lowercase")]
90pub enum ProviderSortStrategy {
91 Price,
93 Throughput,
95 Latency,
97}
98
99#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
101#[serde(rename_all = "lowercase")]
102pub enum SortPartition {
103 Model,
105 None,
107}
108
109#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
113pub struct ProviderSortConfig {
114 pub by: ProviderSortStrategy,
116
117 #[serde(skip_serializing_if = "Option::is_none")]
119 pub partition: Option<SortPartition>,
120}
121
122impl ProviderSortConfig {
123 pub fn new(by: ProviderSortStrategy) -> Self {
125 Self {
126 by,
127 partition: None,
128 }
129 }
130
131 pub fn partition(mut self, partition: SortPartition) -> Self {
133 self.partition = Some(partition);
134 self
135 }
136}
137
138#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
143#[serde(untagged)]
144pub enum ProviderSort {
145 Simple(ProviderSortStrategy),
147 Complex(ProviderSortConfig),
149}
150
151impl From<ProviderSortStrategy> for ProviderSort {
152 fn from(strategy: ProviderSortStrategy) -> Self {
153 ProviderSort::Simple(strategy)
154 }
155}
156
157impl From<ProviderSortConfig> for ProviderSort {
158 fn from(config: ProviderSortConfig) -> Self {
159 ProviderSort::Complex(config)
160 }
161}
162
163#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
167#[serde(untagged)]
168pub enum ThroughputThreshold {
169 Simple(f64),
171 Percentile(PercentileThresholds),
173}
174
175#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
179#[serde(untagged)]
180pub enum LatencyThreshold {
181 Simple(f64),
183 Percentile(PercentileThresholds),
185}
186
187#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
189pub struct PercentileThresholds {
190 #[serde(skip_serializing_if = "Option::is_none")]
192 pub p50: Option<f64>,
193 #[serde(skip_serializing_if = "Option::is_none")]
195 pub p75: Option<f64>,
196 #[serde(skip_serializing_if = "Option::is_none")]
198 pub p90: Option<f64>,
199 #[serde(skip_serializing_if = "Option::is_none")]
201 pub p99: Option<f64>,
202}
203
204impl PercentileThresholds {
205 pub fn new() -> Self {
207 Self::default()
208 }
209
210 pub fn p50(mut self, value: f64) -> Self {
212 self.p50 = Some(value);
213 self
214 }
215
216 pub fn p75(mut self, value: f64) -> Self {
218 self.p75 = Some(value);
219 self
220 }
221
222 pub fn p90(mut self, value: f64) -> Self {
224 self.p90 = Some(value);
225 self
226 }
227
228 pub fn p99(mut self, value: f64) -> Self {
230 self.p99 = Some(value);
231 self
232 }
233}
234
235#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
240pub struct MaxPrice {
241 #[serde(skip_serializing_if = "Option::is_none")]
243 pub prompt: Option<f64>,
244 #[serde(skip_serializing_if = "Option::is_none")]
246 pub completion: Option<f64>,
247 #[serde(skip_serializing_if = "Option::is_none")]
249 pub request: Option<f64>,
250 #[serde(skip_serializing_if = "Option::is_none")]
252 pub image: Option<f64>,
253}
254
255impl MaxPrice {
256 pub fn new() -> Self {
258 Self::default()
259 }
260
261 pub fn prompt(mut self, price: f64) -> Self {
263 self.prompt = Some(price);
264 self
265 }
266
267 pub fn completion(mut self, price: f64) -> Self {
269 self.completion = Some(price);
270 self
271 }
272
273 pub fn request(mut self, price: f64) -> Self {
275 self.request = Some(price);
276 self
277 }
278
279 pub fn image(mut self, price: f64) -> Self {
281 self.image = Some(price);
282 self
283 }
284}
285
286#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
306pub struct ProviderPreferences {
307 #[serde(skip_serializing_if = "Option::is_none")]
311 pub order: Option<Vec<String>>,
312
313 #[serde(skip_serializing_if = "Option::is_none")]
315 pub only: Option<Vec<String>>,
316
317 #[serde(skip_serializing_if = "Option::is_none")]
319 pub ignore: Option<Vec<String>>,
320
321 #[serde(skip_serializing_if = "Option::is_none")]
324 pub allow_fallbacks: Option<bool>,
325
326 #[serde(skip_serializing_if = "Option::is_none")]
330 pub require_parameters: Option<bool>,
331
332 #[serde(skip_serializing_if = "Option::is_none")]
335 pub data_collection: Option<DataCollection>,
336
337 #[serde(skip_serializing_if = "Option::is_none")]
339 pub zdr: Option<bool>,
340
341 #[serde(skip_serializing_if = "Option::is_none")]
345 pub sort: Option<ProviderSort>,
346
347 #[serde(skip_serializing_if = "Option::is_none")]
349 pub preferred_min_throughput: Option<ThroughputThreshold>,
350
351 #[serde(skip_serializing_if = "Option::is_none")]
353 pub preferred_max_latency: Option<LatencyThreshold>,
354
355 #[serde(skip_serializing_if = "Option::is_none")]
357 pub max_price: Option<MaxPrice>,
358
359 #[serde(skip_serializing_if = "Option::is_none")]
362 pub quantizations: Option<Vec<Quantization>>,
363}
364
365impl ProviderPreferences {
366 pub fn new() -> Self {
368 Self::default()
369 }
370
371 pub fn order(mut self, providers: impl IntoIterator<Item = impl Into<String>>) -> Self {
387 self.order = Some(providers.into_iter().map(|p| p.into()).collect());
388 self
389 }
390
391 pub fn only(mut self, providers: impl IntoIterator<Item = impl Into<String>>) -> Self {
403 self.only = Some(providers.into_iter().map(|p| p.into()).collect());
404 self
405 }
406
407 pub fn ignore(mut self, providers: impl IntoIterator<Item = impl Into<String>>) -> Self {
418 self.ignore = Some(providers.into_iter().map(|p| p.into()).collect());
419 self
420 }
421
422 pub fn allow_fallbacks(mut self, allow: bool) -> Self {
427 self.allow_fallbacks = Some(allow);
428 self
429 }
430
431 pub fn require_parameters(mut self, require: bool) -> Self {
437 self.require_parameters = Some(require);
438 self
439 }
440
441 pub fn data_collection(mut self, policy: DataCollection) -> Self {
445 self.data_collection = Some(policy);
446 self
447 }
448
449 pub fn zdr(mut self, enable: bool) -> Self {
460 self.zdr = Some(enable);
461 self
462 }
463
464 pub fn sort(mut self, sort: impl Into<ProviderSort>) -> Self {
480 self.sort = Some(sort.into());
481 self
482 }
483
484 pub fn preferred_min_throughput(mut self, threshold: ThroughputThreshold) -> Self {
504 self.preferred_min_throughput = Some(threshold);
505 self
506 }
507
508 pub fn preferred_max_latency(mut self, threshold: LatencyThreshold) -> Self {
512 self.preferred_max_latency = Some(threshold);
513 self
514 }
515
516 pub fn max_price(mut self, price: MaxPrice) -> Self {
520 self.max_price = Some(price);
521 self
522 }
523
524 pub fn quantizations(mut self, quantizations: impl IntoIterator<Item = Quantization>) -> Self {
537 self.quantizations = Some(quantizations.into_iter().collect());
538 self
539 }
540
541 pub fn zero_data_retention(self) -> Self {
545 self.zdr(true)
546 }
547
548 pub fn fastest(self) -> Self {
550 self.sort(ProviderSortStrategy::Throughput)
551 }
552
553 pub fn cheapest(self) -> Self {
555 self.sort(ProviderSortStrategy::Price)
556 }
557
558 pub fn lowest_latency(self) -> Self {
560 self.sort(ProviderSortStrategy::Latency)
561 }
562
563 pub fn to_json(&self) -> serde_json::Value {
565 serde_json::json!({
566 "provider": self
567 })
568 }
569}
570
571#[derive(Debug, Serialize, Deserialize)]
575pub struct CompletionResponse {
576 pub id: String,
577 pub object: String,
578 pub created: u64,
579 pub model: String,
580 pub choices: Vec<Choice>,
581 pub system_fingerprint: Option<String>,
582 pub usage: Option<Usage>,
583}
584
585impl From<ApiErrorResponse> for CompletionError {
586 fn from(err: ApiErrorResponse) -> Self {
587 CompletionError::ProviderError(err.message)
588 }
589}
590
591impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
592 type Error = CompletionError;
593
594 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
595 let choice = response.choices.first().ok_or_else(|| {
596 CompletionError::ResponseError("Response contained no choices".to_owned())
597 })?;
598
599 let content = match &choice.message {
600 Message::Assistant {
601 content,
602 tool_calls,
603 reasoning,
604 reasoning_details,
605 ..
606 } => {
607 let mut content = content
608 .iter()
609 .map(|c| match c {
610 openai::AssistantContent::Text { text } => {
611 completion::AssistantContent::text(text)
612 }
613 openai::AssistantContent::Refusal { refusal } => {
614 completion::AssistantContent::text(refusal)
615 }
616 })
617 .collect::<Vec<_>>();
618
619 content.extend(tool_calls.iter().map(|call| {
620 completion::AssistantContent::tool_call(
621 &call.id,
622 &call.function.name,
623 call.function.arguments.clone(),
624 )
625 }));
626
627 let mut grouped_reasoning: HashMap<
628 Option<String>,
629 Vec<(usize, usize, message::ReasoningContent)>,
630 > = HashMap::new();
631 let mut reasoning_order: Vec<Option<String>> = Vec::new();
632 for (position, detail) in reasoning_details.iter().enumerate() {
633 let (reasoning_id, sort_index, parsed_content) = match detail {
634 ReasoningDetails::Summary {
635 id, index, summary, ..
636 } => (
637 id.clone(),
638 *index,
639 Some(message::ReasoningContent::Summary(summary.clone())),
640 ),
641 ReasoningDetails::Encrypted {
642 id, index, data, ..
643 } => (
644 id.clone(),
645 *index,
646 Some(message::ReasoningContent::Encrypted(data.clone())),
647 ),
648 ReasoningDetails::Text {
649 id,
650 index,
651 text,
652 signature,
653 ..
654 } => (
655 id.clone(),
656 *index,
657 text.as_ref().map(|text| message::ReasoningContent::Text {
658 text: text.clone(),
659 signature: signature.clone(),
660 }),
661 ),
662 };
663
664 let Some(parsed_content) = parsed_content else {
665 continue;
666 };
667 let sort_index = sort_index.unwrap_or(position);
668
669 let entry = grouped_reasoning.entry(reasoning_id.clone());
670 if matches!(entry, std::collections::hash_map::Entry::Vacant(_)) {
671 reasoning_order.push(reasoning_id);
672 }
673 entry
674 .or_default()
675 .push((sort_index, position, parsed_content));
676 }
677
678 if grouped_reasoning.is_empty() {
679 if let Some(reasoning) = reasoning {
680 content.push(completion::AssistantContent::reasoning(reasoning));
681 }
682 } else {
683 for reasoning_id in reasoning_order {
684 let Some(mut blocks) = grouped_reasoning.remove(&reasoning_id) else {
685 continue;
686 };
687 blocks.sort_by_key(|(index, position, _)| (*index, *position));
688 content.push(completion::AssistantContent::Reasoning(
689 message::Reasoning {
690 id: reasoning_id,
691 content: blocks
692 .into_iter()
693 .map(|(_, _, content)| content)
694 .collect::<Vec<_>>(),
695 },
696 ));
697 }
698 }
699
700 Ok(content)
701 }
702 _ => Err(CompletionError::ResponseError(
703 "Response did not contain a valid message or tool call".into(),
704 )),
705 }?;
706
707 let choice = OneOrMany::many(content).map_err(|_| {
708 CompletionError::ResponseError(
709 "Response contained no message or tool call (empty)".to_owned(),
710 )
711 })?;
712
713 let usage = response
714 .usage
715 .as_ref()
716 .map(|usage| completion::Usage {
717 input_tokens: usage.prompt_tokens as u64,
718 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
719 total_tokens: usage.total_tokens as u64,
720 cached_input_tokens: 0,
721 })
722 .unwrap_or_default();
723
724 Ok(completion::CompletionResponse {
725 choice,
726 usage,
727 raw_response: response,
728 message_id: None,
729 })
730 }
731}
732
733#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
773#[serde(tag = "type", rename_all = "snake_case")]
774pub enum UserContent {
775 Text { text: String },
777
778 #[serde(rename = "image_url")]
782 ImageUrl { image_url: ImageUrl },
783
784 File { file: FileContent },
789
790 InputAudio { input_audio: openai::InputAudio },
794
795 #[serde(rename = "video_url")]
800 VideoUrl { video_url: VideoUrlContent },
801}
802
803impl UserContent {
804 pub fn text(text: impl Into<String>) -> Self {
806 UserContent::Text { text: text.into() }
807 }
808
809 pub fn image_url(url: impl Into<String>) -> Self {
811 UserContent::ImageUrl {
812 image_url: ImageUrl {
813 url: url.into(),
814 detail: None,
815 },
816 }
817 }
818
819 pub fn image_url_with_detail(url: impl Into<String>, detail: ImageDetail) -> Self {
821 UserContent::ImageUrl {
822 image_url: ImageUrl {
823 url: url.into(),
824 detail: Some(detail),
825 },
826 }
827 }
828
829 pub fn image_base64(
836 data: impl Into<String>,
837 mime_type: &str,
838 detail: Option<ImageDetail>,
839 ) -> Self {
840 let data_uri = format!("data:{};base64,{}", mime_type, data.into());
841 UserContent::ImageUrl {
842 image_url: ImageUrl {
843 url: data_uri,
844 detail,
845 },
846 }
847 }
848
849 pub fn file_url(url: impl Into<String>, filename: Option<String>) -> Self {
855 UserContent::File {
856 file: FileContent {
857 filename,
858 file_data: Some(url.into()),
859 },
860 }
861 }
862
863 pub fn file_base64(data: impl Into<String>, mime_type: &str, filename: Option<String>) -> Self {
870 let data_uri = format!("data:{};base64,{}", mime_type, data.into());
871 UserContent::File {
872 file: FileContent {
873 filename,
874 file_data: Some(data_uri),
875 },
876 }
877 }
878
879 pub fn audio_base64(data: impl Into<String>, format: AudioMediaType) -> Self {
887 UserContent::InputAudio {
888 input_audio: openai::InputAudio {
889 data: data.into(),
890 format,
891 },
892 }
893 }
894
895 pub fn video_url(url: impl Into<String>) -> Self {
902 UserContent::VideoUrl {
903 video_url: VideoUrlContent { url: url.into() },
904 }
905 }
906
907 pub fn video_base64(data: impl Into<String>, media_type: VideoMediaType) -> Self {
913 let mime = media_type.to_mime_type();
914 let data_uri = format!("data:{mime};base64,{}", data.into());
915 UserContent::VideoUrl {
916 video_url: VideoUrlContent { url: data_uri },
917 }
918 }
919}
920
921impl From<String> for UserContent {
922 fn from(text: String) -> Self {
923 UserContent::Text { text }
924 }
925}
926
927impl From<&str> for UserContent {
928 fn from(text: &str) -> Self {
929 UserContent::Text {
930 text: text.to_string(),
931 }
932 }
933}
934
935impl std::str::FromStr for UserContent {
936 type Err = std::convert::Infallible;
937
938 fn from_str(s: &str) -> Result<Self, Self::Err> {
939 Ok(UserContent::Text {
940 text: s.to_string(),
941 })
942 }
943}
944
945#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
947pub struct ImageUrl {
948 pub url: String,
950 #[serde(skip_serializing_if = "Option::is_none")]
952 pub detail: Option<ImageDetail>,
953}
954
955#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
961pub struct VideoUrlContent {
962 pub url: String,
964}
965
966#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
973pub struct FileContent {
974 #[serde(skip_serializing_if = "Option::is_none")]
976 pub filename: Option<String>,
977 #[serde(skip_serializing_if = "Option::is_none")]
979 pub file_data: Option<String>,
980}
981
982fn serialize_user_content<S>(
985 content: &OneOrMany<UserContent>,
986 serializer: S,
987) -> Result<S::Ok, S::Error>
988where
989 S: Serializer,
990{
991 if content.len() == 1
992 && let UserContent::Text { text } = content.first_ref()
993 {
994 return serializer.serialize_str(text);
995 }
996 content.serialize(serializer)
997}
998
999impl TryFrom<message::UserContent> for UserContent {
1000 type Error = message::MessageError;
1001
1002 fn try_from(value: message::UserContent) -> Result<Self, Self::Error> {
1003 match value {
1004 message::UserContent::Text(message::Text { text }) => Ok(UserContent::Text { text }),
1005
1006 message::UserContent::Image(message::Image {
1007 data,
1008 detail,
1009 media_type,
1010 ..
1011 }) => {
1012 let url = match data {
1013 DocumentSourceKind::Url(url) => url,
1014 DocumentSourceKind::Base64(data) => {
1015 let mime = media_type
1016 .ok_or_else(|| {
1017 message::MessageError::ConversionError(
1018 "Image media type required for base64 encoding".into(),
1019 )
1020 })?
1021 .to_mime_type();
1022 format!("data:{mime};base64,{data}")
1023 }
1024 DocumentSourceKind::Raw(_) => {
1025 return Err(message::MessageError::ConversionError(
1026 "Raw bytes not supported, encode as base64 first".into(),
1027 ));
1028 }
1029 DocumentSourceKind::String(_) => {
1030 return Err(message::MessageError::ConversionError(
1031 "String source not supported for images".into(),
1032 ));
1033 }
1034 DocumentSourceKind::Unknown => {
1035 return Err(message::MessageError::ConversionError(
1036 "Image has no data".into(),
1037 ));
1038 }
1039 };
1040 Ok(UserContent::ImageUrl {
1041 image_url: ImageUrl { url, detail },
1042 })
1043 }
1044
1045 message::UserContent::Document(message::Document {
1046 data, media_type, ..
1047 }) => match data {
1048 DocumentSourceKind::Url(url) => {
1049 let filename = media_type.as_ref().map(|mt| match mt {
1050 DocumentMediaType::PDF => "document.pdf",
1051 DocumentMediaType::TXT => "document.txt",
1052 DocumentMediaType::HTML => "document.html",
1053 DocumentMediaType::MARKDOWN => "document.md",
1054 DocumentMediaType::CSV => "document.csv",
1055 DocumentMediaType::XML => "document.xml",
1056 _ => "document",
1057 });
1058 Ok(UserContent::File {
1059 file: FileContent {
1060 filename: filename.map(String::from),
1061 file_data: Some(url),
1062 },
1063 })
1064 }
1065 DocumentSourceKind::Base64(data) => {
1066 let mime = media_type
1067 .as_ref()
1068 .map(|m| m.to_mime_type())
1069 .unwrap_or("application/pdf");
1070 let data_uri = format!("data:{mime};base64,{data}");
1071
1072 let filename = media_type.as_ref().map(|mt| match mt {
1073 DocumentMediaType::PDF => "document.pdf",
1074 DocumentMediaType::TXT => "document.txt",
1075 DocumentMediaType::HTML => "document.html",
1076 DocumentMediaType::MARKDOWN => "document.md",
1077 DocumentMediaType::CSV => "document.csv",
1078 DocumentMediaType::XML => "document.xml",
1079 _ => "document",
1080 });
1081
1082 Ok(UserContent::File {
1083 file: FileContent {
1084 filename: filename.map(String::from),
1085 file_data: Some(data_uri),
1086 },
1087 })
1088 }
1089 DocumentSourceKind::String(text) => Ok(UserContent::Text { text }),
1090 DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
1091 "Raw bytes not supported for documents, encode as base64 first".into(),
1092 )),
1093 DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
1094 "Document has no data".into(),
1095 )),
1096 },
1097
1098 message::UserContent::Audio(message::Audio {
1099 data, media_type, ..
1100 }) => match data {
1101 DocumentSourceKind::Base64(data) => {
1102 let format = media_type.ok_or_else(|| {
1103 message::MessageError::ConversionError(
1104 "Audio media type required for base64 encoding".into(),
1105 )
1106 })?;
1107 Ok(UserContent::InputAudio {
1108 input_audio: openai::InputAudio { data, format },
1109 })
1110 }
1111 DocumentSourceKind::Url(_) => Err(message::MessageError::ConversionError(
1112 "OpenRouter does not support audio URLs, encode as base64 first".into(),
1113 )),
1114 DocumentSourceKind::Raw(_) => Err(message::MessageError::ConversionError(
1115 "Raw bytes not supported for audio, encode as base64 first".into(),
1116 )),
1117 DocumentSourceKind::String(_) => Err(message::MessageError::ConversionError(
1118 "String source not supported for audio".into(),
1119 )),
1120 DocumentSourceKind::Unknown => Err(message::MessageError::ConversionError(
1121 "Audio has no data".into(),
1122 )),
1123 },
1124
1125 message::UserContent::Video(message::Video {
1126 data, media_type, ..
1127 }) => {
1128 let url = match data {
1129 DocumentSourceKind::Url(url) => url,
1130 DocumentSourceKind::Base64(data) => {
1131 let mime = media_type
1132 .ok_or_else(|| {
1133 message::MessageError::ConversionError(
1134 "Video media type required for base64 encoding".into(),
1135 )
1136 })?
1137 .to_mime_type();
1138 format!("data:{mime};base64,{data}")
1139 }
1140 DocumentSourceKind::Raw(_) => {
1141 return Err(message::MessageError::ConversionError(
1142 "Raw bytes not supported for video, encode as base64 first".into(),
1143 ));
1144 }
1145 DocumentSourceKind::String(_) => {
1146 return Err(message::MessageError::ConversionError(
1147 "String source not supported for video".into(),
1148 ));
1149 }
1150 DocumentSourceKind::Unknown => {
1151 return Err(message::MessageError::ConversionError(
1152 "Video has no data".into(),
1153 ));
1154 }
1155 };
1156 Ok(UserContent::VideoUrl {
1157 video_url: VideoUrlContent { url },
1158 })
1159 }
1160
1161 message::UserContent::ToolResult(_) => Err(message::MessageError::ConversionError(
1162 "Tool results should be handled as separate messages".into(),
1163 )),
1164 }
1165 }
1166}
1167
1168impl TryFrom<OneOrMany<message::UserContent>> for Vec<Message> {
1169 type Error = message::MessageError;
1170
1171 fn try_from(value: OneOrMany<message::UserContent>) -> Result<Self, Self::Error> {
1172 let (tool_results, other_content): (Vec<_>, Vec<_>) = value
1173 .into_iter()
1174 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
1175
1176 if !tool_results.is_empty() {
1179 tool_results
1180 .into_iter()
1181 .map(|content| match content {
1182 message::UserContent::ToolResult(tool_result) => Ok(Message::ToolResult {
1183 tool_call_id: tool_result.id,
1184 content: tool_result
1185 .content
1186 .into_iter()
1187 .map(|c| match c {
1188 message::ToolResultContent::Text(message::Text { text }) => text,
1189 message::ToolResultContent::Image(_) => {
1190 "[Image content not supported in tool results]".to_string()
1191 }
1192 })
1193 .collect::<Vec<_>>()
1194 .join("\n"),
1195 }),
1196 _ => unreachable!(),
1197 })
1198 .collect::<Result<Vec<_>, _>>()
1199 } else {
1200 let user_content: Vec<UserContent> = other_content
1201 .into_iter()
1202 .map(|content| content.try_into())
1203 .collect::<Result<Vec<_>, _>>()?;
1204
1205 let content = OneOrMany::many(user_content)
1206 .expect("There must be content here if there were no tool result content");
1207
1208 Ok(vec![Message::User {
1209 content,
1210 name: None,
1211 }])
1212 }
1213 }
1214}
1215
1216#[derive(Debug, Deserialize, Serialize)]
1221pub struct Choice {
1222 pub index: usize,
1223 pub native_finish_reason: Option<String>,
1224 pub message: Message,
1225 pub finish_reason: Option<String>,
1226}
1227
1228#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1234#[serde(tag = "role", rename_all = "lowercase")]
1235pub enum Message {
1236 #[serde(alias = "developer")]
1237 System {
1238 #[serde(deserialize_with = "string_or_one_or_many")]
1239 content: OneOrMany<openai::SystemContent>,
1240 #[serde(skip_serializing_if = "Option::is_none")]
1241 name: Option<String>,
1242 },
1243 User {
1244 #[serde(
1245 deserialize_with = "string_or_one_or_many",
1246 serialize_with = "serialize_user_content"
1247 )]
1248 content: OneOrMany<UserContent>,
1249 #[serde(skip_serializing_if = "Option::is_none")]
1250 name: Option<String>,
1251 },
1252 Assistant {
1253 #[serde(default, deserialize_with = "json_utils::string_or_vec")]
1254 content: Vec<openai::AssistantContent>,
1255 #[serde(skip_serializing_if = "Option::is_none")]
1256 refusal: Option<String>,
1257 #[serde(skip_serializing_if = "Option::is_none")]
1258 audio: Option<openai::AudioAssistant>,
1259 #[serde(skip_serializing_if = "Option::is_none")]
1260 name: Option<String>,
1261 #[serde(
1262 default,
1263 deserialize_with = "json_utils::null_or_vec",
1264 skip_serializing_if = "Vec::is_empty"
1265 )]
1266 tool_calls: Vec<openai::ToolCall>,
1267 #[serde(skip_serializing_if = "Option::is_none")]
1268 reasoning: Option<String>,
1269 #[serde(default, skip_serializing_if = "Vec::is_empty")]
1270 reasoning_details: Vec<ReasoningDetails>,
1271 },
1272 #[serde(rename = "tool")]
1273 ToolResult {
1274 tool_call_id: String,
1275 content: String,
1276 },
1277}
1278
1279impl Message {
1280 pub fn system(content: &str) -> Self {
1281 Message::System {
1282 content: OneOrMany::one(content.to_owned().into()),
1283 name: None,
1284 }
1285 }
1286}
1287
1288#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
1289#[serde(tag = "type", rename_all = "snake_case")]
1290pub enum ReasoningDetails {
1291 #[serde(rename = "reasoning.summary")]
1292 Summary {
1293 id: Option<String>,
1294 format: Option<String>,
1295 index: Option<usize>,
1296 summary: String,
1297 },
1298 #[serde(rename = "reasoning.encrypted")]
1299 Encrypted {
1300 id: Option<String>,
1301 format: Option<String>,
1302 index: Option<usize>,
1303 data: String,
1304 },
1305 #[serde(rename = "reasoning.text")]
1306 Text {
1307 id: Option<String>,
1308 format: Option<String>,
1309 index: Option<usize>,
1310 text: Option<String>,
1311 signature: Option<String>,
1312 },
1313}
1314
1315#[derive(Debug, Deserialize, PartialEq, Clone)]
1316#[serde(untagged)]
1317enum ToolCallAdditionalParams {
1318 ReasoningDetails(ReasoningDetails),
1319 Minimal {
1320 id: Option<String>,
1321 format: Option<String>,
1322 },
1323}
1324
1325impl From<openai::UserContent> for UserContent {
1327 fn from(value: openai::UserContent) -> Self {
1328 match value {
1329 openai::UserContent::Text { text } => UserContent::Text { text },
1330 openai::UserContent::Image { image_url } => UserContent::ImageUrl {
1331 image_url: ImageUrl {
1332 url: image_url.url,
1333 detail: Some(image_url.detail),
1334 },
1335 },
1336 openai::UserContent::Audio { input_audio } => UserContent::InputAudio { input_audio },
1337 }
1338 }
1339}
1340
1341impl From<openai::Message> for Message {
1342 fn from(value: openai::Message) -> Self {
1343 match value {
1344 openai::Message::System { content, name } => Self::System { content, name },
1345 openai::Message::User { content, name } => {
1346 let converted_content = content.map(UserContent::from);
1348 Self::User {
1349 content: converted_content,
1350 name,
1351 }
1352 }
1353 openai::Message::Assistant {
1354 content,
1355 refusal,
1356 audio,
1357 name,
1358 tool_calls,
1359 } => Self::Assistant {
1360 content,
1361 refusal,
1362 audio,
1363 name,
1364 tool_calls,
1365 reasoning: None,
1366 reasoning_details: Vec::new(),
1367 },
1368 openai::Message::ToolResult {
1369 tool_call_id,
1370 content,
1371 } => Self::ToolResult {
1372 tool_call_id,
1373 content: content.as_text(),
1374 },
1375 }
1376 }
1377}
1378
1379impl TryFrom<OneOrMany<message::AssistantContent>> for Vec<Message> {
1380 type Error = message::MessageError;
1381
1382 fn try_from(value: OneOrMany<message::AssistantContent>) -> Result<Self, Self::Error> {
1383 let mut text_content = Vec::new();
1384 let mut tool_calls = Vec::new();
1385 let mut reasoning = None;
1386 let mut reasoning_details = Vec::new();
1387
1388 for content in value.into_iter() {
1389 match content {
1390 message::AssistantContent::Text(text) => text_content.push(text),
1391 message::AssistantContent::ToolCall(tool_call) => {
1392 if let Some(additional_params) = &tool_call.additional_params
1398 && let Ok(additional_params) =
1399 serde_json::from_value::<ToolCallAdditionalParams>(
1400 additional_params.clone(),
1401 )
1402 {
1403 match additional_params {
1404 ToolCallAdditionalParams::ReasoningDetails(full) => {
1405 reasoning_details.push(full);
1406 }
1407 ToolCallAdditionalParams::Minimal { id, format } => {
1408 let id = id.or_else(|| tool_call.call_id.clone());
1409 if let Some(signature) = &tool_call.signature
1410 && let Some(id) = id
1411 {
1412 reasoning_details.push(ReasoningDetails::Encrypted {
1413 id: Some(id),
1414 format,
1415 index: None,
1416 data: signature.clone(),
1417 })
1418 }
1419 }
1420 }
1421 } else if let Some(signature) = &tool_call.signature {
1422 reasoning_details.push(ReasoningDetails::Encrypted {
1423 id: tool_call.call_id.clone(),
1424 format: None,
1425 index: None,
1426 data: signature.clone(),
1427 });
1428 }
1429 tool_calls.push(tool_call.into())
1430 }
1431 message::AssistantContent::Reasoning(r) => {
1432 if r.content.is_empty() {
1433 let display = r.display_text();
1434 if !display.is_empty() {
1435 reasoning = Some(display);
1436 }
1437 } else {
1438 for reasoning_block in &r.content {
1439 let index = Some(reasoning_details.len());
1440 match reasoning_block {
1441 message::ReasoningContent::Text { text, signature } => {
1442 reasoning_details.push(ReasoningDetails::Text {
1443 id: r.id.clone(),
1444 format: None,
1445 index,
1446 text: Some(text.clone()),
1447 signature: signature.clone(),
1448 });
1449 }
1450 message::ReasoningContent::Summary(summary) => {
1451 reasoning_details.push(ReasoningDetails::Summary {
1452 id: r.id.clone(),
1453 format: None,
1454 index,
1455 summary: summary.clone(),
1456 });
1457 }
1458 message::ReasoningContent::Encrypted(data)
1459 | message::ReasoningContent::Redacted { data } => {
1460 reasoning_details.push(ReasoningDetails::Encrypted {
1461 id: r.id.clone(),
1462 format: None,
1463 index,
1464 data: data.clone(),
1465 });
1466 }
1467 }
1468 }
1469 }
1470 }
1471 message::AssistantContent::Image(_) => {
1472 return Err(Self::Error::ConversionError(
1473 "OpenRouter currently doesn't support images.".into(),
1474 ));
1475 }
1476 }
1477 }
1478
1479 Ok(vec![Message::Assistant {
1482 content: text_content
1483 .into_iter()
1484 .map(|content| content.text.into())
1485 .collect::<Vec<_>>(),
1486 refusal: None,
1487 audio: None,
1488 name: None,
1489 tool_calls,
1490 reasoning,
1491 reasoning_details,
1492 }])
1493 }
1494}
1495
1496impl TryFrom<message::Message> for Vec<Message> {
1499 type Error = message::MessageError;
1500
1501 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
1502 match message {
1503 message::Message::User { content } => {
1504 content.try_into()
1507 }
1508 message::Message::Assistant { content, .. } => content.try_into(),
1509 }
1510 }
1511}
1512
1513#[derive(Debug, Serialize, Deserialize)]
1514#[serde(untagged, rename_all = "snake_case")]
1515pub enum ToolChoice {
1516 None,
1517 Auto,
1518 Required,
1519 Function(Vec<ToolChoiceFunctionKind>),
1520}
1521
1522impl TryFrom<crate::message::ToolChoice> for ToolChoice {
1523 type Error = CompletionError;
1524
1525 fn try_from(value: crate::message::ToolChoice) -> Result<Self, Self::Error> {
1526 let res = match value {
1527 crate::message::ToolChoice::None => Self::None,
1528 crate::message::ToolChoice::Auto => Self::Auto,
1529 crate::message::ToolChoice::Required => Self::Required,
1530 crate::message::ToolChoice::Specific { function_names } => {
1531 let vec: Vec<ToolChoiceFunctionKind> = function_names
1532 .into_iter()
1533 .map(|name| ToolChoiceFunctionKind::Function { name })
1534 .collect();
1535
1536 Self::Function(vec)
1537 }
1538 };
1539
1540 Ok(res)
1541 }
1542}
1543
1544#[derive(Debug, Serialize, Deserialize)]
1545#[serde(tag = "type", content = "function")]
1546pub enum ToolChoiceFunctionKind {
1547 Function { name: String },
1548}
1549
1550#[derive(Debug, Serialize, Deserialize)]
1551pub(super) struct OpenrouterCompletionRequest {
1552 model: String,
1553 pub messages: Vec<Message>,
1554 #[serde(skip_serializing_if = "Option::is_none")]
1555 temperature: Option<f64>,
1556 #[serde(skip_serializing_if = "Vec::is_empty")]
1557 tools: Vec<crate::providers::openai::completion::ToolDefinition>,
1558 #[serde(skip_serializing_if = "Option::is_none")]
1559 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
1560 #[serde(flatten, skip_serializing_if = "Option::is_none")]
1561 pub additional_params: Option<serde_json::Value>,
1562}
1563
1564pub struct OpenRouterRequestParams<'a> {
1566 pub model: &'a str,
1567 pub request: CompletionRequest,
1568 pub strict_tools: bool,
1569}
1570
1571impl TryFrom<OpenRouterRequestParams<'_>> for OpenrouterCompletionRequest {
1572 type Error = CompletionError;
1573
1574 fn try_from(params: OpenRouterRequestParams) -> Result<Self, Self::Error> {
1575 let OpenRouterRequestParams {
1576 model,
1577 request: req,
1578 strict_tools,
1579 } = params;
1580 let model = req.model.clone().unwrap_or_else(|| model.to_string());
1581
1582 if req.output_schema.is_some() {
1583 tracing::warn!("Structured outputs currently not supported for OpenRouter");
1584 }
1585
1586 let mut full_history: Vec<Message> = match &req.preamble {
1587 Some(preamble) => vec![Message::system(preamble)],
1588 None => vec![],
1589 };
1590 if let Some(docs) = req.normalized_documents() {
1591 let docs: Vec<Message> = docs.try_into()?;
1592 full_history.extend(docs);
1593 }
1594
1595 let chat_history: Vec<Message> = req
1596 .chat_history
1597 .clone()
1598 .into_iter()
1599 .map(|message| message.try_into())
1600 .collect::<Result<Vec<Vec<Message>>, _>>()?
1601 .into_iter()
1602 .flatten()
1603 .collect();
1604
1605 full_history.extend(chat_history);
1606
1607 let tool_choice = req
1608 .tool_choice
1609 .clone()
1610 .map(crate::providers::openai::completion::ToolChoice::try_from)
1611 .transpose()?;
1612
1613 let tools: Vec<crate::providers::openai::completion::ToolDefinition> = req
1614 .tools
1615 .clone()
1616 .into_iter()
1617 .map(|tool| {
1618 let def = crate::providers::openai::completion::ToolDefinition::from(tool);
1619 if strict_tools { def.with_strict() } else { def }
1620 })
1621 .collect();
1622
1623 Ok(Self {
1624 model,
1625 messages: full_history,
1626 temperature: req.temperature,
1627 tools,
1628 tool_choice,
1629 additional_params: req.additional_params,
1630 })
1631 }
1632}
1633
1634impl TryFrom<(&str, CompletionRequest)> for OpenrouterCompletionRequest {
1635 type Error = CompletionError;
1636
1637 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
1638 let model = req.model.clone().unwrap_or_else(|| model.to_string());
1639 OpenrouterCompletionRequest::try_from(OpenRouterRequestParams {
1640 model: &model,
1641 request: req,
1642 strict_tools: false,
1643 })
1644 }
1645}
1646
1647#[derive(Clone)]
1648pub struct CompletionModel<T = reqwest::Client> {
1649 pub(crate) client: Client<T>,
1650 pub model: String,
1651 pub strict_tools: bool,
1654}
1655
1656impl<T> CompletionModel<T> {
1657 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
1658 Self {
1659 client,
1660 model: model.into(),
1661 strict_tools: false,
1662 }
1663 }
1664
1665 pub fn with_strict_tools(mut self) -> Self {
1674 self.strict_tools = true;
1675 self
1676 }
1677}
1678
1679impl<T> completion::CompletionModel for CompletionModel<T>
1680where
1681 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
1682{
1683 type Response = CompletionResponse;
1684 type StreamingResponse = StreamingCompletionResponse;
1685
1686 type Client = Client<T>;
1687
1688 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
1689 Self::new(client.clone(), model)
1690 }
1691
1692 async fn completion(
1693 &self,
1694 completion_request: CompletionRequest,
1695 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
1696 let request_model = completion_request
1697 .model
1698 .clone()
1699 .unwrap_or_else(|| self.model.clone());
1700 let preamble = completion_request.preamble.clone();
1701 let request = OpenrouterCompletionRequest::try_from(OpenRouterRequestParams {
1702 model: request_model.as_ref(),
1703 request: completion_request,
1704 strict_tools: self.strict_tools,
1705 })?;
1706
1707 if enabled!(Level::TRACE) {
1708 tracing::trace!(
1709 target: "rig::completions",
1710 "OpenRouter completion request: {}",
1711 serde_json::to_string_pretty(&request)?
1712 );
1713 }
1714
1715 let span = if tracing::Span::current().is_disabled() {
1716 info_span!(
1717 target: "rig::completions",
1718 "chat",
1719 gen_ai.operation.name = "chat",
1720 gen_ai.provider.name = "openrouter",
1721 gen_ai.request.model = &request_model,
1722 gen_ai.system_instructions = preamble,
1723 gen_ai.response.id = tracing::field::Empty,
1724 gen_ai.response.model = tracing::field::Empty,
1725 gen_ai.usage.output_tokens = tracing::field::Empty,
1726 gen_ai.usage.input_tokens = tracing::field::Empty,
1727 )
1728 } else {
1729 tracing::Span::current()
1730 };
1731
1732 let body = serde_json::to_vec(&request)?;
1733
1734 let req = self
1735 .client
1736 .post("/chat/completions")?
1737 .body(body)
1738 .map_err(|x| CompletionError::HttpError(x.into()))?;
1739
1740 async move {
1741 let response = self.client.send::<_, Bytes>(req).await?;
1742 let status = response.status();
1743 let response_body = response.into_body().into_future().await?.to_vec();
1744
1745 if status.is_success() {
1746 let parsed: ApiResponse<CompletionResponse> =
1747 serde_json::from_slice(&response_body).map_err(|e| {
1748 CompletionError::ResponseError(format!(
1749 "Failed to parse OpenRouter completion response: {}, response body: {}",
1750 e,
1751 String::from_utf8_lossy(&response_body)
1752 ))
1753 })?;
1754 match parsed {
1755 ApiResponse::Ok(response) => {
1756 let span = tracing::Span::current();
1757 span.record_token_usage(&response.usage);
1758 span.record("gen_ai.response.id", &response.id);
1759 span.record("gen_ai.response.model_name", &response.model);
1760
1761 tracing::debug!(target: "rig::completions",
1762 "OpenRouter response: {response:?}");
1763 response.try_into()
1764 }
1765 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
1766 }
1767 } else {
1768 Err(CompletionError::ProviderError(
1769 String::from_utf8_lossy(&response_body).to_string(),
1770 ))
1771 }
1772 }
1773 .instrument(span)
1774 .await
1775 }
1776
1777 async fn stream(
1778 &self,
1779 completion_request: CompletionRequest,
1780 ) -> Result<
1781 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
1782 CompletionError,
1783 > {
1784 CompletionModel::stream(self, completion_request).await
1785 }
1786}
1787
1788#[cfg(test)]
1789mod tests {
1790 use super::*;
1791 use serde_json::json;
1792
1793 #[test]
1794 fn test_openrouter_request_uses_request_model_override() {
1795 let request = CompletionRequest {
1796 model: Some("google/gemini-2.5-flash".to_string()),
1797 preamble: None,
1798 chat_history: crate::OneOrMany::one("Hello".into()),
1799 documents: vec![],
1800 tools: vec![],
1801 temperature: None,
1802 max_tokens: None,
1803 tool_choice: None,
1804 additional_params: None,
1805 output_schema: None,
1806 };
1807
1808 let openrouter_request =
1809 OpenrouterCompletionRequest::try_from(("openai/gpt-4o-mini", request))
1810 .expect("request conversion should succeed");
1811 let serialized =
1812 serde_json::to_value(openrouter_request).expect("serialization should succeed");
1813
1814 assert_eq!(serialized["model"], "google/gemini-2.5-flash");
1815 }
1816
1817 #[test]
1818 fn test_openrouter_request_uses_default_model_when_override_unset() {
1819 let request = CompletionRequest {
1820 model: None,
1821 preamble: None,
1822 chat_history: crate::OneOrMany::one("Hello".into()),
1823 documents: vec![],
1824 tools: vec![],
1825 temperature: None,
1826 max_tokens: None,
1827 tool_choice: None,
1828 additional_params: None,
1829 output_schema: None,
1830 };
1831
1832 let openrouter_request =
1833 OpenrouterCompletionRequest::try_from(("openai/gpt-4o-mini", request))
1834 .expect("request conversion should succeed");
1835 let serialized =
1836 serde_json::to_value(openrouter_request).expect("serialization should succeed");
1837
1838 assert_eq!(serialized["model"], "openai/gpt-4o-mini");
1839 }
1840
1841 #[test]
1842 fn test_completion_response_deserialization_gemini_flash() {
1843 let json = json!({
1845 "id": "gen-AAAAAAAAAA-AAAAAAAAAAAAAAAAAAAA",
1846 "provider": "Google",
1847 "model": "google/gemini-2.5-flash",
1848 "object": "chat.completion",
1849 "created": 1765971703u64,
1850 "choices": [{
1851 "logprobs": null,
1852 "finish_reason": "stop",
1853 "native_finish_reason": "STOP",
1854 "index": 0,
1855 "message": {
1856 "role": "assistant",
1857 "content": "CONTENT",
1858 "refusal": null,
1859 "reasoning": null
1860 }
1861 }],
1862 "usage": {
1863 "prompt_tokens": 669,
1864 "completion_tokens": 5,
1865 "total_tokens": 674
1866 }
1867 });
1868
1869 let response: CompletionResponse = serde_json::from_value(json).unwrap();
1870 assert_eq!(response.id, "gen-AAAAAAAAAA-AAAAAAAAAAAAAAAAAAAA");
1871 assert_eq!(response.model, "google/gemini-2.5-flash");
1872 assert_eq!(response.choices.len(), 1);
1873 assert_eq!(response.choices[0].finish_reason, Some("stop".to_string()));
1874 }
1875
1876 #[test]
1877 fn test_message_assistant_without_reasoning_details() {
1878 let json = json!({
1880 "role": "assistant",
1881 "content": "Hello world",
1882 "refusal": null,
1883 "reasoning": null
1884 });
1885
1886 let message: Message = serde_json::from_value(json).unwrap();
1887 match message {
1888 Message::Assistant {
1889 content,
1890 reasoning_details,
1891 ..
1892 } => {
1893 assert_eq!(content.len(), 1);
1894 assert!(reasoning_details.is_empty());
1895 }
1896 _ => panic!("Expected Assistant message"),
1897 }
1898 }
1899
1900 #[test]
1901 fn test_data_collection_serialization() {
1902 assert_eq!(
1903 serde_json::to_string(&DataCollection::Allow).unwrap(),
1904 r#""allow""#
1905 );
1906 assert_eq!(
1907 serde_json::to_string(&DataCollection::Deny).unwrap(),
1908 r#""deny""#
1909 );
1910 }
1911
1912 #[test]
1913 fn test_data_collection_default() {
1914 assert_eq!(DataCollection::default(), DataCollection::Allow);
1915 }
1916
1917 #[test]
1918 fn test_quantization_serialization() {
1919 assert_eq!(
1920 serde_json::to_string(&Quantization::Int4).unwrap(),
1921 r#""int4""#
1922 );
1923 assert_eq!(
1924 serde_json::to_string(&Quantization::Int8).unwrap(),
1925 r#""int8""#
1926 );
1927 assert_eq!(
1928 serde_json::to_string(&Quantization::Fp16).unwrap(),
1929 r#""fp16""#
1930 );
1931 assert_eq!(
1932 serde_json::to_string(&Quantization::Bf16).unwrap(),
1933 r#""bf16""#
1934 );
1935 assert_eq!(
1936 serde_json::to_string(&Quantization::Fp32).unwrap(),
1937 r#""fp32""#
1938 );
1939 assert_eq!(
1940 serde_json::to_string(&Quantization::Fp8).unwrap(),
1941 r#""fp8""#
1942 );
1943 assert_eq!(
1944 serde_json::to_string(&Quantization::Unknown).unwrap(),
1945 r#""unknown""#
1946 );
1947 }
1948
1949 #[test]
1950 fn test_provider_sort_strategy_serialization() {
1951 assert_eq!(
1952 serde_json::to_string(&ProviderSortStrategy::Price).unwrap(),
1953 r#""price""#
1954 );
1955 assert_eq!(
1956 serde_json::to_string(&ProviderSortStrategy::Throughput).unwrap(),
1957 r#""throughput""#
1958 );
1959 assert_eq!(
1960 serde_json::to_string(&ProviderSortStrategy::Latency).unwrap(),
1961 r#""latency""#
1962 );
1963 }
1964
1965 #[test]
1966 fn test_sort_partition_serialization() {
1967 assert_eq!(
1968 serde_json::to_string(&SortPartition::Model).unwrap(),
1969 r#""model""#
1970 );
1971 assert_eq!(
1972 serde_json::to_string(&SortPartition::None).unwrap(),
1973 r#""none""#
1974 );
1975 }
1976
1977 #[test]
1978 fn test_provider_sort_simple() {
1979 let sort = ProviderSort::Simple(ProviderSortStrategy::Latency);
1980 let json = serde_json::to_value(&sort).unwrap();
1981 assert_eq!(json, "latency");
1982 }
1983
1984 #[test]
1985 fn test_provider_sort_complex() {
1986 let sort = ProviderSort::Complex(
1987 ProviderSortConfig::new(ProviderSortStrategy::Price).partition(SortPartition::None),
1988 );
1989 let json = serde_json::to_value(&sort).unwrap();
1990 assert_eq!(json["by"], "price");
1991 assert_eq!(json["partition"], "none");
1992 }
1993
1994 #[test]
1995 fn test_provider_sort_complex_without_partition() {
1996 let sort = ProviderSort::Complex(ProviderSortConfig::new(ProviderSortStrategy::Throughput));
1997 let json = serde_json::to_value(&sort).unwrap();
1998 assert_eq!(json["by"], "throughput");
1999 assert!(json.get("partition").is_none());
2000 }
2001
2002 #[test]
2003 fn test_provider_sort_from_strategy() {
2004 let sort: ProviderSort = ProviderSortStrategy::Price.into();
2005 assert_eq!(sort, ProviderSort::Simple(ProviderSortStrategy::Price));
2006 }
2007
2008 #[test]
2009 fn test_provider_sort_from_config() {
2010 let config = ProviderSortConfig::new(ProviderSortStrategy::Latency);
2011 let sort: ProviderSort = config.into();
2012 match sort {
2013 ProviderSort::Complex(c) => assert_eq!(c.by, ProviderSortStrategy::Latency),
2014 _ => panic!("Expected Complex variant"),
2015 }
2016 }
2017
2018 #[test]
2019 fn test_percentile_thresholds_builder() {
2020 let thresholds = PercentileThresholds::new()
2021 .p50(10.0)
2022 .p75(25.0)
2023 .p90(50.0)
2024 .p99(100.0);
2025
2026 assert_eq!(thresholds.p50, Some(10.0));
2027 assert_eq!(thresholds.p75, Some(25.0));
2028 assert_eq!(thresholds.p90, Some(50.0));
2029 assert_eq!(thresholds.p99, Some(100.0));
2030 }
2031
2032 #[test]
2033 fn test_percentile_thresholds_default() {
2034 let thresholds = PercentileThresholds::default();
2035 assert_eq!(thresholds.p50, None);
2036 assert_eq!(thresholds.p75, None);
2037 assert_eq!(thresholds.p90, None);
2038 assert_eq!(thresholds.p99, None);
2039 }
2040
2041 #[test]
2042 fn test_throughput_threshold_simple() {
2043 let threshold = ThroughputThreshold::Simple(50.0);
2044 let json = serde_json::to_value(&threshold).unwrap();
2045 assert_eq!(json, 50.0);
2046 }
2047
2048 #[test]
2049 fn test_throughput_threshold_percentile() {
2050 let threshold = ThroughputThreshold::Percentile(PercentileThresholds::new().p90(50.0));
2051 let json = serde_json::to_value(&threshold).unwrap();
2052 assert_eq!(json["p90"], 50.0);
2053 }
2054
2055 #[test]
2056 fn test_latency_threshold_simple() {
2057 let threshold = LatencyThreshold::Simple(0.5);
2058 let json = serde_json::to_value(&threshold).unwrap();
2059 assert_eq!(json, 0.5);
2060 }
2061
2062 #[test]
2063 fn test_latency_threshold_percentile() {
2064 let threshold = LatencyThreshold::Percentile(PercentileThresholds::new().p50(0.1).p99(1.0));
2065 let json = serde_json::to_value(&threshold).unwrap();
2066 assert_eq!(json["p50"], 0.1);
2067 assert_eq!(json["p99"], 1.0);
2068 }
2069
2070 #[test]
2071 fn test_max_price_builder() {
2072 let price = MaxPrice::new().prompt(0.001).completion(0.002);
2073
2074 assert_eq!(price.prompt, Some(0.001));
2075 assert_eq!(price.completion, Some(0.002));
2076 assert_eq!(price.request, None);
2077 assert_eq!(price.image, None);
2078 }
2079
2080 #[test]
2081 fn test_max_price_all_fields() {
2082 let price = MaxPrice::new()
2083 .prompt(0.001)
2084 .completion(0.002)
2085 .request(0.01)
2086 .image(0.05);
2087
2088 let json = serde_json::to_value(&price).unwrap();
2089 assert_eq!(json["prompt"], 0.001);
2090 assert_eq!(json["completion"], 0.002);
2091 assert_eq!(json["request"], 0.01);
2092 assert_eq!(json["image"], 0.05);
2093 }
2094
2095 #[test]
2096 fn test_max_price_default() {
2097 let price = MaxPrice::default();
2098 assert_eq!(price.prompt, None);
2099 assert_eq!(price.completion, None);
2100 assert_eq!(price.request, None);
2101 assert_eq!(price.image, None);
2102 }
2103
2104 #[test]
2105 fn test_provider_preferences_default() {
2106 let prefs = ProviderPreferences::default();
2107 assert!(prefs.order.is_none());
2108 assert!(prefs.only.is_none());
2109 assert!(prefs.ignore.is_none());
2110 assert!(prefs.allow_fallbacks.is_none());
2111 assert!(prefs.require_parameters.is_none());
2112 assert!(prefs.data_collection.is_none());
2113 assert!(prefs.zdr.is_none());
2114 assert!(prefs.sort.is_none());
2115 assert!(prefs.preferred_min_throughput.is_none());
2116 assert!(prefs.preferred_max_latency.is_none());
2117 assert!(prefs.max_price.is_none());
2118 assert!(prefs.quantizations.is_none());
2119 }
2120
2121 #[test]
2122 fn test_provider_preferences_order_with_fallbacks() {
2123 let prefs = ProviderPreferences::new()
2124 .order(["anthropic", "openai"])
2125 .allow_fallbacks(true);
2126
2127 let json = prefs.to_json();
2128 let provider = &json["provider"];
2129
2130 assert_eq!(provider["order"], json!(["anthropic", "openai"]));
2131 assert_eq!(provider["allow_fallbacks"], true);
2132 }
2133
2134 #[test]
2135 fn test_provider_preferences_only_allowlist() {
2136 let prefs = ProviderPreferences::new()
2137 .only(["azure", "together"])
2138 .allow_fallbacks(false);
2139
2140 let json = prefs.to_json();
2141 let provider = &json["provider"];
2142
2143 assert_eq!(provider["only"], json!(["azure", "together"]));
2144 assert_eq!(provider["allow_fallbacks"], false);
2145 }
2146
2147 #[test]
2148 fn test_provider_preferences_ignore() {
2149 let prefs = ProviderPreferences::new().ignore(["deepinfra"]);
2150
2151 let json = prefs.to_json();
2152 let provider = &json["provider"];
2153
2154 assert_eq!(provider["ignore"], json!(["deepinfra"]));
2155 }
2156
2157 #[test]
2158 fn test_provider_preferences_sort_latency() {
2159 let prefs = ProviderPreferences::new().sort(ProviderSortStrategy::Latency);
2160
2161 let json = prefs.to_json();
2162 let provider = &json["provider"];
2163
2164 assert_eq!(provider["sort"], "latency");
2165 }
2166
2167 #[test]
2168 fn test_provider_preferences_price_with_throughput() {
2169 let prefs = ProviderPreferences::new()
2170 .sort(ProviderSortStrategy::Price)
2171 .preferred_min_throughput(ThroughputThreshold::Percentile(
2172 PercentileThresholds::new().p90(50.0),
2173 ));
2174
2175 let json = prefs.to_json();
2176 let provider = &json["provider"];
2177
2178 assert_eq!(provider["sort"], "price");
2179 assert_eq!(provider["preferred_min_throughput"]["p90"], 50.0);
2180 }
2181
2182 #[test]
2183 fn test_provider_preferences_require_parameters() {
2184 let prefs = ProviderPreferences::new().require_parameters(true);
2185
2186 let json = prefs.to_json();
2187 let provider = &json["provider"];
2188
2189 assert_eq!(provider["require_parameters"], true);
2190 }
2191
2192 #[test]
2193 fn test_provider_preferences_data_policy_and_zdr() {
2194 let prefs = ProviderPreferences::new()
2195 .data_collection(DataCollection::Deny)
2196 .zdr(true);
2197
2198 let json = prefs.to_json();
2199 let provider = &json["provider"];
2200
2201 assert_eq!(provider["data_collection"], "deny");
2202 assert_eq!(provider["zdr"], true);
2203 }
2204
2205 #[test]
2206 fn test_provider_preferences_quantizations() {
2207 let prefs =
2208 ProviderPreferences::new().quantizations([Quantization::Int8, Quantization::Fp16]);
2209
2210 let json = prefs.to_json();
2211 let provider = &json["provider"];
2212
2213 assert_eq!(provider["quantizations"], json!(["int8", "fp16"]));
2214 }
2215
2216 #[test]
2217 fn test_provider_preferences_convenience_methods() {
2218 let prefs = ProviderPreferences::new().zero_data_retention().fastest();
2219
2220 assert_eq!(prefs.zdr, Some(true));
2221 assert_eq!(
2222 prefs.sort,
2223 Some(ProviderSort::Simple(ProviderSortStrategy::Throughput))
2224 );
2225
2226 let prefs2 = ProviderPreferences::new().cheapest();
2227 assert_eq!(
2228 prefs2.sort,
2229 Some(ProviderSort::Simple(ProviderSortStrategy::Price))
2230 );
2231
2232 let prefs3 = ProviderPreferences::new().lowest_latency();
2233 assert_eq!(
2234 prefs3.sort,
2235 Some(ProviderSort::Simple(ProviderSortStrategy::Latency))
2236 );
2237 }
2238
2239 #[test]
2240 fn test_provider_preferences_serialization_skips_none() {
2241 let prefs = ProviderPreferences::new().sort(ProviderSortStrategy::Price);
2242
2243 let json = serde_json::to_value(&prefs).unwrap();
2244
2245 assert_eq!(json["sort"], "price");
2246 assert!(json.get("order").is_none());
2247 assert!(json.get("only").is_none());
2248 assert!(json.get("ignore").is_none());
2249 assert!(json.get("zdr").is_none());
2250 }
2251
2252 #[test]
2253 fn test_provider_preferences_deserialization() {
2254 let json = json!({
2255 "order": ["anthropic", "openai"],
2256 "sort": "throughput",
2257 "data_collection": "deny",
2258 "zdr": true,
2259 "quantizations": ["int8", "fp16"]
2260 });
2261
2262 let prefs: ProviderPreferences = serde_json::from_value(json).unwrap();
2263
2264 assert_eq!(
2265 prefs.order,
2266 Some(vec!["anthropic".to_string(), "openai".to_string()])
2267 );
2268 assert_eq!(
2269 prefs.sort,
2270 Some(ProviderSort::Simple(ProviderSortStrategy::Throughput))
2271 );
2272 assert_eq!(prefs.data_collection, Some(DataCollection::Deny));
2273 assert_eq!(prefs.zdr, Some(true));
2274 assert_eq!(
2275 prefs.quantizations,
2276 Some(vec![Quantization::Int8, Quantization::Fp16])
2277 );
2278 }
2279
2280 #[test]
2281 fn test_provider_preferences_deserialization_complex_sort() {
2282 let json = json!({
2283 "sort": {
2284 "by": "latency",
2285 "partition": "model"
2286 }
2287 });
2288
2289 let prefs: ProviderPreferences = serde_json::from_value(json).unwrap();
2290
2291 match prefs.sort {
2292 Some(ProviderSort::Complex(config)) => {
2293 assert_eq!(config.by, ProviderSortStrategy::Latency);
2294 assert_eq!(config.partition, Some(SortPartition::Model));
2295 }
2296 _ => panic!("Expected Complex sort variant"),
2297 }
2298 }
2299
2300 #[test]
2301 fn test_provider_preferences_full_integration() {
2302 let prefs = ProviderPreferences::new()
2303 .order(["anthropic", "openai"])
2304 .only(["anthropic", "openai", "google"])
2305 .sort(ProviderSortStrategy::Throughput)
2306 .data_collection(DataCollection::Deny)
2307 .zdr(true)
2308 .quantizations([Quantization::Int8])
2309 .allow_fallbacks(false);
2310
2311 let json = prefs.to_json();
2312
2313 assert!(json.get("provider").is_some());
2314 let provider = &json["provider"];
2315 assert_eq!(provider["order"], json!(["anthropic", "openai"]));
2316 assert_eq!(provider["only"], json!(["anthropic", "openai", "google"]));
2317 assert_eq!(provider["sort"], "throughput");
2318 assert_eq!(provider["data_collection"], "deny");
2319 assert_eq!(provider["zdr"], true);
2320 assert_eq!(provider["quantizations"], json!(["int8"]));
2321 assert_eq!(provider["allow_fallbacks"], false);
2322 }
2323
2324 #[test]
2325 fn test_provider_preferences_max_price() {
2326 let prefs =
2327 ProviderPreferences::new().max_price(MaxPrice::new().prompt(0.001).completion(0.002));
2328
2329 let json = prefs.to_json();
2330 let provider = &json["provider"];
2331
2332 assert_eq!(provider["max_price"]["prompt"], 0.001);
2333 assert_eq!(provider["max_price"]["completion"], 0.002);
2334 }
2335
2336 #[test]
2337 fn test_provider_preferences_preferred_max_latency() {
2338 let prefs = ProviderPreferences::new().preferred_max_latency(LatencyThreshold::Simple(0.5));
2339
2340 let json = prefs.to_json();
2341 let provider = &json["provider"];
2342
2343 assert_eq!(provider["preferred_max_latency"], 0.5);
2344 }
2345
2346 #[test]
2347 fn test_provider_preferences_empty_arrays() {
2348 let prefs = ProviderPreferences::new()
2349 .order(Vec::<String>::new())
2350 .quantizations(Vec::<Quantization>::new());
2351
2352 let json = prefs.to_json();
2353 let provider = &json["provider"];
2354
2355 assert_eq!(provider["order"], json!([]));
2356 assert_eq!(provider["quantizations"], json!([]));
2357 }
2358
2359 #[test]
2364 fn test_user_content_text_serialization() {
2365 let content = UserContent::text("Hello, world!");
2366 let json = serde_json::to_value(&content).unwrap();
2367
2368 assert_eq!(json["type"], "text");
2369 assert_eq!(json["text"], "Hello, world!");
2370 }
2371
2372 #[test]
2373 fn test_user_content_image_url_serialization() {
2374 let content = UserContent::image_url("https://example.com/image.png");
2375 let json = serde_json::to_value(&content).unwrap();
2376
2377 assert_eq!(json["type"], "image_url");
2378 assert_eq!(json["image_url"]["url"], "https://example.com/image.png");
2379 assert!(json["image_url"].get("detail").is_none());
2380 }
2381
2382 #[test]
2383 fn test_user_content_image_url_with_detail_serialization() {
2384 let content =
2385 UserContent::image_url_with_detail("https://example.com/image.png", ImageDetail::High);
2386 let json = serde_json::to_value(&content).unwrap();
2387
2388 assert_eq!(json["type"], "image_url");
2389 assert_eq!(json["image_url"]["url"], "https://example.com/image.png");
2390 assert_eq!(json["image_url"]["detail"], "high");
2391 }
2392
2393 #[test]
2394 fn test_user_content_image_base64_serialization() {
2395 let content = UserContent::image_base64("SGVsbG8=", "image/png", Some(ImageDetail::Low));
2396 let json = serde_json::to_value(&content).unwrap();
2397
2398 assert_eq!(json["type"], "image_url");
2399 assert_eq!(json["image_url"]["url"], "data:image/png;base64,SGVsbG8=");
2400 assert_eq!(json["image_url"]["detail"], "low");
2401 }
2402
2403 #[test]
2404 fn test_user_content_file_url_serialization() {
2405 let content = UserContent::file_url(
2406 "https://example.com/doc.pdf",
2407 Some("document.pdf".to_string()),
2408 );
2409 let json = serde_json::to_value(&content).unwrap();
2410
2411 assert_eq!(json["type"], "file");
2412 assert_eq!(json["file"]["file_data"], "https://example.com/doc.pdf");
2413 assert_eq!(json["file"]["filename"], "document.pdf");
2414 }
2415
2416 #[test]
2417 fn test_user_content_file_base64_serialization() {
2418 let content = UserContent::file_base64(
2419 "JVBERi0xLjQ=",
2420 "application/pdf",
2421 Some("report.pdf".to_string()),
2422 );
2423 let json = serde_json::to_value(&content).unwrap();
2424
2425 assert_eq!(json["type"], "file");
2426 assert_eq!(
2427 json["file"]["file_data"],
2428 "data:application/pdf;base64,JVBERi0xLjQ="
2429 );
2430 assert_eq!(json["file"]["filename"], "report.pdf");
2431 }
2432
2433 #[test]
2434 fn test_user_content_text_deserialization() {
2435 let json = json!({
2436 "type": "text",
2437 "text": "Hello!"
2438 });
2439
2440 let content: UserContent = serde_json::from_value(json).unwrap();
2441 assert_eq!(
2442 content,
2443 UserContent::Text {
2444 text: "Hello!".to_string()
2445 }
2446 );
2447 }
2448
2449 #[test]
2450 fn test_user_content_image_url_deserialization() {
2451 let json = json!({
2452 "type": "image_url",
2453 "image_url": {
2454 "url": "https://example.com/img.jpg",
2455 "detail": "high"
2456 }
2457 });
2458
2459 let content: UserContent = serde_json::from_value(json).unwrap();
2460 match content {
2461 UserContent::ImageUrl { image_url } => {
2462 assert_eq!(image_url.url, "https://example.com/img.jpg");
2463 assert_eq!(image_url.detail, Some(ImageDetail::High));
2464 }
2465 _ => panic!("Expected ImageUrl variant"),
2466 }
2467 }
2468
2469 #[test]
2470 fn test_user_content_file_deserialization() {
2471 let json = json!({
2472 "type": "file",
2473 "file": {
2474 "filename": "doc.pdf",
2475 "file_data": "https://example.com/doc.pdf"
2476 }
2477 });
2478
2479 let content: UserContent = serde_json::from_value(json).unwrap();
2480 match content {
2481 UserContent::File { file } => {
2482 assert_eq!(file.filename, Some("doc.pdf".to_string()));
2483 assert_eq!(
2484 file.file_data,
2485 Some("https://example.com/doc.pdf".to_string())
2486 );
2487 }
2488 _ => panic!("Expected File variant"),
2489 }
2490 }
2491
2492 #[test]
2493 fn test_message_user_with_text_serialization() {
2494 let message = Message::User {
2495 content: OneOrMany::one(UserContent::text("Hello")),
2496 name: None,
2497 };
2498 let json = serde_json::to_value(&message).unwrap();
2499
2500 assert_eq!(json["role"], "user");
2502 assert_eq!(json["content"], "Hello");
2503 }
2504
2505 #[test]
2506 fn test_message_user_with_mixed_content_serialization() {
2507 let message = Message::User {
2508 content: OneOrMany::many(vec![
2509 UserContent::text("Check this image:"),
2510 UserContent::image_url("https://example.com/img.png"),
2511 ])
2512 .unwrap(),
2513 name: None,
2514 };
2515 let json = serde_json::to_value(&message).unwrap();
2516
2517 assert_eq!(json["role"], "user");
2518 let content = json["content"].as_array().unwrap();
2519 assert_eq!(content.len(), 2);
2520 assert_eq!(content[0]["type"], "text");
2521 assert_eq!(content[1]["type"], "image_url");
2522 }
2523
2524 #[test]
2525 fn test_message_user_with_file_serialization() {
2526 let message = Message::User {
2527 content: OneOrMany::many(vec![
2528 UserContent::text("Analyze this PDF:"),
2529 UserContent::file_url(
2530 "https://example.com/doc.pdf",
2531 Some("document.pdf".to_string()),
2532 ),
2533 ])
2534 .unwrap(),
2535 name: None,
2536 };
2537 let json = serde_json::to_value(&message).unwrap();
2538
2539 assert_eq!(json["role"], "user");
2540 let content = json["content"].as_array().unwrap();
2541 assert_eq!(content.len(), 2);
2542 assert_eq!(content[0]["type"], "text");
2543 assert_eq!(content[1]["type"], "file");
2544 assert_eq!(
2545 content[1]["file"]["file_data"],
2546 "https://example.com/doc.pdf"
2547 );
2548 }
2549
2550 #[test]
2551 fn test_user_content_from_rig_text() {
2552 let rig_content = message::UserContent::Text(message::Text {
2553 text: "Hello".to_string(),
2554 });
2555 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2556
2557 assert_eq!(
2558 openrouter_content,
2559 UserContent::Text {
2560 text: "Hello".to_string()
2561 }
2562 );
2563 }
2564
2565 #[test]
2566 fn test_user_content_from_rig_image_url() {
2567 let rig_content = message::UserContent::Image(message::Image {
2568 data: DocumentSourceKind::Url("https://example.com/img.png".to_string()),
2569 media_type: Some(message::ImageMediaType::PNG),
2570 detail: Some(ImageDetail::High),
2571 additional_params: None,
2572 });
2573 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2574
2575 match openrouter_content {
2576 UserContent::ImageUrl { image_url } => {
2577 assert_eq!(image_url.url, "https://example.com/img.png");
2578 assert_eq!(image_url.detail, Some(ImageDetail::High));
2579 }
2580 _ => panic!("Expected ImageUrl variant"),
2581 }
2582 }
2583
2584 #[test]
2585 fn test_user_content_from_rig_image_base64() {
2586 let rig_content = message::UserContent::Image(message::Image {
2587 data: DocumentSourceKind::Base64("SGVsbG8=".to_string()),
2588 media_type: Some(message::ImageMediaType::JPEG),
2589 detail: Some(ImageDetail::Low),
2590 additional_params: None,
2591 });
2592 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2593
2594 match openrouter_content {
2595 UserContent::ImageUrl { image_url } => {
2596 assert_eq!(image_url.url, "data:image/jpeg;base64,SGVsbG8=");
2597 assert_eq!(image_url.detail, Some(ImageDetail::Low));
2598 }
2599 _ => panic!("Expected ImageUrl variant"),
2600 }
2601 }
2602
2603 #[test]
2604 fn test_user_content_from_rig_document_url() {
2605 let rig_content = message::UserContent::Document(message::Document {
2606 data: DocumentSourceKind::Url("https://example.com/doc.pdf".to_string()),
2607 media_type: Some(DocumentMediaType::PDF),
2608 additional_params: None,
2609 });
2610 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2611
2612 match openrouter_content {
2613 UserContent::File { file } => {
2614 assert_eq!(
2615 file.file_data,
2616 Some("https://example.com/doc.pdf".to_string())
2617 );
2618 assert_eq!(file.filename, Some("document.pdf".to_string()));
2619 }
2620 _ => panic!("Expected File variant"),
2621 }
2622 }
2623
2624 #[test]
2625 fn test_user_content_from_rig_document_base64() {
2626 let rig_content = message::UserContent::Document(message::Document {
2627 data: DocumentSourceKind::Base64("JVBERi0xLjQ=".to_string()),
2628 media_type: Some(DocumentMediaType::PDF),
2629 additional_params: None,
2630 });
2631 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2632
2633 match openrouter_content {
2634 UserContent::File { file } => {
2635 assert_eq!(
2636 file.file_data,
2637 Some("data:application/pdf;base64,JVBERi0xLjQ=".to_string())
2638 );
2639 assert_eq!(file.filename, Some("document.pdf".to_string()));
2640 }
2641 _ => panic!("Expected File variant"),
2642 }
2643 }
2644
2645 #[test]
2646 fn test_user_content_from_rig_document_string_becomes_text() {
2647 let rig_content = message::UserContent::Document(message::Document {
2648 data: DocumentSourceKind::String("Plain text document content".to_string()),
2649 media_type: Some(DocumentMediaType::TXT),
2650 additional_params: None,
2651 });
2652 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2653
2654 assert_eq!(
2655 openrouter_content,
2656 UserContent::Text {
2657 text: "Plain text document content".to_string()
2658 }
2659 );
2660 }
2661
2662 #[test]
2663 fn test_completion_response_with_reasoning_details_maps_to_typed_reasoning() {
2664 let json = json!({
2665 "id": "resp_123",
2666 "object": "chat.completion",
2667 "created": 1,
2668 "model": "openrouter/test-model",
2669 "choices": [{
2670 "index": 0,
2671 "finish_reason": "stop",
2672 "message": {
2673 "role": "assistant",
2674 "content": "hello",
2675 "reasoning": null,
2676 "reasoning_details": [
2677 {"type":"reasoning.summary","id":"rs_1","summary":"s1"},
2678 {"type":"reasoning.text","id":"rs_1","text":"t1","signature":"sig_1"},
2679 {"type":"reasoning.encrypted","id":"rs_1","data":"enc_1"}
2680 ]
2681 }
2682 }]
2683 });
2684
2685 let response: CompletionResponse = serde_json::from_value(json).unwrap();
2686 let converted: completion::CompletionResponse<CompletionResponse> =
2687 response.try_into().unwrap();
2688 let items: Vec<completion::AssistantContent> = converted.choice.into_iter().collect();
2689
2690 assert!(items.iter().any(|item| matches!(
2691 item,
2692 completion::AssistantContent::Reasoning(message::Reasoning { id: Some(id), content })
2693 if id == "rs_1" && content.len() == 3
2694 )));
2695 }
2696
2697 #[test]
2698 fn test_assistant_reasoning_emits_openrouter_reasoning_details() {
2699 let reasoning = message::Reasoning {
2700 id: Some("rs_2".to_string()),
2701 content: vec![
2702 message::ReasoningContent::Text {
2703 text: "step".to_string(),
2704 signature: Some("sig_step".to_string()),
2705 },
2706 message::ReasoningContent::Summary("summary".to_string()),
2707 message::ReasoningContent::Encrypted("enc_blob".to_string()),
2708 ],
2709 };
2710
2711 let messages = Vec::<Message>::try_from(OneOrMany::one(
2712 message::AssistantContent::Reasoning(reasoning),
2713 ))
2714 .unwrap();
2715 let Message::Assistant {
2716 reasoning,
2717 reasoning_details,
2718 ..
2719 } = messages.first().expect("assistant message")
2720 else {
2721 panic!("Expected assistant message");
2722 };
2723
2724 assert!(reasoning.is_none());
2725 assert_eq!(reasoning_details.len(), 3);
2726 assert!(matches!(
2727 reasoning_details.first(),
2728 Some(ReasoningDetails::Text {
2729 id: Some(id),
2730 text: Some(text),
2731 signature: Some(signature),
2732 ..
2733 }) if id == "rs_2" && text == "step" && signature == "sig_step"
2734 ));
2735 }
2736
2737 #[test]
2738 fn test_assistant_redacted_reasoning_emits_encrypted_detail_not_text() {
2739 let reasoning = message::Reasoning {
2740 id: Some("rs_redacted".to_string()),
2741 content: vec![message::ReasoningContent::Redacted {
2742 data: "opaque-redacted-data".to_string(),
2743 }],
2744 };
2745
2746 let messages = Vec::<Message>::try_from(OneOrMany::one(
2747 message::AssistantContent::Reasoning(reasoning),
2748 ))
2749 .unwrap();
2750
2751 let Message::Assistant {
2752 reasoning_details,
2753 reasoning,
2754 ..
2755 } = messages.first().expect("assistant message")
2756 else {
2757 panic!("Expected assistant message");
2758 };
2759
2760 assert!(reasoning.is_none());
2761 assert_eq!(reasoning_details.len(), 1);
2762 assert!(matches!(
2763 reasoning_details.first(),
2764 Some(ReasoningDetails::Encrypted {
2765 id: Some(id),
2766 data,
2767 ..
2768 }) if id == "rs_redacted" && data == "opaque-redacted-data"
2769 ));
2770 }
2771
2772 #[test]
2773 fn test_completion_response_reasoning_details_respects_index_ordering() {
2774 let json = json!({
2775 "id": "resp_ordering",
2776 "object": "chat.completion",
2777 "created": 1,
2778 "model": "openrouter/test-model",
2779 "choices": [{
2780 "index": 0,
2781 "finish_reason": "stop",
2782 "message": {
2783 "role": "assistant",
2784 "content": "hello",
2785 "reasoning": null,
2786 "reasoning_details": [
2787 {"type":"reasoning.summary","id":"rs_order","index":1,"summary":"second"},
2788 {"type":"reasoning.summary","id":"rs_order","index":0,"summary":"first"}
2789 ]
2790 }
2791 }]
2792 });
2793
2794 let response: CompletionResponse = serde_json::from_value(json).unwrap();
2795 let converted: completion::CompletionResponse<CompletionResponse> =
2796 response.try_into().unwrap();
2797 let items: Vec<completion::AssistantContent> = converted.choice.into_iter().collect();
2798 let reasoning_blocks: Vec<_> = items
2799 .into_iter()
2800 .filter_map(|item| match item {
2801 completion::AssistantContent::Reasoning(reasoning) => Some(reasoning),
2802 _ => None,
2803 })
2804 .collect();
2805
2806 assert_eq!(reasoning_blocks.len(), 1);
2807 assert_eq!(reasoning_blocks[0].id.as_deref(), Some("rs_order"));
2808 assert_eq!(
2809 reasoning_blocks[0].content,
2810 vec![
2811 message::ReasoningContent::Summary("first".to_string()),
2812 message::ReasoningContent::Summary("second".to_string()),
2813 ]
2814 );
2815 }
2816
2817 #[test]
2818 fn test_user_content_from_rig_image_missing_media_type_error() {
2819 let rig_content = message::UserContent::Image(message::Image {
2820 data: DocumentSourceKind::Base64("SGVsbG8=".to_string()),
2821 media_type: None, detail: None,
2823 additional_params: None,
2824 });
2825 let result: Result<UserContent, _> = rig_content.try_into();
2826
2827 assert!(result.is_err());
2828 let err = result.unwrap_err();
2829 assert!(err.to_string().contains("media type required"));
2830 }
2831
2832 #[test]
2833 fn test_user_content_from_rig_image_raw_bytes_error() {
2834 let rig_content = message::UserContent::Image(message::Image {
2835 data: DocumentSourceKind::Raw(vec![1, 2, 3]),
2836 media_type: Some(message::ImageMediaType::PNG),
2837 detail: None,
2838 additional_params: None,
2839 });
2840 let result: Result<UserContent, _> = rig_content.try_into();
2841
2842 assert!(result.is_err());
2843 let err = result.unwrap_err();
2844 assert!(err.to_string().contains("base64"));
2845 }
2846
2847 #[test]
2848 fn test_user_content_from_rig_video_url() {
2849 let rig_content = message::UserContent::Video(message::Video {
2850 data: DocumentSourceKind::Url("https://example.com/video.mp4".to_string()),
2851 media_type: Some(message::VideoMediaType::MP4),
2852 additional_params: None,
2853 });
2854 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2855
2856 match openrouter_content {
2857 UserContent::VideoUrl { video_url } => {
2858 assert_eq!(video_url.url, "https://example.com/video.mp4");
2859 }
2860 _ => panic!("Expected VideoUrl variant"),
2861 }
2862 }
2863
2864 #[test]
2865 fn test_user_content_from_rig_video_base64() {
2866 let rig_content = message::UserContent::Video(message::Video {
2867 data: DocumentSourceKind::Base64("SGVsbG8=".to_string()),
2868 media_type: Some(message::VideoMediaType::MP4),
2869 additional_params: None,
2870 });
2871 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2872
2873 match openrouter_content {
2874 UserContent::VideoUrl { video_url } => {
2875 assert_eq!(video_url.url, "data:video/mp4;base64,SGVsbG8=");
2876 }
2877 _ => panic!("Expected VideoUrl variant"),
2878 }
2879 }
2880
2881 #[test]
2882 fn test_user_content_from_rig_video_base64_missing_media_type_error() {
2883 let rig_content = message::UserContent::Video(message::Video {
2884 data: DocumentSourceKind::Base64("SGVsbG8=".to_string()),
2885 media_type: None,
2886 additional_params: None,
2887 });
2888 let result: Result<UserContent, _> = rig_content.try_into();
2889
2890 assert!(result.is_err());
2891 let err = result.unwrap_err();
2892 assert!(err.to_string().contains("media type"));
2893 }
2894
2895 #[test]
2896 fn test_user_content_from_rig_video_raw_bytes_error() {
2897 let rig_content = message::UserContent::Video(message::Video {
2898 data: DocumentSourceKind::Raw(vec![1, 2, 3]),
2899 media_type: Some(message::VideoMediaType::MP4),
2900 additional_params: None,
2901 });
2902 let result: Result<UserContent, _> = rig_content.try_into();
2903
2904 assert!(result.is_err());
2905 let err = result.unwrap_err();
2906 assert!(err.to_string().contains("base64"));
2907 }
2908
2909 #[test]
2910 fn test_user_content_from_rig_audio_base64() {
2911 let rig_content = message::UserContent::Audio(message::Audio {
2912 data: DocumentSourceKind::Base64("audiodata".to_string()),
2913 media_type: Some(message::AudioMediaType::MP3),
2914 additional_params: None,
2915 });
2916 let openrouter_content: UserContent = rig_content.try_into().unwrap();
2917
2918 match openrouter_content {
2919 UserContent::InputAudio { input_audio } => {
2920 assert_eq!(input_audio.data, "audiodata");
2921 assert_eq!(input_audio.format, message::AudioMediaType::MP3);
2922 }
2923 _ => panic!("Expected InputAudio variant"),
2924 }
2925 }
2926
2927 #[test]
2928 fn test_user_content_from_rig_audio_missing_media_type_error() {
2929 let rig_content = message::UserContent::Audio(message::Audio {
2930 data: DocumentSourceKind::Base64("audiodata".to_string()),
2931 media_type: None, additional_params: None,
2933 });
2934 let result: Result<UserContent, _> = rig_content.try_into();
2935
2936 assert!(result.is_err());
2937 let err = result.unwrap_err();
2938 assert!(err.to_string().contains("media type required"));
2939 }
2940
2941 #[test]
2942 fn test_user_content_from_rig_audio_url_error() {
2943 let rig_content = message::UserContent::Audio(message::Audio {
2944 data: DocumentSourceKind::Url("https://example.com/audio.wav".to_string()),
2945 media_type: Some(message::AudioMediaType::WAV),
2946 additional_params: None,
2947 });
2948 let result: Result<UserContent, _> = rig_content.try_into();
2949
2950 assert!(result.is_err());
2951 let err = result.unwrap_err();
2952 assert!(err.to_string().contains("base64"));
2953 }
2954
2955 #[test]
2956 fn test_user_content_from_rig_audio_raw_bytes_error() {
2957 let rig_content = message::UserContent::Audio(message::Audio {
2958 data: DocumentSourceKind::Raw(vec![1, 2, 3]),
2959 media_type: Some(message::AudioMediaType::WAV),
2960 additional_params: None,
2961 });
2962 let result: Result<UserContent, _> = rig_content.try_into();
2963
2964 assert!(result.is_err());
2965 let err = result.unwrap_err();
2966 assert!(err.to_string().contains("base64"));
2967 }
2968
2969 #[test]
2970 fn test_message_conversion_with_pdf() {
2971 let rig_message = message::Message::User {
2972 content: OneOrMany::many(vec![
2973 message::UserContent::Text(message::Text {
2974 text: "Summarize this document".to_string(),
2975 }),
2976 message::UserContent::Document(message::Document {
2977 data: DocumentSourceKind::Url("https://example.com/paper.pdf".to_string()),
2978 media_type: Some(DocumentMediaType::PDF),
2979 additional_params: None,
2980 }),
2981 ])
2982 .unwrap(),
2983 };
2984
2985 let openrouter_messages: Vec<Message> = rig_message.try_into().unwrap();
2986 assert_eq!(openrouter_messages.len(), 1);
2987
2988 match &openrouter_messages[0] {
2989 Message::User { content, .. } => {
2990 assert_eq!(content.len(), 2);
2991
2992 match content.first_ref() {
2994 UserContent::Text { text } => assert_eq!(text, "Summarize this document"),
2995 _ => panic!("Expected Text"),
2996 }
2997 }
2998 _ => panic!("Expected User message"),
2999 }
3000 }
3001
3002 #[test]
3003 fn test_user_content_from_string() {
3004 let content: UserContent = "Hello".into();
3005 assert_eq!(
3006 content,
3007 UserContent::Text {
3008 text: "Hello".to_string()
3009 }
3010 );
3011
3012 let content: UserContent = String::from("World").into();
3013 assert_eq!(
3014 content,
3015 UserContent::Text {
3016 text: "World".to_string()
3017 }
3018 );
3019 }
3020
3021 #[test]
3022 fn test_openai_user_content_conversion() {
3023 let openai_text = openai::UserContent::Text {
3025 text: "Hello".to_string(),
3026 };
3027 let converted: UserContent = openai_text.into();
3028 assert_eq!(
3029 converted,
3030 UserContent::Text {
3031 text: "Hello".to_string()
3032 }
3033 );
3034
3035 let openai_image = openai::UserContent::Image {
3036 image_url: openai::ImageUrl {
3037 url: "https://example.com/img.png".to_string(),
3038 detail: ImageDetail::Auto,
3039 },
3040 };
3041 let converted: UserContent = openai_image.into();
3042 match converted {
3043 UserContent::ImageUrl { image_url } => {
3044 assert_eq!(image_url.url, "https://example.com/img.png");
3045 assert_eq!(image_url.detail, Some(ImageDetail::Auto));
3046 }
3047 _ => panic!("Expected ImageUrl"),
3048 }
3049
3050 let openai_audio = openai::UserContent::Audio {
3051 input_audio: openai::InputAudio {
3052 data: "audiodata".to_string(),
3053 format: AudioMediaType::FLAC,
3054 },
3055 };
3056 let converted: UserContent = openai_audio.into();
3057 match converted {
3058 UserContent::InputAudio { input_audio } => {
3059 assert_eq!(input_audio.data, "audiodata");
3060 assert_eq!(input_audio.format, AudioMediaType::FLAC);
3061 }
3062 _ => panic!("Expected InputAudio"),
3063 }
3064 }
3065
3066 #[test]
3067 fn test_completion_response_reasoning_details_with_multiple_ids_stay_separate() {
3068 let json = json!({
3069 "id": "resp_multi_id",
3070 "object": "chat.completion",
3071 "created": 1,
3072 "model": "openrouter/test-model",
3073 "choices": [{
3074 "index": 0,
3075 "finish_reason": "stop",
3076 "message": {
3077 "role": "assistant",
3078 "content": "hello",
3079 "reasoning": null,
3080 "reasoning_details": [
3081 {"type":"reasoning.summary","id":"rs_a","summary":"a1"},
3082 {"type":"reasoning.summary","id":"rs_b","summary":"b1"},
3083 {"type":"reasoning.summary","id":"rs_a","summary":"a2"}
3084 ]
3085 }
3086 }]
3087 });
3088
3089 let response: CompletionResponse = serde_json::from_value(json).unwrap();
3090 let converted: completion::CompletionResponse<CompletionResponse> =
3091 response.try_into().unwrap();
3092 let items: Vec<completion::AssistantContent> = converted.choice.into_iter().collect();
3093 let reasoning_blocks: Vec<_> = items
3094 .into_iter()
3095 .filter_map(|item| match item {
3096 completion::AssistantContent::Reasoning(reasoning) => Some(reasoning),
3097 _ => None,
3098 })
3099 .collect();
3100
3101 assert_eq!(reasoning_blocks.len(), 2);
3102 assert_eq!(reasoning_blocks[0].id.as_deref(), Some("rs_a"));
3103 assert_eq!(
3104 reasoning_blocks[0].content,
3105 vec![
3106 message::ReasoningContent::Summary("a1".to_string()),
3107 message::ReasoningContent::Summary("a2".to_string()),
3108 ]
3109 );
3110 assert_eq!(reasoning_blocks[1].id.as_deref(), Some("rs_b"));
3111 assert_eq!(
3112 reasoning_blocks[1].content,
3113 vec![message::ReasoningContent::Summary("b1".to_string())]
3114 );
3115 }
3116
3117 #[test]
3118 fn test_user_content_audio_serialization() {
3119 let content = UserContent::audio_base64("SGVsbG8=", AudioMediaType::WAV);
3120 let json = serde_json::to_value(&content).unwrap();
3121
3122 assert_eq!(json["type"], "input_audio");
3123 assert_eq!(json["input_audio"]["data"], "SGVsbG8=");
3124 assert_eq!(json["input_audio"]["format"], "wav");
3125 }
3126
3127 #[test]
3128 fn test_user_content_audio_deserialization() {
3129 let json = json!({
3130 "type": "input_audio",
3131 "input_audio": {
3132 "data": "SGVsbG8=",
3133 "format": "wav"
3134 }
3135 });
3136
3137 let content: UserContent = serde_json::from_value(json).unwrap();
3138 match content {
3139 UserContent::InputAudio { input_audio } => {
3140 assert_eq!(input_audio.data, "SGVsbG8=");
3141 assert_eq!(input_audio.format, AudioMediaType::WAV);
3142 }
3143 _ => panic!("Expected InputAudio variant"),
3144 }
3145 }
3146
3147 #[test]
3148 fn test_message_user_with_audio_serialization() {
3149 let msg = Message::User {
3150 content: OneOrMany::many(vec![
3151 UserContent::text("Transcribe this audio:"),
3152 UserContent::audio_base64("SGVsbG8=", AudioMediaType::MP3),
3153 ])
3154 .unwrap(),
3155 name: None,
3156 };
3157 let json = serde_json::to_value(&msg).unwrap();
3158
3159 assert_eq!(json["role"], "user");
3160 let content = json["content"].as_array().unwrap();
3161 assert_eq!(content.len(), 2);
3162 assert_eq!(content[0]["type"], "text");
3163 assert_eq!(content[1]["type"], "input_audio");
3164 assert_eq!(content[1]["input_audio"]["data"], "SGVsbG8=");
3165 assert_eq!(content[1]["input_audio"]["format"], "mp3");
3166 }
3167
3168 #[test]
3169 fn test_user_content_video_url_serialization() {
3170 let content = UserContent::video_url("https://example.com/video.mp4");
3171 let json = serde_json::to_value(&content).unwrap();
3172
3173 assert_eq!(json["type"], "video_url");
3174 assert_eq!(json["video_url"]["url"], "https://example.com/video.mp4");
3175 }
3176
3177 #[test]
3178 fn test_user_content_video_base64_serialization() {
3179 let content = UserContent::video_base64("SGVsbG8=", VideoMediaType::MP4);
3180 let json = serde_json::to_value(&content).unwrap();
3181
3182 assert_eq!(json["type"], "video_url");
3183 assert_eq!(json["video_url"]["url"], "data:video/mp4;base64,SGVsbG8=");
3184 }
3185
3186 #[test]
3187 fn test_user_content_video_url_deserialization() {
3188 let json = json!({
3189 "type": "video_url",
3190 "video_url": {
3191 "url": "https://example.com/video.mp4"
3192 }
3193 });
3194
3195 let content: UserContent = serde_json::from_value(json).unwrap();
3196 match content {
3197 UserContent::VideoUrl { video_url } => {
3198 assert_eq!(video_url.url, "https://example.com/video.mp4");
3199 }
3200 _ => panic!("Expected VideoUrl variant"),
3201 }
3202 }
3203
3204 #[test]
3205 fn test_message_user_with_video_serialization() {
3206 let msg = Message::User {
3207 content: OneOrMany::many(vec![
3208 UserContent::text("Describe this video:"),
3209 UserContent::video_url("https://example.com/video.mp4"),
3210 ])
3211 .unwrap(),
3212 name: None,
3213 };
3214 let json = serde_json::to_value(&msg).unwrap();
3215
3216 assert_eq!(json["role"], "user");
3217 let content = json["content"].as_array().unwrap();
3218 assert_eq!(content.len(), 2);
3219 assert_eq!(content[0]["type"], "text");
3220 assert_eq!(content[1]["type"], "video_url");
3221 assert_eq!(
3222 content[1]["video_url"]["url"],
3223 "https://example.com/video.mp4"
3224 );
3225 }
3226
3227 #[test]
3228 fn test_user_content_video_url_no_media_type_needed() {
3229 let rig_content = message::UserContent::Video(message::Video {
3230 data: DocumentSourceKind::Url("https://example.com/video.mp4".to_string()),
3231 media_type: None,
3232 additional_params: None,
3233 });
3234 let openrouter_content: UserContent = rig_content.try_into().unwrap();
3235
3236 match openrouter_content {
3237 UserContent::VideoUrl { video_url } => {
3238 assert_eq!(video_url.url, "https://example.com/video.mp4");
3239 }
3240 _ => panic!("Expected VideoUrl variant"),
3241 }
3242 }
3243}