use crate::train::prelude::*;
use burn_dragon_train::train::runtime::{
DeviceMemoryUsage, cleanup_device_memory, device_memory_usage_safe,
};
#[derive(Debug, Clone, Serialize)]
pub struct StartupAutotuneProbe {
pub batch_size: usize,
pub reserved_mb: Option<f64>,
pub in_use_mb: Option<f64>,
pub fit_target: bool,
pub status: String,
}
#[derive(Debug, Clone, Serialize)]
pub struct StartupAutotuneReport {
pub backend_name: String,
pub target_device_memory_mb: usize,
pub target_effective_batch_size: Option<usize>,
pub min_batch_size: usize,
pub max_batch_size: usize,
pub probe_steps: usize,
pub resolved_batch_size: usize,
pub resolved_gradient_accumulation_steps: usize,
pub resolved_effective_batch_size: usize,
pub probes: Vec<StartupAutotuneProbe>,
}
pub fn resolve_startup_batch_size<B>(
config: &TrainingConfig,
dataset: &Arc<Dataset>,
backend_name: &str,
device: &B::Device,
) -> Result<Option<StartupAutotuneReport>>
where
B: AutodiffBackend + Clone + 'static,
B::Device: Clone + 'static,
{
let autotune = &config.wgpu.training.startup_autotune;
if !autotune.enabled || !backend_name.starts_with("wgpu") {
return Ok(None);
}
let min_batch_size = autotune.min_batch_size.max(1);
let max_batch_size = autotune
.max_batch_size
.unwrap_or(config.training.batch_size)
.max(min_batch_size);
let target_effective_batch_size = config
.training
.target_effective_batch_size
.filter(|value| *value > 0);
let training_kernel_block_size =
crate::train::utils::effective_training_kernel_block_size(&config.training);
let tokenizer = dataset.tokenizer();
let mut model_config = build_model_config_with_tokenizer(
&config.model,
training_kernel_block_size,
tokenizer.as_ref(),
)?;
apply_wgpu_fused_core_override(
&mut model_config,
backend_name,
WgpuFusedCoreOverride {
recurrent: config.wgpu.training.fused_core_recurrent,
rollout: config.wgpu.training.fused_core_rollout,
},
);
let summary_event_token_ids = model_config.summary_memory.write_trigger_token_ids.clone();
let mut probes = Vec::new();
let target_bytes = (autotune.target_device_memory_mb as u64).saturating_mul(1024 * 1024);
let (mut low, mut high) = (min_batch_size, max_batch_size);
let mut best_fit = None;
for candidate in startup_candidate_sequence(min_batch_size, max_batch_size) {
let probe = probe_batch_size::<B>(ProbeBatchRequest {
dataset,
model_config: &model_config,
block_size: config.training.block_size,
tbptt_chunk_size: config.training.tbptt_chunk_size,
batch_size: candidate,
probe_steps: autotune.probe_steps.max(1),
target_bytes,
summary_event_token_ids: summary_event_token_ids.as_deref(),
device,
});
let fit_target = probe.fit_target;
probes.push(probe);
if fit_target {
best_fit = Some(candidate);
low = candidate;
if candidate == max_batch_size {
break;
}
} else {
high = candidate;
break;
}
}
if autotune.binary_search && best_fit.is_some() && high > low + 1 {
while high > low + 1 {
let candidate = low + ((high - low) / 2);
let probe = probe_batch_size::<B>(ProbeBatchRequest {
dataset,
model_config: &model_config,
block_size: config.training.block_size,
tbptt_chunk_size: config.training.tbptt_chunk_size,
batch_size: candidate,
probe_steps: autotune.probe_steps.max(1),
target_bytes,
summary_event_token_ids: summary_event_token_ids.as_deref(),
device,
});
let fit_target = probe.fit_target;
probes.push(probe);
if fit_target {
best_fit = Some(candidate);
low = candidate;
} else {
high = candidate;
}
}
}
let Some(resolved_batch_size) = best_fit else {
return Err(anyhow!(
"startup autotune could not find a safe batch size between {} and {} under target {} MiB; probes={}",
min_batch_size,
max_batch_size,
autotune.target_device_memory_mb,
format_probe_summary(&probes)
));
};
info!(
"startup autotune: resolved batch_size={} grad_accumulation_steps={} effective_batch_size={} (target={} MiB, probed {} candidates)",
resolved_batch_size,
resolve_gradient_accumulation_steps(
resolved_batch_size,
config.training.gradient_accumulation_steps,
target_effective_batch_size,
),
resolved_batch_size.saturating_mul(resolve_gradient_accumulation_steps(
resolved_batch_size,
config.training.gradient_accumulation_steps,
target_effective_batch_size,
)),
autotune.target_device_memory_mb,
probes.len(),
);
let resolved_gradient_accumulation_steps = resolve_gradient_accumulation_steps(
resolved_batch_size,
config.training.gradient_accumulation_steps,
target_effective_batch_size,
);
let resolved_effective_batch_size =
resolved_batch_size.saturating_mul(resolved_gradient_accumulation_steps);
Ok(Some(StartupAutotuneReport {
backend_name: backend_name.to_string(),
target_device_memory_mb: autotune.target_device_memory_mb,
target_effective_batch_size,
min_batch_size,
max_batch_size,
probe_steps: autotune.probe_steps.max(1),
resolved_batch_size,
resolved_gradient_accumulation_steps,
resolved_effective_batch_size,
probes,
}))
}
fn probe_batch_size<B>(request: ProbeBatchRequest<'_, B>) -> StartupAutotuneProbe
where
B: AutodiffBackend + Clone + 'static,
B::Device: Clone + 'static,
{
let ProbeBatchRequest {
dataset,
model_config,
block_size,
tbptt_chunk_size,
batch_size,
probe_steps,
target_bytes,
summary_event_token_ids,
device,
} = request;
cleanup_device_memory::<B>(device, false);
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let model = LanguageTrainModel::new(BDH::<B>::new(model_config.clone(), device))
.with_tbptt_chunk_size(tbptt_chunk_size);
let mut peak_usage: Option<DeviceMemoryUsage> = None;
for _ in 0..probe_steps {
let batch = sample_batch_with_shape::<B, _>(
&**dataset,
DatasetSplit::Train,
batch_size,
block_size,
summary_event_token_ids,
0,
device,
);
let output = burn_train::TrainStep::step(&model, batch);
drop(output);
let _ = B::sync(device);
if let Some(usage) = device_memory_usage_safe::<B>(device) {
peak_usage = Some(match peak_usage {
Some(current)
if current.reserved_bytes.max(current.in_use_bytes)
>= usage.reserved_bytes.max(usage.in_use_bytes) =>
{
current
}
_ => usage,
});
}
}
drop(model);
cleanup_device_memory::<B>(device, false);
peak_usage
}));
match result {
Ok(peak_usage) => {
let fit_target = peak_usage
.map(|usage| usage.reserved_bytes.max(usage.in_use_bytes) <= target_bytes)
.unwrap_or(false);
StartupAutotuneProbe {
batch_size,
reserved_mb: peak_usage.map(DeviceMemoryUsage::reserved_mb),
in_use_mb: peak_usage.map(DeviceMemoryUsage::in_use_mb),
fit_target,
status: if fit_target {
"fit".to_string()
} else {
"over_target".to_string()
},
}
}
Err(_) => {
cleanup_device_memory::<B>(device, false);
StartupAutotuneProbe {
batch_size,
reserved_mb: None,
in_use_mb: None,
fit_target: false,
status: "probe_failed".to_string(),
}
}
}
}
struct ProbeBatchRequest<'a, B: AutodiffBackend> {
dataset: &'a Arc<Dataset>,
model_config: &'a BDHConfig,
block_size: usize,
tbptt_chunk_size: Option<usize>,
batch_size: usize,
probe_steps: usize,
target_bytes: u64,
summary_event_token_ids: Option<&'a [u32]>,
device: &'a B::Device,
}
fn startup_candidate_sequence(min_batch_size: usize, max_batch_size: usize) -> Vec<usize> {
let mut candidates = Vec::new();
let mut current = min_batch_size.max(1);
candidates.push(current);
while current < max_batch_size {
let next = current.saturating_mul(2).min(max_batch_size);
if next == current {
break;
}
candidates.push(next);
current = next;
}
candidates
}
pub fn resolve_gradient_accumulation_steps(
resolved_batch_size: usize,
configured_gradient_accumulation_steps: usize,
target_effective_batch_size: Option<usize>,
) -> usize {
match target_effective_batch_size {
Some(target_effective_batch_size) => target_effective_batch_size
.max(resolved_batch_size.max(1))
.div_ceil(resolved_batch_size.max(1))
.max(1),
None => configured_gradient_accumulation_steps.max(1),
}
}
fn format_probe_summary(probes: &[StartupAutotuneProbe]) -> String {
probes
.iter()
.map(|probe| match (probe.reserved_mb, probe.in_use_mb) {
(Some(reserved), Some(in_use)) => format!(
"bs{}:{}:{reserved:.1}/{in_use:.1}MiB",
probe.batch_size, probe.status
),
_ => format!("bs{}:{}", probe.batch_size, probe.status),
})
.collect::<Vec<_>>()
.join(",")
}
#[cfg(test)]
mod tests {
use super::{resolve_gradient_accumulation_steps, startup_candidate_sequence};
#[test]
fn startup_candidate_sequence_doubles_and_caps_at_max() {
assert_eq!(startup_candidate_sequence(4, 4), vec![4]);
assert_eq!(startup_candidate_sequence(4, 20), vec![4, 8, 16, 20]);
assert_eq!(startup_candidate_sequence(3, 24), vec![3, 6, 12, 24]);
}
#[test]
fn resolve_gradient_accumulation_steps_ceil_divides_to_target_effective_batch() {
assert_eq!(resolve_gradient_accumulation_steps(64, 1, None), 1);
assert_eq!(resolve_gradient_accumulation_steps(64, 3, None), 3);
assert_eq!(resolve_gradient_accumulation_steps(64, 1, Some(64)), 1);
assert_eq!(resolve_gradient_accumulation_steps(21, 1, Some(64)), 4);
assert_eq!(resolve_gradient_accumulation_steps(16, 1, Some(128)), 8);
}
}