Skip to main content

hyperinfer_router/
deployment.rs

1use hyperinfer_core::Provider;
2use serde::{Deserialize, Serialize};
3use sha2::{Digest, Sha256};
4use std::collections::HashMap;
5use std::sync::Arc;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct Deployment {
9    pub id: String,
10    pub model_name: String,
11    pub provider: Provider,
12    pub model: String,
13    pub api_key_ref: String,
14    pub base_url: Option<String>,
15    pub weight: u32,
16    pub rpm_limit: Option<u64>,
17    pub tpm_limit: Option<u64>,
18    pub input_cost_per_1k: Option<f64>,
19    pub output_cost_per_1k: Option<f64>,
20    pub order: u32,
21    pub tags: HashMap<String, String>,
22}
23
24impl Deployment {
25    pub fn new(model_name: String, provider: Provider, model: String, api_key_ref: String) -> Self {
26        let id = Self::generate_id(&provider, &model, &None, &api_key_ref);
27        Self {
28            id,
29            model_name,
30            provider,
31            model,
32            api_key_ref,
33            base_url: None,
34            weight: 1,
35            rpm_limit: None,
36            tpm_limit: None,
37            input_cost_per_1k: None,
38            output_cost_per_1k: None,
39            order: 0,
40            tags: HashMap::new(),
41        }
42    }
43
44    pub fn generate_id(
45        provider: &Provider,
46        model: &str,
47        base_url: &Option<String>,
48        api_key_ref: &str,
49    ) -> String {
50        let base_url_str = base_url.as_deref().unwrap_or("");
51        let input = format!("{}:{}:{}:{}", provider, model, base_url_str, api_key_ref);
52        let mut hasher = Sha256::new();
53        hasher.update(input.as_bytes());
54        let result = hasher.finalize();
55        hex::encode(result)
56    }
57
58    pub fn with_base_url(mut self, base_url: String) -> Self {
59        self.id = Self::generate_id(
60            &self.provider,
61            &self.model,
62            &Some(base_url.clone()),
63            &self.api_key_ref,
64        );
65        self.base_url = Some(base_url);
66        self
67    }
68
69    pub fn with_weight(mut self, weight: u32) -> Self {
70        self.weight = weight;
71        self
72    }
73
74    pub fn with_rpm_limit(mut self, rpm_limit: u64) -> Self {
75        self.rpm_limit = Some(rpm_limit);
76        self
77    }
78
79    pub fn with_tpm_limit(mut self, tpm_limit: u64) -> Self {
80        self.tpm_limit = Some(tpm_limit);
81        self
82    }
83
84    pub fn with_input_cost(mut self, cost: f64) -> Self {
85        self.input_cost_per_1k = Some(cost);
86        self
87    }
88
89    pub fn with_output_cost(mut self, cost: f64) -> Self {
90        self.output_cost_per_1k = Some(cost);
91        self
92    }
93
94    pub fn with_order(mut self, order: u32) -> Self {
95        self.order = order;
96        self
97    }
98
99    pub fn with_tag(mut self, key: String, value: String) -> Self {
100        self.tags.insert(key, value);
101        self
102    }
103}
104
105#[derive(Debug, Clone)]
106pub struct DeploymentPool {
107    deployments: HashMap<String, Vec<Arc<Deployment>>>,
108}
109
110impl DeploymentPool {
111    pub fn new() -> Self {
112        Self {
113            deployments: HashMap::new(),
114        }
115    }
116
117    pub fn add(&mut self, deployment: Deployment) {
118        let entry = self
119            .deployments
120            .entry(deployment.model_name.clone())
121            .or_default();
122        entry.push(Arc::new(deployment));
123        entry.sort_by_key(|d| d.order);
124    }
125
126    pub fn remove(&mut self, id: &str) -> bool {
127        let mut found_key = None;
128
129        for (key, deployments) in self.deployments.iter_mut() {
130            let initial_len = deployments.len();
131            deployments.retain(|d| d.id != id);
132            if deployments.len() < initial_len {
133                found_key = Some(key.clone());
134                break;
135            }
136        }
137
138        if let Some(ref key) = found_key {
139            if self.deployments.get(key).is_some_and(|v| v.is_empty()) {
140                self.deployments.remove(key);
141            }
142        }
143
144        found_key.is_some()
145    }
146
147    pub fn get(&self, model_name: &str) -> Option<&[Arc<Deployment>]> {
148        self.deployments.get(model_name).map(|v| v.as_slice())
149    }
150
151    pub fn model_names(&self) -> Vec<String> {
152        self.deployments.keys().cloned().collect()
153    }
154
155    pub fn rebuild(&mut self) {
156        for deployments in self.deployments.values_mut() {
157            deployments.sort_by_key(|d| d.order);
158        }
159    }
160
161    pub fn is_empty(&self) -> bool {
162        self.deployments.is_empty()
163    }
164
165    pub fn total_deployments(&self) -> usize {
166        self.deployments.values().map(|v| v.len()).sum()
167    }
168}
169
170impl Default for DeploymentPool {
171    fn default() -> Self {
172        Self::new()
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    #[test]
181    fn test_deployment_id_determinism() {
182        let d1 = Deployment::new(
183            "gpt-4".to_string(),
184            Provider::OpenAI,
185            "gpt-4".to_string(),
186            "key1".to_string(),
187        );
188        let d2 = Deployment::new(
189            "gpt-4".to_string(),
190            Provider::OpenAI,
191            "gpt-4".to_string(),
192            "key1".to_string(),
193        );
194        assert_eq!(d1.id, d2.id);
195    }
196
197    #[test]
198    fn test_deployment_id_differs_by_api_key() {
199        let d1 = Deployment::new(
200            "gpt-4".to_string(),
201            Provider::OpenAI,
202            "gpt-4".to_string(),
203            "key1".to_string(),
204        );
205        let d2 = Deployment::new(
206            "gpt-4".to_string(),
207            Provider::OpenAI,
208            "gpt-4".to_string(),
209            "key2".to_string(),
210        );
211        assert_ne!(d1.id, d2.id);
212    }
213
214    #[test]
215    fn test_deployment_id_differs_by_base_url() {
216        let d1 = Deployment::new(
217            "gpt-4".to_string(),
218            Provider::OpenAI,
219            "gpt-4".to_string(),
220            "key1".to_string(),
221        );
222        let d2 = Deployment::new(
223            "gpt-4".to_string(),
224            Provider::OpenAI,
225            "gpt-4".to_string(),
226            "key1".to_string(),
227        )
228        .with_base_url("https://custom.api.com".to_string());
229        assert_ne!(d1.id, d2.id);
230    }
231
232    #[test]
233    fn test_deployment_id_differs_by_provider() {
234        let d1 = Deployment::new(
235            "model".to_string(),
236            Provider::OpenAI,
237            "model".to_string(),
238            "key1".to_string(),
239        );
240        let d2 = Deployment::new(
241            "model".to_string(),
242            Provider::Anthropic,
243            "model".to_string(),
244            "key1".to_string(),
245        );
246        assert_ne!(d1.id, d2.id);
247    }
248
249    #[test]
250    fn test_deployment_builder_defaults() {
251        let d = Deployment::new(
252            "gpt-4".to_string(),
253            Provider::OpenAI,
254            "gpt-4".to_string(),
255            "key1".to_string(),
256        );
257        assert_eq!(d.weight, 1);
258        assert_eq!(d.order, 0);
259        assert!(d.base_url.is_none());
260        assert!(d.rpm_limit.is_none());
261        assert!(d.tpm_limit.is_none());
262        assert!(d.input_cost_per_1k.is_none());
263        assert!(d.output_cost_per_1k.is_none());
264        assert!(d.tags.is_empty());
265    }
266
267    #[test]
268    fn test_deployment_builder_chain() {
269        let d = Deployment::new(
270            "gpt-4".to_string(),
271            Provider::OpenAI,
272            "gpt-4".to_string(),
273            "key1".to_string(),
274        )
275        .with_base_url("https://api.openai.com".to_string())
276        .with_weight(5)
277        .with_rpm_limit(1000)
278        .with_tpm_limit(50000)
279        .with_input_cost(0.03)
280        .with_output_cost(0.06)
281        .with_order(10)
282        .with_tag("env".to_string(), "prod".to_string())
283        .with_tag("region".to_string(), "us-east".to_string());
284
285        assert_eq!(d.base_url, Some("https://api.openai.com".to_string()));
286        assert_eq!(d.weight, 5);
287        assert_eq!(d.rpm_limit, Some(1000));
288        assert_eq!(d.tpm_limit, Some(50000));
289        assert_eq!(d.input_cost_per_1k, Some(0.03));
290        assert_eq!(d.output_cost_per_1k, Some(0.06));
291        assert_eq!(d.order, 10);
292        assert_eq!(d.tags.len(), 2);
293        assert_eq!(d.tags.get("env"), Some(&"prod".to_string()));
294        assert_eq!(d.tags.get("region"), Some(&"us-east".to_string()));
295    }
296
297    #[test]
298    fn test_deployment_serialization_roundtrip() {
299        let d = Deployment::new(
300            "gpt-4".to_string(),
301            Provider::OpenAI,
302            "gpt-4".to_string(),
303            "key1".to_string(),
304        )
305        .with_weight(3)
306        .with_rpm_limit(500)
307        .with_tag("env".to_string(), "test".to_string());
308
309        let json = serde_json::to_string(&d).unwrap();
310        let deserialized: Deployment = serde_json::from_str(&json).unwrap();
311
312        assert_eq!(d.id, deserialized.id);
313        assert_eq!(d.model_name, deserialized.model_name);
314        assert_eq!(d.provider, deserialized.provider);
315        assert_eq!(d.model, deserialized.model);
316        assert_eq!(d.weight, deserialized.weight);
317        assert_eq!(d.rpm_limit, deserialized.rpm_limit);
318        assert_eq!(d.tags, deserialized.tags);
319    }
320
321    #[test]
322    fn test_pool_grouping_by_model_name() {
323        let mut pool = DeploymentPool::new();
324
325        pool.add(Deployment::new(
326            "gpt-4".to_string(),
327            Provider::OpenAI,
328            "gpt-4".to_string(),
329            "key1".to_string(),
330        ));
331        pool.add(Deployment::new(
332            "gpt-4".to_string(),
333            Provider::OpenAI,
334            "gpt-4".to_string(),
335            "key2".to_string(),
336        ));
337        pool.add(Deployment::new(
338            "claude-3".to_string(),
339            Provider::Anthropic,
340            "claude-3-opus".to_string(),
341            "key3".to_string(),
342        ));
343
344        let gpt4_deployments = pool.get("gpt-4").unwrap();
345        assert_eq!(gpt4_deployments.len(), 2);
346
347        let claude_deployments = pool.get("claude-3").unwrap();
348        assert_eq!(claude_deployments.len(), 1);
349
350        assert_eq!(pool.total_deployments(), 3);
351    }
352
353    #[test]
354    fn test_pool_ordering_by_order_field() {
355        let mut pool = DeploymentPool::new();
356
357        pool.add(
358            Deployment::new(
359                "gpt-4".to_string(),
360                Provider::OpenAI,
361                "gpt-4".to_string(),
362                "key1".to_string(),
363            )
364            .with_order(3),
365        );
366        pool.add(
367            Deployment::new(
368                "gpt-4".to_string(),
369                Provider::OpenAI,
370                "gpt-4".to_string(),
371                "key2".to_string(),
372            )
373            .with_order(1),
374        );
375        pool.add(
376            Deployment::new(
377                "gpt-4".to_string(),
378                Provider::OpenAI,
379                "gpt-4".to_string(),
380                "key3".to_string(),
381            )
382            .with_order(2),
383        );
384
385        let deployments = pool.get("gpt-4").unwrap();
386        assert_eq!(deployments[0].order, 1);
387        assert_eq!(deployments[1].order, 2);
388        assert_eq!(deployments[2].order, 3);
389    }
390
391    #[test]
392    fn test_pool_remove() {
393        let mut pool = DeploymentPool::new();
394
395        let d1 = Deployment::new(
396            "gpt-4".to_string(),
397            Provider::OpenAI,
398            "gpt-4".to_string(),
399            "key1".to_string(),
400        );
401        let d2 = Deployment::new(
402            "gpt-4".to_string(),
403            Provider::OpenAI,
404            "gpt-4".to_string(),
405            "key2".to_string(),
406        );
407        let d1_id = d1.id.clone();
408
409        pool.add(d1);
410        pool.add(d2);
411        assert_eq!(pool.total_deployments(), 2);
412
413        let removed = pool.remove(&d1_id);
414        assert!(removed);
415        assert_eq!(pool.total_deployments(), 1);
416
417        let removed_again = pool.remove(&d1_id);
418        assert!(!removed_again);
419    }
420
421    #[test]
422    fn test_pool_model_names() {
423        let mut pool = DeploymentPool::new();
424
425        pool.add(Deployment::new(
426            "gpt-4".to_string(),
427            Provider::OpenAI,
428            "gpt-4".to_string(),
429            "key1".to_string(),
430        ));
431        pool.add(Deployment::new(
432            "claude-3".to_string(),
433            Provider::Anthropic,
434            "claude-3-opus".to_string(),
435            "key2".to_string(),
436        ));
437        pool.add(Deployment::new(
438            "gpt-4".to_string(),
439            Provider::OpenAI,
440            "gpt-4".to_string(),
441            "key3".to_string(),
442        ));
443
444        let mut names = pool.model_names();
445        names.sort();
446        assert_eq!(names, vec!["claude-3".to_string(), "gpt-4".to_string()]);
447    }
448
449    #[test]
450    fn test_pool_rebuild() {
451        let mut pool = DeploymentPool::new();
452
453        let d1 = Deployment::new(
454            "gpt-4".to_string(),
455            Provider::OpenAI,
456            "gpt-4".to_string(),
457            "key1".to_string(),
458        )
459        .with_order(1);
460        let d2 = Deployment::new(
461            "gpt-4".to_string(),
462            Provider::OpenAI,
463            "gpt-4".to_string(),
464            "key2".to_string(),
465        )
466        .with_order(2);
467
468        pool.add(d2);
469        pool.add(d1);
470
471        let deployments = pool.get("gpt-4").unwrap();
472        assert_eq!(deployments[0].order, 1);
473        assert_eq!(deployments[1].order, 2);
474
475        pool.rebuild();
476
477        let deployments = pool.get("gpt-4").unwrap();
478        assert_eq!(deployments[0].order, 1);
479        assert_eq!(deployments[1].order, 2);
480    }
481}