1use radiate_error::RadiateError;
2#[cfg(feature = "serde")]
3use serde::{Deserialize, Serialize};
4use std::fmt::Debug;
5use std::hash::Hash;
6use std::iter::Sum;
7use std::ops::{Add, Div, Index, Mul, Sub};
8use std::sync::Arc;
9
10pub trait Scored {
13 fn score(&self) -> Option<&Score>;
14}
15
16#[derive(Clone, PartialEq, Default)]
25#[repr(transparent)]
26pub struct Score {
27 values: Arc<[f32]>,
28}
29
30impl Score {
31 pub fn from_vec(values: Vec<f32>) -> Self {
32 if values.iter().any(|&v| v.is_nan()) {
33 panic!("Score value cannot be NaN");
34 }
35
36 Score {
37 values: Arc::from(values),
38 }
39 }
40
41 pub fn is_multi_objective(&self) -> bool {
42 self.values.len() > 1
43 }
44
45 pub fn objective(&self, idx: usize) -> Option<&f32> {
46 self.values.get(idx)
47 }
48
49 pub fn as_slice(&self) -> &[f32] {
50 &self.values
51 }
52
53 pub fn as_f32(&self) -> f32 {
54 self.values.get(0).cloned().unwrap_or(f32::NAN)
55 }
56
57 pub fn as_f64(&self) -> f64 {
58 self.values.get(0).cloned().unwrap_or(f32::NAN) as f64
59 }
60
61 pub fn as_i32(&self) -> i32 {
62 self.values[0] as i32
63 }
64
65 pub fn as_string(&self) -> String {
66 self.values[0].to_string()
67 }
68
69 pub fn as_usize(&self) -> usize {
70 self.values[0] as usize
71 }
72
73 pub fn iter(&self) -> impl Iterator<Item = &f32> + '_ {
74 self.values.iter()
75 }
76
77 pub fn len(&self) -> usize {
78 self.values.len()
79 }
80}
81
82impl AsRef<[f32]> for Score {
83 fn as_ref(&self) -> &[f32] {
84 &self.values
85 }
86}
87
88impl PartialOrd for Score {
89 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
90 self.values.partial_cmp(&other.values)
91 }
92}
93
94impl Debug for Score {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 write!(f, "{:?}", self.values)
97 }
98}
99
100impl Hash for Score {
101 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
102 let mut hash: usize = 0;
103
104 for value in self.values.iter() {
105 let value_hash = value.to_bits();
106 hash = hash.wrapping_add(value_hash as usize);
107 }
108
109 hash.hash(state);
110 }
111}
112
113impl Index<usize> for Score {
114 type Output = f32;
115 fn index(&self, index: usize) -> &Self::Output {
116 &self.values[index]
117 }
118}
119
120impl Into<Vec<f32>> for Score {
121 fn into(self) -> Vec<f32> {
122 self.values.to_vec()
123 }
124}
125
126impl From<f32> for Score {
127 fn from(value: f32) -> Self {
128 if value.is_nan() {
129 panic!("Score value cannot be NaN")
130 }
131
132 Score {
133 values: Arc::from(vec![value]),
134 }
135 }
136}
137
138impl TryFrom<i16> for Score {
139 type Error = RadiateError;
140
141 fn try_from(value: i16) -> Result<Self, Self::Error> {
142 Ok(Score {
143 values: Arc::from(vec![value as f32]),
144 })
145 }
146}
147
148impl From<f64> for Score {
149 fn from(value: f64) -> Self {
150 if value.is_nan() {
151 panic!("Score value cannot be NaN")
152 }
153
154 Score {
155 values: Arc::from(vec![value as f32]),
156 }
157 }
158}
159
160impl From<i32> for Score {
161 fn from(value: i32) -> Self {
162 Score {
163 values: Arc::from(vec![value as f32]),
164 }
165 }
166}
167
168impl From<i64> for Score {
169 fn from(value: i64) -> Self {
170 Score {
171 values: Arc::from(vec![value as f32]),
172 }
173 }
174}
175
176impl From<usize> for Score {
177 fn from(value: usize) -> Self {
178 Score {
179 values: Arc::from(vec![value as f32]),
180 }
181 }
182}
183
184impl From<String> for Score {
185 fn from(value: String) -> Self {
186 Score {
187 values: Arc::from(vec![
188 value.parse::<f32>().expect("Failed to parse string to f32"),
189 ]),
190 }
191 }
192}
193
194impl From<&str> for Score {
195 fn from(value: &str) -> Self {
196 Score {
197 values: Arc::from(vec![
198 value.parse::<f32>().expect("Failed to parse string to f32"),
199 ]),
200 }
201 }
202}
203
204impl From<Vec<f32>> for Score {
205 fn from(value: Vec<f32>) -> Self {
206 Score::from_vec(value)
207 }
208}
209
210impl From<Vec<f64>> for Score {
211 fn from(value: Vec<f64>) -> Self {
212 Score::from_vec(value.into_iter().map(|v| v as f32).collect())
213 }
214}
215
216impl From<Vec<i32>> for Score {
217 fn from(value: Vec<i32>) -> Self {
218 Score::from_vec(value.into_iter().map(|v| v as f32).collect())
219 }
220}
221
222impl From<Vec<i64>> for Score {
223 fn from(value: Vec<i64>) -> Self {
224 Score::from_vec(value.into_iter().map(|v| v as f32).collect())
225 }
226}
227
228impl From<Vec<usize>> for Score {
229 fn from(value: Vec<usize>) -> Self {
230 Score::from_vec(value.into_iter().map(|v| v as f32).collect())
231 }
232}
233
234impl From<Vec<String>> for Score {
235 fn from(value: Vec<String>) -> Self {
236 Score::from_vec(
237 value
238 .into_iter()
239 .map(|v| v.parse::<f32>().unwrap())
240 .collect(),
241 )
242 }
243}
244
245impl From<Vec<&str>> for Score {
246 fn from(value: Vec<&str>) -> Self {
247 Score::from_vec(
248 value
249 .into_iter()
250 .map(|v| v.parse::<f32>().unwrap())
251 .collect(),
252 )
253 }
254}
255
256impl Add for Score {
257 type Output = Self;
258
259 fn add(self, other: Self) -> Self {
260 if self.values.is_empty() {
261 return other;
262 }
263
264 let mut values = Vec::with_capacity(self.values.len());
265
266 for i in 0..self.values.len() {
267 values.push(self.values[i] + other.values[i]);
268 }
269
270 Score {
271 values: Arc::from(values),
272 }
273 }
274}
275
276impl Add<f32> for Score {
277 type Output = Self;
278
279 fn add(self, other: f32) -> Self {
280 if self.values.is_empty() {
281 return Score::from(other);
282 }
283
284 let mut values = Vec::with_capacity(self.values.len());
285 for i in 0..self.values.len() {
286 values.push(self.values[i] + other);
287 }
288
289 Score {
290 values: Arc::from(self.values),
291 }
292 }
293}
294
295impl Sub for Score {
296 type Output = Self;
297
298 fn sub(self, other: Self) -> Self {
299 if self.values.is_empty() {
300 return other;
301 }
302
303 let mut values = Vec::with_capacity(self.values.len());
304
305 for i in 0..self.values.len() {
306 values.push(self.values[i] - other.values[i]);
307 }
308
309 Score {
310 values: Arc::from(values),
311 }
312 }
313}
314
315impl Sub<f32> for Score {
316 type Output = Self;
317
318 fn sub(self, other: f32) -> Self {
319 if self.values.is_empty() {
320 return Score::from(-other);
321 }
322
323 let mut values = Vec::with_capacity(self.values.len());
324 for i in 0..self.values.len() {
325 values.push(self.values[i] - other);
326 }
327
328 Score {
329 values: Arc::from(values),
330 }
331 }
332}
333
334impl Mul for Score {
335 type Output = Self;
336
337 fn mul(self, other: Self) -> Self {
338 if self.values.is_empty() {
339 return other;
340 }
341
342 let mut values = Vec::with_capacity(self.values.len());
343 for i in 0..self.values.len() {
344 values.push(self.values[i] * other.values[i]);
345 }
346
347 Score {
348 values: Arc::from(values),
349 }
350 }
351}
352
353impl Mul<f32> for Score {
354 type Output = Self;
355
356 fn mul(self, other: f32) -> Self {
357 if self.values.is_empty() {
358 return Score::from(other);
359 }
360
361 let mut values = Vec::with_capacity(self.values.len());
362 for i in 0..self.values.len() {
363 values.push(self.values[i] * other);
364 }
365
366 Score {
367 values: Arc::from(values),
368 }
369 }
370}
371
372impl Mul<Score> for f32 {
373 type Output = Score;
374
375 fn mul(self, other: Score) -> Score {
376 if other.values.is_empty() {
377 return Score::from(self);
378 }
379
380 let mut values = Vec::with_capacity(other.values.len());
381 for i in 0..other.values.len() {
382 values.push(other.values[i] * self);
383 }
384
385 Score {
386 values: Arc::from(values),
387 }
388 }
389}
390
391impl Div for Score {
392 type Output = Self;
393
394 fn div(self, other: Self) -> Self {
395 if self.values.is_empty() {
396 return other;
397 }
398
399 let mut values = Vec::with_capacity(self.values.len());
400 for i in 0..self.values.len() {
401 values.push(self.values[i] / other.values[i]);
402 }
403
404 Score {
405 values: Arc::from(values),
406 }
407 }
408}
409
410impl Div<f32> for Score {
411 type Output = Self;
412
413 fn div(self, other: f32) -> Self {
414 if self.values.is_empty() {
415 return Score::from(other);
416 }
417
418 let mut values = Vec::with_capacity(self.values.len());
419 for i in 0..self.values.len() {
420 values.push(self.values[i] / other);
421 }
422
423 Score {
424 values: Arc::from(values),
425 }
426 }
427}
428
429impl Sum for Score {
430 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
431 let mut values = vec![];
432
433 for score in iter {
434 for (i, value) in score.values.iter().enumerate() {
435 if values.len() <= i {
436 values.push(*value);
437 } else {
438 values[i] += value;
439 }
440 }
441 }
442
443 Score {
444 values: Arc::from(values),
445 }
446 }
447}
448
449impl<'a> Sum<&'a Score> for Score {
450 fn sum<I: Iterator<Item = &'a Score>>(iter: I) -> Self {
451 let mut values = vec![];
452
453 for score in iter {
454 for (i, value) in score.values.iter().enumerate() {
455 if values.len() <= i {
456 values.push(*value);
457 } else {
458 values[i] += value;
459 }
460 }
461 }
462
463 Score {
464 values: Arc::from(values),
465 }
466 }
467}
468
469#[cfg(feature = "serde")]
470impl Serialize for Score {
471 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
472 where
473 S: serde::Serializer,
474 {
475 self.values.as_ref().serialize(serializer)
476 }
477}
478
479#[cfg(feature = "serde")]
480impl<'de> Deserialize<'de> for Score {
481 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
482 where
483 D: serde::Deserializer<'de>,
484 {
485 let vec = Vec::<f32>::deserialize(deserializer)?;
486 for value in &vec {
487 if value.is_nan() {
488 return Err(serde::de::Error::custom("Score value cannot be NaN"));
489 }
490 }
491
492 Ok(Score {
493 values: Arc::from(vec),
494 })
495 }
496}
497
498#[cfg(test)]
499mod tests {
500 use super::*;
501
502 #[test]
503 fn test_score_from_vec() {
504 let score = Score::from(vec![1.0, 2.0, 3.0]);
505 assert_eq!(score.values.len(), 3);
506 }
507
508 #[test]
509 fn test_score_from_usize() {
510 let score = Score::from(3);
511 assert_eq!(score.values.len(), 1);
512 assert_eq!(score.as_f32(), 3.0);
513 assert_eq!(score.as_i32(), 3);
514 }
515
516 #[test]
517 fn test_score_from_f32() {
518 let score = Score::from(1.0);
519 assert_eq!(score.as_f32(), 1.0);
520 assert_eq!(score.as_i32(), 1)
521 }
522
523 #[test]
524 fn test_score_from_i32() {
525 let score = Score::from(-5);
526 assert_eq!(score.as_f32(), -5.0);
527 assert_eq!(score.as_i32(), -5);
528 }
529
530 #[test]
531 fn test_score_add() {
532 let score1 = Score::from(vec![1.0, 2.0, 3.0]);
533 let score2 = Score::from(vec![4.0, 5.0, 6.0]);
534 let score3 = score1 + score2;
535
536 assert_eq!(score3.values.len(), 3);
537 assert_eq!(score3.as_f32(), 5.0);
538 assert_eq!(score3[0], 5.0);
539 assert_eq!(score3[1], 7.0);
540 assert_eq!(score3[2], 9.0);
541 }
542
543 #[test]
544 fn test_score_sub() {
545 let score1 = Score::from(vec![5.0, 7.0, 9.0]);
546 let score2 = Score::from(vec![4.0, 5.0, 6.0]);
547 let score3 = score1 - score2;
548 assert_eq!(score3.values.len(), 3);
549 assert_eq!(score3.as_f32(), 1.0);
550 assert_eq!(score3[0], 1.0);
551 assert_eq!(score3[1], 2.0);
552 assert_eq!(score3[2], 3.0);
553 }
554
555 #[test]
556 fn test_score_mul() {
557 let score1 = Score::from(vec![1.0, 2.0, 3.0]);
558 let score2 = Score::from(vec![4.0, 5.0, 6.0]);
559 let score3 = score1 * score2;
560 assert_eq!(score3.values.len(), 3);
561 assert_eq!(score3.as_f32(), 4.0);
562 assert_eq!(score3[0], 4.0);
563 assert_eq!(score3[1], 10.0);
564 assert_eq!(score3[2], 18.0);
565 }
566
567 #[test]
568 fn test_score_div() {
569 let score1 = Score::from(vec![4.0, 8.0, 12.0]);
570 let score2 = Score::from(vec![2.0, 4.0, 6.0]);
571 let score3 = score1 / score2;
572 assert_eq!(score3.values.len(), 3);
573 assert_eq!(score3.as_f32(), 2.0);
574 assert_eq!(score3[0], 2.0);
575 assert_eq!(score3[1], 2.0);
576 assert_eq!(score3[2], 2.0);
577 }
578
579 #[test]
580 #[cfg(feature = "serde")]
581 fn test_score_can_serialize() {
582 let score = Score::from(vec![1.0, 2.0, 3.0]);
583 let serialized = serde_json::to_string(&score).expect("Failed to serialize Score");
584 let deserialized: Score =
585 serde_json::from_str(&serialized).expect("Failed to deserialize Score");
586 assert_eq!(score, deserialized);
587 }
588}