use std::collections::HashMap;
use axonml_autograd::Variable;
use axonml_nn::{BatchNorm1d, Dropout, GELU, Linear, Module, Parameter, ReLU, Sequential};
#[cfg(test)]
use axonml_tensor::Tensor;
use super::colossus;
pub struct Gaia {
safety_encoder: Sequential,
validation_net: Sequential,
fusion: Sequential,
validation_head: Linear,
safety_head: Linear,
override_head: Linear,
confidence_head: Linear,
training: bool,
}
pub const GAIA_INPUT_DIM: usize = colossus::TOTAL_SPECIALIST_DIM + 256;
impl Default for Gaia {
fn default() -> Self {
Self::new()
}
}
impl Gaia {
pub fn new() -> Self {
let input_dim = GAIA_INPUT_DIM;
let safety_encoder = Sequential::new()
.add(Linear::new(input_dim, 512))
.add(ReLU)
.add(Linear::new(512, 256));
let validation_net = Sequential::new()
.add(Linear::new(input_dim, 512))
.add(BatchNorm1d::new(512))
.add(GELU)
.add(Dropout::new(0.3))
.add(Linear::new(512, 256))
.add(BatchNorm1d::new(256))
.add(ReLU)
.add(Dropout::new(0.2))
.add(Linear::new(256, 128))
.add(ReLU);
let fusion = Sequential::new().add(Linear::new(384, 256)).add(ReLU);
let validation_head = Linear::new(256, 5);
let safety_head = Linear::new(256, 1);
let override_head = Linear::new(256, 8);
let confidence_head = Linear::new(256, 4);
Self {
safety_encoder,
validation_net,
fusion,
validation_head,
safety_head,
override_head,
confidence_head,
training: true,
}
}
pub fn forward_parts(
&self,
specialist_features: &Variable,
colossus_embedding: &Variable,
) -> (Variable, Variable, Variable, Variable, Variable) {
let batch = specialist_features.shape()[0];
let input =
super::aquilo::concat_variables(&[specialist_features, colossus_embedding], batch);
self.forward_all(&input)
}
pub fn forward_all(
&self,
input: &Variable,
) -> (Variable, Variable, Variable, Variable, Variable) {
let batch = input.shape()[0];
let safety_out = self.safety_encoder.forward(input); let validation_out = self.validation_net.forward(input);
let fused = super::aquilo::concat_variables(&[&safety_out, &validation_out], batch);
let embedding = self.fusion.forward(&fused);
let validation = self.validation_head.forward(&embedding);
let safety = self.safety_head.forward(&embedding);
let override_out = self.override_head.forward(&embedding);
let confidence = self.confidence_head.forward(&embedding);
(validation, safety, override_out, confidence, embedding)
}
pub fn embedding_dim() -> usize {
256
}
pub fn output_dim() -> usize {
18
}
}
impl Module for Gaia {
fn forward(&self, input: &Variable) -> Variable {
let (validation, _, _, _, _) = self.forward_all(input);
validation
}
fn parameters(&self) -> Vec<Parameter> {
let mut params = Vec::new();
params.extend(self.safety_encoder.parameters());
params.extend(self.validation_net.parameters());
params.extend(self.fusion.parameters());
params.extend(self.validation_head.parameters());
params.extend(self.safety_head.parameters());
params.extend(self.override_head.parameters());
params.extend(self.confidence_head.parameters());
params
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
for (n, p) in self.safety_encoder.named_parameters() {
params.insert(format!("safety_encoder.{n}"), p);
}
for (n, p) in self.validation_net.named_parameters() {
params.insert(format!("validation_net.{n}"), p);
}
for (n, p) in self.fusion.named_parameters() {
params.insert(format!("fusion.{n}"), p);
}
for (n, p) in self.validation_head.named_parameters() {
params.insert(format!("validation_head.{n}"), p);
}
for (n, p) in self.safety_head.named_parameters() {
params.insert(format!("safety_head.{n}"), p);
}
for (n, p) in self.override_head.named_parameters() {
params.insert(format!("override_head.{n}"), p);
}
for (n, p) in self.confidence_head.named_parameters() {
params.insert(format!("confidence_head.{n}"), p);
}
params
}
fn set_training(&mut self, training: bool) {
self.training = training;
self.safety_encoder.set_training(training);
self.validation_net.set_training(training);
self.fusion.set_training(training);
}
fn is_training(&self) -> bool {
self.training
}
fn name(&self) -> &'static str {
"Gaia"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gaia_output_shapes() {
let model = Gaia::new();
let specialist = Variable::new(
Tensor::from_vec(
vec![1.0; 2 * colossus::TOTAL_SPECIALIST_DIM],
&[2, colossus::TOTAL_SPECIALIST_DIM],
)
.unwrap(),
false,
);
let colossus_emb = Variable::new(
Tensor::from_vec(vec![1.0; 2 * 256], &[2, 256]).unwrap(),
false,
);
let (val, safety, override_out, conf, emb) =
model.forward_parts(&specialist, &colossus_emb);
assert_eq!(val.shape(), vec![2, 5]);
assert_eq!(safety.shape(), vec![2, 1]);
assert_eq!(override_out.shape(), vec![2, 8]);
assert_eq!(conf.shape(), vec![2, 4]);
assert_eq!(emb.shape(), vec![2, 256]);
}
#[test]
fn test_gaia_concat_forward() {
let model = Gaia::new();
let input = Variable::new(
Tensor::from_vec(vec![1.0; 2 * GAIA_INPUT_DIM], &[2, GAIA_INPUT_DIM]).unwrap(),
false,
);
let output = model.forward(&input);
assert_eq!(output.shape(), vec![2, 5]);
}
#[test]
fn test_gaia_parameter_count() {
let model = Gaia::new();
let total: usize = model.parameters().iter().map(|p| p.numel()).sum();
assert!(
total > 1_500_000 && total < 3_000_000,
"Gaia has {} params, expected ~2.2M",
total
);
}
}