pub mod adapters;
pub mod channel;
use std::fmt;
use std::sync::Arc;
use parking_lot::RwLock;
use tracing::debug;
use crate::ensemble::config::SGBTConfig;
use crate::ensemble::SGBT;
use crate::error::Result;
use crate::loss::squared::SquaredLoss;
use crate::loss::Loss;
pub use adapters::{Prediction, PredictionStream};
pub use channel::{SampleReceiver, SampleSender};
const DEFAULT_CHANNEL_CAPACITY: usize = 1024;
pub struct Predictor<L: Loss = SquaredLoss> {
pub(crate) model: Arc<RwLock<SGBT<L>>>,
}
impl<L: Loss> Clone for Predictor<L> {
fn clone(&self) -> Self {
Self {
model: Arc::clone(&self.model),
}
}
}
impl<L: Loss> fmt::Debug for Predictor<L> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Predictor")
.field("n_samples_seen", &self.model.read().n_samples_seen())
.finish()
}
}
impl<L: Loss> Predictor<L> {
#[inline]
pub fn predict(&self, features: &[f64]) -> f64 {
self.model.read().predict(features)
}
#[inline]
pub fn predict_transformed(&self, features: &[f64]) -> f64 {
self.model.read().predict_transformed(features)
}
#[inline]
pub fn n_samples_seen(&self) -> u64 {
self.model.read().n_samples_seen()
}
#[inline]
pub fn is_initialized(&self) -> bool {
self.model.read().is_initialized()
}
}
pub struct AsyncSGBT<L: Loss = SquaredLoss> {
model: Arc<RwLock<SGBT<L>>>,
receiver: Option<SampleReceiver>,
sender: Option<SampleSender>,
}
impl AsyncSGBT<SquaredLoss> {
pub fn new(config: SGBTConfig) -> Self {
Self::with_capacity(config, DEFAULT_CHANNEL_CAPACITY)
}
pub fn with_capacity(config: SGBTConfig, capacity: usize) -> Self {
let model = SGBT::new(config);
let shared = Arc::new(RwLock::new(model));
let (sender, receiver) = channel::bounded(capacity);
Self {
model: shared,
receiver: Some(receiver),
sender: Some(sender),
}
}
}
impl<L: Loss> AsyncSGBT<L> {
pub fn with_loss(config: SGBTConfig, loss: L) -> Self {
Self::with_loss_and_capacity(config, loss, DEFAULT_CHANNEL_CAPACITY)
}
pub fn with_loss_and_capacity(config: SGBTConfig, loss: L, capacity: usize) -> Self {
let model = SGBT::with_loss(config, loss);
let shared = Arc::new(RwLock::new(model));
let (sender, receiver) = channel::bounded(capacity);
Self {
model: shared,
receiver: Some(receiver),
sender: Some(sender),
}
}
pub fn sender(&self) -> SampleSender {
self.sender
.as_ref()
.expect("sender() called after run() consumed the internal sender")
.clone()
}
pub fn predictor(&self) -> Predictor<L> {
Predictor {
model: Arc::clone(&self.model),
}
}
pub async fn run(&mut self) -> Result<()> {
self.sender.take();
let receiver = self
.receiver
.take()
.expect("run() called more than once: receiver already consumed");
self.run_inner(receiver, None::<fn(u64)>).await
}
pub async fn run_with_callback<F>(&mut self, callback: F) -> Result<()>
where
F: Fn(u64),
{
self.sender.take();
let receiver = self
.receiver
.take()
.expect("run_with_callback() called more than once: receiver already consumed");
self.run_inner(receiver, Some(callback)).await
}
async fn run_inner<F>(&self, mut receiver: SampleReceiver, callback: Option<F>) -> Result<()>
where
F: Fn(u64),
{
while let Some(sample) = receiver.recv().await {
let seen;
{
let mut model = self.model.write();
model.train_one(&sample);
seen = model.n_samples_seen();
}
if let Some(ref cb) = callback {
cb(seen);
}
if seen % 1000 == 0 {
debug!(samples_seen = seen, "async training progress");
}
}
let total = self.model.read().n_samples_seen();
debug!(total_samples = total, "async training loop completed");
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ensemble::config::SGBTConfig;
use crate::sample::Sample;
use std::sync::atomic::{AtomicU64, Ordering};
fn default_config() -> SGBTConfig {
SGBTConfig::builder()
.n_steps(5)
.learning_rate(0.1)
.grace_period(10)
.max_depth(3)
.n_bins(8)
.build()
.unwrap()
}
fn sample(x: f64) -> Sample {
Sample::new(vec![x, x * 0.5], x * 2.0)
}
#[tokio::test]
async fn basic_lifecycle() {
let mut runner = AsyncSGBT::new(default_config());
let sender = runner.sender();
let predictor = runner.predictor();
assert_eq!(predictor.n_samples_seen(), 0);
assert!(!predictor.is_initialized());
let handle = tokio::spawn(async move { runner.run().await });
for i in 0..20 {
sender.send(sample(i as f64)).await.unwrap();
}
drop(sender);
handle.await.unwrap().unwrap();
assert_eq!(predictor.n_samples_seen(), 20);
}
#[tokio::test]
async fn concurrent_predict_during_training() {
let mut runner = AsyncSGBT::new(default_config());
let sender = runner.sender();
let predictor = runner.predictor();
let pred_handle = tokio::spawn({
let predictor = predictor.clone();
async move {
let mut predictions = Vec::new();
for _ in 0..50 {
let p = predictor.predict(&[1.0, 0.5]);
predictions.push(p);
tokio::task::yield_now().await;
}
predictions
}
});
let train_handle = tokio::spawn(async move { runner.run().await });
for i in 0..100 {
sender.send(sample(i as f64)).await.unwrap();
}
drop(sender);
let predictions = pred_handle.await.unwrap();
train_handle.await.unwrap().unwrap();
assert!(predictions.iter().all(|p| p.is_finite()));
}
#[tokio::test]
async fn run_returns_ok_on_empty_channel() {
let mut runner = AsyncSGBT::new(default_config());
let sender = runner.sender();
drop(sender);
let result = runner.run().await;
assert!(result.is_ok());
assert_eq!(runner.model.read().n_samples_seen(), 0);
}
#[tokio::test]
async fn with_capacity_custom() {
let mut runner = AsyncSGBT::with_capacity(default_config(), 2);
let sender = runner.sender();
let handle = tokio::spawn(async move { runner.run().await });
sender.send(sample(1.0)).await.unwrap();
sender.send(sample(2.0)).await.unwrap();
drop(sender);
handle.await.unwrap().unwrap();
}
#[tokio::test]
async fn multiple_senders() {
let mut runner = AsyncSGBT::new(default_config());
let sender1 = runner.sender();
let sender2 = runner.sender();
let predictor = runner.predictor();
let handle = tokio::spawn(async move { runner.run().await });
let h1 = tokio::spawn(async move {
for i in 0..10 {
sender1.send(sample(i as f64)).await.unwrap();
}
});
let h2 = tokio::spawn(async move {
for i in 10..20 {
sender2.send(sample(i as f64)).await.unwrap();
}
});
h1.await.unwrap();
h2.await.unwrap();
handle.await.unwrap().unwrap();
assert_eq!(predictor.n_samples_seen(), 20);
}
#[tokio::test]
async fn run_with_callback_invokes() {
let mut runner = AsyncSGBT::new(default_config());
let sender = runner.sender();
let counter = Arc::new(AtomicU64::new(0));
let counter_clone = Arc::clone(&counter);
let handle = tokio::spawn(async move {
runner
.run_with_callback(move |_seen| {
counter_clone.fetch_add(1, Ordering::Relaxed);
})
.await
});
for i in 0..15 {
sender.send(sample(i as f64)).await.unwrap();
}
drop(sender);
handle.await.unwrap().unwrap();
assert_eq!(counter.load(Ordering::Relaxed), 15);
}
#[tokio::test]
async fn callback_receives_correct_counts() {
let mut runner = AsyncSGBT::new(default_config());
let sender = runner.sender();
let counts = Arc::new(parking_lot::Mutex::new(Vec::new()));
let counts_clone = Arc::clone(&counts);
let handle = tokio::spawn(async move {
runner
.run_with_callback(move |seen| {
counts_clone.lock().push(seen);
})
.await
});
for i in 0..5 {
sender.send(sample(i as f64)).await.unwrap();
}
drop(sender);
handle.await.unwrap().unwrap();
let recorded = counts.lock().clone();
assert_eq!(recorded.len(), 5);
for window in recorded.windows(2) {
assert!(window[1] > window[0]);
}
assert_eq!(*recorded.last().unwrap(), 5);
}
#[tokio::test]
async fn predictor_clone_independent() {
let runner = AsyncSGBT::new(default_config());
let p1 = runner.predictor();
let p2 = p1.clone();
let pred1 = p1.predict(&[1.0, 2.0]);
let pred2 = p2.predict(&[1.0, 2.0]);
assert!((pred1 - pred2).abs() < f64::EPSILON);
}
#[tokio::test]
async fn predictor_predict_transformed() {
let runner = AsyncSGBT::new(default_config());
let predictor = runner.predictor();
let raw = predictor.predict(&[1.0, 2.0]);
let transformed = predictor.predict_transformed(&[1.0, 2.0]);
assert!((raw - transformed).abs() < f64::EPSILON);
}
#[test]
fn predictor_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<Predictor>();
}
#[test]
fn async_sgbt_is_send() {
fn assert_send<T: Send>() {}
assert_send::<AsyncSGBT>();
}
#[tokio::test]
async fn training_improves_predictions() {
let mut runner = AsyncSGBT::new(default_config());
let sender = runner.sender();
let predictor = runner.predictor();
let handle = tokio::spawn(async move { runner.run().await });
let pred_before = predictor.predict(&[5.0, 2.5]);
for _ in 0..100 {
sender
.send(Sample::new(vec![5.0, 2.5], 10.0))
.await
.unwrap();
}
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let pred_after = predictor.predict(&[5.0, 2.5]);
drop(sender);
handle.await.unwrap().unwrap();
assert!(
(pred_after - 10.0).abs() < (pred_before - 10.0).abs(),
"prediction should improve: before={}, after={}, target=10.0",
pred_before,
pred_after
);
}
#[tokio::test]
async fn with_loss_creates_runner() {
use crate::loss::logistic::LogisticLoss;
let config = default_config();
let mut runner = AsyncSGBT::with_loss(config, LogisticLoss);
let sender = runner.sender();
let predictor = runner.predictor();
let pred = predictor.predict_transformed(&[1.0, 2.0]);
assert!(
(pred - 0.5).abs() < 1e-6,
"sigmoid(0) should be 0.5, got {}",
pred
);
let handle = tokio::spawn(async move { runner.run().await });
drop(sender);
handle.await.unwrap().unwrap();
}
}