1use ndarray::{Array1, Array2, Array3};
2use ndarray_linalg::{Inverse, Solve, SVD};
3
4pub trait Score: 'static + Clone + Copy + Default {}
6
7#[derive(Clone, Copy, Default)]
9pub struct LogScore;
10impl Score for LogScore {}
11
12#[derive(Clone, Copy, Default)]
14pub struct CRPScore;
15impl Score for CRPScore {}
16
17#[derive(Clone, Copy, Default)]
19pub struct LogScoreCensored;
20impl Score for LogScoreCensored {}
21
22#[derive(Clone, Copy, Default)]
24pub struct CRPScoreCensored;
25impl Score for CRPScoreCensored {}
26
27#[derive(Debug, Clone)]
29pub struct SurvivalData {
30 pub event: Array1<bool>,
32 pub time: Array1<f64>,
34}
35
36impl SurvivalData {
37 pub fn new(event: Array1<bool>, time: Array1<f64>) -> Self {
39 SurvivalData { event, time }
40 }
41
42 pub fn from_arrays(time: &Array1<f64>, event: &Array1<f64>) -> Self {
44 let event_bool = event.mapv(|e| e > 0.5);
45 SurvivalData {
46 event: event_bool,
47 time: time.clone(),
48 }
49 }
50
51 pub fn uncensored(time: Array1<f64>) -> Self {
53 let event = Array1::from_elem(time.len(), true);
54 SurvivalData { event, time }
55 }
56
57 pub fn len(&self) -> usize {
59 self.time.len()
60 }
61
62 pub fn is_empty(&self) -> bool {
64 self.time.is_empty()
65 }
66}
67
68pub trait CensoredScorable<S: Score> {
70 fn censored_score(&self, y: &SurvivalData) -> Array1<f64>;
72
73 fn censored_d_score(&self, y: &SurvivalData) -> Array2<f64>;
75
76 fn censored_metric(&self) -> Array3<f64>;
78
79 fn total_censored_score(&self, y: &SurvivalData, sample_weight: Option<&Array1<f64>>) -> f64 {
81 let scores = self.censored_score(y);
82 if let Some(weights) = sample_weight {
83 (scores * weights).sum() / weights.sum()
84 } else {
85 scores.mean().unwrap_or(0.0)
86 }
87 }
88
89 fn censored_grad(&self, y: &SurvivalData, natural: bool) -> Array2<f64>
91 where
92 Self: Sized,
93 {
94 let grad = self.censored_d_score(y);
95 if !natural {
96 return grad;
97 }
98
99 let metric = self.censored_metric();
100 let n_obs = grad.nrows();
101 let mut natural_grad = Array2::zeros(grad.raw_dim());
102
103 for i in 0..n_obs {
104 let g_i = grad.row(i).to_owned();
105 let metric_i = metric.index_axis(ndarray::Axis(0), i).to_owned();
106
107 if let Ok(ng_i) = metric_i.solve_into(g_i.clone()) {
109 if ng_i.iter().all(|&v| v.is_finite()) {
110 natural_grad.row_mut(i).assign(&ng_i);
111 continue;
112 }
113 }
114
115 if let Ok(inv_metric_i) = metric_i.inv() {
117 let result = inv_metric_i.dot(&grad.row(i));
118 if result.iter().all(|&v| v.is_finite()) {
119 natural_grad.row_mut(i).assign(&result);
120 continue;
121 }
122 }
123
124 if let Some(pinv_metric_i) = pinv(&metric_i) {
126 let result = pinv_metric_i.dot(&grad.row(i));
127 if result.iter().all(|&v| v.is_finite()) {
128 natural_grad.row_mut(i).assign(&result);
129 continue;
130 }
131 }
132
133 natural_grad.row_mut(i).assign(&(&grad.row(i) * 0.99));
135 }
136 natural_grad
137 }
138}
139
140pub fn natural_gradient_regularized(
152 grad: &Array2<f64>,
153 metric: &Array3<f64>,
154 reg: f64,
155) -> Array2<f64> {
156 let n_obs = grad.nrows();
157 let n_params = grad.ncols();
158 let mut natural_grad = Array2::zeros(grad.raw_dim());
159
160 for i in 0..n_obs {
161 let g_i = grad.row(i).to_owned();
162 let mut metric_i = metric.index_axis(ndarray::Axis(0), i).to_owned();
163
164 if reg > 0.0 {
166 for j in 0..n_params {
167 metric_i[[j, j]] += reg;
168 }
169 }
170
171 if let Ok(ng_i) = metric_i.solve_into(g_i.clone()) {
173 if ng_i.iter().all(|&v| v.is_finite()) {
174 natural_grad.row_mut(i).assign(&ng_i);
175 continue;
176 }
177 }
178
179 if let Ok(inv_metric_i) = metric_i.inv() {
181 let result = inv_metric_i.dot(&grad.row(i));
182 if result.iter().all(|&v| v.is_finite()) {
183 natural_grad.row_mut(i).assign(&result);
184 continue;
185 }
186 }
187
188 if let Some(pinv_metric_i) = pinv(&metric_i) {
190 let result = pinv_metric_i.dot(&grad.row(i));
191 if result.iter().all(|&v| v.is_finite()) {
192 natural_grad.row_mut(i).assign(&result);
193 continue;
194 }
195 }
196
197 natural_grad.row_mut(i).assign(&(&grad.row(i) * 0.99));
199 }
200 natural_grad
201}
202
203fn pinv(matrix: &Array2<f64>) -> Option<Array2<f64>> {
206 let rcond = 1e-15; let (u, s, vt) = matrix.svd(true, true).ok()?;
211 let u = u?;
212 let vt = vt?;
213
214 let max_sv = s.iter().cloned().fold(0.0_f64, f64::max);
216 let cutoff = rcond * max_sv;
217
218 let s_pinv: Array1<f64> = s.mapv(|sv| if sv > cutoff { 1.0 / sv } else { 0.0 });
220
221 let n = s_pinv.len();
224 let mut result = Array2::zeros((vt.ncols(), u.nrows()));
225
226 for i in 0..n {
227 for j in 0..u.nrows() {
228 for k in 0..vt.ncols() {
229 result[[k, j]] += vt[[i, k]] * s_pinv[i] * u[[j, i]];
230 }
231 }
232 }
233
234 Some(result)
235}
236
237pub trait Scorable<S: Score> {
239 fn score(&self, y: &Array1<f64>) -> Array1<f64>;
241
242 fn d_score(&self, y: &Array1<f64>) -> Array2<f64>;
244
245 fn metric(&self) -> Array3<f64>;
247
248 fn total_score(&self, y: &Array1<f64>, sample_weight: Option<&Array1<f64>>) -> f64 {
250 let scores = self.score(y);
251 if let Some(weights) = sample_weight {
252 (scores * weights).sum() / weights.sum()
253 } else {
254 scores.mean().unwrap_or(0.0)
255 }
256 }
257
258 fn grad(&self, y: &Array1<f64>, natural: bool) -> Array2<f64> {
264 let grad = self.d_score(y);
265 if !natural {
266 return grad;
267 }
268
269 let metric = self.metric();
270 let n_obs = grad.nrows();
271 let mut natural_grad = Array2::zeros(grad.raw_dim());
272
273 for i in 0..n_obs {
274 let g_i = grad.row(i).to_owned();
275 let metric_i = metric.index_axis(ndarray::Axis(0), i).to_owned();
276
277 if let Ok(ng_i) = metric_i.solve_into(g_i.clone()) {
279 if ng_i.iter().all(|&v| v.is_finite()) {
281 natural_grad.row_mut(i).assign(&ng_i);
282 continue;
283 }
284 }
285
286 if let Ok(inv_metric_i) = metric_i.inv() {
288 let result = inv_metric_i.dot(&grad.row(i));
289 if result.iter().all(|&v| v.is_finite()) {
290 natural_grad.row_mut(i).assign(&result);
291 continue;
292 }
293 }
294
295 if let Some(pinv_metric_i) = pinv(&metric_i) {
297 let result = pinv_metric_i.dot(&grad.row(i));
298 if result.iter().all(|&v| v.is_finite()) {
299 natural_grad.row_mut(i).assign(&result);
300 continue;
301 }
302 }
303
304 natural_grad.row_mut(i).assign(&(&grad.row(i) * 0.99));
307 }
308 natural_grad
309 }
310}