Skip to main content

entrenar/lora/adapter/
merge_pipeline.rs

1//! Merge-Export-Publish pipeline (feature-gated: hub-publish)
2//!
3//! Merges LoRA/QLoRA adapters into base weights, exports, and publishes
4//! to HuggingFace Hub.
5
6use super::error::AdapterError;
7use super::merge_export::{merge_and_collect, merge_qlora_and_collect};
8use crate::hf_pipeline::publish::config::PublishConfig;
9use crate::hf_pipeline::publish::publisher::HfPublisher;
10use crate::hf_pipeline::publish::result::{PublishError, PublishResult};
11use crate::lora::{LoRALayer, QLoRALayer};
12use std::path::Path;
13
14/// Result of merge-export-publish pipeline
15#[derive(Debug, Clone)]
16pub struct MergePublishResult {
17    /// Number of layers merged
18    pub layers_merged: usize,
19    /// Publish result
20    pub publish: PublishResult,
21}
22
23/// Merge LoRA adapters, export as SafeTensors, and publish to HuggingFace Hub
24pub fn merge_export_publish(
25    layers: &[(&str, &LoRALayer)],
26    publish_config: PublishConfig,
27    output_dir: impl AsRef<Path>,
28) -> Result<MergePublishResult, MergePublishError> {
29    let output_dir = output_dir.as_ref();
30    let filename = "model.safetensors";
31
32    // Step 1: Merge
33    let merged = merge_and_collect(layers);
34    let layers_merged = merged.layers_merged;
35
36    // Step 2: Export as SafeTensors
37    let export_path = output_dir.join(filename);
38    std::fs::create_dir_all(output_dir)
39        .map_err(|e| MergePublishError::Merge(AdapterError::Io(e)))?;
40    merged.save_safetensors(&export_path).map_err(MergePublishError::Merge)?;
41
42    // Step 3: Publish
43    let publisher = HfPublisher::new(publish_config).map_err(MergePublishError::Publish)?;
44    let files: Vec<(&Path, &str)> = vec![(&export_path, filename)];
45
46    let publish = publisher.publish(&files, None).map_err(MergePublishError::Publish)?;
47
48    Ok(MergePublishResult { layers_merged, publish })
49}
50
51/// Merge QLoRA adapters, export as SafeTensors, and publish to HuggingFace Hub
52pub fn merge_qlora_export_publish(
53    layers: &[(&str, &QLoRALayer)],
54    publish_config: PublishConfig,
55    output_dir: impl AsRef<Path>,
56) -> Result<MergePublishResult, MergePublishError> {
57    let output_dir = output_dir.as_ref();
58    let filename = "model.safetensors";
59
60    let merged = merge_qlora_and_collect(layers);
61    let layers_merged = merged.layers_merged;
62
63    let export_path = output_dir.join(filename);
64    std::fs::create_dir_all(output_dir)
65        .map_err(|e| MergePublishError::Merge(AdapterError::Io(e)))?;
66    merged.save_safetensors(&export_path).map_err(MergePublishError::Merge)?;
67
68    let publisher = HfPublisher::new(publish_config).map_err(MergePublishError::Publish)?;
69    let files: Vec<(&Path, &str)> = vec![(&export_path, filename)];
70
71    let publish = publisher.publish(&files, None).map_err(MergePublishError::Publish)?;
72
73    Ok(MergePublishResult { layers_merged, publish })
74}
75
76/// Errors from the merge-export-publish pipeline
77#[derive(Debug, thiserror::Error)]
78pub enum MergePublishError {
79    /// Merge/export phase failed
80    #[error("Merge/export failed: {0}")]
81    Merge(AdapterError),
82
83    /// Publish phase failed
84    #[error("Publish failed: {0}")]
85    Publish(PublishError),
86}