use std::collections::BTreeMap;
use std::sync::atomic::{AtomicBool, Ordering};
use serde::Deserialize;
use super::*;
#[derive(Debug, Clone, Deserialize, Default)]
pub struct ProvidersConfig {
#[serde(default)]
pub default_provider: Option<String>,
#[serde(default)]
pub providers: BTreeMap<String, ProviderDef>,
#[serde(default)]
pub aliases: BTreeMap<String, AliasDef>,
#[serde(default)]
pub alias_tool_calling: BTreeMap<String, AliasToolCallingDef>,
#[serde(default)]
pub models: BTreeMap<String, ModelDef>,
#[serde(default)]
pub qc_defaults: BTreeMap<String, String>,
#[serde(default)]
pub inference_rules: Vec<InferenceRule>,
#[serde(default)]
pub tier_rules: Vec<TierRule>,
#[serde(default)]
pub tier_defaults: TierDefaults,
#[serde(default)]
pub model_defaults: BTreeMap<String, BTreeMap<String, toml::Value>>,
#[serde(default)]
pub model_roles: BTreeMap<String, BTreeMap<String, toml::Value>>,
#[serde(default)]
pub suppress: SuppressDef,
#[serde(default)]
pub patch: PatchDef,
}
#[derive(Debug, Clone, Deserialize, Default, PartialEq)]
pub struct PatchDef {
#[serde(default)]
pub models: BTreeMap<String, toml::Value>,
}
#[derive(Debug, Clone, Deserialize, Default, PartialEq, Eq)]
pub struct SuppressDef {
#[serde(default)]
pub routes: Vec<String>,
}
impl ProvidersConfig {
pub fn is_empty(&self) -> bool {
self.default_provider.is_none()
&& self.providers.is_empty()
&& self.aliases.is_empty()
&& self.alias_tool_calling.is_empty()
&& self.models.is_empty()
&& self.qc_defaults.is_empty()
&& self.inference_rules.is_empty()
&& self.tier_rules.is_empty()
&& self.model_defaults.is_empty()
&& self.model_roles.is_empty()
&& self.suppress.routes.is_empty()
&& self.patch.models.is_empty()
&& self.tier_defaults.default == default_mid()
}
pub fn dangling_model_patches(&self) -> Vec<&str> {
self.patch
.models
.keys()
.filter(|id| !self.models.contains_key(*id))
.map(String::as_str)
.collect()
}
pub fn merge_from(&mut self, overlay: &ProvidersConfig) {
for (name, provider) in &overlay.providers {
match self.providers.get_mut(name) {
Some(existing) => existing.merge_from(provider),
None => {
self.providers.insert(name.clone(), provider.clone());
}
}
}
self.aliases.extend(overlay.aliases.clone());
self.alias_tool_calling
.extend(overlay.alias_tool_calling.clone());
self.models.extend(overlay.models.clone());
self.qc_defaults.extend(overlay.qc_defaults.clone());
if !overlay.patch.models.is_empty() || !self.patch.models.is_empty() {
for (id, patch) in &overlay.patch.models {
match self.patch.models.get_mut(id) {
Some(existing) => deep_merge_toml(existing, patch),
None => {
self.patch.models.insert(id.clone(), patch.clone());
}
}
}
apply_model_patches(&mut self.models, &self.patch.models);
}
if overlay.default_provider.is_some() {
self.default_provider = overlay.default_provider.clone();
}
if !overlay.inference_rules.is_empty() {
let mut merged = overlay.inference_rules.clone();
merged.extend(self.inference_rules.clone());
self.inference_rules = merged;
}
if !overlay.tier_rules.is_empty() {
let mut merged = overlay.tier_rules.clone();
merged.extend(self.tier_rules.clone());
self.tier_rules = merged;
}
if overlay.tier_defaults.default != default_mid() {
self.tier_defaults = overlay.tier_defaults.clone();
}
for (pattern, defaults) in &overlay.model_defaults {
self.model_defaults
.entry(pattern.clone())
.or_default()
.extend(defaults.clone());
}
for (role, defaults) in &overlay.model_roles {
self.model_roles
.entry(role.clone())
.or_default()
.extend(defaults.clone());
}
for route in &overlay.suppress.routes {
if !self.suppress.routes.contains(route) {
self.suppress.routes.push(route.clone());
}
}
}
}
fn deep_merge_toml(base: &mut toml::Value, overlay: &toml::Value) {
match (base, overlay) {
(toml::Value::Table(base_table), toml::Value::Table(overlay_table)) => {
for (key, overlay_value) in overlay_table {
match base_table.get_mut(key) {
Some(base_value) => deep_merge_toml(base_value, overlay_value),
None => {
base_table.insert(key.clone(), overlay_value.clone());
}
}
}
}
(base_slot, overlay_value) => *base_slot = overlay_value.clone(),
}
}
static MODEL_PATCH_TYPE_ERROR_WARNED: AtomicBool = AtomicBool::new(false);
fn apply_model_patches(
models: &mut BTreeMap<String, ModelDef>,
patches: &BTreeMap<String, toml::Value>,
) {
for (id, patch) in patches {
let Some(base) = models.get(id) else {
continue;
};
match patched_model_row(base, patch) {
Ok(patched) => {
models.insert(id.clone(), patched);
}
Err(error) => {
if !MODEL_PATCH_TYPE_ERROR_WARNED.swap(true, Ordering::Relaxed) {
eprintln!(
"[llm_config] invalid [patch.models.\"{id}\"] overlay \
(keeping the unpatched row): {error}"
);
}
}
}
}
}
fn patched_model_row(base: &ModelDef, patch: &toml::Value) -> Result<ModelDef, String> {
let mut value = toml::Value::try_from(base)
.map_err(|error| format!("serialize base row for patching: {error}"))?;
deep_merge_toml(&mut value, patch);
ModelDef::deserialize(value).map_err(|error| error.to_string())
}
#[derive(Debug, Clone, Deserialize)]
pub struct InferenceRule {
#[serde(default)]
pub pattern: Option<String>,
#[serde(default)]
pub contains: Option<String>,
#[serde(default)]
pub exact: Option<String>,
pub provider: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct TierRule {
#[serde(default)]
pub pattern: Option<String>,
#[serde(default)]
pub contains: Option<String>,
#[serde(default)]
pub exact: Option<String>,
pub tier: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct TierDefaults {
#[serde(default = "default_mid")]
pub default: String,
}
impl Default for TierDefaults {
fn default() -> Self {
Self {
default: default_mid(),
}
}
}
fn default_mid() -> String {
"mid".to_string()
}