#![allow(missing_docs)]
use serde::Deserialize;
#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CountingLayer {
CountLstm,
CountLstmMoe,
CountLstmV2,
}
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)] pub struct FastinoConfig {
pub hidden_size: usize,
#[serde(default)]
pub counting_layer: Option<CountingLayer>,
#[serde(default = "default_max_len")]
pub max_seq_length: usize,
}
fn default_max_len() -> usize {
512
}
impl FastinoConfig {
pub fn from_path(path: &std::path::Path) -> Result<Self, super::errors::Error> {
let s = std::fs::read_to_string(path)?;
Ok(serde_json::from_str(&s)?)
}
}
impl Default for FastinoConfig {
fn default() -> Self {
Self {
hidden_size: 768,
counting_layer: Some(CountingLayer::CountLstmV2),
max_seq_length: 512,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_minimal_config() {
let json = r#"{"hidden_size": 768, "counting_layer": "count_lstm_v2"}"#;
let cfg: FastinoConfig = serde_json::from_str(json).unwrap();
assert_eq!(cfg.hidden_size, 768);
assert_eq!(cfg.counting_layer, Some(CountingLayer::CountLstmV2));
assert_eq!(cfg.max_seq_length, 512);
}
#[test]
fn parses_all_three_counting_variants() {
for (s, expected) in [
("count_lstm", CountingLayer::CountLstm),
("count_lstm_moe", CountingLayer::CountLstmMoe),
("count_lstm_v2", CountingLayer::CountLstmV2),
] {
let json = format!(r#"{{"hidden_size": 768, "counting_layer": "{s}"}}"#);
let cfg: FastinoConfig = serde_json::from_str(&json).unwrap();
assert_eq!(cfg.counting_layer, Some(expected));
}
}
#[test]
fn missing_counting_layer_is_optional_for_phase1() {
let json = r#"{"hidden_size": 768}"#;
let cfg: FastinoConfig = serde_json::from_str(json).unwrap();
assert!(cfg.counting_layer.is_none());
}
}