hdp_primitives/task/datalake/
compute.rs

1use serde::de::{self, MapAccess, Visitor};
2use serde::{Deserialize, Deserializer, Serialize};
3use std::fmt;
4
5use crate::aggregate_fn::{AggregationFunction, FunctionContext};
6
7/// [`Computation`] is a structure that contains the aggregate function id and context
8#[derive(Debug, PartialEq, Eq, Clone, Serialize)]
9#[serde(rename_all = "camelCase")]
10pub struct Computation {
11    pub aggregate_fn_id: AggregationFunction,
12    pub aggregate_fn_ctx: FunctionContext,
13}
14
15impl Computation {
16    pub fn new(
17        aggregate_fn_id: AggregationFunction,
18        aggregate_fn_ctx: Option<FunctionContext>,
19    ) -> Self {
20        let aggregate_fn_ctn_parsed = aggregate_fn_ctx.unwrap_or_default();
21        Self {
22            aggregate_fn_id,
23            aggregate_fn_ctx: aggregate_fn_ctn_parsed,
24        }
25    }
26}
27
28impl<'de> Deserialize<'de> for Computation {
29    fn deserialize<D>(deserializer: D) -> Result<Computation, D::Error>
30    where
31        D: Deserializer<'de>,
32    {
33        #[derive(Deserialize)]
34        #[serde(field_identifier, rename_all = "camelCase")]
35        enum Field {
36            AggregateFnId,
37            AggregateFnCtx,
38        }
39
40        struct ComputationVisitor;
41
42        impl<'de> Visitor<'de> for ComputationVisitor {
43            type Value = Computation;
44
45            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
46                formatter.write_str("struct Computation")
47            }
48
49            fn visit_map<V>(self, mut map: V) -> Result<Computation, V::Error>
50            where
51                V: MapAccess<'de>,
52            {
53                let mut aggregate_fn_id = None;
54                let mut aggregate_fn_ctx = None;
55                while let Some(key) = map.next_key()? {
56                    match key {
57                        Field::AggregateFnId => {
58                            if aggregate_fn_id.is_some() {
59                                return Err(de::Error::duplicate_field("aggregateFnId"));
60                            }
61                            aggregate_fn_id = Some(map.next_value()?);
62                        }
63                        Field::AggregateFnCtx => {
64                            if aggregate_fn_ctx.is_some() {
65                                return Err(de::Error::duplicate_field("aggregateFnCtx"));
66                            }
67                            aggregate_fn_ctx = Some(map.next_value()?);
68                        }
69                    }
70                }
71                let aggregate_fn_id =
72                    aggregate_fn_id.ok_or_else(|| de::Error::missing_field("aggregateFnId"))?;
73                let aggregate_fn_ctx = aggregate_fn_ctx.unwrap_or_default();
74                Ok(Computation {
75                    aggregate_fn_id,
76                    aggregate_fn_ctx,
77                })
78            }
79        }
80
81        const FIELDS: &[&str] = &["aggregateFnId", "aggregateFnCtx"];
82        deserializer.deserialize_struct("Computation", FIELDS, ComputationVisitor)
83    }
84}