use std::collections::HashMap;
use mistralrs_quant::{QuantizedConfig, StaticLoraConfig};
use serde::{Deserialize, Serialize};
use crate::{
layers::{Activation, Phi4MMRopeScalingConfig},
vision_models::conformer::config::ConformerEncoderConfig,
};
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Phi4MMImageEmbedConfig {
pub n_embd: Option<usize>,
pub crop_size: Option<usize>,
pub embedding_cls: String,
pub enable_gradient_checkpointing: bool,
pub hd_transform_order: Option<String>,
pub image_token_compression_cls: Option<String>,
pub projection_cls: Option<String>,
pub use_hd_transform: Option<bool>,
pub with_learnable_separator: Option<bool>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Phi4MMAudioEmbedConfig {
pub n_embd: Option<usize>,
pub compression_rate: usize,
pub downsample_rate: usize,
pub embedding_cls: String,
pub projection_cls: String,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Phi4MMEmbdLayerConfig {
pub image_embd_layer: Option<Phi4MMImageEmbedConfig>,
pub audio_embd_layer: Option<Phi4MMAudioEmbedConfig>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Phi4MMImgProcessorConfig {
pub layer_idx: Option<isize>,
pub type_feature: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Phi4MMAudioConfig {
pub config: ConformerEncoderConfig,
pub name: String,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Phi4MMConfig {
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: Option<usize>,
pub resid_pdrop: f64,
pub embd_pdrop: f64,
pub attention_dropout: f64,
pub hidden_act: Activation,
pub max_position_embeddings: usize,
pub original_max_position_embeddings: usize,
pub initializer_range: f64,
pub rms_norm_eps: f64,
pub use_cache: bool,
pub tie_word_embeddings: bool,
pub rope_theta: f64,
pub rope_scaling: Option<Phi4MMRopeScalingConfig>,
pub partial_rotary_factor: f64,
pub bos_token_id: usize,
pub eos_token_id: usize,
pub pad_token_id: usize,
pub image_input_id: Option<f64>,
pub sliding_window: Option<usize>,
pub embd_layer: Phi4MMEmbdLayerConfig,
pub img_processor: Option<Phi4MMImgProcessorConfig>,
pub audio_processor: Option<Phi4MMAudioConfig>,
pub vision_lora: StaticLoraConfig,
pub speech_lora: StaticLoraConfig,
pub quantization_config: Option<QuantizedConfig>,
}
impl Phi4MMConfig {
pub fn num_key_value_heads(&self) -> usize {
self.num_key_value_heads.unwrap_or(self.num_attention_heads)
}
pub fn head_dim(&self) -> usize {
self.hidden_size / self.num_attention_heads
}
pub fn loras(&self) -> HashMap<String, StaticLoraConfig> {
let mut accum = HashMap::new();
accum.insert("vision".to_string(), self.vision_lora.clone());
accum
}
}