1use crate::error::KeyPoolExhaustedError;
2use crate::{LlmixError, LlmixResult};
3use std::collections::HashSet;
4use std::env;
5use std::sync::Mutex;
6
7#[derive(Debug)]
8struct KeyPoolState {
9 dead: HashSet<String>,
10 idx: usize,
11}
12
13#[derive(Debug)]
14pub struct KeyPool {
15 keys: Vec<String>,
16 state: Mutex<KeyPoolState>,
17}
18
19impl KeyPool {
20 pub fn new(keys: Vec<String>) -> LlmixResult<Self> {
21 let mut cleaned = Vec::new();
22 let mut seen = HashSet::new();
23
24 for key in keys {
25 let trimmed = key.trim();
26 if trimmed.is_empty() {
27 continue;
28 }
29 if seen.insert(trimmed.to_owned()) {
30 cleaned.push(trimmed.to_owned());
31 }
32 }
33
34 if cleaned.is_empty() {
35 return Err(LlmixError::InvalidKeyPoolConfig(
36 "KeyPool requires at least one non-empty key".to_owned(),
37 ));
38 }
39
40 Ok(Self {
41 keys: cleaned,
42 state: Mutex::new(KeyPoolState {
43 dead: HashSet::new(),
44 idx: 0,
45 }),
46 })
47 }
48
49 pub fn select(&self) -> Result<String, KeyPoolExhaustedError> {
50 let mut state = self.state.lock().unwrap_or_else(|e| e.into_inner());
51
52 if state.dead.len() >= self.keys.len() {
53 return Err(KeyPoolExhaustedError {
54 total_keys: self.keys.len(),
55 });
56 }
57
58 for offset in 0..self.keys.len() {
59 let idx = (state.idx + offset) % self.keys.len();
60 let key = &self.keys[idx];
61 if !state.dead.contains(key) {
62 state.idx = idx;
63 return Ok(key.clone());
64 }
65 }
66
67 Err(KeyPoolExhaustedError {
68 total_keys: self.keys.len(),
69 })
70 }
71
72 pub fn rotate(&self) {
73 let mut state = self.state.lock().unwrap_or_else(|e| e.into_inner());
74
75 if state.dead.len() >= self.keys.len() {
76 return;
77 }
78
79 for offset in 1..=self.keys.len() {
80 let idx = (state.idx + offset) % self.keys.len();
81 if !state.dead.contains(&self.keys[idx]) {
82 state.idx = idx;
83 return;
84 }
85 }
86 }
87
88 pub fn mark_dead(&self, key: &str) -> LlmixResult<()> {
89 if !self.keys.iter().any(|candidate| candidate == key) {
90 return Err(LlmixError::UnknownKeyPoolKey(
91 "Key not in pool: cannot mark unknown key as dead".to_owned(),
92 ));
93 }
94
95 let mut state = self.state.lock().unwrap_or_else(|e| e.into_inner());
96 state.dead.insert(key.to_owned());
97 Ok(())
98 }
99
100 pub fn is_exhausted(&self) -> bool {
101 let state = self.state.lock().unwrap_or_else(|e| e.into_inner());
102 state.dead.len() >= self.keys.len()
103 }
104
105 pub fn alive_count(&self) -> usize {
106 let state = self.state.lock().unwrap_or_else(|e| e.into_inner());
107 self.keys.len() - state.dead.len()
108 }
109
110 pub fn total_count(&self) -> usize {
111 self.keys.len()
112 }
113}
114
115pub fn load_keys_from_env(provider: &str) -> LlmixResult<KeyPool> {
116 let provider_upper = provider
117 .chars()
118 .map(|character| match character {
119 'a'..='z' => character.to_ascii_uppercase(),
120 'A'..='Z' | '0'..='9' => character,
121 _ => '_',
122 })
123 .collect::<String>();
124 let keys_var = format!("{provider_upper}_KEYS");
125 let key_var = format!("{provider_upper}_API_KEY");
126
127 if let Ok(keys_raw) = env::var(&keys_var) {
128 if !keys_raw.trim().is_empty() {
129 return KeyPool::new(keys_raw.split(',').map(ToOwned::to_owned).collect());
130 }
131 }
132
133 if let Ok(single_key) = env::var(&key_var) {
134 if !single_key.trim().is_empty() {
135 return KeyPool::new(vec![single_key]);
136 }
137 }
138
139 Err(LlmixError::InvalidKeyPoolConfig(format!(
140 "No API keys found for {provider_upper}. Set {keys_var} (comma-separated) or {key_var}."
141 )))
142}