1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
use crate::autograd::BackwardOp;
use crate::storage::Storage;
use crate::Tensor;
use rayon::prelude::*;
use std::sync::Arc;
#[derive(Debug)]
pub struct EmbeddingBackward {
pub input: Tensor, // indices
pub weight: Tensor,
pub num_embeddings: usize,
pub embedding_dim: usize,
pub padding_idx: Option<usize>,
}
impl BackwardOp for EmbeddingBackward {
fn backward(&self, grad: &Tensor) {
if self.weight.requires_grad() {
// Grad is (N, *, EmbeddingDim)
// Input is (N, *) indices
// We need to scatter add grad to weight.grad
// weight.grad shape: (NumEmbeddings, EmbeddingDim)
// This is sparse update.
// For simplicity, we can iterate over indices and accumulate.
// But we need to lock weight.grad.
// Let's create a dense grad tensor for weight first (inefficient but simple)
// Or better: accumulate directly if possible.
// Tensor::accumulate_grad expects a Tensor.
// We need to implement a "SparseAccumulate" or just create a dense Zero tensor and fill it.
let mut weight_grad_data = vec![0.0; self.num_embeddings * self.embedding_dim];
#[cfg(feature = "wgpu_backend")]
let (grad, input) = {
let g = if grad.storage().device().is_wgpu() {
grad.to_cpu()
} else {
grad.clone()
};
let i = if self.input.storage().device().is_wgpu() {
self.input.to_cpu()
} else {
self.input.clone()
};
(g, i)
};
#[cfg(not(feature = "wgpu_backend"))]
let input = self.input.clone();
let grad_guard = grad.data();
let grad_data = &*grad_guard;
let input_guard = input.data(); // These are f32, need to cast to usize
let input_data = &*input_guard;
// Check shapes
// Input: (B...)
// Grad: (B..., Dim)
// Input len * Dim == Grad len
let num_indices = input_data.len();
let dim = self.embedding_dim;
if grad_data.len() != num_indices * dim {
panic!("Embedding backward shape mismatch");
}
// Iterate and accumulate
// This part is hard to parallelize without atomic adds on weight_grad_data.
// So run serial or use localized buffers. Serial for now.
for (i, &idx_f) in input_data.iter().enumerate() {
let idx = idx_f as usize;
if let Some(pad) = self.padding_idx {
if idx == pad {
continue;
}
}
if idx >= self.num_embeddings {
// Index out of bounds, ignore or panic? PyTorch panics or errors.
continue;
}
let grad_offset = i * dim;
let weight_offset = idx * dim;
for j in 0..dim {
weight_grad_data[weight_offset + j] += grad_data[grad_offset + j];
}
}
let weight_grad = Tensor::new(&weight_grad_data, self.weight.shape());
self.weight.accumulate_grad(&weight_grad);
self.weight.backward_step();
}
}
}
pub fn embedding(
input: &Tensor,
weight: &Tensor,
padding_idx: Option<usize>,
_max_norm: Option<f32>,
_norm_type: f32,
_scale_grad_by_freq: bool,
_sparse: bool,
) -> Tensor {
// Input: Indices (Arbitrary Shape) -> but stored as f32 in Tensor
// Weight: (NumEmbeddings, EmbeddingDim)
// Output: (InputShape..., EmbeddingDim)
let weight_shape = weight.shape();
if weight_shape.len() != 2 {
panic!("Embedding weight must be 2D");
}
let num_embeddings = weight_shape[0];
let embedding_dim = weight_shape[1];
#[cfg(feature = "wgpu_backend")]
let (input, weight) = {
let i = if input.storage().device().is_wgpu() {
input.to_cpu()
} else {
input.clone()
};
let w = if weight.storage().device().is_wgpu() {
weight.to_cpu()
} else {
weight.clone()
};
(i, w)
};
#[cfg(not(feature = "wgpu_backend"))]
let (input, weight) = (input.clone(), weight.clone());
let input_guard = input.data();
let input_data = &*input_guard;
let weight_guard = weight.data();
let weight_data = &*weight_guard;
let num_indices = input_data.len();
let mut output_data = vec![0.0; num_indices * embedding_dim];
// Parallel lookup
output_data
.par_chunks_mut(embedding_dim)
.enumerate()
.for_each(|(i, out_row)| {
let idx_f = input_data[i];
let idx = idx_f as usize;
if idx >= num_embeddings {
// Panic in real scenario
// panic!("Index {} out of bounds for embedding size {}", idx, num_embeddings);
// But inside parallel iterator panic is messy.
// Let's just fill 0 or clamp?
// PyTorch: runtime error.
return;
}
if let Some(pad) = padding_idx {
if idx == pad {
// Zero vector
out_row.fill(0.0);
return;
}
}
let weight_offset = idx * embedding_dim;
let w_row = &weight_data[weight_offset..weight_offset + embedding_dim];
out_row.copy_from_slice(w_row);
});
let mut output_shape = input.shape().to_vec();
output_shape.push(embedding_dim);
let storage = Storage::new(output_data);
let mut tensor = Tensor::new_with_storage(storage, &output_shape);
if weight.requires_grad() {
tensor.set_requires_grad_mut(true);
tensor.set_op(Arc::new(EmbeddingBackward {
input: input.clone(),
weight: weight.clone(),
num_embeddings,
embedding_dim,
padding_idx,
}));
}
tensor
}