use std::collections::HashSet;
use std::sync::Mutex;
use serde_json::{Map, Value};
use crate::error::UnsupportedModelError;
mod defaults;
#[cfg(test)]
mod tests;
pub use defaults::MODEL_SAMPLING_DEFAULTS;
static INFO_LOGGED: Mutex<Option<HashSet<String>>> = Mutex::new(None);
pub fn get_sampling_defaults(model: &str) -> Map<String, Value> {
match MODEL_SAMPLING_DEFAULTS.get(model) {
Some(entry) => entry.clone(),
None => Map::new(),
}
}
pub fn apply_sampling_defaults(
model: &str,
strict: bool,
) -> Result<Map<String, Value>, UnsupportedModelError> {
let defaults = get_sampling_defaults(model);
let known = !defaults.is_empty();
if strict {
if known {
Ok(defaults)
} else {
Err(UnsupportedModelError::new(model))
}
} else if known {
fire_one_shot_info(model);
Ok(Map::new())
} else {
Ok(Map::new())
}
}
fn fire_one_shot_info(model: &str) {
let Ok(mut guard) = INFO_LOGGED.lock() else {
return;
};
let logged = guard.get_or_insert_with(HashSet::new);
if logged.insert(model.to_string()) {
log::info!(
"Model '{}' has recommended sampling defaults. \
Consider opting in with strict mode for optimal behavior.",
model
);
}
}