Skip to main content

openjd_model/template/
task_parameters.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5//! Task parameter definitions per spec §3.4.
6
7use super::constrained_strings::Identifier;
8use super::parameters::FlexInt;
9use crate::format_string::FormatString;
10use serde::Deserialize;
11
12/// §3.4.1 TaskParameterDefinition — discriminated union on `type`.
13#[derive(Debug, Clone, Deserialize)]
14#[serde(tag = "type")]
15#[allow(non_camel_case_types)]
16pub enum TaskParameterDefinition {
17    INT(IntTaskParameterDefinition),
18    FLOAT(FloatTaskParameterDefinition),
19    STRING(StringTaskParameterDefinition),
20    PATH(PathTaskParameterDefinition),
21    #[serde(rename = "CHUNK[INT]")]
22    CHUNK_INT(ChunkIntTaskParameterDefinition),
23}
24
25impl TaskParameterDefinition {
26    pub fn task_param_type(&self) -> crate::types::TaskParameterType {
27        use crate::types::TaskParameterType;
28        match self {
29            Self::INT(_) => TaskParameterType::Int,
30            Self::FLOAT(_) => TaskParameterType::Float,
31            Self::STRING(_) => TaskParameterType::String,
32            Self::PATH(_) => TaskParameterType::Path,
33            Self::CHUNK_INT(_) => TaskParameterType::ChunkInt,
34        }
35    }
36
37    pub fn name(&self) -> &str {
38        match self {
39            Self::INT(p) => p.name.as_str(),
40            Self::FLOAT(p) => p.name.as_str(),
41            Self::STRING(p) => p.name.as_str(),
42            Self::PATH(p) => p.name.as_str(),
43            Self::CHUNK_INT(p) => p.name.as_str(),
44        }
45    }
46}
47
48/// Int range: either a list of values or a range expression string.
49#[derive(Debug, Clone)]
50pub enum IntRange {
51    List(Vec<FlexInt>),
52    Expression(FormatString),
53}
54
55impl<'de> Deserialize<'de> for IntRange {
56    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
57        let val = serde_json::Value::deserialize(deserializer)?;
58        match val {
59            serde_json::Value::Array(seq) => {
60                let items: Result<Vec<FlexInt>, _> = seq
61                    .into_iter()
62                    .map(|v| serde_json::from_value(v).map_err(serde::de::Error::custom))
63                    .collect();
64                Ok(IntRange::List(items?))
65            }
66            serde_json::Value::String(s) => FormatString::new(&s)
67                .map(IntRange::Expression)
68                .map_err(serde::de::Error::custom),
69            _ => Err(serde::de::Error::custom(
70                "Expected list or string for range",
71            )),
72        }
73    }
74}
75
76/// Range that can be a list or a single expression string (EXPR extension).
77/// Concrete types to avoid derive conflicts with FormatString.
78
79#[derive(Debug, Clone)]
80pub enum StringRange {
81    List(Vec<FormatString>),
82    Expression(FormatString),
83}
84
85impl<'de> Deserialize<'de> for StringRange {
86    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
87        let val = serde_json::Value::deserialize(deserializer)?;
88        match &val {
89            serde_json::Value::Array(_) => {
90                let items: Vec<FormatString> =
91                    serde_json::from_value(val).map_err(serde::de::Error::custom)?;
92                Ok(StringRange::List(items))
93            }
94            serde_json::Value::String(s) => FormatString::new(s)
95                .map(StringRange::Expression)
96                .map_err(serde::de::Error::custom),
97            _ => Err(serde::de::Error::custom(
98                "Expected list or string for range",
99            )),
100        }
101    }
102}
103
104/// A float range list item: either a literal float or a format string.
105#[derive(Debug, Clone)]
106pub enum FloatRangeItem {
107    Float(f64),
108    FormatString(FormatString),
109}
110
111impl<'de> Deserialize<'de> for FloatRangeItem {
112    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
113        let val = serde_json::Value::deserialize(deserializer)?;
114        match &val {
115            serde_json::Value::Number(n) => {
116                let f = n
117                    .as_f64()
118                    .ok_or_else(|| serde::de::Error::custom("Invalid number in float range"))?;
119                super::parameters::reject_nan_inf(f).map_err(serde::de::Error::custom)?;
120                Ok(FloatRangeItem::Float(f))
121            }
122            serde_json::Value::String(s) => FormatString::new(s)
123                .map(FloatRangeItem::FormatString)
124                .map_err(serde::de::Error::custom),
125            _ => Err(serde::de::Error::custom(
126                "Expected number or string in float range",
127            )),
128        }
129    }
130}
131
132#[derive(Debug, Clone)]
133pub enum FloatRange {
134    List(Vec<FloatRangeItem>),
135    Expression(FormatString),
136}
137
138impl<'de> Deserialize<'de> for FloatRange {
139    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
140        let val = serde_json::Value::deserialize(deserializer)?;
141        match &val {
142            serde_json::Value::Array(_) => {
143                let items: Vec<FloatRangeItem> =
144                    serde_json::from_value(val).map_err(serde::de::Error::custom)?;
145                Ok(FloatRange::List(items))
146            }
147            serde_json::Value::String(s) => FormatString::new(s)
148                .map(FloatRange::Expression)
149                .map_err(serde::de::Error::custom),
150            _ => Err(serde::de::Error::custom(
151                "Expected list or string for range",
152            )),
153        }
154    }
155}
156
157/// §3.4.1.1 IntTaskParameterDefinition
158#[derive(Debug, Clone, Deserialize)]
159#[serde(rename_all = "camelCase", deny_unknown_fields)]
160pub struct IntTaskParameterDefinition {
161    pub name: Identifier,
162    pub range: IntRange,
163}
164
165/// §3.4.1.2 FloatTaskParameterDefinition
166#[derive(Debug, Clone, Deserialize)]
167#[serde(rename_all = "camelCase", deny_unknown_fields)]
168pub struct FloatTaskParameterDefinition {
169    pub name: Identifier,
170    pub range: FloatRange,
171}
172
173/// §3.4.1.3 StringTaskParameterDefinition
174#[derive(Debug, Clone, Deserialize)]
175#[serde(rename_all = "camelCase", deny_unknown_fields)]
176pub struct StringTaskParameterDefinition {
177    pub name: Identifier,
178    pub range: StringRange,
179}
180
181/// §3.4.1.4 PathTaskParameterDefinition
182#[derive(Debug, Clone, Deserialize)]
183#[serde(rename_all = "camelCase", deny_unknown_fields)]
184pub struct PathTaskParameterDefinition {
185    pub name: Identifier,
186    pub range: StringRange,
187}
188
189/// §3.4.1.5 ChunkIntTaskParameterDefinition (TASK_CHUNKING extension)
190#[derive(Debug, Clone, Deserialize)]
191#[serde(rename_all = "camelCase", deny_unknown_fields)]
192pub struct ChunkIntTaskParameterDefinition {
193    pub name: Identifier,
194    pub range: IntRange,
195    pub chunks: ChunksDefinition,
196}
197
198/// An integer value or a format string (e.g. `{{Param.ChunkSize}}`).
199///
200/// Accepts:
201/// - YAML integer → `IntOrFormatString::Int(n)`
202/// - String that parses as i64 → `IntOrFormatString::Int(n)`
203/// - String containing `{{…}}` → `IntOrFormatString::FormatString(fs)`
204/// - Boolean/null → error
205#[derive(Debug, Clone)]
206pub enum IntOrFormatString {
207    Int(i64),
208    FormatString(FormatString),
209}
210
211impl IntOrFormatString {
212    /// Return the integer value if this is a literal, or `None` if it's a format string.
213    pub fn as_i64(&self) -> Option<i64> {
214        match self {
215            Self::Int(n) => Some(*n),
216            Self::FormatString(_) => None,
217        }
218    }
219}
220
221impl<'de> Deserialize<'de> for IntOrFormatString {
222    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
223        let val = serde_json::Value::deserialize(deserializer)?;
224        match &val {
225            serde_json::Value::Number(n) => {
226                if let Some(i) = n.as_i64() {
227                    Ok(Self::Int(i))
228                } else if let Some(f) = n.as_f64() {
229                    if f.fract() == 0.0 {
230                        Ok(Self::Int(f as i64))
231                    } else {
232                        Err(serde::de::Error::custom(format!(
233                            "Expected integer, got float: {f}"
234                        )))
235                    }
236                } else {
237                    Err(serde::de::Error::custom("Invalid number"))
238                }
239            }
240            serde_json::Value::String(s) => {
241                // If it contains format string interpolation, treat as FormatString
242                if s.contains("{{") {
243                    FormatString::new(s)
244                        .map(Self::FormatString)
245                        .map_err(serde::de::Error::custom)
246                } else {
247                    // Try parsing as integer
248                    s.trim().parse::<i64>().map(Self::Int).map_err(|_| {
249                        serde::de::Error::custom(format!("Cannot parse '{s}' as integer"))
250                    })
251                }
252            }
253            serde_json::Value::Bool(_) => {
254                Err(serde::de::Error::custom("Expected integer, got boolean"))
255            }
256            serde_json::Value::Null => Err(serde::de::Error::custom("Expected integer, got null")),
257            _ => Err(serde::de::Error::custom("Expected integer or string")),
258        }
259    }
260}
261
262/// Chunks configuration for `CHUNK[INT]` parameters.
263#[derive(Debug, Clone, Deserialize)]
264#[serde(rename_all = "camelCase", deny_unknown_fields)]
265pub struct ChunksDefinition {
266    pub default_task_count: IntOrFormatString,
267    pub target_runtime_seconds: Option<IntOrFormatString>,
268    pub range_constraint: RangeConstraint,
269}
270
271#[derive(Debug, Clone, PartialEq, Eq, Deserialize, serde::Serialize)]
272#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
273pub enum RangeConstraint {
274    Contiguous,
275    Noncontiguous,
276}
277
278/// §3.4 StepParameterSpaceDefinition
279#[derive(Debug, Clone, Deserialize)]
280#[serde(rename_all = "camelCase", deny_unknown_fields)]
281pub struct StepParameterSpaceDefinition {
282    pub task_parameter_definitions: Vec<TaskParameterDefinition>,
283    pub combination: Option<String>,
284}