Skip to main content

irithyll_core/
sample.rs

1//! Core observation trait and zero-copy sample types.
2
3// ---------------------------------------------------------------------------
4// Observation trait -- zero-copy training interface
5// ---------------------------------------------------------------------------
6
7/// Trait for anything that can be used as a training observation.
8///
9/// Implementors provide a feature slice and a target value. The optional
10/// [`weight`](Self::weight) method defaults to 1.0 for uniform weighting.
11///
12/// This trait enables zero-copy training: callers can pass borrowed slices
13/// directly via [`SampleRef`] or tuple impls without allocating.
14///
15/// # Built-in implementations
16///
17/// | Type | Allocates? |
18/// |------|-----------|
19/// | `SampleRef<'a>` | No -- borrows `&[f64]` |
20/// | `(&[f64], f64)` | No -- tuple of slice + target |
21/// | `(Vec<f64>, f64)` | Owns `Vec<f64>` (requires `alloc` feature) |
22pub trait Observation {
23    /// The feature values for this observation.
24    fn features(&self) -> &[f64];
25    /// The target value (regression) or class label (classification).
26    fn target(&self) -> f64;
27    /// Optional sample weight. Defaults to 1.0 (uniform).
28    fn weight(&self) -> f64 {
29        1.0
30    }
31}
32
33// ---------------------------------------------------------------------------
34// SampleRef -- zero-copy borrowing observation
35// ---------------------------------------------------------------------------
36
37/// A borrowed observation that avoids `Vec<f64>` allocation.
38///
39/// Use this when features are already available as a contiguous slice
40/// (e.g., from Arrow arrays, memory-mapped data, or pre-allocated buffers).
41#[derive(Debug, Clone, Copy)]
42pub struct SampleRef<'a> {
43    /// Borrowed feature slice.
44    pub features: &'a [f64],
45    /// Target value.
46    pub target: f64,
47    /// Sample weight (default 1.0).
48    pub weight: f64,
49}
50
51impl<'a> SampleRef<'a> {
52    /// Create a new sample reference with unit weight.
53    #[inline]
54    pub fn new(features: &'a [f64], target: f64) -> Self {
55        Self {
56            features,
57            target,
58            weight: 1.0,
59        }
60    }
61
62    /// Create a new sample reference with explicit weight.
63    #[inline]
64    pub fn weighted(features: &'a [f64], target: f64, weight: f64) -> Self {
65        Self {
66            features,
67            target,
68            weight,
69        }
70    }
71}
72
73impl<'a> Observation for SampleRef<'a> {
74    #[inline]
75    fn features(&self) -> &[f64] {
76        self.features
77    }
78    #[inline]
79    fn target(&self) -> f64 {
80        self.target
81    }
82    #[inline]
83    fn weight(&self) -> f64 {
84        self.weight
85    }
86}
87
88// ---------------------------------------------------------------------------
89// Tuple impls -- quick-and-dirty observations
90// ---------------------------------------------------------------------------
91
92impl Observation for (&[f64], f64) {
93    #[inline]
94    fn features(&self) -> &[f64] {
95        self.0
96    }
97    #[inline]
98    fn target(&self) -> f64 {
99        self.1
100    }
101}
102
103#[cfg(feature = "alloc")]
104impl Observation for (alloc::vec::Vec<f64>, f64) {
105    #[inline]
106    fn features(&self) -> &[f64] {
107        &self.0
108    }
109    #[inline]
110    fn target(&self) -> f64 {
111        self.1
112    }
113}
114
115#[cfg(feature = "alloc")]
116impl Observation for (&alloc::vec::Vec<f64>, f64) {
117    #[inline]
118    fn features(&self) -> &[f64] {
119        self.0
120    }
121    #[inline]
122    fn target(&self) -> f64 {
123        self.1
124    }
125}
126
127// ---------------------------------------------------------------------------
128// From impls -- conversion convenience
129// ---------------------------------------------------------------------------
130
131impl<'a> From<(&'a [f64], f64)> for SampleRef<'a> {
132    fn from((features, target): (&'a [f64], f64)) -> Self {
133        SampleRef::new(features, target)
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140
141    #[test]
142    fn test_sample_ref_new() {
143        let features = [1.0, 2.0, 3.0];
144        let obs = SampleRef::new(&features, 42.0);
145        assert_eq!(obs.features(), &[1.0, 2.0, 3.0]);
146        assert_eq!(obs.target(), 42.0);
147        assert!((obs.weight() - 1.0).abs() < 1e-12);
148    }
149
150    #[test]
151    fn test_sample_ref_weighted() {
152        let features = [1.0, 2.0];
153        let obs = SampleRef::weighted(&features, 5.0, 0.5);
154        assert_eq!(obs.features(), &[1.0, 2.0]);
155        assert_eq!(obs.target(), 5.0);
156        assert!((obs.weight() - 0.5).abs() < 1e-12);
157    }
158
159    #[test]
160    fn test_tuple_slice_observation() {
161        let features = [1.0, 2.0, 3.0];
162        let obs = (&features[..], 42.0);
163        assert_eq!(obs.features(), &[1.0, 2.0, 3.0]);
164        assert_eq!(obs.target(), 42.0);
165        assert!((obs.weight() - 1.0).abs() < 1e-12);
166    }
167
168    #[test]
169    fn test_from_tuple_to_sample_ref() {
170        let features = [1.0, 2.0];
171        let obs: SampleRef = (&features[..], 5.0).into();
172        assert_eq!(obs.features(), &[1.0, 2.0]);
173        assert_eq!(obs.target(), 5.0);
174    }
175
176    #[cfg(feature = "alloc")]
177    #[test]
178    fn test_vec_tuple_observation() {
179        let obs = (alloc::vec![1.0, 2.0, 3.0], 42.0);
180        assert_eq!(obs.features(), &[1.0, 2.0, 3.0]);
181        assert_eq!(obs.target(), 42.0);
182    }
183
184    #[cfg(feature = "alloc")]
185    #[test]
186    fn test_vec_ref_tuple_observation() {
187        let v = alloc::vec![1.0, 2.0, 3.0];
188        let obs = (&v, 42.0);
189        assert_eq!(obs.features(), &[1.0, 2.0, 3.0]);
190        assert_eq!(obs.target(), 42.0);
191    }
192}