entrenar/lora/adapter/
merge_pipeline.rs1use super::error::AdapterError;
7use super::merge_export::{merge_and_collect, merge_qlora_and_collect};
8use crate::hf_pipeline::publish::config::PublishConfig;
9use crate::hf_pipeline::publish::publisher::HfPublisher;
10use crate::hf_pipeline::publish::result::{PublishError, PublishResult};
11use crate::lora::{LoRALayer, QLoRALayer};
12use std::path::Path;
13
14#[derive(Debug, Clone)]
16pub struct MergePublishResult {
17 pub layers_merged: usize,
19 pub publish: PublishResult,
21}
22
23pub fn merge_export_publish(
25 layers: &[(&str, &LoRALayer)],
26 publish_config: PublishConfig,
27 output_dir: impl AsRef<Path>,
28) -> Result<MergePublishResult, MergePublishError> {
29 let output_dir = output_dir.as_ref();
30 let filename = "model.safetensors";
31
32 let merged = merge_and_collect(layers);
34 let layers_merged = merged.layers_merged;
35
36 let export_path = output_dir.join(filename);
38 std::fs::create_dir_all(output_dir)
39 .map_err(|e| MergePublishError::Merge(AdapterError::Io(e)))?;
40 merged.save_safetensors(&export_path).map_err(MergePublishError::Merge)?;
41
42 let publisher = HfPublisher::new(publish_config).map_err(MergePublishError::Publish)?;
44 let files: Vec<(&Path, &str)> = vec![(&export_path, filename)];
45
46 let publish = publisher.publish(&files, None).map_err(MergePublishError::Publish)?;
47
48 Ok(MergePublishResult { layers_merged, publish })
49}
50
51pub fn merge_qlora_export_publish(
53 layers: &[(&str, &QLoRALayer)],
54 publish_config: PublishConfig,
55 output_dir: impl AsRef<Path>,
56) -> Result<MergePublishResult, MergePublishError> {
57 let output_dir = output_dir.as_ref();
58 let filename = "model.safetensors";
59
60 let merged = merge_qlora_and_collect(layers);
61 let layers_merged = merged.layers_merged;
62
63 let export_path = output_dir.join(filename);
64 std::fs::create_dir_all(output_dir)
65 .map_err(|e| MergePublishError::Merge(AdapterError::Io(e)))?;
66 merged.save_safetensors(&export_path).map_err(MergePublishError::Merge)?;
67
68 let publisher = HfPublisher::new(publish_config).map_err(MergePublishError::Publish)?;
69 let files: Vec<(&Path, &str)> = vec![(&export_path, filename)];
70
71 let publish = publisher.publish(&files, None).map_err(MergePublishError::Publish)?;
72
73 Ok(MergePublishResult { layers_merged, publish })
74}
75
76#[derive(Debug, thiserror::Error)]
78pub enum MergePublishError {
79 #[error("Merge/export failed: {0}")]
81 Merge(AdapterError),
82
83 #[error("Publish failed: {0}")]
85 Publish(PublishError),
86}