use std::{collections::BTreeMap, path::PathBuf};
use anyhow::Context;
use burn_p2p_checkpoint::{ArtifactBuildSpec, ChunkingScheme, FsArtifactStore};
use burn_p2p_core::{
ArtifactDescriptor, ArtifactKind, AssignmentLease, CapabilityEstimate, ContentId, DatasetId,
DatasetManifest, DatasetView, DatasetViewId, HeadId, MergePolicy, MetricValue, Precision,
SupportedWorkload,
};
use burn_p2p_dataloader::{
CachedMicroShard, DatasetRegistration, DatasetSizing, MicroShardPlan, MicroShardPlanner,
MicroShardPlannerConfig, UpstreamAdapter,
};
use burn_p2p_experiment::{PatchSupport, RuntimePatch};
use burn_p2p_workload::{
EvalSplit, LeaseDataPipeline, LeaseDataPipelineDescriptor, LeaseDataPipelineKind,
MergeModelCandidate, MetricReport, P2pWorkload, PatchOutcome, TrainError,
TrainerCanonicalReconcileStrategy, WindowCtx, WindowReport, local_upstream_root_for_pipeline,
standard_contribution_weight,
};
use chrono::Utc;
use serde::{Deserialize, Serialize};
use serde_json::Value;
mod worker;
use worker::{PythonMergeCandidateRef, PythonWorkerClient};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PythonTorchRuntimeConfig {
pub python_executable: PathBuf,
pub module_search_roots: Vec<PathBuf>,
pub workload_factory: String,
pub workload_config: Value,
pub env: BTreeMap<String, String>,
}
impl PythonTorchRuntimeConfig {
pub fn new(
python_executable: impl Into<PathBuf>,
workload_factory: impl Into<String>,
workload_config: Value,
) -> Self {
Self {
python_executable: python_executable.into(),
module_search_roots: Vec::new(),
workload_factory: workload_factory.into(),
workload_config,
env: BTreeMap::new(),
}
}
pub fn with_module_search_root(mut self, root: impl Into<PathBuf>) -> Self {
self.module_search_roots.push(root.into());
self
}
pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.env.insert(key.into(), value.into());
self
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PythonTorchDatasetConfig {
pub root: PathBuf,
pub dataset_id: DatasetId,
pub dataset_view_id: DatasetViewId,
pub source_uri: String,
pub format: String,
pub manifest_hash: ContentId,
pub preprocessing_hash: ContentId,
pub tokenizer_hash: Option<ContentId>,
pub sizing: DatasetSizing,
pub planner: MicroShardPlannerConfig,
pub microshards_per_batch: usize,
pub metadata: BTreeMap<String, String>,
}
impl PythonTorchDatasetConfig {
pub fn registration(&self) -> DatasetRegistration {
DatasetRegistration {
manifest: DatasetManifest {
dataset_id: self.dataset_id.clone(),
source_uri: self.source_uri.clone(),
format: self.format.clone(),
manifest_hash: self.manifest_hash.clone(),
metadata: self.metadata.clone(),
},
view: DatasetView {
dataset_view_id: self.dataset_view_id.clone(),
dataset_id: self.dataset_id.clone(),
preprocessing_hash: self.preprocessing_hash.clone(),
tokenizer_hash: self.tokenizer_hash.clone(),
manifest_hash: self.manifest_hash.clone(),
metadata: self.metadata.clone(),
},
upstream: UpstreamAdapter::Local {
root: self.root.display().to_string(),
},
}
}
pub fn plan(&self) -> anyhow::Result<MicroShardPlan> {
let registration = self.registration();
Ok(MicroShardPlanner::new(self.planner.clone())?
.plan(®istration.view, self.sizing.clone())?)
}
}
#[derive(Clone, Debug)]
pub struct PythonTorchWorkloadConfig {
pub runtime: PythonTorchRuntimeConfig,
pub dataset: PythonTorchDatasetConfig,
pub supported_workload: SupportedWorkload,
pub model_schema_hash: ContentId,
pub artifact_record_format: String,
pub artifact_precision: Precision,
pub artifact_chunking: ChunkingScheme,
pub patch_support: PatchSupport,
}
impl PythonTorchWorkloadConfig {
pub fn new(
runtime: PythonTorchRuntimeConfig,
dataset: PythonTorchDatasetConfig,
supported_workload: SupportedWorkload,
model_schema_hash: ContentId,
) -> anyhow::Result<Self> {
Ok(Self {
runtime,
dataset,
supported_workload,
model_schema_hash,
artifact_record_format: "python-torch-safetensors".to_owned(),
artifact_precision: Precision::Fp32,
artifact_chunking: ChunkingScheme::new(256 * 1024)?,
patch_support: PatchSupport {
hot: false,
warm: false,
cold: false,
},
})
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum PythonBatchRef {
CachedMicroshardGroup {
shard_paths: Vec<PathBuf>,
microshard_ids: Vec<String>,
ordinals: Vec<u32>,
bytes_len: u64,
},
MicroEpoch {
lease_id: String,
microshard_ids: Vec<String>,
ordinals: Vec<u32>,
bytes_len: u64,
pipeline_kind: LeaseDataPipelineKind,
payload: Value,
},
}
impl PythonBatchRef {
pub fn cached_microshard_group(group: &[CachedMicroShard]) -> Self {
Self::CachedMicroshardGroup {
shard_paths: group.iter().map(|entry| entry.path.clone()).collect(),
microshard_ids: group
.iter()
.map(|entry| entry.microshard.microshard_id.as_str().to_owned())
.collect(),
ordinals: group.iter().map(|entry| entry.microshard.ordinal).collect(),
bytes_len: group.iter().map(|entry| entry.bytes_len).sum(),
}
}
pub fn micro_epoch(
lease: &AssignmentLease,
cached_microshards: &[CachedMicroShard],
pipeline_kind: LeaseDataPipelineKind,
payload: Value,
) -> Self {
Self::MicroEpoch {
lease_id: lease.lease_id.as_str().to_owned(),
microshard_ids: lease
.microshards
.iter()
.map(|microshard_id| microshard_id.as_str().to_owned())
.collect(),
ordinals: cached_microshards
.iter()
.map(|entry| entry.microshard.ordinal)
.collect(),
bytes_len: cached_microshards.iter().map(|entry| entry.bytes_len).sum(),
pipeline_kind,
payload,
}
}
}
#[derive(Debug)]
pub struct PythonModelHandle {
id: String,
client: PythonWorkerClient,
}
impl PythonModelHandle {
fn new(id: String, client: PythonWorkerClient) -> Self {
Self { id, client }
}
fn id(&self) -> &str {
&self.id
}
}
impl Drop for PythonModelHandle {
fn drop(&mut self) {
self.client.release_model(&self.id);
}
}
#[derive(Clone, Debug)]
pub struct PythonTorchProject {
client: PythonWorkerClient,
config: PythonTorchWorkloadConfig,
data_pipeline: LeaseDataPipeline<String, PythonBatchRef>,
workload_name: String,
runtime_device: String,
capability: CapabilityEstimate,
}
impl PythonTorchProject {
pub fn new(config: PythonTorchWorkloadConfig) -> anyhow::Result<Self> {
let data_pipeline = Self::sharded_data_pipeline(&config.dataset);
Self::new_with_data_pipeline(config, data_pipeline)
}
pub fn new_with_data_pipeline(
config: PythonTorchWorkloadConfig,
data_pipeline: LeaseDataPipeline<String, PythonBatchRef>,
) -> anyhow::Result<Self> {
let client = PythonWorkerClient::spawn(&config.runtime)?;
let hello = client.hello()?;
if hello.protocol_version != 1 {
anyhow::bail!(
"python worker protocol mismatch: expected 1, got {}",
hello.protocol_version
);
}
let probe = client.capability_probe()?;
Ok(Self {
client,
config,
data_pipeline,
workload_name: hello.workload_name,
runtime_device: probe.runtime_device,
capability: probe.capability,
})
}
pub fn sharded_data_pipeline(
dataset: &PythonTorchDatasetConfig,
) -> LeaseDataPipeline<String, PythonBatchRef> {
let registration = dataset.registration();
let microshard_plan = dataset.plan().expect("python dataset plan should resolve");
let group_size = dataset.microshards_per_batch.max(1);
LeaseDataPipeline::new(
LeaseDataPipelineDescriptor::new(
"python-sharded-dataset",
LeaseDataPipelineKind::ShardedStatic,
)
.with_metadata_entry("format", dataset.format.clone()),
move || Ok(registration.clone()),
move |_registration| Ok(microshard_plan.clone()),
move |_lease, cached_microshards, _device| {
let batch_count = cached_microshards.len().div_ceil(group_size).max(1);
let mut batches = Vec::with_capacity(batch_count);
for group in cached_microshards.chunks(group_size) {
batches.push(PythonBatchRef::cached_microshard_group(group));
}
Ok(batches)
},
)
}
pub fn micro_epoch_pipeline(
descriptor: LeaseDataPipelineDescriptor,
dataset_registration: impl Fn() -> anyhow::Result<DatasetRegistration> + Send + Sync + 'static,
microshard_plan: impl Fn(&DatasetRegistration) -> anyhow::Result<MicroShardPlan>
+ Send
+ Sync
+ 'static,
payload: impl Fn(&AssignmentLease, &[CachedMicroShard]) -> anyhow::Result<Value>
+ Send
+ Sync
+ 'static,
) -> LeaseDataPipeline<String, PythonBatchRef> {
let pipeline_kind = descriptor.kind;
LeaseDataPipeline::new(
descriptor,
dataset_registration,
microshard_plan,
move |lease, cached_microshards, _device| {
Ok(vec![PythonBatchRef::micro_epoch(
lease,
cached_microshards,
pipeline_kind,
payload(lease, cached_microshards)?,
)])
},
)
}
pub fn indexed_dataset_pipeline(
pipeline_name: impl Into<String>,
dataset_registration: impl Fn() -> anyhow::Result<DatasetRegistration> + Send + Sync + 'static,
microshard_plan: impl Fn(&DatasetRegistration) -> anyhow::Result<MicroShardPlan>
+ Send
+ Sync
+ 'static,
payload: impl Fn(&AssignmentLease, &[CachedMicroShard]) -> anyhow::Result<Value>
+ Send
+ Sync
+ 'static,
) -> LeaseDataPipeline<String, PythonBatchRef> {
Self::micro_epoch_pipeline(
LeaseDataPipelineDescriptor::new(pipeline_name, LeaseDataPipelineKind::IndexedDataset),
dataset_registration,
microshard_plan,
payload,
)
}
pub fn generated_dataset_pipeline(
pipeline_name: impl Into<String>,
dataset_registration: impl Fn() -> anyhow::Result<DatasetRegistration> + Send + Sync + 'static,
microshard_plan: impl Fn(&DatasetRegistration) -> anyhow::Result<MicroShardPlan>
+ Send
+ Sync
+ 'static,
payload: impl Fn(&AssignmentLease, &[CachedMicroShard]) -> anyhow::Result<Value>
+ Send
+ Sync
+ 'static,
) -> LeaseDataPipeline<String, PythonBatchRef> {
Self::micro_epoch_pipeline(
LeaseDataPipelineDescriptor::new(
pipeline_name,
LeaseDataPipelineKind::GeneratedDataset,
),
dataset_registration,
microshard_plan,
payload,
)
}
pub fn probe_capability(&self) -> &CapabilityEstimate {
&self.capability
}
pub fn runtime_device_name(&self) -> &str {
&self.runtime_device
}
pub fn workload_name(&self) -> &str {
&self.workload_name
}
pub fn data_pipeline_descriptor(&self) -> &LeaseDataPipelineDescriptor {
self.data_pipeline.descriptor()
}
pub fn data_pipeline_kind(&self) -> LeaseDataPipelineKind {
self.data_pipeline.kind()
}
pub fn data_pipeline_registration(&self) -> anyhow::Result<DatasetRegistration> {
self.data_pipeline.dataset_registration()
}
pub fn local_upstream_root(&self) -> anyhow::Result<Option<PathBuf>> {
local_upstream_root_for_pipeline(&self.data_pipeline)
}
pub fn configured_shard_root(&self) -> &std::path::Path {
&self.config.dataset.root
}
}
impl P2pWorkload for PythonTorchProject {
type Device = String;
type Model = PythonModelHandle;
type Batch = PythonBatchRef;
type WindowStats = BTreeMap<String, MetricValue>;
fn init_model(&self, device: &Self::Device) -> Self::Model {
let model_id = self
.client
.init_model(device)
.expect("python worker should initialize a model");
PythonModelHandle::new(model_id, self.client.clone())
}
fn benchmark(&self, _model: &Self::Model, _device: &Self::Device) -> CapabilityEstimate {
self.capability.clone()
}
fn train_window(
&self,
ctx: &mut WindowCtx<Self::Device, Self::Model, Self::Batch>,
) -> Result<WindowReport<Self::WindowStats>, TrainError> {
let mut metrics = self
.client
.train_window(ctx.model.id(), &ctx.batches)
.map_err(|error| TrainError::new(error.to_string()))?;
metrics.insert(
"batch_count".into(),
MetricValue::Integer(ctx.batches.len() as i64),
);
let examples_processed = ctx
.cached_microshards
.iter()
.map(|cached| cached.microshard.estimated_examples)
.sum::<u64>();
let tokens_processed = ctx
.cached_microshards
.iter()
.map(|cached| cached.microshard.estimated_tokens)
.sum::<u64>();
if examples_processed > 0 {
metrics.insert(
"examples_processed".into(),
MetricValue::Integer(examples_processed as i64),
);
}
if tokens_processed > 0 {
metrics.insert(
"tokens_processed".into(),
MetricValue::Integer(tokens_processed as i64),
);
}
if !ctx.cached_microshards.is_empty() {
metrics.insert(
"microshard_count".into(),
MetricValue::Integer(ctx.cached_microshards.len() as i64),
);
}
Ok(WindowReport {
contribution: None,
stats: metrics,
completed_at: Utc::now(),
})
}
fn evaluate(&self, model: &Self::Model, split: EvalSplit) -> MetricReport {
let metrics = self
.client
.evaluate(model.id(), split)
.unwrap_or_else(|error| {
BTreeMap::from([("python_error".into(), MetricValue::Text(error.to_string()))])
});
MetricReport {
metrics,
captured_at: Utc::now(),
}
}
fn apply_patch(&mut self, patch: &RuntimePatch) -> PatchOutcome {
self.client
.apply_patch(patch)
.unwrap_or_else(|error| PatchOutcome::Rejected(error.to_string()))
}
fn supported_patch_classes(&self) -> PatchSupport {
self.config.patch_support
}
fn runtime_device(&self) -> Self::Device {
self.runtime_device.clone()
}
fn dataset_registration(&self) -> anyhow::Result<DatasetRegistration> {
self.data_pipeline.dataset_registration()
}
fn microshard_plan(
&self,
_registration: &DatasetRegistration,
) -> anyhow::Result<MicroShardPlan> {
self.data_pipeline.microshard_plan(_registration)
}
fn load_batches(
&self,
lease: &AssignmentLease,
cached_microshards: &[CachedMicroShard],
) -> anyhow::Result<Vec<Self::Batch>> {
self.data_pipeline
.load_batches(lease, cached_microshards, &self.runtime_device)
}
fn load_model_artifact(
&self,
model: Self::Model,
descriptor: &ArtifactDescriptor,
store: &FsArtifactStore,
_device: &Self::Device,
) -> anyhow::Result<Self::Model> {
let staged_dir = tempfile::Builder::new()
.prefix("burn-p2p-python-load-artifact")
.tempdir()?;
let staged_path = staged_dir.path().join("artifact.safetensors");
store.materialize_artifact_file(descriptor, &staged_path)?;
self.client
.load_model_artifact_path(model.id(), &staged_path)
.context("load python model artifact into worker")?;
Ok(model)
}
fn materialize_model_artifact(
&self,
model: &Self::Model,
artifact_kind: ArtifactKind,
head_id: HeadId,
base_head_id: Option<HeadId>,
store: &FsArtifactStore,
) -> anyhow::Result<ArtifactDescriptor> {
let staged_dir = tempfile::Builder::new()
.prefix("burn-p2p-python-materialized-artifact")
.tempdir()?;
let staged_path = staged_dir.path().join("artifact.safetensors");
self.client
.materialize_model_artifact_path(model.id(), &staged_path)
.context("materialize python model artifact")?;
let mut spec = ArtifactBuildSpec::new(
artifact_kind,
self.config.artifact_precision.clone(),
self.config.model_schema_hash.clone(),
self.config.artifact_record_format.clone(),
)
.with_head(head_id);
if let Some(base_head_id) = base_head_id {
spec = spec.with_base_head(base_head_id);
}
store
.store_artifact_file(&spec, &staged_path, self.config.artifact_chunking)
.map_err(Into::into)
}
fn contribution_metrics(
&self,
report: &WindowReport<Self::WindowStats>,
) -> BTreeMap<String, MetricValue> {
report.stats.clone()
}
fn contribution_weight(&self, report: &WindowReport<Self::WindowStats>) -> f64 {
standard_contribution_weight(&report.stats).unwrap_or(1.0)
}
fn reconcile_canonical_model(
&self,
local_model: &Self::Model,
canonical_model: Self::Model,
strategy: TrainerCanonicalReconcileStrategy,
) -> anyhow::Result<Self::Model> {
let canonical_model = canonical_model;
let returned_id = self.client.reconcile_canonical_model(
local_model.id(),
canonical_model.id(),
strategy,
)?;
debug_assert_eq!(returned_id, canonical_model.id());
Ok(canonical_model)
}
fn merge_candidate_models(
&self,
base_model: &Self::Model,
candidates: &[MergeModelCandidate<'_, Self::Model>],
policy: MergePolicy,
) -> anyhow::Result<Option<Self::Model>> {
let candidate_refs = candidates
.iter()
.map(|candidate| PythonMergeCandidateRef {
peer_id: candidate.peer_id.as_str(),
head_id: candidate.head_id.as_str(),
artifact_id: candidate.artifact_id.as_str(),
model_id: candidate.model.id(),
sample_weight: candidate.sample_weight,
quality_weight: candidate.quality_weight,
})
.collect::<Vec<_>>();
let merged =
self.client
.merge_candidate_models(base_model.id(), &candidate_refs, policy)?;
Ok(merged.map(|model_id| PythonModelHandle::new(model_id, self.client.clone())))
}
fn apply_single_root_ema(
&self,
base_model: &Self::Model,
merged_model: Self::Model,
policy: MergePolicy,
) -> anyhow::Result<Self::Model> {
let merged_model = merged_model;
let returned_id =
self.client
.apply_single_root_ema(base_model.id(), merged_model.id(), policy)?;
debug_assert_eq!(returned_id, merged_model.id());
Ok(merged_model)
}
fn supported_workload(&self) -> SupportedWorkload {
self.config.supported_workload.clone()
}
fn model_schema_hash(&self) -> ContentId {
self.config.model_schema_hash.clone()
}
}