1use hyperinfer_core::Provider;
2use serde::{Deserialize, Serialize};
3use sha2::{Digest, Sha256};
4use std::collections::HashMap;
5use std::sync::Arc;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct Deployment {
9 pub id: String,
10 pub model_name: String,
11 pub provider: Provider,
12 pub model: String,
13 pub api_key_ref: String,
14 pub base_url: Option<String>,
15 pub weight: u32,
16 pub rpm_limit: Option<u64>,
17 pub tpm_limit: Option<u64>,
18 pub input_cost_per_1k: Option<f64>,
19 pub output_cost_per_1k: Option<f64>,
20 pub order: u32,
21 pub tags: HashMap<String, String>,
22}
23
24impl Deployment {
25 pub fn new(model_name: String, provider: Provider, model: String, api_key_ref: String) -> Self {
26 let id = Self::generate_id(&provider, &model, &None, &api_key_ref);
27 Self {
28 id,
29 model_name,
30 provider,
31 model,
32 api_key_ref,
33 base_url: None,
34 weight: 1,
35 rpm_limit: None,
36 tpm_limit: None,
37 input_cost_per_1k: None,
38 output_cost_per_1k: None,
39 order: 0,
40 tags: HashMap::new(),
41 }
42 }
43
44 pub fn generate_id(
45 provider: &Provider,
46 model: &str,
47 base_url: &Option<String>,
48 api_key_ref: &str,
49 ) -> String {
50 let base_url_str = base_url.as_deref().unwrap_or("");
51 let input = format!("{}:{}:{}:{}", provider, model, base_url_str, api_key_ref);
52 let mut hasher = Sha256::new();
53 hasher.update(input.as_bytes());
54 let result = hasher.finalize();
55 hex::encode(result)
56 }
57
58 pub fn with_base_url(mut self, base_url: String) -> Self {
59 self.id = Self::generate_id(
60 &self.provider,
61 &self.model,
62 &Some(base_url.clone()),
63 &self.api_key_ref,
64 );
65 self.base_url = Some(base_url);
66 self
67 }
68
69 pub fn with_weight(mut self, weight: u32) -> Self {
70 self.weight = weight;
71 self
72 }
73
74 pub fn with_rpm_limit(mut self, rpm_limit: u64) -> Self {
75 self.rpm_limit = Some(rpm_limit);
76 self
77 }
78
79 pub fn with_tpm_limit(mut self, tpm_limit: u64) -> Self {
80 self.tpm_limit = Some(tpm_limit);
81 self
82 }
83
84 pub fn with_input_cost(mut self, cost: f64) -> Self {
85 self.input_cost_per_1k = Some(cost);
86 self
87 }
88
89 pub fn with_output_cost(mut self, cost: f64) -> Self {
90 self.output_cost_per_1k = Some(cost);
91 self
92 }
93
94 pub fn with_order(mut self, order: u32) -> Self {
95 self.order = order;
96 self
97 }
98
99 pub fn with_tag(mut self, key: String, value: String) -> Self {
100 self.tags.insert(key, value);
101 self
102 }
103}
104
105#[derive(Debug, Clone)]
106pub struct DeploymentPool {
107 deployments: HashMap<String, Vec<Arc<Deployment>>>,
108}
109
110impl DeploymentPool {
111 pub fn new() -> Self {
112 Self {
113 deployments: HashMap::new(),
114 }
115 }
116
117 pub fn add(&mut self, deployment: Deployment) {
118 let entry = self
119 .deployments
120 .entry(deployment.model_name.clone())
121 .or_default();
122 entry.push(Arc::new(deployment));
123 entry.sort_by_key(|d| d.order);
124 }
125
126 pub fn remove(&mut self, id: &str) -> bool {
127 let mut found_key = None;
128
129 for (key, deployments) in self.deployments.iter_mut() {
130 let initial_len = deployments.len();
131 deployments.retain(|d| d.id != id);
132 if deployments.len() < initial_len {
133 found_key = Some(key.clone());
134 break;
135 }
136 }
137
138 if let Some(ref key) = found_key {
139 if self.deployments.get(key).is_some_and(|v| v.is_empty()) {
140 self.deployments.remove(key);
141 }
142 }
143
144 found_key.is_some()
145 }
146
147 pub fn get(&self, model_name: &str) -> Option<&[Arc<Deployment>]> {
148 self.deployments.get(model_name).map(|v| v.as_slice())
149 }
150
151 pub fn model_names(&self) -> Vec<String> {
152 self.deployments.keys().cloned().collect()
153 }
154
155 pub fn rebuild(&mut self) {
156 for deployments in self.deployments.values_mut() {
157 deployments.sort_by_key(|d| d.order);
158 }
159 }
160
161 pub fn is_empty(&self) -> bool {
162 self.deployments.is_empty()
163 }
164
165 pub fn total_deployments(&self) -> usize {
166 self.deployments.values().map(|v| v.len()).sum()
167 }
168}
169
170impl Default for DeploymentPool {
171 fn default() -> Self {
172 Self::new()
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 #[test]
181 fn test_deployment_id_determinism() {
182 let d1 = Deployment::new(
183 "gpt-4".to_string(),
184 Provider::OpenAI,
185 "gpt-4".to_string(),
186 "key1".to_string(),
187 );
188 let d2 = Deployment::new(
189 "gpt-4".to_string(),
190 Provider::OpenAI,
191 "gpt-4".to_string(),
192 "key1".to_string(),
193 );
194 assert_eq!(d1.id, d2.id);
195 }
196
197 #[test]
198 fn test_deployment_id_differs_by_api_key() {
199 let d1 = Deployment::new(
200 "gpt-4".to_string(),
201 Provider::OpenAI,
202 "gpt-4".to_string(),
203 "key1".to_string(),
204 );
205 let d2 = Deployment::new(
206 "gpt-4".to_string(),
207 Provider::OpenAI,
208 "gpt-4".to_string(),
209 "key2".to_string(),
210 );
211 assert_ne!(d1.id, d2.id);
212 }
213
214 #[test]
215 fn test_deployment_id_differs_by_base_url() {
216 let d1 = Deployment::new(
217 "gpt-4".to_string(),
218 Provider::OpenAI,
219 "gpt-4".to_string(),
220 "key1".to_string(),
221 );
222 let d2 = Deployment::new(
223 "gpt-4".to_string(),
224 Provider::OpenAI,
225 "gpt-4".to_string(),
226 "key1".to_string(),
227 )
228 .with_base_url("https://custom.api.com".to_string());
229 assert_ne!(d1.id, d2.id);
230 }
231
232 #[test]
233 fn test_deployment_id_differs_by_provider() {
234 let d1 = Deployment::new(
235 "model".to_string(),
236 Provider::OpenAI,
237 "model".to_string(),
238 "key1".to_string(),
239 );
240 let d2 = Deployment::new(
241 "model".to_string(),
242 Provider::Anthropic,
243 "model".to_string(),
244 "key1".to_string(),
245 );
246 assert_ne!(d1.id, d2.id);
247 }
248
249 #[test]
250 fn test_deployment_builder_defaults() {
251 let d = Deployment::new(
252 "gpt-4".to_string(),
253 Provider::OpenAI,
254 "gpt-4".to_string(),
255 "key1".to_string(),
256 );
257 assert_eq!(d.weight, 1);
258 assert_eq!(d.order, 0);
259 assert!(d.base_url.is_none());
260 assert!(d.rpm_limit.is_none());
261 assert!(d.tpm_limit.is_none());
262 assert!(d.input_cost_per_1k.is_none());
263 assert!(d.output_cost_per_1k.is_none());
264 assert!(d.tags.is_empty());
265 }
266
267 #[test]
268 fn test_deployment_builder_chain() {
269 let d = Deployment::new(
270 "gpt-4".to_string(),
271 Provider::OpenAI,
272 "gpt-4".to_string(),
273 "key1".to_string(),
274 )
275 .with_base_url("https://api.openai.com".to_string())
276 .with_weight(5)
277 .with_rpm_limit(1000)
278 .with_tpm_limit(50000)
279 .with_input_cost(0.03)
280 .with_output_cost(0.06)
281 .with_order(10)
282 .with_tag("env".to_string(), "prod".to_string())
283 .with_tag("region".to_string(), "us-east".to_string());
284
285 assert_eq!(d.base_url, Some("https://api.openai.com".to_string()));
286 assert_eq!(d.weight, 5);
287 assert_eq!(d.rpm_limit, Some(1000));
288 assert_eq!(d.tpm_limit, Some(50000));
289 assert_eq!(d.input_cost_per_1k, Some(0.03));
290 assert_eq!(d.output_cost_per_1k, Some(0.06));
291 assert_eq!(d.order, 10);
292 assert_eq!(d.tags.len(), 2);
293 assert_eq!(d.tags.get("env"), Some(&"prod".to_string()));
294 assert_eq!(d.tags.get("region"), Some(&"us-east".to_string()));
295 }
296
297 #[test]
298 fn test_deployment_serialization_roundtrip() {
299 let d = Deployment::new(
300 "gpt-4".to_string(),
301 Provider::OpenAI,
302 "gpt-4".to_string(),
303 "key1".to_string(),
304 )
305 .with_weight(3)
306 .with_rpm_limit(500)
307 .with_tag("env".to_string(), "test".to_string());
308
309 let json = serde_json::to_string(&d).unwrap();
310 let deserialized: Deployment = serde_json::from_str(&json).unwrap();
311
312 assert_eq!(d.id, deserialized.id);
313 assert_eq!(d.model_name, deserialized.model_name);
314 assert_eq!(d.provider, deserialized.provider);
315 assert_eq!(d.model, deserialized.model);
316 assert_eq!(d.weight, deserialized.weight);
317 assert_eq!(d.rpm_limit, deserialized.rpm_limit);
318 assert_eq!(d.tags, deserialized.tags);
319 }
320
321 #[test]
322 fn test_pool_grouping_by_model_name() {
323 let mut pool = DeploymentPool::new();
324
325 pool.add(Deployment::new(
326 "gpt-4".to_string(),
327 Provider::OpenAI,
328 "gpt-4".to_string(),
329 "key1".to_string(),
330 ));
331 pool.add(Deployment::new(
332 "gpt-4".to_string(),
333 Provider::OpenAI,
334 "gpt-4".to_string(),
335 "key2".to_string(),
336 ));
337 pool.add(Deployment::new(
338 "claude-3".to_string(),
339 Provider::Anthropic,
340 "claude-3-opus".to_string(),
341 "key3".to_string(),
342 ));
343
344 let gpt4_deployments = pool.get("gpt-4").unwrap();
345 assert_eq!(gpt4_deployments.len(), 2);
346
347 let claude_deployments = pool.get("claude-3").unwrap();
348 assert_eq!(claude_deployments.len(), 1);
349
350 assert_eq!(pool.total_deployments(), 3);
351 }
352
353 #[test]
354 fn test_pool_ordering_by_order_field() {
355 let mut pool = DeploymentPool::new();
356
357 pool.add(
358 Deployment::new(
359 "gpt-4".to_string(),
360 Provider::OpenAI,
361 "gpt-4".to_string(),
362 "key1".to_string(),
363 )
364 .with_order(3),
365 );
366 pool.add(
367 Deployment::new(
368 "gpt-4".to_string(),
369 Provider::OpenAI,
370 "gpt-4".to_string(),
371 "key2".to_string(),
372 )
373 .with_order(1),
374 );
375 pool.add(
376 Deployment::new(
377 "gpt-4".to_string(),
378 Provider::OpenAI,
379 "gpt-4".to_string(),
380 "key3".to_string(),
381 )
382 .with_order(2),
383 );
384
385 let deployments = pool.get("gpt-4").unwrap();
386 assert_eq!(deployments[0].order, 1);
387 assert_eq!(deployments[1].order, 2);
388 assert_eq!(deployments[2].order, 3);
389 }
390
391 #[test]
392 fn test_pool_remove() {
393 let mut pool = DeploymentPool::new();
394
395 let d1 = Deployment::new(
396 "gpt-4".to_string(),
397 Provider::OpenAI,
398 "gpt-4".to_string(),
399 "key1".to_string(),
400 );
401 let d2 = Deployment::new(
402 "gpt-4".to_string(),
403 Provider::OpenAI,
404 "gpt-4".to_string(),
405 "key2".to_string(),
406 );
407 let d1_id = d1.id.clone();
408
409 pool.add(d1);
410 pool.add(d2);
411 assert_eq!(pool.total_deployments(), 2);
412
413 let removed = pool.remove(&d1_id);
414 assert!(removed);
415 assert_eq!(pool.total_deployments(), 1);
416
417 let removed_again = pool.remove(&d1_id);
418 assert!(!removed_again);
419 }
420
421 #[test]
422 fn test_pool_model_names() {
423 let mut pool = DeploymentPool::new();
424
425 pool.add(Deployment::new(
426 "gpt-4".to_string(),
427 Provider::OpenAI,
428 "gpt-4".to_string(),
429 "key1".to_string(),
430 ));
431 pool.add(Deployment::new(
432 "claude-3".to_string(),
433 Provider::Anthropic,
434 "claude-3-opus".to_string(),
435 "key2".to_string(),
436 ));
437 pool.add(Deployment::new(
438 "gpt-4".to_string(),
439 Provider::OpenAI,
440 "gpt-4".to_string(),
441 "key3".to_string(),
442 ));
443
444 let mut names = pool.model_names();
445 names.sort();
446 assert_eq!(names, vec!["claude-3".to_string(), "gpt-4".to_string()]);
447 }
448
449 #[test]
450 fn test_pool_rebuild() {
451 let mut pool = DeploymentPool::new();
452
453 let d1 = Deployment::new(
454 "gpt-4".to_string(),
455 Provider::OpenAI,
456 "gpt-4".to_string(),
457 "key1".to_string(),
458 )
459 .with_order(1);
460 let d2 = Deployment::new(
461 "gpt-4".to_string(),
462 Provider::OpenAI,
463 "gpt-4".to_string(),
464 "key2".to_string(),
465 )
466 .with_order(2);
467
468 pool.add(d2);
469 pool.add(d1);
470
471 let deployments = pool.get("gpt-4").unwrap();
472 assert_eq!(deployments[0].order, 1);
473 assert_eq!(deployments[1].order, 2);
474
475 pool.rebuild();
476
477 let deployments = pool.get("gpt-4").unwrap();
478 assert_eq!(deployments[0].order, 1);
479 assert_eq!(deployments[1].order, 2);
480 }
481}