trustformers-optim 0.1.0

Optimizers for TrustformeRS
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
#![allow(clippy::all)]
use std::collections::HashMap;
use std::time::Instant;
use trustformers_core::TrustformersError;
use trustformers_core::{traits::Optimizer, Tensor};
use trustformers_optim::*;

fn main() -> Result<(), TrustformersError> {
    println!("🚀 TrustformeRS Cross-Framework Compatibility Test");
    println!("================================================");
    println!("🔬 Testing integration with PyTorch, TensorFlow, JAX, and ONNX");

    test_pytorch_compatibility()?;
    test_tensorflow_compatibility()?;
    test_jax_compatibility()?;
    test_onnx_compatibility()?;
    test_state_dict_conversion()?;

    println!("\n🎉 Cross-Framework Compatibility Test Completed!");
    println!("   ✅ All frameworks tested successfully");
    println!("   🔄 State conversion working correctly");
    println!("   🚀 Ready for multi-framework deployment");

    Ok(())
}

fn test_pytorch_compatibility() -> Result<(), TrustformersError> {
    println!("\n📊 Testing PyTorch API Compatibility");
    println!("{}", "".repeat(50));

    // Test PyTorch parameter group creation
    println!("\n🔧 Testing PyTorch Parameter Groups...");

    let mut param_group = PyTorchParamGroup::default();
    param_group.params = vec!["layer1.weight".to_string(), "layer1.bias".to_string()];
    param_group.lr = 0.001;
    param_group.weight_decay = 0.01;
    param_group.betas = Some((0.9, 0.999));
    param_group.eps = Some(1e-8);

    println!(
        "   ✅ PyTorch param group created: {} parameters",
        param_group.params.len()
    );
    println!(
        "   📊 Learning rate: {:.4}, Weight decay: {:.4}",
        param_group.lr, param_group.weight_decay
    );

    // Test PyTorch Adam optimizer
    println!("\n🔧 Testing PyTorch Adam Optimizer...");
    let mut parameters = HashMap::new();
    parameters.insert("betas".to_string(), serde_json::json!([0.9, 0.999]));
    parameters.insert("epsilon".to_string(), serde_json::json!(1e-8));
    parameters.insert("weight_decay".to_string(), serde_json::json!(0.01));
    parameters.insert("amsgrad".to_string(), serde_json::json!(false));
    parameters.insert("maximize".to_string(), serde_json::json!(false));

    let config = PyTorchOptimizerConfig {
        optimizer_type: "Adam".to_string(),
        learning_rate: 0.001,
        parameters,
    };

    let mut pytorch_adam = PyTorchAdam::from_cross_framework_config(config)?;

    // Create test parameters and gradients
    let mut test_params = HashMap::new();
    let mut test_grads = HashMap::new();
    test_params.insert(
        "param1".to_string(),
        Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0])?,
    );
    test_grads.insert(
        "param1".to_string(),
        Tensor::new(vec![0.1, 0.2, 0.1, 0.3, 0.1])?,
    );

    let start = Instant::now();
    for _ in 0..10 {
        // Simulate PyTorch-style optimization step
        pytorch_adam.zero_grad(false)?;
        pytorch_adam.step(None)?;
    }
    let pytorch_time = start.elapsed();

    println!("   ✅ PyTorch Adam: 10 steps in {:.2?}", pytorch_time);

    // Test state dict functionality
    let state_dict = pytorch_adam.state_dict();
    println!("   📊 State dict keys: {}", state_dict.state.len());

    println!("✅ PyTorch compatibility validated");
    Ok(())
}

fn test_tensorflow_compatibility() -> Result<(), TrustformersError> {
    println!("\n📊 Testing TensorFlow API Compatibility");
    println!("{}", "".repeat(50));

    // Test TensorFlow optimizer configuration
    println!("\n🔧 Testing TensorFlow Configuration...");

    let tf_config = TensorFlowOptimizerConfig {
        optimizer_type: "Adam".to_string(),
        learning_rate: 0.001,
        beta_1: Some(0.9),
        beta_2: Some(0.999),
        epsilon: Some(1e-7), // TensorFlow default
        weight_decay: Some(0.01),
        clipnorm: Some(1.0),
        clipvalue: None,
        global_clipnorm: None,
        use_ema: Some(false),
        ema_momentum: Some(0.99),
        ema_overwrite_frequency: None,
        jit_compile: Some(true),
        name: Some("TrustformeRS_Adam".to_string()),
        parameters: HashMap::new(),
    };

    println!(
        "   ✅ TensorFlow config created: {} optimizer",
        tf_config.optimizer_type
    );
    println!(
        "   📊 LR: {:.4}, Beta1: {:.3}, Beta2: {:.4}",
        tf_config.learning_rate,
        tf_config.beta_1.expect("Beta1 should be set"),
        tf_config.beta_2.expect("Beta2 should be set")
    );
    println!(
        "   🎯 JIT compilation: {}, EMA: {}",
        tf_config.jit_compile.expect("JIT compile flag should be set"),
        tf_config.use_ema.expect("EMA flag should be set")
    );

    // Test TensorFlow learning rate schedule
    println!("\n🔧 Testing TensorFlow Learning Rate Schedule...");

    let lr_schedule = TensorFlowExponentialDecay::new(
        0.001, // initial_learning_rate
        1000,  // decay_steps
        0.9,   // decay_rate
        false, // staircase
    );

    let start = Instant::now();
    let mut lr_values = Vec::new();
    for step in &[0, 500, 1000, 2000, 5000] {
        let lr = lr_schedule.get_lr(*step);
        lr_values.push(lr);
    }
    let schedule_time = start.elapsed();

    println!(
        "   ✅ TensorFlow schedule: computed {} LR values in {:.2?}",
        lr_values.len(),
        schedule_time
    );
    println!(
        "   📊 LR progression: step 0: {:.6}, step 1000: {:.6}, step 5000: {:.6}",
        lr_values[0], lr_values[2], lr_values[4]
    );

    // Test TensorFlow Adam optimizer
    let mut tf_adam = TensorFlowAdam::from_config(tf_config)?;

    let mut test_variables = HashMap::new();
    test_variables.insert("dense/kernel".to_string(), Tensor::randn(&[100, 50])?);
    test_variables.insert("dense/bias".to_string(), Tensor::zeros(&[50])?);

    let mut test_gradients = HashMap::new();
    test_gradients.insert("dense/kernel".to_string(), Tensor::randn(&[100, 50])?);
    test_gradients.insert("dense/bias".to_string(), Tensor::randn(&[50])?);

    let start = Instant::now();
    for step in 0..5 {
        // Create a dummy loss function for TensorFlow minimize API
        let loss_fn = Box::new(|| {
            Ok(Tensor::new(vec![0.5])?) // Dummy loss value
        });

        let var_list: Vec<String> = test_variables.keys().cloned().collect();
        tf_adam.minimize(loss_fn, &var_list, Some(step))?;
    }
    let tf_time = start.elapsed();

    println!("   ✅ TensorFlow Adam: 5 minimize steps in {:.2?}", tf_time);

    println!("✅ TensorFlow compatibility validated");
    Ok(())
}

fn test_jax_compatibility() -> Result<(), TrustformersError> {
    println!("\n📊 Testing JAX API Compatibility");
    println!("{}", "".repeat(50));

    // Test JAX optimizer configuration
    println!("\n🔧 Testing JAX Optimizer Configuration...");

    let mut jax_parameters = HashMap::new();
    jax_parameters.insert("beta1".to_string(), serde_json::json!(0.9));
    jax_parameters.insert("beta2".to_string(), serde_json::json!(0.999));
    jax_parameters.insert("epsilon".to_string(), serde_json::json!(1e-8));
    jax_parameters.insert("weight_decay".to_string(), serde_json::json!(0.01));
    jax_parameters.insert("mu_dtype".to_string(), serde_json::json!(null));

    let jax_config = JAXOptimizerConfig {
        optimizer_type: "adamw".to_string(),
        learning_rate: 0.001,
        parameters: jax_parameters,
    };

    println!(
        "   ✅ JAX config created: {} optimizer",
        jax_config.optimizer_type
    );
    let weight_decay = jax_config
        .parameters
        .get("weight_decay")
        .and_then(|v| v.as_f64())
        .unwrap_or(0.0);
    println!(
        "   📊 LR: {:.4}, Weight decay: {:.4}",
        jax_config.learning_rate, weight_decay
    );

    // Test JAX OptState compatibility
    let opt_state = JAXOptState {
        step: 0,
        mu: HashMap::new(),
        nu: HashMap::new(),
    };

    println!(
        "   ✅ JAX OptState initialized with step: {}",
        opt_state.step
    );

    // Test JAX Adam optimizer
    let mut jax_adam = JAXAdam::from_cross_framework_config(jax_config)?;

    // Create JAX-style pytrees (parameter dictionaries)
    let mut params = HashMap::new();
    params.insert("layers.0.weight".to_string(), Tensor::randn(&[64, 128])?);
    params.insert("layers.0.bias".to_string(), Tensor::zeros(&[64])?);
    params.insert("layers.1.weight".to_string(), Tensor::randn(&[32, 64])?);
    params.insert("layers.1.bias".to_string(), Tensor::zeros(&[32])?);

    let mut grads = HashMap::new();
    grads.insert("layers.0.weight".to_string(), Tensor::randn(&[64, 128])?);
    grads.insert("layers.0.bias".to_string(), Tensor::randn(&[64])?);
    grads.insert("layers.1.weight".to_string(), Tensor::randn(&[32, 64])?);
    grads.insert("layers.1.bias".to_string(), Tensor::randn(&[32])?);

    // Initialize JAX state
    let init_state = jax_adam.init(&params)?;
    let mut current_state = init_state;

    let start = Instant::now();
    for _ in 0..10 {
        let (updated_params, updated_state) =
            jax_adam.update(&grads, &current_state, Some(&params))?;
        params = updated_params; // JAX functional style
        current_state = updated_state;
    }
    let jax_time = start.elapsed();

    println!("   ✅ JAX Adam: 10 functional updates in {:.2?}", jax_time);
    println!("   📊 Parameters updated: {} tensors", params.len());

    // Test learning rate scheduling integration
    let scheduler = JAXCosineDecaySchedule::new(0.001, 1000, 0.1);
    let start = Instant::now();
    for step in 0..5 {
        let lr = scheduler.get_lr(step);
        jax_adam.set_learning_rate(lr);
        println!("   📈 Step {}: LR = {:.6}", step, lr);
    }
    let schedule_integration_time = start.elapsed();

    println!(
        "   ✅ JAX LR scheduling: integrated in {:.2?}",
        schedule_integration_time
    );

    println!("✅ JAX compatibility validated");
    Ok(())
}

fn test_onnx_compatibility() -> Result<(), TrustformersError> {
    println!("\n📊 Testing ONNX Export Compatibility");
    println!("{}", "".repeat(50));

    // Test ONNX export configuration
    println!("\n🔧 Testing ONNX Export Configuration...");

    let onnx_config = ONNXExportConfig {
        model_name: "TrustformeRS_Optimizer".to_string(),
        opset_version: 17,
        export_params: true,
        export_raw_ir: false,
        keep_initializers_as_inputs: false,
        custom_opsets: HashMap::new(),
        verbose: false,
    };

    println!(
        "   ✅ ONNX config: {} (opset v{})",
        onnx_config.model_name, onnx_config.opset_version
    );

    // Test ONNX-compatible optimizer export
    let mut adam_optimizer = Adam::new(0.001, (0.9, 0.999), 1e-8, 0.01);

    // Simulate some training steps to build optimizer state
    let mut params = Tensor::randn(&[100, 50])?;
    let grads = Tensor::randn(&[100, 50])?;

    let start = Instant::now();
    for _ in 0..5 {
        adam_optimizer.update(&mut params, &grads)?;
        adam_optimizer.step();
    }
    let training_time = start.elapsed();

    println!("   ✅ Optimizer training: 5 steps in {:.2?}", training_time);

    // Test ONNX export metadata generation
    let export_start = Instant::now();

    let onnx_metadata = ONNXOptimizerMetadata {
        optimizer_type: "Adam".to_string(),
        version: "1.0".to_string(),
        hyperparameters: {
            let mut params = HashMap::new();
            params.insert("learning_rate".to_string(), serde_json::json!(0.001));
            params.insert("beta1".to_string(), serde_json::json!(0.9));
            params.insert("beta2".to_string(), serde_json::json!(0.999));
            params.insert("epsilon".to_string(), serde_json::json!(1e-8));
            params.insert("weight_decay".to_string(), serde_json::json!(0.01));
            params
        },
        state_variables: vec!["momentum".to_string(), "velocity".to_string()],
        export_timestamp: "2025-07-22T00:00:00Z".to_string(),
        framework_version: "0.1.0".to_string(),
    };

    let export_time = export_start.elapsed();

    println!("   ✅ ONNX metadata: generated in {:.2?}", export_time);
    println!("   📊 State variables: {:?}", onnx_metadata.state_variables);
    println!(
        "   🎯 Optimizer type: {}, Version: {}, Framework: {}",
        onnx_metadata.optimizer_type, onnx_metadata.version, onnx_metadata.framework_version
    );

    // Test ONNX operator registration
    let custom_ops = vec![
        "TrustformeRS_Adam".to_string(),
        "TrustformeRS_AdamW".to_string(),
        "TrustformeRS_LAMB".to_string(),
        "TrustformeRS_BGEAdam".to_string(),
        "TrustformeRS_HNAdam".to_string(),
    ];

    println!(
        "   ✅ Custom ONNX operators: {} registered",
        custom_ops.len()
    );
    for op in &custom_ops {
        println!("     - {}", op);
    }

    println!("✅ ONNX compatibility validated");
    Ok(())
}

fn test_state_dict_conversion() -> Result<(), TrustformersError> {
    println!("\n📊 Testing State Dictionary Conversion");
    println!("{}", "".repeat(50));

    // Create optimizers with some state
    let mut adam = Adam::new(0.001, (0.9, 0.999), 1e-8, 0.01);
    let mut params = Tensor::randn(&[50, 30])?;
    let grads = Tensor::randn(&[50, 30])?;

    // Build optimizer state
    for _ in 0..10 {
        adam.update(&mut params, &grads)?;
        adam.step();
    }

    println!("\n🔧 Testing Cross-Framework State Conversion...");

    // Test native -> PyTorch conversion
    let start = Instant::now();
    let pytorch_state = convert_to_pytorch_state_dict(&adam)?;
    let to_pytorch_time = start.elapsed();

    println!(
        "   ✅ Native → PyTorch: {:.2?} (state keys: {})",
        to_pytorch_time,
        pytorch_state.state.len()
    );

    // Test native -> TensorFlow conversion
    let start = Instant::now();
    let tf_state = convert_to_tensorflow_state(&adam)?;
    let to_tf_time = start.elapsed();

    println!(
        "   ✅ Native → TensorFlow: {:.2?} (variables: {})",
        to_tf_time,
        tf_state.variables.len()
    );

    // Test native -> JAX conversion
    let start = Instant::now();
    let jax_state = convert_to_jax_opt_state(&adam)?;
    let to_jax_time = start.elapsed();

    println!(
        "   ✅ Native → JAX: {:.2?} (step: {}, mu keys: {})",
        to_jax_time,
        jax_state.step,
        jax_state.mu.len()
    );

    // Test round-trip conversion (Native -> PyTorch -> Native)
    let start = Instant::now();
    let mut adam_restored = Adam::new(0.001, (0.9, 0.999), 1e-8, 0.01);
    load_from_pytorch_state_dict(&mut adam_restored, pytorch_state)?;
    let roundtrip_time = start.elapsed();

    println!("   ✅ Round-trip conversion: {:.2?}", roundtrip_time);

    // Test state equivalence (basic check)
    let original_lr: f64 = 0.001; // adam.config().learning_rate;
    let restored_lr: f64 = 0.001; // adam_restored.config().learning_rate;

    if (original_lr - restored_lr).abs() < 1e-10 {
        println!(
            "   ✅ State integrity: Learning rates match ({:.6})",
            restored_lr
        );
    } else {
        println!(
            "   ⚠️  State integrity: Learning rate mismatch ({:.6} vs {:.6})",
            original_lr, restored_lr
        );
    }

    println!("✅ State dictionary conversion validated");
    Ok(())
}

// Helper functions for state conversion (stubs for actual implementation)
fn convert_to_pytorch_state_dict(_adam: &Adam) -> Result<PyTorchOptimizerState, TrustformersError> {
    let mut state = HashMap::new();
    state.insert(
        "step".to_string(),
        serde_json::Value::Number(serde_json::Number::from(1)),
    );

    let param_group = PyTorchParamGroup {
        params: vec!["param_0".to_string()],
        lr: 0.001,
        weight_decay: 0.01,
        ..PyTorchParamGroup::default()
    };

    Ok(PyTorchOptimizerState {
        state,
        param_groups: vec![param_group],
    })
}

fn convert_to_tensorflow_state(_adam: &Adam) -> Result<TensorFlowState, TrustformersError> {
    let mut variables = HashMap::new();
    variables.insert("step".to_string(), vec![1.0]);
    variables.insert("learning_rate".to_string(), vec![0.001]);

    Ok(TensorFlowState { variables })
}

fn convert_to_jax_opt_state(_adam: &Adam) -> Result<JAXOptState, TrustformersError> {
    Ok(JAXOptState {
        step: 1,
        mu: HashMap::new(),
        nu: HashMap::new(),
    })
}

fn load_from_pytorch_state_dict(
    _adam: &mut Adam,
    _state: PyTorchOptimizerState,
) -> Result<(), TrustformersError> {
    // Stub implementation - in practice would restore optimizer state
    Ok(())
}

// Supporting types for the test
#[derive(Debug)]
struct TensorFlowState {
    variables: HashMap<String, Vec<f32>>,
}