1use crate::traits::BlockStore;
37use bytes::Bytes;
38use ipfrs_core::{Block, Cid, Error, Result};
39use serde::{Deserialize, Serialize};
40use std::collections::HashMap;
41use std::str::FromStr;
42use std::sync::Arc;
43
44#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
46pub enum DType {
47 F32,
48 F64,
49 F16,
50 BF16,
51 I8,
52 I16,
53 I32,
54 I64,
55 U8,
56 U16,
57 U32,
58 U64,
59 Bool,
60}
61
62impl DType {
63 pub fn size(&self) -> usize {
65 match self {
66 DType::F32 | DType::I32 | DType::U32 => 4,
67 DType::F64 | DType::I64 | DType::U64 => 8,
68 DType::F16 | DType::BF16 | DType::I16 | DType::U16 => 2,
69 DType::I8 | DType::U8 | DType::Bool => 1,
70 }
71 }
72}
73
74impl FromStr for DType {
75 type Err = String;
76
77 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
78 match s {
79 "F32" => Ok(DType::F32),
80 "F64" => Ok(DType::F64),
81 "F16" => Ok(DType::F16),
82 "BF16" => Ok(DType::BF16),
83 "I8" => Ok(DType::I8),
84 "I16" => Ok(DType::I16),
85 "I32" => Ok(DType::I32),
86 "I64" => Ok(DType::I64),
87 "U8" => Ok(DType::U8),
88 "U16" => Ok(DType::U16),
89 "U32" => Ok(DType::U32),
90 "U64" => Ok(DType::U64),
91 "BOOL" => Ok(DType::Bool),
92 _ => Err(format!("Unknown dtype: {s}")),
93 }
94 }
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
99pub struct TensorInfo {
100 pub dtype: DType,
102 pub shape: Vec<usize>,
104 pub data_offsets: (usize, usize),
106}
107
108impl TensorInfo {
109 pub fn numel(&self) -> usize {
111 self.shape.iter().product()
112 }
113
114 pub fn size_bytes(&self) -> usize {
116 self.numel() * self.dtype.size()
117 }
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct SafetensorsHeader {
123 pub tensors: HashMap<String, TensorInfo>,
125 pub metadata: HashMap<String, String>,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct ChunkedTensor {
132 pub name: String,
134 pub info: TensorInfo,
136 #[serde(
138 serialize_with = "serialize_cid_vec",
139 deserialize_with = "deserialize_cid_vec"
140 )]
141 pub chunk_cids: Vec<Cid>,
142 pub chunk_size: usize,
144}
145
146fn serialize_cid_vec<S>(cids: &[Cid], serializer: S) -> std::result::Result<S::Ok, S::Error>
148where
149 S: serde::Serializer,
150{
151 use serde::ser::SerializeSeq;
152 let mut seq = serializer.serialize_seq(Some(cids.len()))?;
153 for cid in cids {
154 seq.serialize_element(&cid.to_bytes())?;
155 }
156 seq.end()
157}
158
159fn deserialize_cid_vec<'de, D>(deserializer: D) -> std::result::Result<Vec<Cid>, D::Error>
160where
161 D: serde::Deserializer<'de>,
162{
163 let bytes_vec: Vec<Vec<u8>> = Deserialize::deserialize(deserializer)?;
164 bytes_vec
165 .into_iter()
166 .map(|bytes| Cid::try_from(bytes).map_err(serde::de::Error::custom))
167 .collect()
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct SafetensorsManifest {
173 pub name: String,
175 pub header: SafetensorsHeader,
177 pub tensors: HashMap<String, ChunkedTensor>,
179 pub total_size: u64,
181}
182
183#[derive(Debug, Clone)]
185pub struct ChunkConfig {
186 pub chunk_size: usize,
188 pub compress: bool,
190}
191
192impl Default for ChunkConfig {
193 fn default() -> Self {
194 Self {
195 chunk_size: 64 * 1024 * 1024, compress: false,
197 }
198 }
199}
200
201pub struct SafetensorsStore<S: BlockStore> {
203 store: Arc<S>,
205 chunk_config: ChunkConfig,
207}
208
209impl<S: BlockStore> SafetensorsStore<S> {
210 pub fn new(store: Arc<S>) -> Self {
212 Self {
213 store,
214 chunk_config: ChunkConfig::default(),
215 }
216 }
217
218 pub fn with_config(store: Arc<S>, chunk_config: ChunkConfig) -> Self {
220 Self {
221 store,
222 chunk_config,
223 }
224 }
225
226 pub fn parse_header(data: &[u8]) -> Result<(SafetensorsHeader, usize)> {
228 if data.len() < 8 {
229 return Err(Error::Storage(
230 "File too small to be safetensors".to_string(),
231 ));
232 }
233
234 let header_size = u64::from_le_bytes([
236 data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
237 ]) as usize;
238
239 if data.len() < 8 + header_size {
240 return Err(Error::Storage("Incomplete safetensors header".to_string()));
241 }
242
243 let header_bytes = &data[8..8 + header_size];
245 let header_json: serde_json::Value = serde_json::from_slice(header_bytes)
246 .map_err(|e| Error::Serialization(format!("Failed to parse header JSON: {e}")))?;
247
248 let mut tensors = HashMap::new();
249 let mut metadata = HashMap::new();
250
251 if let Some(obj) = header_json.as_object() {
253 for (key, value) in obj {
254 if key == "__metadata__" {
255 if let Some(meta_obj) = value.as_object() {
257 for (k, v) in meta_obj {
258 if let Some(s) = v.as_str() {
259 metadata.insert(k.clone(), s.to_string());
260 }
261 }
262 }
263 } else {
264 if let Some(tensor_obj) = value.as_object() {
266 let dtype_str = tensor_obj
267 .get("dtype")
268 .and_then(|v| v.as_str())
269 .ok_or_else(|| Error::Storage("Missing dtype".to_string()))?;
270
271 let dtype = dtype_str.parse::<DType>().map_err(Error::Storage)?;
272
273 let shape: Vec<usize> = tensor_obj
274 .get("shape")
275 .and_then(|v| v.as_array())
276 .ok_or_else(|| Error::Storage("Missing shape".to_string()))?
277 .iter()
278 .filter_map(|v| v.as_u64().map(|n| n as usize))
279 .collect();
280
281 let data_offsets = tensor_obj
282 .get("data_offsets")
283 .and_then(|v| v.as_array())
284 .ok_or_else(|| Error::Storage("Missing data_offsets".to_string()))?;
285
286 let start = data_offsets[0].as_u64().ok_or_else(|| {
287 Error::Storage("Invalid data_offsets start".to_string())
288 })? as usize;
289 let end = data_offsets[1]
290 .as_u64()
291 .ok_or_else(|| Error::Storage("Invalid data_offsets end".to_string()))?
292 as usize;
293
294 tensors.insert(
295 key.clone(),
296 TensorInfo {
297 dtype,
298 shape,
299 data_offsets: (start, end),
300 },
301 );
302 }
303 }
304 }
305 }
306
307 Ok((SafetensorsHeader { tensors, metadata }, 8 + header_size))
308 }
309
310 pub async fn import_from_bytes(&self, name: String, data: &[u8]) -> Result<Cid> {
312 let (header, data_offset) = Self::parse_header(data)?;
314
315 let data_section = &data[data_offset..];
316 let mut chunked_tensors = HashMap::new();
317 let mut total_size = 0u64;
318
319 for (tensor_name, tensor_info) in &header.tensors {
321 let (start, end) = tensor_info.data_offsets;
322 let tensor_data = &data_section[start..end];
323
324 let mut chunk_cids = Vec::new();
326 for chunk in tensor_data.chunks(self.chunk_config.chunk_size) {
327 let block = Block::new(Bytes::from(chunk.to_vec()))?;
328 let cid = *block.cid();
329 self.store.put(&block).await?;
330 chunk_cids.push(cid);
331 }
332
333 chunked_tensors.insert(
334 tensor_name.clone(),
335 ChunkedTensor {
336 name: tensor_name.clone(),
337 info: tensor_info.clone(),
338 chunk_cids,
339 chunk_size: self.chunk_config.chunk_size,
340 },
341 );
342
343 total_size += tensor_data.len() as u64;
344 }
345
346 let manifest = SafetensorsManifest {
348 name,
349 header,
350 tensors: chunked_tensors,
351 total_size,
352 };
353
354 let manifest_bytes = oxicode::serde::encode_to_vec(&manifest, oxicode::config::standard())
356 .map_err(|e| Error::Serialization(format!("Failed to serialize manifest: {e}")))?;
357
358 let manifest_block = Block::new(Bytes::from(manifest_bytes))?;
359 let manifest_cid = *manifest_block.cid();
360 self.store.put(&manifest_block).await?;
361
362 Ok(manifest_cid)
363 }
364
365 pub async fn load_manifest(&self, manifest_cid: &Cid) -> Result<SafetensorsManifest> {
367 let block = self
368 .store
369 .get(manifest_cid)
370 .await?
371 .ok_or_else(|| Error::NotFound(format!("Manifest not found: {manifest_cid}")))?;
372
373 let manifest: SafetensorsManifest =
374 oxicode::serde::decode_owned_from_slice(block.data(), oxicode::config::standard())
375 .map(|(v, _)| v)
376 .map_err(|e| {
377 Error::Serialization(format!("Failed to deserialize manifest: {e}"))
378 })?;
379
380 Ok(manifest)
381 }
382
383 pub async fn load_tensor(&self, manifest_cid: &Cid, tensor_name: &str) -> Result<Vec<u8>> {
385 let manifest = self.load_manifest(manifest_cid).await?;
386
387 let chunked_tensor = manifest
388 .tensors
389 .get(tensor_name)
390 .ok_or_else(|| Error::NotFound(format!("Tensor not found: {tensor_name}")))?;
391
392 let mut tensor_data = Vec::with_capacity(chunked_tensor.info.size_bytes());
394
395 for chunk_cid in &chunked_tensor.chunk_cids {
396 let chunk_block = self
397 .store
398 .get(chunk_cid)
399 .await?
400 .ok_or_else(|| Error::NotFound(format!("Chunk not found: {chunk_cid}")))?;
401
402 tensor_data.extend_from_slice(chunk_block.data());
403 }
404
405 Ok(tensor_data)
406 }
407
408 pub async fn load_tensors(
410 &self,
411 manifest_cid: &Cid,
412 tensor_names: &[&str],
413 ) -> Result<HashMap<String, Vec<u8>>> {
414 let _manifest = self.load_manifest(manifest_cid).await?;
415 let mut result = HashMap::new();
416
417 for &tensor_name in tensor_names {
418 let tensor_data = self.load_tensor(manifest_cid, tensor_name).await?;
419 result.insert(tensor_name.to_string(), tensor_data);
420 }
421
422 Ok(result)
423 }
424
425 pub async fn get_tensor_info(
427 &self,
428 manifest_cid: &Cid,
429 tensor_name: &str,
430 ) -> Result<TensorInfo> {
431 let manifest = self.load_manifest(manifest_cid).await?;
432
433 manifest
434 .tensors
435 .get(tensor_name)
436 .map(|ct| ct.info.clone())
437 .ok_or_else(|| Error::NotFound(format!("Tensor not found: {tensor_name}")))
438 }
439
440 pub async fn list_tensors(&self, manifest_cid: &Cid) -> Result<Vec<String>> {
442 let manifest = self.load_manifest(manifest_cid).await?;
443 Ok(manifest.tensors.keys().cloned().collect())
444 }
445
446 pub async fn get_model_stats(&self, manifest_cid: &Cid) -> Result<ModelStats> {
448 let manifest = self.load_manifest(manifest_cid).await?;
449
450 let tensor_count = manifest.tensors.len();
451 let total_parameters: usize = manifest.tensors.values().map(|ct| ct.info.numel()).sum();
452
453 let chunk_count: usize = manifest
454 .tensors
455 .values()
456 .map(|ct| ct.chunk_cids.len())
457 .sum();
458
459 Ok(ModelStats {
460 name: manifest.name,
461 tensor_count,
462 total_parameters,
463 total_size_bytes: manifest.total_size,
464 chunk_count,
465 avg_chunk_size: if chunk_count > 0 {
466 manifest.total_size / chunk_count as u64
467 } else {
468 0
469 },
470 })
471 }
472}
473
474#[derive(Debug, Clone, PartialEq, Eq)]
476pub struct ModelStats {
477 pub name: String,
479 pub tensor_count: usize,
481 pub total_parameters: usize,
483 pub total_size_bytes: u64,
485 pub chunk_count: usize,
487 pub avg_chunk_size: u64,
489}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494 use crate::blockstore::{BlockStoreConfig, SledBlockStore};
495 use std::path::PathBuf;
496
497 #[test]
498 fn test_dtype_size() {
499 assert_eq!(DType::F32.size(), 4);
500 assert_eq!(DType::F64.size(), 8);
501 assert_eq!(DType::F16.size(), 2);
502 assert_eq!(DType::I8.size(), 1);
503 }
504
505 #[test]
506 fn test_tensor_info_numel() {
507 let info = TensorInfo {
508 dtype: DType::F32,
509 shape: vec![2, 3, 4],
510 data_offsets: (0, 96),
511 };
512
513 assert_eq!(info.numel(), 24);
514 assert_eq!(info.size_bytes(), 96);
515 }
516
517 #[tokio::test]
518 async fn test_safetensors_store() {
519 let config = BlockStoreConfig {
520 path: PathBuf::from("/tmp/ipfrs-safetensors-test"),
521 cache_size: 100 * 1024 * 1024,
522 };
523 let _ = std::fs::remove_dir_all(&config.path);
524
525 let store = Arc::new(SledBlockStore::new(config).unwrap());
526 let safetensors_store = SafetensorsStore::new(store);
527
528 let header = r#"{"tensor1":{"dtype":"F32","shape":[2,2],"data_offsets":[0,16]}}"#;
530 let header_size = header.len() as u64;
531 let mut data = Vec::new();
532 data.extend_from_slice(&header_size.to_le_bytes());
533 data.extend_from_slice(header.as_bytes());
534 data.extend_from_slice(&[0u8; 16]);
536
537 let manifest_cid = safetensors_store
538 .import_from_bytes("test_model".to_string(), &data)
539 .await
540 .unwrap();
541
542 let manifest = safetensors_store
544 .load_manifest(&manifest_cid)
545 .await
546 .unwrap();
547 assert_eq!(manifest.name, "test_model");
548 assert_eq!(manifest.tensors.len(), 1);
549
550 let stats = safetensors_store
552 .get_model_stats(&manifest_cid)
553 .await
554 .unwrap();
555 assert_eq!(stats.tensor_count, 1);
556 assert_eq!(stats.total_parameters, 4);
557 }
558
559 #[test]
560 fn test_parse_header() {
561 let header = r#"{"tensor1":{"dtype":"F32","shape":[2,2],"data_offsets":[0,16]}}"#;
562 let header_size = header.len() as u64;
563 let mut data = Vec::new();
564 data.extend_from_slice(&header_size.to_le_bytes());
565 data.extend_from_slice(header.as_bytes());
566
567 let (parsed, offset) = SafetensorsStore::<SledBlockStore>::parse_header(&data).unwrap();
568 assert_eq!(offset, 8 + header.len());
569 assert_eq!(parsed.tensors.len(), 1);
570 assert!(parsed.tensors.contains_key("tensor1"));
571
572 let tensor_info = &parsed.tensors["tensor1"];
573 assert_eq!(tensor_info.dtype, DType::F32);
574 assert_eq!(tensor_info.shape, vec![2, 2]);
575 assert_eq!(tensor_info.data_offsets, (0, 16));
576 }
577}