use crate::diagnostics::{TraceEvent, TraceLevel};
use crate::error::AphelionResult;
use crate::graph::BuildGraph;
use crate::pipeline::{BuildContext, PipelineStage};
use std::time::SystemTime;
#[derive(Debug, Clone)]
pub struct TrainingAccelConfig {
pub compression_ratio: f32,
pub deterministic: bool,
pub seed: Option<u64>,
pub mixed_precision: bool,
}
impl Default for TrainingAccelConfig {
fn default() -> Self {
Self {
compression_ratio: 0.1,
deterministic: true,
seed: None,
mixed_precision: false,
}
}
}
impl TrainingAccelConfig {
pub fn new(compression_ratio: f32) -> Self {
Self {
compression_ratio: compression_ratio.clamp(0.01, 1.0),
..Default::default()
}
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.deterministic = true;
self.seed = Some(seed);
self
}
pub fn with_mixed_precision(mut self) -> Self {
self.mixed_precision = true;
self
}
}
#[derive(Debug, Clone)]
pub struct InferenceAccelConfig {
pub batch_size: usize,
pub use_ternary_layers: bool,
pub use_kv_cache: bool,
pub max_seq_len: Option<usize>,
}
impl Default for InferenceAccelConfig {
fn default() -> Self {
Self {
batch_size: 1,
use_ternary_layers: true,
use_kv_cache: false,
max_seq_len: None,
}
}
}
impl InferenceAccelConfig {
pub fn new(batch_size: usize) -> Self {
Self {
batch_size: batch_size.max(1),
..Default::default()
}
}
pub fn with_kv_cache(mut self, max_seq_len: usize) -> Self {
self.use_kv_cache = true;
self.max_seq_len = Some(max_seq_len);
self
}
pub fn without_ternary_layers(mut self) -> Self {
self.use_ternary_layers = false;
self
}
}
#[derive(Debug, Clone)]
pub enum AccelMode {
Training(TrainingAccelConfig),
Inference(InferenceAccelConfig),
}
#[derive(Debug, Clone)]
pub struct AccelerationStage {
mode: AccelMode,
}
impl AccelerationStage {
pub fn for_training(compression_ratio: f32) -> Self {
Self {
mode: AccelMode::Training(TrainingAccelConfig::new(compression_ratio)),
}
}
pub fn with_training_config(config: TrainingAccelConfig) -> Self {
Self {
mode: AccelMode::Training(config),
}
}
pub fn for_inference(batch_size: usize) -> Self {
Self {
mode: AccelMode::Inference(InferenceAccelConfig::new(batch_size)),
}
}
pub fn with_inference_config(config: InferenceAccelConfig) -> Self {
Self {
mode: AccelMode::Inference(config),
}
}
pub fn is_training(&self) -> bool {
matches!(self.mode, AccelMode::Training(_))
}
pub fn is_inference(&self) -> bool {
matches!(self.mode, AccelMode::Inference(_))
}
fn apply_training_acceleration(
&self,
ctx: &BuildContext,
graph: &mut BuildGraph,
config: &TrainingAccelConfig,
) -> AphelionResult<()> {
for node in &mut graph.nodes {
node.metadata.insert(
"accel.mode".to_string(),
serde_json::Value::String("training".to_string()),
);
node.metadata.insert(
"accel.compression_ratio".to_string(),
serde_json::Value::Number(
serde_json::Number::from_f64(config.compression_ratio as f64)
.unwrap_or_else(|| serde_json::Number::from(0)),
),
);
node.metadata.insert(
"accel.deterministic".to_string(),
serde_json::Value::Bool(config.deterministic),
);
if let Some(seed) = config.seed {
node.metadata.insert(
"accel.seed".to_string(),
serde_json::Value::Number(serde_json::Number::from(seed)),
);
}
if config.mixed_precision {
node.metadata.insert(
"accel.mixed_precision".to_string(),
serde_json::Value::Bool(true),
);
}
}
ctx.trace.record(TraceEvent {
id: "stage.acceleration.training".to_string(),
message: format!(
"Applied training acceleration: compression_ratio={}, deterministic={}, nodes={}",
config.compression_ratio,
config.deterministic,
graph.nodes.len()
),
timestamp: SystemTime::now(),
level: TraceLevel::Info,
span_id: None,
trace_id: None,
});
Ok(())
}
fn apply_inference_acceleration(
&self,
ctx: &BuildContext,
graph: &mut BuildGraph,
config: &InferenceAccelConfig,
) -> AphelionResult<()> {
for node in &mut graph.nodes {
node.metadata.insert(
"accel.mode".to_string(),
serde_json::Value::String("inference".to_string()),
);
node.metadata.insert(
"accel.batch_size".to_string(),
serde_json::Value::Number(serde_json::Number::from(config.batch_size)),
);
node.metadata.insert(
"accel.ternary_layers".to_string(),
serde_json::Value::Bool(config.use_ternary_layers),
);
if config.use_kv_cache {
node.metadata
.insert("accel.kv_cache".to_string(), serde_json::Value::Bool(true));
if let Some(max_seq_len) = config.max_seq_len {
node.metadata.insert(
"accel.max_seq_len".to_string(),
serde_json::Value::Number(serde_json::Number::from(max_seq_len)),
);
}
}
}
ctx.trace.record(TraceEvent {
id: "stage.acceleration.inference".to_string(),
message: format!(
"Applied inference acceleration: batch_size={}, ternary={}, kv_cache={}, nodes={}",
config.batch_size,
config.use_ternary_layers,
config.use_kv_cache,
graph.nodes.len()
),
timestamp: SystemTime::now(),
level: TraceLevel::Info,
span_id: None,
trace_id: None,
});
Ok(())
}
}
impl PipelineStage for AccelerationStage {
fn name(&self) -> &str {
match &self.mode {
AccelMode::Training(_) => "tritter-acceleration-training",
AccelMode::Inference(_) => "tritter-acceleration-inference",
}
}
fn execute(&self, ctx: &BuildContext, graph: &mut BuildGraph) -> AphelionResult<()> {
match &self.mode {
AccelMode::Training(config) => self.apply_training_acceleration(ctx, graph, config),
AccelMode::Inference(config) => self.apply_inference_acceleration(ctx, graph, config),
}
}
}
#[cfg(feature = "tokio")]
impl crate::pipeline::AsyncPipelineStage for AccelerationStage {
fn name(&self) -> &str {
<Self as PipelineStage>::name(self)
}
fn execute_async<'a>(
&'a self,
ctx: &'a BuildContext<'_>,
graph: &'a mut BuildGraph,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = AphelionResult<()>> + Send + 'a>> {
Box::pin(async move { self.execute(ctx, graph) })
}
}
pub fn gradient_compression_pre_hook(
compression_ratio: f32,
seed: u64,
) -> impl Fn(&BuildContext) -> AphelionResult<()> + Send + Sync + 'static {
move |ctx| {
ctx.trace.record(TraceEvent {
id: "hook.gradient_compression.setup".to_string(),
message: format!(
"Gradient compression initialized: ratio={}, seed={}",
compression_ratio, seed
),
timestamp: SystemTime::now(),
level: TraceLevel::Info,
span_id: None,
trace_id: None,
});
Ok(())
}
}
pub fn gradient_compression_post_hook(
) -> impl Fn(&BuildContext, &BuildGraph) -> AphelionResult<()> + Send + Sync + 'static {
|ctx, graph| {
let accel_nodes = graph
.nodes
.iter()
.filter(|n| n.metadata.contains_key("accel.mode"))
.count();
if accel_nodes == 0 && !graph.nodes.is_empty() {
ctx.trace.record(TraceEvent {
id: "hook.gradient_compression.warning".to_string(),
message: "No nodes have acceleration metadata".to_string(),
timestamp: SystemTime::now(),
level: TraceLevel::Warn,
span_id: None,
trace_id: None,
});
} else {
ctx.trace.record(TraceEvent {
id: "hook.gradient_compression.validated".to_string(),
message: format!(
"Acceleration validated: {}/{} nodes configured",
accel_nodes,
graph.nodes.len()
),
timestamp: SystemTime::now(),
level: TraceLevel::Info,
span_id: None,
trace_id: None,
});
}
Ok(())
}
}
pub fn training_pipeline(compression_ratio: f32) -> crate::pipeline::BuildPipeline {
crate::pipeline::BuildPipeline::new()
.with_stage(Box::new(crate::pipeline::ValidationStage))
.with_stage(Box::new(AccelerationStage::for_training(compression_ratio)))
.with_stage(Box::new(crate::pipeline::HashingStage))
}
pub fn inference_pipeline(batch_size: usize) -> crate::pipeline::BuildPipeline {
crate::pipeline::BuildPipeline::new()
.with_stage(Box::new(AccelerationStage::for_inference(batch_size)))
.with_stage(Box::new(crate::pipeline::HashingStage))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::NullBackend;
use crate::config::ModelConfig;
use crate::diagnostics::InMemoryTraceSink;
#[test]
fn test_training_config_new() {
let config = TrainingAccelConfig::new(0.1);
assert!((config.compression_ratio - 0.1).abs() < f32::EPSILON);
assert!(config.deterministic);
assert!(!config.mixed_precision);
}
#[test]
fn test_training_config_with_seed() {
let config = TrainingAccelConfig::new(0.1).with_seed(42);
assert_eq!(config.seed, Some(42));
assert!(config.deterministic);
}
#[test]
fn test_training_config_with_mixed_precision() {
let config = TrainingAccelConfig::new(0.1).with_mixed_precision();
assert!(config.mixed_precision);
}
#[test]
fn test_inference_config_new() {
let config = InferenceAccelConfig::new(32);
assert_eq!(config.batch_size, 32);
assert!(config.use_ternary_layers);
assert!(!config.use_kv_cache);
}
#[test]
fn test_inference_config_with_kv_cache() {
let config = InferenceAccelConfig::new(32).with_kv_cache(2048);
assert!(config.use_kv_cache);
assert_eq!(config.max_seq_len, Some(2048));
}
#[test]
fn test_inference_config_without_ternary() {
let config = InferenceAccelConfig::new(32).without_ternary_layers();
assert!(!config.use_ternary_layers);
}
#[test]
fn test_acceleration_stage_for_training() {
let stage = AccelerationStage::for_training(0.1);
assert!(stage.is_training());
assert!(!stage.is_inference());
assert_eq!(stage.name(), "tritter-acceleration-training");
}
#[test]
fn test_acceleration_stage_for_inference() {
let stage = AccelerationStage::for_inference(32);
assert!(!stage.is_training());
assert!(stage.is_inference());
assert_eq!(stage.name(), "tritter-acceleration-inference");
}
#[test]
fn test_acceleration_stage_execute_training() {
let stage = AccelerationStage::for_training(0.1);
let backend = NullBackend::cpu();
let trace = InMemoryTraceSink::new();
let ctx = BuildContext::new(&backend, &trace);
let mut graph = BuildGraph::default();
graph.add_node("linear1", ModelConfig::new("linear", "1.0"));
graph.add_node("linear2", ModelConfig::new("linear", "1.0"));
let result = stage.execute(&ctx, &mut graph);
assert!(result.is_ok());
for node in &graph.nodes {
assert!(node.metadata.contains_key("accel.mode"));
assert!(node.metadata.contains_key("accel.compression_ratio"));
assert!(node.metadata.contains_key("accel.deterministic"));
}
let events = trace.events();
assert!(events
.iter()
.any(|e| e.message.contains("Applied training acceleration")));
}
#[test]
fn test_acceleration_stage_execute_inference() {
let stage = AccelerationStage::for_inference(32);
let backend = NullBackend::cpu();
let trace = InMemoryTraceSink::new();
let ctx = BuildContext::new(&backend, &trace);
let mut graph = BuildGraph::default();
graph.add_node("linear1", ModelConfig::new("linear", "1.0"));
let result = stage.execute(&ctx, &mut graph);
assert!(result.is_ok());
for node in &graph.nodes {
assert!(node.metadata.contains_key("accel.mode"));
assert!(node.metadata.contains_key("accel.batch_size"));
assert!(node.metadata.contains_key("accel.ternary_layers"));
}
let events = trace.events();
assert!(events
.iter()
.any(|e| e.message.contains("Applied inference acceleration")));
}
#[test]
fn test_gradient_compression_pre_hook() {
let hook = gradient_compression_pre_hook(0.1, 42);
let backend = NullBackend::cpu();
let trace = InMemoryTraceSink::new();
let ctx = BuildContext::new(&backend, &trace);
let result = hook(&ctx);
assert!(result.is_ok());
let events = trace.events();
assert!(events.iter().any(|e| e.message.contains("ratio=0.1")));
assert!(events.iter().any(|e| e.message.contains("seed=42")));
}
#[test]
fn test_gradient_compression_post_hook() {
let hook = gradient_compression_post_hook();
let backend = NullBackend::cpu();
let trace = InMemoryTraceSink::new();
let ctx = BuildContext::new(&backend, &trace);
let mut graph = BuildGraph::default();
graph.add_node("linear", ModelConfig::new("linear", "1.0"));
let stage = AccelerationStage::for_training(0.1);
stage.execute(&ctx, &mut graph).unwrap();
let trace2 = InMemoryTraceSink::new();
let ctx2 = BuildContext::new(&backend, &trace2);
let result = hook(&ctx2, &graph);
assert!(result.is_ok());
let events = trace2.events();
assert!(events.iter().any(|e| e.message.contains("validated")));
}
#[test]
fn test_training_pipeline() {
let pipeline = training_pipeline(0.1);
let backend = NullBackend::cpu();
let trace = InMemoryTraceSink::new();
let ctx = BuildContext::new(&backend, &trace);
let mut graph = BuildGraph::default();
graph.add_node("linear", ModelConfig::new("linear", "1.0"));
let result = pipeline.execute(&ctx, graph);
assert!(result.is_ok());
}
#[test]
fn test_inference_pipeline() {
let pipeline = inference_pipeline(32);
let backend = NullBackend::cpu();
let trace = InMemoryTraceSink::new();
let ctx = BuildContext::new(&backend, &trace);
let mut graph = BuildGraph::default();
graph.add_node("linear", ModelConfig::new("linear", "1.0"));
let result = pipeline.execute(&ctx, graph);
assert!(result.is_ok());
}
#[test]
fn test_acceleration_stage_clone() {
let stage = AccelerationStage::for_training(0.1);
let cloned = stage.clone();
assert!(cloned.is_training());
}
#[test]
fn test_accel_mode_variants() {
let training = AccelMode::Training(TrainingAccelConfig::default());
let inference = AccelMode::Inference(InferenceAccelConfig::default());
assert!(matches!(training, AccelMode::Training(_)));
assert!(matches!(inference, AccelMode::Inference(_)));
}
}