use std::{
collections::BTreeMap,
fs,
path::{Path, PathBuf},
process::Command,
};
use anyhow::{Context, ensure};
use burn_p2p_checkpoint::{ArtifactBuildSpec, ChunkingScheme, FsArtifactStore};
use burn_p2p_core::{
ArtifactDescriptor, ArtifactKind, AssignmentLease, CapabilityEstimate, ContentId, DatasetId,
DatasetManifest, DatasetView, DatasetViewId, FlattenedTensorPack, HeadId, MergePolicy,
MetricValue, Precision, StateBlob, SupportedWorkload, TrainingProtocol,
};
use burn_p2p_dataloader::{
CachedMicroShard, DatasetRegistration, DatasetSizing, MicroShardPlan, MicroShardPlanner,
MicroShardPlannerConfig, UpstreamAdapter,
};
use burn_p2p_experiment::{PatchSupport, RuntimePatch};
use burn_p2p_workload::{
DiLoCoInnerLoopReport, DiLoCoWorkload, EvalSplit, LeaseDataPipeline,
LeaseDataPipelineDescriptor, LeaseDataPipelineKind, MergeModelCandidate, MetricReport,
P2pWorkload, PatchOutcome, TrainError, TrainerCanonicalReconcileStrategy, WindowCtx,
WindowReport, local_upstream_root_for_pipeline, standard_contribution_weight,
};
use chrono::Utc;
use serde::{Deserialize, Serialize};
use serde_json::Value;
mod worker;
pub use worker::PythonParameterPackPlanResponse;
use worker::{
PythonDiLoCoInnerLoopPathRequest, PythonDiLoCoInnerLoopResponse, PythonMergeCandidateRef,
PythonWorkerClient,
};
const PYTHON_PARAMETER_PACK_FORMAT: &str = "burn-p2p-python-flattened-parameter-pack-v1";
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PythonTorchRuntimeConfig {
pub python_executable: PathBuf,
pub module_search_roots: Vec<PathBuf>,
pub workload_factory: String,
pub workload_config: Value,
pub env: BTreeMap<String, String>,
}
impl PythonTorchRuntimeConfig {
pub fn new(
python_executable: impl Into<PathBuf>,
workload_factory: impl Into<String>,
workload_config: Value,
) -> Self {
Self {
python_executable: python_executable.into(),
module_search_roots: Vec::new(),
workload_factory: workload_factory.into(),
workload_config,
env: BTreeMap::new(),
}
}
pub fn with_module_search_root(mut self, root: impl Into<PathBuf>) -> Self {
self.module_search_roots.push(root.into());
self
}
pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.env.insert(key.into(), value.into());
self
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PythonTorchDatasetConfig {
pub root: PathBuf,
pub dataset_id: DatasetId,
pub dataset_view_id: DatasetViewId,
pub source_uri: String,
pub format: String,
pub manifest_hash: ContentId,
pub preprocessing_hash: ContentId,
pub tokenizer_hash: Option<ContentId>,
pub sizing: DatasetSizing,
pub planner: MicroShardPlannerConfig,
pub microshards_per_batch: usize,
pub metadata: BTreeMap<String, String>,
}
impl PythonTorchDatasetConfig {
pub fn registration(&self) -> DatasetRegistration {
DatasetRegistration {
manifest: DatasetManifest {
dataset_id: self.dataset_id.clone(),
source_uri: self.source_uri.clone(),
format: self.format.clone(),
manifest_hash: self.manifest_hash.clone(),
metadata: self.metadata.clone(),
},
view: DatasetView {
dataset_view_id: self.dataset_view_id.clone(),
dataset_id: self.dataset_id.clone(),
preprocessing_hash: self.preprocessing_hash.clone(),
tokenizer_hash: self.tokenizer_hash.clone(),
manifest_hash: self.manifest_hash.clone(),
metadata: self.metadata.clone(),
},
upstream: UpstreamAdapter::Local {
root: self.root.display().to_string(),
},
}
}
pub fn plan(&self) -> anyhow::Result<MicroShardPlan> {
let registration = self.registration();
Ok(MicroShardPlanner::new(self.planner.clone())?
.plan(®istration.view, self.sizing.clone())?)
}
}
#[derive(Clone, Debug)]
pub struct PythonTorchWorkloadConfig {
pub runtime: PythonTorchRuntimeConfig,
pub dataset: PythonTorchDatasetConfig,
pub supported_workload: SupportedWorkload,
pub model_schema_hash: ContentId,
pub artifact_record_format: String,
pub artifact_precision: Precision,
pub artifact_chunking: ChunkingScheme,
pub patch_support: PatchSupport,
pub diloco: PythonDiLoCoConfig,
pub state_dict_filter: PythonStateDictFilterConfig,
}
impl PythonTorchWorkloadConfig {
pub fn new(
runtime: PythonTorchRuntimeConfig,
dataset: PythonTorchDatasetConfig,
supported_workload: SupportedWorkload,
model_schema_hash: ContentId,
) -> anyhow::Result<Self> {
Ok(Self {
runtime,
dataset,
supported_workload,
model_schema_hash,
artifact_record_format: "python-torch-safetensors".to_owned(),
artifact_precision: Precision::Fp32,
artifact_chunking: ChunkingScheme::new(256 * 1024)?,
patch_support: PatchSupport {
hot: false,
warm: false,
cold: false,
},
diloco: PythonDiLoCoConfig::disabled(),
state_dict_filter: PythonStateDictFilterConfig::default(),
})
}
pub fn with_diloco_in_process(mut self) -> Self {
self.diloco = PythonDiLoCoConfig::in_process_worker();
self
}
pub fn with_diloco_checkpoint_command(
mut self,
command: PythonCheckpointCommandConfig,
) -> Self {
self.diloco = PythonDiLoCoConfig::checkpoint_command(command);
self
}
pub fn with_state_dict_filter(mut self, filter: PythonStateDictFilterConfig) -> Self {
self.state_dict_filter = filter;
self
}
pub fn validate(&self) -> anyhow::Result<()> {
ensure!(
self.runtime.workload_factory.contains(':'),
"python workload_factory must use module:attr form"
);
let (module, attr) = self
.runtime
.workload_factory
.split_once(':')
.expect("contains colon");
ensure!(
!module.trim().is_empty() && !attr.trim().is_empty(),
"python workload_factory must include non-empty module and attr"
);
self.diloco.validate()?;
self.state_dict_filter.validate()?;
Ok(())
}
pub fn validate_for_training_protocol(
&self,
protocol: &TrainingProtocol,
) -> anyhow::Result<()> {
protocol
.validate()
.context("validate training protocol semantics")?;
self.validate()?;
if matches!(protocol, TrainingProtocol::DiLoCo(_))
&& matches!(self.diloco.backend, PythonDiLoCoBackend::Disabled)
{
anyhow::bail!("Python DiLoCo backend is disabled but the training protocol is DiLoCo");
}
Ok(())
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct PythonStateDictFilterConfig {
#[serde(default)]
pub include_globs: Vec<String>,
#[serde(default)]
pub exclude_globs: Vec<String>,
}
impl PythonStateDictFilterConfig {
pub fn all() -> Self {
Self::default()
}
pub fn with_include_glob(mut self, pattern: impl Into<String>) -> Self {
self.include_globs.push(pattern.into());
self
}
pub fn with_exclude_glob(mut self, pattern: impl Into<String>) -> Self {
self.exclude_globs.push(pattern.into());
self
}
pub fn validate(&self) -> anyhow::Result<()> {
validate_globs("include_globs", &self.include_globs)?;
validate_globs("exclude_globs", &self.exclude_globs)?;
Ok(())
}
}
fn validate_globs(label: &str, patterns: &[String]) -> anyhow::Result<()> {
for pattern in patterns {
ensure!(
!pattern.trim().is_empty(),
"state_dict_filter.{label} contains an empty pattern"
);
ensure!(
!pattern.contains('\0'),
"state_dict_filter.{label} pattern contains a NUL byte"
);
}
Ok(())
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PythonDeploymentSanityReport {
pub training_protocol: TrainingProtocol,
pub diloco_backend: PythonDiLoCoBackend,
pub state_dict_filter: PythonStateDictFilterConfig,
pub parameter_pack_plan: PythonParameterPackPlanResponse,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct PythonDiLoCoConfig {
pub backend: PythonDiLoCoBackend,
pub require_exact_steps: bool,
}
impl PythonDiLoCoConfig {
pub fn disabled() -> Self {
Self {
backend: PythonDiLoCoBackend::Disabled,
require_exact_steps: true,
}
}
pub fn in_process_worker() -> Self {
Self {
backend: PythonDiLoCoBackend::InProcessWorker,
require_exact_steps: true,
}
}
pub fn checkpoint_command(command: PythonCheckpointCommandConfig) -> Self {
Self {
backend: PythonDiLoCoBackend::CheckpointCommand(command),
require_exact_steps: true,
}
}
pub fn validate(&self) -> anyhow::Result<()> {
if let PythonDiLoCoBackend::CheckpointCommand(command) = &self.backend {
command.validate()?;
}
Ok(())
}
}
impl Default for PythonDiLoCoConfig {
fn default() -> Self {
Self::disabled()
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum PythonDiLoCoBackend {
Disabled,
InProcessWorker,
CheckpointCommand(PythonCheckpointCommandConfig),
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct PythonCheckpointCommandConfig {
pub program: PathBuf,
pub args: Vec<String>,
pub env: BTreeMap<String, String>,
pub module_search_roots: Vec<PathBuf>,
}
impl PythonCheckpointCommandConfig {
pub fn new(program: impl Into<PathBuf>) -> Self {
Self {
program: program.into(),
args: Vec::new(),
env: BTreeMap::new(),
module_search_roots: Vec::new(),
}
}
pub fn with_arg(mut self, arg: impl Into<String>) -> Self {
self.args.push(arg.into());
self
}
pub fn with_module_search_root(mut self, root: impl Into<PathBuf>) -> Self {
self.module_search_roots.push(root.into());
self
}
pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.env.insert(key.into(), value.into());
self
}
pub fn validate(&self) -> anyhow::Result<()> {
ensure!(
!self.program.as_os_str().is_empty(),
"DiLoCo checkpoint command program must not be empty"
);
for arg in &self.args {
ensure!(
!arg.contains('\0'),
"DiLoCo checkpoint command args must not contain NUL bytes"
);
}
for (key, value) in &self.env {
ensure!(
!key.trim().is_empty() && !key.contains('\0') && !value.contains('\0'),
"DiLoCo checkpoint command env keys and values must be non-empty strings without NUL bytes"
);
}
Ok(())
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum PythonBatchRef {
CachedMicroshardGroup {
shard_paths: Vec<PathBuf>,
microshard_ids: Vec<String>,
ordinals: Vec<u32>,
bytes_len: u64,
},
MicroEpoch {
lease_id: String,
microshard_ids: Vec<String>,
ordinals: Vec<u32>,
bytes_len: u64,
pipeline_kind: LeaseDataPipelineKind,
payload: Value,
},
}
impl PythonBatchRef {
pub fn cached_microshard_group(group: &[CachedMicroShard]) -> Self {
Self::CachedMicroshardGroup {
shard_paths: group.iter().map(|entry| entry.path.clone()).collect(),
microshard_ids: group
.iter()
.map(|entry| entry.microshard.microshard_id.as_str().to_owned())
.collect(),
ordinals: group.iter().map(|entry| entry.microshard.ordinal).collect(),
bytes_len: group.iter().map(|entry| entry.bytes_len).sum(),
}
}
pub fn micro_epoch(
lease: &AssignmentLease,
cached_microshards: &[CachedMicroShard],
pipeline_kind: LeaseDataPipelineKind,
payload: Value,
) -> Self {
Self::MicroEpoch {
lease_id: lease.lease_id.as_str().to_owned(),
microshard_ids: lease
.microshards
.iter()
.map(|microshard_id| microshard_id.as_str().to_owned())
.collect(),
ordinals: cached_microshards
.iter()
.map(|entry| entry.microshard.ordinal)
.collect(),
bytes_len: cached_microshards.iter().map(|entry| entry.bytes_len).sum(),
pipeline_kind,
payload,
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
struct PythonParameterPackManifest {
format: String,
model_schema_hash: String,
layout_hash: String,
parameter_count: usize,
values_f32_le: String,
}
#[derive(Clone, Debug, Serialize)]
struct PythonDiLoCoCheckpointJob<'a> {
protocol_version: u32,
job_id: String,
model_schema_hash: String,
base_parameter_pack_path: String,
output_parameter_pack_path: String,
result_manifest_path: String,
batches: &'a [PythonBatchRef],
num_inner_steps: u32,
require_exact_steps: bool,
state_dict_filter: &'a PythonStateDictFilterConfig,
#[serde(default, skip_serializing_if = "Option::is_none")]
inner_optimizer_state_path: Option<String>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
struct PythonDiLoCoCheckpointResult {
steps_completed: u32,
#[serde(default)]
metrics: BTreeMap<String, MetricValue>,
#[serde(default, skip_serializing_if = "Option::is_none")]
local_parameter_pack_path: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
inner_optimizer_state_path: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
inner_optimizer_state_encoding: Option<String>,
}
#[derive(Debug)]
pub struct PythonModelHandle {
id: String,
client: PythonWorkerClient,
}
impl PythonModelHandle {
fn new(id: String, client: PythonWorkerClient) -> Self {
Self { id, client }
}
fn id(&self) -> &str {
&self.id
}
}
impl Drop for PythonModelHandle {
fn drop(&mut self) {
self.client.release_model(&self.id);
}
}
#[derive(Clone, Debug)]
pub struct PythonTorchProject {
client: PythonWorkerClient,
config: PythonTorchWorkloadConfig,
data_pipeline: LeaseDataPipeline<String, PythonBatchRef>,
workload_name: String,
runtime_device: String,
capability: CapabilityEstimate,
}
impl PythonTorchProject {
pub fn new(config: PythonTorchWorkloadConfig) -> anyhow::Result<Self> {
let data_pipeline = Self::sharded_data_pipeline(&config.dataset);
Self::new_with_data_pipeline(config, data_pipeline)
}
pub fn new_with_data_pipeline(
config: PythonTorchWorkloadConfig,
data_pipeline: LeaseDataPipeline<String, PythonBatchRef>,
) -> anyhow::Result<Self> {
config.validate()?;
let client = PythonWorkerClient::spawn(&config.runtime)?;
let hello = client.hello()?;
if hello.protocol_version == 0 || hello.protocol_version > 2 {
anyhow::bail!(
"python worker protocol mismatch: expected 1 or 2, got {}",
hello.protocol_version
);
}
if matches!(&config.diloco.backend, PythonDiLoCoBackend::InProcessWorker)
&& !hello
.capabilities
.iter()
.any(|capability| capability == "diloco_checkpoint_job")
{
anyhow::bail!(
"python worker does not advertise the diloco_checkpoint_job capability required by in-process DiLoCo"
);
}
let probe = client.capability_probe()?;
Ok(Self {
client,
config,
data_pipeline,
workload_name: hello.workload_name,
runtime_device: probe.runtime_device,
capability: probe.capability,
})
}
pub fn sanity_check_training_protocol(
&self,
protocol: &TrainingProtocol,
) -> anyhow::Result<PythonDeploymentSanityReport> {
self.config.validate_for_training_protocol(protocol)?;
let parameter_pack_plan = self
.client
.parameter_pack_plan(&self.runtime_device, &self.config.state_dict_filter)
.context("probe Python parameter-pack plan")?;
if matches!(protocol, TrainingProtocol::DiLoCo(_))
&& !parameter_pack_plan.uses_custom_parameter_pack_hooks
&& parameter_pack_plan.parameter_count == 0
{
anyhow::bail!(
"Python DiLoCo parameter-pack filter selected zero floating-point parameters"
);
}
Ok(PythonDeploymentSanityReport {
training_protocol: protocol.clone(),
diloco_backend: self.config.diloco.backend.clone(),
state_dict_filter: self.config.state_dict_filter.clone(),
parameter_pack_plan,
})
}
pub fn sharded_data_pipeline(
dataset: &PythonTorchDatasetConfig,
) -> LeaseDataPipeline<String, PythonBatchRef> {
let registration = dataset.registration();
let microshard_plan = dataset.plan().expect("python dataset plan should resolve");
let group_size = dataset.microshards_per_batch.max(1);
LeaseDataPipeline::new(
LeaseDataPipelineDescriptor::new(
"python-sharded-dataset",
LeaseDataPipelineKind::ShardedStatic,
)
.with_metadata_entry("format", dataset.format.clone()),
move || Ok(registration.clone()),
move |_registration| Ok(microshard_plan.clone()),
move |_lease, cached_microshards, _device| {
let batch_count = cached_microshards.len().div_ceil(group_size).max(1);
let mut batches = Vec::with_capacity(batch_count);
for group in cached_microshards.chunks(group_size) {
batches.push(PythonBatchRef::cached_microshard_group(group));
}
Ok(batches)
},
)
}
pub fn micro_epoch_pipeline(
descriptor: LeaseDataPipelineDescriptor,
dataset_registration: impl Fn() -> anyhow::Result<DatasetRegistration> + Send + Sync + 'static,
microshard_plan: impl Fn(&DatasetRegistration) -> anyhow::Result<MicroShardPlan>
+ Send
+ Sync
+ 'static,
payload: impl Fn(&AssignmentLease, &[CachedMicroShard]) -> anyhow::Result<Value>
+ Send
+ Sync
+ 'static,
) -> LeaseDataPipeline<String, PythonBatchRef> {
let pipeline_kind = descriptor.kind;
LeaseDataPipeline::new(
descriptor,
dataset_registration,
microshard_plan,
move |lease, cached_microshards, _device| {
Ok(vec![PythonBatchRef::micro_epoch(
lease,
cached_microshards,
pipeline_kind,
payload(lease, cached_microshards)?,
)])
},
)
}
pub fn indexed_dataset_pipeline(
pipeline_name: impl Into<String>,
dataset_registration: impl Fn() -> anyhow::Result<DatasetRegistration> + Send + Sync + 'static,
microshard_plan: impl Fn(&DatasetRegistration) -> anyhow::Result<MicroShardPlan>
+ Send
+ Sync
+ 'static,
payload: impl Fn(&AssignmentLease, &[CachedMicroShard]) -> anyhow::Result<Value>
+ Send
+ Sync
+ 'static,
) -> LeaseDataPipeline<String, PythonBatchRef> {
Self::micro_epoch_pipeline(
LeaseDataPipelineDescriptor::new(pipeline_name, LeaseDataPipelineKind::IndexedDataset),
dataset_registration,
microshard_plan,
payload,
)
}
pub fn generated_dataset_pipeline(
pipeline_name: impl Into<String>,
dataset_registration: impl Fn() -> anyhow::Result<DatasetRegistration> + Send + Sync + 'static,
microshard_plan: impl Fn(&DatasetRegistration) -> anyhow::Result<MicroShardPlan>
+ Send
+ Sync
+ 'static,
payload: impl Fn(&AssignmentLease, &[CachedMicroShard]) -> anyhow::Result<Value>
+ Send
+ Sync
+ 'static,
) -> LeaseDataPipeline<String, PythonBatchRef> {
Self::micro_epoch_pipeline(
LeaseDataPipelineDescriptor::new(
pipeline_name,
LeaseDataPipelineKind::GeneratedDataset,
),
dataset_registration,
microshard_plan,
payload,
)
}
pub fn probe_capability(&self) -> &CapabilityEstimate {
&self.capability
}
pub fn runtime_device_name(&self) -> &str {
&self.runtime_device
}
pub fn workload_name(&self) -> &str {
&self.workload_name
}
pub fn data_pipeline_descriptor(&self) -> &LeaseDataPipelineDescriptor {
self.data_pipeline.descriptor()
}
pub fn data_pipeline_kind(&self) -> LeaseDataPipelineKind {
self.data_pipeline.kind()
}
pub fn data_pipeline_registration(&self) -> anyhow::Result<DatasetRegistration> {
self.data_pipeline.dataset_registration()
}
pub fn local_upstream_root(&self) -> anyhow::Result<Option<PathBuf>> {
local_upstream_root_for_pipeline(&self.data_pipeline)
}
pub fn configured_shard_root(&self) -> &std::path::Path {
&self.config.dataset.root
}
fn read_parameter_pack_dir(path: &Path) -> anyhow::Result<FlattenedTensorPack> {
let manifest_path = path.join("manifest.json");
let manifest: PythonParameterPackManifest =
serde_json::from_slice(&fs::read(&manifest_path).with_context(|| {
format!("read parameter-pack manifest {}", manifest_path.display())
})?)
.with_context(|| {
format!("decode parameter-pack manifest {}", manifest_path.display())
})?;
ensure!(
manifest.format == PYTHON_PARAMETER_PACK_FORMAT,
"unsupported Python parameter-pack format {}",
manifest.format
);
let values_path = path.join(&manifest.values_f32_le);
let value_bytes = fs::read(&values_path)
.with_context(|| format!("read parameter-pack values {}", values_path.display()))?;
ensure!(
value_bytes.len() == manifest.parameter_count * std::mem::size_of::<f32>(),
"parameter-pack byte length {} does not match parameter count {}",
value_bytes.len(),
manifest.parameter_count
);
let mut values = Vec::with_capacity(manifest.parameter_count);
for chunk in value_bytes.chunks_exact(std::mem::size_of::<f32>()) {
values.push(f32::from_le_bytes(
chunk.try_into().expect("chunk has 4 bytes"),
));
}
Ok(FlattenedTensorPack::new(
ContentId::new(manifest.model_schema_hash),
ContentId::new(manifest.layout_hash),
values,
))
}
fn write_parameter_pack_dir(path: &Path, pack: &FlattenedTensorPack) -> anyhow::Result<()> {
fs::create_dir_all(path)
.with_context(|| format!("create parameter-pack dir {}", path.display()))?;
let values_path = path.join("values.f32le");
let mut value_bytes = Vec::with_capacity(pack.values.len() * std::mem::size_of::<f32>());
for value in &pack.values {
value_bytes.extend_from_slice(&value.to_le_bytes());
}
fs::write(&values_path, value_bytes)
.with_context(|| format!("write parameter-pack values {}", values_path.display()))?;
let manifest = PythonParameterPackManifest {
format: PYTHON_PARAMETER_PACK_FORMAT.to_owned(),
model_schema_hash: pack.model_schema_hash.as_str().to_owned(),
layout_hash: pack.layout_hash.as_str().to_owned(),
parameter_count: pack.parameter_count(),
values_f32_le: "values.f32le".into(),
};
fs::write(
path.join("manifest.json"),
serde_json::to_vec_pretty(&manifest)?,
)
.with_context(|| format!("write parameter-pack manifest {}", path.display()))?;
Ok(())
}
fn write_state_blob_file(path: &Path, state: &StateBlob) -> anyhow::Result<()> {
fs::write(path, &state.bytes)
.with_context(|| format!("write optimizer state blob {}", path.display()))
}
fn read_state_blob_file(path: &Path, encoding: impl Into<String>) -> anyhow::Result<StateBlob> {
StateBlob::try_new(
encoding,
fs::read(path)
.with_context(|| format!("read optimizer state blob {}", path.display()))?,
)
.map_err(anyhow::Error::from)
}
fn validate_diloco_steps(
&self,
requested: u32,
response: &PythonDiLoCoInnerLoopResponse,
) -> Result<(), TrainError> {
if self.config.diloco.require_exact_steps && response.steps_completed != requested {
return Err(TrainError::new(format!(
"Python DiLoCo inner loop completed {} step(s), requested {}",
response.steps_completed, requested
)));
}
Ok(())
}
fn run_checkpoint_command(
&self,
command_config: &PythonCheckpointCommandConfig,
job_manifest_path: &Path,
result_manifest_path: &Path,
) -> anyhow::Result<()> {
let runtime_root = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("python");
let mut pythonpath_entries = vec![runtime_root];
pythonpath_entries.extend(self.config.runtime.module_search_roots.iter().cloned());
pythonpath_entries.extend(command_config.module_search_roots.iter().cloned());
let pythonpath = worker::join_pythonpath_for_command(&pythonpath_entries)?;
let mut command = Command::new(&command_config.program);
for arg in &command_config.args {
command.arg(
arg.replace(
"{job_manifest}",
job_manifest_path.to_string_lossy().as_ref(),
)
.replace(
"{result_manifest}",
result_manifest_path.to_string_lossy().as_ref(),
),
);
}
command
.env("BURN_P2P_DILOCO_JOB_MANIFEST", job_manifest_path)
.env("BURN_P2P_DILOCO_RESULT_MANIFEST", result_manifest_path);
if let Some(path) = pythonpath {
command.env("PYTHONPATH", path);
}
for (key, value) in &self.config.runtime.env {
command.env(key, value);
}
for (key, value) in &command_config.env {
command.env(key, value);
}
let status = command.status().with_context(|| {
format!(
"spawn Python DiLoCo checkpoint command {:?}",
command_config.program
)
})?;
ensure!(
status.success(),
"Python DiLoCo checkpoint command exited with {status}"
);
Ok(())
}
}
impl P2pWorkload for PythonTorchProject {
type Device = String;
type Model = PythonModelHandle;
type Batch = PythonBatchRef;
type WindowStats = BTreeMap<String, MetricValue>;
fn init_model(&self, device: &Self::Device) -> Self::Model {
let model_id = self
.client
.init_model(device)
.expect("python worker should initialize a model");
PythonModelHandle::new(model_id, self.client.clone())
}
fn benchmark(&self, _model: &Self::Model, _device: &Self::Device) -> CapabilityEstimate {
self.capability.clone()
}
fn train_window(
&self,
ctx: &mut WindowCtx<Self::Device, Self::Model, Self::Batch>,
) -> Result<WindowReport<Self::WindowStats>, TrainError> {
let mut metrics = self
.client
.train_window(ctx.model.id(), &ctx.batches)
.map_err(|error| TrainError::new(error.to_string()))?;
metrics.insert(
"batch_count".into(),
MetricValue::Integer(ctx.batches.len() as i64),
);
let examples_processed = ctx
.cached_microshards
.iter()
.map(|cached| cached.microshard.estimated_examples)
.sum::<u64>();
let tokens_processed = ctx
.cached_microshards
.iter()
.map(|cached| cached.microshard.estimated_tokens)
.sum::<u64>();
if examples_processed > 0 {
metrics.insert(
"examples_processed".into(),
MetricValue::Integer(examples_processed as i64),
);
}
if tokens_processed > 0 {
metrics.insert(
"tokens_processed".into(),
MetricValue::Integer(tokens_processed as i64),
);
}
if !ctx.cached_microshards.is_empty() {
metrics.insert(
"microshard_count".into(),
MetricValue::Integer(ctx.cached_microshards.len() as i64),
);
}
Ok(WindowReport {
contribution: None,
stats: metrics,
completed_at: Utc::now(),
})
}
fn evaluate(&self, model: &Self::Model, split: EvalSplit) -> MetricReport {
let metrics = self
.client
.evaluate(model.id(), split)
.unwrap_or_else(|error| {
BTreeMap::from([("python_error".into(), MetricValue::Text(error.to_string()))])
});
MetricReport {
metrics,
captured_at: Utc::now(),
}
}
fn apply_patch(&mut self, patch: &RuntimePatch) -> PatchOutcome {
self.client
.apply_patch(patch)
.unwrap_or_else(|error| PatchOutcome::Rejected(error.to_string()))
}
fn supported_patch_classes(&self) -> PatchSupport {
self.config.patch_support
}
fn runtime_device(&self) -> Self::Device {
self.runtime_device.clone()
}
fn dataset_registration(&self) -> anyhow::Result<DatasetRegistration> {
self.data_pipeline.dataset_registration()
}
fn microshard_plan(
&self,
_registration: &DatasetRegistration,
) -> anyhow::Result<MicroShardPlan> {
self.data_pipeline.microshard_plan(_registration)
}
fn load_batches(
&self,
lease: &AssignmentLease,
cached_microshards: &[CachedMicroShard],
) -> anyhow::Result<Vec<Self::Batch>> {
self.data_pipeline
.load_batches(lease, cached_microshards, &self.runtime_device)
}
fn load_model_artifact(
&self,
model: Self::Model,
descriptor: &ArtifactDescriptor,
store: &FsArtifactStore,
_device: &Self::Device,
) -> anyhow::Result<Self::Model> {
let staged_dir = tempfile::Builder::new()
.prefix("burn-p2p-python-load-artifact")
.tempdir()?;
let staged_path = staged_dir.path().join("artifact.safetensors");
store.materialize_artifact_file(descriptor, &staged_path)?;
self.client
.load_model_artifact_path(model.id(), &staged_path)
.context("load python model artifact into worker")?;
Ok(model)
}
fn materialize_model_artifact(
&self,
model: &Self::Model,
artifact_kind: ArtifactKind,
head_id: HeadId,
base_head_id: Option<HeadId>,
store: &FsArtifactStore,
) -> anyhow::Result<ArtifactDescriptor> {
let staged_dir = tempfile::Builder::new()
.prefix("burn-p2p-python-materialized-artifact")
.tempdir()?;
let staged_path = staged_dir.path().join("artifact.safetensors");
self.client
.materialize_model_artifact_path(model.id(), &staged_path)
.context("materialize python model artifact")?;
let mut spec = ArtifactBuildSpec::new(
artifact_kind,
self.config.artifact_precision.clone(),
self.config.model_schema_hash.clone(),
self.config.artifact_record_format.clone(),
)
.with_head(head_id);
if let Some(base_head_id) = base_head_id {
spec = spec.with_base_head(base_head_id);
}
store
.store_artifact_file(&spec, &staged_path, self.config.artifact_chunking)
.map_err(Into::into)
}
fn contribution_metrics(
&self,
report: &WindowReport<Self::WindowStats>,
) -> BTreeMap<String, MetricValue> {
report.stats.clone()
}
fn contribution_weight(&self, report: &WindowReport<Self::WindowStats>) -> f64 {
standard_contribution_weight(&report.stats).unwrap_or(1.0)
}
fn reconcile_canonical_model(
&self,
local_model: &Self::Model,
canonical_model: Self::Model,
strategy: TrainerCanonicalReconcileStrategy,
) -> anyhow::Result<Self::Model> {
let canonical_model = canonical_model;
let returned_id = self.client.reconcile_canonical_model(
local_model.id(),
canonical_model.id(),
strategy,
)?;
debug_assert_eq!(returned_id, canonical_model.id());
Ok(canonical_model)
}
fn merge_candidate_models(
&self,
base_model: &Self::Model,
candidates: &[MergeModelCandidate<'_, Self::Model>],
policy: MergePolicy,
) -> anyhow::Result<Option<Self::Model>> {
let candidate_refs = candidates
.iter()
.map(|candidate| PythonMergeCandidateRef {
peer_id: candidate.peer_id.as_str(),
head_id: candidate.head_id.as_str(),
artifact_id: candidate.artifact_id.as_str(),
model_id: candidate.model.id(),
sample_weight: candidate.sample_weight,
quality_weight: candidate.quality_weight,
})
.collect::<Vec<_>>();
let merged =
self.client
.merge_candidate_models(base_model.id(), &candidate_refs, policy)?;
Ok(merged.map(|model_id| PythonModelHandle::new(model_id, self.client.clone())))
}
fn apply_single_root_ema(
&self,
base_model: &Self::Model,
merged_model: Self::Model,
policy: MergePolicy,
) -> anyhow::Result<Self::Model> {
let merged_model = merged_model;
let returned_id =
self.client
.apply_single_root_ema(base_model.id(), merged_model.id(), policy)?;
debug_assert_eq!(returned_id, merged_model.id());
Ok(merged_model)
}
fn supported_workload(&self) -> SupportedWorkload {
self.config.supported_workload.clone()
}
fn model_schema_hash(&self) -> ContentId {
self.config.model_schema_hash.clone()
}
}
impl DiLoCoWorkload for PythonTorchProject {
fn export_parameter_pack(&self, model: &Self::Model) -> anyhow::Result<FlattenedTensorPack> {
let staged_dir = tempfile::Builder::new()
.prefix("burn-p2p-python-export-pack")
.tempdir()?;
let pack_path = staged_dir.path().join("parameters");
self.client
.export_parameter_pack_path(
model.id(),
&pack_path,
&self.config.model_schema_hash,
&self.config.state_dict_filter,
)
.context("export Python model parameter pack")?;
let pack = Self::read_parameter_pack_dir(&pack_path)?;
ensure!(
pack.model_schema_hash == self.config.model_schema_hash,
"Python exported model schema {}, expected {}",
pack.model_schema_hash.as_str(),
self.config.model_schema_hash.as_str()
);
Ok(pack)
}
fn import_parameter_pack(
&self,
device: &Self::Device,
pack: &FlattenedTensorPack,
) -> anyhow::Result<Self::Model> {
ensure!(
pack.model_schema_hash == self.config.model_schema_hash,
"cannot import Python parameter pack for schema {}, expected {}",
pack.model_schema_hash.as_str(),
self.config.model_schema_hash.as_str()
);
let staged_dir = tempfile::Builder::new()
.prefix("burn-p2p-python-import-pack")
.tempdir()?;
let pack_path = staged_dir.path().join("parameters");
Self::write_parameter_pack_dir(&pack_path, pack)?;
let model_id = self
.client
.import_parameter_pack_path(device, &pack_path, &self.config.state_dict_filter)
.context("import Python model parameter pack")?;
Ok(PythonModelHandle::new(model_id, self.client.clone()))
}
fn run_inner_steps(
&self,
model: &Self::Model,
batches: &[Self::Batch],
num_inner_steps: u32,
inner_optimizer_state: Option<&StateBlob>,
) -> Result<DiLoCoInnerLoopReport, TrainError> {
if num_inner_steps > 0 && batches.is_empty() {
return Err(TrainError::new(
"Python DiLoCo inner loop requires at least one batch",
));
}
match &self.config.diloco.backend {
PythonDiLoCoBackend::Disabled => Err(TrainError::new(
"Python DiLoCo bridge is disabled for this workload",
)),
PythonDiLoCoBackend::InProcessWorker => {
self.run_worker_inner_loop(model, batches, num_inner_steps, inner_optimizer_state)
}
PythonDiLoCoBackend::CheckpointCommand(command) => self.run_command_inner_loop(
command,
model,
batches,
num_inner_steps,
inner_optimizer_state,
),
}
}
}
impl PythonTorchProject {
fn run_worker_inner_loop(
&self,
model: &PythonModelHandle,
batches: &[PythonBatchRef],
num_inner_steps: u32,
inner_optimizer_state: Option<&StateBlob>,
) -> Result<DiLoCoInnerLoopReport, TrainError> {
let staged_dir = tempfile::Builder::new()
.prefix("burn-p2p-python-diloco-worker")
.tempdir()
.map_err(|error| TrainError::new(error.to_string()))?;
let output_pack_path = staged_dir.path().join("local-parameters");
let inner_state_path = if let Some(state) = inner_optimizer_state {
let path = staged_dir.path().join("inner-state-input.blob");
Self::write_state_blob_file(&path, state)
.map_err(|error| TrainError::new(error.to_string()))?;
Some(path)
} else {
None
};
let response = self
.client
.run_inner_loop_path(PythonDiLoCoInnerLoopPathRequest {
model_id: model.id(),
batches,
num_inner_steps,
inner_optimizer_state_path: inner_state_path.as_deref(),
output_parameter_pack_path: &output_pack_path,
model_schema_hash: &self.config.model_schema_hash,
require_exact_steps: self.config.diloco.require_exact_steps,
state_dict_filter: &self.config.state_dict_filter,
})
.map_err(|error| TrainError::new(error.to_string()))?;
self.validate_diloco_steps(num_inner_steps, &response)?;
let local_parameters = Self::read_parameter_pack_dir(&output_pack_path)
.map_err(|error| TrainError::new(error.to_string()))?;
let inner_optimizer_state = response
.inner_optimizer_state_path
.as_ref()
.map(|path| {
Self::read_state_blob_file(
path,
response
.inner_optimizer_state_encoding
.clone()
.unwrap_or_else(|| "application/octet-stream".into()),
)
})
.transpose()
.map_err(|error| TrainError::new(error.to_string()))?;
Ok(DiLoCoInnerLoopReport {
local_parameters,
inner_optimizer_state,
steps_completed: response.steps_completed,
metrics: response.metrics,
})
}
fn run_command_inner_loop(
&self,
command_config: &PythonCheckpointCommandConfig,
model: &PythonModelHandle,
batches: &[PythonBatchRef],
num_inner_steps: u32,
inner_optimizer_state: Option<&StateBlob>,
) -> Result<DiLoCoInnerLoopReport, TrainError> {
let staged_dir = tempfile::Builder::new()
.prefix("burn-p2p-python-diloco-command")
.tempdir()
.map_err(|error| TrainError::new(error.to_string()))?;
let base_pack_path = staged_dir.path().join("base-parameters");
let output_pack_path = staged_dir.path().join("local-parameters");
let result_manifest_path = staged_dir.path().join("result.json");
let job_manifest_path = staged_dir.path().join("job.json");
self.client
.export_parameter_pack_path(
model.id(),
&base_pack_path,
&self.config.model_schema_hash,
&self.config.state_dict_filter,
)
.map_err(|error| TrainError::new(error.to_string()))?;
let inner_state_path = if let Some(state) = inner_optimizer_state {
let path = staged_dir.path().join("inner-state-input.blob");
Self::write_state_blob_file(&path, state)
.map_err(|error| TrainError::new(error.to_string()))?;
Some(path)
} else {
None
};
let job = PythonDiLoCoCheckpointJob {
protocol_version: 1,
job_id: format!(
"diloco-inner-{}",
Utc::now().timestamp_nanos_opt().unwrap_or(0)
),
model_schema_hash: self.config.model_schema_hash.as_str().to_owned(),
base_parameter_pack_path: base_pack_path.to_string_lossy().into_owned(),
output_parameter_pack_path: output_pack_path.to_string_lossy().into_owned(),
result_manifest_path: result_manifest_path.to_string_lossy().into_owned(),
batches,
num_inner_steps,
require_exact_steps: self.config.diloco.require_exact_steps,
state_dict_filter: &self.config.state_dict_filter,
inner_optimizer_state_path: inner_state_path
.as_ref()
.map(|path| path.to_string_lossy().into_owned()),
};
fs::write(
&job_manifest_path,
serde_json::to_vec_pretty(&job).map_err(|error| {
TrainError::new(format!("serialize Python DiLoCo command job: {error}"))
})?,
)
.map_err(|error| TrainError::new(error.to_string()))?;
self.run_checkpoint_command(command_config, &job_manifest_path, &result_manifest_path)
.map_err(|error| TrainError::new(error.to_string()))?;
let result: PythonDiLoCoCheckpointResult =
serde_json::from_slice(&fs::read(&result_manifest_path).map_err(|error| {
TrainError::new(format!(
"read Python DiLoCo command result {}: {error}",
result_manifest_path.display()
))
})?)
.map_err(|error| TrainError::new(format!("decode Python DiLoCo result: {error}")))?;
let response = PythonDiLoCoInnerLoopResponse {
steps_completed: result.steps_completed,
metrics: result.metrics,
inner_optimizer_state_path: result.inner_optimizer_state_path.map(PathBuf::from),
inner_optimizer_state_encoding: result.inner_optimizer_state_encoding,
};
self.validate_diloco_steps(num_inner_steps, &response)?;
let local_pack_path = result
.local_parameter_pack_path
.map(PathBuf::from)
.unwrap_or(output_pack_path);
let local_parameters = Self::read_parameter_pack_dir(&local_pack_path)
.map_err(|error| TrainError::new(error.to_string()))?;
let inner_optimizer_state = response
.inner_optimizer_state_path
.as_ref()
.map(|path| {
Self::read_state_blob_file(
path,
response
.inner_optimizer_state_encoding
.clone()
.unwrap_or_else(|| "application/octet-stream".into()),
)
})
.transpose()
.map_err(|error| TrainError::new(error.to_string()))?;
Ok(DiLoCoInnerLoopReport {
local_parameters,
inner_optimizer_state,
steps_completed: response.steps_completed,
metrics: response.metrics,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_workload_config() -> PythonTorchWorkloadConfig {
let runtime = PythonTorchRuntimeConfig::new(
"python3",
"demo_runtime:make_workload",
serde_json::json!({}),
);
let dataset = PythonTorchDatasetConfig {
root: PathBuf::from("/tmp/python-dataset"),
dataset_id: DatasetId::new("dataset-a"),
dataset_view_id: DatasetViewId::new("view-a"),
source_uri: "file:///tmp/python-dataset".into(),
format: "test".into(),
manifest_hash: ContentId::new("manifest-a"),
preprocessing_hash: ContentId::new("preprocess-a"),
tokenizer_hash: None,
sizing: DatasetSizing::default(),
planner: MicroShardPlannerConfig::default(),
microshards_per_batch: 1,
metadata: BTreeMap::new(),
};
let workload = burn_p2p_core::SupportedWorkloadBuilder::new(
burn_p2p_core::WorkloadId::new("python-test"),
"Python Test",
ContentId::new("program-a"),
ContentId::new("format-a"),
)
.build();
PythonTorchWorkloadConfig::new(runtime, dataset, workload, ContentId::new("schema-a"))
.expect("config")
}
#[test]
fn python_diloco_config_defaults_to_disabled_exact_steps() {
let config = PythonDiLoCoConfig::default();
assert!(matches!(config.backend, PythonDiLoCoBackend::Disabled));
assert!(config.require_exact_steps);
let command = PythonCheckpointCommandConfig::new("python3").with_arg("trainer.py");
let config = PythonDiLoCoConfig::checkpoint_command(command);
assert!(matches!(
config.backend,
PythonDiLoCoBackend::CheckpointCommand(_)
));
assert!(config.require_exact_steps);
}
#[test]
fn state_dict_filter_rejects_empty_patterns() {
let filter = PythonStateDictFilterConfig::default().with_include_glob(" ");
let error = filter.validate().expect_err("empty glob should fail");
assert!(
error
.to_string()
.contains("state_dict_filter.include_globs contains an empty pattern")
);
}
#[test]
fn workload_config_rejects_diloco_protocol_when_backend_disabled() {
let config = test_workload_config();
let protocol = TrainingProtocol::DiLoCo(burn_p2p_core::DiLoCoPolicy {
target_group_size: 1,
minimum_group_size: 1,
topology_policy: burn_p2p_core::DiLoCoTopologyPolicy {
fanout: 1,
..burn_p2p_core::DiLoCoTopologyPolicy::default()
},
..burn_p2p_core::DiLoCoPolicy::default()
});
let error = config
.validate_for_training_protocol(&protocol)
.expect_err("disabled backend should fail DiLoCo protocol");
assert!(
error
.to_string()
.contains("Python DiLoCo backend is disabled")
);
let enabled = config.with_diloco_in_process();
enabled
.validate_for_training_protocol(&protocol)
.expect("enabled in-process DiLoCo config");
}
#[test]
fn parameter_pack_sidecar_round_trips_flattened_values() {
let temp = tempfile::tempdir().expect("tempdir");
let pack = FlattenedTensorPack::new(
ContentId::new("schema-a"),
ContentId::new("layout-a"),
vec![0.25, -1.5, 3.0, 8.25],
);
PythonTorchProject::write_parameter_pack_dir(temp.path(), &pack).expect("write pack");
let decoded = PythonTorchProject::read_parameter_pack_dir(temp.path()).expect("read pack");
assert_eq!(decoded, pack);
}
}