Skip to main content

irithyll/stream/
mod.rs

1//! Async streaming infrastructure for tokio-native sample ingestion.
2//!
3//! This module provides the async interface for running an [`SGBT`] model as
4//! a long-lived training service. Samples arrive through a bounded channel,
5//! the model trains incrementally on each one, and concurrent read-only
6//! prediction access is available at all times via [`Predictor`] handles.
7//!
8//! # Architecture
9//!
10//! ```text
11//! ┌──────────┐     mpsc      ┌───────────┐
12//! │ Senders  │───(bounded)──>│ AsyncSGBT │  (write lock per sample)
13//! └──────────┘    channel    │  .run()    │
14//!                            └─────┬─────┘
15//!                                  │
16//!                       Arc<RwLock<SGBT<L>>>
17//!                                  │
18//!                            ┌─────┴─────┐
19//!                            │ Predictor │  (read lock per predict)
20//!                            └───────────┘
21//! ```
22//!
23//! # Lifecycle
24//!
25//! 1. Create an `AsyncSGBT` via [`new`](AsyncSGBT::new) or
26//!    [`with_capacity`](AsyncSGBT::with_capacity).
27//! 2. Clone sender handles via [`sender`](AsyncSGBT::sender) and predictor
28//!    handles via [`predictor`](AsyncSGBT::predictor).
29//! 3. Spawn [`run`](AsyncSGBT::run) (or [`run_with_callback`](AsyncSGBT::run_with_callback))
30//!    on a tokio task. The loop starts by dropping its internal sender copy
31//!    so that the channel closes cleanly once all external senders are dropped.
32//! 4. Feed samples from any number of async tasks.
33//! 5. Drop all senders to signal shutdown; the training loop drains remaining
34//!    buffered samples and returns `Ok(())`.
35//!
36//! # Example
37//!
38//! ```no_run
39//! use irithyll::{SGBTConfig, Sample};
40//! use irithyll::stream::AsyncSGBT;
41//!
42//! # async fn example() -> irithyll::error::Result<()> {
43//! let config = SGBTConfig::builder()
44//!     .n_steps(50)
45//!     .learning_rate(0.1)
46//!     .build()?;
47//!
48//! let mut runner = AsyncSGBT::new(config);
49//! let sender = runner.sender();
50//! let predictor = runner.predictor();
51//!
52//! // Spawn the training loop.
53//! let train_handle = tokio::spawn(async move { runner.run().await });
54//!
55//! // Feed samples from any async context.
56//! sender.send(Sample::new(vec![1.0, 2.0], 3.0)).await?;
57//!
58//! // Predict concurrently while training proceeds.
59//! let pred = predictor.predict(&[1.0, 2.0]);
60//!
61//! // Drop sender to signal shutdown; training loop returns Ok(()).
62//! drop(sender);
63//! train_handle.await.unwrap()?;
64//! # Ok(())
65//! # }
66//! ```
67
68pub mod adapters;
69pub mod channel;
70
71use std::fmt;
72use std::sync::Arc;
73
74use parking_lot::RwLock;
75use tracing::debug;
76
77use crate::ensemble::config::SGBTConfig;
78use crate::ensemble::SGBT;
79use crate::error::Result;
80use crate::loss::squared::SquaredLoss;
81use crate::loss::Loss;
82
83pub use adapters::{Prediction, PredictionStream};
84pub use channel::{SampleReceiver, SampleSender};
85
86/// Default bounded channel capacity when none is specified.
87const DEFAULT_CHANNEL_CAPACITY: usize = 1024;
88
89// ---------------------------------------------------------------------------
90// Predictor
91// ---------------------------------------------------------------------------
92
93/// A concurrent, read-only prediction handle to a shared [`SGBT`] model.
94///
95/// Obtained via [`AsyncSGBT::predictor`]. Each prediction acquires a read lock
96/// on the underlying `RwLock<SGBT<L>>`, allowing multiple predictors to operate
97/// concurrently and in parallel with the training loop (which holds a write
98/// lock only briefly per sample).
99///
100/// `Predictor` is `Clone`, `Send`, and `Sync` -- share it freely across tasks.
101pub struct Predictor<L: Loss = SquaredLoss> {
102    pub(crate) model: Arc<RwLock<SGBT<L>>>,
103}
104
105// Manual Clone impl -- cloning the Arc doesn't require L: Clone.
106impl<L: Loss> Clone for Predictor<L> {
107    fn clone(&self) -> Self {
108        Self {
109            model: Arc::clone(&self.model),
110        }
111    }
112}
113
114impl<L: Loss> fmt::Debug for Predictor<L> {
115    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116        f.debug_struct("Predictor")
117            .field("n_samples_seen", &self.model.read().n_samples_seen())
118            .finish()
119    }
120}
121
122impl<L: Loss> Predictor<L> {
123    /// Predict the raw model output for a feature vector.
124    ///
125    /// Acquires a read lock on the shared model. Returns the unscaled
126    /// ensemble prediction (base + weighted sum of tree outputs).
127    #[inline]
128    pub fn predict(&self, features: &[f64]) -> f64 {
129        self.model.read().predict(features)
130    }
131
132    /// Predict with the loss function's transform applied.
133    ///
134    /// For regression (squared loss) this is identity; for binary
135    /// classification (logistic loss) this applies the sigmoid.
136    #[inline]
137    pub fn predict_transformed(&self, features: &[f64]) -> f64 {
138        self.model.read().predict_transformed(features)
139    }
140
141    /// Number of samples the model has been trained on so far.
142    #[inline]
143    pub fn n_samples_seen(&self) -> u64 {
144        self.model.read().n_samples_seen()
145    }
146
147    /// Whether the model's base prediction has been initialized.
148    ///
149    /// Returns `false` until enough initial samples have been collected
150    /// (typically 50) to compute the base constant.
151    #[inline]
152    pub fn is_initialized(&self) -> bool {
153        self.model.read().is_initialized()
154    }
155}
156
157// ---------------------------------------------------------------------------
158// AsyncSGBT
159// ---------------------------------------------------------------------------
160
161/// Async wrapper around [`SGBT`] for tokio-native streaming training.
162///
163/// `AsyncSGBT` owns the shared model and the receiving end of a bounded
164/// sample channel. Call [`run`](Self::run) to start the training loop,
165/// which consumes samples from the channel and trains incrementally.
166///
167/// Generic over `L: Loss` so the training loop benefits from monomorphized
168/// gradient/hessian dispatch (no vtable overhead).
169///
170/// Prediction handles ([`Predictor`]) and sender handles ([`SampleSender`])
171/// can be obtained before starting the loop and used concurrently from
172/// other tasks.
173///
174/// # Shutdown
175///
176/// When [`run`](Self::run) is called, it drops the internal sender copy
177/// so that the channel closes as soon as all external senders are dropped.
178/// The loop then drains any remaining buffered samples and returns `Ok(())`.
179pub struct AsyncSGBT<L: Loss = SquaredLoss> {
180    /// Shared model, protected by a parking_lot RwLock.
181    model: Arc<RwLock<SGBT<L>>>,
182    /// Receiving end of the sample channel.
183    receiver: Option<SampleReceiver>,
184    /// Sending end, kept so callers can clone it via `sender()`.
185    /// Wrapped in Option so `run()` can drop it before entering the loop,
186    /// ensuring the channel closes when all external senders are dropped.
187    sender: Option<SampleSender>,
188}
189
190impl AsyncSGBT<SquaredLoss> {
191    /// Create a new async SGBT runner with the default channel capacity (1024).
192    ///
193    /// Uses squared loss (regression). For other loss functions, use
194    /// [`with_loss`](AsyncSGBT::with_loss) or [`with_loss_and_capacity`](AsyncSGBT::with_loss_and_capacity).
195    pub fn new(config: SGBTConfig) -> Self {
196        Self::with_capacity(config, DEFAULT_CHANNEL_CAPACITY)
197    }
198
199    /// Create a new async SGBT runner with a custom channel capacity.
200    ///
201    /// Uses squared loss (regression).
202    pub fn with_capacity(config: SGBTConfig, capacity: usize) -> Self {
203        let model = SGBT::new(config);
204        let shared = Arc::new(RwLock::new(model));
205        let (sender, receiver) = channel::bounded(capacity);
206
207        Self {
208            model: shared,
209            receiver: Some(receiver),
210            sender: Some(sender),
211        }
212    }
213}
214
215impl<L: Loss> AsyncSGBT<L> {
216    /// Create a new async SGBT runner with a specific loss function.
217    ///
218    /// ```no_run
219    /// use irithyll::SGBTConfig;
220    /// use irithyll::stream::AsyncSGBT;
221    /// use irithyll::loss::logistic::LogisticLoss;
222    ///
223    /// let config = SGBTConfig::builder().n_steps(10).build().unwrap();
224    /// let runner = AsyncSGBT::with_loss(config, LogisticLoss);
225    /// ```
226    pub fn with_loss(config: SGBTConfig, loss: L) -> Self {
227        Self::with_loss_and_capacity(config, loss, DEFAULT_CHANNEL_CAPACITY)
228    }
229
230    /// Create a new async SGBT runner with a specific loss and channel capacity.
231    pub fn with_loss_and_capacity(config: SGBTConfig, loss: L, capacity: usize) -> Self {
232        let model = SGBT::with_loss(config, loss);
233        let shared = Arc::new(RwLock::new(model));
234        let (sender, receiver) = channel::bounded(capacity);
235
236        Self {
237            model: shared,
238            receiver: Some(receiver),
239            sender: Some(sender),
240        }
241    }
242
243    /// Obtain a clonable sender handle for feeding samples into the channel.
244    ///
245    /// Multiple senders can be created (via `Clone`) and used from different
246    /// async tasks. The training loop runs until all external senders are
247    /// dropped.
248    ///
249    /// # Panics
250    ///
251    /// Panics if called after [`run`](Self::run) has already started, since
252    /// the internal sender is consumed at that point.
253    pub fn sender(&self) -> SampleSender {
254        self.sender
255            .as_ref()
256            .expect("sender() called after run() consumed the internal sender")
257            .clone()
258    }
259
260    /// Obtain a concurrent prediction handle to the shared model.
261    ///
262    /// The predictor can be cloned and used from any thread or task while
263    /// the training loop is running.
264    pub fn predictor(&self) -> Predictor<L> {
265        Predictor {
266            model: Arc::clone(&self.model),
267        }
268    }
269
270    /// Run the main training loop.
271    ///
272    /// Receives samples from the bounded channel and trains the model
273    /// incrementally. For each sample:
274    ///
275    /// 1. Acquire write lock on the shared `SGBT<L>`.
276    /// 2. Call `train_one(&sample)`.
277    /// 3. Release the lock.
278    ///
279    /// Before entering the loop, the internal sender is dropped so that the
280    /// channel closes cleanly when all external senders are dropped.
281    ///
282    /// Returns `Ok(())` when the channel closes (all senders have been
283    /// dropped and all buffered samples have been consumed).
284    ///
285    /// # Logging
286    ///
287    /// Emits a `tracing::debug!` message every 1000 samples with the
288    /// current sample count.
289    ///
290    /// # Panics
291    ///
292    /// Panics if called more than once (the receiver is consumed on first call).
293    pub async fn run(&mut self) -> Result<()> {
294        // Drop our sender so the channel closes when external senders drop.
295        self.sender.take();
296
297        let receiver = self
298            .receiver
299            .take()
300            .expect("run() called more than once: receiver already consumed");
301
302        self.run_inner(receiver, None::<fn(u64)>).await
303    }
304
305    /// Run the training loop with a callback invoked after each sample.
306    ///
307    /// Behaves identically to [`run`](Self::run), but calls `callback`
308    /// with the current `n_samples_seen()` count after training each sample.
309    /// Useful for progress bars, metrics collection, or adaptive control.
310    ///
311    /// The callback runs synchronously within the training task -- keep it
312    /// fast to avoid blocking the loop.
313    ///
314    /// # Panics
315    ///
316    /// Panics if called more than once (the receiver is consumed on first call).
317    pub async fn run_with_callback<F>(&mut self, callback: F) -> Result<()>
318    where
319        F: Fn(u64),
320    {
321        // Drop our sender so the channel closes when external senders drop.
322        self.sender.take();
323
324        let receiver = self
325            .receiver
326            .take()
327            .expect("run_with_callback() called more than once: receiver already consumed");
328
329        self.run_inner(receiver, Some(callback)).await
330    }
331
332    /// Internal training loop shared by `run` and `run_with_callback`.
333    async fn run_inner<F>(&self, mut receiver: SampleReceiver, callback: Option<F>) -> Result<()>
334    where
335        F: Fn(u64),
336    {
337        while let Some(sample) = receiver.recv().await {
338            let seen;
339            {
340                let mut model = self.model.write();
341                model.train_one(&sample);
342                seen = model.n_samples_seen();
343            }
344
345            if let Some(ref cb) = callback {
346                cb(seen);
347            }
348
349            if seen % 1000 == 0 {
350                debug!(samples_seen = seen, "async training progress");
351            }
352        }
353
354        let total = self.model.read().n_samples_seen();
355        debug!(total_samples = total, "async training loop completed");
356
357        Ok(())
358    }
359}
360
361// ---------------------------------------------------------------------------
362// Tests
363// ---------------------------------------------------------------------------
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368    use crate::ensemble::config::SGBTConfig;
369    use crate::sample::Sample;
370
371    use std::sync::atomic::{AtomicU64, Ordering};
372
373    fn default_config() -> SGBTConfig {
374        SGBTConfig::builder()
375            .n_steps(5)
376            .learning_rate(0.1)
377            .grace_period(10)
378            .max_depth(3)
379            .n_bins(8)
380            .build()
381            .unwrap()
382    }
383
384    fn sample(x: f64) -> Sample {
385        Sample::new(vec![x, x * 0.5], x * 2.0)
386    }
387
388    // 1. Basic lifecycle: send samples, run loop, verify training.
389    #[tokio::test]
390    async fn basic_lifecycle() {
391        let mut runner = AsyncSGBT::new(default_config());
392        let sender = runner.sender();
393        let predictor = runner.predictor();
394
395        // Initially untrained.
396        assert_eq!(predictor.n_samples_seen(), 0);
397        assert!(!predictor.is_initialized());
398
399        let handle = tokio::spawn(async move { runner.run().await });
400
401        for i in 0..20 {
402            sender.send(sample(i as f64)).await.unwrap();
403        }
404        drop(sender);
405
406        handle.await.unwrap().unwrap();
407        assert_eq!(predictor.n_samples_seen(), 20);
408    }
409
410    // 2. Predictor works concurrently with training.
411    #[tokio::test]
412    async fn concurrent_predict_during_training() {
413        let mut runner = AsyncSGBT::new(default_config());
414        let sender = runner.sender();
415        let predictor = runner.predictor();
416
417        let pred_handle = tokio::spawn({
418            let predictor = predictor.clone();
419            async move {
420                // Keep predicting while training runs.
421                let mut predictions = Vec::new();
422                for _ in 0..50 {
423                    let p = predictor.predict(&[1.0, 0.5]);
424                    predictions.push(p);
425                    tokio::task::yield_now().await;
426                }
427                predictions
428            }
429        });
430
431        let train_handle = tokio::spawn(async move { runner.run().await });
432
433        for i in 0..100 {
434            sender.send(sample(i as f64)).await.unwrap();
435        }
436        drop(sender);
437
438        let predictions = pred_handle.await.unwrap();
439        train_handle.await.unwrap().unwrap();
440
441        // All predictions should be finite.
442        assert!(predictions.iter().all(|p| p.is_finite()));
443    }
444
445    // 3. run returns Ok(()) when channel closes immediately (no samples).
446    #[tokio::test]
447    async fn run_returns_ok_on_empty_channel() {
448        let mut runner = AsyncSGBT::new(default_config());
449        let sender = runner.sender();
450        // Drop the only external sender immediately.
451        drop(sender);
452
453        // run() drops the internal sender, so the channel is fully closed.
454        let result = runner.run().await;
455        assert!(result.is_ok());
456        assert_eq!(runner.model.read().n_samples_seen(), 0);
457    }
458
459    // 4. with_capacity creates channel with specified size.
460    #[tokio::test]
461    async fn with_capacity_custom() {
462        let mut runner = AsyncSGBT::with_capacity(default_config(), 2);
463        let sender = runner.sender();
464
465        let handle = tokio::spawn(async move { runner.run().await });
466
467        // Channel capacity 2: should be able to send 2 without blocking.
468        sender.send(sample(1.0)).await.unwrap();
469        sender.send(sample(2.0)).await.unwrap();
470        drop(sender);
471
472        handle.await.unwrap().unwrap();
473    }
474
475    // 5. Multiple senders from different tasks.
476    #[tokio::test]
477    async fn multiple_senders() {
478        let mut runner = AsyncSGBT::new(default_config());
479        let sender1 = runner.sender();
480        let sender2 = runner.sender();
481        let predictor = runner.predictor();
482
483        let handle = tokio::spawn(async move { runner.run().await });
484
485        let h1 = tokio::spawn(async move {
486            for i in 0..10 {
487                sender1.send(sample(i as f64)).await.unwrap();
488            }
489        });
490
491        let h2 = tokio::spawn(async move {
492            for i in 10..20 {
493                sender2.send(sample(i as f64)).await.unwrap();
494            }
495        });
496
497        h1.await.unwrap();
498        h2.await.unwrap();
499
500        // Both senders dropped (moved into tasks that completed).
501        // run() already dropped its internal sender, so the channel closes.
502        // Wait for the training loop to drain and finish.
503        handle.await.unwrap().unwrap();
504
505        assert_eq!(predictor.n_samples_seen(), 20);
506    }
507
508    // 6. run_with_callback invokes callback for each sample.
509    #[tokio::test]
510    async fn run_with_callback_invokes() {
511        let mut runner = AsyncSGBT::new(default_config());
512        let sender = runner.sender();
513
514        let counter = Arc::new(AtomicU64::new(0));
515        let counter_clone = Arc::clone(&counter);
516
517        let handle = tokio::spawn(async move {
518            runner
519                .run_with_callback(move |_seen| {
520                    counter_clone.fetch_add(1, Ordering::Relaxed);
521                })
522                .await
523        });
524
525        for i in 0..15 {
526            sender.send(sample(i as f64)).await.unwrap();
527        }
528        drop(sender);
529
530        handle.await.unwrap().unwrap();
531        assert_eq!(counter.load(Ordering::Relaxed), 15);
532    }
533
534    // 7. Callback receives correct sample counts.
535    #[tokio::test]
536    async fn callback_receives_correct_counts() {
537        let mut runner = AsyncSGBT::new(default_config());
538        let sender = runner.sender();
539
540        let counts = Arc::new(parking_lot::Mutex::new(Vec::new()));
541        let counts_clone = Arc::clone(&counts);
542
543        let handle = tokio::spawn(async move {
544            runner
545                .run_with_callback(move |seen| {
546                    counts_clone.lock().push(seen);
547                })
548                .await
549        });
550
551        for i in 0..5 {
552            sender.send(sample(i as f64)).await.unwrap();
553        }
554        drop(sender);
555
556        handle.await.unwrap().unwrap();
557
558        let recorded = counts.lock().clone();
559        assert_eq!(recorded.len(), 5);
560        // Counts should be monotonically increasing.
561        for window in recorded.windows(2) {
562            assert!(window[1] > window[0]);
563        }
564        assert_eq!(*recorded.last().unwrap(), 5);
565    }
566
567    // 8. Predictor clone is independent but sees same model.
568    #[tokio::test]
569    async fn predictor_clone_independent() {
570        let runner = AsyncSGBT::new(default_config());
571        let p1 = runner.predictor();
572        let p2 = p1.clone();
573
574        // Both should return the same prediction (same underlying model).
575        let pred1 = p1.predict(&[1.0, 2.0]);
576        let pred2 = p2.predict(&[1.0, 2.0]);
577        assert!((pred1 - pred2).abs() < f64::EPSILON);
578    }
579
580    // 9. predict_transformed works through Predictor.
581    #[tokio::test]
582    async fn predictor_predict_transformed() {
583        let runner = AsyncSGBT::new(default_config());
584        let predictor = runner.predictor();
585
586        // For squared loss, predict_transformed == predict (identity transform).
587        let raw = predictor.predict(&[1.0, 2.0]);
588        let transformed = predictor.predict_transformed(&[1.0, 2.0]);
589        assert!((raw - transformed).abs() < f64::EPSILON);
590    }
591
592    // 10. Predictor is Send + Sync (compile-time check).
593    #[test]
594    fn predictor_is_send_sync() {
595        fn assert_send_sync<T: Send + Sync>() {}
596        assert_send_sync::<Predictor>();
597    }
598
599    // 11. AsyncSGBT is Send (required for tokio::spawn).
600    #[test]
601    fn async_sgbt_is_send() {
602        fn assert_send<T: Send>() {}
603        assert_send::<AsyncSGBT>();
604    }
605
606    // 12. Training actually improves predictions.
607    #[tokio::test]
608    async fn training_improves_predictions() {
609        let mut runner = AsyncSGBT::new(default_config());
610        let sender = runner.sender();
611        let predictor = runner.predictor();
612
613        let handle = tokio::spawn(async move { runner.run().await });
614
615        // Prediction before training.
616        let pred_before = predictor.predict(&[5.0, 2.5]);
617
618        // Send consistent data: target = 10.0.
619        for _ in 0..100 {
620            sender
621                .send(Sample::new(vec![5.0, 2.5], 10.0))
622                .await
623                .unwrap();
624        }
625
626        // Give the training loop time to process.
627        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
628
629        let pred_after = predictor.predict(&[5.0, 2.5]);
630        drop(sender);
631
632        handle.await.unwrap().unwrap();
633
634        // After training on constant target 10.0, prediction should be
635        // closer to 10.0 than the initial (0.0).
636        assert!(
637            (pred_after - 10.0).abs() < (pred_before - 10.0).abs(),
638            "prediction should improve: before={}, after={}, target=10.0",
639            pred_before,
640            pred_after
641        );
642    }
643
644    // 13. with_loss creates async runner with custom loss.
645    #[tokio::test]
646    async fn with_loss_creates_runner() {
647        use crate::loss::logistic::LogisticLoss;
648
649        let config = default_config();
650        let mut runner = AsyncSGBT::with_loss(config, LogisticLoss);
651        let sender = runner.sender();
652        let predictor = runner.predictor();
653
654        // Sigmoid(0) = 0.5 for logistic loss
655        let pred = predictor.predict_transformed(&[1.0, 2.0]);
656        assert!(
657            (pred - 0.5).abs() < 1e-6,
658            "sigmoid(0) should be 0.5, got {}",
659            pred
660        );
661
662        let handle = tokio::spawn(async move { runner.run().await });
663        drop(sender);
664        handle.await.unwrap().unwrap();
665    }
666}