1use anyhow::Result;
10
11use crate::output::{self, error, format_bytes, print_cid, print_header, print_kv, success};
12use crate::progress;
13
14#[derive(Debug)]
16pub struct SafetensorsInfo {
17 pub num_tensors: usize,
18 pub tensors: Vec<(String, Vec<usize>, String)>,
19}
20
21pub fn extract_safetensors_metadata(data: &[u8]) -> Option<SafetensorsInfo> {
23 if data.len() < 8 {
25 return None;
26 }
27
28 let metadata_len = u64::from_le_bytes([
30 data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
31 ]) as usize;
32
33 if data.len() < 8 + metadata_len {
34 return None;
35 }
36
37 let metadata_bytes = &data[8..8 + metadata_len];
39 let metadata_str = std::str::from_utf8(metadata_bytes).ok()?;
40 let metadata: serde_json::Value = serde_json::from_str(metadata_str).ok()?;
41
42 let mut tensors = Vec::new();
44 if let Some(obj) = metadata.as_object() {
45 for (name, info) in obj {
46 if name == "__metadata__" {
47 continue; }
49
50 if let Some(tensor_info) = info.as_object() {
51 let shape = tensor_info
52 .get("shape")
53 .and_then(|v| v.as_array())
54 .map(|arr| {
55 arr.iter()
56 .filter_map(|v| v.as_u64().map(|n| n as usize))
57 .collect::<Vec<_>>()
58 })
59 .unwrap_or_default();
60
61 let dtype = tensor_info
62 .get("dtype")
63 .and_then(|v| v.as_str())
64 .unwrap_or("unknown")
65 .to_string();
66
67 tensors.push((name.clone(), shape, dtype));
68 }
69 }
70 }
71
72 Some(SafetensorsInfo {
73 num_tensors: tensors.len(),
74 tensors,
75 })
76}
77
78pub async fn tensor_add(path: &str, format: &str) -> Result<()> {
80 use bytes::Bytes;
81 use ipfrs_core::Block;
82 use ipfrs_storage::{BlockStoreConfig, BlockStoreTrait, SledBlockStore};
83
84 let file_path = std::path::Path::new(path);
85 let filename = file_path
86 .file_name()
87 .map(|s| s.to_string_lossy().to_string())
88 .unwrap_or_else(|| path.to_string());
89
90 let metadata = tokio::fs::metadata(path).await?;
92 let file_size = metadata.len();
93
94 let pb = progress::spinner(&format!("Reading tensor file {}", filename));
95
96 let data = tokio::fs::read(path).await?;
98 let bytes_data = Bytes::from(data.clone());
99
100 progress::finish_spinner_success(
101 &pb,
102 &format!("Read {} ({})", filename, format_bytes(file_size)),
103 );
104
105 let tensor_info = if path.ends_with(".safetensors") {
107 extract_safetensors_metadata(&data)
108 } else {
109 None
110 };
111
112 let pb = progress::spinner("Creating tensor block");
114 let block = Block::new(bytes_data)?;
115 let cid = *block.cid();
116 progress::finish_spinner_success(&pb, "Tensor block created");
117
118 let config = BlockStoreConfig::default();
120 let store = SledBlockStore::new(config)?;
121
122 let pb = progress::spinner("Storing tensor");
124 store.put(&block).await?;
125 progress::finish_spinner_success(&pb, "Tensor stored");
126
127 match format {
128 "json" => {
129 println!("{{");
130 println!(" \"path\": \"{}\",", path);
131 println!(" \"cid\": \"{}\",", cid);
132 println!(" \"size\": {}", block.size());
133 if let Some(info) = tensor_info {
134 println!(" ,\"metadata\": {{");
135 println!(" \"format\": \"safetensors\",");
136 println!(" \"tensors\": {}", info.num_tensors);
137 println!(" }}");
138 }
139 println!("}}");
140 }
141 _ => {
142 success(&format!("Added tensor {}", filename));
143 print_cid("CID", &cid.to_string());
144 print_kv("Size", &format_bytes(block.size()));
145 if let Some(info) = tensor_info {
146 print_kv("Format", "safetensors");
147 print_kv("Tensors", &info.num_tensors.to_string());
148 }
149 }
150 }
151
152 Ok(())
153}
154
155pub async fn tensor_get(cid_str: &str, output: Option<&str>) -> Result<()> {
157 use ipfrs_core::Cid;
158 use ipfrs_storage::{BlockStoreConfig, BlockStoreTrait, SledBlockStore};
159 use tokio::fs;
160
161 let cid = cid_str
163 .parse::<Cid>()
164 .map_err(|e| anyhow::anyhow!("Invalid CID: {}", e))?;
165
166 let pb = progress::spinner(&format!("Retrieving tensor {}", cid));
167
168 let config = BlockStoreConfig::default();
170 let store = SledBlockStore::new(config)?;
171
172 match store.get(&cid).await? {
174 Some(block) => {
175 progress::finish_spinner_success(&pb, "Tensor retrieved");
176
177 let output_path = output.unwrap_or("tensor.safetensors");
178 fs::write(output_path, block.data()).await?;
179
180 success(&format!("Saved tensor to: {}", output_path));
181 print_kv("CID", &cid.to_string());
182 print_kv("Size", &format_bytes(block.size()));
183 Ok(())
184 }
185 None => {
186 progress::finish_spinner_error(&pb, "Tensor not found");
187 error(&format!("Tensor not found: {}", cid));
188 std::process::exit(1);
189 }
190 }
191}
192
193pub async fn tensor_info(cid_str: &str, format: &str) -> Result<()> {
195 use ipfrs_core::Cid;
196 use ipfrs_storage::{BlockStoreConfig, BlockStoreTrait, SledBlockStore};
197
198 let cid = cid_str
200 .parse::<Cid>()
201 .map_err(|e| anyhow::anyhow!("Invalid CID: {}", e))?;
202
203 let pb = progress::spinner(&format!("Retrieving tensor metadata {}", cid));
204
205 let config = BlockStoreConfig::default();
207 let store = SledBlockStore::new(config)?;
208
209 match store.get(&cid).await? {
211 Some(block) => {
212 progress::finish_spinner_success(&pb, "Tensor metadata retrieved");
213
214 let tensor_info = extract_safetensors_metadata(block.data());
215
216 match format {
217 "json" => {
218 println!("{{");
219 println!(" \"cid\": \"{}\",", cid);
220 println!(" \"size\": {},", block.size());
221 if let Some(info) = tensor_info {
222 println!(" \"format\": \"safetensors\",");
223 println!(" \"num_tensors\": {},", info.num_tensors);
224 println!(" \"tensors\": [");
225 for (i, (name, shape, dtype)) in info.tensors.iter().enumerate() {
226 print!(" {{");
227 print!("\"name\": \"{}\", ", name);
228 print!("\"shape\": {:?}, ", shape);
229 print!("\"dtype\": \"{}\"", dtype);
230 print!("}}");
231 if i < info.tensors.len() - 1 {
232 println!(",");
233 } else {
234 println!();
235 }
236 }
237 println!(" ]");
238 } else {
239 println!(" \"format\": \"unknown\"");
240 }
241 println!("}}");
242 }
243 _ => {
244 print_header(&format!("Tensor: {}", cid));
245 print_kv("Size", &format_bytes(block.size()));
246 if let Some(info) = tensor_info {
247 print_kv("Format", "safetensors");
248 print_kv("Number of tensors", &info.num_tensors.to_string());
249 println!("\nTensors:");
250 for (name, shape, dtype) in &info.tensors {
251 println!(" {} {:?} ({})", name, shape, dtype);
252 }
253 } else {
254 print_kv("Format", "unknown (raw binary)");
255 }
256 }
257 }
258 Ok(())
259 }
260 None => {
261 progress::finish_spinner_error(&pb, "Tensor not found");
262 error(&format!("Tensor not found: {}", cid));
263 std::process::exit(1);
264 }
265 }
266}
267
268pub async fn tensor_export(cid_str: &str, output_path: &str, target_format: &str) -> Result<()> {
270 use ipfrs_core::Cid;
271 use ipfrs_storage::{BlockStoreConfig, BlockStoreTrait, SledBlockStore};
272 use tokio::fs;
273
274 let cid = cid_str
276 .parse::<Cid>()
277 .map_err(|e| anyhow::anyhow!("Invalid CID: {}", e))?;
278
279 let pb = progress::spinner(&format!("Exporting tensor to {}", target_format));
280
281 let config = BlockStoreConfig::default();
283 let store = SledBlockStore::new(config)?;
284
285 match store.get(&cid).await? {
287 Some(block) => {
288 match target_format {
291 "safetensors" | "numpy" | "pytorch" => {
292 fs::write(output_path, block.data()).await?;
293 progress::finish_spinner_success(&pb, "Tensor exported");
294
295 success(&format!("Exported tensor to {}", output_path));
296 print_kv("Format", target_format);
297 print_kv("Size", &format_bytes(block.size()));
298 }
299 _ => {
300 progress::finish_spinner_error(&pb, "Unsupported format");
301 error(&format!("Unsupported format: {}", target_format));
302 output::info("Supported formats: safetensors, numpy, pytorch");
303 std::process::exit(1);
304 }
305 }
306 Ok(())
307 }
308 None => {
309 progress::finish_spinner_error(&pb, "Tensor not found");
310 error(&format!("Tensor not found: {}", cid));
311 std::process::exit(1);
312 }
313 }
314}