use std::fmt;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum StoicheiaTask {
SecondArgmax,
Argmedian,
Median,
LongestCycle,
}
impl fmt::Display for StoicheiaTask {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::SecondArgmax => write!(f, "2nd_argmax"),
Self::Argmedian => write!(f, "argmedian"),
Self::Median => write!(f, "median"),
Self::LongestCycle => write!(f, "longest_cycle"),
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum StoicheiaArch {
Rnn,
Transformer,
}
impl fmt::Display for StoicheiaArch {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Rnn => write!(f, "rnn"),
Self::Transformer => write!(f, "transformer"),
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum StoicheiaOutput {
Distribution,
Scalar,
}
#[derive(Debug, Clone)]
pub struct StoicheiaConfig {
pub hidden_size: usize,
pub seq_len: usize,
pub task: StoicheiaTask,
pub arch: StoicheiaArch,
pub output: StoicheiaOutput,
pub num_layers: usize,
pub num_heads: usize,
pub input_range: usize,
}
impl StoicheiaConfig {
#[must_use]
pub const fn from_task(task: StoicheiaTask, hidden_size: usize, seq_len: usize) -> Self {
let (arch, output) = match task {
StoicheiaTask::SecondArgmax | StoicheiaTask::Argmedian => {
(StoicheiaArch::Rnn, StoicheiaOutput::Distribution)
}
StoicheiaTask::Median => (StoicheiaArch::Rnn, StoicheiaOutput::Scalar),
StoicheiaTask::LongestCycle => {
(StoicheiaArch::Transformer, StoicheiaOutput::Distribution)
}
};
let num_layers = match arch {
StoicheiaArch::Rnn => 1,
StoicheiaArch::Transformer => 2,
};
let num_heads = match arch {
StoicheiaArch::Rnn => 0,
StoicheiaArch::Transformer => 1,
};
let input_range = match arch {
StoicheiaArch::Rnn => 0,
StoicheiaArch::Transformer => seq_len,
};
Self {
hidden_size,
seq_len,
task,
arch,
output,
num_layers,
num_heads,
input_range,
}
}
#[must_use]
pub const fn output_size(&self) -> usize {
match self.output {
StoicheiaOutput::Distribution => self.seq_len,
StoicheiaOutput::Scalar => 1,
}
}
#[must_use]
pub const fn param_count(&self) -> usize {
match self.arch {
StoicheiaArch::Rnn => self.hidden_size * (1 + self.hidden_size + self.output_size()),
StoicheiaArch::Transformer => {
self.input_range * self.hidden_size
+ self.seq_len * self.hidden_size
+ self.num_layers * 4 * self.hidden_size * self.hidden_size
+ self.output_size() * self.hidden_size
}
}
}
}
impl fmt::Display for StoicheiaConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}(h={}, n={}, arch={}, {} params)",
self.task,
self.hidden_size,
self.seq_len,
self.arch,
self.param_count(),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn from_task_rnn_distribution() {
let cfg = StoicheiaConfig::from_task(StoicheiaTask::SecondArgmax, 16, 10);
assert_eq!(cfg.arch, StoicheiaArch::Rnn);
assert_eq!(cfg.output, StoicheiaOutput::Distribution);
assert_eq!(cfg.output_size(), 10);
assert_eq!(cfg.num_layers, 1);
assert_eq!(cfg.num_heads, 0);
assert_eq!(cfg.param_count(), 432);
}
#[test]
fn from_task_rnn_scalar() {
let cfg = StoicheiaConfig::from_task(StoicheiaTask::Median, 4, 3);
assert_eq!(cfg.arch, StoicheiaArch::Rnn);
assert_eq!(cfg.output, StoicheiaOutput::Scalar);
assert_eq!(cfg.output_size(), 1);
assert_eq!(cfg.param_count(), 24);
}
#[test]
fn from_task_transformer() {
let cfg = StoicheiaConfig::from_task(StoicheiaTask::LongestCycle, 4, 4);
assert_eq!(cfg.arch, StoicheiaArch::Transformer);
assert_eq!(cfg.output, StoicheiaOutput::Distribution);
assert_eq!(cfg.num_layers, 2);
assert_eq!(cfg.num_heads, 1);
assert_eq!(cfg.input_range, 4);
assert_eq!(cfg.param_count(), 176);
}
#[test]
fn display_config() {
let cfg = StoicheiaConfig::from_task(StoicheiaTask::SecondArgmax, 16, 10);
let s = cfg.to_string();
assert!(s.contains("2nd_argmax"));
assert!(s.contains("432"));
}
#[test]
fn blog_m2_2() {
let cfg = StoicheiaConfig::from_task(StoicheiaTask::SecondArgmax, 2, 2);
assert_eq!(cfg.param_count(), 10);
}
#[test]
fn blog_m4_3() {
let cfg = StoicheiaConfig::from_task(StoicheiaTask::SecondArgmax, 4, 3);
assert_eq!(cfg.param_count(), 32);
}
}