1use std::fmt;
2
3use burn::module::{Content, ModuleDisplay, ModuleDisplayDefault};
4pub use burn_dragon_core::{
5 BdhFiringTargetKind, BdhInitializationKind, BdhNeuronGainKind, BdhResidualScalingKind,
6 BdhTopologyPriorKind, BitNetLowBitProtocol, LowBitActivationFormat, LowBitActivationGrouping,
7 LowBitInferenceMode, LowBitSavedActivationMode, LowBitTargetModule, LowBitTrainingMode,
8 LowBitWeightFormat, LowBitWeightGrouping, RhoCompressionConfig, RhoPrecisionConfig,
9 SequenceKernelConfig,
10};
11use serde::{Deserialize, Serialize};
12
13fn default_parallel_world_size() -> usize {
14 1
15}
16
17fn default_parallel_group_size() -> usize {
18 1
19}
20
21fn default_find_unused_parameters() -> bool {
22 false
23}
24
25fn default_gradient_as_bucket_view() -> bool {
26 true
27}
28
29fn default_pipeline_stage_count() -> usize {
30 1
31}
32
33fn default_pipeline_microbatches() -> usize {
34 1
35}
36
37fn default_pipeline_virtual_stages_per_rank() -> usize {
38 1
39}
40
41fn default_pipeline_max_inflight_microbatches() -> usize {
42 1
43}
44
45#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
46#[serde(rename_all = "snake_case")]
47pub enum ParallelismKind {
48 #[default]
49 Single,
50 Ddp,
51 Fsdp,
52 TensorParallelNeuron,
53 Hybrid2D,
54}
55
56#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
57#[serde(rename_all = "snake_case")]
58pub enum ParallelCommunicationBackend {
59 #[default]
60 Auto,
61 Nccl,
62 Gloo,
63}
64
65#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
66#[serde(rename_all = "snake_case")]
67pub enum TensorParallelAxis {
68 #[default]
69 Neuron,
70}
71
72#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
73#[serde(rename_all = "snake_case")]
74pub enum TensorParallelPartitionKind {
75 #[default]
76 Contiguous,
77 HeadAligned,
78}
79
80#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
81#[serde(rename_all = "snake_case")]
82pub enum ParallelCheckpointFormat {
83 #[default]
84 UnshardedV1,
85 ShardedV2,
86}
87
88#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
89#[serde(rename_all = "snake_case")]
90pub enum FsdpMixedPrecisionKind {
91 #[default]
92 Disabled,
93 Bf16,
94 F16,
95}
96
97#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
98#[serde(rename_all = "snake_case")]
99pub enum PipelineScheduleKind {
100 Gpipe,
101 #[serde(rename = "interleaved_1f1b")]
102 #[default]
103 Interleaved1f1b,
104}
105
106#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
107#[serde(rename_all = "snake_case")]
108pub enum PipelinePartitionKind {
109 #[default]
110 LayerContiguous,
111}
112
113#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
114#[serde(rename_all = "snake_case")]
115pub enum PipelineCommunicationKind {
116 #[default]
117 ActivationTensor,
118 BlockResidualCache,
119}
120
121#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
122#[serde(rename_all = "snake_case")]
123pub enum PipelineSharedWeightSyncKind {
124 #[default]
125 AllReducePerStep,
126}
127
128#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
129#[serde(rename_all = "snake_case")]
130pub enum PipelineCachePolicy {
131 #[default]
132 Disabled,
133 ResidentBlockSummaries,
134}
135
136#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
137#[serde(rename_all = "snake_case")]
138pub enum PipelineCacheEvictionKind {
139 #[default]
140 StepBoundary,
141}
142
143#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
144#[serde(rename_all = "snake_case")]
145pub enum PipelineTransportDtype {
146 #[default]
147 Auto,
148 Fp32,
149 Bf16,
150 F16,
151}
152
153#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
154#[serde(default)]
155pub struct ParallelDataConfig {
156 #[serde(default = "default_parallel_group_size")]
157 pub size: usize,
158 pub backend: ParallelCommunicationBackend,
159 #[serde(default = "default_find_unused_parameters")]
160 pub find_unused_parameters: bool,
161 #[serde(default = "default_gradient_as_bucket_view")]
162 pub gradient_as_bucket_view: bool,
163 #[serde(default)]
164 pub collective_num_nodes: Option<u32>,
165 #[serde(default)]
166 pub collective_global_address: Option<String>,
167 #[serde(default)]
168 pub collective_node_address: Option<String>,
169 #[serde(default)]
170 pub collective_data_service_port: Option<u16>,
171}
172
173impl Default for ParallelDataConfig {
174 fn default() -> Self {
175 Self {
176 size: default_parallel_group_size(),
177 backend: ParallelCommunicationBackend::default(),
178 find_unused_parameters: default_find_unused_parameters(),
179 gradient_as_bucket_view: default_gradient_as_bucket_view(),
180 collective_num_nodes: None,
181 collective_global_address: None,
182 collective_node_address: None,
183 collective_data_service_port: None,
184 }
185 }
186}
187
188#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
189#[serde(default)]
190pub struct ParallelTensorConfig {
191 #[serde(default = "default_parallel_group_size")]
192 pub size: usize,
193 pub axis: TensorParallelAxis,
194 pub partition: TensorParallelPartitionKind,
195 pub sequence_parallel: bool,
196}
197
198impl Default for ParallelTensorConfig {
199 fn default() -> Self {
200 Self {
201 size: default_parallel_group_size(),
202 axis: TensorParallelAxis::default(),
203 partition: TensorParallelPartitionKind::default(),
204 sequence_parallel: false,
205 }
206 }
207}
208
209#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)]
210#[serde(default)]
211pub struct ParallelFsdpConfig {
212 pub enabled: bool,
213 pub reshard_after_forward: bool,
214 pub cpu_offload: bool,
215 pub mixed_precision: FsdpMixedPrecisionKind,
216}
217
218#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)]
219#[serde(default)]
220pub struct ParallelCheckpointConfig {
221 pub format: ParallelCheckpointFormat,
222 pub async_write: bool,
223}
224
225#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
226#[serde(default)]
227pub struct ParallelPipelineCacheConfig {
228 pub enabled: bool,
229 pub policy: PipelineCachePolicy,
230 pub reuse_across_backward: bool,
231 #[serde(default = "default_pipeline_max_inflight_microbatches")]
232 pub max_inflight_microbatches: usize,
233 pub eviction: PipelineCacheEvictionKind,
234 pub transport_dtype: PipelineTransportDtype,
235}
236
237impl Default for ParallelPipelineCacheConfig {
238 fn default() -> Self {
239 Self {
240 enabled: false,
241 policy: PipelineCachePolicy::default(),
242 reuse_across_backward: true,
243 max_inflight_microbatches: default_pipeline_max_inflight_microbatches(),
244 eviction: PipelineCacheEvictionKind::default(),
245 transport_dtype: PipelineTransportDtype::default(),
246 }
247 }
248}
249
250#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
251#[serde(default)]
252pub struct ParallelPipelineConfig {
253 pub enabled: bool,
254 #[serde(default = "default_pipeline_stage_count")]
255 pub stage_count: usize,
256 #[serde(default = "default_pipeline_virtual_stages_per_rank")]
257 pub virtual_stages_per_rank: usize,
258 pub schedule: PipelineScheduleKind,
259 #[serde(default = "default_pipeline_microbatches")]
260 pub microbatches: usize,
261 pub partition: PipelinePartitionKind,
262 pub activation_checkpointing: bool,
263 pub shared_weight_sync: PipelineSharedWeightSyncKind,
264 pub communication: PipelineCommunicationKind,
265 pub cache: ParallelPipelineCacheConfig,
266}
267
268impl Default for ParallelPipelineConfig {
269 fn default() -> Self {
270 Self {
271 enabled: false,
272 stage_count: default_pipeline_stage_count(),
273 virtual_stages_per_rank: default_pipeline_virtual_stages_per_rank(),
274 schedule: PipelineScheduleKind::default(),
275 microbatches: default_pipeline_microbatches(),
276 partition: PipelinePartitionKind::default(),
277 activation_checkpointing: false,
278 shared_weight_sync: PipelineSharedWeightSyncKind::default(),
279 communication: PipelineCommunicationKind::default(),
280 cache: ParallelPipelineCacheConfig::default(),
281 }
282 }
283}
284
285#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
286#[serde(default)]
287pub struct ParallelConfig {
288 pub mode: ParallelismKind,
289 #[serde(default = "default_parallel_world_size")]
290 pub world_size: usize,
291 pub data: ParallelDataConfig,
292 pub tensor: ParallelTensorConfig,
293 pub fsdp: ParallelFsdpConfig,
294 pub checkpoint: ParallelCheckpointConfig,
295 pub pipeline: ParallelPipelineConfig,
296}
297
298impl Default for ParallelConfig {
299 fn default() -> Self {
300 Self {
301 mode: ParallelismKind::Single,
302 world_size: default_parallel_world_size(),
303 data: ParallelDataConfig::default(),
304 tensor: ParallelTensorConfig::default(),
305 fsdp: ParallelFsdpConfig::default(),
306 checkpoint: ParallelCheckpointConfig::default(),
307 pipeline: ParallelPipelineConfig::default(),
308 }
309 }
310}
311
312#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
313pub struct LowBitModelSpec {
314 pub enabled: bool,
315 pub protocol: BitNetLowBitProtocol,
316 pub training_mode: LowBitTrainingMode,
317 pub inference_mode: LowBitInferenceMode,
318 pub weight_format: LowBitWeightFormat,
319 pub activation_format: LowBitActivationFormat,
320 pub decoder_x_mode: LowBitWeightFormat,
321 pub encoder_mode: Option<LowBitWeightFormat>,
322 pub activation_grouping: LowBitActivationGrouping,
323 pub weight_grouping: LowBitWeightGrouping,
324 pub strict_bitnet_reference: bool,
325 pub target_modules: Vec<LowBitTargetModule>,
326 pub rho_precision: RhoPrecisionConfig,
327 pub rho_compression: RhoCompressionConfig,
328}
329
330#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
331pub struct ModelSpec {
332 pub arch: String,
333 pub n_embd: usize,
334 pub n_head: usize,
335 pub n_layer: usize,
336 pub latent_total: usize,
337 pub latent_per_head: usize,
338 pub shared_layer_weights: bool,
339 pub sequence_kernel: SequenceKernelConfig,
340 #[serde(default)]
341 pub bdh_initialization_kind: BdhInitializationKind,
342 #[serde(default)]
343 pub bdh_residual_scaling_kind: BdhResidualScalingKind,
344 #[serde(default)]
345 pub bdh_neuron_gain_kind: BdhNeuronGainKind,
346 #[serde(default)]
347 pub bdh_topology_prior_kind: BdhTopologyPriorKind,
348 #[serde(default)]
349 pub bdh_firing_target_kind: BdhFiringTargetKind,
350 #[serde(default, skip_serializing_if = "Option::is_none")]
351 pub low_bit: Option<LowBitModelSpec>,
352}
353
354#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
355pub struct ParallelSpec {
356 pub mode: ParallelismKind,
357 pub world_size: usize,
358 pub data_parallel_size: usize,
359 pub tensor_parallel_size: usize,
360 pub tensor_parallel_axis: TensorParallelAxis,
361 pub tensor_parallel_partition: TensorParallelPartitionKind,
362 pub fsdp_enabled: bool,
363 pub checkpoint_format: ParallelCheckpointFormat,
364 pub collective_num_nodes: Option<u32>,
365 pub collective_global_address: Option<String>,
366 pub collective_node_address: Option<String>,
367 pub collective_data_service_port: Option<u16>,
368 #[serde(default)]
369 pub pipeline_enabled: bool,
370 #[serde(default)]
371 pub pipeline_stage_count: usize,
372 #[serde(default)]
373 pub pipeline_virtual_stages_per_rank: usize,
374 #[serde(default)]
375 pub pipeline_schedule: PipelineScheduleKind,
376 #[serde(default)]
377 pub pipeline_microbatches: usize,
378 #[serde(default)]
379 pub pipeline_partition: PipelinePartitionKind,
380 #[serde(default)]
381 pub pipeline_activation_checkpointing: bool,
382 #[serde(default)]
383 pub pipeline_shared_weight_sync: PipelineSharedWeightSyncKind,
384 #[serde(default)]
385 pub pipeline_communication: PipelineCommunicationKind,
386 #[serde(default)]
387 pub pipeline_cache_enabled: bool,
388 #[serde(default)]
389 pub pipeline_cache_policy: PipelineCachePolicy,
390 #[serde(default)]
391 pub pipeline_cache_reuse_across_backward: bool,
392 #[serde(default)]
393 pub pipeline_cache_max_inflight_microbatches: usize,
394 #[serde(default)]
395 pub pipeline_cache_eviction: PipelineCacheEvictionKind,
396 #[serde(default)]
397 pub pipeline_cache_transport_dtype: PipelineTransportDtype,
398}
399
400#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
401pub struct KernelSpec {
402 pub sequence_kernel: SequenceKernelConfig,
403 pub fused_kernels_enabled: bool,
404 pub rollout_fast_steps_per_slow_step: usize,
405 pub wgpu_fused_core_recurrent: Option<bool>,
406 pub wgpu_fused_core_rollout: Option<bool>,
407 #[serde(default, skip_serializing_if = "Option::is_none")]
408 pub low_bit_kernel_abi_version: Option<u32>,
409 #[serde(default, skip_serializing_if = "Option::is_none")]
410 pub low_bit_runtime: Option<String>,
411 #[serde(default, skip_serializing_if = "Option::is_none")]
412 pub low_bit_saved_activation_mode: Option<LowBitSavedActivationMode>,
413 #[serde(default, skip_serializing_if = "Option::is_none")]
414 pub low_bit_saved_activation_format: Option<String>,
415 #[serde(default, skip_serializing_if = "Option::is_none")]
416 pub low_bit_saved_activation_inventory: Option<LowBitSavedActivationInventorySpec>,
417 #[serde(default, skip_serializing_if = "Option::is_none")]
418 pub low_bit_native_supported: Option<bool>,
419 #[serde(default, skip_serializing_if = "Option::is_none")]
420 pub low_bit_memory: Option<LowBitMemorySpec>,
421}
422
423#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
424pub struct LowBitMemorySpec {
425 pub master_weight_bytes: u64,
426 pub execution_weight_bytes: u64,
427 pub activation_shell_bytes: u64,
428 pub saved_activation_bytes: u64,
429 pub rho_state_bytes: u64,
430 pub workspace_bytes: u64,
431 pub estimated_total_bytes: u64,
432}
433
434#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
435pub struct LowBitSavedActivationInventorySpec {
436 pub mode: LowBitSavedActivationMode,
437 pub format: String,
438 pub requires_rho_window_anchor: bool,
439 pub tensors: Vec<LowBitSavedActivationTensorSpec>,
440}
441
442#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
443pub struct LowBitSavedActivationTensorSpec {
444 pub name: String,
445 pub shape: Vec<usize>,
446 pub element_count: u64,
447 pub estimated_bytes: u64,
448 pub recompute_policy: String,
449}
450
451#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
452pub struct OptimizerSpec {
453 pub name: super::optimizer::OptimizerKind,
454 pub learning_rate: f64,
455 pub weight_decay: f32,
456 #[serde(default, skip_serializing_if = "Option::is_none")]
457 pub weight_decay_final: Option<f32>,
458 pub schedule_mode: super::optimizer::OptimizerScheduleMode,
459}
460
461#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
462pub struct StateAxisSpec {
463 pub name: String,
464 pub size: Option<usize>,
465}
466
467#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
468pub struct StateTensorSpec {
469 pub name: String,
470 pub axes: Vec<StateAxisSpec>,
471}
472
473#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
474pub struct LayerStateSpec {
475 pub layer_index: usize,
476 pub latent_total: usize,
477 pub latent_per_head: usize,
478 pub tensors: Vec<StateTensorSpec>,
479}
480
481#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
482pub struct StateLayout {
483 pub state_family: String,
484 pub position_tracked: bool,
485 pub layers: Vec<LayerStateSpec>,
486}
487
488#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)]
489#[serde(rename_all = "snake_case")]
490pub enum WgpuBackend {
491 #[default]
492 Auto,
493 Vulkan,
494 Dx12,
495 Metal,
496 #[serde(rename = "opengl")]
497 OpenGl,
498}
499
500#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)]
501pub enum WgpuMemoryConfig {
502 #[serde(rename = "subslices")]
503 #[default]
504 SubSlices,
505 Exclusive,
506}
507
508#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)]
509#[serde(rename_all = "snake_case")]
510pub enum WgpuGenerationExecutor {
511 #[default]
512 Baseline,
513 RolloutChunked,
514}
515
516fn default_generation_chunk_tokens() -> usize {
517 8
518}
519
520fn default_generation_device_buffer_tokens() -> usize {
521 64
522}
523
524#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
525#[serde(default)]
526pub struct WgpuInferenceConfig {
527 pub fused_core_recurrent: Option<bool>,
528 pub fused_core_rollout: Option<bool>,
529 pub generation_executor: WgpuGenerationExecutor,
530 #[serde(default = "default_generation_chunk_tokens")]
531 pub generation_chunk_tokens: usize,
532 #[serde(default = "default_generation_device_buffer_tokens")]
533 pub generation_device_buffer_tokens: usize,
534}
535
536impl Default for WgpuInferenceConfig {
537 fn default() -> Self {
538 Self {
539 fused_core_recurrent: None,
540 fused_core_rollout: None,
541 generation_executor: WgpuGenerationExecutor::Baseline,
542 generation_chunk_tokens: default_generation_chunk_tokens(),
543 generation_device_buffer_tokens: default_generation_device_buffer_tokens(),
544 }
545 }
546}
547
548#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)]
549#[serde(default)]
550pub struct WgpuTrainingConfig {
551 pub fused_core_recurrent: Option<bool>,
552 pub fused_core_rollout: Option<bool>,
553 pub startup_autotune: WgpuStartupAutotuneConfig,
554}
555
556fn default_startup_autotune_min_batch_size() -> usize {
557 1
558}
559
560fn default_startup_autotune_probe_steps() -> usize {
561 1
562}
563
564fn default_startup_autotune_binary_search() -> bool {
565 true
566}
567
568#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
569#[serde(default)]
570pub struct WgpuStartupAutotuneConfig {
571 pub enabled: bool,
572 pub target_device_memory_mb: usize,
573 #[serde(default = "default_startup_autotune_min_batch_size")]
574 pub min_batch_size: usize,
575 pub max_batch_size: Option<usize>,
576 #[serde(default = "default_startup_autotune_probe_steps")]
577 pub probe_steps: usize,
578 #[serde(default = "default_startup_autotune_binary_search")]
579 pub binary_search: bool,
580}
581
582impl Default for WgpuStartupAutotuneConfig {
583 fn default() -> Self {
584 Self {
585 enabled: false,
586 target_device_memory_mb: 0,
587 min_batch_size: default_startup_autotune_min_batch_size(),
588 max_batch_size: None,
589 probe_steps: default_startup_autotune_probe_steps(),
590 binary_search: default_startup_autotune_binary_search(),
591 }
592 }
593}
594
595#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)]
596#[serde(default)]
597pub struct WgpuRuntimeConfig {
598 pub backend: WgpuBackend,
599 pub tasks_max: Option<usize>,
600 pub memory: WgpuMemoryConfig,
601 pub training: WgpuTrainingConfig,
602 pub inference: WgpuInferenceConfig,
603}
604
605#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
606#[serde(default)]
607pub struct GdpoConfig {
608 pub enabled: bool,
609 pub group_size: usize,
610 pub hard_weight: f32,
611 pub easy_weight: f32,
612 pub policy_weight: f32,
613 pub policy_clip_range: f32,
614 pub hard_gate: GdpoHardGate,
615 pub norm_epsilon: f32,
616 pub advantage_clip: f32,
617 pub advantage_ema_decay: f32,
618}
619
620impl Default for GdpoConfig {
621 fn default() -> Self {
622 Self {
623 enabled: false,
624 group_size: 1,
625 hard_weight: 1.0,
626 easy_weight: 1.0,
627 policy_weight: 1.0,
628 policy_clip_range: 0.2,
629 hard_gate: GdpoHardGate::Percentile { quantile: 0.5 },
630 norm_epsilon: 1e-6,
631 advantage_clip: 0.0,
632 advantage_ema_decay: 0.0,
633 }
634 }
635}
636
637#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
638#[serde(tag = "type", rename_all = "snake_case")]
639pub enum GdpoHardGate {
640 Off,
641 Fixed { threshold: f32 },
642 Percentile { quantile: f32 },
643}
644
645impl Default for GdpoHardGate {
646 fn default() -> Self {
647 Self::Percentile { quantile: 0.5 }
648 }
649}
650
651impl fmt::Display for GdpoHardGate {
652 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
653 match self {
654 Self::Off => write!(f, "off"),
655 Self::Fixed { threshold } => write!(f, "fixed(threshold={threshold:.4})"),
656 Self::Percentile { quantile } => write!(f, "percentile(quantile={quantile:.3})"),
657 }
658 }
659}
660
661impl ModuleDisplayDefault for GdpoHardGate {
662 fn content(&self, content: Content) -> Option<Content> {
663 content.add_formatted(self).optional()
664 }
665}
666
667impl ModuleDisplay for GdpoHardGate {}
668
669impl ModuleDisplayDefault for GdpoConfig {
670 fn content(&self, content: Content) -> Option<Content> {
671 content
672 .add("enabled", &self.enabled)
673 .add("group_size", &self.group_size)
674 .add("hard_weight", &self.hard_weight)
675 .add("easy_weight", &self.easy_weight)
676 .add("policy_weight", &self.policy_weight)
677 .add("policy_clip_range", &self.policy_clip_range)
678 .add("hard_gate", &self.hard_gate)
679 .add("norm_epsilon", &self.norm_epsilon)
680 .add("advantage_clip", &self.advantage_clip)
681 .add("advantage_ema_decay", &self.advantage_ema_decay)
682 .optional()
683 }
684}
685
686impl ModuleDisplay for GdpoConfig {}
687
688#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)]
689#[serde(rename_all = "snake_case")]
690pub enum VisionTeacherVariant {
691 #[default]
692 Vits,
693 Vitb,
694 Vitl,
695 Vitg,
696}