1use std::{
8 error::Error,
9 io,
10 ops::Deref,
11 sync::mpsc::{Receiver, Sender},
12};
13
14use crate::{
15 algorithm::Algo,
16 model::{Ball, Model},
17};
18use serde::{de::DeserializeOwned, Serialize};
19use serde_json::{json, Map, Value};
20
21pub struct Streamer<In, Out>
37where
38 In: Iterator<Item = Result<String, Box<dyn Error>>>,
39 Out: FnMut(String) -> Result<(), Box<dyn Error>>,
40{
41 points: In,
42 write: Out,
43}
44
45impl<In, Out> Streamer<In, Out>
46where
47 In: Iterator<Item = Result<String, Box<dyn Error>>>,
48 Out: FnMut(String) -> Result<(), Box<dyn Error>>,
49{
50 pub fn new(points: In, write: Out) -> Self {
52 Self { points, write }
53 }
54
55 pub fn run<Point: PartialEq + Serialize + DeserializeOwned + 'static>(
57 mut streamer: Streamer<In, Out>,
58 algo: Algo<Point>,
59 model: &mut Model<Point>,
60 ) -> Result<(), Box<dyn Error>> {
61 for input in streamer.points {
62 let point_str = input?;
63 let point: Point = serde_json::from_str(&point_str)?;
64 algo.fit(model, point);
65 let balls = serialize_model(model);
66 let output = serde_json::to_string(&balls)?;
67 (streamer.write)(output)?;
68 }
69 Ok(())
70 }
71}
72
73fn serialize_model<Point: PartialEq + Serialize + 'static>(
74 model: &Model<Point>,
75) -> Vec<Map<String, Value>> {
76 let balls: Vec<_> = model
77 .iter_balls()
78 .map(|data| serialize_ball(data))
79 .collect();
80 balls
81}
82
83fn serialize_ball<Point: PartialEq + Serialize>(
84 data: impl Deref<Target = Ball<Point>>,
85) -> Map<String, Value> {
86 let mut map = Map::new();
87 map.insert("center".into(), json!(data.center()));
88 map.insert("radius".into(), json!(data.radius()));
89 map.insert("weight".into(), json!(data.weight()));
90 map
91}
92
93pub fn stdio() -> (
95 impl Iterator<Item = Result<String, Box<dyn Error>>>,
96 impl FnMut(String) -> Result<(), Box<dyn Error>>,
97) {
98 let points = io::stdin()
99 .lines()
100 .map(|f| -> Result<String, Box<dyn Error>> { Ok(f?) });
101 let write = |model| {
102 println!("{}", model);
103 Ok(())
104 };
105 (points, write)
106}
107
108pub fn channels(
110 point_receiver: Receiver<String>,
111 model_producer: Sender<String>,
112) -> (
113 impl Iterator<Item = Result<String, Box<dyn Error>>>,
114 impl FnMut(String) -> Result<(), Box<dyn Error>>,
115) {
116 let points = point_receiver.into_iter().map(|f| Ok(f));
117 let write = move |model| {
118 model_producer.send(model)?;
119 Ok(())
120 };
121 (points, write)
122}
123
124#[cfg(test)]
125mod tests {
126
127 use std::sync::mpsc;
128
129 use crate::{space, streamer::*};
130
131 #[test]
132 fn test_serialize_ball() {
133 let obj = serialize_ball(&Ball::new(vec![3., 5.1], 4.7, 0.999));
134 let json = serde_json::to_string(&obj).unwrap();
135 assert_eq!(
136 r#"{"center":[3.0,5.1],"radius":2.16794833886788,"weight":0.999}"#,
137 json
138 );
139 }
140
141 #[test]
142 fn test_serialize_model() {
143 let mut model = Model::new(space::euclid_dist);
144 let v = model.add_ball(Ball::new(vec![3., 5.1], 4.7, 0.999), vec![]);
145 model.add_ball(Ball::new(vec![1.2, 6.], 1.3, 3.998), vec![v.as_neighbor()]);
146 let obj = serialize_model(&model);
147 let json = serde_json::to_string(&obj).unwrap();
148 assert_eq!(
149 r#"[{"center":[3.0,5.1],"radius":2.16794833886788,"weight":0.999},{"center":[1.2,6.0],"radius":1.140175425099138,"weight":3.998}]"#,
150 json
151 );
152 }
153
154 #[test]
155 fn test_streamer() {
156 let algo = Algo::new(space::euclid_dist, space::real_combine);
157 let mut model = Model::new(space::euclid_dist);
158 let points = vec![Ok(String::from("[1.0,1.0]"))].into_iter();
159 let mut result = String::new();
160 let write = |s| {
161 result = s;
162 Ok(())
163 };
164 let streamer = Streamer::new(points, write);
165 match Streamer::run(streamer, algo, &mut model) {
166 Ok(()) => assert_eq!(
167 r#"[{"center":[1.0,1.0],"radius":null,"weight":0.0}]"#,
168 result
169 ),
170 Err(_) => panic!(),
171 };
172 }
173
174 #[test]
175 fn test_channels() {
176 let (point_producer, point_receiver) = mpsc::channel();
177 let (model_producer, model_receiver) = mpsc::channel();
178 let (mut points, mut write) = channels(point_receiver, model_producer);
179 point_producer.send(String::from("point")).unwrap();
180 let p = points.next().unwrap().unwrap();
181 assert_eq!("point", p);
182 (write)(String::from("model")).unwrap();
183 let m = model_receiver.recv().unwrap();
184 assert_eq!("model", m);
185 }
186}