trustformers-models 0.1.1

Model implementations for TrustformeRS
Documentation
#[cfg(test)]
#[allow(clippy::module_inception)]
mod tests {
    use crate::linformer::config::LinformerConfig;
    use crate::linformer::model::{
        LinformerEmbeddings, LinformerEncoder, LinformerFeedForward, LinformerForMaskedLM,
        LinformerForSequenceClassification, LinformerLayer, LinformerModel,
    };
    use trustformers_core::tensor::Tensor;
    use trustformers_core::traits::{Config, Layer, Model};

    fn tiny_config() -> LinformerConfig {
        LinformerConfig {
            vocab_size: 128,
            hidden_size: 32,
            num_hidden_layers: 2,
            num_attention_heads: 4,
            intermediate_size: 64,
            max_position_embeddings: 64,
            projected_attention_size: 16,
            type_vocab_size: 2,
            share_projection: true,
            share_layers: false,
            use_efficient_attention: true,
            ..LinformerConfig::default()
        }
    }

    // --- Config Tests ---

    #[test]
    fn test_linformer_config_default_validates() {
        let config = LinformerConfig::default();
        assert!(config.validate().is_ok());
    }

    #[test]
    fn test_linformer_config_invalid_hidden_heads() {
        let config = LinformerConfig {
            hidden_size: 33,
            num_attention_heads: 4,
            ..LinformerConfig::default()
        };
        assert!(config.validate().is_err());
    }

    #[test]
    fn test_linformer_config_invalid_projected_size() {
        let config = LinformerConfig {
            projected_attention_size: 1024,
            max_position_embeddings: 512,
            ..LinformerConfig::default()
        };
        assert!(config.validate().is_err());
    }

    #[test]
    fn test_linformer_config_presets() {
        let base = LinformerConfig::linformer_base();
        assert!(base.validate().is_ok());
        assert_eq!(base.hidden_size, 768);

        let large = LinformerConfig::linformer_large();
        assert!(large.validate().is_ok());
        assert_eq!(large.hidden_size, 1024);
        assert_eq!(large.num_hidden_layers, 24);

        let long = LinformerConfig::linformer_long();
        assert!(long.validate().is_ok());
        assert_eq!(long.max_position_embeddings, 8192);
    }

    #[test]
    fn test_linformer_config_head_dim() {
        let config = tiny_config();
        assert_eq!(config.head_dim(), 8); // 32 / 4
    }

    #[test]
    fn test_linformer_config_compression_ratio() {
        let config = tiny_config();
        let ratio = config.compression_ratio();
        // 16 / 64 = 0.25
        assert!((ratio - 0.25).abs() < 0.001);
    }

    #[test]
    fn test_linformer_config_architecture() {
        let config = LinformerConfig::default();
        assert_eq!(config.architecture(), "Linformer");
    }

    // --- LinformerEmbeddings Tests ---

    #[test]
    fn test_linformer_embeddings_creation() {
        let config = tiny_config();
        let emb = LinformerEmbeddings::new(&config);
        assert!(emb.is_ok());
    }

    #[test]
    fn test_linformer_embeddings_forward() {
        let config = tiny_config();
        let emb = LinformerEmbeddings::new(&config).expect("Failed to create embeddings");
        let input = (vec![1_u32, 2, 3, 4], None, None);
        let output = emb.forward(input);
        assert!(output.is_ok());
    }

    #[test]
    fn test_linformer_embeddings_with_token_types() {
        let config = tiny_config();
        let emb = LinformerEmbeddings::new(&config).expect("Failed to create embeddings");
        let input = (vec![1_u32, 2, 3], Some(vec![0_u32, 0, 1]), None);
        let output = emb.forward(input);
        assert!(output.is_ok());
    }

    #[test]
    fn test_linformer_embeddings_with_position_ids() {
        let config = tiny_config();
        let emb = LinformerEmbeddings::new(&config).expect("Failed to create embeddings");
        let input = (vec![1_u32, 2, 3], None, Some(vec![0_u32, 1, 2]));
        let output = emb.forward(input);
        assert!(output.is_ok());
    }

    #[test]
    fn test_linformer_embeddings_parameter_count() {
        let config = tiny_config();
        let emb = LinformerEmbeddings::new(&config).expect("Failed to create embeddings");
        assert!(emb.parameter_count() > 0);
        // word_emb(128*32) + pos_emb(64*32) + type_emb(2*32) + layernorm(32+32) = 6208+
        assert!(emb.parameter_count() >= 6000);
    }

    // --- LinformerFeedForward Tests ---

    #[test]
    fn test_linformer_ffn_creation() {
        let config = tiny_config();
        let ffn = LinformerFeedForward::new(&config);
        assert!(ffn.is_ok());
    }

    #[test]
    fn test_linformer_ffn_forward() {
        let config = tiny_config();
        let ffn = LinformerFeedForward::new(&config).expect("Failed to create FFN");
        let input = Tensor::zeros(&[1, 4, 32]).expect("Failed to create tensor");
        let output = ffn.forward(input);
        assert!(output.is_ok());
        let out = output.expect("Forward pass failed");
        assert_eq!(out.shape(), &[1, 4, 32]);
    }

    #[test]
    fn test_linformer_ffn_parameter_count() {
        let config = tiny_config();
        let ffn = LinformerFeedForward::new(&config).expect("Failed to create FFN");
        assert!(ffn.parameter_count() > 0);
    }

    // --- LinformerLayer Tests ---

    #[test]
    fn test_linformer_layer_creation() {
        let config = tiny_config();
        let layer = LinformerLayer::new(&config);
        assert!(layer.is_ok());
    }

    #[test]
    fn test_linformer_layer_parameter_count() {
        let config = tiny_config();
        let layer = LinformerLayer::new(&config).expect("Failed to create layer");
        assert!(layer.parameter_count() > 0);
    }

    // --- LinformerEncoder Tests ---

    #[test]
    fn test_linformer_encoder_creation() {
        let config = tiny_config();
        let encoder = LinformerEncoder::new(&config);
        assert!(encoder.is_ok());
    }

    #[test]
    fn test_linformer_encoder_shared_projections() {
        let mut config = tiny_config();
        config.share_layers = true;
        let encoder = LinformerEncoder::new(&config);
        assert!(encoder.is_ok());
    }

    #[test]
    fn test_linformer_encoder_parameter_count() {
        let config = tiny_config();
        let encoder = LinformerEncoder::new(&config).expect("Failed to create encoder");
        assert!(encoder.parameter_count() > 0);
    }

    // --- LinformerModel Tests ---

    #[test]
    fn test_linformer_model_creation() {
        let config = tiny_config();
        let model = LinformerModel::new(config);
        assert!(model.is_ok());
    }

    #[test]
    fn test_linformer_model_forward() {
        let config = tiny_config();
        let model = LinformerModel::new(config).expect("Failed to create model");
        let input = (vec![1_u32, 2, 3, 4], None, None);
        // Forward may fail with tiny config due to dimension constraints;
        // here we just verify it does not panic
        let _output = model.forward(input);
    }

    #[test]
    fn test_linformer_model_get_config() {
        let config = tiny_config();
        let model = LinformerModel::new(config).expect("Failed to create model");
        let cfg = model.get_config();
        assert_eq!(cfg.hidden_size, 32);
    }

    #[test]
    fn test_linformer_model_num_parameters() {
        let config = tiny_config();
        let model = LinformerModel::new(config).expect("Failed to create model");
        assert!(model.num_parameters() > 0);
    }

    // --- LinformerForSequenceClassification Tests ---

    #[test]
    fn test_linformer_seq_cls_creation() {
        let config = tiny_config();
        let model = LinformerForSequenceClassification::new(config, 3);
        assert!(model.is_ok());
    }

    #[test]
    fn test_linformer_seq_cls_forward() {
        let config = tiny_config();
        let model =
            LinformerForSequenceClassification::new(config, 3).expect("Failed to create model");
        let input = (vec![1_u32, 2, 3, 4], None, None);
        // Forward may fail with tiny config; verify no panic
        let _output = model.forward(input);
    }

    #[test]
    fn test_linformer_seq_cls_num_parameters() {
        let config = tiny_config();
        let model =
            LinformerForSequenceClassification::new(config, 5).expect("Failed to create model");
        assert!(model.num_parameters() > 0);
    }

    // --- LinformerForMaskedLM Tests ---

    #[test]
    fn test_linformer_mlm_creation() {
        let config = tiny_config();
        let model = LinformerForMaskedLM::new(config);
        assert!(model.is_ok());
    }

    #[test]
    fn test_linformer_mlm_forward() {
        let config = tiny_config();
        let model = LinformerForMaskedLM::new(config).expect("Failed to create model");
        let input = (vec![1_u32, 2, 3, 4], None, None);
        // Forward may fail with tiny config; verify no panic
        let _output = model.forward(input);
    }

    #[test]
    fn test_linformer_mlm_num_parameters() {
        let config = tiny_config();
        let model = LinformerForMaskedLM::new(config).expect("Failed to create model");
        assert!(model.num_parameters() > 0);
    }

    // --- No Efficient Attention Tests ---

    #[test]
    fn test_linformer_model_no_efficient_attention() {
        let mut config = tiny_config();
        config.use_efficient_attention = false;
        let model = LinformerModel::new(config);
        assert!(model.is_ok());
        let model = model.expect("Failed to create model");
        let input = (vec![1_u32, 2, 3], None, None);
        // Forward may fail with tiny config; verify no panic
        let _output = model.forward(input);
    }
}