fluent_data/
streamer.rs

1//! The [Streamer] continuously consumes data points and produces models.
2//!
3//! This module also provides the [stdio] function that builds
4//! a point iterator which reads the standard input and a
5//! write closure that writes to the standard output.
6
7use std::{
8    error::Error,
9    io,
10    ops::Deref,
11    sync::mpsc::{Receiver, Sender},
12};
13
14use crate::{
15    algorithm::Algo,
16    model::{Ball, Model},
17};
18use serde::{de::DeserializeOwned, Serialize};
19use serde_json::{json, Map, Value};
20
21/// Reads data from `In` and writes model to `Out`.
22/// ```
23/// use std::{error::Error, io};
24///
25/// use fluent_data::{algorithm::Algo, model::Model, space, streamer::{Streamer, self}};
26///
27/// fn main() -> Result<(), Box<dyn Error>> {
28///     let algo = Algo::new(space::euclid_dist, space::real_combine);
29///     let mut model = Model::new(space::euclid_dist);
30///     let (points, write) = streamer::stdio();
31///     let streamer = Streamer::new(points, write);
32///     Streamer::run(streamer, algo, &mut model)?;
33///     Ok(())
34/// }
35/// ```
36pub struct Streamer<In, Out>
37where
38    In: Iterator<Item = Result<String, Box<dyn Error>>>,
39    Out: FnMut(String) -> Result<(), Box<dyn Error>>,
40{
41    points: In,
42    write: Out,
43}
44
45impl<In, Out> Streamer<In, Out>
46where
47    In: Iterator<Item = Result<String, Box<dyn Error>>>,
48    Out: FnMut(String) -> Result<(), Box<dyn Error>>,
49{
50    /// builds a new streamer instance.
51    pub fn new(points: In, write: Out) -> Self {
52        Self { points, write }
53    }
54
55    /// Infinitely reads points from `In` source and write model changes to `Out` sink.
56    pub fn run<Point: PartialEq + Serialize + DeserializeOwned + 'static>(
57        mut streamer: Streamer<In, Out>,
58        algo: Algo<Point>,
59        model: &mut Model<Point>,
60    ) -> Result<(), Box<dyn Error>> {
61        for input in streamer.points {
62            let point_str = input?;
63            let point: Point = serde_json::from_str(&point_str)?;
64            algo.fit(model, point);
65            let balls = serialize_model(model);
66            let output = serde_json::to_string(&balls)?;
67            (streamer.write)(output)?;
68        }
69        Ok(())
70    }
71}
72
73fn serialize_model<Point: PartialEq + Serialize + 'static>(
74    model: &Model<Point>,
75) -> Vec<Map<String, Value>> {
76    let balls: Vec<_> = model
77        .iter_balls()
78        .map(|data| serialize_ball(data))
79        .collect();
80    balls
81}
82
83fn serialize_ball<Point: PartialEq + Serialize>(
84    data: impl Deref<Target = Ball<Point>>,
85) -> Map<String, Value> {
86    let mut map = Map::new();
87    map.insert("center".into(), json!(data.center()));
88    map.insert("radius".into(), json!(data.radius()));
89    map.insert("weight".into(), json!(data.weight()));
90    map
91}
92
93/// Returns point iterator / model writer that use standard in out.
94pub fn stdio() -> (
95    impl Iterator<Item = Result<String, Box<dyn Error>>>,
96    impl FnMut(String) -> Result<(), Box<dyn Error>>,
97) {
98    let points = io::stdin()
99        .lines()
100        .map(|f| -> Result<String, Box<dyn Error>> { Ok(f?) });
101    let write = |model| {
102        println!("{}", model);
103        Ok(())
104    };
105    (points, write)
106}
107
108/// Returns point iterator / model writer that use mpsc channels.
109pub fn channels(
110    point_receiver: Receiver<String>,
111    model_producer: Sender<String>,
112) -> (
113    impl Iterator<Item = Result<String, Box<dyn Error>>>,
114    impl FnMut(String) -> Result<(), Box<dyn Error>>,
115) {
116    let points = point_receiver.into_iter().map(|f| Ok(f));
117    let write = move |model| {
118        model_producer.send(model)?;
119        Ok(())
120    };
121    (points, write)
122}
123
124#[cfg(test)]
125mod tests {
126
127    use std::sync::mpsc;
128
129    use crate::{space, streamer::*};
130
131    #[test]
132    fn test_serialize_ball() {
133        let obj = serialize_ball(&Ball::new(vec![3., 5.1], 4.7, 0.999));
134        let json = serde_json::to_string(&obj).unwrap();
135        assert_eq!(
136            r#"{"center":[3.0,5.1],"radius":2.16794833886788,"weight":0.999}"#,
137            json
138        );
139    }
140
141    #[test]
142    fn test_serialize_model() {
143        let mut model = Model::new(space::euclid_dist);
144        let v = model.add_ball(Ball::new(vec![3., 5.1], 4.7, 0.999), vec![]);
145        model.add_ball(Ball::new(vec![1.2, 6.], 1.3, 3.998), vec![v.as_neighbor()]);
146        let obj = serialize_model(&model);
147        let json = serde_json::to_string(&obj).unwrap();
148        assert_eq!(
149            r#"[{"center":[3.0,5.1],"radius":2.16794833886788,"weight":0.999},{"center":[1.2,6.0],"radius":1.140175425099138,"weight":3.998}]"#,
150            json
151        );
152    }
153
154    #[test]
155    fn test_streamer() {
156        let algo = Algo::new(space::euclid_dist, space::real_combine);
157        let mut model = Model::new(space::euclid_dist);
158        let points = vec![Ok(String::from("[1.0,1.0]"))].into_iter();
159        let mut result = String::new();
160        let write = |s| {
161            result = s;
162            Ok(())
163        };
164        let streamer = Streamer::new(points, write);
165        match Streamer::run(streamer, algo, &mut model) {
166            Ok(()) => assert_eq!(
167                r#"[{"center":[1.0,1.0],"radius":null,"weight":0.0}]"#,
168                result
169            ),
170            Err(_) => panic!(),
171        };
172    }
173
174    #[test]
175    fn test_channels() {
176        let (point_producer, point_receiver) = mpsc::channel();
177        let (model_producer, model_receiver) = mpsc::channel();
178        let (mut points, mut write) = channels(point_receiver, model_producer);
179        point_producer.send(String::from("point")).unwrap();
180        let p = points.next().unwrap().unwrap();
181        assert_eq!("point", p);
182        (write)(String::from("model")).unwrap();
183        let m = model_receiver.recv().unwrap();
184        assert_eq!("model", m);
185    }
186}