Skip to main content

irithyll_core/
sample.rs

1//! Core observation trait and zero-copy sample types.
2
3// ---------------------------------------------------------------------------
4// Observation trait -- zero-copy training interface
5// ---------------------------------------------------------------------------
6
7/// Trait for anything that can be used as a training observation.
8///
9/// Implementors provide a feature slice and a target value. The optional
10/// [`weight`](Self::weight) method defaults to 1.0 for uniform weighting.
11///
12/// This trait enables zero-copy training: callers can pass borrowed slices
13/// directly via [`SampleRef`] or tuple impls without allocating.
14///
15/// # Built-in implementations
16///
17/// | Type | Allocates? |
18/// |------|-----------|
19/// | `SampleRef<'a>` | No -- borrows `&[f64]` |
20/// | `(&[f64], f64)` | No -- tuple of slice + target |
21/// | `(Vec<f64>, f64)` | Owns `Vec<f64>` (requires `alloc` feature) |
22pub trait Observation {
23    /// The feature values for this observation.
24    fn features(&self) -> &[f64];
25    /// The target value (regression) or class label (classification).
26    fn target(&self) -> f64;
27    /// Optional sample weight. Defaults to 1.0 (uniform).
28    fn weight(&self) -> f64 {
29        1.0
30    }
31}
32
33// ---------------------------------------------------------------------------
34// SampleRef -- zero-copy borrowing observation
35// ---------------------------------------------------------------------------
36
37/// A borrowed observation that avoids `Vec<f64>` allocation.
38///
39/// Use this when features are already available as a contiguous slice
40/// (e.g., from Arrow arrays, memory-mapped data, or pre-allocated buffers).
41#[derive(Debug, Clone, Copy)]
42pub struct SampleRef<'a> {
43    /// Borrowed feature slice.
44    pub features: &'a [f64],
45    /// Target value.
46    pub target: f64,
47    /// Sample weight (default 1.0).
48    pub weight: f64,
49}
50
51impl<'a> SampleRef<'a> {
52    /// Create a new sample reference with unit weight.
53    #[inline]
54    pub fn new(features: &'a [f64], target: f64) -> Self {
55        Self {
56            features,
57            target,
58            weight: 1.0,
59        }
60    }
61
62    /// Create a new sample reference with explicit weight.
63    #[inline]
64    pub fn weighted(features: &'a [f64], target: f64, weight: f64) -> Self {
65        Self {
66            features,
67            target,
68            weight,
69        }
70    }
71}
72
73impl<'a> Observation for SampleRef<'a> {
74    #[inline]
75    fn features(&self) -> &[f64] {
76        self.features
77    }
78    #[inline]
79    fn target(&self) -> f64 {
80        self.target
81    }
82    #[inline]
83    fn weight(&self) -> f64 {
84        self.weight
85    }
86}
87
88// ---------------------------------------------------------------------------
89// Tuple impls -- quick-and-dirty observations
90// ---------------------------------------------------------------------------
91
92impl Observation for (&[f64], f64) {
93    #[inline]
94    fn features(&self) -> &[f64] {
95        self.0
96    }
97    #[inline]
98    fn target(&self) -> f64 {
99        self.1
100    }
101}
102
103#[cfg(feature = "alloc")]
104impl Observation for (alloc::vec::Vec<f64>, f64) {
105    #[inline]
106    fn features(&self) -> &[f64] {
107        &self.0
108    }
109    #[inline]
110    fn target(&self) -> f64 {
111        self.1
112    }
113}
114
115#[cfg(feature = "alloc")]
116impl Observation for (&alloc::vec::Vec<f64>, f64) {
117    #[inline]
118    fn features(&self) -> &[f64] {
119        self.0
120    }
121    #[inline]
122    fn target(&self) -> f64 {
123        self.1
124    }
125}
126
127// ---------------------------------------------------------------------------
128// Sample -- owned observation (requires alloc)
129// ---------------------------------------------------------------------------
130
131/// A single owned observation with feature vector and target value.
132///
133/// For regression, `target` is the continuous value to predict.
134/// For binary classification, `target` is 0.0 or 1.0.
135/// For multi-class, `target` is the class index as f64 (0.0, 1.0, 2.0, ...).
136#[cfg(feature = "alloc")]
137#[derive(Debug, Clone)]
138#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
139pub struct Sample {
140    /// Feature values for this observation.
141    pub features: alloc::vec::Vec<f64>,
142    /// Target value (regression) or class label (classification).
143    pub target: f64,
144    /// Optional sample weight (default 1.0).
145    pub weight: f64,
146}
147
148#[cfg(feature = "alloc")]
149impl Sample {
150    /// Create a new sample with unit weight.
151    #[inline]
152    pub fn new(features: alloc::vec::Vec<f64>, target: f64) -> Self {
153        Self {
154            features,
155            target,
156            weight: 1.0,
157        }
158    }
159
160    /// Create a new sample with explicit weight.
161    #[inline]
162    pub fn weighted(features: alloc::vec::Vec<f64>, target: f64, weight: f64) -> Self {
163        Self {
164            features,
165            target,
166            weight,
167        }
168    }
169
170    /// Number of features in this sample.
171    #[inline]
172    pub fn n_features(&self) -> usize {
173        self.features.len()
174    }
175}
176
177#[cfg(feature = "alloc")]
178impl Observation for Sample {
179    #[inline]
180    fn features(&self) -> &[f64] {
181        &self.features
182    }
183    #[inline]
184    fn target(&self) -> f64 {
185        self.target
186    }
187    #[inline]
188    fn weight(&self) -> f64 {
189        self.weight
190    }
191}
192
193#[cfg(feature = "alloc")]
194impl Observation for &Sample {
195    #[inline]
196    fn features(&self) -> &[f64] {
197        &self.features
198    }
199    #[inline]
200    fn target(&self) -> f64 {
201        self.target
202    }
203    #[inline]
204    fn weight(&self) -> f64 {
205        self.weight
206    }
207}
208
209// ---------------------------------------------------------------------------
210// From impls -- conversion convenience
211// ---------------------------------------------------------------------------
212
213impl<'a> From<(&'a [f64], f64)> for SampleRef<'a> {
214    fn from((features, target): (&'a [f64], f64)) -> Self {
215        SampleRef::new(features, target)
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[test]
224    fn test_sample_ref_new() {
225        let features = [1.0, 2.0, 3.0];
226        let obs = SampleRef::new(&features, 42.0);
227        assert_eq!(obs.features(), &[1.0, 2.0, 3.0]);
228        assert_eq!(obs.target(), 42.0);
229        assert!((obs.weight() - 1.0).abs() < 1e-12);
230    }
231
232    #[test]
233    fn test_sample_ref_weighted() {
234        let features = [1.0, 2.0];
235        let obs = SampleRef::weighted(&features, 5.0, 0.5);
236        assert_eq!(obs.features(), &[1.0, 2.0]);
237        assert_eq!(obs.target(), 5.0);
238        assert!((obs.weight() - 0.5).abs() < 1e-12);
239    }
240
241    #[test]
242    fn test_tuple_slice_observation() {
243        let features = [1.0, 2.0, 3.0];
244        let obs = (&features[..], 42.0);
245        assert_eq!(obs.features(), &[1.0, 2.0, 3.0]);
246        assert_eq!(obs.target(), 42.0);
247        assert!((obs.weight() - 1.0).abs() < 1e-12);
248    }
249
250    #[test]
251    fn test_from_tuple_to_sample_ref() {
252        let features = [1.0, 2.0];
253        let obs: SampleRef = (&features[..], 5.0).into();
254        assert_eq!(obs.features(), &[1.0, 2.0]);
255        assert_eq!(obs.target(), 5.0);
256    }
257
258    #[cfg(feature = "alloc")]
259    #[test]
260    fn test_vec_tuple_observation() {
261        let obs = (alloc::vec![1.0, 2.0, 3.0], 42.0);
262        assert_eq!(obs.features(), &[1.0, 2.0, 3.0]);
263        assert_eq!(obs.target(), 42.0);
264    }
265
266    #[cfg(feature = "alloc")]
267    #[test]
268    fn test_vec_ref_tuple_observation() {
269        let v = alloc::vec![1.0, 2.0, 3.0];
270        let obs = (&v, 42.0);
271        assert_eq!(obs.features(), &[1.0, 2.0, 3.0]);
272        assert_eq!(obs.target(), 42.0);
273    }
274}