use crate::automl::space::{categorical, int_range, linear_range, ParamMap, SearchSpace};
use irithyll_core::attention::AttentionMode;
use irithyll_core::learner::StreamingLearner;
use super::{Factory, FactoryError};
fn linear_attention_search_space() -> SearchSpace {
SearchSpace::builder()
.param("n_heads", categorical(&[1u32, 2, 4, 8]))
.param("forgetting_factor", linear_range(0.95, 0.9999))
.param("warmup", int_range(5, 50))
.build()
.expect("linear_attention_search_space: builder produces a valid space by construction")
}
fn delta_product_search_space() -> SearchSpace {
SearchSpace::builder()
.param("n_heads", int_range(1, 8))
.param("n_compositions", int_range(1, 4))
.param("forgetting_factor", linear_range(0.95, 0.9999))
.param("warmup", int_range(5, 50))
.build()
.expect("delta_product_search_space: builder produces a valid space by construction")
}
fn rwkv7_search_space() -> SearchSpace {
SearchSpace::builder()
.param("n_heads", int_range(1, 8))
.param("forgetting_factor", linear_range(0.95, 0.9999))
.param("warmup", int_range(5, 50))
.build()
.expect("rwkv7_search_space: builder produces a valid space by construction")
}
impl Factory {
pub fn attention(d_model: usize) -> Self {
Self {
algorithm: super::Algorithm::Attention,
n_features: d_model,
space: linear_attention_search_space(),
warmup: 10,
complexity: 8000,
seed: 42,
accuracy_based_pruning: false,
proactive_prune_interval: None,
prune_half_life: None,
projection: None,
}
}
pub fn delta_product(d_model: usize) -> Self {
Self {
algorithm: super::Algorithm::DeltaProduct,
n_features: d_model,
space: delta_product_search_space(),
warmup: 10,
complexity: 8000,
seed: 42,
accuracy_based_pruning: false,
proactive_prune_interval: None,
prune_half_life: None,
projection: None,
}
}
pub fn rwkv7(d_model: usize) -> Self {
Self {
algorithm: super::Algorithm::Rwkv7,
n_features: d_model,
space: rwkv7_search_space(),
warmup: 10,
complexity: 5000,
seed: 42,
accuracy_based_pruning: false,
proactive_prune_interval: None,
prune_half_life: None,
projection: None,
}
}
pub(crate) fn create_attention(
&self,
params: &ParamMap,
) -> Result<Box<dyn StreamingLearner>, FactoryError> {
use crate::attention::{StreamingAttentionConfig, StreamingAttentionModel};
match self.algorithm {
super::Algorithm::Attention => {
let n_heads_str = params.category("n_heads")?;
let n_heads: usize =
n_heads_str
.as_str()
.parse()
.map_err(|_| FactoryError::IncompatibleArm {
reason: format!(
"attention: n_heads category '{n_heads_str}' is not a valid integer"
),
})?;
let forgetting_factor = params.float("forgetting_factor")?;
let warmup = params.usize("warmup")?;
let d_model = self.n_features;
if d_model > 0 && n_heads > 0 && d_model % n_heads != 0 {
return Err(FactoryError::IncompatibleArm {
reason: format!("n_heads={} does not divide d_model={}", n_heads, d_model),
});
}
let attn_config = StreamingAttentionConfig::builder()
.d_model(d_model)
.n_heads(n_heads)
.mode(AttentionMode::GLA)
.forgetting_factor(forgetting_factor)
.warmup(warmup)
.build()?;
Ok(Box::new(StreamingAttentionModel::new(attn_config)))
}
super::Algorithm::DeltaProduct => {
let n_heads = params.usize("n_heads")?.max(1);
let n_compositions = params.usize("n_compositions")?.max(1);
let forgetting_factor = params.float("forgetting_factor")?;
let warmup = params.usize("warmup")?;
let d_model = self.n_features;
if d_model > 0 && d_model % n_heads != 0 {
return Err(FactoryError::IncompatibleArm {
reason: format!("n_heads={} does not divide d_model={}", n_heads, d_model),
});
}
let attn_config = StreamingAttentionConfig::builder()
.d_model(d_model)
.n_heads(n_heads)
.mode(AttentionMode::DeltaProduct {
n_compositions,
reflections: false,
})
.forgetting_factor(forgetting_factor)
.warmup(warmup)
.build()?;
Ok(Box::new(StreamingAttentionModel::new(attn_config)))
}
super::Algorithm::Rwkv7 => {
let n_heads = params.usize("n_heads")?.max(1);
let forgetting_factor = params.float("forgetting_factor")?;
let warmup = params.usize("warmup")?;
let d_model = self.n_features;
if d_model > 0 && d_model % n_heads != 0 {
return Err(FactoryError::IncompatibleArm {
reason: format!("n_heads={} does not divide d_model={}", n_heads, d_model),
});
}
let attn_config = StreamingAttentionConfig::builder()
.d_model(d_model)
.n_heads(n_heads)
.mode(AttentionMode::RWKV7)
.forgetting_factor(forgetting_factor)
.warmup(warmup)
.build()?;
Ok(Box::new(StreamingAttentionModel::new(attn_config)))
}
_ => panic!("create_attention called on non-attention algorithm"),
}
}
}