aprender-core 0.29.1

Next-generation machine learning library in pure Rust

// ============================================================================
// Tests
// ============================================================================

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_parse_scalar_types() {
        assert!(matches!(parse_scalar("42"), YamlValue::Int(42)));
        assert!(matches!(parse_scalar("3.14"), YamlValue::Float(_)));
        assert!(matches!(parse_scalar("true"), YamlValue::Bool(true)));
        assert!(matches!(parse_scalar("false"), YamlValue::Bool(false)));
        assert!(matches!(parse_scalar("null"), YamlValue::Null));
        assert!(matches!(parse_scalar("\"hello\""), YamlValue::String(_)));
        assert!(matches!(parse_scalar("hello"), YamlValue::String(_)));
    }

    #[test]
    fn test_parse_simple_yaml() {
        let yaml = r#"
family: qwen2
display_name: "Qwen2"
vendor: Alibaba
"#;
        let result = parse_yaml(yaml).expect("parse yaml");
        assert_eq!(result.get_str("family"), Some("qwen2"));
        assert_eq!(result.get_str("display_name"), Some("Qwen2"));
        assert_eq!(result.get_str("vendor"), Some("Alibaba"));
    }

    #[test]
    fn test_parse_nested_yaml() {
        let yaml = r#"
constraints:
  attention_type: gqa
  has_bias: true
  activation: silu
"#;
        let result = parse_yaml(yaml).expect("parse yaml");
        let constraints = result.get("constraints").expect("constraints");
        assert_eq!(constraints.get_str("attention_type"), Some("gqa"));
        assert_eq!(constraints.get_bool("has_bias"), Some(true));
    }

    #[test]
    fn test_parse_sequence_yaml() {
        let yaml = r#"
architectures:
  - Qwen2ForCausalLM
  - Qwen2ForSequenceClassification
"#;
        let result = parse_yaml(yaml).expect("parse yaml");
        let arch = result
            .get("architectures")
            .and_then(YamlValue::as_sequence)
            .expect("architectures");
        assert_eq!(arch.len(), 2);
        assert_eq!(arch[0].as_str(), Some("Qwen2ForCausalLM"));
    }

    #[test]
    fn test_load_qwen2_yaml() {
        let contracts_dir = Path::new(env!("CARGO_MANIFEST_DIR")).join("../../contracts");
        let yaml_path = contracts_dir.join("model-families/qwen2.yaml");
        if yaml_path.exists() {
            let config = load_family_yaml(&yaml_path).expect("load qwen2 yaml");
            assert_eq!(config.family, "qwen2");
            assert_eq!(config.vendor, "Alibaba");
            assert!(!config.size_variants.is_empty());
            assert!(config.size_variants.contains_key("0.5b"));

            let half_b = &config.size_variants["0.5b"];
            assert_eq!(half_b.hidden_dim, 896);
            assert_eq!(half_b.num_layers, 24);
        }
    }

    #[test]
    fn test_load_family_registry() {
        let contracts_dir = Path::new(env!("CARGO_MANIFEST_DIR")).join("../../contracts");
        if contracts_dir.join("model-families").exists() {
            let registry = load_family_registry(&contracts_dir).expect("load registry");
            assert!(
                !registry.is_empty(),
                "Registry should have at least one family"
            );

            let names = registry.family_names();
            assert!(names.contains(&"qwen2"), "Registry should contain qwen2");
        }
    }

    #[test]
    fn test_parse_null_values() {
        let yaml = r#"
tensor_template:
  embedding: model.embed_tokens.weight
  per_layer:
    q_proj: "model.layers.{n}.self_attn.q_proj.weight"
    q_proj_bias: null
"#;
        let result = parse_yaml(yaml).expect("parse yaml");
        let template = parse_tensor_template(result.get("tensor_template").expect("template"))
            .expect("parse template");
        assert!(template.per_layer.get("q_proj").expect("q_proj").is_some());
        assert!(template
            .per_layer
            .get("q_proj_bias")
            .expect("q_proj_bias")
            .is_none());
    }

    #[test]
    fn test_parse_scalar_quoted_single() {
        assert!(matches!(parse_scalar("'hello'"), YamlValue::String(_)));
        if let YamlValue::String(s) = parse_scalar("'hello'") {
            assert_eq!(s, "hello");
        }
    }

    #[test]
    fn test_parse_scalar_tilde_null() {
        assert!(matches!(parse_scalar("~"), YamlValue::Null));
    }

    #[test]
    fn test_parse_scalar_yes_no() {
        assert!(matches!(parse_scalar("yes"), YamlValue::Bool(true)));
        assert!(matches!(parse_scalar("no"), YamlValue::Bool(false)));
    }

    #[test]
    fn test_yaml_value_as_i64() {
        let v = YamlValue::Int(42);
        assert_eq!(v.as_i64(), Some(42));
        let v = YamlValue::Float(3.9);
        assert_eq!(v.as_i64(), Some(3));
        let v = YamlValue::String("nope".to_string());
        assert_eq!(v.as_i64(), None);
    }

    #[test]
    fn test_yaml_value_as_f64() {
        let v = YamlValue::Float(3.14);
        assert!((v.as_f64().expect("f64") - 3.14).abs() < f64::EPSILON);
        let v = YamlValue::Int(42);
        assert!((v.as_f64().expect("f64") - 42.0).abs() < f64::EPSILON);
    }

    #[test]
    fn test_yaml_value_as_bool() {
        let v = YamlValue::Bool(true);
        assert_eq!(v.as_bool(), Some(true));
        let v = YamlValue::String("nope".to_string());
        assert_eq!(v.as_bool(), None);
    }

    #[test]
    fn test_yaml_value_as_usize() {
        let v = YamlValue::Int(10);
        assert_eq!(v.as_usize(), Some(10));
        let v = YamlValue::Int(-1);
        assert_eq!(v.as_usize(), None);
    }

    #[test]
    fn test_yaml_value_as_sequence() {
        let v = YamlValue::Sequence(vec![YamlValue::Int(1)]);
        assert!(v.as_sequence().is_some());
        let v = YamlValue::Int(1);
        assert!(v.as_sequence().is_none());
    }

    #[test]
    fn test_yaml_value_as_mapping() {
        let v = YamlValue::Mapping(vec![("key".to_string(), YamlValue::Int(1))]);
        assert!(v.as_mapping().is_some());
        let v = YamlValue::Int(1);
        assert!(v.as_mapping().is_none());
    }

    #[test]
    fn test_yaml_value_get_on_non_mapping() {
        let v = YamlValue::Int(42);
        assert!(v.get("key").is_none());
        assert!(v.get_str("key").is_none());
        assert!(v.get_usize("key").is_none());
        assert!(v.get_f64("key").is_none());
        assert!(v.get_bool("key").is_none());
    }

    #[test]
    fn test_yaml_value_as_str_on_non_string() {
        let v = YamlValue::Int(42);
        assert!(v.as_str().is_none());
    }

    #[test]
    fn test_parse_empty_yaml() {
        let yaml = "";
        let result = parse_yaml(yaml).expect("parse yaml");
        assert!(result.get("anything").is_none());
    }

    #[test]
    fn test_parse_yaml_comments_only() {
        let yaml = "# just a comment\n# another comment\n";
        let result = parse_yaml(yaml).expect("parse yaml");
        assert!(result.get("anything").is_none());
    }

    #[test]
    fn test_default_contracts_dir() {
        let root = Path::new("/tmp/project");
        let dir = default_contracts_dir(root);
        assert_eq!(dir, Path::new("/tmp/project/contracts"));
    }

    #[test]
    fn test_load_family_registry_missing_dir() {
        let result = load_family_registry(Path::new("/nonexistent/path"));
        assert!(result.is_ok());
        assert!(result.expect("registry").is_empty());
    }

    #[test]
    fn test_parse_family_yaml_missing_family() {
        let yaml = r#"
display_name: "Test"
vendor: Test
architectures:
  - TestArch
hf_pattern: "test/*"
size_variants:
  small:
    parameters: "1B"
    hidden_dim: 768
    num_layers: 12
    num_heads: 12
    intermediate_dim: 3072
    vocab_size: 30000
constraints:
  attention_type: mha
tensor_template:
  embedding: "model.embed.weight"
"#;
        let result = parse_family_yaml(yaml, Path::new("test.yaml"));
        assert!(result.is_err());
    }

    #[test]
    fn test_parse_shape_template_flat() {
        let yaml = r#"
shape_template:
  embedding: "[vocab_size, hidden_dim]"
  lm_head: "[vocab_size, hidden_dim]"
"#;
        let result = parse_yaml(yaml).expect("parse yaml");
        let shapes = parse_shape_template(result.get("shape_template").expect("shape_template"));
        assert_eq!(
            shapes.shapes.get("embedding"),
            Some(&"[vocab_size, hidden_dim]".to_string())
        );
        assert_eq!(
            shapes.shapes.get("lm_head"),
            Some(&"[vocab_size, hidden_dim]".to_string())
        );
    }

    #[test]
    fn test_parse_chat_template() {
        let yaml = r#"
chat_template:
  format: chatml
  template: "test template"
  bos_token: "<s>"
  eos_token: "</s>"
  special_tokens:
    pad: "<pad>"
"#;
        let result = parse_yaml(yaml).expect("parse yaml");
        let ct = parse_chat_template(result.get("chat_template").expect("chat_template"))
            .expect("parse chat template");
        assert_eq!(ct.format, "chatml");
        assert_eq!(ct.bos_token, "<s>");
        assert_eq!(ct.special_tokens.get("pad"), Some(&"<pad>".to_string()));
    }

    #[test]
    fn test_parse_certification() {
        let yaml = r#"
certification:
  playbook_path: "path/to/playbook.yaml"
  csv_family_key: "test"
  size_categories:
    small: tiny
    large: xlarge
"#;
        let result = parse_yaml(yaml).expect("parse yaml");
        let cert = parse_certification(result.get("certification").expect("certification"))
            .expect("parse certification");
        assert_eq!(cert.playbook_path, "path/to/playbook.yaml");
        assert_eq!(cert.csv_family_key, "test");
        assert_eq!(cert.size_categories.get("small"), Some(&"tiny".to_string()));
    }
}