hecate/codegen/input_schema/
range.rs

1use schemars::{JsonSchema, json_schema};
2use std::borrow::Cow;
3use std::str::FromStr;
4
5use thiserror::Error;
6
7use serde::{Deserialize, Serialize, de::Error};
8
9use super::{RANGE_PATTERN, RawRepr};
10
11#[derive(Clone, Debug, PartialEq)]
12pub struct Range<T> {
13    pub start: T,
14    pub end: T,
15    raw: String,
16}
17
18impl<T> Serialize for Range<T> {
19    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
20    where
21        S: serde::Serializer,
22    {
23        serializer.serialize_str(&self.raw)
24    }
25}
26impl<T: JsonSchema> JsonSchema for Range<T> {
27    fn schema_name() -> Cow<'static, str> {
28        format!("Range<{}>", T::schema_name()).into()
29    }
30
31    fn schema_id() -> Cow<'static, str> {
32        format!("{}::Range<{}>", module_path!(), T::schema_id()).into()
33    }
34
35    fn json_schema(_generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
36        json_schema!({
37            "type": "string",
38            "pattern": RANGE_PATTERN
39        })
40    }
41
42    fn inline_schema() -> bool {
43        true
44    }
45}
46// impl<T> JsonSchema for Range<T> {
47//     fn schema_name() -> String {
48//         String::from("Range")
49//     }
50//
51//     fn json_schema(_gen: &mut schemars::r#gen::SchemaGenerator) -> schemars::schema::Schema {
52//         let mut schema = SchemaObject::default();
53//         //schema.instance_type = Some(SingleOrVec::Single(Box::new(InstanceType::String)));
54//         schema.subschemas = Some(Box::new(SubschemaValidation {
55//             one_of: Some(vec![
56//                 // Schema for string type
57//                 Schema::Object(SchemaObject {
58//                     instance_type: Some(SingleOrVec::Single(Box::new(InstanceType::String))),
59//                     string: Some(Box::new(StringValidation {
60//                         pattern: Some(RANGE_PATTERN.to_string()),
61//                         ..Default::default()
62//                     })),
63//                     ..Default::default()
64//                 }),
65//                 // Schema for number type
66//                 Schema::Object(SchemaObject {
67//                     instance_type: Some(SingleOrVec::Single(Box::new(InstanceType::Number))),
68//                     ..Default::default()
69//                 }),
70//             ]),
71//             ..Default::default()
72//         }));
73//
74//         Schema::Object(schema)
75//     }
76// }
77
78impl<'de, T> Deserialize<'de> for Range<T>
79where
80    T: FromStr,
81    T::Err: std::error::Error + 'static,
82{
83    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
84    where
85        D: serde::Deserializer<'de>,
86    {
87        let raw: String = Deserialize::deserialize(deserializer)?;
88        let range: Range<T> = raw.parse().map_err(|e| D::Error::custom(e))?;
89
90        Ok(range)
91    }
92}
93
94impl<T> RawRepr for Range<T> {
95    /// Returns a raw string representation as written by the user.
96    fn raw(&self) -> &str {
97        &self.raw
98    }
99}
100
101#[derive(Debug, Error, PartialEq)]
102pub enum ParseRangeError<T: FromStr>
103where
104    T::Err: std::error::Error + 'static,
105{
106    #[error("Wrong number of elements for a range, it should be 2. Example: 0 .. 1")]
107    WrongArgNumber,
108    #[error("Failed to parse start bound: {0}")]
109    InvalidStart(#[source] T::Err),
110    #[error("Failed to parse end bound: {0}")]
111    InvalidEnd(#[source] T::Err),
112}
113
114use super::quantity;
115
116impl<T> FromStr for Range<T>
117where
118    T: FromStr,
119    T::Err: std::error::Error + 'static,
120{
121    type Err = ParseRangeError<T>;
122
123    fn from_str(s: &str) -> Result<Self, Self::Err> {
124        let parts: Vec<&str> = s.split("..").map(|part| part.trim()).collect();
125
126        if parts.len() != 2 {
127            return Err(ParseRangeError::WrongArgNumber);
128        }
129
130        let start_unit = quantity::get_unit(parts[0]);
131        let end_unit = quantity::get_unit(parts[1]);
132
133        let mut start = parts[0].to_string();
134        let mut end = parts[1].to_string();
135        if start_unit.is_some() && end_unit.is_none() {
136            end += start_unit.expect("already checked")
137        } else if start_unit.is_none() && end_unit.is_some() {
138            start += end_unit.expect("already checked")
139        }
140
141        let start: T = start
142            .parse()
143            .map_err(|e| ParseRangeError::InvalidStart(e))?;
144        let end: T = end.parse().map_err(|e| ParseRangeError::InvalidEnd(e))?;
145
146        Ok(Range {
147            start,
148            end,
149            raw: s.to_string(),
150        })
151    }
152}
153
154#[cfg(test)]
155mod tests {
156
157    use super::*;
158
159    use super::quantity::*;
160
161    #[test]
162    fn same_unit_when_only_one() {
163        let raw = "0..1dm2";
164        let range: Range<Area> = raw.parse().expect("valid range should be parsed");
165        assert_eq!(
166            range,
167            Range {
168                start: "0dm2".parse().unwrap(),
169                end: "1dm2".parse().unwrap(),
170                raw: raw.to_string()
171            }
172        );
173    }
174
175    #[test]
176    fn different_units() {
177        let raw = "10km^+2..1m2";
178        let range: Range<Area> = raw.parse().expect("valid range should be parsed");
179        assert_eq!(
180            range,
181            Range {
182                start: "10km^+2".parse().unwrap(),
183                end: "1m2".parse().unwrap(),
184                raw: raw.to_string()
185            }
186        );
187    }
188
189    #[test]
190    fn parse_valid_range() {
191        let raw = "0 .. 1";
192        let range: Range<i32> = raw.parse().expect("Valid range should be parsed");
193        assert_eq!(
194            range,
195            Range {
196                start: 0,
197                end: 1,
198                raw: raw.to_string()
199            }
200        )
201    }
202    #[test]
203    fn serialize_range() {
204        let range = Range {
205            start: 0,
206            end: 1,
207            raw: "0 .. 1".to_string(),
208        };
209        let serialized = serde_json::to_string(&range).expect("Serialization should succeed");
210        assert_eq!(serialized, "\"0 .. 1\"");
211    }
212    //
213    //#[test]
214    //fn deserialize_range() {
215    //    let raw = "\"0 .. 1\"";
216    //    let range: Range<i32> = serde_json::from_str(raw).expect("Deserialization should succeed");
217    //    assert_eq!(
218    //        range,
219    //        Range {
220    //            start: 0,
221    //            end: 1,
222    //            raw: "0 .. 1".to_string()
223    //        }
224    //    );
225    //}
226
227    #[test]
228    fn parse_invalid_range_wrong_arg_number() {
229        let raw = "0 .. 1 .. 2";
230        let result: Result<Range<i32>, _> = raw.parse();
231        assert_eq!(result.unwrap_err(), ParseRangeError::WrongArgNumber);
232    }
233
234    #[test]
235    fn parse_invalid_range_invalid_start() {
236        let raw = "a .. 1";
237        let result: Result<Range<i32>, _> = raw.parse();
238        assert_eq!(
239            result.unwrap_err(),
240            ParseRangeError::InvalidStart("a".parse::<i32>().unwrap_err())
241        );
242    }
243
244    #[test]
245    fn parse_invalid_range_invalid_end() {
246        let raw = "0 .. b";
247        let result: Result<Range<i32>, _> = raw.parse();
248        assert_eq!(
249            result.unwrap_err(),
250            ParseRangeError::InvalidEnd("b".parse::<i32>().unwrap_err())
251        );
252    }
253
254    #[test]
255    fn parse_length_range() {
256        let raw = "-1 .. 1";
257        let range: Range<Length> = raw.parse().expect("Valid range should be parsed");
258        assert_eq!(
259            range,
260            Range {
261                start: "-1".parse().unwrap(),
262                end: "1".parse().unwrap(),
263                raw: raw.to_string()
264            }
265        )
266    }
267}