burn_dragon_train 0.4.0

Training utilities for burn_dragon
Documentation
use std::path::PathBuf;

use serde::Serialize;
use serde::de::DeserializeOwned;

use super::{VisionTrainingConfig, load_vision_training_config};

fn config_root() -> PathBuf {
    PathBuf::from(env!("CARGO_MANIFEST_DIR"))
        .join("..")
        .join("..")
        .join("config")
}

fn roundtrip_config<T>(config: &T) -> T
where
    T: Serialize + DeserializeOwned,
{
    let serialized = toml::to_string(config).expect("serialize config");
    toml::from_str(&serialized).expect("parse serialized config")
}

#[test]
fn vision_configs_parse_serialize_validate() {
    let root = config_root();
    let files = [
        "vision/base.toml",
        "vision/identity/tiny.toml",
        "vision/mae/tiny.toml",
        "vision/croco/tiny.toml",
        "vision/lejepa/tiny.toml",
        "vision/saccade/tiny.toml",
    ];

    for file in files {
        let paths = vec![root.join(file)];
        let config: VisionTrainingConfig =
            load_vision_training_config(&paths).unwrap_or_else(|err| {
                panic!("failed to load vision config from {paths:?}: {err}");
            });
        config
            .validate()
            .unwrap_or_else(|err| panic!("vision config validation failed: {err}"));

        let roundtripped: VisionTrainingConfig = roundtrip_config(&config);
        roundtripped
            .validate()
            .unwrap_or_else(|err| panic!("roundtripped vision config validation failed: {err}"));
    }
}