Skip to main content

cake_core/utils/
split.rs

1//! Model splitting utility — creates per-worker model bundles from a full model.
2
3use std::{
4    collections::HashMap,
5    fs::File,
6    path::Path,
7};
8
9use anyhow::Result;
10use safetensors::{Dtype, SafeTensors, View};
11use serde::{Deserialize, Serialize};
12
13use crate::{
14    cake::{Node, Topology},
15    utils, ModelType,
16};
17
18#[derive(Debug, Serialize, Deserialize)]
19struct Index {
20    pub weight_map: HashMap<String, String>,
21}
22
23impl Index {
24    pub fn new() -> Self {
25        let weight_map = HashMap::new();
26        Self { weight_map }
27    }
28}
29
30#[derive(Debug)]
31struct TensorStore {
32    dtype: Dtype,
33    shape: Vec<usize>,
34    data: Vec<u8>,
35}
36
37impl View for TensorStore {
38    fn dtype(&self) -> Dtype {
39        self.dtype
40    }
41
42    fn shape(&self) -> &[usize] {
43        &self.shape
44    }
45
46    fn data(&self) -> std::borrow::Cow<'_, [u8]> {
47        std::borrow::Cow::from(&self.data)
48    }
49
50    fn data_len(&self) -> usize {
51        self.data.len()
52    }
53}
54
55fn load_index(data_path: &Path) -> Result<Index> {
56    let tensors_index_path = data_path.join("model.safetensors.index.json");
57
58    if tensors_index_path.exists() {
59        let tensors_index_data = std::fs::read_to_string(tensors_index_path)?;
60        let tensors_index: Index = serde_json::from_str(&tensors_index_data)?;
61        Ok(tensors_index)
62    } else {
63        let single_path = data_path.join("model.safetensors");
64        if !single_path.exists() {
65            anyhow::bail!(
66                "neither model.safetensors.index.json nor model.safetensors found in {}",
67                data_path.display()
68            );
69        }
70
71        log::info!("no index file found, generating from model.safetensors ...");
72
73        let file = File::open(&single_path)?;
74        let buffer = unsafe { memmap2::MmapOptions::new().map(&file)? };
75        let tensors = SafeTensors::deserialize(&buffer)?;
76
77        let mut index = Index::new();
78        for (name, _) in tensors.tensors() {
79            index
80                .weight_map
81                .insert(name.to_string(), "model.safetensors".to_string());
82        }
83
84        Ok(index)
85    }
86}
87
88fn reduce_for_worker(
89    index: &Index,
90    worker: &Node,
91) -> Result<(Index, HashMap<String, Vec<String>>)> {
92    log::info!("worker: {}", &worker.host);
93
94    let mut reduced: HashMap<String, Vec<String>> = HashMap::new();
95    let mut new_index = Index::new();
96
97    for (layer_full_name, filename) in &index.weight_map {
98        if worker.is_text_model_layer_owner(layer_full_name) {
99            if let Some(layers) = reduced.get_mut(filename) {
100                layers.push(layer_full_name.to_string());
101            } else {
102                reduced.insert(filename.to_string(), vec![layer_full_name.to_string()]);
103            }
104
105            new_index.weight_map.insert(
106                layer_full_name.to_string(),
107                "reduced.safetensors".to_string(),
108            );
109        }
110    }
111
112    Ok((new_index, reduced))
113}
114
115fn create_new_metadata(
116    data_path: &Path,
117    reduced: &HashMap<String, Vec<String>>,
118) -> Result<HashMap<String, TensorStore>> {
119    let mut metadata: HashMap<String, TensorStore> = HashMap::new();
120
121    for (filename, tensor_names) in reduced {
122        let filepath = data_path.join(filename);
123
124        log::info!("loading {} ...", filepath.display());
125
126        let file = File::open(&filepath)?;
127        let buffer = unsafe { memmap2::MmapOptions::new().map(&file)? };
128        let tensors = SafeTensors::deserialize(&buffer)?;
129
130        log::info!("  extracting {} tensors", tensor_names.len());
131
132        for tensor_name in tensor_names {
133            let tensor = tensors.tensor(tensor_name)?;
134            metadata.insert(
135                tensor_name.to_string(),
136                TensorStore {
137                    dtype: tensor.dtype(),
138                    shape: tensor.shape().to_vec(),
139                    data: tensor.data().to_vec(),
140                },
141            );
142        }
143
144        drop(tensors);
145        drop(buffer);
146    }
147
148    Ok(metadata)
149}
150
151/// Split a model into per-worker bundles.
152///
153/// Each bundle contains a reduced safetensors file with only the worker's assigned tensors,
154/// a matching index file, and the worker's topology.
155pub fn split_model(
156    model_path: &Path,
157    topology_path: &str,
158    worker: Option<&str>,
159    output: &Path,
160) -> Result<()> {
161    let topology = Topology::from_path(topology_path, &ModelType::TextModel)?;
162    let index = load_index(model_path)?;
163
164    log::info!("index has {} tensors", index.weight_map.len());
165
166    let selected_workers: Vec<String> = if let Some(name) = worker {
167        vec![name.to_string()]
168    } else {
169        topology.keys().map(|s| s.to_string()).collect()
170    };
171
172    log::info!("processing {} workers", selected_workers.len());
173
174    for worker_name in &selected_workers {
175        log::info!("processing worker {worker_name} ...");
176
177        let worker_node = topology
178            .get(worker_name)
179            .ok_or_else(|| anyhow!("can't find worker '{}' in topology", worker_name))?;
180
181        let (new_index, reduced) = reduce_for_worker(&index, worker_node)?;
182
183        log::info!("compacting {} tensors ...", new_index.weight_map.len());
184
185        let metadata = create_new_metadata(model_path, &reduced)?;
186
187        let bundle_name = format!("{worker_name}-node");
188        let output_path = output.join(&bundle_name);
189        let model_output_path = output_path.join("model");
190        if !output_path.exists() {
191            log::info!("creating {}", model_output_path.display());
192            std::fs::create_dir_all(&model_output_path)?;
193        } else {
194            log::info!("saving model to {}", model_output_path.display());
195        }
196
197        let new_index_path = model_output_path.join("model.safetensors.index.json");
198
199        log::info!("saving new index to {} ...", new_index_path.display());
200
201        let new_index_data = serde_json::to_string_pretty(&new_index)?;
202        std::fs::write(&new_index_path, new_index_data)?;
203
204        let new_tensors_path = model_output_path.join("reduced.safetensors");
205
206        log::info!(
207            "saving reduced tensors to {} ...",
208            new_tensors_path.display()
209        );
210
211        safetensors::serialize_to_file(metadata, None, &new_tensors_path)?;
212
213        // Verify the output is readable.
214        let loaded = utils::load_safetensors_paths_from_index(new_index_path)?;
215        assert_eq!(loaded.len(), 1);
216        let file = File::open(&loaded[0])?;
217        let buffer = unsafe { memmap2::MmapOptions::new().map(&file)? };
218        let _ = SafeTensors::deserialize(&buffer)?;
219
220        let new_topology_path = output_path.join("topology.yml");
221
222        log::info!(
223            "saving worker topology to {} ...",
224            new_topology_path.display()
225        );
226
227        let mut new_topology: HashMap<String, &Node> = HashMap::new();
228        new_topology.insert(worker_name.to_string(), worker_node);
229        let new_topology_data = serde_yaml::to_string(&new_topology)?;
230        std::fs::write(&new_topology_path, new_topology_data)?;
231    }
232
233    Ok(())
234}