1use std::{collections::BTreeMap, path::PathBuf};
2
3use anyhow::Context;
4use burn_p2p_checkpoint::{ArtifactBuildSpec, ChunkingScheme, FsArtifactStore};
5use burn_p2p_core::{
6 ArtifactDescriptor, ArtifactKind, AssignmentLease, CapabilityEstimate, ContentId, DatasetId,
7 DatasetManifest, DatasetView, DatasetViewId, HeadId, MergePolicy, MetricValue, Precision,
8 SupportedWorkload,
9};
10use burn_p2p_dataloader::{
11 CachedMicroShard, DatasetRegistration, DatasetSizing, MicroShardPlan, MicroShardPlanner,
12 MicroShardPlannerConfig, UpstreamAdapter,
13};
14use burn_p2p_experiment::{PatchSupport, RuntimePatch};
15use burn_p2p_workload::{
16 EvalSplit, LeaseDataPipeline, LeaseDataPipelineDescriptor, LeaseDataPipelineKind,
17 MergeModelCandidate, MetricReport, P2pWorkload, PatchOutcome, TrainError,
18 TrainerCanonicalReconcileStrategy, WindowCtx, WindowReport, local_upstream_root_for_pipeline,
19 standard_contribution_weight,
20};
21use chrono::Utc;
22use serde::{Deserialize, Serialize};
23use serde_json::Value;
24
25mod worker;
26
27use worker::{PythonMergeCandidateRef, PythonWorkerClient};
28
29#[derive(Clone, Debug, Serialize, Deserialize)]
30pub struct PythonTorchRuntimeConfig {
32 pub python_executable: PathBuf,
34 pub module_search_roots: Vec<PathBuf>,
36 pub workload_factory: String,
38 pub workload_config: Value,
40 pub env: BTreeMap<String, String>,
42}
43
44impl PythonTorchRuntimeConfig {
45 pub fn new(
47 python_executable: impl Into<PathBuf>,
48 workload_factory: impl Into<String>,
49 workload_config: Value,
50 ) -> Self {
51 Self {
52 python_executable: python_executable.into(),
53 module_search_roots: Vec::new(),
54 workload_factory: workload_factory.into(),
55 workload_config,
56 env: BTreeMap::new(),
57 }
58 }
59
60 pub fn with_module_search_root(mut self, root: impl Into<PathBuf>) -> Self {
62 self.module_search_roots.push(root.into());
63 self
64 }
65
66 pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
68 self.env.insert(key.into(), value.into());
69 self
70 }
71}
72
73#[derive(Clone, Debug, Serialize, Deserialize)]
74pub struct PythonTorchDatasetConfig {
76 pub root: PathBuf,
78 pub dataset_id: DatasetId,
80 pub dataset_view_id: DatasetViewId,
82 pub source_uri: String,
84 pub format: String,
86 pub manifest_hash: ContentId,
88 pub preprocessing_hash: ContentId,
90 pub tokenizer_hash: Option<ContentId>,
92 pub sizing: DatasetSizing,
94 pub planner: MicroShardPlannerConfig,
96 pub microshards_per_batch: usize,
98 pub metadata: BTreeMap<String, String>,
100}
101
102impl PythonTorchDatasetConfig {
103 pub fn registration(&self) -> DatasetRegistration {
105 DatasetRegistration {
106 manifest: DatasetManifest {
107 dataset_id: self.dataset_id.clone(),
108 source_uri: self.source_uri.clone(),
109 format: self.format.clone(),
110 manifest_hash: self.manifest_hash.clone(),
111 metadata: self.metadata.clone(),
112 },
113 view: DatasetView {
114 dataset_view_id: self.dataset_view_id.clone(),
115 dataset_id: self.dataset_id.clone(),
116 preprocessing_hash: self.preprocessing_hash.clone(),
117 tokenizer_hash: self.tokenizer_hash.clone(),
118 manifest_hash: self.manifest_hash.clone(),
119 metadata: self.metadata.clone(),
120 },
121 upstream: UpstreamAdapter::Local {
122 root: self.root.display().to_string(),
123 },
124 }
125 }
126
127 pub fn plan(&self) -> anyhow::Result<MicroShardPlan> {
129 let registration = self.registration();
130 Ok(MicroShardPlanner::new(self.planner.clone())?
131 .plan(®istration.view, self.sizing.clone())?)
132 }
133}
134
135#[derive(Clone, Debug)]
136pub struct PythonTorchWorkloadConfig {
138 pub runtime: PythonTorchRuntimeConfig,
140 pub dataset: PythonTorchDatasetConfig,
142 pub supported_workload: SupportedWorkload,
144 pub model_schema_hash: ContentId,
146 pub artifact_record_format: String,
148 pub artifact_precision: Precision,
150 pub artifact_chunking: ChunkingScheme,
152 pub patch_support: PatchSupport,
154}
155
156impl PythonTorchWorkloadConfig {
157 pub fn new(
159 runtime: PythonTorchRuntimeConfig,
160 dataset: PythonTorchDatasetConfig,
161 supported_workload: SupportedWorkload,
162 model_schema_hash: ContentId,
163 ) -> anyhow::Result<Self> {
164 Ok(Self {
165 runtime,
166 dataset,
167 supported_workload,
168 model_schema_hash,
169 artifact_record_format: "python-torch-safetensors".to_owned(),
170 artifact_precision: Precision::Fp32,
171 artifact_chunking: ChunkingScheme::new(256 * 1024)?,
172 patch_support: PatchSupport {
173 hot: false,
174 warm: false,
175 cold: false,
176 },
177 })
178 }
179}
180
181#[derive(Clone, Debug, Serialize, Deserialize)]
182#[serde(tag = "kind", rename_all = "snake_case")]
183pub enum PythonBatchRef {
185 CachedMicroshardGroup {
187 shard_paths: Vec<PathBuf>,
189 microshard_ids: Vec<String>,
191 ordinals: Vec<u32>,
193 bytes_len: u64,
195 },
196 MicroEpoch {
198 lease_id: String,
200 microshard_ids: Vec<String>,
202 ordinals: Vec<u32>,
204 bytes_len: u64,
206 pipeline_kind: LeaseDataPipelineKind,
208 payload: Value,
210 },
211}
212
213impl PythonBatchRef {
214 pub fn cached_microshard_group(group: &[CachedMicroShard]) -> Self {
216 Self::CachedMicroshardGroup {
217 shard_paths: group.iter().map(|entry| entry.path.clone()).collect(),
218 microshard_ids: group
219 .iter()
220 .map(|entry| entry.microshard.microshard_id.as_str().to_owned())
221 .collect(),
222 ordinals: group.iter().map(|entry| entry.microshard.ordinal).collect(),
223 bytes_len: group.iter().map(|entry| entry.bytes_len).sum(),
224 }
225 }
226
227 pub fn micro_epoch(
229 lease: &AssignmentLease,
230 cached_microshards: &[CachedMicroShard],
231 pipeline_kind: LeaseDataPipelineKind,
232 payload: Value,
233 ) -> Self {
234 Self::MicroEpoch {
235 lease_id: lease.lease_id.as_str().to_owned(),
236 microshard_ids: lease
237 .microshards
238 .iter()
239 .map(|microshard_id| microshard_id.as_str().to_owned())
240 .collect(),
241 ordinals: cached_microshards
242 .iter()
243 .map(|entry| entry.microshard.ordinal)
244 .collect(),
245 bytes_len: cached_microshards.iter().map(|entry| entry.bytes_len).sum(),
246 pipeline_kind,
247 payload,
248 }
249 }
250}
251
252#[derive(Debug)]
253pub struct PythonModelHandle {
255 id: String,
256 client: PythonWorkerClient,
257}
258
259impl PythonModelHandle {
260 fn new(id: String, client: PythonWorkerClient) -> Self {
261 Self { id, client }
262 }
263
264 fn id(&self) -> &str {
265 &self.id
266 }
267}
268
269impl Drop for PythonModelHandle {
270 fn drop(&mut self) {
271 self.client.release_model(&self.id);
272 }
273}
274
275#[derive(Clone, Debug)]
276pub struct PythonTorchProject {
278 client: PythonWorkerClient,
279 config: PythonTorchWorkloadConfig,
280 data_pipeline: LeaseDataPipeline<String, PythonBatchRef>,
281 workload_name: String,
282 runtime_device: String,
283 capability: CapabilityEstimate,
284}
285
286impl PythonTorchProject {
287 pub fn new(config: PythonTorchWorkloadConfig) -> anyhow::Result<Self> {
289 let data_pipeline = Self::sharded_data_pipeline(&config.dataset);
290 Self::new_with_data_pipeline(config, data_pipeline)
291 }
292
293 pub fn new_with_data_pipeline(
295 config: PythonTorchWorkloadConfig,
296 data_pipeline: LeaseDataPipeline<String, PythonBatchRef>,
297 ) -> anyhow::Result<Self> {
298 let client = PythonWorkerClient::spawn(&config.runtime)?;
299 let hello = client.hello()?;
300 if hello.protocol_version != 1 {
301 anyhow::bail!(
302 "python worker protocol mismatch: expected 1, got {}",
303 hello.protocol_version
304 );
305 }
306 let probe = client.capability_probe()?;
307 Ok(Self {
308 client,
309 config,
310 data_pipeline,
311 workload_name: hello.workload_name,
312 runtime_device: probe.runtime_device,
313 capability: probe.capability,
314 })
315 }
316
317 pub fn sharded_data_pipeline(
319 dataset: &PythonTorchDatasetConfig,
320 ) -> LeaseDataPipeline<String, PythonBatchRef> {
321 let registration = dataset.registration();
322 let microshard_plan = dataset.plan().expect("python dataset plan should resolve");
323 let group_size = dataset.microshards_per_batch.max(1);
324 LeaseDataPipeline::new(
325 LeaseDataPipelineDescriptor::new(
326 "python-sharded-dataset",
327 LeaseDataPipelineKind::ShardedStatic,
328 )
329 .with_metadata_entry("format", dataset.format.clone()),
330 move || Ok(registration.clone()),
331 move |_registration| Ok(microshard_plan.clone()),
332 move |_lease, cached_microshards, _device| {
333 let batch_count = cached_microshards.len().div_ceil(group_size).max(1);
334 let mut batches = Vec::with_capacity(batch_count);
335 for group in cached_microshards.chunks(group_size) {
336 batches.push(PythonBatchRef::cached_microshard_group(group));
337 }
338 Ok(batches)
339 },
340 )
341 }
342
343 pub fn micro_epoch_pipeline(
345 descriptor: LeaseDataPipelineDescriptor,
346 dataset_registration: impl Fn() -> anyhow::Result<DatasetRegistration> + Send + Sync + 'static,
347 microshard_plan: impl Fn(&DatasetRegistration) -> anyhow::Result<MicroShardPlan>
348 + Send
349 + Sync
350 + 'static,
351 payload: impl Fn(&AssignmentLease, &[CachedMicroShard]) -> anyhow::Result<Value>
352 + Send
353 + Sync
354 + 'static,
355 ) -> LeaseDataPipeline<String, PythonBatchRef> {
356 let pipeline_kind = descriptor.kind;
357 LeaseDataPipeline::new(
358 descriptor,
359 dataset_registration,
360 microshard_plan,
361 move |lease, cached_microshards, _device| {
362 Ok(vec![PythonBatchRef::micro_epoch(
363 lease,
364 cached_microshards,
365 pipeline_kind,
366 payload(lease, cached_microshards)?,
367 )])
368 },
369 )
370 }
371
372 pub fn indexed_dataset_pipeline(
374 pipeline_name: impl Into<String>,
375 dataset_registration: impl Fn() -> anyhow::Result<DatasetRegistration> + Send + Sync + 'static,
376 microshard_plan: impl Fn(&DatasetRegistration) -> anyhow::Result<MicroShardPlan>
377 + Send
378 + Sync
379 + 'static,
380 payload: impl Fn(&AssignmentLease, &[CachedMicroShard]) -> anyhow::Result<Value>
381 + Send
382 + Sync
383 + 'static,
384 ) -> LeaseDataPipeline<String, PythonBatchRef> {
385 Self::micro_epoch_pipeline(
386 LeaseDataPipelineDescriptor::new(pipeline_name, LeaseDataPipelineKind::IndexedDataset),
387 dataset_registration,
388 microshard_plan,
389 payload,
390 )
391 }
392
393 pub fn generated_dataset_pipeline(
395 pipeline_name: impl Into<String>,
396 dataset_registration: impl Fn() -> anyhow::Result<DatasetRegistration> + Send + Sync + 'static,
397 microshard_plan: impl Fn(&DatasetRegistration) -> anyhow::Result<MicroShardPlan>
398 + Send
399 + Sync
400 + 'static,
401 payload: impl Fn(&AssignmentLease, &[CachedMicroShard]) -> anyhow::Result<Value>
402 + Send
403 + Sync
404 + 'static,
405 ) -> LeaseDataPipeline<String, PythonBatchRef> {
406 Self::micro_epoch_pipeline(
407 LeaseDataPipelineDescriptor::new(
408 pipeline_name,
409 LeaseDataPipelineKind::GeneratedDataset,
410 ),
411 dataset_registration,
412 microshard_plan,
413 payload,
414 )
415 }
416
417 pub fn probe_capability(&self) -> &CapabilityEstimate {
419 &self.capability
420 }
421
422 pub fn runtime_device_name(&self) -> &str {
424 &self.runtime_device
425 }
426
427 pub fn workload_name(&self) -> &str {
429 &self.workload_name
430 }
431
432 pub fn data_pipeline_descriptor(&self) -> &LeaseDataPipelineDescriptor {
434 self.data_pipeline.descriptor()
435 }
436
437 pub fn data_pipeline_kind(&self) -> LeaseDataPipelineKind {
439 self.data_pipeline.kind()
440 }
441
442 pub fn data_pipeline_registration(&self) -> anyhow::Result<DatasetRegistration> {
444 self.data_pipeline.dataset_registration()
445 }
446
447 pub fn local_upstream_root(&self) -> anyhow::Result<Option<PathBuf>> {
450 local_upstream_root_for_pipeline(&self.data_pipeline)
451 }
452
453 pub fn configured_shard_root(&self) -> &std::path::Path {
457 &self.config.dataset.root
458 }
459}
460
461impl P2pWorkload for PythonTorchProject {
462 type Device = String;
463 type Model = PythonModelHandle;
464 type Batch = PythonBatchRef;
465 type WindowStats = BTreeMap<String, MetricValue>;
466
467 fn init_model(&self, device: &Self::Device) -> Self::Model {
468 let model_id = self
469 .client
470 .init_model(device)
471 .expect("python worker should initialize a model");
472 PythonModelHandle::new(model_id, self.client.clone())
473 }
474
475 fn benchmark(&self, _model: &Self::Model, _device: &Self::Device) -> CapabilityEstimate {
476 self.capability.clone()
477 }
478
479 fn train_window(
480 &self,
481 ctx: &mut WindowCtx<Self::Device, Self::Model, Self::Batch>,
482 ) -> Result<WindowReport<Self::WindowStats>, TrainError> {
483 let mut metrics = self
484 .client
485 .train_window(ctx.model.id(), &ctx.batches)
486 .map_err(|error| TrainError::new(error.to_string()))?;
487 metrics.insert(
488 "batch_count".into(),
489 MetricValue::Integer(ctx.batches.len() as i64),
490 );
491 let examples_processed = ctx
492 .cached_microshards
493 .iter()
494 .map(|cached| cached.microshard.estimated_examples)
495 .sum::<u64>();
496 let tokens_processed = ctx
497 .cached_microshards
498 .iter()
499 .map(|cached| cached.microshard.estimated_tokens)
500 .sum::<u64>();
501 if examples_processed > 0 {
502 metrics.insert(
503 "examples_processed".into(),
504 MetricValue::Integer(examples_processed as i64),
505 );
506 }
507 if tokens_processed > 0 {
508 metrics.insert(
509 "tokens_processed".into(),
510 MetricValue::Integer(tokens_processed as i64),
511 );
512 }
513 if !ctx.cached_microshards.is_empty() {
514 metrics.insert(
515 "microshard_count".into(),
516 MetricValue::Integer(ctx.cached_microshards.len() as i64),
517 );
518 }
519 Ok(WindowReport {
520 contribution: None,
521 stats: metrics,
522 completed_at: Utc::now(),
523 })
524 }
525
526 fn evaluate(&self, model: &Self::Model, split: EvalSplit) -> MetricReport {
527 let metrics = self
528 .client
529 .evaluate(model.id(), split)
530 .unwrap_or_else(|error| {
531 BTreeMap::from([("python_error".into(), MetricValue::Text(error.to_string()))])
532 });
533 MetricReport {
534 metrics,
535 captured_at: Utc::now(),
536 }
537 }
538
539 fn apply_patch(&mut self, patch: &RuntimePatch) -> PatchOutcome {
540 self.client
541 .apply_patch(patch)
542 .unwrap_or_else(|error| PatchOutcome::Rejected(error.to_string()))
543 }
544
545 fn supported_patch_classes(&self) -> PatchSupport {
546 self.config.patch_support
547 }
548
549 fn runtime_device(&self) -> Self::Device {
550 self.runtime_device.clone()
551 }
552
553 fn dataset_registration(&self) -> anyhow::Result<DatasetRegistration> {
554 self.data_pipeline.dataset_registration()
555 }
556
557 fn microshard_plan(
558 &self,
559 _registration: &DatasetRegistration,
560 ) -> anyhow::Result<MicroShardPlan> {
561 self.data_pipeline.microshard_plan(_registration)
562 }
563
564 fn load_batches(
565 &self,
566 lease: &AssignmentLease,
567 cached_microshards: &[CachedMicroShard],
568 ) -> anyhow::Result<Vec<Self::Batch>> {
569 self.data_pipeline
570 .load_batches(lease, cached_microshards, &self.runtime_device)
571 }
572
573 fn load_model_artifact(
574 &self,
575 model: Self::Model,
576 descriptor: &ArtifactDescriptor,
577 store: &FsArtifactStore,
578 _device: &Self::Device,
579 ) -> anyhow::Result<Self::Model> {
580 let staged_dir = tempfile::Builder::new()
581 .prefix("burn-p2p-python-load-artifact")
582 .tempdir()?;
583 let staged_path = staged_dir.path().join("artifact.safetensors");
584 store.materialize_artifact_file(descriptor, &staged_path)?;
585 self.client
586 .load_model_artifact_path(model.id(), &staged_path)
587 .context("load python model artifact into worker")?;
588 Ok(model)
589 }
590
591 fn materialize_model_artifact(
592 &self,
593 model: &Self::Model,
594 artifact_kind: ArtifactKind,
595 head_id: HeadId,
596 base_head_id: Option<HeadId>,
597 store: &FsArtifactStore,
598 ) -> anyhow::Result<ArtifactDescriptor> {
599 let staged_dir = tempfile::Builder::new()
600 .prefix("burn-p2p-python-materialized-artifact")
601 .tempdir()?;
602 let staged_path = staged_dir.path().join("artifact.safetensors");
603 self.client
604 .materialize_model_artifact_path(model.id(), &staged_path)
605 .context("materialize python model artifact")?;
606 let mut spec = ArtifactBuildSpec::new(
607 artifact_kind,
608 self.config.artifact_precision.clone(),
609 self.config.model_schema_hash.clone(),
610 self.config.artifact_record_format.clone(),
611 )
612 .with_head(head_id);
613 if let Some(base_head_id) = base_head_id {
614 spec = spec.with_base_head(base_head_id);
615 }
616 store
617 .store_artifact_file(&spec, &staged_path, self.config.artifact_chunking)
618 .map_err(Into::into)
619 }
620
621 fn contribution_metrics(
622 &self,
623 report: &WindowReport<Self::WindowStats>,
624 ) -> BTreeMap<String, MetricValue> {
625 report.stats.clone()
626 }
627
628 fn contribution_weight(&self, report: &WindowReport<Self::WindowStats>) -> f64 {
629 standard_contribution_weight(&report.stats).unwrap_or(1.0)
630 }
631
632 fn reconcile_canonical_model(
633 &self,
634 local_model: &Self::Model,
635 canonical_model: Self::Model,
636 strategy: TrainerCanonicalReconcileStrategy,
637 ) -> anyhow::Result<Self::Model> {
638 let canonical_model = canonical_model;
639 let returned_id = self.client.reconcile_canonical_model(
640 local_model.id(),
641 canonical_model.id(),
642 strategy,
643 )?;
644 debug_assert_eq!(returned_id, canonical_model.id());
645 Ok(canonical_model)
646 }
647
648 fn merge_candidate_models(
649 &self,
650 base_model: &Self::Model,
651 candidates: &[MergeModelCandidate<'_, Self::Model>],
652 policy: MergePolicy,
653 ) -> anyhow::Result<Option<Self::Model>> {
654 let candidate_refs = candidates
655 .iter()
656 .map(|candidate| PythonMergeCandidateRef {
657 peer_id: candidate.peer_id.as_str(),
658 head_id: candidate.head_id.as_str(),
659 artifact_id: candidate.artifact_id.as_str(),
660 model_id: candidate.model.id(),
661 sample_weight: candidate.sample_weight,
662 quality_weight: candidate.quality_weight,
663 })
664 .collect::<Vec<_>>();
665 let merged =
666 self.client
667 .merge_candidate_models(base_model.id(), &candidate_refs, policy)?;
668 Ok(merged.map(|model_id| PythonModelHandle::new(model_id, self.client.clone())))
669 }
670
671 fn apply_single_root_ema(
672 &self,
673 base_model: &Self::Model,
674 merged_model: Self::Model,
675 policy: MergePolicy,
676 ) -> anyhow::Result<Self::Model> {
677 let merged_model = merged_model;
678 let returned_id =
679 self.client
680 .apply_single_root_ema(base_model.id(), merged_model.id(), policy)?;
681 debug_assert_eq!(returned_id, merged_model.id());
682 Ok(merged_model)
683 }
684
685 fn supported_workload(&self) -> SupportedWorkload {
686 self.config.supported_workload.clone()
687 }
688
689 fn model_schema_hash(&self) -> ContentId {
690 self.config.model_schema_hash.clone()
691 }
692}