Skip to main content

agent_proxy_rust_model_router/
lib.rs

1//! Model routing and channel selection middleware.
2//!
3//! Implements the selection strategy from specs/0003-channel-model.md Phase 1:
4//! `FlatFee` channels with quota > 0 and healthy are preferred; `Metered` channels
5//! serve as fallback. Health tracking uses a simple binary model with 60s
6//! cooldown.
7
8#![forbid(unsafe_code)]
9#![warn(missing_docs, missing_debug_implementations)]
10
11mod types;
12
13use std::{sync::Arc, time::Duration};
14
15use agent_proxy_rust_core::{
16    ProxyError,
17    extensions::{EXT_SELECTED_CHANNEL, EXT_SELECTED_MAPPING},
18    middleware::ProxyMiddleware,
19    types::{ApiFormat, ChannelConfig, ConnectionContext, ProxyRequest, ProxyResponse},
20};
21use agent_proxy_rust_storage::ProtocolEntry;
22use agent_proxy_rust_storage::Storage;
23use arc_swap::ArcSwap;
24use async_trait::async_trait;
25use dashmap::DashMap;
26use secrecy::ExposeSecret;
27use tracing::{debug, warn};
28pub use types::{
29    BillingDimension, ChannelBilling, ChannelHealth, ChannelState, ExhaustedAction, Pricing,
30    PricingTier, Quota, QuotaUsage, TierPrice,
31};
32
33/// Cooldown period before an unhealthy channel is retried.
34const COOLDOWN: Duration = Duration::from_secs(60);
35
36/// Parsed in-memory representation of a channel with its model mappings.
37#[derive(Debug, Clone)]
38pub struct ResolvedChannel {
39    /// Channel ID.
40    pub channel_id: String,
41    /// Human-readable channel name.
42    pub channel_name: String,
43    /// API key for upstream requests.
44    pub api_key: secrecy::SecretString,
45    /// Supported protocols parsed from the channel's JSON configuration.
46    pub protocols: Vec<ProtocolEntry>,
47    /// Whether the channel is enabled.
48    pub enabled: bool,
49    /// Optional protocol override.
50    pub force_protocol: Option<String>,
51    /// Routing priority (higher = selected first).
52    pub priority: u32,
53    /// Model mappings bound to this channel.
54    pub mappings: Vec<ResolvedMapping>,
55}
56
57impl ResolvedChannel {
58    /// Returns the protocol identifiers supported by this channel.
59    #[allow(dead_code)]
60    fn supported_protocols(&self) -> Vec<&str> {
61        self.protocols.iter().map(|p| p.protocol.as_str()).collect()
62    }
63}
64
65/// Parsed in-memory representation of a model mapping.
66#[derive(Debug, Clone)]
67pub struct ResolvedMapping {
68    /// Mapping ID for quota tracking.
69    pub mapping_id: String,
70    /// Client-facing model name.
71    pub client_name: String,
72    /// Upstream model name sent to the API.
73    pub upstream_name: String,
74    /// Billing type (flat-fee or metered).
75    pub billing: ChannelBilling,
76    /// Protocols this mapping is valid for. Empty = all protocols (backward compatible).
77    pub allowed_protocols: Vec<String>,
78}
79
80/// Lightweight mapping info stored in the context extension.
81#[derive(Debug, Clone)]
82pub struct SelectedMappingInfo {
83    /// Channel ID for cost tracking.
84    pub channel_id: String,
85    /// Model mapping ID for quota tracking.
86    pub mapping_id: String,
87    /// Client-facing model name.
88    pub client_name: String,
89    /// Upstream model name sent to the API.
90    pub upstream_name: String,
91    /// Whether this mapping uses flat-fee billing.
92    pub is_flat_fee: bool,
93    /// Pricing snapshot at selection time (metered only).
94    pub pricing: Option<Pricing>,
95    /// Serialized pricing for audit trail.
96    pub pricing_snapshot_json: String,
97}
98
99/// Channel selection and model routing middleware.
100pub struct ModelRouterMiddleware {
101    channels: Arc<ArcSwap<Vec<ResolvedChannel>>>,
102    health: Arc<DashMap<String, ChannelState>>,
103    /// Per-mapping-id quota consumption counters. Keys match `model_mappings.id`.
104    quota_usage: Arc<DashMap<String, QuotaUsage>>,
105    /// Shared API key overrides: populated at startup from DB, updated at
106    /// runtime by the admin API. The router looks up keys here first,
107    /// falling back to the `ResolvedChannel::api_key`.
108    channel_api_keys: Arc<DashMap<String, secrecy::SecretString>>,
109}
110
111impl std::fmt::Debug for ModelRouterMiddleware {
112    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113        f.debug_struct("ModelRouterMiddleware")
114            .field("channels", &self.channels.load())
115            .field("health", &self.health)
116            .field("quota_usage", &self.quota_usage)
117            .field("channel_api_keys", &"<DashMap>")
118            .finish()
119    }
120}
121
122impl ModelRouterMiddleware {
123    /// Creates a new [`ModelRouterMiddleware`] from a storage backend.
124    ///
125    /// Loads all enabled channels and their model mappings, parsing the
126    /// storage-layer string fields into typed enums.
127    ///
128    /// # Errors
129    ///
130    /// Returns `ProxyError` if the storage backend fails or if any channel
131    /// has an unrecognized protocol string.
132    pub async fn from_storage(storage: Arc<dyn Storage>) -> Result<Self, ProxyError> {
133        let storage_channels = storage
134            .list_channels(None)
135            .await
136            .map_err(|e| ProxyError::Internal(e.into()))?;
137
138        let mut channels = Vec::with_capacity(storage_channels.len());
139
140        for ch in storage_channels {
141            // Parse protocols JSON into typed entries; skip channels with no protocols
142            let protocols: Vec<ProtocolEntry> =
143                serde_json::from_str(&ch.protocols).unwrap_or_default();
144            if protocols.is_empty() {
145                warn!(
146                    channel = %ch.id,
147                    "channel has no protocols configured, skipping"
148                );
149                continue;
150            }
151
152            let storage_mappings = storage
153                .list_mappings(&ch.id)
154                .await
155                .map_err(|e| ProxyError::Internal(e.into()))?;
156
157            let mappings: Vec<ResolvedMapping> = storage_mappings
158                .into_iter()
159                .filter(|m| m.enabled)
160                .filter_map(|m| {
161                    let billing = ChannelBilling::from_storage(&m.billing, &m.pricing_json)
162                        .map_err(|e| {
163                            warn!(
164                                channel = %ch.id,
165                                mapping = %m.id,
166                                error = %e,
167                                "failed to parse mapping billing/pricing, skipping"
168                            );
169                        })
170                        .ok()?;
171                    let allowed_protocols: Vec<String> =
172                        serde_json::from_str(&m.protocols).unwrap_or_default();
173                    Some(ResolvedMapping {
174                        mapping_id: m.id,
175                        client_name: m.client_name,
176                        upstream_name: m.upstream_name,
177                        billing,
178                        allowed_protocols,
179                    })
180                })
181                .collect();
182
183            // Normalize: trim trailing slashes and treat empty rewrite_path as None
184            let protocols: Vec<ProtocolEntry> = protocols
185                .into_iter()
186                .map(|mut p| {
187                    p.base_url = p.base_url.trim_end_matches('/').to_string();
188                    p.rewrite_path = p.rewrite_path.filter(|rp| !rp.is_empty());
189                    p
190                })
191                .collect();
192
193            channels.push(ResolvedChannel {
194                channel_id: ch.id,
195                channel_name: ch.name,
196                api_key: ch.api_key,
197                protocols,
198                enabled: ch.enabled,
199                force_protocol: ch.force_protocol,
200                priority: ch.priority,
201                mappings,
202            });
203        }
204
205        // Pre-populate the in-memory health map: channels without API keys
206        // start as Unhealthy so the router won't select them. Also persist
207        // the status to DB so the Admin API reflects it.
208        let health: Arc<DashMap<String, ChannelState>> = Arc::new(DashMap::new());
209        for ch in &channels {
210            if ch.api_key.expose_secret().is_empty() {
211                // Mark unhealthy in-memory (3 consecutive failures → Unhealthy + cooldown)
212                health
213                    .entry(ch.channel_id.clone())
214                    .or_default()
215                    .mark_unhealthy();
216                // Persist to DB for admin UI display
217                let _ = storage.record_channel_failure(&ch.channel_id).await;
218                let _ = storage.record_channel_failure(&ch.channel_id).await;
219                let _ = storage.record_channel_failure(&ch.channel_id).await;
220                tracing::info!(channel=%ch.channel_id, name=%ch.channel_name, "no API key — Unavailable");
221            }
222        }
223
224        // Build the shared API-key override map from the DB values.
225        // When the admin API updates a key, it writes here so the router
226        // picks up the new key without a restart.
227        let channel_api_keys: Arc<DashMap<String, secrecy::SecretString>> =
228            Arc::new(DashMap::new());
229        for ch in &channels {
230            if !ch.api_key.expose_secret().is_empty() {
231                channel_api_keys.insert(ch.channel_id.clone(), ch.api_key.clone());
232            }
233        }
234
235        Ok(Self {
236            channels: Arc::new(ArcSwap::from_pointee(channels)),
237            health,
238            quota_usage: Arc::new(DashMap::new()),
239            channel_api_keys,
240        })
241    }
242
243    /// Returns a reference to the in-memory health map.
244    #[must_use]
245    pub fn health_map(&self) -> &Arc<DashMap<String, ChannelState>> {
246        &self.health
247    }
248
249    /// Returns a reference to the shared API-key override map.
250    ///
251    /// The admin API writes updated keys here so the router picks them up
252    /// at request time without needing a restart.
253    #[must_use]
254    pub fn api_key_map(&self) -> &Arc<DashMap<String, secrecy::SecretString>> {
255        &self.channel_api_keys
256    }
257
258    /// Returns a clone of the atomic channel list Arc so the admin API can
259    /// trigger a hot-reload after mutations (priority, enabled, etc.).
260    #[must_use]
261    pub fn channels_swap(&self) -> Arc<ArcSwap<Vec<ResolvedChannel>>> {
262        Arc::clone(&self.channels)
263    }
264
265    /// Finds all candidate mappings for a given client model name.
266    fn find_candidates<'c>(
267        channels: &'c [ResolvedChannel],
268        client_name: &str,
269    ) -> Vec<(&'c ResolvedChannel, &'c ResolvedMapping)> {
270        let mut candidates = Vec::new();
271        for ch in channels {
272            if !ch.enabled {
273                continue;
274            }
275            for m in &ch.mappings {
276                if m.client_name == client_name {
277                    candidates.push((ch, m));
278                }
279            }
280        }
281        candidates
282    }
283
284    /// Applies the Phase 1 selection strategy.
285    fn select_channel<'a>(
286        &self,
287        candidates: &[(&'a ResolvedChannel, &'a ResolvedMapping)],
288        client_name: &str,
289    ) -> Result<(&'a ResolvedChannel, &'a ResolvedMapping), ProxyError> {
290        let (mut flatfee, mut metered): (Vec<_>, Vec<_>) = candidates
291            .iter()
292            .partition(|(_, m)| m.billing.is_flat_fee());
293
294        // Sort by channel priority: higher = selected first
295        flatfee.sort_by_key(|(ch, _)| std::cmp::Reverse(ch.priority));
296        metered.sort_by_key(|(ch, _)| std::cmp::Reverse(ch.priority));
297
298        // Phase 1: try FlatFee channels first
299        for (ch, m) in &flatfee {
300            if !self.has_api_key(&ch.channel_id) {
301                debug!(
302                    channel = %ch.channel_id,
303                    "skipping flat-fee channel: no API key configured"
304                );
305                continue;
306            }
307            if !self.is_healthy(&ch.channel_id) {
308                continue;
309            }
310            if let ChannelBilling::FlatFee {
311                on_exhausted,
312                quota,
313                ..
314            } = &m.billing
315            {
316                // Check actual monthly consumption against quota
317                let within_quota = self
318                    .quota_usage
319                    .entry(m.mapping_id.clone())
320                    .or_default()
321                    .is_within_quota(quota.as_ref());
322
323                if within_quota {
324                    return Ok((ch, m));
325                }
326                if *on_exhausted == ExhaustedAction::Block {
327                    debug!(
328                        channel = %ch.channel_id,
329                        model = %client_name,
330                        "flat-fee channel quota exhausted, blocking"
331                    );
332                    return Err(ProxyError::ChannelSelection {
333                        model: client_name.to_owned(),
334                    });
335                }
336            }
337        }
338
339        // Phase 1: try Metered channels
340        for (ch, m) in &metered {
341            if !self.has_api_key(&ch.channel_id) {
342                debug!(
343                    channel = %ch.channel_id,
344                    "skipping metered channel: no API key configured"
345                );
346                continue;
347            }
348            if self.is_healthy(&ch.channel_id) {
349                return Ok((ch, m));
350            }
351        }
352
353        // All unhealthy — try any channel past cooldown that has a valid key
354        for (ch, m) in candidates {
355            if !self.has_api_key(&ch.channel_id) {
356                continue;
357            }
358            if self.is_tryable_past_cooldown(&ch.channel_id) {
359                warn!(
360                    channel = %ch.channel_id,
361                    model = %client_name,
362                    "all channels unhealthy, retrying past cooldown"
363                );
364                return Ok((ch, m));
365            }
366        }
367
368        Err(ProxyError::ChannelSelection {
369            model: client_name.to_owned(),
370        })
371    }
372
373    /// Returns `true` when a channel has a usable API key.
374    ///
375    /// Checks the runtime override map first (so admin API key updates take
376    /// effect immediately), then falls back to the key loaded from storage
377    /// at startup. Channels without any key are permanently excluded from
378    /// selection to avoid repeated 401 failures.
379    fn has_api_key(&self, channel_id: &str) -> bool {
380        // Runtime override from admin API takes precedence
381        if let Some(key) = self.channel_api_keys.get(channel_id) {
382            return !key.expose_secret().is_empty();
383        }
384        // Check the key loaded from storage at startup
385        self.channels
386            .load()
387            .iter()
388            .any(|ch| ch.channel_id == channel_id && !ch.api_key.expose_secret().is_empty())
389    }
390
391    fn is_healthy(&self, channel_id: &str) -> bool {
392        self.health
393            .get(channel_id)
394            .is_none_or(|s| s.is_tryable(COOLDOWN))
395    }
396
397    fn is_tryable_past_cooldown(&self, channel_id: &str) -> bool {
398        self.health
399            .get(channel_id)
400            .is_none_or(|s| s.is_tryable(COOLDOWN))
401    }
402
403    fn mark_healthy(&self, channel_id: &str) {
404        if let Some(mut state) = self.health.get_mut(channel_id) {
405            state.record_success();
406        }
407    }
408
409    /// Records a request failure. After 3 consecutive failures the
410    /// channel is marked Unhealthy with a 60 s cooldown.
411    fn record_failure(&self, channel_id: &str) {
412        let mut state = self.health.entry(channel_id.to_owned()).or_default();
413        state.record_failure();
414    }
415
416    /// Forces a channel to Unhealthy immediately (e.g. 5xx server error).
417    fn mark_unhealthy_immediate(&self, channel_id: &str) {
418        self.health
419            .entry(channel_id.to_owned())
420            .or_default()
421            .mark_unhealthy();
422    }
423}
424
425/// Hot-reloads the in-memory channel list from storage and atomically swaps
426/// it into `channels_swap`.
427///
428/// Called by the admin API after channel mutations (create, update, delete)
429/// so that priority, enabled, protocol, and mapping changes take effect
430/// immediately without requiring a proxy restart.
431///
432/// # Errors
433///
434/// Returns `ProxyError::Internal` if the storage backend fails or if any
435/// channel has an unrecognized protocol string.
436pub async fn reload_channels_from_storage(
437    storage: &dyn Storage,
438    channels_swap: &ArcSwap<Vec<ResolvedChannel>>,
439) -> Result<(), ProxyError> {
440    let storage_channels = storage
441        .list_channels(None)
442        .await
443        .map_err(|e| ProxyError::Internal(e.into()))?;
444
445    let mut channels = Vec::with_capacity(storage_channels.len());
446
447    for ch in storage_channels {
448        let protocols: Vec<ProtocolEntry> = serde_json::from_str(&ch.protocols).unwrap_or_default();
449        if protocols.is_empty() {
450            warn!(
451                channel = %ch.id,
452                "channel has no protocols configured, skipping"
453            );
454            continue;
455        }
456
457        let storage_mappings = storage
458            .list_mappings(&ch.id)
459            .await
460            .map_err(|e| ProxyError::Internal(e.into()))?;
461
462        let mappings: Vec<ResolvedMapping> = storage_mappings
463            .into_iter()
464            .filter(|m| m.enabled)
465            .filter_map(|m| {
466                let billing = ChannelBilling::from_storage(&m.billing, &m.pricing_json)
467                    .map_err(|e| {
468                        warn!(
469                            channel = %ch.id,
470                            mapping = %m.id,
471                            error = %e,
472                            "failed to parse mapping billing/pricing, skipping"
473                        );
474                    })
475                    .ok()?;
476                let allowed_protocols: Vec<String> =
477                    serde_json::from_str(&m.protocols).unwrap_or_default();
478                Some(ResolvedMapping {
479                    mapping_id: m.id,
480                    client_name: m.client_name,
481                    upstream_name: m.upstream_name,
482                    billing,
483                    allowed_protocols,
484                })
485            })
486            .collect();
487
488        let protocols: Vec<ProtocolEntry> = protocols
489            .into_iter()
490            .map(|mut p| {
491                p.base_url = p.base_url.trim_end_matches('/').to_string();
492                p.rewrite_path = p.rewrite_path.filter(|rp| !rp.is_empty());
493                p
494            })
495            .collect();
496
497        channels.push(ResolvedChannel {
498            channel_id: ch.id,
499            channel_name: ch.name,
500            api_key: ch.api_key,
501            protocols,
502            enabled: ch.enabled,
503            force_protocol: ch.force_protocol,
504            priority: ch.priority,
505            mappings,
506        });
507    }
508
509    channels_swap.store(Arc::new(channels));
510    tracing::info!(count = channels_swap.load().len(), "channels hot-reloaded");
511    Ok(())
512}
513
514#[async_trait]
515impl ProxyMiddleware for ModelRouterMiddleware {
516    #[allow(clippy::too_many_lines)]
517    async fn on_request(
518        &self,
519        req: &mut ProxyRequest,
520        ctx: &mut ConnectionContext,
521    ) -> Result<(), ProxyError> {
522        let mut body: serde_json::Value =
523            serde_json::from_slice(&req.body).map_err(|e| ProxyError::BadRequest(e.to_string()))?;
524
525        let client_name = body
526            .get("model")
527            .and_then(|v| v.as_str())
528            .map(String::from)
529            .unwrap_or_default();
530
531        if client_name.is_empty() {
532            return Err(ProxyError::BadRequest(
533                "request body missing 'model' field".into(),
534            ));
535        }
536
537        // Hold the Arc guard for the entire request so references remain valid.
538        let channels = self.channels.load();
539        let candidates = Self::find_candidates(&channels, &client_name);
540
541        if candidates.is_empty() {
542            return Err(ProxyError::ChannelSelection { model: client_name });
543        }
544
545        let (channel, mapping) = self.select_channel(&candidates, &client_name)?;
546
547        debug!(
548            channel = %channel.channel_id,
549            client_model = %client_name,
550            upstream_model = %mapping.upstream_name,
551            "selected channel"
552        );
553
554        // Replace model name in body
555        if let Some(model_field) = body.get_mut("model") {
556            *model_field = serde_json::Value::String(mapping.upstream_name.clone());
557        }
558        let new_body =
559            serde_json::to_vec(&body).map_err(|e| ProxyError::BadRequest(e.to_string()))?;
560        req.body = bytes::Bytes::from(new_body);
561
562        // Determine target protocol using the 3-step resolution
563        let mut target_protocol = resolve_target_protocol(
564            channel.force_protocol.as_deref(),
565            ctx.detected_format,
566            &channel.protocols,
567        )?;
568
569        // ── Protocol-model compatibility check ───────────────────────
570        //
571        // If the mapping declares protocol constraints (e.g. a model only
572        // works on openai_chat), validate that the resolved protocol is
573        // compatible. When it isn't, switch to the first protocol that
574        // both the mapping and the channel support — the bridge middleware
575        // will handle format conversion.
576        if !mapping.allowed_protocols.is_empty() {
577            let target_str = protocol_to_str(target_protocol);
578            if !mapping.allowed_protocols.iter().any(|p| p == target_str) {
579                // Resolved protocol is not in the mapping's allowed list.
580                // Find the first protocol the channel supports that the
581                // mapping also allows.
582                let compatible = channel.protocols.iter().find(|pe| {
583                    mapping
584                        .allowed_protocols
585                        .iter()
586                        .any(|ap| ap == &pe.protocol)
587                });
588                if let Some(entry) = compatible {
589                    debug!(
590                        channel = %channel.channel_id,
591                        mapping = %mapping.mapping_id,
592                        resolved = %target_str,
593                        switched_to = %entry.protocol,
594                        "mapping protocol constraint: switching target protocol"
595                    );
596                    target_protocol = parse_protocol(&entry.protocol)?;
597                } else {
598                    let channel_prots: Vec<&str> = channel
599                        .protocols
600                        .iter()
601                        .map(|p| p.protocol.as_str())
602                        .collect();
603                    return Err(ProxyError::Internal(anyhow::anyhow!(
604                        "mapping '{}' protocol constraint {:?} incompatible with channel protocols {channel_prots:?}",
605                        mapping.mapping_id,
606                        mapping.allowed_protocols,
607                    )));
608                }
609            }
610        }
611
612        ctx.target_protocol = Some(target_protocol);
613
614        // Resolve upstream URL from protocols entries
615        let (base_url, rewrite_path) = resolve_upstream_url(target_protocol, &channel.protocols)?;
616
617        // Look up the API key from the shared override map first (so admin
618        // API updates take effect without a restart), falling back to the
619        // key that was loaded at startup.
620        let api_key = self
621            .channel_api_keys
622            .get(&channel.channel_id)
623            .map_or_else(|| channel.api_key.clone(), |r| r.clone());
624
625        // Write ChannelConfig to extensions
626        ctx.insert(
627            EXT_SELECTED_CHANNEL,
628            ChannelConfig {
629                url: base_url,
630                api_key,
631                protocol: target_protocol,
632                name: channel.channel_name.clone(),
633                rewrite_path,
634            },
635        );
636
637        // Extract pricing snapshot from billing
638        let (pricing, snapshot_json) = match &mapping.billing {
639            ChannelBilling::Metered { pricing } => {
640                let json = serde_json::to_string(pricing).unwrap_or_default();
641                (Some(pricing.clone()), json)
642            }
643            ChannelBilling::FlatFee { .. } => (None, r#"{"type":"flat_fee"}"#.to_string()),
644        };
645
646        // Write selected mapping info to extensions
647        ctx.insert(
648            EXT_SELECTED_MAPPING,
649            SelectedMappingInfo {
650                channel_id: channel.channel_id.clone(),
651                mapping_id: mapping.mapping_id.clone(),
652                client_name: mapping.client_name.clone(),
653                upstream_name: mapping.upstream_name.clone(),
654                is_flat_fee: mapping.billing.is_flat_fee(),
655                pricing,
656                pricing_snapshot_json: snapshot_json,
657            },
658        );
659
660        Ok(())
661    }
662
663    async fn on_response(
664        &self,
665        res: &mut ProxyResponse,
666        ctx: &ConnectionContext,
667    ) -> Result<(), ProxyError> {
668        let channel_id = ctx
669            .get::<ChannelConfig>(EXT_SELECTED_CHANNEL)
670            .map(|ch| ch.name.clone())
671            .unwrap_or_default();
672
673        if channel_id.is_empty() {
674            return Ok(());
675        }
676
677        // Record quota usage for the selected mapping
678        if let Some(mapping_info) = ctx.get::<SelectedMappingInfo>(EXT_SELECTED_MAPPING)
679            && mapping_info.is_flat_fee
680        {
681            let token_count =
682                serde_json::from_slice(&res.body).map_or(0, |body| extract_token_count(&body));
683            self.quota_usage
684                .entry(mapping_info.mapping_id.clone())
685                .or_default()
686                .record_usage(token_count);
687        }
688
689        if res.status.is_server_error() || res.status == http::StatusCode::UNAUTHORIZED {
690            // 5xx: immediate unhealthy — server is down
691            // 401: authentication failure — API key is missing, invalid, or expired
692            warn!(
693                channel = %channel_id,
694                status = %res.status,
695                "upstream {}, marking channel unhealthy immediately",
696                if res.status.is_server_error() { "5xx" } else { "401 Unauthorized" }
697            );
698            self.mark_unhealthy_immediate(&channel_id);
699        } else if res.status.is_client_error() && res.status.as_u16() != 429 {
700            // 4xx (except 401, 429): client errors — not the channel's fault
701            debug!(
702                channel = %channel_id,
703                status = %res.status,
704                "client error, not counting as channel failure"
705            );
706        } else if res.status == http::StatusCode::TOO_MANY_REQUESTS {
707            // 429: rate limit — counts as a failure
708            warn!(
709                channel = %channel_id,
710                "upstream 429 rate limit, recording failure"
711            );
712            self.record_failure(&channel_id);
713        } else {
714            // 2xx: success
715            self.mark_healthy(&channel_id);
716        }
717
718        Ok(())
719    }
720
721    fn name(&self) -> &'static str {
722        "model-router"
723    }
724}
725
726// ── Helpers ─────────────────────────────────────────────────────────
727
728/// Resolves the target protocol for a request using a 3-step strategy:
729///
730/// 1. If `force_protocol` is set, validate it exists in `protocols` and use it.
731/// 2. Otherwise, if the client's `detected_format` matches a protocol in `protocols`,
732///    use it (passthrough, no conversion).
733/// 3. Otherwise, fall back to the first protocol in `protocols`.
734///
735/// # Errors
736///
737/// Returns `ProxyError::Internal` if `force_protocol` is set but not found in
738/// `protocols`, if the matched protocol string is unrecognized, or if
739/// `protocols` is empty.
740fn resolve_target_protocol(
741    force_protocol: Option<&str>,
742    detected_format: Option<ApiFormat>,
743    protocols: &[ProtocolEntry],
744) -> Result<ApiFormat, ProxyError> {
745    // Step 1: force_protocol must exist in protocols
746    if let Some(fp) = force_protocol {
747        let target = parse_protocol(fp)?;
748        let target_str = protocol_to_str(target);
749        if !protocols.iter().any(|p| p.protocol == target_str) {
750            return Err(ProxyError::Internal(anyhow::anyhow!(
751                "force_protocol '{fp}' not found in channel protocols"
752            )));
753        }
754        return Ok(target);
755    }
756
757    // Step 2: if client protocol is supported, passthrough
758    if let Some(df) = detected_format {
759        let df_str = protocol_to_str(df);
760        if !df_str.is_empty() && protocols.iter().any(|p| p.protocol == df_str) {
761            return Ok(df);
762        }
763    }
764
765    // Step 3: fallback to first protocol
766    if let Some(first) = protocols.first()
767        && !first.protocol.is_empty()
768    {
769        return parse_protocol(&first.protocol);
770    }
771
772    Err(ProxyError::Internal(anyhow::anyhow!(
773        "channel has no protocols configured"
774    )))
775}
776
777/// Resolves the upstream URL for a given protocol from the channel's `protocols` entries.
778///
779/// Returns the `(base_url, rewrite_path)` tuple for the matching protocol entry.
780/// `rewrite_path` is `None` when the entry does not specify a path rewrite — the
781/// original request path should be passed through.
782///
783/// # Errors
784///
785/// Returns `ProxyError::Internal` if no entry matches the target protocol or
786/// if the matched entry has an empty `base_url`.
787fn resolve_upstream_url(
788    protocol: ApiFormat,
789    protocols: &[ProtocolEntry],
790) -> Result<(String, Option<String>), ProxyError> {
791    let target = protocol_to_str(protocol);
792
793    let entry = protocols
794        .iter()
795        .find(|e| e.protocol == target)
796        .ok_or_else(|| {
797            ProxyError::Internal(anyhow::anyhow!(
798                "no protocol entry for '{target}' in channel protocols"
799            ))
800        })?;
801
802    if entry.base_url.is_empty() {
803        return Err(ProxyError::Internal(anyhow::anyhow!(
804            "protocol entry '{target}' has empty base_url"
805        )));
806    }
807
808    Ok((entry.base_url.clone(), entry.rewrite_path.clone()))
809}
810
811/// Returns the `snake_case` string representation for an [`ApiFormat`] variant,
812/// matching the format used in the `protocols` JSON column.
813fn protocol_to_str(protocol: ApiFormat) -> &'static str {
814    match protocol {
815        ApiFormat::AnthropicMessages => "anthropic_messages",
816        ApiFormat::OpenaiChat => "openai_chat",
817        ApiFormat::OpenaiResponses => "openai_responses",
818        _ => "",
819    }
820}
821
822fn parse_protocol(s: &str) -> Result<ApiFormat, ProxyError> {
823    match s {
824        "anthropic_messages" => Ok(ApiFormat::AnthropicMessages),
825        "openai_chat" => Ok(ApiFormat::OpenaiChat),
826        "openai_responses" => Ok(ApiFormat::OpenaiResponses),
827        other => Err(ProxyError::Internal(anyhow::anyhow!(
828            "unknown protocol in storage: {other}"
829        ))),
830    }
831}
832
833/// Returns the total token count from an upstream response body for quota tracking.
834fn extract_token_count(body: &serde_json::Value) -> u64 {
835    body.get("usage").map_or(0, |u| {
836        u.get("input_tokens")
837            .and_then(serde_json::Value::as_u64)
838            .unwrap_or(0)
839            + u.get("output_tokens")
840                .and_then(serde_json::Value::as_u64)
841                .unwrap_or(0)
842    })
843}
844
845// ── Tests ───────────────────────────────────────────────────────────
846
847#[cfg(test)]
848#[allow(
849    clippy::unwrap_used,
850    clippy::unwrap_in_result,
851    clippy::unchecked_duration_subtraction,
852    clippy::panic
853)]
854mod tests {
855    use std::time::Duration;
856
857    use super::*;
858    use crate::types::ChannelHealth;
859
860    fn make_channel(
861        id: &str,
862        name: &str,
863        protocols: Vec<ProtocolEntry>,
864        mappings: Vec<ResolvedMapping>,
865    ) -> ResolvedChannel {
866        ResolvedChannel {
867            channel_id: id.into(),
868            channel_name: name.into(),
869            api_key: secrecy::SecretString::from("sk-test"),
870            protocols,
871            enabled: true,
872            force_protocol: None,
873            priority: 0,
874            mappings,
875        }
876    }
877
878    fn make_mapping_flatfee(
879        client: &str,
880        upstream: &str,
881        exhausted: ExhaustedAction,
882    ) -> ResolvedMapping {
883        ResolvedMapping {
884            mapping_id: format!("test:{client}"),
885            client_name: client.into(),
886            upstream_name: upstream.into(),
887            billing: ChannelBilling::FlatFee {
888                monthly_cost_hint: None,
889                quota: Some(Quota::Unlimited),
890                on_exhausted: exhausted,
891            },
892            allowed_protocols: Vec::new(),
893        }
894    }
895
896    fn make_protocols(protocol: ApiFormat, base_url: &str) -> Vec<ProtocolEntry> {
897        vec![ProtocolEntry {
898            protocol: protocol_to_str(protocol).to_string(),
899            base_url: base_url.to_string(),
900            rewrite_path: None,
901        }]
902    }
903
904    fn make_mapping_metered(client: &str, upstream: &str) -> ResolvedMapping {
905        ResolvedMapping {
906            mapping_id: format!("test:{client}"),
907            client_name: client.into(),
908            upstream_name: upstream.into(),
909            billing: ChannelBilling::Metered {
910                pricing: Pricing::PerToken {
911                    input_per_mtok: 3.0,
912                    output_per_mtok: 15.0,
913                    cache_write_per_mtok: None,
914                    cache_read_per_mtok: None,
915                    thinking_per_mtok: None,
916                    currency: "USD".to_string(),
917                },
918            },
919            allowed_protocols: Vec::new(),
920        }
921    }
922
923    fn make_middleware(channels: Vec<ResolvedChannel>) -> ModelRouterMiddleware {
924        ModelRouterMiddleware {
925            channels: Arc::new(ArcSwap::from_pointee(channels)),
926            health: Arc::new(DashMap::new()),
927            quota_usage: Arc::new(DashMap::new()),
928            channel_api_keys: Arc::new(DashMap::new()),
929        }
930    }
931
932    // ── Selection strategy ──────────────────────────────────────
933
934    #[test]
935    fn test_select_flatfee_has_quota_and_healthy() {
936        let mw = make_middleware(vec![
937            make_channel(
938                "sub",
939                "Subscription",
940                make_protocols(ApiFormat::AnthropicMessages, "https://sub.example.com"),
941                vec![make_mapping_flatfee(
942                    "claude-sonnet",
943                    "claude-sonnet-4-7",
944                    ExhaustedAction::FallbackToMetered,
945                )],
946            ),
947            make_channel(
948                "metered",
949                "Metered",
950                make_protocols(ApiFormat::AnthropicMessages, "https://metered.example.com"),
951                vec![make_mapping_metered("claude-sonnet", "claude-opus-4-7")],
952            ),
953        ]);
954
955        let channels = mw.channels.load();
956        let candidates = ModelRouterMiddleware::find_candidates(&channels, "claude-sonnet");
957        let (ch, m) = mw.select_channel(&candidates, "claude-sonnet").unwrap();
958        assert_eq!(ch.channel_id, "sub");
959        assert!(m.billing.is_flat_fee());
960    }
961
962    #[test]
963    fn test_select_metered_when_flatfee_exhausted_fallback() {
964        let mw = make_middleware(vec![
965            make_channel(
966                "sub-exhausted",
967                "Subscription",
968                make_protocols(ApiFormat::AnthropicMessages, "https://sub.example.com"),
969                vec![ResolvedMapping {
970                    mapping_id: "flatfee-exhausted".into(),
971                    client_name: "claude-sonnet".into(),
972                    upstream_name: "claude-sonnet-4-7".into(),
973                    billing: ChannelBilling::FlatFee {
974                        monthly_cost_hint: None,
975                        quota: Some(Quota::MaxRequests { per_month: 0 }),
976                        on_exhausted: ExhaustedAction::FallbackToMetered,
977                    },
978                    allowed_protocols: Vec::new(),
979                }],
980            ),
981            make_channel(
982                "metered",
983                "Metered",
984                make_protocols(ApiFormat::AnthropicMessages, "https://metered.example.com"),
985                vec![make_mapping_metered("claude-sonnet", "claude-opus-4-7")],
986            ),
987        ]);
988
989        let channels = mw.channels.load();
990        let candidates = ModelRouterMiddleware::find_candidates(&channels, "claude-sonnet");
991        let (ch, _m) = mw.select_channel(&candidates, "claude-sonnet").unwrap();
992        assert_eq!(ch.channel_id, "metered");
993    }
994
995    #[test]
996    fn test_select_block_when_flatfee_exhausted_block() {
997        let mw = make_middleware(vec![make_channel(
998            "sub-blocked",
999            "Subscription",
1000            make_protocols(ApiFormat::AnthropicMessages, "https://sub.example.com"),
1001            vec![ResolvedMapping {
1002                mapping_id: "flatfee-blocked".into(),
1003                client_name: "claude-sonnet".into(),
1004                upstream_name: "claude-sonnet-4-7".into(),
1005                billing: ChannelBilling::FlatFee {
1006                    monthly_cost_hint: None,
1007                    quota: Some(Quota::MaxRequests { per_month: 0 }),
1008                    on_exhausted: ExhaustedAction::Block,
1009                },
1010                allowed_protocols: Vec::new(),
1011            }],
1012        )]);
1013
1014        let channels = mw.channels.load();
1015        let candidates = ModelRouterMiddleware::find_candidates(&channels, "claude-sonnet");
1016        let err = mw.select_channel(&candidates, "claude-sonnet").unwrap_err();
1017        assert!(matches!(err, ProxyError::ChannelSelection { .. }));
1018    }
1019
1020    #[test]
1021    fn test_select_all_unhealthy_returns_error() {
1022        let mw = make_middleware(vec![
1023            make_channel(
1024                "m1",
1025                "Metered1",
1026                make_protocols(ApiFormat::AnthropicMessages, "https://m1.example.com"),
1027                vec![make_mapping_metered("claude-sonnet", "claude-opus-4-7")],
1028            ),
1029            make_channel(
1030                "m2",
1031                "Metered2",
1032                make_protocols(ApiFormat::AnthropicMessages, "https://m2.example.com"),
1033                vec![make_mapping_metered("claude-sonnet", "claude-haiku-4-5")],
1034            ),
1035        ]);
1036
1037        mw.mark_unhealthy_immediate("m1");
1038        mw.mark_unhealthy_immediate("m2");
1039
1040        let channels = mw.channels.load();
1041        let candidates = ModelRouterMiddleware::find_candidates(&channels, "claude-sonnet");
1042        let err = mw.select_channel(&candidates, "claude-sonnet").unwrap_err();
1043        assert!(matches!(err, ProxyError::ChannelSelection { .. }));
1044    }
1045
1046    #[test]
1047    fn test_no_candidates_for_unknown_model() {
1048        let mw = make_middleware(vec![make_channel(
1049            "m1",
1050            "Metered1",
1051            make_protocols(ApiFormat::AnthropicMessages, "https://m1.example.com"),
1052            vec![make_mapping_metered("claude-sonnet", "claude-opus-4-7")],
1053        )]);
1054
1055        let channels = mw.channels.load();
1056        let candidates = ModelRouterMiddleware::find_candidates(&channels, "nonexistent-model");
1057        assert!(candidates.is_empty());
1058    }
1059
1060    #[test]
1061    fn test_disabled_channel_skipped() {
1062        let mw = ModelRouterMiddleware {
1063            quota_usage: Arc::new(DashMap::new()),
1064            channels: Arc::new(ArcSwap::from_pointee(vec![ResolvedChannel {
1065                channel_id: "disabled".into(),
1066                channel_name: "Disabled".into(),
1067                api_key: secrecy::SecretString::from("sk-test"),
1068                protocols: make_protocols(
1069                    ApiFormat::AnthropicMessages,
1070                    "https://disabled.example.com",
1071                ),
1072                enabled: false,
1073                force_protocol: None,
1074                priority: 0,
1075                mappings: vec![make_mapping_metered("claude-sonnet", "claude-opus-4-7")],
1076            }])),
1077            health: Arc::new(DashMap::new()),
1078            channel_api_keys: Arc::new(DashMap::new()),
1079        };
1080
1081        let channels = mw.channels.load();
1082        let candidates = ModelRouterMiddleware::find_candidates(&channels, "claude-sonnet");
1083        assert!(candidates.is_empty());
1084    }
1085
1086    // ── resolve_upstream_url ────────────────────────────────────
1087
1088    #[test]
1089    fn test_resolve_upstream_url_returns_base_url_and_rewrite_path() {
1090        let protocols = vec![ProtocolEntry {
1091            protocol: "openai_chat".into(),
1092            base_url: "https://api.deepseek.com".into(),
1093            rewrite_path: Some("/chat/completions".into()),
1094        }];
1095        let (base_url, rewrite_path) =
1096            resolve_upstream_url(ApiFormat::OpenaiChat, &protocols).unwrap();
1097        assert_eq!(base_url, "https://api.deepseek.com");
1098        assert_eq!(rewrite_path, Some("/chat/completions".into()));
1099    }
1100
1101    #[test]
1102    fn test_resolve_upstream_url_no_rewrite_path() {
1103        let protocols = vec![ProtocolEntry {
1104            protocol: "openai_chat".into(),
1105            base_url: "https://api.deepseek.com".into(),
1106            rewrite_path: None,
1107        }];
1108        let (base_url, rewrite_path) =
1109            resolve_upstream_url(ApiFormat::OpenaiChat, &protocols).unwrap();
1110        assert_eq!(base_url, "https://api.deepseek.com");
1111        assert_eq!(rewrite_path, None);
1112    }
1113
1114    #[test]
1115    fn test_resolve_upstream_url_no_matching_protocol() {
1116        let protocols = vec![ProtocolEntry {
1117            protocol: "openai_chat".into(),
1118            base_url: "https://api.deepseek.com".into(),
1119            rewrite_path: None,
1120        }];
1121        let result = resolve_upstream_url(ApiFormat::AnthropicMessages, &protocols);
1122        assert!(result.is_err());
1123    }
1124
1125    #[test]
1126    fn test_resolve_upstream_url_empty_base_url() {
1127        let protocols = vec![ProtocolEntry {
1128            protocol: "openai_chat".into(),
1129            base_url: String::new(),
1130            rewrite_path: None,
1131        }];
1132        let result = resolve_upstream_url(ApiFormat::OpenaiChat, &protocols);
1133        assert!(result.is_err());
1134    }
1135
1136    // ── resolve_target_protocol ─────────────────────────────────
1137
1138    fn make_protocol_entries(entries: &[(&str, &str)]) -> Vec<ProtocolEntry> {
1139        entries
1140            .iter()
1141            .map(|&(protocol, base_url)| ProtocolEntry {
1142                protocol: protocol.to_owned(),
1143                base_url: base_url.to_owned(),
1144                rewrite_path: None,
1145            })
1146            .collect()
1147    }
1148
1149    #[test]
1150    fn test_resolve_target_protocol_force_valid() {
1151        let protocols = make_protocol_entries(&[
1152            ("openai_chat", "https://api.example.com"),
1153            ("anthropic_messages", "https://api.example.com/anthropic"),
1154        ]);
1155        let result = resolve_target_protocol(
1156            Some("openai_chat"),
1157            Some(ApiFormat::AnthropicMessages),
1158            &protocols,
1159        )
1160        .unwrap();
1161        assert_eq!(result, ApiFormat::OpenaiChat);
1162    }
1163
1164    #[test]
1165    fn test_resolve_target_protocol_force_not_in_protocols() {
1166        let protocols = make_protocol_entries(&[("openai_chat", "https://api.example.com")]);
1167        let result = resolve_target_protocol(Some("anthropic_messages"), None, &protocols);
1168        assert!(result.is_err());
1169    }
1170
1171    #[test]
1172    fn test_resolve_target_protocol_passthrough_client_match() {
1173        let protocols = make_protocol_entries(&[
1174            ("openai_chat", "https://api.example.com"),
1175            ("anthropic_messages", "https://api.example.com/anthropic"),
1176        ]);
1177        let result =
1178            resolve_target_protocol(None, Some(ApiFormat::AnthropicMessages), &protocols).unwrap();
1179        assert_eq!(result, ApiFormat::AnthropicMessages);
1180    }
1181
1182    #[test]
1183    fn test_resolve_target_protocol_fallback_to_first() {
1184        let protocols = make_protocol_entries(&[
1185            ("openai_chat", "https://api.example.com"),
1186            ("anthropic_messages", "https://api.example.com/anthropic"),
1187        ]);
1188        // Client sends a protocol not in the list → fallback to first
1189        let result =
1190            resolve_target_protocol(None, Some(ApiFormat::OpenaiResponses), &protocols).unwrap();
1191        assert_eq!(result, ApiFormat::OpenaiChat);
1192    }
1193
1194    #[test]
1195    fn test_resolve_target_protocol_no_client_format() {
1196        let protocols = make_protocol_entries(&[("openai_chat", "https://api.example.com")]);
1197        // No detected_format → fallback to first
1198        let result = resolve_target_protocol(None, None, &protocols).unwrap();
1199        assert_eq!(result, ApiFormat::OpenaiChat);
1200    }
1201
1202    #[test]
1203    fn test_resolve_target_protocol_empty_protocols() {
1204        let result = resolve_target_protocol(None, Some(ApiFormat::AnthropicMessages), &[]);
1205        assert!(result.is_err());
1206    }
1207
1208    #[test]
1209    fn test_resolve_target_protocol_force_with_empty_protocols() {
1210        let result = resolve_target_protocol(Some("anthropic_messages"), None, &[]);
1211        assert!(result.is_err());
1212    }
1213
1214    // ── Health tracking ─────────────────────────────────────────
1215
1216    #[test]
1217    fn test_health_mark_unhealthy_then_healthy() {
1218        let mw = make_middleware(vec![]);
1219        mw.mark_unhealthy_immediate("ch1");
1220        assert!(!mw.is_healthy("ch1"));
1221
1222        mw.mark_healthy("ch1");
1223        assert!(mw.is_healthy("ch1"));
1224    }
1225
1226    #[test]
1227    fn test_health_cooldown_expired() {
1228        let mw = make_middleware(vec![]);
1229        mw.health.insert(
1230            "ch1".to_owned(),
1231            ChannelState {
1232                health: ChannelHealth::Unhealthy,
1233                consecutive_failures: 0,
1234                failed_at: Some(std::time::Instant::now() - Duration::from_secs(61)),
1235            },
1236        );
1237        assert!(mw.is_healthy("ch1"));
1238    }
1239
1240    // ── API key filtering ───────────────────────────────────────
1241
1242    fn make_channel_with_key(
1243        id: &str,
1244        api_key: &str,
1245        protocols: Vec<ProtocolEntry>,
1246        mappings: Vec<ResolvedMapping>,
1247    ) -> ResolvedChannel {
1248        ResolvedChannel {
1249            channel_id: id.into(),
1250            channel_name: id.into(),
1251            api_key: secrecy::SecretString::from(api_key),
1252            protocols,
1253            enabled: true,
1254            force_protocol: None,
1255            priority: 10,
1256            mappings,
1257        }
1258    }
1259
1260    #[test]
1261    fn test_channel_with_empty_api_key_is_skipped() {
1262        let mw = make_middleware(vec![
1263            make_channel_with_key(
1264                "no-key",
1265                "",
1266                make_protocols(ApiFormat::AnthropicMessages, "https://no-key.example.com"),
1267                vec![make_mapping_metered("claude-sonnet", "claude-sonnet-v1")],
1268            ),
1269            make_channel_with_key(
1270                "has-key",
1271                "sk-valid",
1272                make_protocols(ApiFormat::AnthropicMessages, "https://has-key.example.com"),
1273                vec![make_mapping_metered("claude-sonnet", "claude-sonnet-v2")],
1274            ),
1275        ]);
1276
1277        let channels = mw.channels.load();
1278        let candidates = ModelRouterMiddleware::find_candidates(&channels, "claude-sonnet");
1279        assert_eq!(candidates.len(), 2);
1280        let (ch, _m) = mw.select_channel(&candidates, "claude-sonnet").unwrap();
1281        assert_eq!(ch.channel_id, "has-key", "should skip channel with empty API key");
1282    }
1283
1284    #[test]
1285    fn test_all_channels_empty_key_returns_error() {
1286        let mw = make_middleware(vec![make_channel_with_key(
1287            "no-key-1",
1288            "",
1289            make_protocols(ApiFormat::AnthropicMessages, "https://no1.example.com"),
1290            vec![make_mapping_metered("claude-sonnet", "claude-sonnet-v1")],
1291        )]);
1292
1293        let channels = mw.channels.load();
1294        let candidates = ModelRouterMiddleware::find_candidates(&channels, "claude-sonnet");
1295        let err = mw.select_channel(&candidates, "claude-sonnet").unwrap_err();
1296        assert!(
1297            matches!(err, ProxyError::ChannelSelection { .. }),
1298            "should error when no channel has a valid API key"
1299        );
1300    }
1301
1302    #[test]
1303    fn test_has_api_key_runtime_override() {
1304        let mw = make_middleware(vec![make_channel_with_key(
1305            "no-key-stored",
1306            "",
1307            make_protocols(ApiFormat::AnthropicMessages, "https://no-key.example.com"),
1308            vec![make_mapping_metered("claude-sonnet", "claude-sonnet-v1")],
1309        )]);
1310
1311        // Without override: should be excluded
1312        assert!(!mw.has_api_key("no-key-stored"));
1313
1314        // With runtime override: should be available
1315        mw.channel_api_keys
1316            .insert("no-key-stored".into(), secrecy::SecretString::from("sk-override"));
1317        assert!(mw.has_api_key("no-key-stored"));
1318    }
1319
1320    #[test]
1321    fn test_empty_key_skipped_in_fallback_phase() {
1322        // All channels unhealthy but "has-key" past cooldown, "no-key" also past cooldown
1323        let mw = make_middleware(vec![
1324            make_channel_with_key(
1325                "no-key",
1326                "",
1327                make_protocols(ApiFormat::AnthropicMessages, "https://no-key.example.com"),
1328                vec![make_mapping_metered("claude-sonnet", "claude-sonnet-v1")],
1329            ),
1330            make_channel_with_key(
1331                "has-key",
1332                "sk-valid",
1333                make_protocols(ApiFormat::AnthropicMessages, "https://has-key.example.com"),
1334                vec![make_mapping_metered("claude-sonnet", "claude-sonnet-v2")],
1335            ),
1336        ]);
1337
1338        // Mark both unhealthy (past cooldown: 61s ago)
1339        for ch_id in ["no-key", "has-key"] {
1340            mw.health.insert(
1341                ch_id.to_owned(),
1342                ChannelState {
1343                    health: ChannelHealth::Unhealthy,
1344                    consecutive_failures: 1,
1345                    failed_at: Some(std::time::Instant::now() - Duration::from_secs(61)),
1346                },
1347            );
1348        }
1349
1350        let channels = mw.channels.load();
1351        let candidates = ModelRouterMiddleware::find_candidates(&channels, "claude-sonnet");
1352        // Fallback should skip "no-key" and pick "has-key"
1353        let (ch, _m) = mw.select_channel(&candidates, "claude-sonnet").unwrap();
1354        assert_eq!(ch.channel_id, "has-key", "fallback should skip empty-key channel");
1355    }
1356}