quantrs2_ml/pytorch_api/
loss.rs

1//! Loss functions for PyTorch-like API
2
3use crate::error::{MLError, Result};
4use crate::scirs2_integration::SciRS2Array;
5use scirs2_core::ndarray::{ArrayD, IxDyn};
6
7/// Loss functions for quantum ML
8pub trait QuantumLoss: Send + Sync {
9    /// Compute loss
10    fn forward(&self, predictions: &SciRS2Array, targets: &SciRS2Array) -> Result<SciRS2Array>;
11
12    /// Loss function name
13    fn name(&self) -> &str;
14}
15
16/// Mean Squared Error loss
17pub struct QuantumMSELoss;
18
19impl QuantumLoss for QuantumMSELoss {
20    fn forward(&self, predictions: &SciRS2Array, targets: &SciRS2Array) -> Result<SciRS2Array> {
21        let diff = predictions.data.clone() - &targets.data;
22        let squared_diff = &diff * &diff;
23        let mse = squared_diff.mean().ok_or_else(|| {
24            MLError::InvalidConfiguration("Cannot compute mean of empty array".to_string())
25        })?;
26
27        let loss_data = ArrayD::from_elem(IxDyn(&[]), mse);
28        Ok(SciRS2Array::new(loss_data, predictions.requires_grad))
29    }
30
31    fn name(&self) -> &str {
32        "MSELoss"
33    }
34}
35
36/// Cross Entropy loss
37pub struct QuantumCrossEntropyLoss;
38
39impl QuantumLoss for QuantumCrossEntropyLoss {
40    fn forward(&self, predictions: &SciRS2Array, targets: &SciRS2Array) -> Result<SciRS2Array> {
41        let max_val = predictions
42            .data
43            .iter()
44            .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
45        let exp_preds = predictions.data.mapv(|x| (x - max_val).exp());
46        let sum_exp = exp_preds.sum();
47        let softmax = exp_preds.mapv(|x| x / sum_exp);
48
49        let log_softmax = softmax.mapv(|x| x.ln());
50        let cross_entropy = -(&targets.data * &log_softmax).sum();
51
52        let loss_data = ArrayD::from_elem(IxDyn(&[]), cross_entropy);
53        Ok(SciRS2Array::new(loss_data, predictions.requires_grad))
54    }
55
56    fn name(&self) -> &str {
57        "CrossEntropyLoss"
58    }
59}
60
61/// Binary Cross Entropy Loss
62pub struct QuantumBCELoss;
63
64impl QuantumLoss for QuantumBCELoss {
65    fn forward(&self, predictions: &SciRS2Array, targets: &SciRS2Array) -> Result<SciRS2Array> {
66        let eps = 1e-7;
67        let mut loss = 0.0;
68        let n = predictions.data.len() as f64;
69
70        for (pred, target) in predictions.data.iter().zip(targets.data.iter()) {
71            let p = pred.clamp(eps, 1.0 - eps);
72            loss -= target * p.ln() + (1.0 - target) * (1.0 - p).ln();
73        }
74
75        let output = ArrayD::from_elem(IxDyn(&[1]), loss / n);
76        Ok(SciRS2Array::new(output, predictions.requires_grad))
77    }
78
79    fn name(&self) -> &str {
80        "BCELoss"
81    }
82}
83
84/// Binary Cross Entropy with Logits Loss
85pub struct QuantumBCEWithLogitsLoss;
86
87impl QuantumLoss for QuantumBCEWithLogitsLoss {
88    fn forward(&self, predictions: &SciRS2Array, targets: &SciRS2Array) -> Result<SciRS2Array> {
89        let mut loss = 0.0;
90        let n = predictions.data.len() as f64;
91
92        for (logit, target) in predictions.data.iter().zip(targets.data.iter()) {
93            let max_val = logit.max(0.0);
94            loss += max_val - logit * target + (1.0 + (-logit.abs()).exp()).ln();
95        }
96
97        let output = ArrayD::from_elem(IxDyn(&[1]), loss / n);
98        Ok(SciRS2Array::new(output, predictions.requires_grad))
99    }
100
101    fn name(&self) -> &str {
102        "BCEWithLogitsLoss"
103    }
104}
105
106/// L1 Loss (Mean Absolute Error)
107pub struct QuantumL1Loss;
108
109impl QuantumLoss for QuantumL1Loss {
110    fn forward(&self, predictions: &SciRS2Array, targets: &SciRS2Array) -> Result<SciRS2Array> {
111        let mut loss = 0.0;
112        let n = predictions.data.len() as f64;
113
114        for (pred, target) in predictions.data.iter().zip(targets.data.iter()) {
115            loss += (pred - target).abs();
116        }
117
118        let output = ArrayD::from_elem(IxDyn(&[1]), loss / n);
119        Ok(SciRS2Array::new(output, predictions.requires_grad))
120    }
121
122    fn name(&self) -> &str {
123        "L1Loss"
124    }
125}
126
127/// Smooth L1 Loss (Huber Loss)
128pub struct QuantumSmoothL1Loss {
129    beta: f64,
130}
131
132impl QuantumSmoothL1Loss {
133    /// Create new smooth L1 loss
134    pub fn new(beta: f64) -> Self {
135        Self { beta }
136    }
137}
138
139impl Default for QuantumSmoothL1Loss {
140    fn default() -> Self {
141        Self { beta: 1.0 }
142    }
143}
144
145impl QuantumLoss for QuantumSmoothL1Loss {
146    fn forward(&self, predictions: &SciRS2Array, targets: &SciRS2Array) -> Result<SciRS2Array> {
147        let mut loss = 0.0;
148        let n = predictions.data.len() as f64;
149
150        for (pred, target) in predictions.data.iter().zip(targets.data.iter()) {
151            let diff = (pred - target).abs();
152            if diff < self.beta {
153                loss += 0.5 * diff * diff / self.beta;
154            } else {
155                loss += diff - 0.5 * self.beta;
156            }
157        }
158
159        let output = ArrayD::from_elem(IxDyn(&[1]), loss / n);
160        Ok(SciRS2Array::new(output, predictions.requires_grad))
161    }
162
163    fn name(&self) -> &str {
164        "SmoothL1Loss"
165    }
166}
167
168/// Negative Log Likelihood Loss
169pub struct QuantumNLLLoss;
170
171impl QuantumLoss for QuantumNLLLoss {
172    fn forward(&self, predictions: &SciRS2Array, targets: &SciRS2Array) -> Result<SciRS2Array> {
173        let shape = predictions.data.shape();
174        if shape.len() != 2 {
175            return Err(MLError::InvalidConfiguration(
176                "NLLLoss expects 2D predictions (batch_size, num_classes)".to_string(),
177            ));
178        }
179
180        let batch_size = shape[0];
181        let mut loss = 0.0;
182
183        for b in 0..batch_size {
184            let target_class = targets.data[[b]] as usize;
185            loss -= predictions.data[[b, target_class]];
186        }
187
188        let output = ArrayD::from_elem(IxDyn(&[1]), loss / batch_size as f64);
189        Ok(SciRS2Array::new(output, predictions.requires_grad))
190    }
191
192    fn name(&self) -> &str {
193        "NLLLoss"
194    }
195}
196
197/// Kullback-Leibler Divergence Loss
198pub struct QuantumKLDivLoss {
199    reduction: String,
200}
201
202impl QuantumKLDivLoss {
203    /// Create new KL divergence loss
204    pub fn new() -> Self {
205        Self {
206            reduction: "mean".to_string(),
207        }
208    }
209
210    /// Set reduction type
211    pub fn reduction(mut self, reduction: &str) -> Self {
212        self.reduction = reduction.to_string();
213        self
214    }
215}
216
217impl Default for QuantumKLDivLoss {
218    fn default() -> Self {
219        Self::new()
220    }
221}
222
223impl QuantumLoss for QuantumKLDivLoss {
224    fn forward(&self, predictions: &SciRS2Array, targets: &SciRS2Array) -> Result<SciRS2Array> {
225        let mut loss = 0.0;
226
227        for (log_q, p) in predictions.data.iter().zip(targets.data.iter()) {
228            if *p > 0.0 {
229                loss += p * (p.ln() - log_q);
230            }
231        }
232
233        let output = match self.reduction.as_str() {
234            "sum" => ArrayD::from_elem(IxDyn(&[1]), loss),
235            "mean" => ArrayD::from_elem(IxDyn(&[1]), loss / predictions.data.len() as f64),
236            _ => ArrayD::from_elem(IxDyn(&[1]), loss),
237        };
238
239        Ok(SciRS2Array::new(output, predictions.requires_grad))
240    }
241
242    fn name(&self) -> &str {
243        "KLDivLoss"
244    }
245}