1pub trait Observation {
23 fn features(&self) -> &[f64];
25 fn target(&self) -> f64;
27 fn weight(&self) -> f64 {
29 1.0
30 }
31}
32
33#[derive(Debug, Clone, Copy)]
42pub struct SampleRef<'a> {
43 pub features: &'a [f64],
45 pub target: f64,
47 pub weight: f64,
49}
50
51impl<'a> SampleRef<'a> {
52 #[inline]
54 pub fn new(features: &'a [f64], target: f64) -> Self {
55 Self {
56 features,
57 target,
58 weight: 1.0,
59 }
60 }
61
62 #[inline]
64 pub fn weighted(features: &'a [f64], target: f64, weight: f64) -> Self {
65 Self {
66 features,
67 target,
68 weight,
69 }
70 }
71}
72
73impl<'a> Observation for SampleRef<'a> {
74 #[inline]
75 fn features(&self) -> &[f64] {
76 self.features
77 }
78 #[inline]
79 fn target(&self) -> f64 {
80 self.target
81 }
82 #[inline]
83 fn weight(&self) -> f64 {
84 self.weight
85 }
86}
87
88impl Observation for (&[f64], f64) {
93 #[inline]
94 fn features(&self) -> &[f64] {
95 self.0
96 }
97 #[inline]
98 fn target(&self) -> f64 {
99 self.1
100 }
101}
102
103#[cfg(feature = "alloc")]
104impl Observation for (alloc::vec::Vec<f64>, f64) {
105 #[inline]
106 fn features(&self) -> &[f64] {
107 &self.0
108 }
109 #[inline]
110 fn target(&self) -> f64 {
111 self.1
112 }
113}
114
115#[cfg(feature = "alloc")]
116impl Observation for (&alloc::vec::Vec<f64>, f64) {
117 #[inline]
118 fn features(&self) -> &[f64] {
119 self.0
120 }
121 #[inline]
122 fn target(&self) -> f64 {
123 self.1
124 }
125}
126
127impl<'a> From<(&'a [f64], f64)> for SampleRef<'a> {
132 fn from((features, target): (&'a [f64], f64)) -> Self {
133 SampleRef::new(features, target)
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140
141 #[test]
142 fn test_sample_ref_new() {
143 let features = [1.0, 2.0, 3.0];
144 let obs = SampleRef::new(&features, 42.0);
145 assert_eq!(obs.features(), &[1.0, 2.0, 3.0]);
146 assert_eq!(obs.target(), 42.0);
147 assert!((obs.weight() - 1.0).abs() < 1e-12);
148 }
149
150 #[test]
151 fn test_sample_ref_weighted() {
152 let features = [1.0, 2.0];
153 let obs = SampleRef::weighted(&features, 5.0, 0.5);
154 assert_eq!(obs.features(), &[1.0, 2.0]);
155 assert_eq!(obs.target(), 5.0);
156 assert!((obs.weight() - 0.5).abs() < 1e-12);
157 }
158
159 #[test]
160 fn test_tuple_slice_observation() {
161 let features = [1.0, 2.0, 3.0];
162 let obs = (&features[..], 42.0);
163 assert_eq!(obs.features(), &[1.0, 2.0, 3.0]);
164 assert_eq!(obs.target(), 42.0);
165 assert!((obs.weight() - 1.0).abs() < 1e-12);
166 }
167
168 #[test]
169 fn test_from_tuple_to_sample_ref() {
170 let features = [1.0, 2.0];
171 let obs: SampleRef = (&features[..], 5.0).into();
172 assert_eq!(obs.features(), &[1.0, 2.0]);
173 assert_eq!(obs.target(), 5.0);
174 }
175
176 #[cfg(feature = "alloc")]
177 #[test]
178 fn test_vec_tuple_observation() {
179 let obs = (alloc::vec![1.0, 2.0, 3.0], 42.0);
180 assert_eq!(obs.features(), &[1.0, 2.0, 3.0]);
181 assert_eq!(obs.target(), 42.0);
182 }
183
184 #[cfg(feature = "alloc")]
185 #[test]
186 fn test_vec_ref_tuple_observation() {
187 let v = alloc::vec![1.0, 2.0, 3.0];
188 let obs = (&v, 42.0);
189 assert_eq!(obs.features(), &[1.0, 2.0, 3.0]);
190 assert_eq!(obs.target(), 42.0);
191 }
192}