1use crate::array::Array;
9use crate::autograd::{Tensor, ComputeNode, OpKind, is_grad_enabled};
10use crate::autograd::backward;
11use crate::ops;
12use anyhow::Result;
13
14impl Tensor {
15 pub fn add(&self, other: &Tensor) -> Result<Self> {
17 let result = ops::add(&self.data, &other.data)?;
19
20 if !is_grad_enabled() || (!self.requires_grad && !other.requires_grad) {
22 return Ok(Tensor::new(result, false));
23 }
24
25 let requires_grad = self.requires_grad || other.requires_grad;
27 let backward_fn = Box::new(backward::add_backward);
28
29 let node = ComputeNode::new(
30 OpKind::Add,
31 vec![self.clone(), other.clone()],
32 Some(backward_fn),
33 );
34
35 Ok(Tensor::from_operation(result, node, requires_grad))
36 }
37
38 pub fn sub(&self, other: &Tensor) -> Result<Self> {
40 let result = ops::sub(&self.data, &other.data)?;
41
42 if !is_grad_enabled() || (!self.requires_grad && !other.requires_grad) {
43 return Ok(Tensor::new(result, false));
44 }
45
46 let requires_grad = self.requires_grad || other.requires_grad;
47 let backward_fn = Box::new(backward::sub_backward);
48
49 let node = ComputeNode::new(
50 OpKind::Sub,
51 vec![self.clone(), other.clone()],
52 Some(backward_fn),
53 );
54
55 Ok(Tensor::from_operation(result, node, requires_grad))
56 }
57
58 pub fn mul(&self, other: &Tensor) -> Result<Self> {
60 let result = ops::mul(&self.data, &other.data)?;
63
64 if !is_grad_enabled() || (!self.requires_grad && !other.requires_grad) {
65 return Ok(Tensor::new(result, false));
66 }
67
68 let requires_grad = self.requires_grad || other.requires_grad;
69 let backward_fn = Box::new(backward::mul_backward);
70
71 let node = ComputeNode::new(
72 OpKind::Mul,
73 vec![self.clone(), other.clone()],
74 Some(backward_fn),
75 );
76
77 Ok(Tensor::from_operation(result, node, requires_grad))
78 }
79
80 pub fn div(&self, other: &Tensor) -> Result<Self> {
82 let result = ops::div(&self.data, &other.data)?;
83
84 if !is_grad_enabled() || (!self.requires_grad && !other.requires_grad) {
85 return Ok(Tensor::new(result, false));
86 }
87
88 let requires_grad = self.requires_grad || other.requires_grad;
89 let backward_fn = Box::new(backward::div_backward);
90
91 let node = ComputeNode::new(
92 OpKind::Div,
93 vec![self.clone(), other.clone()],
94 Some(backward_fn),
95 );
96
97 Ok(Tensor::from_operation(result, node, requires_grad))
98 }
99
100 pub fn matmul(&self, other: &Tensor) -> Result<Self> {
102 let result = ops::matmul(&self.data, &other.data)?;
104
105 if !is_grad_enabled() || (!self.requires_grad && !other.requires_grad) {
106 return Ok(Tensor::new(result, false));
107 }
108
109 let requires_grad = self.requires_grad || other.requires_grad;
110 let backward_fn = Box::new(backward::matmul_backward);
111
112 let node = ComputeNode::new(
113 OpKind::MatMul,
114 vec![self.clone(), other.clone()],
115 Some(backward_fn),
116 );
117
118 Ok(Tensor::from_operation(result, node, requires_grad))
119 }
120
121 pub fn relu(&self) -> Result<Self> {
123 let result = Array::new(
125 self.data.shape.clone(),
126 self.data.data.iter().map(|&x| x.max(0.0)).collect()
127 );
128
129 if !is_grad_enabled() || !self.requires_grad {
130 return Ok(Tensor::new(result, false));
131 }
132
133 let backward_fn = Box::new(backward::relu_backward);
134
135 let node = ComputeNode::new(
136 OpKind::ReLU,
137 vec![self.clone()],
138 Some(backward_fn),
139 );
140
141 Ok(Tensor::from_operation(result, node, true))
142 }
143
144 pub fn sigmoid(&self) -> Result<Self> {
146 let result = Array::new(
148 self.data.shape.clone(),
149 self.data.data.iter()
150 .map(|&x| 1.0 / (1.0 + (-x).exp()))
151 .collect()
152 );
153
154 if !is_grad_enabled() || !self.requires_grad {
155 return Ok(Tensor::new(result, false));
156 }
157
158 let backward_fn = Box::new(backward::sigmoid_backward);
159
160 let node = ComputeNode::new(
161 OpKind::Sigmoid,
162 vec![self.clone()],
163 Some(backward_fn),
164 );
165
166 Ok(Tensor::from_operation(result, node, true))
167 }
168
169 pub fn exp(&self) -> Result<Self> {
171 let result = Array::new(
172 self.data.shape.clone(),
173 self.data.data.iter().map(|&x| x.exp()).collect()
174 );
175
176 if !is_grad_enabled() || !self.requires_grad {
177 return Ok(Tensor::new(result, false));
178 }
179
180 let backward_fn = Box::new(backward::exp_backward);
181
182 let node = ComputeNode::new(
183 OpKind::Exp,
184 vec![self.clone()],
185 Some(backward_fn),
186 );
187
188 Ok(Tensor::from_operation(result, node, true))
189 }
190
191 pub fn log(&self) -> Result<Self> {
193 let result = Array::new(
194 self.data.shape.clone(),
195 self.data.data.iter().map(|&x| x.ln()).collect()
196 );
197
198 if !is_grad_enabled() || !self.requires_grad {
199 return Ok(Tensor::new(result, false));
200 }
201
202 let backward_fn = Box::new(backward::log_backward);
203
204 let node = ComputeNode::new(
205 OpKind::Log,
206 vec![self.clone()],
207 Some(backward_fn),
208 );
209
210 Ok(Tensor::from_operation(result, node, true))
211 }
212
213 pub fn sum(&self) -> Result<Self> {
215 let result = ops::sum(&self.data, None)?;
217
218 if !is_grad_enabled() || !self.requires_grad {
219 return Ok(Tensor::new(result, false));
220 }
221
222 let backward_fn = Box::new(backward::sum_backward);
223
224 let node = ComputeNode::new(
225 OpKind::Sum { axis: None },
226 vec![self.clone()],
227 Some(backward_fn),
228 );
229
230 Ok(Tensor::from_operation(result, node, true))
231 }
232
233 pub fn mean(&self) -> Result<Self> {
235 let sum_val: f32 = self.data.data.iter().sum();
236 let n = self.data.data.len() as f32;
237 let result = Array::new(vec![1], vec![sum_val / n]);
238
239 if !is_grad_enabled() || !self.requires_grad {
240 return Ok(Tensor::new(result, false));
241 }
242
243 let backward_fn = Box::new(backward::mean_backward);
244
245 let node = ComputeNode::new(
246 OpKind::Mean { axis: None },
247 vec![self.clone()],
248 Some(backward_fn),
249 );
250
251 Ok(Tensor::from_operation(result, node, true))
252 }
253
254 pub fn mse_loss(&self, target: &Tensor) -> Result<Self> {
256 let diff_squared: f32 = self.data.data.iter()
258 .zip(target.data.data.iter())
259 .map(|(p, t)| (p - t).powi(2))
260 .sum();
261 let n = self.data.data.len() as f32;
262 let result = Array::new(vec![1], vec![diff_squared / n]);
263
264 if !is_grad_enabled() || (!self.requires_grad && !target.requires_grad) {
265 return Ok(Tensor::new(result, false));
266 }
267
268 let requires_grad = self.requires_grad || target.requires_grad;
269 let backward_fn = Box::new(backward::mse_backward);
270
271 let node = ComputeNode::new(
272 OpKind::MSE,
273 vec![self.clone(), target.clone()],
274 Some(backward_fn),
275 );
276
277 Ok(Tensor::from_operation(result, node, requires_grad))
278 }
279
280 pub fn cross_entropy_loss(&self, target: &Tensor) -> Result<Self> {
282 if self.data.shape.len() != 2 || target.data.shape.len() != 2 {
284 return Err(anyhow::anyhow!("CrossEntropy expect 2D tensors [batch, classes]"));
285 }
286
287 let batch_size = self.data.shape[0];
288 let num_classes = self.data.shape[1];
289
290 let mut total_loss = 0.0;
291
292 for i in 0..batch_size {
294 let start = i * num_classes;
295 let end = start + num_classes;
296 let logits = &self.data.data[start..end];
297 let targets = &target.data.data[start..end];
298
299 let max_val = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
301 let exp_sum: f32 = logits.iter().map(|x| (x - max_val).exp()).sum();
302
303 let sample_loss: f32 = logits.iter()
306 .zip(targets.iter())
307 .map(|(&x, &t)| {
308 let log_softmax = (x - max_val) - exp_sum.ln();
309 -t * log_softmax
310 })
311 .sum();
312
313 total_loss += sample_loss;
314 }
315
316 let mean_loss = total_loss / batch_size as f32;
318 let result = Array::new(vec![1], vec![mean_loss]);
319
320 if !is_grad_enabled() || (!self.requires_grad && !target.requires_grad) {
321 return Ok(Tensor::new(result, false));
322 }
323
324 let requires_grad = self.requires_grad || target.requires_grad;
325 let backward_fn = Box::new(backward::cross_entropy_backward);
326
327 let node = ComputeNode::new(
328 OpKind::CrossEntropy,
329 vec![self.clone(), target.clone()],
330 Some(backward_fn),
331 );
332
333 Ok(Tensor::from_operation(result, node, requires_grad))
334 }
335
336 pub fn flatten(&self, start_dim: usize, end_dim: usize) -> Result<Self> {
338 let result = ops::flatten(&self.data, start_dim, end_dim)?;
339
340 if !is_grad_enabled() || !self.requires_grad {
341 return Ok(Tensor::new(result, false));
342 }
343
344 let backward_fn = Box::new(backward::flatten_backward);
345 let node = ComputeNode::new(
346 OpKind::Flatten { start_dim, end_dim },
347 vec![self.clone()],
348 Some(backward_fn),
349 );
350 Ok(Tensor::from_operation(result, node, true))
351 }
352
353 pub fn reshape(&self, shape: Vec<usize>) -> Result<Self> {
355 let shape_isize: Vec<isize> = shape.iter().map(|&x| x as isize).collect();
356 let result = ops::reshape(&self.data, &shape_isize)?;
357
358 if !is_grad_enabled() || !self.requires_grad {
359 return Ok(Tensor::new(result, false));
360 }
361
362 let backward_fn = Box::new(backward::reshape_backward);
363 let node = ComputeNode::new(
364 OpKind::Reshape { shape },
365 vec![self.clone()],
366 Some(backward_fn),
367 );
368 Ok(Tensor::from_operation(result, node, true))
369 }
370
371 pub fn conv1d(&self, weight: &Tensor, bias: Option<&Tensor>, stride: usize, padding: usize) -> Result<Self> {
373 let bias_data = bias.map(|b| &b.data);
374 let result = ops::conv::conv1d(&self.data, &weight.data, bias_data, stride, padding)?;
375
376 let mut inputs = vec![self.clone(), weight.clone()];
377 let mut requires_grad = self.requires_grad || weight.requires_grad;
378
379 if let Some(b) = bias {
380 inputs.push(b.clone());
381 requires_grad = requires_grad || b.requires_grad;
382 }
383
384 if !is_grad_enabled() || !requires_grad {
385 return Ok(Tensor::new(result, false));
386 }
387
388 let backward_fn = Box::new(backward::conv1d_backward);
389 let node = ComputeNode::new(
390 OpKind::Conv1D { stride, padding },
391 inputs,
392 Some(backward_fn),
393 );
394 Ok(Tensor::from_operation(result, node, true))
395 }
396
397 #[allow(clippy::too_many_arguments)]
399 pub fn batch_norm(&self, running_mean: &mut Tensor, running_var: &mut Tensor, weight: &Tensor, bias: &Tensor, training: bool, momentum: f32, eps: f32) -> Result<Self> {
400 let result = ops::batchnorm::batch_norm(
401 &self.data,
402 &mut running_mean.data,
403 &mut running_var.data,
404 &weight.data,
405 &bias.data,
406 training,
407 momentum,
408 eps
409 )?;
410
411 let requires_grad = self.requires_grad || weight.requires_grad || bias.requires_grad;
412
413 if !is_grad_enabled() || !requires_grad {
414 return Ok(Tensor::new(result, false));
415 }
416
417 let inputs = vec![
420 self.clone(),
421 weight.clone(),
422 bias.clone(),
423 running_mean.clone(),
424 running_var.clone()
425 ];
426 let backward_fn = Box::new(backward::batchnorm_backward);
427
428 let node = ComputeNode::new(
429 OpKind::BatchNorm { training, momentum, eps },
430 inputs,
431 Some(backward_fn),
432 );
433 Ok(Tensor::from_operation(result, node, true))
434 }
435
436 pub fn dropout(&self, p: f32, training: bool) -> Result<Self> {
438 let result = ops::dropout::dropout(&self.data, p, training)?;
439
440 if !is_grad_enabled() || !self.requires_grad {
441 return Ok(Tensor::new(result, false));
442 }
443
444 let backward_fn = Box::new(backward::dropout_backward);
445 let node = ComputeNode::new(
446 OpKind::Dropout { p, training },
447 vec![self.clone()],
448 Some(backward_fn),
449 );
450 Ok(Tensor::from_operation(result, node, true))
451 }
452
453 pub fn pow(&self, exponent: f32) -> Result<Self> {
455 let exponent_arr = Array::new(vec![1], vec![exponent]);
456 let result = ops::pow(&self.data, &exponent_arr)?;
457
458 if !is_grad_enabled() || !self.requires_grad {
459 return Ok(Tensor::new(result, false));
460 }
461
462 let backward_fn = Box::new(backward::pow_backward);
463 let node = ComputeNode::new(
464 OpKind::Pow(exponent),
465 vec![self.clone()],
466 Some(backward_fn),
467 );
468 Ok(Tensor::from_operation(result, node, true))
469 }
470
471 pub fn sqrt(&self) -> Result<Self> {
473 let result = ops::sqrt(&self.data)?;
474
475 if !is_grad_enabled() || !self.requires_grad {
476 return Ok(Tensor::new(result, false));
477 }
478
479 let backward_fn = Box::new(backward::sqrt_backward);
480 let node = ComputeNode::new(
481 OpKind::Sqrt,
482 vec![self.clone()],
483 Some(backward_fn),
484 );
485 Ok(Tensor::from_operation(result, node, true))
486 }
487
488 pub fn sin(&self) -> Result<Self> {
490 let result = ops::sin(&self.data)?;
491
492 if !is_grad_enabled() || !self.requires_grad {
493 return Ok(Tensor::new(result, false));
494 }
495
496 let backward_fn = Box::new(backward::sin_backward);
497 let node = ComputeNode::new(
498 OpKind::Sin,
499 vec![self.clone()],
500 Some(backward_fn),
501 );
502 Ok(Tensor::from_operation(result, node, true))
503 }
504
505 pub fn cos(&self) -> Result<Self> {
507 let result = ops::cos(&self.data)?;
508
509 if !is_grad_enabled() || !self.requires_grad {
510 return Ok(Tensor::new(result, false));
511 }
512
513 let backward_fn = Box::new(backward::cos_backward);
514 let node = ComputeNode::new(
515 OpKind::Cos,
516 vec![self.clone()],
517 Some(backward_fn),
518 );
519 Ok(Tensor::from_operation(result, node, true))
520 }
521
522 pub fn tan(&self) -> Result<Self> {
524 let result = ops::tan(&self.data)?;
525
526 if !is_grad_enabled() || !self.requires_grad {
527 return Ok(Tensor::new(result, false));
528 }
529
530 let backward_fn = Box::new(backward::tan_backward);
531 let node = ComputeNode::new(
532 OpKind::Tan,
533 vec![self.clone()],
534 Some(backward_fn),
535 );
536 Ok(Tensor::from_operation(result, node, true))
537 }
538
539 pub fn tanh(&self) -> Result<Self> {
541 let result = ops::tanh(&self.data)?;
542
543 if !is_grad_enabled() || !self.requires_grad {
544 return Ok(Tensor::new(result, false));
545 }
546
547 let backward_fn = Box::new(backward::tanh_backward);
548 let node = ComputeNode::new(
549 OpKind::Tanh,
550 vec![self.clone()],
551 Some(backward_fn),
552 );
553 Ok(Tensor::from_operation(result, node, true))
554 }
555}