1use super::nixl::{self, NixlDescriptor};
10use super::{MemoryDescriptor, StorageKind};
11use std::any::Any;
12use std::sync::Arc;
13
14pub trait TensorDescriptor: MemoryDescriptor {
30 fn shape(&self) -> &[usize];
32
33 fn stride(&self) -> &[usize];
37
38 fn element_size(&self) -> usize;
40}
41
42pub trait TensorDescriptorExt: TensorDescriptor {
48 fn numel(&self) -> usize {
50 self.shape().iter().product()
51 }
52
53 fn ndim(&self) -> usize {
55 self.shape().len()
56 }
57
58 fn is_contiguous(&self) -> bool {
64 let shape = self.shape();
65 let stride = self.stride();
66
67 if shape.is_empty() {
68 return true;
69 }
70
71 let mut expected_stride = 1;
72 for i in (0..shape.len()).rev() {
73 if stride[i] != expected_stride {
74 return false;
75 }
76 expected_stride *= shape[i];
77 }
78 true
79 }
80
81 fn contiguous_stride(&self) -> Vec<usize> {
86 let shape = self.shape();
87 if shape.is_empty() {
88 return vec![];
89 }
90
91 let mut stride = vec![1; shape.len()];
92 for i in (0..shape.len() - 1).rev() {
93 stride[i] = stride[i + 1] * shape[i + 1];
94 }
95 stride
96 }
97
98 fn cuda_device_id(&self) -> Option<usize> {
100 match self.storage_kind() {
101 StorageKind::Device(idx) => Some(idx as usize),
102 _ => None,
103 }
104 }
105}
106
107impl<T: TensorDescriptor + ?Sized> TensorDescriptorExt for T {}
109
110impl nixl::NixlCompatible for Arc<dyn TensorDescriptor> {
115 fn nixl_params(&self) -> (*const u8, usize, nixl::MemType, u64) {
116 let storage = self.storage_kind();
117 let (mem_type, device_id) = match storage {
118 StorageKind::Device(idx) => (nixl::MemType::Vram, idx as u64),
119 StorageKind::System => (nixl::MemType::Dram, 0),
120 StorageKind::Pinned => (nixl::MemType::Dram, 0),
121 StorageKind::Disk(fd) => (nixl::MemType::File, fd),
122 };
123 (self.addr() as *const u8, self.size(), mem_type, device_id)
124 }
125}
126
127impl MemoryDescriptor for Arc<dyn TensorDescriptor> {
128 fn addr(&self) -> usize {
129 (**self).addr()
130 }
131
132 fn size(&self) -> usize {
133 (**self).size()
134 }
135
136 fn storage_kind(&self) -> StorageKind {
137 (**self).storage_kind()
138 }
139
140 fn as_any(&self) -> &dyn Any {
141 self
142 }
143
144 fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
145 None
146 }
147}
148
149impl TensorDescriptor for Arc<dyn TensorDescriptor> {
150 fn shape(&self) -> &[usize] {
151 (**self).shape()
152 }
153
154 fn stride(&self) -> &[usize] {
155 (**self).stride()
156 }
157
158 fn element_size(&self) -> usize {
159 (**self).element_size()
160 }
161}
162
163impl nixl::NixlCompatible for Arc<dyn TensorDescriptor + Send + Sync> {
168 fn nixl_params(&self) -> (*const u8, usize, nixl::MemType, u64) {
169 let storage = self.storage_kind();
170 let (mem_type, device_id) = match storage {
171 StorageKind::Device(idx) => (nixl::MemType::Vram, idx as u64),
172 StorageKind::System => (nixl::MemType::Dram, 0),
173 StorageKind::Pinned => (nixl::MemType::Dram, 0),
174 StorageKind::Disk(fd) => (nixl::MemType::File, fd),
175 };
176 (self.addr() as *const u8, self.size(), mem_type, device_id)
177 }
178}
179
180impl MemoryDescriptor for Arc<dyn TensorDescriptor + Send + Sync> {
181 fn addr(&self) -> usize {
182 (**self).addr()
183 }
184
185 fn size(&self) -> usize {
186 (**self).size()
187 }
188
189 fn storage_kind(&self) -> StorageKind {
190 (**self).storage_kind()
191 }
192
193 fn as_any(&self) -> &dyn Any {
194 self
195 }
196
197 fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
198 None
199 }
200}
201
202impl TensorDescriptor for Arc<dyn TensorDescriptor + Send + Sync> {
203 fn shape(&self) -> &[usize] {
204 (**self).shape()
205 }
206
207 fn stride(&self) -> &[usize] {
208 (**self).stride()
209 }
210
211 fn element_size(&self) -> usize {
212 (**self).element_size()
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 #[derive(Debug)]
222 struct TestTensor {
223 addr: usize,
224 size: usize,
225 shape: Vec<usize>,
226 stride: Vec<usize>,
227 element_size: usize,
228 }
229
230 impl MemoryDescriptor for TestTensor {
231 fn addr(&self) -> usize {
232 self.addr
233 }
234
235 fn size(&self) -> usize {
236 self.size
237 }
238
239 fn storage_kind(&self) -> StorageKind {
240 StorageKind::System
241 }
242
243 fn as_any(&self) -> &dyn Any {
244 self
245 }
246
247 fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
248 None
249 }
250 }
251
252 impl TensorDescriptor for TestTensor {
253 fn shape(&self) -> &[usize] {
254 &self.shape
255 }
256
257 fn stride(&self) -> &[usize] {
258 &self.stride
259 }
260
261 fn element_size(&self) -> usize {
262 self.element_size
263 }
264 }
265
266 #[test]
267 fn test_numel() {
268 let tensor = TestTensor {
269 addr: 0x1000,
270 size: 24 * 4, shape: vec![2, 3, 4],
272 stride: vec![12, 4, 1],
273 element_size: 4,
274 };
275 assert_eq!(tensor.numel(), 24);
276 }
277
278 #[test]
279 fn test_ndim() {
280 let tensor = TestTensor {
281 addr: 0x1000,
282 size: 24 * 4,
283 shape: vec![2, 3, 4],
284 stride: vec![12, 4, 1],
285 element_size: 4,
286 };
287 assert_eq!(tensor.ndim(), 3);
288 }
289
290 #[test]
291 fn test_is_contiguous_true() {
292 let tensor = TestTensor {
293 addr: 0x1000,
294 size: 24 * 4,
295 shape: vec![2, 3, 4],
296 stride: vec![12, 4, 1], element_size: 4,
298 };
299 assert!(tensor.is_contiguous());
300 }
301
302 #[test]
303 fn test_is_contiguous_false() {
304 let tensor = TestTensor {
305 addr: 0x1000,
306 size: 24 * 4,
307 shape: vec![2, 3, 4],
308 stride: vec![24, 4, 1], element_size: 4,
310 };
311 assert!(!tensor.is_contiguous());
312 }
313
314 #[test]
315 fn test_contiguous_stride() {
316 let tensor = TestTensor {
317 addr: 0x1000,
318 size: 24 * 4,
319 shape: vec![2, 3, 4],
320 stride: vec![24, 4, 1], element_size: 4,
322 };
323 assert_eq!(tensor.contiguous_stride(), vec![12, 4, 1]);
324 }
325
326 #[test]
327 fn test_empty_tensor() {
328 let tensor = TestTensor {
329 addr: 0x1000,
330 size: 0,
331 shape: vec![],
332 stride: vec![],
333 element_size: 4,
334 };
335 assert_eq!(tensor.numel(), 1); assert_eq!(tensor.ndim(), 0);
337 assert!(tensor.is_contiguous());
338 }
339
340 #[test]
341 fn test_1d_tensor_contiguous() {
342 let tensor = TestTensor {
343 addr: 0x1000,
344 size: 10 * 4,
345 shape: vec![10],
346 stride: vec![1],
347 element_size: 4,
348 };
349 assert_eq!(tensor.numel(), 10);
350 assert_eq!(tensor.ndim(), 1);
351 assert!(tensor.is_contiguous());
352 assert_eq!(tensor.contiguous_stride(), vec![1]);
353 }
354
355 #[test]
356 fn test_1d_tensor_non_contiguous() {
357 let tensor = TestTensor {
358 addr: 0x1000,
359 size: 10 * 4,
360 shape: vec![10],
361 stride: vec![2], element_size: 4,
363 };
364 assert!(!tensor.is_contiguous());
365 }
366
367 #[test]
368 fn test_2d_tensor() {
369 let tensor = TestTensor {
370 addr: 0x1000,
371 size: 6 * 4,
372 shape: vec![2, 3],
373 stride: vec![3, 1],
374 element_size: 4,
375 };
376 assert_eq!(tensor.numel(), 6);
377 assert_eq!(tensor.ndim(), 2);
378 assert!(tensor.is_contiguous());
379 }
380
381 #[test]
382 fn test_high_dimensional_tensor() {
383 let shape = vec![2, 3, 4, 5, 6];
385 let stride = vec![360, 120, 30, 6, 1];
387 let numel: usize = shape.iter().product();
388 let tensor = TestTensor {
389 addr: 0x1000,
390 size: numel * 4,
391 shape,
392 stride,
393 element_size: 4,
394 };
395 assert_eq!(tensor.numel(), 720);
396 assert_eq!(tensor.ndim(), 5);
397 assert!(tensor.is_contiguous());
398 assert_eq!(tensor.contiguous_stride(), vec![360, 120, 30, 6, 1]);
399 }
400
401 #[test]
402 fn test_tensor_with_size_1_dimensions() {
403 let tensor = TestTensor {
405 addr: 0x1000,
406 size: 12 * 4,
407 shape: vec![1, 3, 1, 4],
408 stride: vec![12, 4, 4, 1], element_size: 4,
410 };
411 assert_eq!(tensor.numel(), 12);
412 assert_eq!(tensor.ndim(), 4);
413 assert!(tensor.is_contiguous());
414 }
415
416 #[test]
417 fn test_contiguous_stride_empty() {
418 let tensor = TestTensor {
419 addr: 0x1000,
420 size: 0,
421 shape: vec![],
422 stride: vec![],
423 element_size: 4,
424 };
425 assert!(tensor.contiguous_stride().is_empty());
426 }
427
428 #[test]
429 fn test_contiguous_stride_1d() {
430 let tensor = TestTensor {
431 addr: 0x1000,
432 size: 5 * 4,
433 shape: vec![5],
434 stride: vec![1],
435 element_size: 4,
436 };
437 assert_eq!(tensor.contiguous_stride(), vec![1]);
438 }
439
440 #[test]
441 fn test_cuda_device_id_system() {
442 let tensor = TestTensor {
443 addr: 0x1000,
444 size: 100,
445 shape: vec![10],
446 stride: vec![1],
447 element_size: 4,
448 };
449 assert_eq!(tensor.cuda_device_id(), None);
450 }
451
452 #[derive(Debug)]
454 struct DeviceTensor {
455 addr: usize,
456 size: usize,
457 shape: Vec<usize>,
458 stride: Vec<usize>,
459 element_size: usize,
460 device_id: u32,
461 }
462
463 impl MemoryDescriptor for DeviceTensor {
464 fn addr(&self) -> usize {
465 self.addr
466 }
467
468 fn size(&self) -> usize {
469 self.size
470 }
471
472 fn storage_kind(&self) -> StorageKind {
473 StorageKind::Device(self.device_id)
474 }
475
476 fn as_any(&self) -> &dyn Any {
477 self
478 }
479
480 fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
481 None
482 }
483 }
484
485 impl TensorDescriptor for DeviceTensor {
486 fn shape(&self) -> &[usize] {
487 &self.shape
488 }
489
490 fn stride(&self) -> &[usize] {
491 &self.stride
492 }
493
494 fn element_size(&self) -> usize {
495 self.element_size
496 }
497 }
498
499 #[test]
500 fn test_cuda_device_id_device() {
501 let tensor = DeviceTensor {
502 addr: 0x1000,
503 size: 100,
504 shape: vec![10],
505 stride: vec![1],
506 element_size: 4,
507 device_id: 2,
508 };
509 assert_eq!(tensor.cuda_device_id(), Some(2));
510 }
511
512 #[test]
513 fn test_arc_tensor_descriptor() {
514 let tensor = TestTensor {
515 addr: 0x1000,
516 size: 24 * 4,
517 shape: vec![2, 3, 4],
518 stride: vec![12, 4, 1],
519 element_size: 4,
520 };
521 let arc: Arc<dyn TensorDescriptor> = Arc::new(tensor);
522
523 assert_eq!(arc.addr(), 0x1000);
524 assert_eq!(arc.size(), 24 * 4);
525 assert_eq!(arc.shape(), &[2, 3, 4]);
526 assert_eq!(arc.stride(), &[12, 4, 1]);
527 assert_eq!(arc.element_size(), 4);
528 assert_eq!(arc.storage_kind(), StorageKind::System);
529 assert!(arc.nixl_descriptor().is_none());
530 }
531
532 #[test]
533 fn test_arc_tensor_send_sync() {
534 struct SendSyncTensor {
536 addr: usize,
537 size: usize,
538 shape: Vec<usize>,
539 stride: Vec<usize>,
540 element_size: usize,
541 }
542
543 unsafe impl Send for SendSyncTensor {}
544 unsafe impl Sync for SendSyncTensor {}
545
546 impl std::fmt::Debug for SendSyncTensor {
547 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
548 f.debug_struct("SendSyncTensor").finish()
549 }
550 }
551
552 impl MemoryDescriptor for SendSyncTensor {
553 fn addr(&self) -> usize {
554 self.addr
555 }
556 fn size(&self) -> usize {
557 self.size
558 }
559 fn storage_kind(&self) -> StorageKind {
560 StorageKind::System
561 }
562 fn as_any(&self) -> &dyn Any {
563 self
564 }
565 fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
566 None
567 }
568 }
569
570 impl TensorDescriptor for SendSyncTensor {
571 fn shape(&self) -> &[usize] {
572 &self.shape
573 }
574 fn stride(&self) -> &[usize] {
575 &self.stride
576 }
577 fn element_size(&self) -> usize {
578 self.element_size
579 }
580 }
581
582 let tensor = SendSyncTensor {
583 addr: 0x2000,
584 size: 100,
585 shape: vec![10],
586 stride: vec![1],
587 element_size: 4,
588 };
589 let arc: Arc<dyn TensorDescriptor + Send + Sync> = Arc::new(tensor);
590
591 assert_eq!(arc.addr(), 0x2000);
592 assert_eq!(arc.size(), 100);
593 assert_eq!(arc.shape(), &[10]);
594 assert_eq!(arc.stride(), &[1]);
595 assert_eq!(arc.element_size(), 4);
596 }
597
598 #[test]
599 fn test_tensor_shape_stride_element_size() {
600 let tensor = TestTensor {
601 addr: 0x1000,
602 size: 48,
603 shape: vec![3, 4],
604 stride: vec![4, 1],
605 element_size: 4,
606 };
607 assert_eq!(tensor.shape(), &[3, 4]);
608 assert_eq!(tensor.stride(), &[4, 1]);
609 assert_eq!(tensor.element_size(), 4);
610 }
611
612 #[test]
613 fn test_tensor_numel_single_element() {
614 let tensor = TestTensor {
615 addr: 0x1000,
616 size: 4,
617 shape: vec![1, 1, 1],
618 stride: vec![1, 1, 1],
619 element_size: 4,
620 };
621 assert_eq!(tensor.numel(), 1);
622 }
623}