1use yscv_tensor::{Tensor, TensorError};
2
3use crate::{ImageAugmentationOp, ImageAugmentationPipeline, ModelError};
4
5use super::helpers::{
6 LcgRng, class_balanced_sampling_weights, should_apply_probability, shuffle_indices,
7};
8use super::types::{BatchIterOptions, SamplingPolicy, SupervisedDataset};
9
10#[derive(Debug, Clone, PartialEq)]
12pub struct MixUpConfig {
13 probability: f32,
14 lambda_min: f32,
15}
16
17impl Default for MixUpConfig {
18 fn default() -> Self {
19 Self {
20 probability: 1.0,
21 lambda_min: 0.0,
22 }
23 }
24}
25
26impl MixUpConfig {
27 pub fn new() -> Self {
28 Self::default()
29 }
30
31 pub fn with_probability(mut self, probability: f32) -> Result<Self, ModelError> {
32 validate_mixup_probability(probability)?;
33 self.probability = probability;
34 Ok(self)
35 }
36
37 pub fn with_lambda_min(mut self, lambda_min: f32) -> Result<Self, ModelError> {
38 validate_mixup_lambda_min(lambda_min)?;
39 self.lambda_min = lambda_min;
40 Ok(self)
41 }
42
43 pub fn probability(&self) -> f32 {
44 self.probability
45 }
46
47 pub fn lambda_min(&self) -> f32 {
48 self.lambda_min
49 }
50}
51
52#[derive(Debug, Clone, PartialEq)]
54pub struct CutMixConfig {
55 probability: f32,
56 min_patch_fraction: f32,
57 max_patch_fraction: f32,
58}
59
60impl Default for CutMixConfig {
61 fn default() -> Self {
62 Self {
63 probability: 1.0,
64 min_patch_fraction: 0.1,
65 max_patch_fraction: 0.5,
66 }
67 }
68}
69
70impl CutMixConfig {
71 pub fn new() -> Self {
72 Self::default()
73 }
74
75 pub fn with_probability(mut self, probability: f32) -> Result<Self, ModelError> {
76 validate_cutmix_probability(probability)?;
77 self.probability = probability;
78 Ok(self)
79 }
80
81 pub fn with_min_patch_fraction(mut self, min_patch_fraction: f32) -> Result<Self, ModelError> {
82 validate_cutmix_patch_fraction("min_patch_fraction", min_patch_fraction)?;
83 self.min_patch_fraction = min_patch_fraction;
84 if self.min_patch_fraction > self.max_patch_fraction {
85 return Err(ModelError::InvalidCutMixArgument {
86 field: "min_patch_fraction",
87 value: self.min_patch_fraction,
88 message: format!(
89 "min_patch_fraction must be <= max_patch_fraction ({})",
90 self.max_patch_fraction
91 ),
92 });
93 }
94 Ok(self)
95 }
96
97 pub fn with_max_patch_fraction(mut self, max_patch_fraction: f32) -> Result<Self, ModelError> {
98 validate_cutmix_patch_fraction("max_patch_fraction", max_patch_fraction)?;
99 self.max_patch_fraction = max_patch_fraction;
100 if self.min_patch_fraction > self.max_patch_fraction {
101 return Err(ModelError::InvalidCutMixArgument {
102 field: "max_patch_fraction",
103 value: self.max_patch_fraction,
104 message: format!(
105 "max_patch_fraction must be >= min_patch_fraction ({})",
106 self.min_patch_fraction
107 ),
108 });
109 }
110 Ok(self)
111 }
112
113 pub fn probability(&self) -> f32 {
114 self.probability
115 }
116
117 pub fn min_patch_fraction(&self) -> f32 {
118 self.min_patch_fraction
119 }
120
121 pub fn max_patch_fraction(&self) -> f32 {
122 self.max_patch_fraction
123 }
124}
125
126pub(super) struct MixUpBatch {
127 pub(super) inputs: Tensor,
128 pub(super) targets: Tensor,
129}
130
131pub(super) fn validate_augmentation_compatibility(
132 inputs: &Tensor,
133 pipeline: &ImageAugmentationPipeline,
134) -> Result<(), ModelError> {
135 if inputs.rank() != 4 {
136 return Err(ModelError::InvalidAugmentationInputShape {
137 got: inputs.shape().to_vec(),
138 });
139 }
140 let channels = inputs.shape()[3];
141 for op in pipeline.ops() {
142 if let ImageAugmentationOp::ChannelNormalize { mean, std: _ } = op
143 && mean.len() != channels
144 {
145 return Err(ModelError::InvalidAugmentationArgument {
146 operation: "channel_normalize",
147 message: format!(
148 "channel count mismatch: dataset_channels={channels}, mean/std_len={}",
149 mean.len()
150 ),
151 });
152 }
153 }
154 Ok(())
155}
156
157pub(super) fn validate_mixup_config(config: &MixUpConfig) -> Result<(), ModelError> {
158 validate_mixup_probability(config.probability())?;
159 validate_mixup_lambda_min(config.lambda_min())?;
160 Ok(())
161}
162
163pub(super) fn validate_cutmix_config(config: &CutMixConfig) -> Result<(), ModelError> {
164 validate_cutmix_probability(config.probability())?;
165 validate_cutmix_patch_fraction("min_patch_fraction", config.min_patch_fraction())?;
166 validate_cutmix_patch_fraction("max_patch_fraction", config.max_patch_fraction())?;
167 if config.min_patch_fraction() > config.max_patch_fraction() {
168 return Err(ModelError::InvalidCutMixArgument {
169 field: "min_patch_fraction",
170 value: config.min_patch_fraction(),
171 message: format!(
172 "min_patch_fraction must be <= max_patch_fraction ({})",
173 config.max_patch_fraction()
174 ),
175 });
176 }
177 Ok(())
178}
179
180fn validate_mixup_probability(probability: f32) -> Result<(), ModelError> {
181 if !probability.is_finite() || !(0.0..=1.0).contains(&probability) {
182 return Err(ModelError::InvalidMixupArgument {
183 field: "probability",
184 value: probability,
185 message: "probability must be finite and in [0, 1]".to_string(),
186 });
187 }
188 Ok(())
189}
190
191fn validate_cutmix_probability(probability: f32) -> Result<(), ModelError> {
192 if !probability.is_finite() || !(0.0..=1.0).contains(&probability) {
193 return Err(ModelError::InvalidCutMixArgument {
194 field: "probability",
195 value: probability,
196 message: "probability must be finite and in [0, 1]".to_string(),
197 });
198 }
199 Ok(())
200}
201
202fn validate_cutmix_patch_fraction(field: &'static str, value: f32) -> Result<(), ModelError> {
203 if !value.is_finite() || !(0.0..=1.0).contains(&value) {
204 return Err(ModelError::InvalidCutMixArgument {
205 field,
206 value,
207 message: format!("{field} must be finite and in [0, 1]"),
208 });
209 }
210 Ok(())
211}
212
213pub(super) fn validate_cutmix_compatibility(inputs: &Tensor) -> Result<(), ModelError> {
214 if inputs.rank() != 4 {
215 return Err(ModelError::InvalidCutMixInputShape {
216 got: inputs.shape().to_vec(),
217 });
218 }
219 Ok(())
220}
221
222fn validate_mixup_lambda_min(lambda_min: f32) -> Result<(), ModelError> {
223 if !lambda_min.is_finite() || !(0.0..=0.5).contains(&lambda_min) {
224 return Err(ModelError::InvalidMixupArgument {
225 field: "lambda_min",
226 value: lambda_min,
227 message: "lambda_min must be finite and in [0, 0.5]".to_string(),
228 });
229 }
230 Ok(())
231}
232
233pub(super) fn apply_mixup_batch(
234 inputs: &Tensor,
235 targets: &Tensor,
236 config: &MixUpConfig,
237 seed: u64,
238) -> Result<MixUpBatch, ModelError> {
239 validate_mixup_config(config)?;
240 if inputs.rank() == 0 || targets.rank() == 0 {
241 return Err(ModelError::InvalidDatasetRank {
242 inputs_rank: inputs.rank(),
243 targets_rank: targets.rank(),
244 });
245 }
246 let batch_size = inputs.shape()[0];
247 if batch_size != targets.shape()[0] {
248 return Err(ModelError::DatasetShapeMismatch {
249 inputs: inputs.shape().to_vec(),
250 targets: targets.shape().to_vec(),
251 });
252 }
253 if batch_size < 2 {
254 return Ok(MixUpBatch {
255 inputs: inputs.clone(),
256 targets: targets.clone(),
257 });
258 }
259
260 let mut rng = LcgRng::new(seed);
261 if !should_apply_probability(config.probability(), &mut rng) {
262 return Ok(MixUpBatch {
263 inputs: inputs.clone(),
264 targets: targets.clone(),
265 });
266 }
267
268 let lambda =
269 config.lambda_min() + rng.next_unit_f64() as f32 * (1.0 - 2.0 * config.lambda_min());
270 let partner_indices = build_partner_indices(batch_size, seed ^ 0xA5A5_A5A5_5A5A_5A5A);
271
272 Ok(MixUpBatch {
273 inputs: blend_rows(inputs, &partner_indices, lambda)?,
274 targets: blend_rows(targets, &partner_indices, lambda)?,
275 })
276}
277
278pub(super) fn apply_cutmix_batch(
279 inputs: &Tensor,
280 targets: &Tensor,
281 config: &CutMixConfig,
282 seed: u64,
283) -> Result<MixUpBatch, ModelError> {
284 validate_cutmix_config(config)?;
285 validate_cutmix_compatibility(inputs)?;
286 if targets.rank() == 0 {
287 return Err(ModelError::InvalidDatasetRank {
288 inputs_rank: inputs.rank(),
289 targets_rank: targets.rank(),
290 });
291 }
292 let batch_size = inputs.shape()[0];
293 if batch_size != targets.shape()[0] {
294 return Err(ModelError::DatasetShapeMismatch {
295 inputs: inputs.shape().to_vec(),
296 targets: targets.shape().to_vec(),
297 });
298 }
299 if batch_size < 2 {
300 return Ok(MixUpBatch {
301 inputs: inputs.clone(),
302 targets: targets.clone(),
303 });
304 }
305
306 let mut rng = LcgRng::new(seed);
307 if !should_apply_probability(config.probability(), &mut rng) {
308 return Ok(MixUpBatch {
309 inputs: inputs.clone(),
310 targets: targets.clone(),
311 });
312 }
313
314 let height = inputs.shape()[1];
315 let width = inputs.shape()[2];
316 let channels = inputs.shape()[3];
317 if height == 0 || width == 0 || channels == 0 {
318 return Ok(MixUpBatch {
319 inputs: inputs.clone(),
320 targets: targets.clone(),
321 });
322 }
323
324 let input_row_width = height
325 .checked_mul(width)
326 .and_then(|value| value.checked_mul(channels))
327 .ok_or_else(|| {
328 ModelError::Tensor(TensorError::SizeOverflow {
329 shape: inputs.shape().to_vec(),
330 })
331 })?;
332 let target_row_width = targets.shape()[1..]
333 .iter()
334 .try_fold(1usize, |acc, dim| acc.checked_mul(*dim))
335 .ok_or_else(|| {
336 ModelError::Tensor(TensorError::SizeOverflow {
337 shape: targets.shape().to_vec(),
338 })
339 })?;
340
341 let mut mixed_inputs = inputs.data().to_vec();
342 let mut mixed_targets = targets.data().to_vec();
343 let partner_indices = build_partner_indices(batch_size, seed ^ 0x5A5A_A5A5_DEAD_BEEF);
344 let total_pixels = (height * width) as f32;
345
346 for (row_index, partner_index) in partner_indices.iter().enumerate() {
347 let patch_fraction = sample_cutmix_patch_fraction(config, &mut rng);
348 let patch_height = ((height as f32 * patch_fraction).floor() as usize)
349 .max(1)
350 .min(height);
351 let patch_width = ((width as f32 * patch_fraction).floor() as usize)
352 .max(1)
353 .min(width);
354 let top = rng.next_usize(height - patch_height + 1);
355 let left = rng.next_usize(width - patch_width + 1);
356
357 let row_start = row_index.checked_mul(input_row_width).ok_or_else(|| {
358 ModelError::Tensor(TensorError::SizeOverflow {
359 shape: inputs.shape().to_vec(),
360 })
361 })?;
362 let partner_start = partner_index.checked_mul(input_row_width).ok_or_else(|| {
363 ModelError::Tensor(TensorError::SizeOverflow {
364 shape: inputs.shape().to_vec(),
365 })
366 })?;
367
368 for y in 0..patch_height {
369 for x in 0..patch_width {
370 let pixel_offset = ((top + y) * width + (left + x)) * channels;
371 let dst = row_start + pixel_offset;
372 let src = partner_start + pixel_offset;
373 mixed_inputs[dst..(dst + channels)]
374 .copy_from_slice(&inputs.data()[src..(src + channels)]);
375 }
376 }
377
378 let replaced_ratio = (patch_height * patch_width) as f32 / total_pixels;
379 let lambda = 1.0 - replaced_ratio;
380 let target_row_start = row_index.checked_mul(target_row_width).ok_or_else(|| {
381 ModelError::Tensor(TensorError::SizeOverflow {
382 shape: targets.shape().to_vec(),
383 })
384 })?;
385 let partner_target_start =
386 partner_index.checked_mul(target_row_width).ok_or_else(|| {
387 ModelError::Tensor(TensorError::SizeOverflow {
388 shape: targets.shape().to_vec(),
389 })
390 })?;
391 for offset in 0..target_row_width {
392 mixed_targets[target_row_start + offset] = lambda
393 * targets.data()[target_row_start + offset]
394 + (1.0 - lambda) * targets.data()[partner_target_start + offset];
395 }
396 }
397
398 Ok(MixUpBatch {
399 inputs: Tensor::from_vec(inputs.shape().to_vec(), mixed_inputs)?,
400 targets: Tensor::from_vec(targets.shape().to_vec(), mixed_targets)?,
401 })
402}
403
404fn sample_cutmix_patch_fraction(config: &CutMixConfig, rng: &mut LcgRng) -> f32 {
405 if (config.max_patch_fraction() - config.min_patch_fraction()).abs() <= f32::EPSILON {
406 return config.min_patch_fraction();
407 }
408 config.min_patch_fraction()
409 + rng.next_unit_f64() as f32 * (config.max_patch_fraction() - config.min_patch_fraction())
410}
411
412fn blend_rows(
413 tensor: &Tensor,
414 partner_indices: &[usize],
415 lambda: f32,
416) -> Result<Tensor, ModelError> {
417 if tensor.rank() == 0 {
418 return Err(ModelError::InvalidDatasetRank {
419 inputs_rank: tensor.rank(),
420 targets_rank: tensor.rank(),
421 });
422 }
423 let batch_size = tensor.shape()[0];
424 if partner_indices.len() != batch_size {
425 return Err(ModelError::DatasetShapeMismatch {
426 inputs: tensor.shape().to_vec(),
427 targets: vec![partner_indices.len()],
428 });
429 }
430 let row_width = tensor.shape()[1..]
431 .iter()
432 .try_fold(1usize, |acc, dim| acc.checked_mul(*dim))
433 .ok_or_else(|| {
434 ModelError::Tensor(TensorError::SizeOverflow {
435 shape: tensor.shape().to_vec(),
436 })
437 })?;
438
439 let mut out = vec![0.0f32; tensor.len()];
440 let left_weight = lambda;
441 let right_weight = 1.0 - lambda;
442 for (row_index, partner_index) in partner_indices.iter().enumerate() {
443 if *partner_index >= batch_size {
444 return Err(ModelError::DatasetShapeMismatch {
445 inputs: tensor.shape().to_vec(),
446 targets: vec![*partner_index, batch_size],
447 });
448 }
449 let row_start = row_index.checked_mul(row_width).ok_or_else(|| {
450 ModelError::Tensor(TensorError::SizeOverflow {
451 shape: tensor.shape().to_vec(),
452 })
453 })?;
454 let partner_start = partner_index.checked_mul(row_width).ok_or_else(|| {
455 ModelError::Tensor(TensorError::SizeOverflow {
456 shape: tensor.shape().to_vec(),
457 })
458 })?;
459 for offset in 0..row_width {
460 let dst = row_start + offset;
461 out[dst] = left_weight * tensor.data()[row_start + offset]
462 + right_weight * tensor.data()[partner_start + offset];
463 }
464 }
465 Tensor::from_vec(tensor.shape().to_vec(), out).map_err(Into::into)
466}
467
468fn build_partner_indices(batch_size: usize, seed: u64) -> Vec<usize> {
469 let mut partner_indices = (0..batch_size).collect::<Vec<_>>();
470 shuffle_indices(&mut partner_indices, seed);
471 if partner_indices
472 .iter()
473 .enumerate()
474 .all(|(index, partner)| index == *partner)
475 {
476 partner_indices.rotate_left(1);
477 }
478 partner_indices
479}
480
481pub(super) fn build_sample_order(
482 dataset: &SupervisedDataset,
483 options: &BatchIterOptions,
484) -> Result<Vec<usize>, ModelError> {
485 if let Some(policy) = options.sampling.as_ref() {
486 return build_sample_order_from_policy(dataset, policy);
487 }
488
489 let mut order = (0..dataset.len()).collect::<Vec<_>>();
490 if options.shuffle {
491 shuffle_indices(&mut order, options.shuffle_seed);
492 }
493 Ok(order)
494}
495
496fn build_sample_order_from_policy(
497 dataset: &SupervisedDataset,
498 policy: &SamplingPolicy,
499) -> Result<Vec<usize>, ModelError> {
500 let dataset_len = dataset.len();
501 match policy {
502 SamplingPolicy::Sequential => Ok((0..dataset_len).collect()),
503 SamplingPolicy::Shuffled { seed } => {
504 let mut order = (0..dataset_len).collect::<Vec<_>>();
505 shuffle_indices(&mut order, *seed);
506 Ok(order)
507 }
508 SamplingPolicy::BalancedByClass {
509 seed,
510 with_replacement,
511 } => {
512 let weights = class_balanced_sampling_weights(dataset.targets())?;
513 if *with_replacement {
514 sample_weighted_with_replacement(&weights, dataset_len, *seed)
515 } else {
516 sample_weighted_without_replacement(&weights, *seed)
517 }
518 }
519 SamplingPolicy::Weighted {
520 weights,
521 seed,
522 with_replacement,
523 } => {
524 validate_sampling_weights(weights, dataset_len)?;
525 if *with_replacement {
526 sample_weighted_with_replacement(weights, dataset_len, *seed)
527 } else {
528 sample_weighted_without_replacement(weights, *seed)
529 }
530 }
531 }
532}
533
534fn validate_sampling_weights(weights: &[f32], dataset_len: usize) -> Result<(), ModelError> {
535 if weights.len() != dataset_len {
536 return Err(ModelError::InvalidSamplingWeightsLength {
537 expected: dataset_len,
538 got: weights.len(),
539 });
540 }
541 let mut positive = false;
542 for (index, weight) in weights.iter().enumerate() {
543 if !weight.is_finite() || *weight < 0.0 {
544 return Err(ModelError::InvalidSamplingWeight {
545 index,
546 value: *weight,
547 });
548 }
549 if *weight > 0.0 {
550 positive = true;
551 }
552 }
553 if !positive && dataset_len > 0 {
554 return Err(ModelError::InvalidSamplingDistribution);
555 }
556 Ok(())
557}
558
559fn sample_weighted_with_replacement(
560 weights: &[f32],
561 draw_count: usize,
562 seed: u64,
563) -> Result<Vec<usize>, ModelError> {
564 if draw_count == 0 {
565 return Ok(Vec::new());
566 }
567
568 let mut cumulative = Vec::with_capacity(weights.len());
569 let mut total = 0.0f64;
570 for weight in weights {
571 total += *weight as f64;
572 cumulative.push(total);
573 }
574 if total <= 0.0 {
575 return Err(ModelError::InvalidSamplingDistribution);
576 }
577
578 let mut rng = LcgRng::new(seed);
579 let mut out = Vec::with_capacity(draw_count);
580 for _ in 0..draw_count {
581 let draw = rng.next_unit_f64() * total;
582 let mut sampled = cumulative.partition_point(|prefix| *prefix <= draw);
583 if sampled >= weights.len() {
584 sampled = weights.len() - 1;
585 }
586 out.push(sampled);
587 }
588 Ok(out)
589}
590
591fn sample_weighted_without_replacement(
592 weights: &[f32],
593 seed: u64,
594) -> Result<Vec<usize>, ModelError> {
595 if weights.is_empty() {
596 return Ok(Vec::new());
597 }
598
599 let mut rng = LcgRng::new(seed);
600 let mut keyed = Vec::with_capacity(weights.len());
601 for (index, weight) in weights.iter().enumerate() {
602 let key = if *weight == 0.0 {
603 0.0
604 } else {
605 let u = rng.next_unit_open_f64();
606 u.powf(1.0 / *weight as f64)
607 };
608 keyed.push((index, key));
609 }
610
611 keyed.sort_by(|left, right| {
612 right
613 .1
614 .total_cmp(&left.1)
615 .then_with(|| left.0.cmp(&right.0))
616 });
617 Ok(keyed.into_iter().map(|(index, _)| index).collect())
618}