Skip to main content

burn_dragon_train/config/
core.rs

1use std::fmt;
2
3use burn::module::{Content, ModuleDisplay, ModuleDisplayDefault};
4pub use burn_dragon_core::{
5    BdhFiringTargetKind, BdhInitializationKind, BdhNeuronGainKind, BdhResidualScalingKind,
6    BdhTopologyPriorKind, BitNetLowBitProtocol, LowBitActivationFormat, LowBitActivationGrouping,
7    LowBitInferenceMode, LowBitSavedActivationMode, LowBitTargetModule, LowBitTrainingMode,
8    LowBitWeightFormat, LowBitWeightGrouping, RhoCompressionConfig, RhoPrecisionConfig,
9    SequenceKernelConfig,
10};
11use serde::{Deserialize, Serialize};
12
13fn default_parallel_world_size() -> usize {
14    1
15}
16
17fn default_parallel_group_size() -> usize {
18    1
19}
20
21fn default_find_unused_parameters() -> bool {
22    false
23}
24
25fn default_gradient_as_bucket_view() -> bool {
26    true
27}
28
29fn default_pipeline_stage_count() -> usize {
30    1
31}
32
33fn default_pipeline_microbatches() -> usize {
34    1
35}
36
37fn default_pipeline_virtual_stages_per_rank() -> usize {
38    1
39}
40
41fn default_pipeline_max_inflight_microbatches() -> usize {
42    1
43}
44
45#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
46#[serde(rename_all = "snake_case")]
47pub enum ParallelismKind {
48    #[default]
49    Single,
50    Ddp,
51    Fsdp,
52    TensorParallelNeuron,
53    Hybrid2D,
54}
55
56#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
57#[serde(rename_all = "snake_case")]
58pub enum ParallelCommunicationBackend {
59    #[default]
60    Auto,
61    Nccl,
62    Gloo,
63}
64
65#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
66#[serde(rename_all = "snake_case")]
67pub enum TensorParallelAxis {
68    #[default]
69    Neuron,
70}
71
72#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
73#[serde(rename_all = "snake_case")]
74pub enum TensorParallelPartitionKind {
75    #[default]
76    Contiguous,
77    HeadAligned,
78}
79
80#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
81#[serde(rename_all = "snake_case")]
82pub enum ParallelCheckpointFormat {
83    #[default]
84    UnshardedV1,
85    ShardedV2,
86}
87
88#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
89#[serde(rename_all = "snake_case")]
90pub enum FsdpMixedPrecisionKind {
91    #[default]
92    Disabled,
93    Bf16,
94    F16,
95}
96
97#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
98#[serde(rename_all = "snake_case")]
99pub enum PipelineScheduleKind {
100    Gpipe,
101    #[serde(rename = "interleaved_1f1b")]
102    #[default]
103    Interleaved1f1b,
104}
105
106#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
107#[serde(rename_all = "snake_case")]
108pub enum PipelinePartitionKind {
109    #[default]
110    LayerContiguous,
111}
112
113#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
114#[serde(rename_all = "snake_case")]
115pub enum PipelineCommunicationKind {
116    #[default]
117    ActivationTensor,
118    BlockResidualCache,
119}
120
121#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
122#[serde(rename_all = "snake_case")]
123pub enum PipelineSharedWeightSyncKind {
124    #[default]
125    AllReducePerStep,
126}
127
128#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
129#[serde(rename_all = "snake_case")]
130pub enum PipelineCachePolicy {
131    #[default]
132    Disabled,
133    ResidentBlockSummaries,
134}
135
136#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
137#[serde(rename_all = "snake_case")]
138pub enum PipelineCacheEvictionKind {
139    #[default]
140    StepBoundary,
141}
142
143#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
144#[serde(rename_all = "snake_case")]
145pub enum PipelineTransportDtype {
146    #[default]
147    Auto,
148    Fp32,
149    Bf16,
150    F16,
151}
152
153#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
154#[serde(default)]
155pub struct ParallelDataConfig {
156    #[serde(default = "default_parallel_group_size")]
157    pub size: usize,
158    pub backend: ParallelCommunicationBackend,
159    #[serde(default = "default_find_unused_parameters")]
160    pub find_unused_parameters: bool,
161    #[serde(default = "default_gradient_as_bucket_view")]
162    pub gradient_as_bucket_view: bool,
163    #[serde(default)]
164    pub collective_num_nodes: Option<u32>,
165    #[serde(default)]
166    pub collective_global_address: Option<String>,
167    #[serde(default)]
168    pub collective_node_address: Option<String>,
169    #[serde(default)]
170    pub collective_data_service_port: Option<u16>,
171}
172
173impl Default for ParallelDataConfig {
174    fn default() -> Self {
175        Self {
176            size: default_parallel_group_size(),
177            backend: ParallelCommunicationBackend::default(),
178            find_unused_parameters: default_find_unused_parameters(),
179            gradient_as_bucket_view: default_gradient_as_bucket_view(),
180            collective_num_nodes: None,
181            collective_global_address: None,
182            collective_node_address: None,
183            collective_data_service_port: None,
184        }
185    }
186}
187
188#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
189#[serde(default)]
190pub struct ParallelTensorConfig {
191    #[serde(default = "default_parallel_group_size")]
192    pub size: usize,
193    pub axis: TensorParallelAxis,
194    pub partition: TensorParallelPartitionKind,
195    pub sequence_parallel: bool,
196}
197
198impl Default for ParallelTensorConfig {
199    fn default() -> Self {
200        Self {
201            size: default_parallel_group_size(),
202            axis: TensorParallelAxis::default(),
203            partition: TensorParallelPartitionKind::default(),
204            sequence_parallel: false,
205        }
206    }
207}
208
209#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)]
210#[serde(default)]
211pub struct ParallelFsdpConfig {
212    pub enabled: bool,
213    pub reshard_after_forward: bool,
214    pub cpu_offload: bool,
215    pub mixed_precision: FsdpMixedPrecisionKind,
216}
217
218#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)]
219#[serde(default)]
220pub struct ParallelCheckpointConfig {
221    pub format: ParallelCheckpointFormat,
222    pub async_write: bool,
223}
224
225#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
226#[serde(default)]
227pub struct ParallelPipelineCacheConfig {
228    pub enabled: bool,
229    pub policy: PipelineCachePolicy,
230    pub reuse_across_backward: bool,
231    #[serde(default = "default_pipeline_max_inflight_microbatches")]
232    pub max_inflight_microbatches: usize,
233    pub eviction: PipelineCacheEvictionKind,
234    pub transport_dtype: PipelineTransportDtype,
235}
236
237impl Default for ParallelPipelineCacheConfig {
238    fn default() -> Self {
239        Self {
240            enabled: false,
241            policy: PipelineCachePolicy::default(),
242            reuse_across_backward: true,
243            max_inflight_microbatches: default_pipeline_max_inflight_microbatches(),
244            eviction: PipelineCacheEvictionKind::default(),
245            transport_dtype: PipelineTransportDtype::default(),
246        }
247    }
248}
249
250#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
251#[serde(default)]
252pub struct ParallelPipelineConfig {
253    pub enabled: bool,
254    #[serde(default = "default_pipeline_stage_count")]
255    pub stage_count: usize,
256    #[serde(default = "default_pipeline_virtual_stages_per_rank")]
257    pub virtual_stages_per_rank: usize,
258    pub schedule: PipelineScheduleKind,
259    #[serde(default = "default_pipeline_microbatches")]
260    pub microbatches: usize,
261    pub partition: PipelinePartitionKind,
262    pub activation_checkpointing: bool,
263    pub shared_weight_sync: PipelineSharedWeightSyncKind,
264    pub communication: PipelineCommunicationKind,
265    pub cache: ParallelPipelineCacheConfig,
266}
267
268impl Default for ParallelPipelineConfig {
269    fn default() -> Self {
270        Self {
271            enabled: false,
272            stage_count: default_pipeline_stage_count(),
273            virtual_stages_per_rank: default_pipeline_virtual_stages_per_rank(),
274            schedule: PipelineScheduleKind::default(),
275            microbatches: default_pipeline_microbatches(),
276            partition: PipelinePartitionKind::default(),
277            activation_checkpointing: false,
278            shared_weight_sync: PipelineSharedWeightSyncKind::default(),
279            communication: PipelineCommunicationKind::default(),
280            cache: ParallelPipelineCacheConfig::default(),
281        }
282    }
283}
284
285#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
286#[serde(default)]
287pub struct ParallelConfig {
288    pub mode: ParallelismKind,
289    #[serde(default = "default_parallel_world_size")]
290    pub world_size: usize,
291    pub data: ParallelDataConfig,
292    pub tensor: ParallelTensorConfig,
293    pub fsdp: ParallelFsdpConfig,
294    pub checkpoint: ParallelCheckpointConfig,
295    pub pipeline: ParallelPipelineConfig,
296}
297
298impl Default for ParallelConfig {
299    fn default() -> Self {
300        Self {
301            mode: ParallelismKind::Single,
302            world_size: default_parallel_world_size(),
303            data: ParallelDataConfig::default(),
304            tensor: ParallelTensorConfig::default(),
305            fsdp: ParallelFsdpConfig::default(),
306            checkpoint: ParallelCheckpointConfig::default(),
307            pipeline: ParallelPipelineConfig::default(),
308        }
309    }
310}
311
312#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
313pub struct LowBitModelSpec {
314    pub enabled: bool,
315    pub protocol: BitNetLowBitProtocol,
316    pub training_mode: LowBitTrainingMode,
317    pub inference_mode: LowBitInferenceMode,
318    pub weight_format: LowBitWeightFormat,
319    pub activation_format: LowBitActivationFormat,
320    pub decoder_x_mode: LowBitWeightFormat,
321    pub encoder_mode: Option<LowBitWeightFormat>,
322    pub activation_grouping: LowBitActivationGrouping,
323    pub weight_grouping: LowBitWeightGrouping,
324    pub strict_bitnet_reference: bool,
325    pub target_modules: Vec<LowBitTargetModule>,
326    pub rho_precision: RhoPrecisionConfig,
327    pub rho_compression: RhoCompressionConfig,
328}
329
330#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
331pub struct ModelSpec {
332    pub arch: String,
333    pub n_embd: usize,
334    pub n_head: usize,
335    pub n_layer: usize,
336    pub latent_total: usize,
337    pub latent_per_head: usize,
338    pub shared_layer_weights: bool,
339    pub sequence_kernel: SequenceKernelConfig,
340    #[serde(default)]
341    pub bdh_initialization_kind: BdhInitializationKind,
342    #[serde(default)]
343    pub bdh_residual_scaling_kind: BdhResidualScalingKind,
344    #[serde(default)]
345    pub bdh_neuron_gain_kind: BdhNeuronGainKind,
346    #[serde(default)]
347    pub bdh_topology_prior_kind: BdhTopologyPriorKind,
348    #[serde(default)]
349    pub bdh_firing_target_kind: BdhFiringTargetKind,
350    #[serde(default, skip_serializing_if = "Option::is_none")]
351    pub low_bit: Option<LowBitModelSpec>,
352}
353
354#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
355pub struct ParallelSpec {
356    pub mode: ParallelismKind,
357    pub world_size: usize,
358    pub data_parallel_size: usize,
359    pub tensor_parallel_size: usize,
360    pub tensor_parallel_axis: TensorParallelAxis,
361    pub tensor_parallel_partition: TensorParallelPartitionKind,
362    pub fsdp_enabled: bool,
363    pub checkpoint_format: ParallelCheckpointFormat,
364    pub collective_num_nodes: Option<u32>,
365    pub collective_global_address: Option<String>,
366    pub collective_node_address: Option<String>,
367    pub collective_data_service_port: Option<u16>,
368    #[serde(default)]
369    pub pipeline_enabled: bool,
370    #[serde(default)]
371    pub pipeline_stage_count: usize,
372    #[serde(default)]
373    pub pipeline_virtual_stages_per_rank: usize,
374    #[serde(default)]
375    pub pipeline_schedule: PipelineScheduleKind,
376    #[serde(default)]
377    pub pipeline_microbatches: usize,
378    #[serde(default)]
379    pub pipeline_partition: PipelinePartitionKind,
380    #[serde(default)]
381    pub pipeline_activation_checkpointing: bool,
382    #[serde(default)]
383    pub pipeline_shared_weight_sync: PipelineSharedWeightSyncKind,
384    #[serde(default)]
385    pub pipeline_communication: PipelineCommunicationKind,
386    #[serde(default)]
387    pub pipeline_cache_enabled: bool,
388    #[serde(default)]
389    pub pipeline_cache_policy: PipelineCachePolicy,
390    #[serde(default)]
391    pub pipeline_cache_reuse_across_backward: bool,
392    #[serde(default)]
393    pub pipeline_cache_max_inflight_microbatches: usize,
394    #[serde(default)]
395    pub pipeline_cache_eviction: PipelineCacheEvictionKind,
396    #[serde(default)]
397    pub pipeline_cache_transport_dtype: PipelineTransportDtype,
398}
399
400#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
401pub struct KernelSpec {
402    pub sequence_kernel: SequenceKernelConfig,
403    pub fused_kernels_enabled: bool,
404    pub rollout_fast_steps_per_slow_step: usize,
405    pub wgpu_fused_core_recurrent: Option<bool>,
406    pub wgpu_fused_core_rollout: Option<bool>,
407    #[serde(default, skip_serializing_if = "Option::is_none")]
408    pub low_bit_kernel_abi_version: Option<u32>,
409    #[serde(default, skip_serializing_if = "Option::is_none")]
410    pub low_bit_runtime: Option<String>,
411    #[serde(default, skip_serializing_if = "Option::is_none")]
412    pub low_bit_saved_activation_mode: Option<LowBitSavedActivationMode>,
413    #[serde(default, skip_serializing_if = "Option::is_none")]
414    pub low_bit_saved_activation_format: Option<String>,
415    #[serde(default, skip_serializing_if = "Option::is_none")]
416    pub low_bit_saved_activation_inventory: Option<LowBitSavedActivationInventorySpec>,
417    #[serde(default, skip_serializing_if = "Option::is_none")]
418    pub low_bit_native_supported: Option<bool>,
419    #[serde(default, skip_serializing_if = "Option::is_none")]
420    pub low_bit_memory: Option<LowBitMemorySpec>,
421}
422
423#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
424pub struct LowBitMemorySpec {
425    pub master_weight_bytes: u64,
426    pub execution_weight_bytes: u64,
427    pub activation_shell_bytes: u64,
428    pub saved_activation_bytes: u64,
429    pub rho_state_bytes: u64,
430    pub workspace_bytes: u64,
431    pub estimated_total_bytes: u64,
432}
433
434#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
435pub struct LowBitSavedActivationInventorySpec {
436    pub mode: LowBitSavedActivationMode,
437    pub format: String,
438    pub requires_rho_window_anchor: bool,
439    pub tensors: Vec<LowBitSavedActivationTensorSpec>,
440}
441
442#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
443pub struct LowBitSavedActivationTensorSpec {
444    pub name: String,
445    pub shape: Vec<usize>,
446    pub element_count: u64,
447    pub estimated_bytes: u64,
448    pub recompute_policy: String,
449}
450
451#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
452pub struct OptimizerSpec {
453    pub name: super::optimizer::OptimizerKind,
454    pub learning_rate: f64,
455    pub weight_decay: f32,
456    #[serde(default, skip_serializing_if = "Option::is_none")]
457    pub weight_decay_final: Option<f32>,
458    pub schedule_mode: super::optimizer::OptimizerScheduleMode,
459}
460
461#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
462pub struct StateAxisSpec {
463    pub name: String,
464    pub size: Option<usize>,
465}
466
467#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
468pub struct StateTensorSpec {
469    pub name: String,
470    pub axes: Vec<StateAxisSpec>,
471}
472
473#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
474pub struct LayerStateSpec {
475    pub layer_index: usize,
476    pub latent_total: usize,
477    pub latent_per_head: usize,
478    pub tensors: Vec<StateTensorSpec>,
479}
480
481#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
482pub struct StateLayout {
483    pub state_family: String,
484    pub position_tracked: bool,
485    pub layers: Vec<LayerStateSpec>,
486}
487
488#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)]
489#[serde(rename_all = "snake_case")]
490pub enum WgpuBackend {
491    #[default]
492    Auto,
493    Vulkan,
494    Dx12,
495    Metal,
496    #[serde(rename = "opengl")]
497    OpenGl,
498}
499
500#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)]
501pub enum WgpuMemoryConfig {
502    #[serde(rename = "subslices")]
503    #[default]
504    SubSlices,
505    Exclusive,
506}
507
508#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)]
509#[serde(rename_all = "snake_case")]
510pub enum WgpuGenerationExecutor {
511    #[default]
512    Baseline,
513    RolloutChunked,
514}
515
516fn default_generation_chunk_tokens() -> usize {
517    8
518}
519
520fn default_generation_device_buffer_tokens() -> usize {
521    64
522}
523
524#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
525#[serde(default)]
526pub struct WgpuInferenceConfig {
527    pub fused_core_recurrent: Option<bool>,
528    pub fused_core_rollout: Option<bool>,
529    pub generation_executor: WgpuGenerationExecutor,
530    #[serde(default = "default_generation_chunk_tokens")]
531    pub generation_chunk_tokens: usize,
532    #[serde(default = "default_generation_device_buffer_tokens")]
533    pub generation_device_buffer_tokens: usize,
534}
535
536impl Default for WgpuInferenceConfig {
537    fn default() -> Self {
538        Self {
539            fused_core_recurrent: None,
540            fused_core_rollout: None,
541            generation_executor: WgpuGenerationExecutor::Baseline,
542            generation_chunk_tokens: default_generation_chunk_tokens(),
543            generation_device_buffer_tokens: default_generation_device_buffer_tokens(),
544        }
545    }
546}
547
548#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)]
549#[serde(default)]
550pub struct WgpuTrainingConfig {
551    pub fused_core_recurrent: Option<bool>,
552    pub fused_core_rollout: Option<bool>,
553    pub startup_autotune: WgpuStartupAutotuneConfig,
554}
555
556fn default_startup_autotune_min_batch_size() -> usize {
557    1
558}
559
560fn default_startup_autotune_probe_steps() -> usize {
561    1
562}
563
564fn default_startup_autotune_binary_search() -> bool {
565    true
566}
567
568#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
569#[serde(default)]
570pub struct WgpuStartupAutotuneConfig {
571    pub enabled: bool,
572    pub target_device_memory_mb: usize,
573    #[serde(default = "default_startup_autotune_min_batch_size")]
574    pub min_batch_size: usize,
575    pub max_batch_size: Option<usize>,
576    #[serde(default = "default_startup_autotune_probe_steps")]
577    pub probe_steps: usize,
578    #[serde(default = "default_startup_autotune_binary_search")]
579    pub binary_search: bool,
580}
581
582impl Default for WgpuStartupAutotuneConfig {
583    fn default() -> Self {
584        Self {
585            enabled: false,
586            target_device_memory_mb: 0,
587            min_batch_size: default_startup_autotune_min_batch_size(),
588            max_batch_size: None,
589            probe_steps: default_startup_autotune_probe_steps(),
590            binary_search: default_startup_autotune_binary_search(),
591        }
592    }
593}
594
595#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)]
596#[serde(default)]
597pub struct WgpuRuntimeConfig {
598    pub backend: WgpuBackend,
599    pub tasks_max: Option<usize>,
600    pub memory: WgpuMemoryConfig,
601    pub training: WgpuTrainingConfig,
602    pub inference: WgpuInferenceConfig,
603}
604
605#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
606#[serde(default)]
607pub struct GdpoConfig {
608    pub enabled: bool,
609    pub group_size: usize,
610    pub hard_weight: f32,
611    pub easy_weight: f32,
612    pub policy_weight: f32,
613    pub policy_clip_range: f32,
614    pub hard_gate: GdpoHardGate,
615    pub norm_epsilon: f32,
616    pub advantage_clip: f32,
617    pub advantage_ema_decay: f32,
618}
619
620impl Default for GdpoConfig {
621    fn default() -> Self {
622        Self {
623            enabled: false,
624            group_size: 1,
625            hard_weight: 1.0,
626            easy_weight: 1.0,
627            policy_weight: 1.0,
628            policy_clip_range: 0.2,
629            hard_gate: GdpoHardGate::Percentile { quantile: 0.5 },
630            norm_epsilon: 1e-6,
631            advantage_clip: 0.0,
632            advantage_ema_decay: 0.0,
633        }
634    }
635}
636
637#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
638#[serde(tag = "type", rename_all = "snake_case")]
639pub enum GdpoHardGate {
640    Off,
641    Fixed { threshold: f32 },
642    Percentile { quantile: f32 },
643}
644
645impl Default for GdpoHardGate {
646    fn default() -> Self {
647        Self::Percentile { quantile: 0.5 }
648    }
649}
650
651impl fmt::Display for GdpoHardGate {
652    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
653        match self {
654            Self::Off => write!(f, "off"),
655            Self::Fixed { threshold } => write!(f, "fixed(threshold={threshold:.4})"),
656            Self::Percentile { quantile } => write!(f, "percentile(quantile={quantile:.3})"),
657        }
658    }
659}
660
661impl ModuleDisplayDefault for GdpoHardGate {
662    fn content(&self, content: Content) -> Option<Content> {
663        content.add_formatted(self).optional()
664    }
665}
666
667impl ModuleDisplay for GdpoHardGate {}
668
669impl ModuleDisplayDefault for GdpoConfig {
670    fn content(&self, content: Content) -> Option<Content> {
671        content
672            .add("enabled", &self.enabled)
673            .add("group_size", &self.group_size)
674            .add("hard_weight", &self.hard_weight)
675            .add("easy_weight", &self.easy_weight)
676            .add("policy_weight", &self.policy_weight)
677            .add("policy_clip_range", &self.policy_clip_range)
678            .add("hard_gate", &self.hard_gate)
679            .add("norm_epsilon", &self.norm_epsilon)
680            .add("advantage_clip", &self.advantage_clip)
681            .add("advantage_ema_decay", &self.advantage_ema_decay)
682            .optional()
683    }
684}
685
686impl ModuleDisplay for GdpoConfig {}
687
688#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
689#[serde(rename_all = "snake_case")]
690pub enum VisionTeacherVariant {
691    #[default]
692    Vits,
693    Vitb,
694    Vitl,
695    Vitg,
696}