Skip to main content

burn_p2p_python/
lib.rs

1use std::{collections::BTreeMap, path::PathBuf};
2
3use anyhow::Context;
4use burn_p2p_checkpoint::{ArtifactBuildSpec, ChunkingScheme, FsArtifactStore};
5use burn_p2p_core::{
6    ArtifactDescriptor, ArtifactKind, AssignmentLease, CapabilityEstimate, ContentId, DatasetId,
7    DatasetManifest, DatasetView, DatasetViewId, HeadId, MergePolicy, MetricValue, Precision,
8    SupportedWorkload,
9};
10use burn_p2p_dataloader::{
11    CachedMicroShard, DatasetRegistration, DatasetSizing, MicroShardPlan, MicroShardPlanner,
12    MicroShardPlannerConfig, UpstreamAdapter,
13};
14use burn_p2p_experiment::{PatchSupport, RuntimePatch};
15use burn_p2p_workload::{
16    EvalSplit, LeaseDataPipeline, LeaseDataPipelineDescriptor, LeaseDataPipelineKind,
17    MergeModelCandidate, MetricReport, P2pWorkload, PatchOutcome, TrainError,
18    TrainerCanonicalReconcileStrategy, WindowCtx, WindowReport, local_upstream_root_for_pipeline,
19    standard_contribution_weight,
20};
21use chrono::Utc;
22use serde::{Deserialize, Serialize};
23use serde_json::Value;
24
25mod worker;
26
27use worker::{PythonMergeCandidateRef, PythonWorkerClient};
28
29#[derive(Clone, Debug, Serialize, Deserialize)]
30/// Configures how the Rust runtime launches the Python worker process.
31pub struct PythonTorchRuntimeConfig {
32    /// Python executable used to launch the runtime worker.
33    pub python_executable: PathBuf,
34    /// Additional module roots appended to `PYTHONPATH`.
35    pub module_search_roots: Vec<PathBuf>,
36    /// Python workload factory in `module:attr` form.
37    pub workload_factory: String,
38    /// JSON config passed directly to the Python workload.
39    pub workload_config: Value,
40    /// Additional environment variables for the worker process.
41    pub env: BTreeMap<String, String>,
42}
43
44impl PythonTorchRuntimeConfig {
45    /// Creates a new config for the provided Python workload factory.
46    pub fn new(
47        python_executable: impl Into<PathBuf>,
48        workload_factory: impl Into<String>,
49        workload_config: Value,
50    ) -> Self {
51        Self {
52            python_executable: python_executable.into(),
53            module_search_roots: Vec::new(),
54            workload_factory: workload_factory.into(),
55            workload_config,
56            env: BTreeMap::new(),
57        }
58    }
59
60    /// Adds one extra Python import root.
61    pub fn with_module_search_root(mut self, root: impl Into<PathBuf>) -> Self {
62        self.module_search_roots.push(root.into());
63        self
64    }
65
66    /// Adds one environment override for the worker process.
67    pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
68        self.env.insert(key.into(), value.into());
69        self
70    }
71}
72
73#[derive(Clone, Debug, Serialize, Deserialize)]
74/// Declares the shard-backed dataset view exposed to the p2p runtime.
75pub struct PythonTorchDatasetConfig {
76    /// Root containing `fetch-manifest.json` and shard files.
77    pub root: PathBuf,
78    /// Stable dataset id.
79    pub dataset_id: DatasetId,
80    /// Stable dataset view id.
81    pub dataset_view_id: DatasetViewId,
82    /// Source URI surfaced in dataset metadata.
83    pub source_uri: String,
84    /// Dataset format tag.
85    pub format: String,
86    /// Dataset manifest hash.
87    pub manifest_hash: ContentId,
88    /// Preprocessing hash for the view.
89    pub preprocessing_hash: ContentId,
90    /// Optional tokenizer hash.
91    pub tokenizer_hash: Option<ContentId>,
92    /// Dataset sizing used to plan microshards.
93    pub sizing: DatasetSizing,
94    /// Planner config used to derive microshard ids.
95    pub planner: MicroShardPlannerConfig,
96    /// Number of cached microshards grouped into one Python batch ref.
97    pub microshards_per_batch: usize,
98    /// Arbitrary dataset metadata propagated into the registration.
99    pub metadata: BTreeMap<String, String>,
100}
101
102impl PythonTorchDatasetConfig {
103    /// Returns a local-upstream dataset registration.
104    pub fn registration(&self) -> DatasetRegistration {
105        DatasetRegistration {
106            manifest: DatasetManifest {
107                dataset_id: self.dataset_id.clone(),
108                source_uri: self.source_uri.clone(),
109                format: self.format.clone(),
110                manifest_hash: self.manifest_hash.clone(),
111                metadata: self.metadata.clone(),
112            },
113            view: DatasetView {
114                dataset_view_id: self.dataset_view_id.clone(),
115                dataset_id: self.dataset_id.clone(),
116                preprocessing_hash: self.preprocessing_hash.clone(),
117                tokenizer_hash: self.tokenizer_hash.clone(),
118                manifest_hash: self.manifest_hash.clone(),
119                metadata: self.metadata.clone(),
120            },
121            upstream: UpstreamAdapter::Local {
122                root: self.root.display().to_string(),
123            },
124        }
125    }
126
127    /// Plans the microshards for this dataset view.
128    pub fn plan(&self) -> anyhow::Result<MicroShardPlan> {
129        let registration = self.registration();
130        Ok(MicroShardPlanner::new(self.planner.clone())?
131            .plan(&registration.view, self.sizing.clone())?)
132    }
133}
134
135#[derive(Clone, Debug)]
136/// Static workload identity and artifact settings for one Python/Torch runtime.
137pub struct PythonTorchWorkloadConfig {
138    /// Python worker launch config.
139    pub runtime: PythonTorchRuntimeConfig,
140    /// Dataset/shard config.
141    pub dataset: PythonTorchDatasetConfig,
142    /// Workload identity published into the release manifest.
143    pub supported_workload: SupportedWorkload,
144    /// Stable model schema hash.
145    pub model_schema_hash: ContentId,
146    /// Artifact record format tag.
147    pub artifact_record_format: String,
148    /// Descriptor precision published with model artifacts.
149    pub artifact_precision: Precision,
150    /// Chunking policy for stored model artifacts.
151    pub artifact_chunking: ChunkingScheme,
152    /// Patch support advertised by the workload.
153    pub patch_support: PatchSupport,
154}
155
156impl PythonTorchWorkloadConfig {
157    /// Creates a new config with a safetensors-backed artifact default.
158    pub fn new(
159        runtime: PythonTorchRuntimeConfig,
160        dataset: PythonTorchDatasetConfig,
161        supported_workload: SupportedWorkload,
162        model_schema_hash: ContentId,
163    ) -> anyhow::Result<Self> {
164        Ok(Self {
165            runtime,
166            dataset,
167            supported_workload,
168            model_schema_hash,
169            artifact_record_format: "python-torch-safetensors".to_owned(),
170            artifact_precision: Precision::Fp32,
171            artifact_chunking: ChunkingScheme::new(256 * 1024)?,
172            patch_support: PatchSupport {
173                hot: false,
174                warm: false,
175                cold: false,
176            },
177        })
178    }
179}
180
181#[derive(Clone, Debug, Serialize, Deserialize)]
182#[serde(tag = "kind", rename_all = "snake_case")]
183/// One lease-loaded micro-epoch batch passed to the Python worker.
184pub enum PythonBatchRef {
185    /// Cached shard group materialized by the Rust shard cache.
186    CachedMicroshardGroup {
187        /// Paths of the cached microshards that should be consumed together.
188        shard_paths: Vec<PathBuf>,
189        /// Stable microshard ids included in this batch group.
190        microshard_ids: Vec<String>,
191        /// Ordinals of the grouped microshards.
192        ordinals: Vec<u32>,
193        /// Total byte size represented by the group.
194        bytes_len: u64,
195    },
196    /// Lease-scoped micro-epoch descriptor rebuilt inside Python.
197    MicroEpoch {
198        /// Lease id driving this micro-epoch.
199        lease_id: String,
200        /// Stable microshard ids included in the lease.
201        microshard_ids: Vec<String>,
202        /// Ordinals of the microshards included in the lease.
203        ordinals: Vec<u32>,
204        /// Total byte size represented by the cached microshards, when known.
205        bytes_len: u64,
206        /// High-level pipeline kind associated with this lease.
207        pipeline_kind: LeaseDataPipelineKind,
208        /// Opaque workload-specific payload consumed by Python.
209        payload: Value,
210    },
211}
212
213impl PythonBatchRef {
214    /// Builds a shard-backed batch ref.
215    pub fn cached_microshard_group(group: &[CachedMicroShard]) -> Self {
216        Self::CachedMicroshardGroup {
217            shard_paths: group.iter().map(|entry| entry.path.clone()).collect(),
218            microshard_ids: group
219                .iter()
220                .map(|entry| entry.microshard.microshard_id.as_str().to_owned())
221                .collect(),
222            ordinals: group.iter().map(|entry| entry.microshard.ordinal).collect(),
223            bytes_len: group.iter().map(|entry| entry.bytes_len).sum(),
224        }
225    }
226
227    /// Builds a generic micro-epoch descriptor for Python-side dataloader reconstruction.
228    pub fn micro_epoch(
229        lease: &AssignmentLease,
230        cached_microshards: &[CachedMicroShard],
231        pipeline_kind: LeaseDataPipelineKind,
232        payload: Value,
233    ) -> Self {
234        Self::MicroEpoch {
235            lease_id: lease.lease_id.as_str().to_owned(),
236            microshard_ids: lease
237                .microshards
238                .iter()
239                .map(|microshard_id| microshard_id.as_str().to_owned())
240                .collect(),
241            ordinals: cached_microshards
242                .iter()
243                .map(|entry| entry.microshard.ordinal)
244                .collect(),
245            bytes_len: cached_microshards.iter().map(|entry| entry.bytes_len).sum(),
246            pipeline_kind,
247            payload,
248        }
249    }
250}
251
252#[derive(Debug)]
253/// Worker-owned model/optimizer state referenced by one opaque handle id.
254pub struct PythonModelHandle {
255    id: String,
256    client: PythonWorkerClient,
257}
258
259impl PythonModelHandle {
260    fn new(id: String, client: PythonWorkerClient) -> Self {
261        Self { id, client }
262    }
263
264    fn id(&self) -> &str {
265        &self.id
266    }
267}
268
269impl Drop for PythonModelHandle {
270    fn drop(&mut self) {
271        self.client.release_model(&self.id);
272    }
273}
274
275#[derive(Clone, Debug)]
276/// Python/Torch-backed workload bridge implemented on top of the generic p2p runtime.
277pub struct PythonTorchProject {
278    client: PythonWorkerClient,
279    config: PythonTorchWorkloadConfig,
280    data_pipeline: LeaseDataPipeline<String, PythonBatchRef>,
281    workload_name: String,
282    runtime_device: String,
283    capability: CapabilityEstimate,
284}
285
286impl PythonTorchProject {
287    /// Spawns the backing Python worker and probes its runtime capability.
288    pub fn new(config: PythonTorchWorkloadConfig) -> anyhow::Result<Self> {
289        let data_pipeline = Self::sharded_data_pipeline(&config.dataset);
290        Self::new_with_data_pipeline(config, data_pipeline)
291    }
292
293    /// Spawns the backing Python worker with an explicit lease data pipeline.
294    pub fn new_with_data_pipeline(
295        config: PythonTorchWorkloadConfig,
296        data_pipeline: LeaseDataPipeline<String, PythonBatchRef>,
297    ) -> anyhow::Result<Self> {
298        let client = PythonWorkerClient::spawn(&config.runtime)?;
299        let hello = client.hello()?;
300        if hello.protocol_version != 1 {
301            anyhow::bail!(
302                "python worker protocol mismatch: expected 1, got {}",
303                hello.protocol_version
304            );
305        }
306        let probe = client.capability_probe()?;
307        Ok(Self {
308            client,
309            config,
310            data_pipeline,
311            workload_name: hello.workload_name,
312            runtime_device: probe.runtime_device,
313            capability: probe.capability,
314        })
315    }
316
317    /// Returns the default shard-backed lease data pipeline for one dataset config.
318    pub fn sharded_data_pipeline(
319        dataset: &PythonTorchDatasetConfig,
320    ) -> LeaseDataPipeline<String, PythonBatchRef> {
321        let registration = dataset.registration();
322        let microshard_plan = dataset.plan().expect("python dataset plan should resolve");
323        let group_size = dataset.microshards_per_batch.max(1);
324        LeaseDataPipeline::new(
325            LeaseDataPipelineDescriptor::new(
326                "python-sharded-dataset",
327                LeaseDataPipelineKind::ShardedStatic,
328            )
329            .with_metadata_entry("format", dataset.format.clone()),
330            move || Ok(registration.clone()),
331            move |_registration| Ok(microshard_plan.clone()),
332            move |_lease, cached_microshards, _device| {
333                let batch_count = cached_microshards.len().div_ceil(group_size).max(1);
334                let mut batches = Vec::with_capacity(batch_count);
335                for group in cached_microshards.chunks(group_size) {
336                    batches.push(PythonBatchRef::cached_microshard_group(group));
337                }
338                Ok(batches)
339            },
340        )
341    }
342
343    /// Builds a generic Python micro-epoch pipeline backed by workload-defined payloads.
344    pub fn micro_epoch_pipeline(
345        descriptor: LeaseDataPipelineDescriptor,
346        dataset_registration: impl Fn() -> anyhow::Result<DatasetRegistration> + Send + Sync + 'static,
347        microshard_plan: impl Fn(&DatasetRegistration) -> anyhow::Result<MicroShardPlan>
348        + Send
349        + Sync
350        + 'static,
351        payload: impl Fn(&AssignmentLease, &[CachedMicroShard]) -> anyhow::Result<Value>
352        + Send
353        + Sync
354        + 'static,
355    ) -> LeaseDataPipeline<String, PythonBatchRef> {
356        let pipeline_kind = descriptor.kind;
357        LeaseDataPipeline::new(
358            descriptor,
359            dataset_registration,
360            microshard_plan,
361            move |lease, cached_microshards, _device| {
362                Ok(vec![PythonBatchRef::micro_epoch(
363                    lease,
364                    cached_microshards,
365                    pipeline_kind,
366                    payload(lease, cached_microshards)?,
367                )])
368            },
369        )
370    }
371
372    /// Builds a Python pipeline for existing torch `Dataset`/`Sampler`-style data flows.
373    pub fn indexed_dataset_pipeline(
374        pipeline_name: impl Into<String>,
375        dataset_registration: impl Fn() -> anyhow::Result<DatasetRegistration> + Send + Sync + 'static,
376        microshard_plan: impl Fn(&DatasetRegistration) -> anyhow::Result<MicroShardPlan>
377        + Send
378        + Sync
379        + 'static,
380        payload: impl Fn(&AssignmentLease, &[CachedMicroShard]) -> anyhow::Result<Value>
381        + Send
382        + Sync
383        + 'static,
384    ) -> LeaseDataPipeline<String, PythonBatchRef> {
385        Self::micro_epoch_pipeline(
386            LeaseDataPipelineDescriptor::new(pipeline_name, LeaseDataPipelineKind::IndexedDataset),
387            dataset_registration,
388            microshard_plan,
389            payload,
390        )
391    }
392
393    /// Builds a Python pipeline for deterministic synthetic or recipe-driven data generation.
394    pub fn generated_dataset_pipeline(
395        pipeline_name: impl Into<String>,
396        dataset_registration: impl Fn() -> anyhow::Result<DatasetRegistration> + Send + Sync + 'static,
397        microshard_plan: impl Fn(&DatasetRegistration) -> anyhow::Result<MicroShardPlan>
398        + Send
399        + Sync
400        + 'static,
401        payload: impl Fn(&AssignmentLease, &[CachedMicroShard]) -> anyhow::Result<Value>
402        + Send
403        + Sync
404        + 'static,
405    ) -> LeaseDataPipeline<String, PythonBatchRef> {
406        Self::micro_epoch_pipeline(
407            LeaseDataPipelineDescriptor::new(
408                pipeline_name,
409                LeaseDataPipelineKind::GeneratedDataset,
410            ),
411            dataset_registration,
412            microshard_plan,
413            payload,
414        )
415    }
416
417    /// Returns the worker-advertised capability estimate.
418    pub fn probe_capability(&self) -> &CapabilityEstimate {
419        &self.capability
420    }
421
422    /// Returns the resolved runtime device tag.
423    pub fn runtime_device_name(&self) -> &str {
424        &self.runtime_device
425    }
426
427    /// Returns the Python-side workload name advertised by the worker.
428    pub fn workload_name(&self) -> &str {
429        &self.workload_name
430    }
431
432    /// Returns the static lease/micro-epoch data pipeline descriptor.
433    pub fn data_pipeline_descriptor(&self) -> &LeaseDataPipelineDescriptor {
434        self.data_pipeline.descriptor()
435    }
436
437    /// Returns the configured lease/micro-epoch pipeline kind.
438    pub fn data_pipeline_kind(&self) -> LeaseDataPipelineKind {
439        self.data_pipeline.kind()
440    }
441
442    /// Returns the dataset registration backing the current pipeline.
443    pub fn data_pipeline_registration(&self) -> anyhow::Result<DatasetRegistration> {
444        self.data_pipeline.dataset_registration()
445    }
446
447    /// Returns the local upstream root when the current pipeline is backed by
448    /// a `Local` dataset registration.
449    pub fn local_upstream_root(&self) -> anyhow::Result<Option<PathBuf>> {
450        local_upstream_root_for_pipeline(&self.data_pipeline)
451    }
452
453    /// Returns the configured shard root for the default sharded Python dataset
454    /// config. This is configuration data only and may be unrelated to the
455    /// active pipeline when `new_with_data_pipeline(...)` is used.
456    pub fn configured_shard_root(&self) -> &std::path::Path {
457        &self.config.dataset.root
458    }
459}
460
461impl P2pWorkload for PythonTorchProject {
462    type Device = String;
463    type Model = PythonModelHandle;
464    type Batch = PythonBatchRef;
465    type WindowStats = BTreeMap<String, MetricValue>;
466
467    fn init_model(&self, device: &Self::Device) -> Self::Model {
468        let model_id = self
469            .client
470            .init_model(device)
471            .expect("python worker should initialize a model");
472        PythonModelHandle::new(model_id, self.client.clone())
473    }
474
475    fn benchmark(&self, _model: &Self::Model, _device: &Self::Device) -> CapabilityEstimate {
476        self.capability.clone()
477    }
478
479    fn train_window(
480        &self,
481        ctx: &mut WindowCtx<Self::Device, Self::Model, Self::Batch>,
482    ) -> Result<WindowReport<Self::WindowStats>, TrainError> {
483        let mut metrics = self
484            .client
485            .train_window(ctx.model.id(), &ctx.batches)
486            .map_err(|error| TrainError::new(error.to_string()))?;
487        metrics.insert(
488            "batch_count".into(),
489            MetricValue::Integer(ctx.batches.len() as i64),
490        );
491        let examples_processed = ctx
492            .cached_microshards
493            .iter()
494            .map(|cached| cached.microshard.estimated_examples)
495            .sum::<u64>();
496        let tokens_processed = ctx
497            .cached_microshards
498            .iter()
499            .map(|cached| cached.microshard.estimated_tokens)
500            .sum::<u64>();
501        if examples_processed > 0 {
502            metrics.insert(
503                "examples_processed".into(),
504                MetricValue::Integer(examples_processed as i64),
505            );
506        }
507        if tokens_processed > 0 {
508            metrics.insert(
509                "tokens_processed".into(),
510                MetricValue::Integer(tokens_processed as i64),
511            );
512        }
513        if !ctx.cached_microshards.is_empty() {
514            metrics.insert(
515                "microshard_count".into(),
516                MetricValue::Integer(ctx.cached_microshards.len() as i64),
517            );
518        }
519        Ok(WindowReport {
520            contribution: None,
521            stats: metrics,
522            completed_at: Utc::now(),
523        })
524    }
525
526    fn evaluate(&self, model: &Self::Model, split: EvalSplit) -> MetricReport {
527        let metrics = self
528            .client
529            .evaluate(model.id(), split)
530            .unwrap_or_else(|error| {
531                BTreeMap::from([("python_error".into(), MetricValue::Text(error.to_string()))])
532            });
533        MetricReport {
534            metrics,
535            captured_at: Utc::now(),
536        }
537    }
538
539    fn apply_patch(&mut self, patch: &RuntimePatch) -> PatchOutcome {
540        self.client
541            .apply_patch(patch)
542            .unwrap_or_else(|error| PatchOutcome::Rejected(error.to_string()))
543    }
544
545    fn supported_patch_classes(&self) -> PatchSupport {
546        self.config.patch_support
547    }
548
549    fn runtime_device(&self) -> Self::Device {
550        self.runtime_device.clone()
551    }
552
553    fn dataset_registration(&self) -> anyhow::Result<DatasetRegistration> {
554        self.data_pipeline.dataset_registration()
555    }
556
557    fn microshard_plan(
558        &self,
559        _registration: &DatasetRegistration,
560    ) -> anyhow::Result<MicroShardPlan> {
561        self.data_pipeline.microshard_plan(_registration)
562    }
563
564    fn load_batches(
565        &self,
566        lease: &AssignmentLease,
567        cached_microshards: &[CachedMicroShard],
568    ) -> anyhow::Result<Vec<Self::Batch>> {
569        self.data_pipeline
570            .load_batches(lease, cached_microshards, &self.runtime_device)
571    }
572
573    fn load_model_artifact(
574        &self,
575        model: Self::Model,
576        descriptor: &ArtifactDescriptor,
577        store: &FsArtifactStore,
578        _device: &Self::Device,
579    ) -> anyhow::Result<Self::Model> {
580        let staged_dir = tempfile::Builder::new()
581            .prefix("burn-p2p-python-load-artifact")
582            .tempdir()?;
583        let staged_path = staged_dir.path().join("artifact.safetensors");
584        store.materialize_artifact_file(descriptor, &staged_path)?;
585        self.client
586            .load_model_artifact_path(model.id(), &staged_path)
587            .context("load python model artifact into worker")?;
588        Ok(model)
589    }
590
591    fn materialize_model_artifact(
592        &self,
593        model: &Self::Model,
594        artifact_kind: ArtifactKind,
595        head_id: HeadId,
596        base_head_id: Option<HeadId>,
597        store: &FsArtifactStore,
598    ) -> anyhow::Result<ArtifactDescriptor> {
599        let staged_dir = tempfile::Builder::new()
600            .prefix("burn-p2p-python-materialized-artifact")
601            .tempdir()?;
602        let staged_path = staged_dir.path().join("artifact.safetensors");
603        self.client
604            .materialize_model_artifact_path(model.id(), &staged_path)
605            .context("materialize python model artifact")?;
606        let mut spec = ArtifactBuildSpec::new(
607            artifact_kind,
608            self.config.artifact_precision.clone(),
609            self.config.model_schema_hash.clone(),
610            self.config.artifact_record_format.clone(),
611        )
612        .with_head(head_id);
613        if let Some(base_head_id) = base_head_id {
614            spec = spec.with_base_head(base_head_id);
615        }
616        store
617            .store_artifact_file(&spec, &staged_path, self.config.artifact_chunking)
618            .map_err(Into::into)
619    }
620
621    fn contribution_metrics(
622        &self,
623        report: &WindowReport<Self::WindowStats>,
624    ) -> BTreeMap<String, MetricValue> {
625        report.stats.clone()
626    }
627
628    fn contribution_weight(&self, report: &WindowReport<Self::WindowStats>) -> f64 {
629        standard_contribution_weight(&report.stats).unwrap_or(1.0)
630    }
631
632    fn reconcile_canonical_model(
633        &self,
634        local_model: &Self::Model,
635        canonical_model: Self::Model,
636        strategy: TrainerCanonicalReconcileStrategy,
637    ) -> anyhow::Result<Self::Model> {
638        let canonical_model = canonical_model;
639        let returned_id = self.client.reconcile_canonical_model(
640            local_model.id(),
641            canonical_model.id(),
642            strategy,
643        )?;
644        debug_assert_eq!(returned_id, canonical_model.id());
645        Ok(canonical_model)
646    }
647
648    fn merge_candidate_models(
649        &self,
650        base_model: &Self::Model,
651        candidates: &[MergeModelCandidate<'_, Self::Model>],
652        policy: MergePolicy,
653    ) -> anyhow::Result<Option<Self::Model>> {
654        let candidate_refs = candidates
655            .iter()
656            .map(|candidate| PythonMergeCandidateRef {
657                peer_id: candidate.peer_id.as_str(),
658                head_id: candidate.head_id.as_str(),
659                artifact_id: candidate.artifact_id.as_str(),
660                model_id: candidate.model.id(),
661                sample_weight: candidate.sample_weight,
662                quality_weight: candidate.quality_weight,
663            })
664            .collect::<Vec<_>>();
665        let merged =
666            self.client
667                .merge_candidate_models(base_model.id(), &candidate_refs, policy)?;
668        Ok(merged.map(|model_id| PythonModelHandle::new(model_id, self.client.clone())))
669    }
670
671    fn apply_single_root_ema(
672        &self,
673        base_model: &Self::Model,
674        merged_model: Self::Model,
675        policy: MergePolicy,
676    ) -> anyhow::Result<Self::Model> {
677        let merged_model = merged_model;
678        let returned_id =
679            self.client
680                .apply_single_root_ema(base_model.id(), merged_model.id(), policy)?;
681        debug_assert_eq!(returned_id, merged_model.id());
682        Ok(merged_model)
683    }
684
685    fn supported_workload(&self) -> SupportedWorkload {
686        self.config.supported_workload.clone()
687    }
688
689    fn model_schema_hash(&self) -> ContentId {
690        self.config.model_schema_hash.clone()
691    }
692}