Skip to main content

codex_helper_core/
runtime_candidate_state.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5use crate::balance::{ProviderBalanceSnapshot, StationRoutingBalanceSummary};
6use crate::routing_ir::{RouteCandidate, RoutePlanTemplate};
7use crate::runtime_identity::RuntimeUpstreamIdentity;
8use crate::state::{LbConfigView, LbUpstreamView, PassiveUpstreamHealth, StationHealth};
9
10#[derive(Debug, Clone, Copy, Default)]
11pub struct RouteRuntimeSignalInputs<'a> {
12    pub station_health: Option<&'a HashMap<String, StationHealth>>,
13    pub load_balancers: Option<&'a HashMap<String, LbConfigView>>,
14    pub provider_balances: Option<&'a HashMap<String, Vec<ProviderBalanceSnapshot>>>,
15    pub now_ms: u64,
16}
17
18impl<'a> RouteRuntimeSignalInputs<'a> {
19    pub fn empty(now_ms: u64) -> Self {
20        Self {
21            station_health: None,
22            load_balancers: None,
23            provider_balances: None,
24            now_ms,
25        }
26    }
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
30pub struct RouteCandidateRuntimeSignals {
31    pub identity: RuntimeUpstreamIdentity,
32    #[serde(default, skip_serializing_if = "Option::is_none")]
33    pub passive_health: Option<PassiveUpstreamHealth>,
34    #[serde(default, skip_serializing_if = "Option::is_none")]
35    pub load_balancer: Option<LbUpstreamView>,
36    #[serde(
37        default,
38        skip_serializing_if = "StationRoutingBalanceSummary::is_empty"
39    )]
40    pub balance: StationRoutingBalanceSummary,
41}
42
43impl RoutePlanTemplate {
44    pub fn candidate_runtime_signals(
45        &self,
46        candidate: &RouteCandidate,
47        inputs: &RouteRuntimeSignalInputs<'_>,
48    ) -> RouteCandidateRuntimeSignals {
49        let identity = self.candidate_identity(candidate);
50        RouteCandidateRuntimeSignals {
51            passive_health: candidate_passive_health(&identity, inputs.station_health),
52            load_balancer: candidate_load_balancer_state(&identity, inputs.load_balancers),
53            balance: candidate_balance_summary(&identity, inputs.provider_balances, inputs.now_ms),
54            identity,
55        }
56    }
57
58    pub fn candidate_runtime_signal_view(
59        &self,
60        inputs: &RouteRuntimeSignalInputs<'_>,
61    ) -> Vec<RouteCandidateRuntimeSignals> {
62        self.candidates
63            .iter()
64            .map(|candidate| self.candidate_runtime_signals(candidate, inputs))
65            .collect()
66    }
67}
68
69fn candidate_passive_health(
70    identity: &RuntimeUpstreamIdentity,
71    station_health: Option<&HashMap<String, StationHealth>>,
72) -> Option<PassiveUpstreamHealth> {
73    let compatibility = identity.compatibility.as_ref()?;
74    station_health?
75        .get(compatibility.station_name.as_str())?
76        .upstreams
77        .iter()
78        .find(|upstream| upstream.base_url == identity.base_url)
79        .and_then(|upstream| upstream.passive.clone())
80}
81
82fn candidate_load_balancer_state(
83    identity: &RuntimeUpstreamIdentity,
84    load_balancers: Option<&HashMap<String, LbConfigView>>,
85) -> Option<LbUpstreamView> {
86    let compatibility = identity.compatibility.as_ref()?;
87    load_balancers?
88        .get(compatibility.station_name.as_str())?
89        .upstreams
90        .get(compatibility.upstream_index)
91        .cloned()
92}
93
94fn candidate_balance_summary(
95    identity: &RuntimeUpstreamIdentity,
96    provider_balances: Option<&HashMap<String, Vec<ProviderBalanceSnapshot>>>,
97    now_ms: u64,
98) -> StationRoutingBalanceSummary {
99    let Some(compatibility) = identity.compatibility.as_ref() else {
100        return StationRoutingBalanceSummary::default();
101    };
102    let Some(snapshots) =
103        provider_balances.and_then(|balances| balances.get(compatibility.station_name.as_str()))
104    else {
105        return StationRoutingBalanceSummary::default();
106    };
107
108    StationRoutingBalanceSummary::from_snapshot_iter_at(
109        snapshots
110            .iter()
111            .filter(|snapshot| balance_snapshot_matches_candidate(snapshot, identity)),
112        now_ms,
113    )
114}
115
116fn balance_snapshot_matches_candidate(
117    snapshot: &ProviderBalanceSnapshot,
118    identity: &RuntimeUpstreamIdentity,
119) -> bool {
120    let Some(compatibility) = identity.compatibility.as_ref() else {
121        return false;
122    };
123    snapshot.provider_id == identity.provider_endpoint.provider_id
124        && snapshot.station_name.as_deref() == Some(compatibility.station_name.as_str())
125        && snapshot.upstream_index == Some(compatibility.upstream_index)
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use crate::balance::BalanceSnapshotStatus;
132    use crate::config::{
133        ProviderConfigV4, ProviderEndpointV4, ServiceConfig, ServiceViewV4, UpstreamAuth,
134        UpstreamConfig,
135    };
136    use crate::routing_ir::{compile_legacy_route_plan_template, compile_v4_route_plan_template};
137    use crate::state::{PassiveHealthState, UpstreamHealth};
138    use std::collections::BTreeMap;
139
140    fn provider(base_url: &str) -> ProviderConfigV4 {
141        ProviderConfigV4 {
142            base_url: Some(base_url.to_string()),
143            ..ProviderConfigV4::default()
144        }
145    }
146
147    fn passive_health(state: PassiveHealthState, score: u8) -> PassiveUpstreamHealth {
148        PassiveUpstreamHealth {
149            score,
150            state,
151            observed_at_ms: 100,
152            last_failure_at_ms: Some(100),
153            consecutive_failures: 1,
154            ..PassiveUpstreamHealth::default()
155        }
156    }
157
158    fn balance_snapshot(
159        provider_id: &str,
160        station_name: &str,
161        upstream_index: usize,
162        exhausted: bool,
163    ) -> ProviderBalanceSnapshot {
164        ProviderBalanceSnapshot {
165            provider_id: provider_id.to_string(),
166            station_name: Some(station_name.to_string()),
167            upstream_index: Some(upstream_index),
168            source: "test".to_string(),
169            fetched_at_ms: 100,
170            stale_after_ms: Some(200),
171            status: if exhausted {
172                BalanceSnapshotStatus::Exhausted
173            } else {
174                BalanceSnapshotStatus::Ok
175            },
176            exhausted: Some(exhausted),
177            ..ProviderBalanceSnapshot::default()
178        }
179    }
180
181    #[test]
182    fn route_candidate_runtime_signals_attach_existing_legacy_state() {
183        let view = ServiceViewV4 {
184            providers: BTreeMap::from([(
185                "input".to_string(),
186                provider("https://input.example/v1"),
187            )]),
188            ..ServiceViewV4::default()
189        };
190        let template = compile_v4_route_plan_template("codex", &view).expect("route template");
191
192        let station_health = HashMap::from([(
193            "routing".to_string(),
194            StationHealth {
195                checked_at_ms: 100,
196                upstreams: vec![UpstreamHealth {
197                    base_url: "https://input.example/v1".to_string(),
198                    passive: Some(passive_health(PassiveHealthState::Failing, 20)),
199                    ..UpstreamHealth::default()
200                }],
201            },
202        )]);
203        let load_balancers = HashMap::from([(
204            "routing".to_string(),
205            LbConfigView {
206                last_good_index: None,
207                upstreams: vec![LbUpstreamView {
208                    failure_count: 3,
209                    cooldown_remaining_secs: Some(11),
210                    usage_exhausted: true,
211                }],
212            },
213        )]);
214        let provider_balances = HashMap::from([(
215            "routing".to_string(),
216            vec![balance_snapshot("input", "routing", 0, true)],
217        )]);
218        let inputs = RouteRuntimeSignalInputs {
219            station_health: Some(&station_health),
220            load_balancers: Some(&load_balancers),
221            provider_balances: Some(&provider_balances),
222            now_ms: 150,
223        };
224
225        let signals = template.candidate_runtime_signal_view(&inputs);
226
227        assert_eq!(signals.len(), 1);
228        assert_eq!(
229            signals[0].identity.provider_endpoint.stable_key(),
230            "codex/input/default"
231        );
232        assert!(signals[0].identity.compatibility.is_none());
233        assert!(signals[0].passive_health.is_none());
234        assert!(signals[0].load_balancer.is_none());
235        assert!(signals[0].balance.is_empty());
236    }
237
238    #[test]
239    fn route_candidate_runtime_signals_disambiguate_multi_endpoint_provider_by_legacy_index() {
240        let mut endpoints = BTreeMap::new();
241        endpoints.insert(
242            "slow".to_string(),
243            ProviderEndpointV4 {
244                base_url: "https://slow.example/v1".to_string(),
245                enabled: true,
246                priority: 10,
247                tags: BTreeMap::new(),
248                supported_models: BTreeMap::new(),
249                model_mapping: BTreeMap::new(),
250            },
251        );
252        endpoints.insert(
253            "fast".to_string(),
254            ProviderEndpointV4 {
255                base_url: "https://fast.example/v1".to_string(),
256                enabled: true,
257                priority: 0,
258                tags: BTreeMap::new(),
259                supported_models: BTreeMap::new(),
260                model_mapping: BTreeMap::new(),
261            },
262        );
263        let view = ServiceViewV4 {
264            providers: BTreeMap::from([(
265                "input".to_string(),
266                ProviderConfigV4 {
267                    endpoints,
268                    auth: UpstreamAuth::default(),
269                    ..ProviderConfigV4::default()
270                },
271            )]),
272            ..ServiceViewV4::default()
273        };
274        let template = compile_v4_route_plan_template("codex", &view).expect("route template");
275        let provider_balances = HashMap::from([(
276            "routing".to_string(),
277            vec![
278                balance_snapshot("input", "routing", 0, false),
279                balance_snapshot("input", "routing", 1, true),
280            ],
281        )]);
282        let load_balancers = HashMap::from([(
283            "routing".to_string(),
284            LbConfigView {
285                last_good_index: Some(0),
286                upstreams: vec![
287                    LbUpstreamView {
288                        failure_count: 0,
289                        cooldown_remaining_secs: None,
290                        usage_exhausted: false,
291                    },
292                    LbUpstreamView {
293                        failure_count: 3,
294                        cooldown_remaining_secs: Some(30),
295                        usage_exhausted: true,
296                    },
297                ],
298            },
299        )]);
300        let inputs = RouteRuntimeSignalInputs {
301            load_balancers: Some(&load_balancers),
302            provider_balances: Some(&provider_balances),
303            now_ms: 150,
304            ..RouteRuntimeSignalInputs::default()
305        };
306
307        let signals = template.candidate_runtime_signal_view(&inputs);
308
309        assert_eq!(
310            signals
311                .iter()
312                .map(|signal| signal.identity.provider_endpoint.stable_key())
313                .collect::<Vec<_>>(),
314            vec!["codex/input/fast", "codex/input/slow"]
315        );
316        assert!(signals[0].identity.compatibility.is_none());
317        assert!(signals[1].identity.compatibility.is_none());
318        assert!(signals[0].balance.is_empty());
319        assert!(signals[1].balance.is_empty());
320        assert!(signals[0].load_balancer.is_none());
321        assert!(signals[1].load_balancer.is_none());
322    }
323
324    #[test]
325    fn route_candidate_runtime_signals_keep_legacy_station_compatibility_reads() {
326        let service = ServiceConfig {
327            name: "primary".to_string(),
328            alias: Some("Primary".to_string()),
329            enabled: true,
330            level: 1,
331            upstreams: vec![UpstreamConfig {
332                base_url: "https://legacy.example/v1".to_string(),
333                auth: UpstreamAuth::default(),
334                tags: HashMap::from([
335                    ("provider_id".to_string(), "legacy-provider".to_string()),
336                    ("endpoint_id".to_string(), "legacy-endpoint".to_string()),
337                ]),
338                supported_models: HashMap::new(),
339                model_mapping: HashMap::new(),
340            }],
341        };
342        let template = compile_legacy_route_plan_template("codex", [&service]);
343
344        let station_health = HashMap::from([(
345            "primary".to_string(),
346            StationHealth {
347                checked_at_ms: 100,
348                upstreams: vec![UpstreamHealth {
349                    base_url: "https://legacy.example/v1".to_string(),
350                    passive: Some(passive_health(PassiveHealthState::Degraded, 60)),
351                    ..UpstreamHealth::default()
352                }],
353            },
354        )]);
355        let load_balancers = HashMap::from([(
356            "primary".to_string(),
357            LbConfigView {
358                last_good_index: Some(0),
359                upstreams: vec![LbUpstreamView {
360                    failure_count: 1,
361                    cooldown_remaining_secs: None,
362                    usage_exhausted: false,
363                }],
364            },
365        )]);
366        let provider_balances = HashMap::from([(
367            "primary".to_string(),
368            vec![balance_snapshot("legacy-provider", "primary", 0, false)],
369        )]);
370        let inputs = RouteRuntimeSignalInputs {
371            station_health: Some(&station_health),
372            load_balancers: Some(&load_balancers),
373            provider_balances: Some(&provider_balances),
374            now_ms: 150,
375        };
376
377        let signals = template.candidate_runtime_signal_view(&inputs);
378
379        assert_eq!(signals.len(), 1);
380        assert_eq!(
381            signals[0].identity.provider_endpoint.stable_key(),
382            "codex/legacy-provider/legacy-endpoint"
383        );
384        assert_eq!(
385            signals[0]
386                .identity
387                .compatibility
388                .as_ref()
389                .map(crate::runtime_identity::LegacyUpstreamKey::stable_key)
390                .as_deref(),
391            Some("codex/primary/0")
392        );
393        assert_eq!(
394            signals[0]
395                .passive_health
396                .as_ref()
397                .map(|health| health.state),
398            Some(PassiveHealthState::Degraded)
399        );
400        assert_eq!(
401            signals[0]
402                .load_balancer
403                .as_ref()
404                .map(|view| view.failure_count),
405            Some(1)
406        );
407        assert_eq!(signals[0].balance.ok, 1);
408        assert_eq!(signals[0].balance.routing_snapshots, 1);
409    }
410}