1use std::{
15    fmt::{self, Display},
16    str,
17};
18
19use derive_getters::Getters;
20use serde::{
21    de::{Error as SerdeError, IntoDeserializer},
22    Deserialize, Deserializer, Serialize, Serializer,
23};
24
25use derive_builder::Builder;
26
27use crate::{error::Error, types::StructField};
28
29use super::types::{StructType, Type};
30
31pub static DEFAULT_PARTITION_SPEC_ID: i32 = 0;
32
33#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
34#[serde(rename_all = "lowercase", remote = "Self")]
35pub enum Transform {
37    Identity,
39    Bucket(u32),
41    Truncate(u32),
43    Year,
45    Month,
47    Day,
49    Hour,
51    Void,
53}
54
55impl<'de> Deserialize<'de> for Transform {
56    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
57    where
58        D: Deserializer<'de>,
59    {
60        let s = String::deserialize(deserializer)?;
61        if s.starts_with("bucket") {
62            deserialize_bucket(s.into_deserializer())
63        } else if s.starts_with("truncate") {
64            deserialize_truncate(s.into_deserializer())
65        } else {
66            Transform::deserialize(s.into_deserializer())
67        }
68    }
69}
70
71impl Serialize for Transform {
72    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
73    where
74        S: serde::Serializer,
75    {
76        match self {
77            Transform::Bucket(bucket) => serialize_bucket(bucket, serializer),
78            Transform::Truncate(truncate) => serialize_truncate(truncate, serializer),
79            x => Transform::serialize(x, serializer),
80        }
81    }
82}
83
84fn deserialize_bucket<'de, D>(deserializer: D) -> Result<Transform, D::Error>
85where
86    D: Deserializer<'de>,
87{
88    let bucket = String::deserialize(deserializer)?
89        .trim_start_matches(r"bucket[")
90        .trim_end_matches(']')
91        .to_owned();
92
93    bucket
94        .parse()
95        .map(Transform::Bucket)
96        .map_err(D::Error::custom)
97}
98
99fn serialize_bucket<S>(value: &u32, serializer: S) -> Result<S::Ok, S::Error>
100where
101    S: Serializer,
102{
103    serializer.serialize_str(&format!("bucket[{value}]"))
104}
105
106fn deserialize_truncate<'de, D>(deserializer: D) -> Result<Transform, D::Error>
107where
108    D: Deserializer<'de>,
109{
110    let truncate = String::deserialize(deserializer)?
111        .trim_start_matches(r"truncate[")
112        .trim_end_matches(']')
113        .to_owned();
114
115    truncate
116        .parse()
117        .map(Transform::Truncate)
118        .map_err(D::Error::custom)
119}
120
121fn serialize_truncate<S>(value: &u32, serializer: S) -> Result<S::Ok, S::Error>
122where
123    S: Serializer,
124{
125    serializer.serialize_str(&format!("truncate[{value}]"))
126}
127
128impl Display for Transform {
129    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130        match self {
131            Transform::Identity => write!(f, "identity"),
132            Transform::Year => write!(f, "year"),
133            Transform::Month => write!(f, "month"),
134            Transform::Day => write!(f, "day"),
135            Transform::Hour => write!(f, "hour"),
136            Transform::Bucket(i) => write!(f, "bucket[{i}]"),
137            Transform::Truncate(i) => write!(f, "truncate[{i}]"),
138            Transform::Void => write!(f, "void"),
139        }
140    }
141}
142
143#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Getters)]
144#[serde(rename_all = "kebab-case")]
145pub struct PartitionField {
147    source_id: i32,
149    field_id: i32,
152    name: String,
154    transform: Transform,
156}
157
158impl PartitionField {
159    pub fn new(source_id: i32, field_id: i32, name: &str, transform: Transform) -> Self {
161        Self {
162            source_id,
163            field_id,
164            name: name.to_string(),
165            transform,
166        }
167    }
168}
169
170#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Default, Builder, Getters)]
171#[serde(rename_all = "kebab-case")]
172#[builder(setter(prefix = "with"))]
173pub struct PartitionSpec {
175    #[builder(default = "DEFAULT_PARTITION_SPEC_ID")]
177    spec_id: i32,
178    #[builder(setter(each(name = "with_partition_field")))]
180    fields: Vec<PartitionField>,
181}
182
183impl PartitionSpec {
184    pub fn builder() -> PartitionSpecBuilder {
186        PartitionSpecBuilder::default()
187    }
188    pub fn data_types(&self, schema: &StructType) -> Result<Vec<Type>, Error> {
190        self.fields
191            .iter()
192            .map(|field| {
193                schema
194                    .get(field.source_id as usize)
195                    .ok_or(Error::NotFound(format!("Schema field {}", field.name)))
196                    .and_then(|x| x.field_type.clone().tranform(&field.transform))
197            })
198            .collect::<Result<Vec<_>, Error>>()
199    }
200}
201
202impl fmt::Display for PartitionSpec {
203    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
204        write!(
205            f,
206            "{}",
207            &serde_json::to_string(self).map_err(|_| fmt::Error)?,
208        )
209    }
210}
211
212impl str::FromStr for PartitionSpec {
213    type Err = Error;
214    fn from_str(s: &str) -> Result<Self, Self::Err> {
215        serde_json::from_str(s).map_err(Error::from)
216    }
217}
218
219#[derive(Debug)]
222pub struct BoundPartitionField<'a> {
223    partition_field: &'a PartitionField,
224    struct_field: &'a StructField,
225}
226
227impl<'a> BoundPartitionField<'a> {
228    pub fn new(partition_field: &'a PartitionField, struct_field: &'a StructField) -> Self {
234        Self {
235            partition_field,
236            struct_field,
237        }
238    }
239
240    pub fn name(&self) -> &str {
242        &self.partition_field.name
243    }
244
245    pub fn source_name(&self) -> &str {
247        &self.struct_field.name
248    }
249
250    pub fn field_type(&self) -> &Type {
252        &self.struct_field.field_type
253    }
254
255    pub fn transform(&self) -> &Transform {
257        &self.partition_field.transform
258    }
259
260    pub fn field_id(&self) -> i32 {
262        self.partition_field.field_id
263    }
264
265    pub fn source_id(&self) -> i32 {
267        self.partition_field.source_id
268    }
269
270    pub fn required(&self) -> bool {
272        self.struct_field.required
273    }
274
275    pub fn partition_field(&self) -> &PartitionField {
277        self.partition_field
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    #[test]
286    fn partition_spec() {
287        let sort_order = r#"
288        {
289        "spec-id": 1,
290        "fields": [ {
291            "source-id": 4,
292            "field-id": 1000,
293            "name": "ts_day",
294            "transform": "day"
295            }, {
296            "source-id": 1,
297            "field-id": 1001,
298            "name": "id_bucket",
299            "transform": "bucket[16]"
300            }, {
301            "source-id": 2,
302            "field-id": 1002,
303            "name": "id_truncate",
304            "transform": "truncate[4]"
305            } ]
306        }
307        "#;
308
309        let partition_spec: PartitionSpec = serde_json::from_str(sort_order).unwrap();
310        assert_eq!(4, partition_spec.fields[0].source_id);
311        assert_eq!(1000, partition_spec.fields[0].field_id);
312        assert_eq!("ts_day", partition_spec.fields[0].name);
313        assert_eq!(Transform::Day, partition_spec.fields[0].transform);
314
315        assert_eq!(1, partition_spec.fields[1].source_id);
316        assert_eq!(1001, partition_spec.fields[1].field_id);
317        assert_eq!("id_bucket", partition_spec.fields[1].name);
318        assert_eq!(Transform::Bucket(16), partition_spec.fields[1].transform);
319
320        assert_eq!(2, partition_spec.fields[2].source_id);
321        assert_eq!(1002, partition_spec.fields[2].field_id);
322        assert_eq!("id_truncate", partition_spec.fields[2].name);
323        assert_eq!(Transform::Truncate(4), partition_spec.fields[2].transform);
324    }
325}