irithyll/stream/mod.rs
1//! Async streaming infrastructure for tokio-native sample ingestion.
2//!
3//! This module provides the async interface for running an [`SGBT`] model as
4//! a long-lived training service. Samples arrive through a bounded channel,
5//! the model trains incrementally on each one, and concurrent read-only
6//! prediction access is available at all times via [`Predictor`] handles.
7//!
8//! # Architecture
9//!
10//! ```text
11//! ┌──────────┐ mpsc ┌───────────┐
12//! │ Senders │───(bounded)──>│ AsyncSGBT │ (write lock per sample)
13//! └──────────┘ channel │ .run() │
14//! └─────┬─────┘
15//! │
16//! Arc<RwLock<SGBT<L>>>
17//! │
18//! ┌─────┴─────┐
19//! │ Predictor │ (read lock per predict)
20//! └───────────┘
21//! ```
22//!
23//! # Lifecycle
24//!
25//! 1. Create an `AsyncSGBT` via [`new`](AsyncSGBT::new) or
26//! [`with_capacity`](AsyncSGBT::with_capacity).
27//! 2. Clone sender handles via [`sender`](AsyncSGBT::sender) and predictor
28//! handles via [`predictor`](AsyncSGBT::predictor).
29//! 3. Spawn [`run`](AsyncSGBT::run) (or [`run_with_callback`](AsyncSGBT::run_with_callback))
30//! on a tokio task. The loop starts by dropping its internal sender copy
31//! so that the channel closes cleanly once all external senders are dropped.
32//! 4. Feed samples from any number of async tasks.
33//! 5. Drop all senders to signal shutdown; the training loop drains remaining
34//! buffered samples and returns `Ok(())`.
35//!
36//! # Example
37//!
38//! ```no_run
39//! use irithyll::{SGBTConfig, Sample};
40//! use irithyll::stream::AsyncSGBT;
41//!
42//! # async fn example() -> irithyll::error::Result<()> {
43//! let config = SGBTConfig::builder()
44//! .n_steps(50)
45//! .learning_rate(0.1)
46//! .build()?;
47//!
48//! let mut runner = AsyncSGBT::new(config);
49//! let sender = runner.sender();
50//! let predictor = runner.predictor();
51//!
52//! // Spawn the training loop.
53//! let train_handle = tokio::spawn(async move { runner.run().await });
54//!
55//! // Feed samples from any async context.
56//! sender.send(Sample::new(vec![1.0, 2.0], 3.0)).await?;
57//!
58//! // Predict concurrently while training proceeds.
59//! let pred = predictor.predict(&[1.0, 2.0]);
60//!
61//! // Drop sender to signal shutdown; training loop returns Ok(()).
62//! drop(sender);
63//! train_handle.await.unwrap()?;
64//! # Ok(())
65//! # }
66//! ```
67
68pub mod adapters;
69pub mod channel;
70
71use std::fmt;
72use std::sync::Arc;
73
74use parking_lot::RwLock;
75use tracing::debug;
76
77use crate::ensemble::config::SGBTConfig;
78use crate::ensemble::SGBT;
79use crate::error::Result;
80use crate::loss::squared::SquaredLoss;
81use crate::loss::Loss;
82
83pub use adapters::{Prediction, PredictionStream};
84pub use channel::{SampleReceiver, SampleSender};
85
86/// Default bounded channel capacity when none is specified.
87const DEFAULT_CHANNEL_CAPACITY: usize = 1024;
88
89// ---------------------------------------------------------------------------
90// Predictor
91// ---------------------------------------------------------------------------
92
93/// A concurrent, read-only prediction handle to a shared [`SGBT`] model.
94///
95/// Obtained via [`AsyncSGBT::predictor`]. Each prediction acquires a read lock
96/// on the underlying `RwLock<SGBT<L>>`, allowing multiple predictors to operate
97/// concurrently and in parallel with the training loop (which holds a write
98/// lock only briefly per sample).
99///
100/// `Predictor` is `Clone`, `Send`, and `Sync` -- share it freely across tasks.
101pub struct Predictor<L: Loss = SquaredLoss> {
102 pub(crate) model: Arc<RwLock<SGBT<L>>>,
103}
104
105// Manual Clone impl -- cloning the Arc doesn't require L: Clone.
106impl<L: Loss> Clone for Predictor<L> {
107 fn clone(&self) -> Self {
108 Self {
109 model: Arc::clone(&self.model),
110 }
111 }
112}
113
114impl<L: Loss> fmt::Debug for Predictor<L> {
115 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116 f.debug_struct("Predictor")
117 .field("n_samples_seen", &self.model.read().n_samples_seen())
118 .finish()
119 }
120}
121
122impl<L: Loss> Predictor<L> {
123 /// Predict the raw model output for a feature vector.
124 ///
125 /// Acquires a read lock on the shared model. Returns the unscaled
126 /// ensemble prediction (base + weighted sum of tree outputs).
127 #[inline]
128 pub fn predict(&self, features: &[f64]) -> f64 {
129 self.model.read().predict(features)
130 }
131
132 /// Predict with the loss function's transform applied.
133 ///
134 /// For regression (squared loss) this is identity; for binary
135 /// classification (logistic loss) this applies the sigmoid.
136 #[inline]
137 pub fn predict_transformed(&self, features: &[f64]) -> f64 {
138 self.model.read().predict_transformed(features)
139 }
140
141 /// Number of samples the model has been trained on so far.
142 #[inline]
143 pub fn n_samples_seen(&self) -> u64 {
144 self.model.read().n_samples_seen()
145 }
146
147 /// Whether the model's base prediction has been initialized.
148 ///
149 /// Returns `false` until enough initial samples have been collected
150 /// (typically 50) to compute the base constant.
151 #[inline]
152 pub fn is_initialized(&self) -> bool {
153 self.model.read().is_initialized()
154 }
155}
156
157// ---------------------------------------------------------------------------
158// AsyncSGBT
159// ---------------------------------------------------------------------------
160
161/// Async wrapper around [`SGBT`] for tokio-native streaming training.
162///
163/// `AsyncSGBT` owns the shared model and the receiving end of a bounded
164/// sample channel. Call [`run`](Self::run) to start the training loop,
165/// which consumes samples from the channel and trains incrementally.
166///
167/// Generic over `L: Loss` so the training loop benefits from monomorphized
168/// gradient/hessian dispatch (no vtable overhead).
169///
170/// Prediction handles ([`Predictor`]) and sender handles ([`SampleSender`])
171/// can be obtained before starting the loop and used concurrently from
172/// other tasks.
173///
174/// # Shutdown
175///
176/// When [`run`](Self::run) is called, it drops the internal sender copy
177/// so that the channel closes as soon as all external senders are dropped.
178/// The loop then drains any remaining buffered samples and returns `Ok(())`.
179pub struct AsyncSGBT<L: Loss = SquaredLoss> {
180 /// Shared model, protected by a parking_lot RwLock.
181 model: Arc<RwLock<SGBT<L>>>,
182 /// Receiving end of the sample channel.
183 receiver: Option<SampleReceiver>,
184 /// Sending end, kept so callers can clone it via `sender()`.
185 /// Wrapped in Option so `run()` can drop it before entering the loop,
186 /// ensuring the channel closes when all external senders are dropped.
187 sender: Option<SampleSender>,
188}
189
190impl AsyncSGBT<SquaredLoss> {
191 /// Create a new async SGBT runner with the default channel capacity (1024).
192 ///
193 /// Uses squared loss (regression). For other loss functions, use
194 /// [`with_loss`](AsyncSGBT::with_loss) or [`with_loss_and_capacity`](AsyncSGBT::with_loss_and_capacity).
195 pub fn new(config: SGBTConfig) -> Self {
196 Self::with_capacity(config, DEFAULT_CHANNEL_CAPACITY)
197 }
198
199 /// Create a new async SGBT runner with a custom channel capacity.
200 ///
201 /// Uses squared loss (regression).
202 pub fn with_capacity(config: SGBTConfig, capacity: usize) -> Self {
203 let model = SGBT::new(config);
204 let shared = Arc::new(RwLock::new(model));
205 let (sender, receiver) = channel::bounded(capacity);
206
207 Self {
208 model: shared,
209 receiver: Some(receiver),
210 sender: Some(sender),
211 }
212 }
213}
214
215impl<L: Loss> AsyncSGBT<L> {
216 /// Create a new async SGBT runner with a specific loss function.
217 ///
218 /// ```no_run
219 /// use irithyll::SGBTConfig;
220 /// use irithyll::stream::AsyncSGBT;
221 /// use irithyll::loss::logistic::LogisticLoss;
222 ///
223 /// let config = SGBTConfig::builder().n_steps(10).build().unwrap();
224 /// let runner = AsyncSGBT::with_loss(config, LogisticLoss);
225 /// ```
226 pub fn with_loss(config: SGBTConfig, loss: L) -> Self {
227 Self::with_loss_and_capacity(config, loss, DEFAULT_CHANNEL_CAPACITY)
228 }
229
230 /// Create a new async SGBT runner with a specific loss and channel capacity.
231 pub fn with_loss_and_capacity(config: SGBTConfig, loss: L, capacity: usize) -> Self {
232 let model = SGBT::with_loss(config, loss);
233 let shared = Arc::new(RwLock::new(model));
234 let (sender, receiver) = channel::bounded(capacity);
235
236 Self {
237 model: shared,
238 receiver: Some(receiver),
239 sender: Some(sender),
240 }
241 }
242
243 /// Obtain a clonable sender handle for feeding samples into the channel.
244 ///
245 /// Multiple senders can be created (via `Clone`) and used from different
246 /// async tasks. The training loop runs until all external senders are
247 /// dropped.
248 ///
249 /// # Panics
250 ///
251 /// Panics if called after [`run`](Self::run) has already started, since
252 /// the internal sender is consumed at that point.
253 pub fn sender(&self) -> SampleSender {
254 self.sender
255 .as_ref()
256 .expect("sender() called after run() consumed the internal sender")
257 .clone()
258 }
259
260 /// Obtain a concurrent prediction handle to the shared model.
261 ///
262 /// The predictor can be cloned and used from any thread or task while
263 /// the training loop is running.
264 pub fn predictor(&self) -> Predictor<L> {
265 Predictor {
266 model: Arc::clone(&self.model),
267 }
268 }
269
270 /// Run the main training loop.
271 ///
272 /// Receives samples from the bounded channel and trains the model
273 /// incrementally. For each sample:
274 ///
275 /// 1. Acquire write lock on the shared `SGBT<L>`.
276 /// 2. Call `train_one(&sample)`.
277 /// 3. Release the lock.
278 ///
279 /// Before entering the loop, the internal sender is dropped so that the
280 /// channel closes cleanly when all external senders are dropped.
281 ///
282 /// Returns `Ok(())` when the channel closes (all senders have been
283 /// dropped and all buffered samples have been consumed).
284 ///
285 /// # Logging
286 ///
287 /// Emits a `tracing::debug!` message every 1000 samples with the
288 /// current sample count.
289 ///
290 /// # Panics
291 ///
292 /// Panics if called more than once (the receiver is consumed on first call).
293 pub async fn run(&mut self) -> Result<()> {
294 // Drop our sender so the channel closes when external senders drop.
295 self.sender.take();
296
297 let receiver = self
298 .receiver
299 .take()
300 .expect("run() called more than once: receiver already consumed");
301
302 self.run_inner(receiver, None::<fn(u64)>).await
303 }
304
305 /// Run the training loop with a callback invoked after each sample.
306 ///
307 /// Behaves identically to [`run`](Self::run), but calls `callback`
308 /// with the current `n_samples_seen()` count after training each sample.
309 /// Useful for progress bars, metrics collection, or adaptive control.
310 ///
311 /// The callback runs synchronously within the training task -- keep it
312 /// fast to avoid blocking the loop.
313 ///
314 /// # Panics
315 ///
316 /// Panics if called more than once (the receiver is consumed on first call).
317 pub async fn run_with_callback<F>(&mut self, callback: F) -> Result<()>
318 where
319 F: Fn(u64),
320 {
321 // Drop our sender so the channel closes when external senders drop.
322 self.sender.take();
323
324 let receiver = self
325 .receiver
326 .take()
327 .expect("run_with_callback() called more than once: receiver already consumed");
328
329 self.run_inner(receiver, Some(callback)).await
330 }
331
332 /// Internal training loop shared by `run` and `run_with_callback`.
333 async fn run_inner<F>(&self, mut receiver: SampleReceiver, callback: Option<F>) -> Result<()>
334 where
335 F: Fn(u64),
336 {
337 while let Some(sample) = receiver.recv().await {
338 let seen;
339 {
340 let mut model = self.model.write();
341 model.train_one(&sample);
342 seen = model.n_samples_seen();
343 }
344
345 if let Some(ref cb) = callback {
346 cb(seen);
347 }
348
349 if seen % 1000 == 0 {
350 debug!(samples_seen = seen, "async training progress");
351 }
352 }
353
354 let total = self.model.read().n_samples_seen();
355 debug!(total_samples = total, "async training loop completed");
356
357 Ok(())
358 }
359}
360
361// ---------------------------------------------------------------------------
362// Tests
363// ---------------------------------------------------------------------------
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368 use crate::ensemble::config::SGBTConfig;
369 use crate::sample::Sample;
370
371 use std::sync::atomic::{AtomicU64, Ordering};
372
373 fn default_config() -> SGBTConfig {
374 SGBTConfig::builder()
375 .n_steps(5)
376 .learning_rate(0.1)
377 .grace_period(10)
378 .max_depth(3)
379 .n_bins(8)
380 .build()
381 .unwrap()
382 }
383
384 fn sample(x: f64) -> Sample {
385 Sample::new(vec![x, x * 0.5], x * 2.0)
386 }
387
388 // 1. Basic lifecycle: send samples, run loop, verify training.
389 #[tokio::test]
390 async fn basic_lifecycle() {
391 let mut runner = AsyncSGBT::new(default_config());
392 let sender = runner.sender();
393 let predictor = runner.predictor();
394
395 // Initially untrained.
396 assert_eq!(predictor.n_samples_seen(), 0);
397 assert!(!predictor.is_initialized());
398
399 let handle = tokio::spawn(async move { runner.run().await });
400
401 for i in 0..20 {
402 sender.send(sample(i as f64)).await.unwrap();
403 }
404 drop(sender);
405
406 handle.await.unwrap().unwrap();
407 assert_eq!(predictor.n_samples_seen(), 20);
408 }
409
410 // 2. Predictor works concurrently with training.
411 #[tokio::test]
412 async fn concurrent_predict_during_training() {
413 let mut runner = AsyncSGBT::new(default_config());
414 let sender = runner.sender();
415 let predictor = runner.predictor();
416
417 let pred_handle = tokio::spawn({
418 let predictor = predictor.clone();
419 async move {
420 // Keep predicting while training runs.
421 let mut predictions = Vec::new();
422 for _ in 0..50 {
423 let p = predictor.predict(&[1.0, 0.5]);
424 predictions.push(p);
425 tokio::task::yield_now().await;
426 }
427 predictions
428 }
429 });
430
431 let train_handle = tokio::spawn(async move { runner.run().await });
432
433 for i in 0..100 {
434 sender.send(sample(i as f64)).await.unwrap();
435 }
436 drop(sender);
437
438 let predictions = pred_handle.await.unwrap();
439 train_handle.await.unwrap().unwrap();
440
441 // All predictions should be finite.
442 assert!(predictions.iter().all(|p| p.is_finite()));
443 }
444
445 // 3. run returns Ok(()) when channel closes immediately (no samples).
446 #[tokio::test]
447 async fn run_returns_ok_on_empty_channel() {
448 let mut runner = AsyncSGBT::new(default_config());
449 let sender = runner.sender();
450 // Drop the only external sender immediately.
451 drop(sender);
452
453 // run() drops the internal sender, so the channel is fully closed.
454 let result = runner.run().await;
455 assert!(result.is_ok());
456 assert_eq!(runner.model.read().n_samples_seen(), 0);
457 }
458
459 // 4. with_capacity creates channel with specified size.
460 #[tokio::test]
461 async fn with_capacity_custom() {
462 let mut runner = AsyncSGBT::with_capacity(default_config(), 2);
463 let sender = runner.sender();
464
465 let handle = tokio::spawn(async move { runner.run().await });
466
467 // Channel capacity 2: should be able to send 2 without blocking.
468 sender.send(sample(1.0)).await.unwrap();
469 sender.send(sample(2.0)).await.unwrap();
470 drop(sender);
471
472 handle.await.unwrap().unwrap();
473 }
474
475 // 5. Multiple senders from different tasks.
476 #[tokio::test]
477 async fn multiple_senders() {
478 let mut runner = AsyncSGBT::new(default_config());
479 let sender1 = runner.sender();
480 let sender2 = runner.sender();
481 let predictor = runner.predictor();
482
483 let handle = tokio::spawn(async move { runner.run().await });
484
485 let h1 = tokio::spawn(async move {
486 for i in 0..10 {
487 sender1.send(sample(i as f64)).await.unwrap();
488 }
489 });
490
491 let h2 = tokio::spawn(async move {
492 for i in 10..20 {
493 sender2.send(sample(i as f64)).await.unwrap();
494 }
495 });
496
497 h1.await.unwrap();
498 h2.await.unwrap();
499
500 // Both senders dropped (moved into tasks that completed).
501 // run() already dropped its internal sender, so the channel closes.
502 // Wait for the training loop to drain and finish.
503 handle.await.unwrap().unwrap();
504
505 assert_eq!(predictor.n_samples_seen(), 20);
506 }
507
508 // 6. run_with_callback invokes callback for each sample.
509 #[tokio::test]
510 async fn run_with_callback_invokes() {
511 let mut runner = AsyncSGBT::new(default_config());
512 let sender = runner.sender();
513
514 let counter = Arc::new(AtomicU64::new(0));
515 let counter_clone = Arc::clone(&counter);
516
517 let handle = tokio::spawn(async move {
518 runner
519 .run_with_callback(move |_seen| {
520 counter_clone.fetch_add(1, Ordering::Relaxed);
521 })
522 .await
523 });
524
525 for i in 0..15 {
526 sender.send(sample(i as f64)).await.unwrap();
527 }
528 drop(sender);
529
530 handle.await.unwrap().unwrap();
531 assert_eq!(counter.load(Ordering::Relaxed), 15);
532 }
533
534 // 7. Callback receives correct sample counts.
535 #[tokio::test]
536 async fn callback_receives_correct_counts() {
537 let mut runner = AsyncSGBT::new(default_config());
538 let sender = runner.sender();
539
540 let counts = Arc::new(parking_lot::Mutex::new(Vec::new()));
541 let counts_clone = Arc::clone(&counts);
542
543 let handle = tokio::spawn(async move {
544 runner
545 .run_with_callback(move |seen| {
546 counts_clone.lock().push(seen);
547 })
548 .await
549 });
550
551 for i in 0..5 {
552 sender.send(sample(i as f64)).await.unwrap();
553 }
554 drop(sender);
555
556 handle.await.unwrap().unwrap();
557
558 let recorded = counts.lock().clone();
559 assert_eq!(recorded.len(), 5);
560 // Counts should be monotonically increasing.
561 for window in recorded.windows(2) {
562 assert!(window[1] > window[0]);
563 }
564 assert_eq!(*recorded.last().unwrap(), 5);
565 }
566
567 // 8. Predictor clone is independent but sees same model.
568 #[tokio::test]
569 async fn predictor_clone_independent() {
570 let runner = AsyncSGBT::new(default_config());
571 let p1 = runner.predictor();
572 let p2 = p1.clone();
573
574 // Both should return the same prediction (same underlying model).
575 let pred1 = p1.predict(&[1.0, 2.0]);
576 let pred2 = p2.predict(&[1.0, 2.0]);
577 assert!((pred1 - pred2).abs() < f64::EPSILON);
578 }
579
580 // 9. predict_transformed works through Predictor.
581 #[tokio::test]
582 async fn predictor_predict_transformed() {
583 let runner = AsyncSGBT::new(default_config());
584 let predictor = runner.predictor();
585
586 // For squared loss, predict_transformed == predict (identity transform).
587 let raw = predictor.predict(&[1.0, 2.0]);
588 let transformed = predictor.predict_transformed(&[1.0, 2.0]);
589 assert!((raw - transformed).abs() < f64::EPSILON);
590 }
591
592 // 10. Predictor is Send + Sync (compile-time check).
593 #[test]
594 fn predictor_is_send_sync() {
595 fn assert_send_sync<T: Send + Sync>() {}
596 assert_send_sync::<Predictor>();
597 }
598
599 // 11. AsyncSGBT is Send (required for tokio::spawn).
600 #[test]
601 fn async_sgbt_is_send() {
602 fn assert_send<T: Send>() {}
603 assert_send::<AsyncSGBT>();
604 }
605
606 // 12. Training actually improves predictions.
607 #[tokio::test]
608 async fn training_improves_predictions() {
609 let mut runner = AsyncSGBT::new(default_config());
610 let sender = runner.sender();
611 let predictor = runner.predictor();
612
613 let handle = tokio::spawn(async move { runner.run().await });
614
615 // Prediction before training.
616 let pred_before = predictor.predict(&[5.0, 2.5]);
617
618 // Send consistent data: target = 10.0.
619 for _ in 0..100 {
620 sender
621 .send(Sample::new(vec![5.0, 2.5], 10.0))
622 .await
623 .unwrap();
624 }
625
626 // Give the training loop time to process.
627 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
628
629 let pred_after = predictor.predict(&[5.0, 2.5]);
630 drop(sender);
631
632 handle.await.unwrap().unwrap();
633
634 // After training on constant target 10.0, prediction should be
635 // closer to 10.0 than the initial (0.0).
636 assert!(
637 (pred_after - 10.0).abs() < (pred_before - 10.0).abs(),
638 "prediction should improve: before={}, after={}, target=10.0",
639 pred_before,
640 pred_after
641 );
642 }
643
644 // 13. with_loss creates async runner with custom loss.
645 #[tokio::test]
646 async fn with_loss_creates_runner() {
647 use crate::loss::logistic::LogisticLoss;
648
649 let config = default_config();
650 let mut runner = AsyncSGBT::with_loss(config, LogisticLoss);
651 let sender = runner.sender();
652 let predictor = runner.predictor();
653
654 // Sigmoid(0) = 0.5 for logistic loss
655 let pred = predictor.predict_transformed(&[1.0, 2.0]);
656 assert!(
657 (pred - 0.5).abs() < 1e-6,
658 "sigmoid(0) should be 0.5, got {}",
659 pred
660 );
661
662 let handle = tokio::spawn(async move { runner.run().await });
663 drop(sender);
664 handle.await.unwrap().unwrap();
665 }
666}