quantrs2_ml/pytorch_api/
loss.rs1use crate::error::{MLError, Result};
4use crate::scirs2_integration::SciRS2Array;
5use scirs2_core::ndarray::{ArrayD, IxDyn};
6
7pub trait QuantumLoss: Send + Sync {
9 fn forward(&self, predictions: &SciRS2Array, targets: &SciRS2Array) -> Result<SciRS2Array>;
11
12 fn name(&self) -> &str;
14}
15
16pub struct QuantumMSELoss;
18
19impl QuantumLoss for QuantumMSELoss {
20 fn forward(&self, predictions: &SciRS2Array, targets: &SciRS2Array) -> Result<SciRS2Array> {
21 let diff = predictions.data.clone() - &targets.data;
22 let squared_diff = &diff * &diff;
23 let mse = squared_diff.mean().ok_or_else(|| {
24 MLError::InvalidConfiguration("Cannot compute mean of empty array".to_string())
25 })?;
26
27 let loss_data = ArrayD::from_elem(IxDyn(&[]), mse);
28 Ok(SciRS2Array::new(loss_data, predictions.requires_grad))
29 }
30
31 fn name(&self) -> &str {
32 "MSELoss"
33 }
34}
35
36pub struct QuantumCrossEntropyLoss;
38
39impl QuantumLoss for QuantumCrossEntropyLoss {
40 fn forward(&self, predictions: &SciRS2Array, targets: &SciRS2Array) -> Result<SciRS2Array> {
41 let max_val = predictions
42 .data
43 .iter()
44 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
45 let exp_preds = predictions.data.mapv(|x| (x - max_val).exp());
46 let sum_exp = exp_preds.sum();
47 let softmax = exp_preds.mapv(|x| x / sum_exp);
48
49 let log_softmax = softmax.mapv(|x| x.ln());
50 let cross_entropy = -(&targets.data * &log_softmax).sum();
51
52 let loss_data = ArrayD::from_elem(IxDyn(&[]), cross_entropy);
53 Ok(SciRS2Array::new(loss_data, predictions.requires_grad))
54 }
55
56 fn name(&self) -> &str {
57 "CrossEntropyLoss"
58 }
59}
60
61pub struct QuantumBCELoss;
63
64impl QuantumLoss for QuantumBCELoss {
65 fn forward(&self, predictions: &SciRS2Array, targets: &SciRS2Array) -> Result<SciRS2Array> {
66 let eps = 1e-7;
67 let mut loss = 0.0;
68 let n = predictions.data.len() as f64;
69
70 for (pred, target) in predictions.data.iter().zip(targets.data.iter()) {
71 let p = pred.clamp(eps, 1.0 - eps);
72 loss -= target * p.ln() + (1.0 - target) * (1.0 - p).ln();
73 }
74
75 let output = ArrayD::from_elem(IxDyn(&[1]), loss / n);
76 Ok(SciRS2Array::new(output, predictions.requires_grad))
77 }
78
79 fn name(&self) -> &str {
80 "BCELoss"
81 }
82}
83
84pub struct QuantumBCEWithLogitsLoss;
86
87impl QuantumLoss for QuantumBCEWithLogitsLoss {
88 fn forward(&self, predictions: &SciRS2Array, targets: &SciRS2Array) -> Result<SciRS2Array> {
89 let mut loss = 0.0;
90 let n = predictions.data.len() as f64;
91
92 for (logit, target) in predictions.data.iter().zip(targets.data.iter()) {
93 let max_val = logit.max(0.0);
94 loss += max_val - logit * target + (1.0 + (-logit.abs()).exp()).ln();
95 }
96
97 let output = ArrayD::from_elem(IxDyn(&[1]), loss / n);
98 Ok(SciRS2Array::new(output, predictions.requires_grad))
99 }
100
101 fn name(&self) -> &str {
102 "BCEWithLogitsLoss"
103 }
104}
105
106pub struct QuantumL1Loss;
108
109impl QuantumLoss for QuantumL1Loss {
110 fn forward(&self, predictions: &SciRS2Array, targets: &SciRS2Array) -> Result<SciRS2Array> {
111 let mut loss = 0.0;
112 let n = predictions.data.len() as f64;
113
114 for (pred, target) in predictions.data.iter().zip(targets.data.iter()) {
115 loss += (pred - target).abs();
116 }
117
118 let output = ArrayD::from_elem(IxDyn(&[1]), loss / n);
119 Ok(SciRS2Array::new(output, predictions.requires_grad))
120 }
121
122 fn name(&self) -> &str {
123 "L1Loss"
124 }
125}
126
127pub struct QuantumSmoothL1Loss {
129 beta: f64,
130}
131
132impl QuantumSmoothL1Loss {
133 pub fn new(beta: f64) -> Self {
135 Self { beta }
136 }
137}
138
139impl Default for QuantumSmoothL1Loss {
140 fn default() -> Self {
141 Self { beta: 1.0 }
142 }
143}
144
145impl QuantumLoss for QuantumSmoothL1Loss {
146 fn forward(&self, predictions: &SciRS2Array, targets: &SciRS2Array) -> Result<SciRS2Array> {
147 let mut loss = 0.0;
148 let n = predictions.data.len() as f64;
149
150 for (pred, target) in predictions.data.iter().zip(targets.data.iter()) {
151 let diff = (pred - target).abs();
152 if diff < self.beta {
153 loss += 0.5 * diff * diff / self.beta;
154 } else {
155 loss += diff - 0.5 * self.beta;
156 }
157 }
158
159 let output = ArrayD::from_elem(IxDyn(&[1]), loss / n);
160 Ok(SciRS2Array::new(output, predictions.requires_grad))
161 }
162
163 fn name(&self) -> &str {
164 "SmoothL1Loss"
165 }
166}
167
168pub struct QuantumNLLLoss;
170
171impl QuantumLoss for QuantumNLLLoss {
172 fn forward(&self, predictions: &SciRS2Array, targets: &SciRS2Array) -> Result<SciRS2Array> {
173 let shape = predictions.data.shape();
174 if shape.len() != 2 {
175 return Err(MLError::InvalidConfiguration(
176 "NLLLoss expects 2D predictions (batch_size, num_classes)".to_string(),
177 ));
178 }
179
180 let batch_size = shape[0];
181 let mut loss = 0.0;
182
183 for b in 0..batch_size {
184 let target_class = targets.data[[b]] as usize;
185 loss -= predictions.data[[b, target_class]];
186 }
187
188 let output = ArrayD::from_elem(IxDyn(&[1]), loss / batch_size as f64);
189 Ok(SciRS2Array::new(output, predictions.requires_grad))
190 }
191
192 fn name(&self) -> &str {
193 "NLLLoss"
194 }
195}
196
197pub struct QuantumKLDivLoss {
199 reduction: String,
200}
201
202impl QuantumKLDivLoss {
203 pub fn new() -> Self {
205 Self {
206 reduction: "mean".to_string(),
207 }
208 }
209
210 pub fn reduction(mut self, reduction: &str) -> Self {
212 self.reduction = reduction.to_string();
213 self
214 }
215}
216
217impl Default for QuantumKLDivLoss {
218 fn default() -> Self {
219 Self::new()
220 }
221}
222
223impl QuantumLoss for QuantumKLDivLoss {
224 fn forward(&self, predictions: &SciRS2Array, targets: &SciRS2Array) -> Result<SciRS2Array> {
225 let mut loss = 0.0;
226
227 for (log_q, p) in predictions.data.iter().zip(targets.data.iter()) {
228 if *p > 0.0 {
229 loss += p * (p.ln() - log_q);
230 }
231 }
232
233 let output = match self.reduction.as_str() {
234 "sum" => ArrayD::from_elem(IxDyn(&[1]), loss),
235 "mean" => ArrayD::from_elem(IxDyn(&[1]), loss / predictions.data.len() as f64),
236 _ => ArrayD::from_elem(IxDyn(&[1]), loss),
237 };
238
239 Ok(SciRS2Array::new(output, predictions.requires_grad))
240 }
241
242 fn name(&self) -> &str {
243 "KLDivLoss"
244 }
245}