use std::collections::HashSet;
use std::sync::OnceLock;
use regex::Regex;
use serde_json::Value;
struct LockedEntry {
pattern: Regex,
locked: &'static [&'static str],
}
fn registry() -> &'static [LockedEntry] {
static REG: OnceLock<Vec<LockedEntry>> = OnceLock::new();
REG.get_or_init(|| {
vec![
LockedEntry {
pattern: Regex::new(r"^kimi-k2\.").expect("static regex"),
locked: &[
"temperature",
"top_p",
"top_k",
"n",
"presence_penalty",
"frequency_penalty",
],
},
LockedEntry {
pattern: Regex::new(r"^o1").expect("static regex"),
locked: &[
"temperature",
"top_p",
"presence_penalty",
"frequency_penalty",
"logprobs",
"logit_bias",
],
},
LockedEntry {
pattern: Regex::new(r"^o3").expect("static regex"),
locked: &[
"temperature",
"top_p",
"presence_penalty",
"frequency_penalty",
"logprobs",
"logit_bias",
],
},
]
})
}
fn normalise(model_name: &str) -> &str {
model_name.rsplit('/').next().unwrap_or(model_name)
}
pub fn locked_params_for_model(model_name: &str) -> HashSet<&'static str> {
let mut locked: HashSet<&'static str> = HashSet::new();
if model_name.is_empty() {
return locked;
}
let normalised = normalise(model_name);
for entry in registry() {
if entry.pattern.is_match(normalised) {
for param in entry.locked {
locked.insert(*param);
}
}
}
locked
}
pub fn is_locked(model_name: &str, parameter: &str) -> bool {
if model_name.is_empty() {
return false;
}
let normalised = normalise(model_name);
for entry in registry() {
if entry.pattern.is_match(normalised)
&& entry.locked.iter().any(|p| *p == parameter)
{
return true;
}
}
false
}
pub fn apply_sampling_params(body: &mut Value, model_name: &str) -> Vec<(String, Value)> {
let locked = locked_params_for_model(model_name);
if locked.is_empty() {
return Vec::new();
}
let Some(obj) = body.as_object_mut() else {
return Vec::new();
};
let mut removed: Vec<(String, Value)> = Vec::new();
for name in locked {
if let Some(value) = obj.remove(name) {
removed.push((name.to_string(), value));
}
}
removed
}
pub fn registered_patterns() -> Vec<(String, Vec<&'static str>)> {
registry()
.iter()
.map(|e| (e.pattern.as_str().to_string(), e.locked.to_vec()))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn empty_model_has_no_locked_params() {
assert!(locked_params_for_model("").is_empty());
}
#[test]
fn unrelated_model_has_no_locked_params() {
assert!(locked_params_for_model("gpt-4o").is_empty());
assert!(locked_params_for_model("claude-sonnet-4-5").is_empty());
assert!(locked_params_for_model("gemini-2.5-pro").is_empty());
}
#[test]
fn kimi_k2_6_locks_six_params() {
let locked = locked_params_for_model("kimi-k2.6");
assert!(locked.contains("temperature"));
assert!(locked.contains("top_p"));
assert!(locked.contains("top_k"));
assert!(locked.contains("n"));
assert!(locked.contains("presence_penalty"));
assert!(locked.contains("frequency_penalty"));
assert_eq!(locked.len(), 6);
}
#[test]
fn kimi_k2_8_also_locked() {
assert!(is_locked("kimi-k2.8", "temperature"));
}
#[test]
fn kimi_k1_not_locked() {
assert!(!is_locked("kimi-k1.5", "temperature"));
assert!(locked_params_for_model("kimi-k1.5").is_empty());
}
#[test]
fn o1_family_locks_six_params() {
let locked = locked_params_for_model("o1");
assert!(locked.contains("temperature"));
assert!(locked.contains("logprobs"));
assert!(locked.contains("logit_bias"));
assert_eq!(locked.len(), 6);
}
#[test]
fn o1_mini_and_o1_preview_match_pattern() {
assert!(is_locked("o1-mini", "temperature"));
assert!(is_locked("o1-preview", "logit_bias"));
}
#[test]
fn o3_family_locks_same_set_as_o1() {
let o1 = locked_params_for_model("o1");
let o3 = locked_params_for_model("o3-mini");
assert_eq!(o1, o3);
}
#[test]
fn apply_strips_locked_params_from_object() {
let mut body = json!({
"model": "kimi-k2.6",
"messages": [{"role": "user", "content": "hi"}],
"temperature": 0.5,
"top_p": 0.9,
"max_tokens": 2048,
});
let removed = apply_sampling_params(&mut body, "kimi-k2.6");
let removed_names: HashSet<String> = removed.iter().map(|(n, _)| n.clone()).collect();
assert!(removed_names.contains("temperature"));
assert!(removed_names.contains("top_p"));
assert!(body.get("max_tokens").is_some());
assert!(body.get("temperature").is_none());
assert!(body.get("top_p").is_none());
}
#[test]
fn apply_no_op_for_unlocked_model() {
let mut body = json!({
"model": "gpt-4o",
"temperature": 0.5,
});
let removed = apply_sampling_params(&mut body, "gpt-4o");
assert!(removed.is_empty());
assert_eq!(body["temperature"], 0.5);
}
#[test]
fn apply_handles_non_object_body_gracefully() {
let mut body = json!("not an object");
let removed = apply_sampling_params(&mut body, "kimi-k2.6");
assert!(removed.is_empty());
}
#[test]
fn registered_patterns_returns_all_three_families() {
let patterns = registered_patterns();
assert_eq!(patterns.len(), 3);
let pattern_strs: Vec<&str> = patterns.iter().map(|(p, _)| p.as_str()).collect();
assert!(pattern_strs.contains(&r"^kimi-k2\."));
assert!(pattern_strs.contains(&r"^o1"));
assert!(pattern_strs.contains(&r"^o3"));
}
#[test]
fn slug_form_strips_provider_prefix_for_openrouter() {
assert!(is_locked("openai/o1-mini", "temperature"));
assert!(is_locked("openai/o3", "logprobs"));
assert!(is_locked("moonshot/kimi-k2.6", "top_p"));
}
#[test]
fn slug_form_does_not_widen_match_for_unrelated_models() {
assert!(!is_locked("openai/gpt-4o-mini", "temperature"));
assert!(!is_locked("anthropic/claude-sonnet-4-5", "temperature"));
assert!(!is_locked("google/gemini-2.5-pro", "temperature"));
}
#[test]
fn slug_normalisation_idempotent_for_direct_model_names() {
let direct = locked_params_for_model("o1-mini");
let slug = locked_params_for_model("openai/o1-mini");
assert_eq!(direct, slug);
}
#[test]
fn apply_returns_removed_values_for_warning() {
let mut body = json!({
"model": "o1-mini",
"temperature": 0.7,
"logprobs": true,
});
let removed = apply_sampling_params(&mut body, "o1-mini");
let map: std::collections::HashMap<String, Value> = removed.into_iter().collect();
assert_eq!(map.get("temperature"), Some(&json!(0.7)));
assert_eq!(map.get("logprobs"), Some(&json!(true)));
}
}