ipfrs_cli/commands/
tensor.rs

1//! Tensor operation commands
2//!
3//! This module provides tensor-related operations:
4//! - `tensor_add` - Add tensor file
5//! - `tensor_get` - Get tensor by CID
6//! - `tensor_info` - Show tensor metadata
7//! - `tensor_export` - Export tensor to different format
8
9use anyhow::Result;
10
11use crate::output::{self, error, format_bytes, print_cid, print_header, print_kv, success};
12use crate::progress;
13
14/// Safetensors metadata structure
15#[derive(Debug)]
16pub struct SafetensorsInfo {
17    pub num_tensors: usize,
18    pub tensors: Vec<(String, Vec<usize>, String)>,
19}
20
21/// Extract metadata from safetensors file
22pub fn extract_safetensors_metadata(data: &[u8]) -> Option<SafetensorsInfo> {
23    // Safetensors format starts with an 8-byte header containing the JSON metadata length
24    if data.len() < 8 {
25        return None;
26    }
27
28    // Read the first 8 bytes as u64 (little-endian)
29    let metadata_len = u64::from_le_bytes([
30        data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
31    ]) as usize;
32
33    if data.len() < 8 + metadata_len {
34        return None;
35    }
36
37    // Extract and parse JSON metadata
38    let metadata_bytes = &data[8..8 + metadata_len];
39    let metadata_str = std::str::from_utf8(metadata_bytes).ok()?;
40    let metadata: serde_json::Value = serde_json::from_str(metadata_str).ok()?;
41
42    // Extract tensor information
43    let mut tensors = Vec::new();
44    if let Some(obj) = metadata.as_object() {
45        for (name, info) in obj {
46            if name == "__metadata__" {
47                continue; // Skip metadata field
48            }
49
50            if let Some(tensor_info) = info.as_object() {
51                let shape = tensor_info
52                    .get("shape")
53                    .and_then(|v| v.as_array())
54                    .map(|arr| {
55                        arr.iter()
56                            .filter_map(|v| v.as_u64().map(|n| n as usize))
57                            .collect::<Vec<_>>()
58                    })
59                    .unwrap_or_default();
60
61                let dtype = tensor_info
62                    .get("dtype")
63                    .and_then(|v| v.as_str())
64                    .unwrap_or("unknown")
65                    .to_string();
66
67                tensors.push((name.clone(), shape, dtype));
68            }
69        }
70    }
71
72    Some(SafetensorsInfo {
73        num_tensors: tensors.len(),
74        tensors,
75    })
76}
77
78/// Add tensor file
79pub async fn tensor_add(path: &str, format: &str) -> Result<()> {
80    use bytes::Bytes;
81    use ipfrs_core::Block;
82    use ipfrs_storage::{BlockStoreConfig, BlockStoreTrait, SledBlockStore};
83
84    let file_path = std::path::Path::new(path);
85    let filename = file_path
86        .file_name()
87        .map(|s| s.to_string_lossy().to_string())
88        .unwrap_or_else(|| path.to_string());
89
90    // Get file size for progress
91    let metadata = tokio::fs::metadata(path).await?;
92    let file_size = metadata.len();
93
94    let pb = progress::spinner(&format!("Reading tensor file {}", filename));
95
96    // Read tensor file
97    let data = tokio::fs::read(path).await?;
98    let bytes_data = Bytes::from(data.clone());
99
100    progress::finish_spinner_success(
101        &pb,
102        &format!("Read {} ({})", filename, format_bytes(file_size)),
103    );
104
105    // Try to extract tensor metadata if it's a safetensors file
106    let tensor_info = if path.ends_with(".safetensors") {
107        extract_safetensors_metadata(&data)
108    } else {
109        None
110    };
111
112    // Create block
113    let pb = progress::spinner("Creating tensor block");
114    let block = Block::new(bytes_data)?;
115    let cid = *block.cid();
116    progress::finish_spinner_success(&pb, "Tensor block created");
117
118    // Initialize storage
119    let config = BlockStoreConfig::default();
120    let store = SledBlockStore::new(config)?;
121
122    // Store block
123    let pb = progress::spinner("Storing tensor");
124    store.put(&block).await?;
125    progress::finish_spinner_success(&pb, "Tensor stored");
126
127    match format {
128        "json" => {
129            println!("{{");
130            println!("  \"path\": \"{}\",", path);
131            println!("  \"cid\": \"{}\",", cid);
132            println!("  \"size\": {}", block.size());
133            if let Some(info) = tensor_info {
134                println!("  ,\"metadata\": {{");
135                println!("    \"format\": \"safetensors\",");
136                println!("    \"tensors\": {}", info.num_tensors);
137                println!("  }}");
138            }
139            println!("}}");
140        }
141        _ => {
142            success(&format!("Added tensor {}", filename));
143            print_cid("CID", &cid.to_string());
144            print_kv("Size", &format_bytes(block.size()));
145            if let Some(info) = tensor_info {
146                print_kv("Format", "safetensors");
147                print_kv("Tensors", &info.num_tensors.to_string());
148            }
149        }
150    }
151
152    Ok(())
153}
154
155/// Get tensor by CID
156pub async fn tensor_get(cid_str: &str, output: Option<&str>) -> Result<()> {
157    use ipfrs_core::Cid;
158    use ipfrs_storage::{BlockStoreConfig, BlockStoreTrait, SledBlockStore};
159    use tokio::fs;
160
161    // Parse CID
162    let cid = cid_str
163        .parse::<Cid>()
164        .map_err(|e| anyhow::anyhow!("Invalid CID: {}", e))?;
165
166    let pb = progress::spinner(&format!("Retrieving tensor {}", cid));
167
168    // Initialize storage
169    let config = BlockStoreConfig::default();
170    let store = SledBlockStore::new(config)?;
171
172    // Retrieve block
173    match store.get(&cid).await? {
174        Some(block) => {
175            progress::finish_spinner_success(&pb, "Tensor retrieved");
176
177            let output_path = output.unwrap_or("tensor.safetensors");
178            fs::write(output_path, block.data()).await?;
179
180            success(&format!("Saved tensor to: {}", output_path));
181            print_kv("CID", &cid.to_string());
182            print_kv("Size", &format_bytes(block.size()));
183            Ok(())
184        }
185        None => {
186            progress::finish_spinner_error(&pb, "Tensor not found");
187            error(&format!("Tensor not found: {}", cid));
188            std::process::exit(1);
189        }
190    }
191}
192
193/// Show tensor metadata
194pub async fn tensor_info(cid_str: &str, format: &str) -> Result<()> {
195    use ipfrs_core::Cid;
196    use ipfrs_storage::{BlockStoreConfig, BlockStoreTrait, SledBlockStore};
197
198    // Parse CID
199    let cid = cid_str
200        .parse::<Cid>()
201        .map_err(|e| anyhow::anyhow!("Invalid CID: {}", e))?;
202
203    let pb = progress::spinner(&format!("Retrieving tensor metadata {}", cid));
204
205    // Initialize storage
206    let config = BlockStoreConfig::default();
207    let store = SledBlockStore::new(config)?;
208
209    // Retrieve block
210    match store.get(&cid).await? {
211        Some(block) => {
212            progress::finish_spinner_success(&pb, "Tensor metadata retrieved");
213
214            let tensor_info = extract_safetensors_metadata(block.data());
215
216            match format {
217                "json" => {
218                    println!("{{");
219                    println!("  \"cid\": \"{}\",", cid);
220                    println!("  \"size\": {},", block.size());
221                    if let Some(info) = tensor_info {
222                        println!("  \"format\": \"safetensors\",");
223                        println!("  \"num_tensors\": {},", info.num_tensors);
224                        println!("  \"tensors\": [");
225                        for (i, (name, shape, dtype)) in info.tensors.iter().enumerate() {
226                            print!("    {{");
227                            print!("\"name\": \"{}\", ", name);
228                            print!("\"shape\": {:?}, ", shape);
229                            print!("\"dtype\": \"{}\"", dtype);
230                            print!("}}");
231                            if i < info.tensors.len() - 1 {
232                                println!(",");
233                            } else {
234                                println!();
235                            }
236                        }
237                        println!("  ]");
238                    } else {
239                        println!("  \"format\": \"unknown\"");
240                    }
241                    println!("}}");
242                }
243                _ => {
244                    print_header(&format!("Tensor: {}", cid));
245                    print_kv("Size", &format_bytes(block.size()));
246                    if let Some(info) = tensor_info {
247                        print_kv("Format", "safetensors");
248                        print_kv("Number of tensors", &info.num_tensors.to_string());
249                        println!("\nTensors:");
250                        for (name, shape, dtype) in &info.tensors {
251                            println!("  {} {:?} ({})", name, shape, dtype);
252                        }
253                    } else {
254                        print_kv("Format", "unknown (raw binary)");
255                    }
256                }
257            }
258            Ok(())
259        }
260        None => {
261            progress::finish_spinner_error(&pb, "Tensor not found");
262            error(&format!("Tensor not found: {}", cid));
263            std::process::exit(1);
264        }
265    }
266}
267
268/// Export tensor to different format
269pub async fn tensor_export(cid_str: &str, output_path: &str, target_format: &str) -> Result<()> {
270    use ipfrs_core::Cid;
271    use ipfrs_storage::{BlockStoreConfig, BlockStoreTrait, SledBlockStore};
272    use tokio::fs;
273
274    // Parse CID
275    let cid = cid_str
276        .parse::<Cid>()
277        .map_err(|e| anyhow::anyhow!("Invalid CID: {}", e))?;
278
279    let pb = progress::spinner(&format!("Exporting tensor to {}", target_format));
280
281    // Initialize storage
282    let config = BlockStoreConfig::default();
283    let store = SledBlockStore::new(config)?;
284
285    // Retrieve block
286    match store.get(&cid).await? {
287        Some(block) => {
288            // For now, we just copy the data as-is
289            // In a real implementation, we would convert between formats
290            match target_format {
291                "safetensors" | "numpy" | "pytorch" => {
292                    fs::write(output_path, block.data()).await?;
293                    progress::finish_spinner_success(&pb, "Tensor exported");
294
295                    success(&format!("Exported tensor to {}", output_path));
296                    print_kv("Format", target_format);
297                    print_kv("Size", &format_bytes(block.size()));
298                }
299                _ => {
300                    progress::finish_spinner_error(&pb, "Unsupported format");
301                    error(&format!("Unsupported format: {}", target_format));
302                    output::info("Supported formats: safetensors, numpy, pytorch");
303                    std::process::exit(1);
304                }
305            }
306            Ok(())
307        }
308        None => {
309            progress::finish_spinner_error(&pb, "Tensor not found");
310            error(&format!("Tensor not found: {}", cid));
311            std::process::exit(1);
312        }
313    }
314}