Skip to main content

llmix_rs/
key_pool.rs

1use crate::error::KeyPoolExhaustedError;
2use crate::{LlmixError, LlmixResult};
3use std::collections::HashSet;
4use std::env;
5use std::sync::Mutex;
6
7#[derive(Debug)]
8struct KeyPoolState {
9    dead: HashSet<String>,
10    idx: usize,
11}
12
13#[derive(Debug)]
14pub struct KeyPool {
15    keys: Vec<String>,
16    state: Mutex<KeyPoolState>,
17}
18
19impl KeyPool {
20    pub fn new(keys: Vec<String>) -> LlmixResult<Self> {
21        let mut cleaned = Vec::new();
22        let mut seen = HashSet::new();
23
24        for key in keys {
25            let trimmed = key.trim();
26            if trimmed.is_empty() {
27                continue;
28            }
29            if seen.insert(trimmed.to_owned()) {
30                cleaned.push(trimmed.to_owned());
31            }
32        }
33
34        if cleaned.is_empty() {
35            return Err(LlmixError::InvalidKeyPoolConfig(
36                "KeyPool requires at least one non-empty key".to_owned(),
37            ));
38        }
39
40        Ok(Self {
41            keys: cleaned,
42            state: Mutex::new(KeyPoolState {
43                dead: HashSet::new(),
44                idx: 0,
45            }),
46        })
47    }
48
49    pub fn select(&self) -> Result<String, KeyPoolExhaustedError> {
50        let mut state = self.state.lock().unwrap_or_else(|e| e.into_inner());
51
52        if state.dead.len() >= self.keys.len() {
53            return Err(KeyPoolExhaustedError {
54                total_keys: self.keys.len(),
55            });
56        }
57
58        for offset in 0..self.keys.len() {
59            let idx = (state.idx + offset) % self.keys.len();
60            let key = &self.keys[idx];
61            if !state.dead.contains(key) {
62                state.idx = idx;
63                return Ok(key.clone());
64            }
65        }
66
67        Err(KeyPoolExhaustedError {
68            total_keys: self.keys.len(),
69        })
70    }
71
72    pub fn rotate(&self) {
73        let mut state = self.state.lock().unwrap_or_else(|e| e.into_inner());
74
75        if state.dead.len() >= self.keys.len() {
76            return;
77        }
78
79        for offset in 1..=self.keys.len() {
80            let idx = (state.idx + offset) % self.keys.len();
81            if !state.dead.contains(&self.keys[idx]) {
82                state.idx = idx;
83                return;
84            }
85        }
86    }
87
88    pub fn mark_dead(&self, key: &str) -> LlmixResult<()> {
89        if !self.keys.iter().any(|candidate| candidate == key) {
90            return Err(LlmixError::UnknownKeyPoolKey(
91                "Key not in pool: cannot mark unknown key as dead".to_owned(),
92            ));
93        }
94
95        let mut state = self.state.lock().unwrap_or_else(|e| e.into_inner());
96        state.dead.insert(key.to_owned());
97        Ok(())
98    }
99
100    pub fn is_exhausted(&self) -> bool {
101        let state = self.state.lock().unwrap_or_else(|e| e.into_inner());
102        state.dead.len() >= self.keys.len()
103    }
104
105    pub fn alive_count(&self) -> usize {
106        let state = self.state.lock().unwrap_or_else(|e| e.into_inner());
107        self.keys.len() - state.dead.len()
108    }
109
110    pub fn total_count(&self) -> usize {
111        self.keys.len()
112    }
113}
114
115pub fn load_keys_from_env(provider: &str) -> LlmixResult<KeyPool> {
116    let provider_upper = provider
117        .chars()
118        .map(|character| match character {
119            'a'..='z' => character.to_ascii_uppercase(),
120            'A'..='Z' | '0'..='9' => character,
121            _ => '_',
122        })
123        .collect::<String>();
124    let keys_var = format!("{provider_upper}_KEYS");
125    let key_var = format!("{provider_upper}_API_KEY");
126
127    if let Ok(keys_raw) = env::var(&keys_var) {
128        if !keys_raw.trim().is_empty() {
129            return KeyPool::new(keys_raw.split(',').map(ToOwned::to_owned).collect());
130        }
131    }
132
133    if let Ok(single_key) = env::var(&key_var) {
134        if !single_key.trim().is_empty() {
135            return KeyPool::new(vec![single_key]);
136        }
137    }
138
139    Err(LlmixError::InvalidKeyPoolConfig(format!(
140        "No API keys found for {provider_upper}. Set {keys_var} (comma-separated) or {key_var}."
141    )))
142}