1use crate::block::Block;
26use crate::error::{Error, Result};
27use bytes::Bytes;
28use serde::{Deserialize, Serialize};
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
32pub enum TensorDtype {
33 F32,
35 F16,
37 F64,
39 I8,
41 I32,
43 I64,
45 U8,
47 U32,
49 Bool,
51}
52
53impl TensorDtype {
54 #[inline]
56 pub fn size_bytes(&self) -> usize {
57 match self {
58 TensorDtype::F32 => 4,
59 TensorDtype::F16 => 2,
60 TensorDtype::F64 => 8,
61 TensorDtype::I8 => 1,
62 TensorDtype::I32 => 4,
63 TensorDtype::I64 => 8,
64 TensorDtype::U8 => 1,
65 TensorDtype::U32 => 4,
66 TensorDtype::Bool => 1,
67 }
68 }
69
70 #[inline]
72 pub fn name(&self) -> &'static str {
73 match self {
74 TensorDtype::F32 => "float32",
75 TensorDtype::F16 => "float16",
76 TensorDtype::F64 => "float64",
77 TensorDtype::I8 => "int8",
78 TensorDtype::I32 => "int32",
79 TensorDtype::I64 => "int64",
80 TensorDtype::U8 => "uint8",
81 TensorDtype::U32 => "uint32",
82 TensorDtype::Bool => "bool",
83 }
84 }
85}
86
87#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
89pub struct TensorShape {
90 dims: Vec<usize>,
91}
92
93impl TensorShape {
94 pub fn new(dims: Vec<usize>) -> Self {
96 Self { dims }
97 }
98
99 pub fn scalar() -> Self {
101 Self { dims: vec![] }
102 }
103
104 #[inline]
106 pub fn dims(&self) -> &[usize] {
107 &self.dims
108 }
109
110 #[inline]
112 pub fn rank(&self) -> usize {
113 self.dims.len()
114 }
115
116 #[inline]
118 pub fn element_count(&self) -> usize {
119 if self.dims.is_empty() {
120 1
121 } else {
122 self.dims.iter().product()
123 }
124 }
125
126 #[inline]
128 pub fn is_scalar(&self) -> bool {
129 self.dims.is_empty()
130 }
131
132 #[inline]
134 pub fn is_vector(&self) -> bool {
135 self.dims.len() == 1
136 }
137
138 #[inline]
140 pub fn is_matrix(&self) -> bool {
141 self.dims.len() == 2
142 }
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct TensorMetadata {
148 pub shape: TensorShape,
150 pub dtype: TensorDtype,
152 pub name: Option<String>,
154 pub metadata: std::collections::BTreeMap<String, String>,
156}
157
158impl TensorMetadata {
159 pub fn new(shape: TensorShape, dtype: TensorDtype) -> Self {
161 Self {
162 shape,
163 dtype,
164 name: None,
165 metadata: std::collections::BTreeMap::new(),
166 }
167 }
168
169 pub fn with_name(mut self, name: String) -> Self {
171 self.name = Some(name);
172 self
173 }
174
175 pub fn with_metadata(mut self, key: String, value: String) -> Self {
177 self.metadata.insert(key, value);
178 self
179 }
180
181 pub fn expected_size(&self) -> usize {
183 self.shape.element_count() * self.dtype.size_bytes()
184 }
185}
186
187#[derive(Debug, Clone)]
193pub struct TensorBlock {
194 block: Block,
196 metadata: TensorMetadata,
198}
199
200impl TensorBlock {
201 pub fn new(data: Bytes, shape: TensorShape, dtype: TensorDtype) -> Result<Self> {
231 let metadata = TensorMetadata::new(shape, dtype);
232
233 let expected_size = metadata.expected_size();
235 if data.len() != expected_size {
236 return Err(Error::InvalidData(format!(
237 "Tensor data size mismatch: expected {} bytes, got {}",
238 expected_size,
239 data.len()
240 )));
241 }
242
243 let block = Block::new(data)?;
245
246 Ok(Self { block, metadata })
247 }
248
249 pub fn with_metadata(data: Bytes, metadata: TensorMetadata) -> Result<Self> {
251 let expected_size = metadata.expected_size();
252 if data.len() != expected_size {
253 return Err(Error::InvalidData(format!(
254 "Tensor data size mismatch: expected {} bytes, got {}",
255 expected_size,
256 data.len()
257 )));
258 }
259
260 let block = Block::new(data)?;
261 Ok(Self { block, metadata })
262 }
263
264 pub fn block(&self) -> &Block {
266 &self.block
267 }
268
269 pub fn metadata(&self) -> &TensorMetadata {
271 &self.metadata
272 }
273
274 pub fn shape(&self) -> &TensorShape {
276 &self.metadata.shape
277 }
278
279 pub fn dtype(&self) -> TensorDtype {
281 self.metadata.dtype
282 }
283
284 pub fn element_count(&self) -> usize {
286 self.metadata.shape.element_count()
287 }
288
289 pub fn cid(&self) -> &crate::cid::Cid {
291 self.block.cid()
292 }
293
294 pub fn data(&self) -> &Bytes {
296 self.block.data()
297 }
298
299 pub fn into_parts(self) -> (Block, TensorMetadata) {
301 (self.block, self.metadata)
302 }
303
304 pub fn verify(&self) -> Result<bool> {
306 self.block.verify()
307 }
308
309 pub fn reshape(&self, new_shape: TensorShape) -> Result<Self> {
311 if new_shape.element_count() != self.element_count() {
312 return Err(Error::InvalidInput(format!(
313 "Cannot reshape tensor with {} elements to shape with {} elements",
314 self.element_count(),
315 new_shape.element_count()
316 )));
317 }
318
319 let new_metadata = TensorMetadata {
320 shape: new_shape,
321 dtype: self.metadata.dtype,
322 name: self.metadata.name.clone(),
323 metadata: self.metadata.metadata.clone(),
324 };
325
326 Ok(Self {
327 block: self.block.clone(),
328 metadata: new_metadata,
329 })
330 }
331
332 pub fn size_bytes(&self) -> usize {
334 self.data().len()
335 }
336
337 pub fn is_scalar(&self) -> bool {
339 self.shape().is_scalar()
340 }
341
342 pub fn is_vector(&self) -> bool {
344 self.shape().is_vector()
345 }
346
347 pub fn is_matrix(&self) -> bool {
349 self.shape().is_matrix()
350 }
351}
352
353impl TensorBlock {
355 pub fn from_f32_slice(data: &[f32], shape: TensorShape) -> Result<Self> {
357 if data.len() != shape.element_count() {
358 return Err(Error::InvalidInput(format!(
359 "Data length {} doesn't match shape element count {}",
360 data.len(),
361 shape.element_count()
362 )));
363 }
364
365 let bytes: Vec<u8> = data.iter().flat_map(|&f| f.to_le_bytes()).collect();
366 Self::new(Bytes::from(bytes), shape, TensorDtype::F32)
367 }
368
369 pub fn from_f64_slice(data: &[f64], shape: TensorShape) -> Result<Self> {
371 if data.len() != shape.element_count() {
372 return Err(Error::InvalidInput(format!(
373 "Data length {} doesn't match shape element count {}",
374 data.len(),
375 shape.element_count()
376 )));
377 }
378
379 let bytes: Vec<u8> = data.iter().flat_map(|&f| f.to_le_bytes()).collect();
380 Self::new(Bytes::from(bytes), shape, TensorDtype::F64)
381 }
382
383 pub fn from_i32_slice(data: &[i32], shape: TensorShape) -> Result<Self> {
385 if data.len() != shape.element_count() {
386 return Err(Error::InvalidInput(format!(
387 "Data length {} doesn't match shape element count {}",
388 data.len(),
389 shape.element_count()
390 )));
391 }
392
393 let bytes: Vec<u8> = data.iter().flat_map(|&i| i.to_le_bytes()).collect();
394 Self::new(Bytes::from(bytes), shape, TensorDtype::I32)
395 }
396
397 pub fn from_i64_slice(data: &[i64], shape: TensorShape) -> Result<Self> {
399 if data.len() != shape.element_count() {
400 return Err(Error::InvalidInput(format!(
401 "Data length {} doesn't match shape element count {}",
402 data.len(),
403 shape.element_count()
404 )));
405 }
406
407 let bytes: Vec<u8> = data.iter().flat_map(|&i| i.to_le_bytes()).collect();
408 Self::new(Bytes::from(bytes), shape, TensorDtype::I64)
409 }
410
411 pub fn from_u8_slice(data: &[u8], shape: TensorShape) -> Result<Self> {
413 if data.len() != shape.element_count() {
414 return Err(Error::InvalidInput(format!(
415 "Data length {} doesn't match shape element count {}",
416 data.len(),
417 shape.element_count()
418 )));
419 }
420
421 Self::new(Bytes::copy_from_slice(data), shape, TensorDtype::U8)
422 }
423
424 pub fn to_f32_vec(&self) -> Result<Vec<f32>> {
426 if self.dtype() != TensorDtype::F32 {
427 return Err(Error::InvalidInput(format!(
428 "Cannot convert {} tensor to f32",
429 self.dtype().name()
430 )));
431 }
432
433 let data = self.data();
434 let mut result = Vec::with_capacity(self.element_count());
435
436 for chunk in data.chunks_exact(4) {
437 let bytes: [u8; 4] = chunk.try_into().unwrap();
438 result.push(f32::from_le_bytes(bytes));
439 }
440
441 Ok(result)
442 }
443
444 pub fn to_f64_vec(&self) -> Result<Vec<f64>> {
446 if self.dtype() != TensorDtype::F64 {
447 return Err(Error::InvalidInput(format!(
448 "Cannot convert {} tensor to f64",
449 self.dtype().name()
450 )));
451 }
452
453 let data = self.data();
454 let mut result = Vec::with_capacity(self.element_count());
455
456 for chunk in data.chunks_exact(8) {
457 let bytes: [u8; 8] = chunk.try_into().unwrap();
458 result.push(f64::from_le_bytes(bytes));
459 }
460
461 Ok(result)
462 }
463
464 pub fn to_i32_vec(&self) -> Result<Vec<i32>> {
466 if self.dtype() != TensorDtype::I32 {
467 return Err(Error::InvalidInput(format!(
468 "Cannot convert {} tensor to i32",
469 self.dtype().name()
470 )));
471 }
472
473 let data = self.data();
474 let mut result = Vec::with_capacity(self.element_count());
475
476 for chunk in data.chunks_exact(4) {
477 let bytes: [u8; 4] = chunk.try_into().unwrap();
478 result.push(i32::from_le_bytes(bytes));
479 }
480
481 Ok(result)
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488
489 #[test]
490 fn test_tensor_dtype_sizes() {
491 assert_eq!(TensorDtype::F32.size_bytes(), 4);
492 assert_eq!(TensorDtype::F16.size_bytes(), 2);
493 assert_eq!(TensorDtype::I8.size_bytes(), 1);
494 assert_eq!(TensorDtype::I32.size_bytes(), 4);
495 }
496
497 #[test]
498 fn test_tensor_shape() {
499 let shape = TensorShape::new(vec![2, 3, 4]);
500 assert_eq!(shape.rank(), 3);
501 assert_eq!(shape.element_count(), 24);
502 assert!(!shape.is_scalar());
503 assert!(!shape.is_vector());
504 assert!(!shape.is_matrix());
505
506 let scalar = TensorShape::scalar();
507 assert!(scalar.is_scalar());
508 assert_eq!(scalar.element_count(), 1);
509 }
510
511 #[test]
512 fn test_tensor_block_creation() {
513 let shape = TensorShape::new(vec![2, 2]);
514 let data: Vec<u8> = [1.0f32, 2.0, 3.0, 4.0]
515 .iter()
516 .flat_map(|f| f.to_le_bytes())
517 .collect();
518
519 let tensor = TensorBlock::new(Bytes::from(data), shape, TensorDtype::F32).unwrap();
520
521 assert_eq!(tensor.element_count(), 4);
522 assert_eq!(tensor.dtype(), TensorDtype::F32);
523 assert_eq!(tensor.shape().dims(), &[2, 2]);
524 }
525
526 #[test]
527 fn test_tensor_size_validation() {
528 let shape = TensorShape::new(vec![2, 2]);
529 let data: Vec<u8> = [1.0f32, 2.0, 3.0]
531 .iter()
532 .flat_map(|f| f.to_le_bytes())
533 .collect();
534
535 let result = TensorBlock::new(Bytes::from(data), shape, TensorDtype::F32);
536 assert!(result.is_err());
537 }
538
539 #[test]
540 fn test_tensor_metadata() {
541 let shape = TensorShape::new(vec![10, 20]);
542 let metadata = TensorMetadata::new(shape, TensorDtype::F32)
543 .with_name("layer1.weight".to_string())
544 .with_metadata("requires_grad".to_string(), "true".to_string());
545
546 assert_eq!(metadata.name, Some("layer1.weight".to_string()));
547 assert_eq!(metadata.expected_size(), 10 * 20 * 4); }
549
550 #[test]
551 fn test_tensor_from_f32_slice() {
552 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
553 let shape = TensorShape::new(vec![2, 3]);
554
555 let tensor = TensorBlock::from_f32_slice(&data, shape).unwrap();
556 assert_eq!(tensor.element_count(), 6);
557 assert_eq!(tensor.dtype(), TensorDtype::F32);
558
559 let recovered = tensor.to_f32_vec().unwrap();
561 assert_eq!(recovered, data);
562 }
563
564 #[test]
565 fn test_tensor_from_i32_slice() {
566 let data = vec![10i32, 20, 30, 40];
567 let shape = TensorShape::new(vec![2, 2]);
568
569 let tensor = TensorBlock::from_i32_slice(&data, shape).unwrap();
570 assert_eq!(tensor.element_count(), 4);
571
572 let recovered = tensor.to_i32_vec().unwrap();
573 assert_eq!(recovered, data);
574 }
575
576 #[test]
577 fn test_tensor_reshape() {
578 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
579 let shape = TensorShape::new(vec![2, 3]);
580 let tensor = TensorBlock::from_f32_slice(&data, shape).unwrap();
581
582 let reshaped = tensor.reshape(TensorShape::new(vec![3, 2])).unwrap();
584 assert_eq!(reshaped.shape().dims(), &[3, 2]);
585 assert_eq!(reshaped.element_count(), 6);
586
587 let recovered = reshaped.to_f32_vec().unwrap();
589 assert_eq!(recovered, data);
590 }
591
592 #[test]
593 fn test_tensor_reshape_invalid() {
594 let data = vec![1.0f32, 2.0, 3.0, 4.0];
595 let shape = TensorShape::new(vec![2, 2]);
596 let tensor = TensorBlock::from_f32_slice(&data, shape).unwrap();
597
598 let result = tensor.reshape(TensorShape::new(vec![3, 2])); assert!(result.is_err());
601 }
602
603 #[test]
604 fn test_tensor_type_checks() {
605 let data = vec![1.0f32, 2.0];
606 let tensor = TensorBlock::from_f32_slice(&data, TensorShape::new(vec![2])).unwrap();
607 assert!(tensor.is_vector());
608 assert!(!tensor.is_matrix());
609 assert!(!tensor.is_scalar());
610
611 let matrix = TensorBlock::from_f32_slice(&data, TensorShape::new(vec![1, 2])).unwrap();
612 assert!(matrix.is_matrix());
613 }
614
615 #[test]
616 fn test_tensor_to_vec_wrong_dtype() {
617 let data = vec![1i32, 2, 3];
618 let tensor = TensorBlock::from_i32_slice(&data, TensorShape::new(vec![3])).unwrap();
619
620 let result = tensor.to_f32_vec();
622 assert!(result.is_err());
623 }
624}