1#![forbid(unsafe_code)]
9#![warn(missing_docs, missing_debug_implementations)]
10
11mod types;
12
13use std::{sync::Arc, time::Duration};
14
15use agent_proxy_rust_core::{
16 ProxyError,
17 extensions::{EXT_SELECTED_CHANNEL, EXT_SELECTED_MAPPING},
18 middleware::ProxyMiddleware,
19 types::{ApiFormat, ChannelConfig, ConnectionContext, ProxyRequest, ProxyResponse},
20};
21use agent_proxy_rust_storage::ProtocolEntry;
22use agent_proxy_rust_storage::Storage;
23use arc_swap::ArcSwap;
24use async_trait::async_trait;
25use dashmap::DashMap;
26use secrecy::ExposeSecret;
27use tracing::{debug, warn};
28pub use types::{
29 BillingDimension, ChannelBilling, ChannelHealth, ChannelState, ExhaustedAction, Pricing,
30 PricingTier, Quota, QuotaUsage, TierPrice,
31};
32
33const COOLDOWN: Duration = Duration::from_secs(60);
35
36#[derive(Debug, Clone)]
38pub struct ResolvedChannel {
39 pub channel_id: String,
41 pub channel_name: String,
43 pub api_key: secrecy::SecretString,
45 pub protocols: Vec<ProtocolEntry>,
47 pub enabled: bool,
49 pub force_protocol: Option<String>,
51 pub priority: u32,
53 pub mappings: Vec<ResolvedMapping>,
55}
56
57impl ResolvedChannel {
58 #[allow(dead_code)]
60 fn supported_protocols(&self) -> Vec<&str> {
61 self.protocols.iter().map(|p| p.protocol.as_str()).collect()
62 }
63}
64
65#[derive(Debug, Clone)]
67pub struct ResolvedMapping {
68 pub mapping_id: String,
70 pub client_name: String,
72 pub upstream_name: String,
74 pub billing: ChannelBilling,
76 pub allowed_protocols: Vec<String>,
78}
79
80#[derive(Debug, Clone)]
82pub struct SelectedMappingInfo {
83 pub channel_id: String,
85 pub mapping_id: String,
87 pub client_name: String,
89 pub upstream_name: String,
91 pub is_flat_fee: bool,
93 pub pricing: Option<Pricing>,
95 pub pricing_snapshot_json: String,
97}
98
99pub struct ModelRouterMiddleware {
101 channels: Arc<ArcSwap<Vec<ResolvedChannel>>>,
102 health: Arc<DashMap<String, ChannelState>>,
103 quota_usage: Arc<DashMap<String, QuotaUsage>>,
105 channel_api_keys: Arc<DashMap<String, secrecy::SecretString>>,
109}
110
111impl std::fmt::Debug for ModelRouterMiddleware {
112 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113 f.debug_struct("ModelRouterMiddleware")
114 .field("channels", &self.channels.load())
115 .field("health", &self.health)
116 .field("quota_usage", &self.quota_usage)
117 .field("channel_api_keys", &"<DashMap>")
118 .finish()
119 }
120}
121
122impl ModelRouterMiddleware {
123 pub async fn from_storage(storage: Arc<dyn Storage>) -> Result<Self, ProxyError> {
133 let storage_channels = storage
134 .list_channels(None)
135 .await
136 .map_err(|e| ProxyError::Internal(e.into()))?;
137
138 let mut channels = Vec::with_capacity(storage_channels.len());
139
140 for ch in storage_channels {
141 let protocols: Vec<ProtocolEntry> =
143 serde_json::from_str(&ch.protocols).unwrap_or_default();
144 if protocols.is_empty() {
145 warn!(
146 channel = %ch.id,
147 "channel has no protocols configured, skipping"
148 );
149 continue;
150 }
151
152 let storage_mappings = storage
153 .list_mappings(&ch.id)
154 .await
155 .map_err(|e| ProxyError::Internal(e.into()))?;
156
157 let mappings: Vec<ResolvedMapping> = storage_mappings
158 .into_iter()
159 .filter(|m| m.enabled)
160 .filter_map(|m| {
161 let billing = ChannelBilling::from_storage(&m.billing, &m.pricing_json)
162 .map_err(|e| {
163 warn!(
164 channel = %ch.id,
165 mapping = %m.id,
166 error = %e,
167 "failed to parse mapping billing/pricing, skipping"
168 );
169 })
170 .ok()?;
171 let allowed_protocols: Vec<String> =
172 serde_json::from_str(&m.protocols).unwrap_or_default();
173 Some(ResolvedMapping {
174 mapping_id: m.id,
175 client_name: m.client_name,
176 upstream_name: m.upstream_name,
177 billing,
178 allowed_protocols,
179 })
180 })
181 .collect();
182
183 let protocols: Vec<ProtocolEntry> = protocols
185 .into_iter()
186 .map(|mut p| {
187 p.base_url = p.base_url.trim_end_matches('/').to_string();
188 p.rewrite_path = p.rewrite_path.filter(|rp| !rp.is_empty());
189 p
190 })
191 .collect();
192
193 channels.push(ResolvedChannel {
194 channel_id: ch.id,
195 channel_name: ch.name,
196 api_key: ch.api_key,
197 protocols,
198 enabled: ch.enabled,
199 force_protocol: ch.force_protocol,
200 priority: ch.priority,
201 mappings,
202 });
203 }
204
205 let health: Arc<DashMap<String, ChannelState>> = Arc::new(DashMap::new());
209 for ch in &channels {
210 if ch.api_key.expose_secret().is_empty() {
211 health
213 .entry(ch.channel_id.clone())
214 .or_default()
215 .mark_unhealthy();
216 let _ = storage.record_channel_failure(&ch.channel_id).await;
218 let _ = storage.record_channel_failure(&ch.channel_id).await;
219 let _ = storage.record_channel_failure(&ch.channel_id).await;
220 tracing::info!(channel=%ch.channel_id, name=%ch.channel_name, "no API key — Unavailable");
221 }
222 }
223
224 let channel_api_keys: Arc<DashMap<String, secrecy::SecretString>> =
228 Arc::new(DashMap::new());
229 for ch in &channels {
230 if !ch.api_key.expose_secret().is_empty() {
231 channel_api_keys.insert(ch.channel_id.clone(), ch.api_key.clone());
232 }
233 }
234
235 Ok(Self {
236 channels: Arc::new(ArcSwap::from_pointee(channels)),
237 health,
238 quota_usage: Arc::new(DashMap::new()),
239 channel_api_keys,
240 })
241 }
242
243 #[must_use]
245 pub fn health_map(&self) -> &Arc<DashMap<String, ChannelState>> {
246 &self.health
247 }
248
249 #[must_use]
254 pub fn api_key_map(&self) -> &Arc<DashMap<String, secrecy::SecretString>> {
255 &self.channel_api_keys
256 }
257
258 #[must_use]
261 pub fn channels_swap(&self) -> Arc<ArcSwap<Vec<ResolvedChannel>>> {
262 Arc::clone(&self.channels)
263 }
264
265 fn find_candidates<'c>(
267 channels: &'c [ResolvedChannel],
268 client_name: &str,
269 ) -> Vec<(&'c ResolvedChannel, &'c ResolvedMapping)> {
270 let mut candidates = Vec::new();
271 for ch in channels {
272 if !ch.enabled {
273 continue;
274 }
275 for m in &ch.mappings {
276 if m.client_name == client_name {
277 candidates.push((ch, m));
278 }
279 }
280 }
281 candidates
282 }
283
284 fn select_channel<'a>(
286 &self,
287 candidates: &[(&'a ResolvedChannel, &'a ResolvedMapping)],
288 client_name: &str,
289 ) -> Result<(&'a ResolvedChannel, &'a ResolvedMapping), ProxyError> {
290 let (mut flatfee, mut metered): (Vec<_>, Vec<_>) = candidates
291 .iter()
292 .partition(|(_, m)| m.billing.is_flat_fee());
293
294 flatfee.sort_by_key(|(ch, _)| std::cmp::Reverse(ch.priority));
296 metered.sort_by_key(|(ch, _)| std::cmp::Reverse(ch.priority));
297
298 for (ch, m) in &flatfee {
300 if !self.has_api_key(&ch.channel_id) {
301 debug!(
302 channel = %ch.channel_id,
303 "skipping flat-fee channel: no API key configured"
304 );
305 continue;
306 }
307 if !self.is_healthy(&ch.channel_id) {
308 continue;
309 }
310 if let ChannelBilling::FlatFee {
311 on_exhausted,
312 quota,
313 ..
314 } = &m.billing
315 {
316 let within_quota = self
318 .quota_usage
319 .entry(m.mapping_id.clone())
320 .or_default()
321 .is_within_quota(quota.as_ref());
322
323 if within_quota {
324 return Ok((ch, m));
325 }
326 if *on_exhausted == ExhaustedAction::Block {
327 debug!(
328 channel = %ch.channel_id,
329 model = %client_name,
330 "flat-fee channel quota exhausted, blocking"
331 );
332 return Err(ProxyError::ChannelSelection {
333 model: client_name.to_owned(),
334 });
335 }
336 }
337 }
338
339 for (ch, m) in &metered {
341 if !self.has_api_key(&ch.channel_id) {
342 debug!(
343 channel = %ch.channel_id,
344 "skipping metered channel: no API key configured"
345 );
346 continue;
347 }
348 if self.is_healthy(&ch.channel_id) {
349 return Ok((ch, m));
350 }
351 }
352
353 for (ch, m) in candidates {
355 if !self.has_api_key(&ch.channel_id) {
356 continue;
357 }
358 if self.is_tryable_past_cooldown(&ch.channel_id) {
359 warn!(
360 channel = %ch.channel_id,
361 model = %client_name,
362 "all channels unhealthy, retrying past cooldown"
363 );
364 return Ok((ch, m));
365 }
366 }
367
368 Err(ProxyError::ChannelSelection {
369 model: client_name.to_owned(),
370 })
371 }
372
373 fn has_api_key(&self, channel_id: &str) -> bool {
380 if let Some(key) = self.channel_api_keys.get(channel_id) {
382 return !key.expose_secret().is_empty();
383 }
384 self.channels
386 .load()
387 .iter()
388 .any(|ch| ch.channel_id == channel_id && !ch.api_key.expose_secret().is_empty())
389 }
390
391 fn is_healthy(&self, channel_id: &str) -> bool {
392 self.health
393 .get(channel_id)
394 .is_none_or(|s| s.is_tryable(COOLDOWN))
395 }
396
397 fn is_tryable_past_cooldown(&self, channel_id: &str) -> bool {
398 self.health
399 .get(channel_id)
400 .is_none_or(|s| s.is_tryable(COOLDOWN))
401 }
402
403 fn mark_healthy(&self, channel_id: &str) {
404 if let Some(mut state) = self.health.get_mut(channel_id) {
405 state.record_success();
406 }
407 }
408
409 fn record_failure(&self, channel_id: &str) {
412 let mut state = self.health.entry(channel_id.to_owned()).or_default();
413 state.record_failure();
414 }
415
416 fn mark_unhealthy_immediate(&self, channel_id: &str) {
418 self.health
419 .entry(channel_id.to_owned())
420 .or_default()
421 .mark_unhealthy();
422 }
423}
424
425pub async fn reload_channels_from_storage(
437 storage: &dyn Storage,
438 channels_swap: &ArcSwap<Vec<ResolvedChannel>>,
439) -> Result<(), ProxyError> {
440 let storage_channels = storage
441 .list_channels(None)
442 .await
443 .map_err(|e| ProxyError::Internal(e.into()))?;
444
445 let mut channels = Vec::with_capacity(storage_channels.len());
446
447 for ch in storage_channels {
448 let protocols: Vec<ProtocolEntry> = serde_json::from_str(&ch.protocols).unwrap_or_default();
449 if protocols.is_empty() {
450 warn!(
451 channel = %ch.id,
452 "channel has no protocols configured, skipping"
453 );
454 continue;
455 }
456
457 let storage_mappings = storage
458 .list_mappings(&ch.id)
459 .await
460 .map_err(|e| ProxyError::Internal(e.into()))?;
461
462 let mappings: Vec<ResolvedMapping> = storage_mappings
463 .into_iter()
464 .filter(|m| m.enabled)
465 .filter_map(|m| {
466 let billing = ChannelBilling::from_storage(&m.billing, &m.pricing_json)
467 .map_err(|e| {
468 warn!(
469 channel = %ch.id,
470 mapping = %m.id,
471 error = %e,
472 "failed to parse mapping billing/pricing, skipping"
473 );
474 })
475 .ok()?;
476 let allowed_protocols: Vec<String> =
477 serde_json::from_str(&m.protocols).unwrap_or_default();
478 Some(ResolvedMapping {
479 mapping_id: m.id,
480 client_name: m.client_name,
481 upstream_name: m.upstream_name,
482 billing,
483 allowed_protocols,
484 })
485 })
486 .collect();
487
488 let protocols: Vec<ProtocolEntry> = protocols
489 .into_iter()
490 .map(|mut p| {
491 p.base_url = p.base_url.trim_end_matches('/').to_string();
492 p.rewrite_path = p.rewrite_path.filter(|rp| !rp.is_empty());
493 p
494 })
495 .collect();
496
497 channels.push(ResolvedChannel {
498 channel_id: ch.id,
499 channel_name: ch.name,
500 api_key: ch.api_key,
501 protocols,
502 enabled: ch.enabled,
503 force_protocol: ch.force_protocol,
504 priority: ch.priority,
505 mappings,
506 });
507 }
508
509 channels_swap.store(Arc::new(channels));
510 tracing::info!(count = channels_swap.load().len(), "channels hot-reloaded");
511 Ok(())
512}
513
514#[async_trait]
515impl ProxyMiddleware for ModelRouterMiddleware {
516 #[allow(clippy::too_many_lines)]
517 async fn on_request(
518 &self,
519 req: &mut ProxyRequest,
520 ctx: &mut ConnectionContext,
521 ) -> Result<(), ProxyError> {
522 let mut body: serde_json::Value =
523 serde_json::from_slice(&req.body).map_err(|e| ProxyError::BadRequest(e.to_string()))?;
524
525 let client_name = body
526 .get("model")
527 .and_then(|v| v.as_str())
528 .map(String::from)
529 .unwrap_or_default();
530
531 if client_name.is_empty() {
532 return Err(ProxyError::BadRequest(
533 "request body missing 'model' field".into(),
534 ));
535 }
536
537 let channels = self.channels.load();
539 let candidates = Self::find_candidates(&channels, &client_name);
540
541 if candidates.is_empty() {
542 return Err(ProxyError::ChannelSelection { model: client_name });
543 }
544
545 let (channel, mapping) = self.select_channel(&candidates, &client_name)?;
546
547 debug!(
548 channel = %channel.channel_id,
549 client_model = %client_name,
550 upstream_model = %mapping.upstream_name,
551 "selected channel"
552 );
553
554 if let Some(model_field) = body.get_mut("model") {
556 *model_field = serde_json::Value::String(mapping.upstream_name.clone());
557 }
558 let new_body =
559 serde_json::to_vec(&body).map_err(|e| ProxyError::BadRequest(e.to_string()))?;
560 req.body = bytes::Bytes::from(new_body);
561
562 let mut target_protocol = resolve_target_protocol(
564 channel.force_protocol.as_deref(),
565 ctx.detected_format,
566 &channel.protocols,
567 )?;
568
569 if !mapping.allowed_protocols.is_empty() {
577 let target_str = protocol_to_str(target_protocol);
578 if !mapping.allowed_protocols.iter().any(|p| p == target_str) {
579 let compatible = channel.protocols.iter().find(|pe| {
583 mapping
584 .allowed_protocols
585 .iter()
586 .any(|ap| ap == &pe.protocol)
587 });
588 if let Some(entry) = compatible {
589 debug!(
590 channel = %channel.channel_id,
591 mapping = %mapping.mapping_id,
592 resolved = %target_str,
593 switched_to = %entry.protocol,
594 "mapping protocol constraint: switching target protocol"
595 );
596 target_protocol = parse_protocol(&entry.protocol)?;
597 } else {
598 let channel_prots: Vec<&str> = channel
599 .protocols
600 .iter()
601 .map(|p| p.protocol.as_str())
602 .collect();
603 return Err(ProxyError::Internal(anyhow::anyhow!(
604 "mapping '{}' protocol constraint {:?} incompatible with channel protocols {channel_prots:?}",
605 mapping.mapping_id,
606 mapping.allowed_protocols,
607 )));
608 }
609 }
610 }
611
612 ctx.target_protocol = Some(target_protocol);
613
614 let (base_url, rewrite_path) = resolve_upstream_url(target_protocol, &channel.protocols)?;
616
617 let api_key = self
621 .channel_api_keys
622 .get(&channel.channel_id)
623 .map_or_else(|| channel.api_key.clone(), |r| r.clone());
624
625 ctx.insert(
627 EXT_SELECTED_CHANNEL,
628 ChannelConfig {
629 url: base_url,
630 api_key,
631 protocol: target_protocol,
632 name: channel.channel_name.clone(),
633 rewrite_path,
634 },
635 );
636
637 let (pricing, snapshot_json) = match &mapping.billing {
639 ChannelBilling::Metered { pricing } => {
640 let json = serde_json::to_string(pricing).unwrap_or_default();
641 (Some(pricing.clone()), json)
642 }
643 ChannelBilling::FlatFee { .. } => (None, r#"{"type":"flat_fee"}"#.to_string()),
644 };
645
646 ctx.insert(
648 EXT_SELECTED_MAPPING,
649 SelectedMappingInfo {
650 channel_id: channel.channel_id.clone(),
651 mapping_id: mapping.mapping_id.clone(),
652 client_name: mapping.client_name.clone(),
653 upstream_name: mapping.upstream_name.clone(),
654 is_flat_fee: mapping.billing.is_flat_fee(),
655 pricing,
656 pricing_snapshot_json: snapshot_json,
657 },
658 );
659
660 Ok(())
661 }
662
663 async fn on_response(
664 &self,
665 res: &mut ProxyResponse,
666 ctx: &ConnectionContext,
667 ) -> Result<(), ProxyError> {
668 let channel_id = ctx
669 .get::<ChannelConfig>(EXT_SELECTED_CHANNEL)
670 .map(|ch| ch.name.clone())
671 .unwrap_or_default();
672
673 if channel_id.is_empty() {
674 return Ok(());
675 }
676
677 if let Some(mapping_info) = ctx.get::<SelectedMappingInfo>(EXT_SELECTED_MAPPING)
679 && mapping_info.is_flat_fee
680 {
681 let token_count =
682 serde_json::from_slice(&res.body).map_or(0, |body| extract_token_count(&body));
683 self.quota_usage
684 .entry(mapping_info.mapping_id.clone())
685 .or_default()
686 .record_usage(token_count);
687 }
688
689 if res.status.is_server_error() || res.status == http::StatusCode::UNAUTHORIZED {
690 warn!(
693 channel = %channel_id,
694 status = %res.status,
695 "upstream {}, marking channel unhealthy immediately",
696 if res.status.is_server_error() { "5xx" } else { "401 Unauthorized" }
697 );
698 self.mark_unhealthy_immediate(&channel_id);
699 } else if res.status.is_client_error() && res.status.as_u16() != 429 {
700 debug!(
702 channel = %channel_id,
703 status = %res.status,
704 "client error, not counting as channel failure"
705 );
706 } else if res.status == http::StatusCode::TOO_MANY_REQUESTS {
707 warn!(
709 channel = %channel_id,
710 "upstream 429 rate limit, recording failure"
711 );
712 self.record_failure(&channel_id);
713 } else {
714 self.mark_healthy(&channel_id);
716 }
717
718 Ok(())
719 }
720
721 fn name(&self) -> &'static str {
722 "model-router"
723 }
724}
725
726fn resolve_target_protocol(
741 force_protocol: Option<&str>,
742 detected_format: Option<ApiFormat>,
743 protocols: &[ProtocolEntry],
744) -> Result<ApiFormat, ProxyError> {
745 if let Some(fp) = force_protocol {
747 let target = parse_protocol(fp)?;
748 let target_str = protocol_to_str(target);
749 if !protocols.iter().any(|p| p.protocol == target_str) {
750 return Err(ProxyError::Internal(anyhow::anyhow!(
751 "force_protocol '{fp}' not found in channel protocols"
752 )));
753 }
754 return Ok(target);
755 }
756
757 if let Some(df) = detected_format {
759 let df_str = protocol_to_str(df);
760 if !df_str.is_empty() && protocols.iter().any(|p| p.protocol == df_str) {
761 return Ok(df);
762 }
763 }
764
765 if let Some(first) = protocols.first()
767 && !first.protocol.is_empty()
768 {
769 return parse_protocol(&first.protocol);
770 }
771
772 Err(ProxyError::Internal(anyhow::anyhow!(
773 "channel has no protocols configured"
774 )))
775}
776
777fn resolve_upstream_url(
788 protocol: ApiFormat,
789 protocols: &[ProtocolEntry],
790) -> Result<(String, Option<String>), ProxyError> {
791 let target = protocol_to_str(protocol);
792
793 let entry = protocols
794 .iter()
795 .find(|e| e.protocol == target)
796 .ok_or_else(|| {
797 ProxyError::Internal(anyhow::anyhow!(
798 "no protocol entry for '{target}' in channel protocols"
799 ))
800 })?;
801
802 if entry.base_url.is_empty() {
803 return Err(ProxyError::Internal(anyhow::anyhow!(
804 "protocol entry '{target}' has empty base_url"
805 )));
806 }
807
808 Ok((entry.base_url.clone(), entry.rewrite_path.clone()))
809}
810
811fn protocol_to_str(protocol: ApiFormat) -> &'static str {
814 match protocol {
815 ApiFormat::AnthropicMessages => "anthropic_messages",
816 ApiFormat::OpenaiChat => "openai_chat",
817 ApiFormat::OpenaiResponses => "openai_responses",
818 _ => "",
819 }
820}
821
822fn parse_protocol(s: &str) -> Result<ApiFormat, ProxyError> {
823 match s {
824 "anthropic_messages" => Ok(ApiFormat::AnthropicMessages),
825 "openai_chat" => Ok(ApiFormat::OpenaiChat),
826 "openai_responses" => Ok(ApiFormat::OpenaiResponses),
827 other => Err(ProxyError::Internal(anyhow::anyhow!(
828 "unknown protocol in storage: {other}"
829 ))),
830 }
831}
832
833fn extract_token_count(body: &serde_json::Value) -> u64 {
835 body.get("usage").map_or(0, |u| {
836 u.get("input_tokens")
837 .and_then(serde_json::Value::as_u64)
838 .unwrap_or(0)
839 + u.get("output_tokens")
840 .and_then(serde_json::Value::as_u64)
841 .unwrap_or(0)
842 })
843}
844
845#[cfg(test)]
848#[allow(
849 clippy::unwrap_used,
850 clippy::unwrap_in_result,
851 clippy::unchecked_duration_subtraction,
852 clippy::panic
853)]
854mod tests {
855 use std::time::Duration;
856
857 use super::*;
858 use crate::types::ChannelHealth;
859
860 fn make_channel(
861 id: &str,
862 name: &str,
863 protocols: Vec<ProtocolEntry>,
864 mappings: Vec<ResolvedMapping>,
865 ) -> ResolvedChannel {
866 ResolvedChannel {
867 channel_id: id.into(),
868 channel_name: name.into(),
869 api_key: secrecy::SecretString::from("sk-test"),
870 protocols,
871 enabled: true,
872 force_protocol: None,
873 priority: 0,
874 mappings,
875 }
876 }
877
878 fn make_mapping_flatfee(
879 client: &str,
880 upstream: &str,
881 exhausted: ExhaustedAction,
882 ) -> ResolvedMapping {
883 ResolvedMapping {
884 mapping_id: format!("test:{client}"),
885 client_name: client.into(),
886 upstream_name: upstream.into(),
887 billing: ChannelBilling::FlatFee {
888 monthly_cost_hint: None,
889 quota: Some(Quota::Unlimited),
890 on_exhausted: exhausted,
891 },
892 allowed_protocols: Vec::new(),
893 }
894 }
895
896 fn make_protocols(protocol: ApiFormat, base_url: &str) -> Vec<ProtocolEntry> {
897 vec![ProtocolEntry {
898 protocol: protocol_to_str(protocol).to_string(),
899 base_url: base_url.to_string(),
900 rewrite_path: None,
901 }]
902 }
903
904 fn make_mapping_metered(client: &str, upstream: &str) -> ResolvedMapping {
905 ResolvedMapping {
906 mapping_id: format!("test:{client}"),
907 client_name: client.into(),
908 upstream_name: upstream.into(),
909 billing: ChannelBilling::Metered {
910 pricing: Pricing::PerToken {
911 input_per_mtok: 3.0,
912 output_per_mtok: 15.0,
913 cache_write_per_mtok: None,
914 cache_read_per_mtok: None,
915 thinking_per_mtok: None,
916 currency: "USD".to_string(),
917 },
918 },
919 allowed_protocols: Vec::new(),
920 }
921 }
922
923 fn make_middleware(channels: Vec<ResolvedChannel>) -> ModelRouterMiddleware {
924 ModelRouterMiddleware {
925 channels: Arc::new(ArcSwap::from_pointee(channels)),
926 health: Arc::new(DashMap::new()),
927 quota_usage: Arc::new(DashMap::new()),
928 channel_api_keys: Arc::new(DashMap::new()),
929 }
930 }
931
932 #[test]
935 fn test_select_flatfee_has_quota_and_healthy() {
936 let mw = make_middleware(vec![
937 make_channel(
938 "sub",
939 "Subscription",
940 make_protocols(ApiFormat::AnthropicMessages, "https://sub.example.com"),
941 vec![make_mapping_flatfee(
942 "claude-sonnet",
943 "claude-sonnet-4-7",
944 ExhaustedAction::FallbackToMetered,
945 )],
946 ),
947 make_channel(
948 "metered",
949 "Metered",
950 make_protocols(ApiFormat::AnthropicMessages, "https://metered.example.com"),
951 vec![make_mapping_metered("claude-sonnet", "claude-opus-4-7")],
952 ),
953 ]);
954
955 let channels = mw.channels.load();
956 let candidates = ModelRouterMiddleware::find_candidates(&channels, "claude-sonnet");
957 let (ch, m) = mw.select_channel(&candidates, "claude-sonnet").unwrap();
958 assert_eq!(ch.channel_id, "sub");
959 assert!(m.billing.is_flat_fee());
960 }
961
962 #[test]
963 fn test_select_metered_when_flatfee_exhausted_fallback() {
964 let mw = make_middleware(vec![
965 make_channel(
966 "sub-exhausted",
967 "Subscription",
968 make_protocols(ApiFormat::AnthropicMessages, "https://sub.example.com"),
969 vec![ResolvedMapping {
970 mapping_id: "flatfee-exhausted".into(),
971 client_name: "claude-sonnet".into(),
972 upstream_name: "claude-sonnet-4-7".into(),
973 billing: ChannelBilling::FlatFee {
974 monthly_cost_hint: None,
975 quota: Some(Quota::MaxRequests { per_month: 0 }),
976 on_exhausted: ExhaustedAction::FallbackToMetered,
977 },
978 allowed_protocols: Vec::new(),
979 }],
980 ),
981 make_channel(
982 "metered",
983 "Metered",
984 make_protocols(ApiFormat::AnthropicMessages, "https://metered.example.com"),
985 vec![make_mapping_metered("claude-sonnet", "claude-opus-4-7")],
986 ),
987 ]);
988
989 let channels = mw.channels.load();
990 let candidates = ModelRouterMiddleware::find_candidates(&channels, "claude-sonnet");
991 let (ch, _m) = mw.select_channel(&candidates, "claude-sonnet").unwrap();
992 assert_eq!(ch.channel_id, "metered");
993 }
994
995 #[test]
996 fn test_select_block_when_flatfee_exhausted_block() {
997 let mw = make_middleware(vec![make_channel(
998 "sub-blocked",
999 "Subscription",
1000 make_protocols(ApiFormat::AnthropicMessages, "https://sub.example.com"),
1001 vec![ResolvedMapping {
1002 mapping_id: "flatfee-blocked".into(),
1003 client_name: "claude-sonnet".into(),
1004 upstream_name: "claude-sonnet-4-7".into(),
1005 billing: ChannelBilling::FlatFee {
1006 monthly_cost_hint: None,
1007 quota: Some(Quota::MaxRequests { per_month: 0 }),
1008 on_exhausted: ExhaustedAction::Block,
1009 },
1010 allowed_protocols: Vec::new(),
1011 }],
1012 )]);
1013
1014 let channels = mw.channels.load();
1015 let candidates = ModelRouterMiddleware::find_candidates(&channels, "claude-sonnet");
1016 let err = mw.select_channel(&candidates, "claude-sonnet").unwrap_err();
1017 assert!(matches!(err, ProxyError::ChannelSelection { .. }));
1018 }
1019
1020 #[test]
1021 fn test_select_all_unhealthy_returns_error() {
1022 let mw = make_middleware(vec![
1023 make_channel(
1024 "m1",
1025 "Metered1",
1026 make_protocols(ApiFormat::AnthropicMessages, "https://m1.example.com"),
1027 vec![make_mapping_metered("claude-sonnet", "claude-opus-4-7")],
1028 ),
1029 make_channel(
1030 "m2",
1031 "Metered2",
1032 make_protocols(ApiFormat::AnthropicMessages, "https://m2.example.com"),
1033 vec![make_mapping_metered("claude-sonnet", "claude-haiku-4-5")],
1034 ),
1035 ]);
1036
1037 mw.mark_unhealthy_immediate("m1");
1038 mw.mark_unhealthy_immediate("m2");
1039
1040 let channels = mw.channels.load();
1041 let candidates = ModelRouterMiddleware::find_candidates(&channels, "claude-sonnet");
1042 let err = mw.select_channel(&candidates, "claude-sonnet").unwrap_err();
1043 assert!(matches!(err, ProxyError::ChannelSelection { .. }));
1044 }
1045
1046 #[test]
1047 fn test_no_candidates_for_unknown_model() {
1048 let mw = make_middleware(vec![make_channel(
1049 "m1",
1050 "Metered1",
1051 make_protocols(ApiFormat::AnthropicMessages, "https://m1.example.com"),
1052 vec![make_mapping_metered("claude-sonnet", "claude-opus-4-7")],
1053 )]);
1054
1055 let channels = mw.channels.load();
1056 let candidates = ModelRouterMiddleware::find_candidates(&channels, "nonexistent-model");
1057 assert!(candidates.is_empty());
1058 }
1059
1060 #[test]
1061 fn test_disabled_channel_skipped() {
1062 let mw = ModelRouterMiddleware {
1063 quota_usage: Arc::new(DashMap::new()),
1064 channels: Arc::new(ArcSwap::from_pointee(vec![ResolvedChannel {
1065 channel_id: "disabled".into(),
1066 channel_name: "Disabled".into(),
1067 api_key: secrecy::SecretString::from("sk-test"),
1068 protocols: make_protocols(
1069 ApiFormat::AnthropicMessages,
1070 "https://disabled.example.com",
1071 ),
1072 enabled: false,
1073 force_protocol: None,
1074 priority: 0,
1075 mappings: vec![make_mapping_metered("claude-sonnet", "claude-opus-4-7")],
1076 }])),
1077 health: Arc::new(DashMap::new()),
1078 channel_api_keys: Arc::new(DashMap::new()),
1079 };
1080
1081 let channels = mw.channels.load();
1082 let candidates = ModelRouterMiddleware::find_candidates(&channels, "claude-sonnet");
1083 assert!(candidates.is_empty());
1084 }
1085
1086 #[test]
1089 fn test_resolve_upstream_url_returns_base_url_and_rewrite_path() {
1090 let protocols = vec![ProtocolEntry {
1091 protocol: "openai_chat".into(),
1092 base_url: "https://api.deepseek.com".into(),
1093 rewrite_path: Some("/chat/completions".into()),
1094 }];
1095 let (base_url, rewrite_path) =
1096 resolve_upstream_url(ApiFormat::OpenaiChat, &protocols).unwrap();
1097 assert_eq!(base_url, "https://api.deepseek.com");
1098 assert_eq!(rewrite_path, Some("/chat/completions".into()));
1099 }
1100
1101 #[test]
1102 fn test_resolve_upstream_url_no_rewrite_path() {
1103 let protocols = vec![ProtocolEntry {
1104 protocol: "openai_chat".into(),
1105 base_url: "https://api.deepseek.com".into(),
1106 rewrite_path: None,
1107 }];
1108 let (base_url, rewrite_path) =
1109 resolve_upstream_url(ApiFormat::OpenaiChat, &protocols).unwrap();
1110 assert_eq!(base_url, "https://api.deepseek.com");
1111 assert_eq!(rewrite_path, None);
1112 }
1113
1114 #[test]
1115 fn test_resolve_upstream_url_no_matching_protocol() {
1116 let protocols = vec![ProtocolEntry {
1117 protocol: "openai_chat".into(),
1118 base_url: "https://api.deepseek.com".into(),
1119 rewrite_path: None,
1120 }];
1121 let result = resolve_upstream_url(ApiFormat::AnthropicMessages, &protocols);
1122 assert!(result.is_err());
1123 }
1124
1125 #[test]
1126 fn test_resolve_upstream_url_empty_base_url() {
1127 let protocols = vec![ProtocolEntry {
1128 protocol: "openai_chat".into(),
1129 base_url: String::new(),
1130 rewrite_path: None,
1131 }];
1132 let result = resolve_upstream_url(ApiFormat::OpenaiChat, &protocols);
1133 assert!(result.is_err());
1134 }
1135
1136 fn make_protocol_entries(entries: &[(&str, &str)]) -> Vec<ProtocolEntry> {
1139 entries
1140 .iter()
1141 .map(|&(protocol, base_url)| ProtocolEntry {
1142 protocol: protocol.to_owned(),
1143 base_url: base_url.to_owned(),
1144 rewrite_path: None,
1145 })
1146 .collect()
1147 }
1148
1149 #[test]
1150 fn test_resolve_target_protocol_force_valid() {
1151 let protocols = make_protocol_entries(&[
1152 ("openai_chat", "https://api.example.com"),
1153 ("anthropic_messages", "https://api.example.com/anthropic"),
1154 ]);
1155 let result = resolve_target_protocol(
1156 Some("openai_chat"),
1157 Some(ApiFormat::AnthropicMessages),
1158 &protocols,
1159 )
1160 .unwrap();
1161 assert_eq!(result, ApiFormat::OpenaiChat);
1162 }
1163
1164 #[test]
1165 fn test_resolve_target_protocol_force_not_in_protocols() {
1166 let protocols = make_protocol_entries(&[("openai_chat", "https://api.example.com")]);
1167 let result = resolve_target_protocol(Some("anthropic_messages"), None, &protocols);
1168 assert!(result.is_err());
1169 }
1170
1171 #[test]
1172 fn test_resolve_target_protocol_passthrough_client_match() {
1173 let protocols = make_protocol_entries(&[
1174 ("openai_chat", "https://api.example.com"),
1175 ("anthropic_messages", "https://api.example.com/anthropic"),
1176 ]);
1177 let result =
1178 resolve_target_protocol(None, Some(ApiFormat::AnthropicMessages), &protocols).unwrap();
1179 assert_eq!(result, ApiFormat::AnthropicMessages);
1180 }
1181
1182 #[test]
1183 fn test_resolve_target_protocol_fallback_to_first() {
1184 let protocols = make_protocol_entries(&[
1185 ("openai_chat", "https://api.example.com"),
1186 ("anthropic_messages", "https://api.example.com/anthropic"),
1187 ]);
1188 let result =
1190 resolve_target_protocol(None, Some(ApiFormat::OpenaiResponses), &protocols).unwrap();
1191 assert_eq!(result, ApiFormat::OpenaiChat);
1192 }
1193
1194 #[test]
1195 fn test_resolve_target_protocol_no_client_format() {
1196 let protocols = make_protocol_entries(&[("openai_chat", "https://api.example.com")]);
1197 let result = resolve_target_protocol(None, None, &protocols).unwrap();
1199 assert_eq!(result, ApiFormat::OpenaiChat);
1200 }
1201
1202 #[test]
1203 fn test_resolve_target_protocol_empty_protocols() {
1204 let result = resolve_target_protocol(None, Some(ApiFormat::AnthropicMessages), &[]);
1205 assert!(result.is_err());
1206 }
1207
1208 #[test]
1209 fn test_resolve_target_protocol_force_with_empty_protocols() {
1210 let result = resolve_target_protocol(Some("anthropic_messages"), None, &[]);
1211 assert!(result.is_err());
1212 }
1213
1214 #[test]
1217 fn test_health_mark_unhealthy_then_healthy() {
1218 let mw = make_middleware(vec![]);
1219 mw.mark_unhealthy_immediate("ch1");
1220 assert!(!mw.is_healthy("ch1"));
1221
1222 mw.mark_healthy("ch1");
1223 assert!(mw.is_healthy("ch1"));
1224 }
1225
1226 #[test]
1227 fn test_health_cooldown_expired() {
1228 let mw = make_middleware(vec![]);
1229 mw.health.insert(
1230 "ch1".to_owned(),
1231 ChannelState {
1232 health: ChannelHealth::Unhealthy,
1233 consecutive_failures: 0,
1234 failed_at: Some(std::time::Instant::now() - Duration::from_secs(61)),
1235 },
1236 );
1237 assert!(mw.is_healthy("ch1"));
1238 }
1239
1240 fn make_channel_with_key(
1243 id: &str,
1244 api_key: &str,
1245 protocols: Vec<ProtocolEntry>,
1246 mappings: Vec<ResolvedMapping>,
1247 ) -> ResolvedChannel {
1248 ResolvedChannel {
1249 channel_id: id.into(),
1250 channel_name: id.into(),
1251 api_key: secrecy::SecretString::from(api_key),
1252 protocols,
1253 enabled: true,
1254 force_protocol: None,
1255 priority: 10,
1256 mappings,
1257 }
1258 }
1259
1260 #[test]
1261 fn test_channel_with_empty_api_key_is_skipped() {
1262 let mw = make_middleware(vec![
1263 make_channel_with_key(
1264 "no-key",
1265 "",
1266 make_protocols(ApiFormat::AnthropicMessages, "https://no-key.example.com"),
1267 vec![make_mapping_metered("claude-sonnet", "claude-sonnet-v1")],
1268 ),
1269 make_channel_with_key(
1270 "has-key",
1271 "sk-valid",
1272 make_protocols(ApiFormat::AnthropicMessages, "https://has-key.example.com"),
1273 vec![make_mapping_metered("claude-sonnet", "claude-sonnet-v2")],
1274 ),
1275 ]);
1276
1277 let channels = mw.channels.load();
1278 let candidates = ModelRouterMiddleware::find_candidates(&channels, "claude-sonnet");
1279 assert_eq!(candidates.len(), 2);
1280 let (ch, _m) = mw.select_channel(&candidates, "claude-sonnet").unwrap();
1281 assert_eq!(ch.channel_id, "has-key", "should skip channel with empty API key");
1282 }
1283
1284 #[test]
1285 fn test_all_channels_empty_key_returns_error() {
1286 let mw = make_middleware(vec![make_channel_with_key(
1287 "no-key-1",
1288 "",
1289 make_protocols(ApiFormat::AnthropicMessages, "https://no1.example.com"),
1290 vec![make_mapping_metered("claude-sonnet", "claude-sonnet-v1")],
1291 )]);
1292
1293 let channels = mw.channels.load();
1294 let candidates = ModelRouterMiddleware::find_candidates(&channels, "claude-sonnet");
1295 let err = mw.select_channel(&candidates, "claude-sonnet").unwrap_err();
1296 assert!(
1297 matches!(err, ProxyError::ChannelSelection { .. }),
1298 "should error when no channel has a valid API key"
1299 );
1300 }
1301
1302 #[test]
1303 fn test_has_api_key_runtime_override() {
1304 let mw = make_middleware(vec![make_channel_with_key(
1305 "no-key-stored",
1306 "",
1307 make_protocols(ApiFormat::AnthropicMessages, "https://no-key.example.com"),
1308 vec![make_mapping_metered("claude-sonnet", "claude-sonnet-v1")],
1309 )]);
1310
1311 assert!(!mw.has_api_key("no-key-stored"));
1313
1314 mw.channel_api_keys
1316 .insert("no-key-stored".into(), secrecy::SecretString::from("sk-override"));
1317 assert!(mw.has_api_key("no-key-stored"));
1318 }
1319
1320 #[test]
1321 fn test_empty_key_skipped_in_fallback_phase() {
1322 let mw = make_middleware(vec![
1324 make_channel_with_key(
1325 "no-key",
1326 "",
1327 make_protocols(ApiFormat::AnthropicMessages, "https://no-key.example.com"),
1328 vec![make_mapping_metered("claude-sonnet", "claude-sonnet-v1")],
1329 ),
1330 make_channel_with_key(
1331 "has-key",
1332 "sk-valid",
1333 make_protocols(ApiFormat::AnthropicMessages, "https://has-key.example.com"),
1334 vec![make_mapping_metered("claude-sonnet", "claude-sonnet-v2")],
1335 ),
1336 ]);
1337
1338 for ch_id in ["no-key", "has-key"] {
1340 mw.health.insert(
1341 ch_id.to_owned(),
1342 ChannelState {
1343 health: ChannelHealth::Unhealthy,
1344 consecutive_failures: 1,
1345 failed_at: Some(std::time::Instant::now() - Duration::from_secs(61)),
1346 },
1347 );
1348 }
1349
1350 let channels = mw.channels.load();
1351 let candidates = ModelRouterMiddleware::find_candidates(&channels, "claude-sonnet");
1352 let (ch, _m) = mw.select_channel(&candidates, "claude-sonnet").unwrap();
1354 assert_eq!(ch.channel_id, "has-key", "fallback should skip empty-key channel");
1355 }
1356}