use cubecl_common::config::logger::{LogLevel, LoggerConfig};
#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct FusionConfig {
#[serde(default)]
pub logger: LoggerConfig<FusionLogLevel>,
#[serde(default)]
pub beam_search: BeamSearchConfig,
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct BeamSearchConfig {
#[serde(default = "default_max_blocks")]
pub max_blocks: usize,
}
impl Default for BeamSearchConfig {
fn default() -> Self {
Self {
max_blocks: default_max_blocks(),
}
}
}
fn default_max_blocks() -> usize {
5
}
#[derive(
Default,
Clone,
Copy,
Debug,
PartialEq,
Eq,
PartialOrd,
Ord,
serde::Serialize,
serde::Deserialize,
)]
pub enum FusionLogLevel {
#[default]
#[serde(rename = "disabled")]
Disabled,
#[serde(rename = "basic")]
Basic,
#[serde(rename = "medium")]
Medium,
#[serde(rename = "full")]
Full,
}
impl LogLevel for FusionLogLevel {}