use serde::{Deserialize, Deserializer};
use std::collections::HashMap;
use crate::config::{code_predictor_config::CodePredictorConfig, rope_config::RopeScaling};
#[derive(Debug, Clone, PartialEq)]
pub enum DialectValue {
NoDialect,
Dialect(String),
}
impl DialectValue {
pub fn as_dialect(&self) -> Option<&str> {
match self {
DialectValue::NoDialect => None,
DialectValue::Dialect(name) => Some(name),
}
}
pub fn is_no_dialect(&self) -> bool {
matches!(self, DialectValue::NoDialect)
}
}
impl<'de> Deserialize<'de> for DialectValue {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
use serde::de::{self, Visitor};
struct DialectValueVisitor;
impl<'de> Visitor<'de> for DialectValueVisitor {
type Value = DialectValue;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a boolean false or a dialect name string")
}
fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
where
E: de::Error,
{
if v {
Ok(DialectValue::NoDialect)
} else {
Ok(DialectValue::NoDialect)
}
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(DialectValue::Dialect(v.to_string()))
}
fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(DialectValue::Dialect(v))
}
}
deserializer.deserialize_any(DialectValueVisitor)
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct TalkerConfig {
#[serde(default)]
pub code_predictor_config: CodePredictorConfig,
#[serde(default = "default_vocab_size")]
pub vocab_size: usize,
#[serde(default = "default_hidden_size")]
pub hidden_size: usize,
#[serde(default = "default_intermediate_size")]
pub intermediate_size: usize,
#[serde(default = "default_num_hidden_layers")]
pub num_hidden_layers: usize,
#[serde(default = "default_num_attention_heads")]
pub num_attention_heads: usize,
#[serde(default = "default_num_key_value_heads")]
pub num_key_value_heads: usize,
#[serde(default = "default_head_dim")]
pub head_dim: usize,
#[serde(default = "default_hidden_act")]
pub hidden_act: String,
#[serde(default = "default_max_position_embeddings")]
pub max_position_embeddings: usize,
#[serde(default = "default_initializer_range")]
pub initializer_range: f64,
#[serde(default = "default_rms_norm_eps")]
pub rms_norm_eps: f64,
#[serde(default = "default_rope_theta")]
pub rope_theta: f64,
pub rope_scaling: Option<RopeScaling>,
#[serde(default)]
pub attention_bias: bool,
#[serde(default)]
pub attention_dropout: f64,
#[serde(default = "default_num_code_groups")]
pub num_code_groups: usize,
#[serde(default = "default_text_hidden_size")]
pub text_hidden_size: usize,
#[serde(default = "default_text_vocab_size")]
pub text_vocab_size: usize,
#[serde(default = "default_codec_eos_token_id")]
pub codec_eos_token_id: usize,
#[serde(default = "default_codec_bos_id")]
pub codec_bos_id: usize,
#[serde(default = "default_codec_pad_id")]
pub codec_pad_id: usize,
#[serde(default = "default_codec_think_id")]
pub codec_think_id: usize,
#[serde(default = "default_codec_nothink_id")]
pub codec_nothink_id: usize,
#[serde(default = "default_codec_think_bos_id")]
pub codec_think_bos_id: usize,
#[serde(default = "default_codec_think_eos_id")]
pub codec_think_eos_id: usize,
#[serde(default)]
pub spk_id: Option<HashMap<String, usize>>,
#[serde(default)]
pub spk_is_dialect: Option<HashMap<String, DialectValue>>,
#[serde(default)]
pub codec_language_id: Option<HashMap<String, usize>>,
#[serde(default)]
pub use_sliding_window: bool,
pub sliding_window: Option<usize>,
pub pad_token_id: Option<usize>,
}
fn default_vocab_size() -> usize {
3072
}
fn default_hidden_size() -> usize {
1024
}
fn default_intermediate_size() -> usize {
2048
}
fn default_num_hidden_layers() -> usize {
20
}
fn default_num_attention_heads() -> usize {
16
}
fn default_num_key_value_heads() -> usize {
2
}
fn default_head_dim() -> usize {
128
}
fn default_hidden_act() -> String {
"silu".to_string()
}
fn default_max_position_embeddings() -> usize {
32768
}
fn default_initializer_range() -> f64 {
0.02
}
fn default_rms_norm_eps() -> f64 {
1e-6
}
fn default_rope_theta() -> f64 {
10000.0
}
fn default_num_code_groups() -> usize {
32
}
fn default_text_hidden_size() -> usize {
2048
}
fn default_text_vocab_size() -> usize {
151936
}
fn default_codec_eos_token_id() -> usize {
4198
}
fn default_codec_bos_id() -> usize {
4197
}
fn default_codec_pad_id() -> usize {
4196
}
fn default_codec_think_id() -> usize {
4202
}
fn default_codec_nothink_id() -> usize {
4203
}
fn default_codec_think_bos_id() -> usize {
4204
}
fn default_codec_think_eos_id() -> usize {
4205
}
impl Default for TalkerConfig {
fn default() -> Self {
Self {
code_predictor_config: CodePredictorConfig::default(),
vocab_size: default_vocab_size(),
hidden_size: default_hidden_size(),
intermediate_size: default_intermediate_size(),
num_hidden_layers: default_num_hidden_layers(),
num_attention_heads: default_num_attention_heads(),
num_key_value_heads: default_num_key_value_heads(),
head_dim: default_head_dim(),
hidden_act: default_hidden_act(),
max_position_embeddings: default_max_position_embeddings(),
initializer_range: default_initializer_range(),
rms_norm_eps: default_rms_norm_eps(),
rope_theta: default_rope_theta(),
rope_scaling: None,
attention_bias: false,
attention_dropout: 0.0,
num_code_groups: default_num_code_groups(),
text_hidden_size: default_text_hidden_size(),
text_vocab_size: default_text_vocab_size(),
codec_eos_token_id: default_codec_eos_token_id(),
codec_bos_id: default_codec_bos_id(),
codec_pad_id: default_codec_pad_id(),
codec_think_id: default_codec_think_id(),
codec_nothink_id: default_codec_nothink_id(),
codec_think_bos_id: default_codec_think_bos_id(),
codec_think_eos_id: default_codec_think_eos_id(),
spk_id: None,
spk_is_dialect: None,
codec_language_id: None,
use_sliding_window: false,
sliding_window: None,
pad_token_id: None,
}
}
}
impl TalkerConfig {
pub fn head_dim(&self) -> usize {
self.head_dim
}
}
impl crate::nn::attention::config::AttentionConfig for TalkerConfig {
fn hidden_size(&self) -> usize {
self.hidden_size
}
fn num_attention_heads(&self) -> usize {
self.num_attention_heads
}
fn num_key_value_heads(&self) -> usize {
self.num_key_value_heads
}
fn head_dim(&self) -> usize {
self.head_dim
}
fn attention_bias(&self) -> bool {
self.attention_bias
}
fn rms_norm_eps(&self) -> f64 {
self.rms_norm_eps
}
fn sliding_window(&self) -> Option<usize> {
self.sliding_window
}
}