1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
use std::sync::{Arc, RwLock};
use crate::{
Device, Shape, Tensor, TensorError,
tensor::{TensorData, TensorInner},
};
impl Tensor {
/// Deep clone the data from this tensor to a new tensor.
///
/// # Notes
/// * Because the cloned tensor is a new tensor, it will not be connected to the previous computation graph.
///
/// # Returns
/// * `Ok(tensor)` - The cloned tensor if successful.
/// * `Err(TensorError)` - The error when cloning the tensor.
pub fn deep_clone(&self) -> Result<Self, TensorError> {
let inner = match &*self.data.inner.read()? {
TensorInner::Tensor(tensor) => TensorInner::Tensor(tensor.detach()),
TensorInner::Var(var) => {
TensorInner::Var(candle_core::Var::from_tensor(&var.detach())?)
}
};
let device = self.data.device.read()?.clone();
let grad = match &*self.data.grad.read()? {
Some(grad) => Some(grad.detach()),
None => None,
};
Ok(Self {
data: Arc::new(TensorData {
inner: RwLock::new(inner),
device: RwLock::new(device),
grad: RwLock::new(grad),
parents: RwLock::new(vec![]),
}),
})
}
/// Create a new tensor with random values uniformly distributed in the specified range.
///
/// # Parameters
/// * `low` - The lower bound of the uniform distribution.
/// * `high` - The upper bound of the uniform distribution.
/// * `shape` - The shape of the tensor.
/// * `device` - The device to store the tensor.
/// * `grad_enabled` - Whether to enable gradient tracking for the tensor.
///
/// # Returns
/// * `Ok(tensor)` - The new tensor if successful.
/// * `Err(TensorError)` - The error when creating the tensor.
pub fn from_random_uniform<T>(
low: T,
high: T,
shape: &Shape,
device: &Device,
grad_enabled: bool,
) -> Result<Self, TensorError>
where
T: candle_core::FloatDType,
{
let inner = match grad_enabled {
true => TensorInner::Var(candle_core::Var::rand(low, high, shape, device)?),
false => TensorInner::Tensor(candle_core::Tensor::rand(low, high, shape, device)?),
};
let grad = match &inner {
TensorInner::Var(var) => Some(var.zeros_like()?),
TensorInner::Tensor(_) => None,
};
Ok(Self {
data: Arc::new(TensorData {
inner: RwLock::new(inner),
device: RwLock::new(device.clone()),
parents: RwLock::new(vec![]),
grad: RwLock::new(grad),
}),
})
}
}