1use std::collections::{HashMap, HashSet};
5use std::error::Error;
6use std::fmt::{Display, Formatter};
7use std::marker::PhantomData;
8
9use rtlola_frontend::mir::{OutputReference, Stream, StreamReference};
10use rtlola_frontend::RtLolaMir;
11
12use crate::monitor::{Change, Total, TotalIncremental, VerdictRepresentation};
13use crate::time::{OutputTimeRepresentation, TimeConversion};
14use crate::{Value, ValueConvertError};
15
16pub trait NewVerdictFactory<
18 MonitorOutput: VerdictRepresentation,
19 OutputTime: OutputTimeRepresentation,
20>: VerdictFactory<MonitorOutput, OutputTime> + Sized
21{
22 type CreationData;
24 type CreationError;
26
27 fn new(ir: &RtLolaMir, data: Self::CreationData) -> Result<Self, Self::CreationError>;
29}
30
31pub trait VerdictFactory<MonitorOutput: VerdictRepresentation, OutputTime: OutputTimeRepresentation>
34{
35 type Record;
37
38 type Error: Error + 'static;
40
41 fn get_verdict(
43 &mut self,
44 rec: MonitorOutput,
45 ts: OutputTime::InnerTime,
46 ) -> Result<Self::Record, Self::Error>;
47}
48
49pub trait AssociatedVerdictFactory<
51 MonitorOutput: VerdictRepresentation,
52 OutputTime: OutputTimeRepresentation,
53>
54{
55 type Factory: NewVerdictFactory<MonitorOutput, OutputTime>;
57}
58
59#[derive(Debug, Clone)]
62pub enum StreamValue {
63 Stream(Option<Value>),
65 Instances(HashMap<Vec<Value>, Value>),
67}
68
69pub trait FromValues: Sized {
71 type OutputTime;
73
74 fn streams() -> Vec<String>;
76
77 fn construct(ts: Self::OutputTime, data: Vec<StreamValue>) -> Result<Self, FromValuesError>;
80}
81
82impl<V, ExpectedTime, MonitorTime> AssociatedVerdictFactory<Total, MonitorTime> for V
83where
84 V: FromValues<OutputTime = ExpectedTime>,
85 MonitorTime: TimeConversion<ExpectedTime>,
86{
87 type Factory = StructVerdictFactory<V>;
88}
89
90impl<V, ExpectedTime, MonitorTime> AssociatedVerdictFactory<TotalIncremental, MonitorTime> for V
91where
92 V: FromValues<OutputTime = ExpectedTime>,
93 MonitorTime: TimeConversion<ExpectedTime>,
94{
95 type Factory = StructVerdictFactory<V>;
96}
97
98#[derive(Debug)]
100pub enum FromValuesError {
101 ValueConversion(ValueConvertError),
103 ExpectedValue {
105 stream_name: String,
107 },
108 InvalidHashMap {
110 stream_name: String,
112 expected_num_params: usize,
114 got_number_params: usize,
116 },
117 StreamKindMismatch,
119}
120
121impl Display for FromValuesError {
122 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
123 match self {
124 FromValuesError::ValueConversion(v) => write!(f, "{}", v),
125 FromValuesError::ExpectedValue { stream_name } => {
126 write!(
127 f,
128 "The value for stream {} was expected to exist but was not present in the monitor verdict.",
129 stream_name
130 )
131 }
132 FromValuesError::InvalidHashMap {
133 stream_name,
134 expected_num_params,
135 got_number_params,
136 } => {
137 write!(
138 f,
139 "Mismatch in the number of parameters of stream {}\nExpected {} parameters, but got {}",
140 stream_name, expected_num_params, got_number_params
141 )
142 }
143 FromValuesError::StreamKindMismatch => {
144 write!(
145 f,
146 "Expected a parameterized stream but got a non-parameterized stream or vice-versa."
147 )
148 }
149 }
150 }
151}
152
153impl Error for FromValuesError {
154 fn source(&self) -> Option<&(dyn Error + 'static)> {
155 match self {
156 FromValuesError::ValueConversion(e) => Some(e),
157 FromValuesError::ExpectedValue { .. } => None,
158 FromValuesError::InvalidHashMap { .. } => None,
159 FromValuesError::StreamKindMismatch => None,
160 }
161 }
162}
163
164impl From<ValueConvertError> for FromValuesError {
165 fn from(value: ValueConvertError) -> Self {
166 Self::ValueConversion(value)
167 }
168}
169
170#[derive(Debug)]
172pub enum StructVerdictError {
173 UnknownStream(String),
175 ValueError(FromValuesError),
177}
178impl Display for StructVerdictError {
179 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
180 match self {
181 StructVerdictError::UnknownStream(field) => {
182 write!(f, "No stream found for struct field: {}", field)
183 }
184 StructVerdictError::ValueError(ve) => write!(f, "{}", ve),
185 }
186 }
187}
188
189impl Error for StructVerdictError {
190 fn source(&self) -> Option<&(dyn Error + 'static)> {
191 match self {
192 StructVerdictError::UnknownStream(_) => None,
193 StructVerdictError::ValueError(ve) => Some(ve),
194 }
195 }
196}
197
198impl From<FromValuesError> for StructVerdictError {
199 fn from(value: FromValuesError) -> Self {
200 Self::ValueError(value)
201 }
202}
203
204#[derive(Debug, Clone)]
206pub struct StructVerdictFactory<V: FromValues> {
207 map: Vec<StreamReference>,
209 map_inv: HashMap<StreamReference, usize>,
211 parameterized_streams: HashSet<OutputReference>,
212 inner: PhantomData<V>,
213}
214
215impl<V: FromValues> StructVerdictFactory<V> {
216 pub fn new(ir: &RtLolaMir) -> Result<Self, StructVerdictError> {
218 let map: Vec<StreamReference> = V::streams()
219 .iter()
220 .map(|name| {
221 ir.get_stream_by_name(name)
222 .map(|s| s.as_stream_ref())
223 .or_else(|| {
224 name.starts_with("trigger_")
225 .then(|| name.split_once('_'))
226 .flatten()
227 .and_then(|(_, trg_idx)| trg_idx.parse::<usize>().ok())
228 .and_then(|trg_idx| {
229 ir.triggers.get(trg_idx).map(|trg| trg.output_reference)
230 })
231 })
232 .ok_or_else(|| StructVerdictError::UnknownStream(name.to_string()))
233 })
234 .collect::<Result<_, _>>()?;
235 let map_inv = map.iter().enumerate().map(|(idx, sr)| (*sr, idx)).collect();
236 let parameterized_streams = ir
237 .outputs
238 .iter()
239 .filter(|os| os.is_parameterized())
240 .map(|o| o.reference.out_ix())
241 .collect();
242 Ok(Self {
243 map,
244 map_inv,
245 parameterized_streams,
246 inner: Default::default(),
247 })
248 }
249}
250
251impl<O, I, V> VerdictFactory<Total, O> for StructVerdictFactory<V>
252where
253 V: FromValues<OutputTime = I>,
254 O: OutputTimeRepresentation + TimeConversion<I>,
255{
256 type Error = StructVerdictError;
257 type Record = V;
258
259 fn get_verdict(&mut self, rec: Total, ts: O::InnerTime) -> Result<Self::Record, Self::Error> {
260 let values: Vec<StreamValue> = self
261 .map
262 .iter()
263 .map(|sr| match sr {
264 StreamReference::In(i) => StreamValue::Stream(rec.inputs[*i].clone()),
265 StreamReference::Out(o) if !self.parameterized_streams.contains(o) => {
266 StreamValue::Stream(rec.outputs[*o][0].1.clone())
267 }
268 StreamReference::Out(o) => StreamValue::Instances(
269 rec.outputs[*o]
270 .iter()
271 .filter(|(_, value)| value.is_some())
272 .map(|(inst, val)| (inst.clone().unwrap(), val.clone().unwrap()))
273 .collect(),
274 ),
275 })
276 .collect();
277 let time = O::into(ts);
278 Ok(V::construct(time, values)?)
279 }
280}
281
282impl<O, I, V> NewVerdictFactory<Total, O> for StructVerdictFactory<V>
283where
284 V: FromValues<OutputTime = I>,
285 O: OutputTimeRepresentation + TimeConversion<I>,
286{
287 type CreationData = ();
288 type CreationError = StructVerdictError;
289
290 fn new(ir: &RtLolaMir, _data: Self::CreationData) -> Result<Self, Self::Error> {
291 Self::new(ir)
292 }
293}
294
295impl<O, I, V> VerdictFactory<TotalIncremental, O> for StructVerdictFactory<V>
296where
297 V: FromValues<OutputTime = I>,
298 O: OutputTimeRepresentation + TimeConversion<I>,
299{
300 type Error = StructVerdictError;
301 type Record = V;
302
303 fn get_verdict(
304 &mut self,
305 rec: TotalIncremental,
306 ts: O::InnerTime,
307 ) -> Result<Self::Record, Self::Error> {
308 let mut values: Vec<StreamValue> = self
309 .map
310 .iter()
311 .map(|sr| {
312 if sr.is_output() && self.parameterized_streams.contains(&sr.out_ix()) {
313 StreamValue::Instances(HashMap::new())
314 } else {
315 StreamValue::Stream(None)
316 }
317 })
318 .collect();
319
320 for (ir, v) in rec.inputs {
321 if let Some(idx) = self.map_inv.get(&StreamReference::In(ir)) {
322 values[*idx] = StreamValue::Stream(Some(v));
323 }
324 }
325 for (or, changes) in rec.outputs {
326 if let Some(idx) = self.map_inv.get(&StreamReference::Out(or)) {
327 if self.parameterized_streams.contains(&or) {
328 let StreamValue::Instances(res) = &mut values[*idx] else {
329 unreachable!("Mapping did not work!");
330 };
331 for change in changes {
332 if let Change::Value(p, v) = change {
333 res.insert(p.unwrap(), v);
334 }
335 }
336 } else {
337 let value = changes.into_iter().find_map(|change| {
338 if let Change::Value(_, v) = change {
339 Some(v)
340 } else {
341 None
342 }
343 });
344 values[*idx] = StreamValue::Stream(value);
345 }
346 }
347 }
348 let time = O::into(ts);
349 Ok(V::construct(time, values)?)
350 }
351}
352
353impl<O, I, V> NewVerdictFactory<TotalIncremental, O> for StructVerdictFactory<V>
354where
355 V: FromValues<OutputTime = I>,
356 O: OutputTimeRepresentation + TimeConversion<I>,
357{
358 type CreationData = ();
359 type CreationError = StructVerdictError;
360
361 fn new(ir: &RtLolaMir, _data: Self::CreationData) -> Result<Self, Self::Error> {
362 Self::new(ir)
363 }
364}