Skip to main content

ndrs/tensor/
handle.rs

1/// Tensor 句柄类型定义(RcTensor, ArcTensor)及基础方法
2use super::Tensor;
3use crate::{DType, Device};
4use anyhow::{Context, Result, anyhow, bail};
5use parking_lot::ReentrantMutex;
6use std::cell::RefCell;
7use std::rc::Rc;
8use std::sync::Arc;
9
10/// 引用计数(非线程安全)张量句柄。
11#[derive(Clone, Debug)]
12pub struct RcTensor(pub Rc<RefCell<Tensor>>);
13
14impl RcTensor {
15    /// 获取内部 `RefCell<Tensor>` 的只读引用(不获取锁)
16    pub fn lock(&self) -> &RefCell<Tensor> {
17        &*self.0
18    }
19
20    /// 从 `Tensor` 创建新的 `RcTensor`
21    pub fn from_tensor(t: Tensor) -> Self {
22        RcTensor(Rc::new(RefCell::new(t)))
23    }
24
25    /// 获取内部句柄的克隆
26    pub fn into_inner(self) -> Rc<RefCell<Tensor>> {
27        self.0
28    }
29}
30
31impl From<Tensor> for RcTensor {
32    fn from(t: Tensor) -> Self {
33        RcTensor::from_tensor(t)
34    }
35}
36
37/// 原子引用计数(线程安全)张量句柄。
38#[derive(Clone, Debug)]
39pub struct ArcTensor(pub Arc<ReentrantMutex<RefCell<Tensor>>>);
40
41impl ArcTensor {
42    /// 获取内部数据的互斥锁守卫
43    pub fn lock(&self) -> parking_lot::ReentrantMutexGuard<RefCell<Tensor>> {
44        self.0.lock()
45    }
46
47    /// 从 `Tensor` 创建新的 `ArcTensor`
48    pub fn from_tensor(t: Tensor) -> Self {
49        ArcTensor(Arc::new(ReentrantMutex::new(RefCell::new(t))))
50    }
51
52    /// 获取内部句柄的克隆
53    pub fn into_inner(self) -> Arc<ReentrantMutex<RefCell<Tensor>>> {
54        self.0
55    }
56}
57
58impl From<Tensor> for ArcTensor {
59    fn from(t: Tensor) -> Self {
60        ArcTensor::from_tensor(t)
61    }
62}
63
64impl ArcTensor {
65    pub fn shape(&self) -> Vec<usize> {
66        self.0.lock().borrow().shape().to_vec()
67    }
68    pub fn dtype(&self) -> DType {
69        self.0.lock().borrow().dtype()
70    }
71    pub fn device(&self) -> Device {
72        self.0.lock().borrow().device()
73    }
74}