burn_dragon_language 0.5.0

Language modeling components for burn_dragon
Documentation
use std::fs;
use std::path::{Path, PathBuf};

use anyhow::{Context, Result, anyhow};
use serde::de::DeserializeOwned;
use toml::Value;

pub(crate) fn load_merged_toml<T>(paths: &[PathBuf]) -> Result<T>
where
    T: DeserializeOwned,
{
    if paths.is_empty() {
        return Err(anyhow!("at least one configuration path is required"));
    }

    let mut iter = paths.iter();
    let first_path = iter
        .next()
        .ok_or_else(|| anyhow!("configuration iterator unexpectedly empty"))?;
    let mut value = load_value(first_path)?;

    for path in iter {
        let overlay = load_value(path)?;
        merge_values(&mut value, overlay);
    }

    reject_removed_training_fast_train(&value)?;
    value.try_into::<T>().map_err(|err| anyhow!(err))
}

fn reject_removed_training_fast_train(value: &Value) -> Result<()> {
    match value {
        Value::Table(table) => {
            if let Some(Value::Table(training)) = table.get("training")
                && training.contains_key("fast_train")
            {
                return Err(anyhow!(
                    "training.fast_train has been removed from the language config schema; use training.launch_mode and training.sequence_kernel_override explicitly"
                ));
            }
            for child in table.values() {
                reject_removed_training_fast_train(child)?;
            }
        }
        Value::Array(values) => {
            for child in values {
                reject_removed_training_fast_train(child)?;
            }
        }
        _ => {}
    }
    Ok(())
}

fn load_value(path: &Path) -> Result<Value> {
    let mut stack = Vec::new();
    load_value_recursive(path, &mut stack)
}

fn load_value_recursive(path: &Path, stack: &mut Vec<PathBuf>) -> Result<Value> {
    let canonical = fs::canonicalize(path).with_context(|| {
        format!(
            "failed to canonicalize configuration file {}",
            path.display()
        )
    })?;
    if let Some(idx) = stack.iter().position(|seen| seen == &canonical) {
        let mut cycle = stack[idx..]
            .iter()
            .map(|path| path.display().to_string())
            .collect::<Vec<_>>();
        cycle.push(canonical.display().to_string());
        return Err(anyhow!(
            "config extends cycle detected: {}",
            cycle.join(" -> ")
        ));
    }

    stack.push(canonical.clone());
    let result = (|| {
        let content = fs::read_to_string(&canonical).with_context(|| {
            format!("failed to read configuration file {}", canonical.display())
        })?;
        let table: toml::value::Table = toml::from_str(&content)
            .with_context(|| format!("failed to parse {} as TOML", canonical.display()))?;
        let mut value = Value::Table(table);
        let extends = take_extends(&mut value)
            .with_context(|| format!("failed to parse extends in {}", canonical.display()))?;
        if let Some(extends) = extends {
            let base_dir = canonical.parent().unwrap_or_else(|| Path::new("."));
            let mut merged = Value::Table(toml::value::Table::new());
            for extend in extends {
                let extend_path = base_dir.join(extend);
                let base = load_value_recursive(&extend_path, stack)?;
                merge_values(&mut merged, base);
            }
            merge_values(&mut merged, value);
            Ok(merged)
        } else {
            Ok(value)
        }
    })();
    stack.pop();
    result
}

fn take_extends(value: &mut Value) -> Result<Option<Vec<PathBuf>>> {
    let Value::Table(table) = value else {
        return Ok(None);
    };
    let Some(extends) = table.remove("extends") else {
        return Ok(None);
    };
    match extends {
        Value::String(path) => Ok(Some(vec![PathBuf::from(path)])),
        Value::Array(values) => {
            let mut out = Vec::with_capacity(values.len());
            for value in values {
                match value {
                    Value::String(path) => out.push(PathBuf::from(path)),
                    other => {
                        return Err(anyhow!(
                            "extends entries must be strings, got {}",
                            other.type_str()
                        ));
                    }
                }
            }
            Ok(Some(out))
        }
        other => Err(anyhow!(
            "extends must be a string or array of strings, got {}",
            other.type_str()
        )),
    }
}

pub(crate) fn merge_values(base: &mut Value, overlay: Value) {
    match (base, overlay) {
        (Value::Table(base_table), Value::Table(overlay_table)) => {
            if let Some(Value::String(overlay_type)) = overlay_table.get("type") {
                let type_changed = match base_table.get("type") {
                    Some(Value::String(base_type)) => base_type != overlay_type,
                    Some(_) => true,
                    None => !base_table.is_empty(),
                };
                if type_changed {
                    base_table.clear();
                }
            }
            for (key, overlay_value) in overlay_table {
                match base_table.get_mut(&key) {
                    Some(base_value) => merge_values(base_value, overlay_value),
                    None => {
                        base_table.insert(key, overlay_value);
                    }
                }
            }
        }
        (base_value, overlay_value) => {
            *base_value = overlay_value;
        }
    }
}