Skip to main content

dynamo_memory/
tensor.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Tensor abstraction built on top of MemoryDescriptor.
5//!
6//! A tensor is memory with shape, stride, and element size metadata.
7//! The underlying memory could be externally owned, self-owned, or a view.
8
9use super::nixl::{self, NixlDescriptor};
10use super::{MemoryDescriptor, StorageKind};
11use std::any::Any;
12use std::sync::Arc;
13
14/// A tensor is memory with shape, stride, and element size metadata.
15///
16/// This trait extends [`MemoryDescriptor`] with tensor-specific metadata.
17/// The underlying memory could be externally owned, self-owned, or a view.
18///
19/// # Shape and Stride
20///
21/// - `shape()` returns the number of elements in each dimension
22/// - `stride()` returns the number of elements to skip when incrementing each dimension
23/// - `element_size()` returns the number of bytes per element
24///
25/// For a contiguous tensor with shape `[2, 3, 4]`:
26/// - stride would be `[12, 4, 1]` (row-major/C order)
27/// - total elements = 2 * 3 * 4 = 24
28/// - total bytes = 24 * element_size()
29pub trait TensorDescriptor: MemoryDescriptor {
30    /// Shape of the tensor (number of elements per dimension).
31    fn shape(&self) -> &[usize];
32
33    /// Stride of the tensor (elements to skip per dimension).
34    ///
35    /// `stride[i]` indicates how many elements to skip when incrementing dimension `i`.
36    fn stride(&self) -> &[usize];
37
38    /// Number of bytes per element.
39    fn element_size(&self) -> usize;
40}
41
42// =============================================================================
43// Helper methods for TensorDescriptor
44// =============================================================================
45
46/// Extension trait providing helper methods for tensor descriptors.
47pub trait TensorDescriptorExt: TensorDescriptor {
48    /// Total number of elements in the tensor (product of shape).
49    fn numel(&self) -> usize {
50        self.shape().iter().product()
51    }
52
53    /// Number of dimensions (rank).
54    fn ndim(&self) -> usize {
55        self.shape().len()
56    }
57
58    /// Check if tensor is contiguous in memory (row-major/C order).
59    ///
60    /// A tensor is contiguous if its strides follow the pattern where
61    /// the last dimension has stride 1, and each preceding dimension
62    /// has stride equal to the product of all following dimensions.
63    fn is_contiguous(&self) -> bool {
64        let shape = self.shape();
65        let stride = self.stride();
66
67        if shape.is_empty() {
68            return true;
69        }
70
71        let mut expected_stride = 1;
72        for i in (0..shape.len()).rev() {
73            if stride[i] != expected_stride {
74                return false;
75            }
76            expected_stride *= shape[i];
77        }
78        true
79    }
80
81    /// Compute the contiguous stride for the current shape.
82    ///
83    /// Returns the stride that would make this tensor contiguous
84    /// (row-major/C order).
85    fn contiguous_stride(&self) -> Vec<usize> {
86        let shape = self.shape();
87        if shape.is_empty() {
88            return vec![];
89        }
90
91        let mut stride = vec![1; shape.len()];
92        for i in (0..shape.len() - 1).rev() {
93            stride[i] = stride[i + 1] * shape[i + 1];
94        }
95        stride
96    }
97
98    /// Returns the CUDA device ID if the tensor is on a CUDA device.
99    fn cuda_device_id(&self) -> Option<usize> {
100        match self.storage_kind() {
101            StorageKind::Device(idx) => Some(idx as usize),
102            _ => None,
103        }
104    }
105}
106
107// Blanket impl for all TensorDescriptor types
108impl<T: TensorDescriptor + ?Sized> TensorDescriptorExt for T {}
109
110// =============================================================================
111// Arc<dyn TensorDescriptor> support for NixlRegisterExt
112// =============================================================================
113
114impl nixl::NixlCompatible for Arc<dyn TensorDescriptor> {
115    fn nixl_params(&self) -> (*const u8, usize, nixl::MemType, u64) {
116        let storage = self.storage_kind();
117        let (mem_type, device_id) = match storage {
118            StorageKind::Device(idx) => (nixl::MemType::Vram, idx as u64),
119            StorageKind::System => (nixl::MemType::Dram, 0),
120            StorageKind::Pinned => (nixl::MemType::Dram, 0),
121            StorageKind::Disk(fd) => (nixl::MemType::File, fd),
122        };
123        (self.addr() as *const u8, self.size(), mem_type, device_id)
124    }
125}
126
127impl MemoryDescriptor for Arc<dyn TensorDescriptor> {
128    fn addr(&self) -> usize {
129        (**self).addr()
130    }
131
132    fn size(&self) -> usize {
133        (**self).size()
134    }
135
136    fn storage_kind(&self) -> StorageKind {
137        (**self).storage_kind()
138    }
139
140    fn as_any(&self) -> &dyn Any {
141        self
142    }
143
144    fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
145        None
146    }
147}
148
149impl TensorDescriptor for Arc<dyn TensorDescriptor> {
150    fn shape(&self) -> &[usize] {
151        (**self).shape()
152    }
153
154    fn stride(&self) -> &[usize] {
155        (**self).stride()
156    }
157
158    fn element_size(&self) -> usize {
159        (**self).element_size()
160    }
161}
162
163// =============================================================================
164// Arc<dyn TensorDescriptor + Send + Sync> support
165// =============================================================================
166
167impl nixl::NixlCompatible for Arc<dyn TensorDescriptor + Send + Sync> {
168    fn nixl_params(&self) -> (*const u8, usize, nixl::MemType, u64) {
169        let storage = self.storage_kind();
170        let (mem_type, device_id) = match storage {
171            StorageKind::Device(idx) => (nixl::MemType::Vram, idx as u64),
172            StorageKind::System => (nixl::MemType::Dram, 0),
173            StorageKind::Pinned => (nixl::MemType::Dram, 0),
174            StorageKind::Disk(fd) => (nixl::MemType::File, fd),
175        };
176        (self.addr() as *const u8, self.size(), mem_type, device_id)
177    }
178}
179
180impl MemoryDescriptor for Arc<dyn TensorDescriptor + Send + Sync> {
181    fn addr(&self) -> usize {
182        (**self).addr()
183    }
184
185    fn size(&self) -> usize {
186        (**self).size()
187    }
188
189    fn storage_kind(&self) -> StorageKind {
190        (**self).storage_kind()
191    }
192
193    fn as_any(&self) -> &dyn Any {
194        self
195    }
196
197    fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
198        None
199    }
200}
201
202impl TensorDescriptor for Arc<dyn TensorDescriptor + Send + Sync> {
203    fn shape(&self) -> &[usize] {
204        (**self).shape()
205    }
206
207    fn stride(&self) -> &[usize] {
208        (**self).stride()
209    }
210
211    fn element_size(&self) -> usize {
212        (**self).element_size()
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    /// Simple test tensor for unit tests
221    #[derive(Debug)]
222    struct TestTensor {
223        addr: usize,
224        size: usize,
225        shape: Vec<usize>,
226        stride: Vec<usize>,
227        element_size: usize,
228    }
229
230    impl MemoryDescriptor for TestTensor {
231        fn addr(&self) -> usize {
232            self.addr
233        }
234
235        fn size(&self) -> usize {
236            self.size
237        }
238
239        fn storage_kind(&self) -> StorageKind {
240            StorageKind::System
241        }
242
243        fn as_any(&self) -> &dyn Any {
244            self
245        }
246
247        fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
248            None
249        }
250    }
251
252    impl TensorDescriptor for TestTensor {
253        fn shape(&self) -> &[usize] {
254            &self.shape
255        }
256
257        fn stride(&self) -> &[usize] {
258            &self.stride
259        }
260
261        fn element_size(&self) -> usize {
262            self.element_size
263        }
264    }
265
266    #[test]
267    fn test_numel() {
268        let tensor = TestTensor {
269            addr: 0x1000,
270            size: 24 * 4, // 24 elements * 4 bytes
271            shape: vec![2, 3, 4],
272            stride: vec![12, 4, 1],
273            element_size: 4,
274        };
275        assert_eq!(tensor.numel(), 24);
276    }
277
278    #[test]
279    fn test_ndim() {
280        let tensor = TestTensor {
281            addr: 0x1000,
282            size: 24 * 4,
283            shape: vec![2, 3, 4],
284            stride: vec![12, 4, 1],
285            element_size: 4,
286        };
287        assert_eq!(tensor.ndim(), 3);
288    }
289
290    #[test]
291    fn test_is_contiguous_true() {
292        let tensor = TestTensor {
293            addr: 0x1000,
294            size: 24 * 4,
295            shape: vec![2, 3, 4],
296            stride: vec![12, 4, 1], // Contiguous stride
297            element_size: 4,
298        };
299        assert!(tensor.is_contiguous());
300    }
301
302    #[test]
303    fn test_is_contiguous_false() {
304        let tensor = TestTensor {
305            addr: 0x1000,
306            size: 24 * 4,
307            shape: vec![2, 3, 4],
308            stride: vec![24, 4, 1], // Non-contiguous (gap between first dim)
309            element_size: 4,
310        };
311        assert!(!tensor.is_contiguous());
312    }
313
314    #[test]
315    fn test_contiguous_stride() {
316        let tensor = TestTensor {
317            addr: 0x1000,
318            size: 24 * 4,
319            shape: vec![2, 3, 4],
320            stride: vec![24, 4, 1], // Non-contiguous
321            element_size: 4,
322        };
323        assert_eq!(tensor.contiguous_stride(), vec![12, 4, 1]);
324    }
325
326    #[test]
327    fn test_empty_tensor() {
328        let tensor = TestTensor {
329            addr: 0x1000,
330            size: 0,
331            shape: vec![],
332            stride: vec![],
333            element_size: 4,
334        };
335        assert_eq!(tensor.numel(), 1); // Empty product is 1
336        assert_eq!(tensor.ndim(), 0);
337        assert!(tensor.is_contiguous());
338    }
339
340    #[test]
341    fn test_1d_tensor_contiguous() {
342        let tensor = TestTensor {
343            addr: 0x1000,
344            size: 10 * 4,
345            shape: vec![10],
346            stride: vec![1],
347            element_size: 4,
348        };
349        assert_eq!(tensor.numel(), 10);
350        assert_eq!(tensor.ndim(), 1);
351        assert!(tensor.is_contiguous());
352        assert_eq!(tensor.contiguous_stride(), vec![1]);
353    }
354
355    #[test]
356    fn test_1d_tensor_non_contiguous() {
357        let tensor = TestTensor {
358            addr: 0x1000,
359            size: 10 * 4,
360            shape: vec![10],
361            stride: vec![2], // Strided access (every other element)
362            element_size: 4,
363        };
364        assert!(!tensor.is_contiguous());
365    }
366
367    #[test]
368    fn test_2d_tensor() {
369        let tensor = TestTensor {
370            addr: 0x1000,
371            size: 6 * 4,
372            shape: vec![2, 3],
373            stride: vec![3, 1],
374            element_size: 4,
375        };
376        assert_eq!(tensor.numel(), 6);
377        assert_eq!(tensor.ndim(), 2);
378        assert!(tensor.is_contiguous());
379    }
380
381    #[test]
382    fn test_high_dimensional_tensor() {
383        // 5D tensor: [2, 3, 4, 5, 6]
384        let shape = vec![2, 3, 4, 5, 6];
385        // Contiguous stride: [360, 120, 30, 6, 1]
386        let stride = vec![360, 120, 30, 6, 1];
387        let numel: usize = shape.iter().product();
388        let tensor = TestTensor {
389            addr: 0x1000,
390            size: numel * 4,
391            shape,
392            stride,
393            element_size: 4,
394        };
395        assert_eq!(tensor.numel(), 720);
396        assert_eq!(tensor.ndim(), 5);
397        assert!(tensor.is_contiguous());
398        assert_eq!(tensor.contiguous_stride(), vec![360, 120, 30, 6, 1]);
399    }
400
401    #[test]
402    fn test_tensor_with_size_1_dimensions() {
403        // Shape with singleton dimensions: [1, 3, 1, 4]
404        let tensor = TestTensor {
405            addr: 0x1000,
406            size: 12 * 4,
407            shape: vec![1, 3, 1, 4],
408            stride: vec![12, 4, 4, 1], // Contiguous for this shape
409            element_size: 4,
410        };
411        assert_eq!(tensor.numel(), 12);
412        assert_eq!(tensor.ndim(), 4);
413        assert!(tensor.is_contiguous());
414    }
415
416    #[test]
417    fn test_contiguous_stride_empty() {
418        let tensor = TestTensor {
419            addr: 0x1000,
420            size: 0,
421            shape: vec![],
422            stride: vec![],
423            element_size: 4,
424        };
425        assert!(tensor.contiguous_stride().is_empty());
426    }
427
428    #[test]
429    fn test_contiguous_stride_1d() {
430        let tensor = TestTensor {
431            addr: 0x1000,
432            size: 5 * 4,
433            shape: vec![5],
434            stride: vec![1],
435            element_size: 4,
436        };
437        assert_eq!(tensor.contiguous_stride(), vec![1]);
438    }
439
440    #[test]
441    fn test_cuda_device_id_system() {
442        let tensor = TestTensor {
443            addr: 0x1000,
444            size: 100,
445            shape: vec![10],
446            stride: vec![1],
447            element_size: 4,
448        };
449        assert_eq!(tensor.cuda_device_id(), None);
450    }
451
452    /// Test tensor that reports Device storage kind
453    #[derive(Debug)]
454    struct DeviceTensor {
455        addr: usize,
456        size: usize,
457        shape: Vec<usize>,
458        stride: Vec<usize>,
459        element_size: usize,
460        device_id: u32,
461    }
462
463    impl MemoryDescriptor for DeviceTensor {
464        fn addr(&self) -> usize {
465            self.addr
466        }
467
468        fn size(&self) -> usize {
469            self.size
470        }
471
472        fn storage_kind(&self) -> StorageKind {
473            StorageKind::Device(self.device_id)
474        }
475
476        fn as_any(&self) -> &dyn Any {
477            self
478        }
479
480        fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
481            None
482        }
483    }
484
485    impl TensorDescriptor for DeviceTensor {
486        fn shape(&self) -> &[usize] {
487            &self.shape
488        }
489
490        fn stride(&self) -> &[usize] {
491            &self.stride
492        }
493
494        fn element_size(&self) -> usize {
495            self.element_size
496        }
497    }
498
499    #[test]
500    fn test_cuda_device_id_device() {
501        let tensor = DeviceTensor {
502            addr: 0x1000,
503            size: 100,
504            shape: vec![10],
505            stride: vec![1],
506            element_size: 4,
507            device_id: 2,
508        };
509        assert_eq!(tensor.cuda_device_id(), Some(2));
510    }
511
512    #[test]
513    fn test_arc_tensor_descriptor() {
514        let tensor = TestTensor {
515            addr: 0x1000,
516            size: 24 * 4,
517            shape: vec![2, 3, 4],
518            stride: vec![12, 4, 1],
519            element_size: 4,
520        };
521        let arc: Arc<dyn TensorDescriptor> = Arc::new(tensor);
522
523        assert_eq!(arc.addr(), 0x1000);
524        assert_eq!(arc.size(), 24 * 4);
525        assert_eq!(arc.shape(), &[2, 3, 4]);
526        assert_eq!(arc.stride(), &[12, 4, 1]);
527        assert_eq!(arc.element_size(), 4);
528        assert_eq!(arc.storage_kind(), StorageKind::System);
529        assert!(arc.nixl_descriptor().is_none());
530    }
531
532    #[test]
533    fn test_arc_tensor_send_sync() {
534        // TestTensor doesn't impl Send+Sync, so we need a type that does
535        struct SendSyncTensor {
536            addr: usize,
537            size: usize,
538            shape: Vec<usize>,
539            stride: Vec<usize>,
540            element_size: usize,
541        }
542
543        unsafe impl Send for SendSyncTensor {}
544        unsafe impl Sync for SendSyncTensor {}
545
546        impl std::fmt::Debug for SendSyncTensor {
547            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
548                f.debug_struct("SendSyncTensor").finish()
549            }
550        }
551
552        impl MemoryDescriptor for SendSyncTensor {
553            fn addr(&self) -> usize {
554                self.addr
555            }
556            fn size(&self) -> usize {
557                self.size
558            }
559            fn storage_kind(&self) -> StorageKind {
560                StorageKind::System
561            }
562            fn as_any(&self) -> &dyn Any {
563                self
564            }
565            fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
566                None
567            }
568        }
569
570        impl TensorDescriptor for SendSyncTensor {
571            fn shape(&self) -> &[usize] {
572                &self.shape
573            }
574            fn stride(&self) -> &[usize] {
575                &self.stride
576            }
577            fn element_size(&self) -> usize {
578                self.element_size
579            }
580        }
581
582        let tensor = SendSyncTensor {
583            addr: 0x2000,
584            size: 100,
585            shape: vec![10],
586            stride: vec![1],
587            element_size: 4,
588        };
589        let arc: Arc<dyn TensorDescriptor + Send + Sync> = Arc::new(tensor);
590
591        assert_eq!(arc.addr(), 0x2000);
592        assert_eq!(arc.size(), 100);
593        assert_eq!(arc.shape(), &[10]);
594        assert_eq!(arc.stride(), &[1]);
595        assert_eq!(arc.element_size(), 4);
596    }
597
598    #[test]
599    fn test_tensor_shape_stride_element_size() {
600        let tensor = TestTensor {
601            addr: 0x1000,
602            size: 48,
603            shape: vec![3, 4],
604            stride: vec![4, 1],
605            element_size: 4,
606        };
607        assert_eq!(tensor.shape(), &[3, 4]);
608        assert_eq!(tensor.stride(), &[4, 1]);
609        assert_eq!(tensor.element_size(), 4);
610    }
611
612    #[test]
613    fn test_tensor_numel_single_element() {
614        let tensor = TestTensor {
615            addr: 0x1000,
616            size: 4,
617            shape: vec![1, 1, 1],
618            stride: vec![1, 1, 1],
619            element_size: 4,
620        };
621        assert_eq!(tensor.numel(), 1);
622    }
623}