1use std::{
4 collections::HashMap,
5 fs::File,
6 path::Path,
7};
8
9use anyhow::Result;
10use safetensors::{Dtype, SafeTensors, View};
11use serde::{Deserialize, Serialize};
12
13use crate::{
14 cake::{Node, Topology},
15 utils, ModelType,
16};
17
18#[derive(Debug, Serialize, Deserialize)]
19struct Index {
20 pub weight_map: HashMap<String, String>,
21}
22
23impl Index {
24 pub fn new() -> Self {
25 let weight_map = HashMap::new();
26 Self { weight_map }
27 }
28}
29
30#[derive(Debug)]
31struct TensorStore {
32 dtype: Dtype,
33 shape: Vec<usize>,
34 data: Vec<u8>,
35}
36
37impl View for TensorStore {
38 fn dtype(&self) -> Dtype {
39 self.dtype
40 }
41
42 fn shape(&self) -> &[usize] {
43 &self.shape
44 }
45
46 fn data(&self) -> std::borrow::Cow<'_, [u8]> {
47 std::borrow::Cow::from(&self.data)
48 }
49
50 fn data_len(&self) -> usize {
51 self.data.len()
52 }
53}
54
55fn load_index(data_path: &Path) -> Result<Index> {
56 let tensors_index_path = data_path.join("model.safetensors.index.json");
57
58 if tensors_index_path.exists() {
59 let tensors_index_data = std::fs::read_to_string(tensors_index_path)?;
60 let tensors_index: Index = serde_json::from_str(&tensors_index_data)?;
61 Ok(tensors_index)
62 } else {
63 let single_path = data_path.join("model.safetensors");
64 if !single_path.exists() {
65 anyhow::bail!(
66 "neither model.safetensors.index.json nor model.safetensors found in {}",
67 data_path.display()
68 );
69 }
70
71 log::info!("no index file found, generating from model.safetensors ...");
72
73 let file = File::open(&single_path)?;
74 let buffer = unsafe { memmap2::MmapOptions::new().map(&file)? };
75 let tensors = SafeTensors::deserialize(&buffer)?;
76
77 let mut index = Index::new();
78 for (name, _) in tensors.tensors() {
79 index
80 .weight_map
81 .insert(name.to_string(), "model.safetensors".to_string());
82 }
83
84 Ok(index)
85 }
86}
87
88fn reduce_for_worker(
89 index: &Index,
90 worker: &Node,
91) -> Result<(Index, HashMap<String, Vec<String>>)> {
92 log::info!("worker: {}", &worker.host);
93
94 let mut reduced: HashMap<String, Vec<String>> = HashMap::new();
95 let mut new_index = Index::new();
96
97 for (layer_full_name, filename) in &index.weight_map {
98 if worker.is_text_model_layer_owner(layer_full_name) {
99 if let Some(layers) = reduced.get_mut(filename) {
100 layers.push(layer_full_name.to_string());
101 } else {
102 reduced.insert(filename.to_string(), vec![layer_full_name.to_string()]);
103 }
104
105 new_index.weight_map.insert(
106 layer_full_name.to_string(),
107 "reduced.safetensors".to_string(),
108 );
109 }
110 }
111
112 Ok((new_index, reduced))
113}
114
115fn create_new_metadata(
116 data_path: &Path,
117 reduced: &HashMap<String, Vec<String>>,
118) -> Result<HashMap<String, TensorStore>> {
119 let mut metadata: HashMap<String, TensorStore> = HashMap::new();
120
121 for (filename, tensor_names) in reduced {
122 let filepath = data_path.join(filename);
123
124 log::info!("loading {} ...", filepath.display());
125
126 let file = File::open(&filepath)?;
127 let buffer = unsafe { memmap2::MmapOptions::new().map(&file)? };
128 let tensors = SafeTensors::deserialize(&buffer)?;
129
130 log::info!(" extracting {} tensors", tensor_names.len());
131
132 for tensor_name in tensor_names {
133 let tensor = tensors.tensor(tensor_name)?;
134 metadata.insert(
135 tensor_name.to_string(),
136 TensorStore {
137 dtype: tensor.dtype(),
138 shape: tensor.shape().to_vec(),
139 data: tensor.data().to_vec(),
140 },
141 );
142 }
143
144 drop(tensors);
145 drop(buffer);
146 }
147
148 Ok(metadata)
149}
150
151pub fn split_model(
156 model_path: &Path,
157 topology_path: &str,
158 worker: Option<&str>,
159 output: &Path,
160) -> Result<()> {
161 let topology = Topology::from_path(topology_path, &ModelType::TextModel)?;
162 let index = load_index(model_path)?;
163
164 log::info!("index has {} tensors", index.weight_map.len());
165
166 let selected_workers: Vec<String> = if let Some(name) = worker {
167 vec![name.to_string()]
168 } else {
169 topology.keys().map(|s| s.to_string()).collect()
170 };
171
172 log::info!("processing {} workers", selected_workers.len());
173
174 for worker_name in &selected_workers {
175 log::info!("processing worker {worker_name} ...");
176
177 let worker_node = topology
178 .get(worker_name)
179 .ok_or_else(|| anyhow!("can't find worker '{}' in topology", worker_name))?;
180
181 let (new_index, reduced) = reduce_for_worker(&index, worker_node)?;
182
183 log::info!("compacting {} tensors ...", new_index.weight_map.len());
184
185 let metadata = create_new_metadata(model_path, &reduced)?;
186
187 let bundle_name = format!("{worker_name}-node");
188 let output_path = output.join(&bundle_name);
189 let model_output_path = output_path.join("model");
190 if !output_path.exists() {
191 log::info!("creating {}", model_output_path.display());
192 std::fs::create_dir_all(&model_output_path)?;
193 } else {
194 log::info!("saving model to {}", model_output_path.display());
195 }
196
197 let new_index_path = model_output_path.join("model.safetensors.index.json");
198
199 log::info!("saving new index to {} ...", new_index_path.display());
200
201 let new_index_data = serde_json::to_string_pretty(&new_index)?;
202 std::fs::write(&new_index_path, new_index_data)?;
203
204 let new_tensors_path = model_output_path.join("reduced.safetensors");
205
206 log::info!(
207 "saving reduced tensors to {} ...",
208 new_tensors_path.display()
209 );
210
211 safetensors::serialize_to_file(metadata, None, &new_tensors_path)?;
212
213 let loaded = utils::load_safetensors_paths_from_index(new_index_path)?;
215 assert_eq!(loaded.len(), 1);
216 let file = File::open(&loaded[0])?;
217 let buffer = unsafe { memmap2::MmapOptions::new().map(&file)? };
218 let _ = SafeTensors::deserialize(&buffer)?;
219
220 let new_topology_path = output_path.join("topology.yml");
221
222 log::info!(
223 "saving worker topology to {} ...",
224 new_topology_path.display()
225 );
226
227 let mut new_topology: HashMap<String, &Node> = HashMap::new();
228 new_topology.insert(worker_name.to_string(), worker_node);
229 let new_topology_data = serde_yaml::to_string(&new_topology)?;
230 std::fs::write(&new_topology_path, new_topology_data)?;
231 }
232
233 Ok(())
234}