use std::pin::Pin;
use std::task::{Context, Poll};
use futures_core::Stream;
use super::channel::SampleReceiver;
use super::Predictor;
#[derive(Debug, Clone)]
pub struct Prediction {
pub raw: f64,
pub transformed: f64,
pub target: f64,
pub features: Vec<f64>,
}
pub struct PredictionStream {
receiver: SampleReceiver,
predictor: Predictor,
}
impl PredictionStream {
pub fn new(receiver: SampleReceiver, predictor: Predictor) -> Self {
Self {
receiver,
predictor,
}
}
}
impl Stream for PredictionStream {
type Item = Prediction;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
match this.receiver.poll_recv(cx) {
Poll::Ready(Some(sample)) => {
let raw = this.predictor.predict(&sample.features);
let transformed = this.predictor.predict_transformed(&sample.features);
Poll::Ready(Some(Prediction {
raw,
transformed,
target: sample.target,
features: sample.features,
}))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
(0, None)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ensemble::config::SGBTConfig;
use crate::ensemble::SGBT;
use crate::sample::Sample;
use crate::stream::channel;
use parking_lot::RwLock;
use std::sync::Arc;
use futures_core::Stream;
use std::pin::Pin;
fn default_config() -> SGBTConfig {
SGBTConfig::builder()
.n_steps(5)
.learning_rate(0.1)
.grace_period(10)
.max_depth(3)
.n_bins(8)
.build()
.unwrap()
}
fn make_predictor(config: SGBTConfig) -> (Arc<RwLock<SGBT>>, Predictor) {
let model = SGBT::new(config);
let shared = Arc::new(RwLock::new(model));
let predictor = Predictor {
model: Arc::clone(&shared),
};
(shared, predictor)
}
fn sample(x: f64) -> Sample {
Sample::new(vec![x, x * 0.5], x * 2.0)
}
#[tokio::test]
async fn stream_yields_predictions() {
let (_shared, predictor) = make_predictor(default_config());
let (tx, rx) = channel::bounded(16);
let mut stream = PredictionStream::new(rx, predictor);
tx.send(sample(1.0)).await.unwrap();
drop(tx);
let pred = poll_next(&mut stream).await;
assert!(pred.is_some());
let pred = pred.unwrap();
assert!((pred.target - 2.0).abs() < f64::EPSILON);
assert_eq!(pred.features.len(), 2);
assert!(pred.raw.is_finite());
assert!(pred.transformed.is_finite());
}
#[tokio::test]
async fn stream_terminates_on_close() {
let (_shared, predictor) = make_predictor(default_config());
let (tx, rx) = channel::bounded(16);
drop(tx);
let mut stream = PredictionStream::new(rx, predictor);
let pred = poll_next(&mut stream).await;
assert!(pred.is_none());
}
#[tokio::test]
async fn prediction_preserves_sample_data() {
let (_shared, predictor) = make_predictor(default_config());
let (tx, rx) = channel::bounded(16);
let mut stream = PredictionStream::new(rx, predictor);
let s = Sample::weighted(vec![3.25, 2.71], 42.0, 1.5);
tx.send(s).await.unwrap();
drop(tx);
let pred = poll_next(&mut stream).await.unwrap();
assert!((pred.features[0] - 3.25).abs() < f64::EPSILON);
assert!((pred.features[1] - 2.71).abs() < f64::EPSILON);
assert!((pred.target - 42.0).abs() < f64::EPSILON);
}
#[tokio::test]
async fn multiple_predictions() {
let (_shared, predictor) = make_predictor(default_config());
let (tx, rx) = channel::bounded(16);
let mut stream = PredictionStream::new(rx, predictor);
for i in 0..5 {
tx.send(sample(i as f64)).await.unwrap();
}
drop(tx);
let mut count = 0;
while let Some(pred) = poll_next(&mut stream).await {
assert!(pred.raw.is_finite());
count += 1;
}
assert_eq!(count, 5);
}
#[tokio::test]
async fn predictions_reflect_model_state() {
let (shared, predictor) = make_predictor(default_config());
let (tx, rx) = channel::bounded(16);
let mut stream = PredictionStream::new(rx, predictor);
tx.send(sample(1.0)).await.unwrap();
let pred_before = poll_next(&mut stream).await.unwrap();
{
let mut model = shared.write();
for i in 0..100 {
model.train_one(&Sample::new(vec![1.0, 0.5], 10.0 + i as f64 * 0.01));
}
}
tx.send(sample(1.0)).await.unwrap();
drop(tx);
let pred_after = poll_next(&mut stream).await.unwrap();
assert!(
(pred_after.raw - pred_before.raw).abs() > f64::EPSILON
|| pred_before.raw.abs() < f64::EPSILON,
"predictions should reflect training: before={}, after={}",
pred_before.raw,
pred_after.raw
);
}
#[test]
fn prediction_debug() {
let p = Prediction {
raw: 1.0,
transformed: 0.73,
target: 1.0,
features: vec![0.5, 0.6],
};
let dbg = format!("{:?}", p);
assert!(dbg.contains("raw"));
assert!(dbg.contains("transformed"));
}
#[test]
fn prediction_clone() {
let p = Prediction {
raw: 2.5,
transformed: 0.92,
target: 1.0,
features: vec![1.0, 2.0, 3.0],
};
let p2 = p.clone();
assert!((p2.raw - 2.5).abs() < f64::EPSILON);
assert_eq!(p2.features.len(), 3);
}
#[tokio::test]
async fn stream_size_hint() {
let (_shared, predictor) = make_predictor(default_config());
let (_tx, rx) = channel::bounded(16);
let stream = PredictionStream::new(rx, predictor);
let (lo, hi) = stream.size_hint();
assert_eq!(lo, 0);
assert!(hi.is_none());
}
#[tokio::test]
async fn empty_stream() {
let (_shared, predictor) = make_predictor(default_config());
let (tx, rx) = channel::bounded(16);
drop(tx);
let mut stream = PredictionStream::new(rx, predictor);
assert!(poll_next(&mut stream).await.is_none());
}
#[test]
fn stream_is_unpin() {
fn assert_unpin<T: Unpin>() {}
assert_unpin::<PredictionStream>();
}
async fn poll_next<S: Stream + Unpin>(stream: &mut S) -> Option<S::Item> {
use std::future::poll_fn;
poll_fn(|cx| Pin::new(&mut *stream).poll_next(cx)).await
}
}