Skip to main content

laddu_core/parameters/
map.rs

1use std::{collections::HashMap, fmt::Display, ops::Index};
2
3use serde::{Deserialize, Serialize};
4
5use super::Parameter;
6use crate::{
7    resources::{ParameterID, Parameters},
8    LadduError, LadduResult,
9};
10
11/// An ordered set of [`Parameter`]s.
12#[derive(Default, Debug, Clone)]
13pub struct ParameterMap {
14    parameters: Vec<Parameter>,
15    name_to_index: HashMap<String, usize>,
16}
17
18#[derive(Serialize, Deserialize)]
19struct ParameterMapSerde {
20    parameters: Vec<Parameter>,
21}
22
23impl Index<usize> for ParameterMap {
24    type Output = Parameter;
25
26    fn index(&self, index: usize) -> &Self::Output {
27        &self.parameters[index]
28    }
29}
30
31impl Index<&str> for ParameterMap {
32    type Output = Parameter;
33
34    fn index(&self, key: &str) -> &Self::Output {
35        self.get(key)
36            .unwrap_or_else(|| panic!("parameter '{key}' not found"))
37    }
38}
39
40impl IntoIterator for ParameterMap {
41    type Item = Parameter;
42    type IntoIter = std::vec::IntoIter<Parameter>;
43
44    fn into_iter(self) -> Self::IntoIter {
45        self.parameters.into_iter()
46    }
47}
48
49impl<'a> IntoIterator for &'a ParameterMap {
50    type Item = &'a Parameter;
51    type IntoIter = std::slice::Iter<'a, Parameter>;
52
53    fn into_iter(self) -> Self::IntoIter {
54        self.parameters.iter()
55    }
56}
57
58impl Serialize for ParameterMap {
59    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
60    where
61        S: serde::Serializer,
62    {
63        ParameterMapSerde {
64            parameters: self.parameters.clone(),
65        }
66        .serialize(serializer)
67    }
68}
69
70impl<'de> Deserialize<'de> for ParameterMap {
71    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
72    where
73        D: serde::Deserializer<'de>,
74    {
75        let serde = ParameterMapSerde::deserialize(deserializer)?;
76        Ok(Self::from_parameters(serde.parameters))
77    }
78}
79
80impl ParameterMap {
81    fn from_parameters(parameters: Vec<Parameter>) -> Self {
82        let name_to_index = parameters
83            .iter()
84            .enumerate()
85            .map(|(index, parameter)| (parameter.name(), index))
86            .collect();
87        Self {
88            parameters,
89            name_to_index,
90        }
91    }
92
93    /// Register a parameter into the ordered map and return its assembled [`ParameterID`].
94    pub fn register_parameter(&mut self, p: &Parameter) -> LadduResult<ParameterID> {
95        let name = p.name();
96        if name.is_empty() {
97            return Err(LadduError::UnregisteredParameter {
98                name: "<unnamed>".to_string(),
99                reason: "Parameter was not initialized with a name".to_string(),
100            });
101        }
102
103        if let Some((index, existing)) = self.get_indexed(&name) {
104            match (existing.fixed(), p.fixed()) {
105                (Some(a), Some(b)) if (a - b).abs() > f64::EPSILON => {
106                    return Err(LadduError::ParameterConflict {
107                        name,
108                        reason: "conflicting fixed values for the same parameter name".to_string(),
109                    });
110                }
111                (Some(_), None) => {
112                    return Err(LadduError::ParameterConflict {
113                        name,
114                        reason: "attempted to use a fixed parameter name as free".to_string(),
115                    });
116                }
117                (None, Some(_)) => {
118                    return Err(LadduError::ParameterConflict {
119                        name,
120                        reason: "attempted to use a free parameter name as fixed".to_string(),
121                    });
122                }
123                (Some(_), Some(_)) | (None, None) => return Ok(self.parameter_id(index)),
124            }
125        }
126
127        let index = self.parameters.len();
128        self.insert(p.clone());
129        Ok(self.parameter_id(index))
130    }
131
132    /// Return the assembled indices of all free parameters.
133    pub fn free_parameter_indices(&self) -> Vec<usize> {
134        (0..self.free().len()).collect()
135    }
136
137    /// Rename a single parameter in place.
138    pub fn rename_parameter(&mut self, old: &str, new: &str) -> LadduResult<()> {
139        if old == new {
140            return Ok(());
141        }
142        if self.contains_key(new) {
143            return Err(LadduError::ParameterConflict {
144                name: new.to_string(),
145                reason: "rename target already exists".to_string(),
146            });
147        }
148        if let Some(index) = self.index(old) {
149            let parameter = self.parameters[index].clone();
150            parameter.set_name(new);
151            self.name_to_index.remove(old);
152            self.name_to_index.insert(new.to_string(), index);
153        } else {
154            self.assert_parameter_exists(old)?;
155        }
156        Ok(())
157    }
158
159    /// Rename multiple parameters in place.
160    pub fn rename_parameters(&mut self, mapping: &HashMap<String, String>) -> LadduResult<()> {
161        for (old, new) in mapping {
162            self.rename_parameter(old, new)?;
163        }
164        Ok(())
165    }
166
167    /// Fix a parameter to the supplied value.
168    pub fn fix_parameter(&self, name: &str, value: f64) -> LadduResult<()> {
169        self.assert_parameter_exists(name)?;
170        if let Some(parameter) = self.get(name) {
171            parameter.set_fixed_value(Some(value));
172        }
173        Ok(())
174    }
175
176    /// Mark a parameter as free.
177    pub fn free_parameter(&self, name: &str) -> LadduResult<()> {
178        self.assert_parameter_exists(name)?;
179        if let Some(parameter) = self.get(name) {
180            parameter.set_fixed_value(None);
181        }
182        Ok(())
183    }
184
185    /// Return whether a parameter with the given name exists.
186    pub fn contains_key(&self, name: &str) -> bool {
187        self.name_to_index.contains_key(name)
188    }
189
190    /// Return the storage index for a named parameter.
191    pub fn index(&self, name: &str) -> Option<usize> {
192        self.name_to_index.get(name).copied()
193    }
194
195    /// Insert or replace a parameter by name while preserving insertion order.
196    pub fn insert(&mut self, parameter: Parameter) -> Option<Parameter> {
197        let name = parameter.name();
198        if let Some(index) = self.index(&name) {
199            Some(std::mem::replace(&mut self.parameters[index], parameter))
200        } else {
201            let index = self.parameters.len();
202            self.parameters.push(parameter);
203            self.name_to_index.insert(name, index);
204            None
205        }
206    }
207
208    /// The number of parameters in the set.
209    pub fn len(&self) -> usize {
210        self.parameters.len()
211    }
212
213    /// Returns true if the parameter set has no elements.
214    pub fn is_empty(&self) -> bool {
215        self.parameters.is_empty()
216    }
217
218    /// Iterate over all parameters in the set.
219    pub fn iter(&self) -> std::slice::Iter<'_, Parameter> {
220        self.parameters.iter()
221    }
222
223    /// Get a parameter by name.
224    pub fn get(&self, key: &str) -> Option<&Parameter> {
225        self.index(key).map(|index| &self.parameters[index])
226    }
227
228    /// Get both the storage index and parameter for a given name.
229    pub fn get_indexed(&self, key: &str) -> Option<(usize, &Parameter)> {
230        self.index(key)
231            .map(|index| (index, &self.parameters[index]))
232    }
233
234    /// Get all parameter names in order.
235    pub fn names(&self) -> Vec<String> {
236        self.parameters.iter().map(Parameter::name).collect()
237    }
238
239    /// Filter the parameter set by a predicate.
240    pub fn filter(&self, predicate: impl Fn(&Parameter) -> bool) -> Self {
241        Self::from_parameters(
242            self.parameters
243                .iter()
244                .filter(|parameter| predicate(parameter))
245                .cloned()
246                .collect(),
247        )
248    }
249
250    /// Get a set containing only free parameters.
251    pub fn free(&self) -> Self {
252        self.filter(|p| p.is_free())
253    }
254
255    /// Get a set containing only fixed parameters.
256    pub fn fixed(&self) -> Self {
257        self.filter(|p| p.is_fixed())
258    }
259
260    /// Get a set containing only initialized parameters.
261    pub fn initialized(&self) -> Self {
262        self.filter(|p| p.initial().is_some())
263    }
264
265    /// Get a set containing only uninitialized parameters.
266    pub fn uninitialized(&self) -> Self {
267        self.filter(|p| p.initial().is_none())
268    }
269
270    /// Assemble free inputs into a full [`Parameters`] object.
271    ///
272    /// The resulting values are ordered with all free parameters first, followed by fixed ones.
273    pub fn assemble(&self, free_values: &[f64]) -> LadduResult<Parameters> {
274        let expected_free = self.free().len();
275        let n_fixed = self.fixed().len();
276        let mut values = vec![0.0; expected_free + n_fixed];
277        let mut storage_to_assembled = vec![0; self.len()];
278        let mut free_iter = free_values.iter();
279        let mut free_index = 0;
280        let mut fixed_index = expected_free;
281        for (storage_index, parameter) in self.parameters.iter().enumerate() {
282            if let Some(value) = parameter.fixed() {
283                values[fixed_index] = value;
284                storage_to_assembled[storage_index] = fixed_index;
285                fixed_index += 1;
286            } else if let Some(value) = free_iter.next() {
287                values[free_index] = *value;
288                storage_to_assembled[storage_index] = free_index;
289                free_index += 1;
290            } else {
291                return Err(LadduError::LengthMismatch {
292                    context: "parameter values".to_string(),
293                    expected: expected_free,
294                    actual: free_values.len(),
295                });
296            }
297        }
298        if free_iter.next().is_some() {
299            return Err(LadduError::LengthMismatch {
300                context: "parameter values".to_string(),
301                expected: expected_free,
302                actual: free_values.len(),
303            });
304        }
305        Ok(Parameters::new(values, expected_free, storage_to_assembled))
306    }
307
308    /// Merge two parameter maps.
309    ///
310    /// When parameters overlap, the state and value stored in `self` always take precedence over
311    /// entries from `other`.
312    pub fn merge(&self, other: &Self) -> (Self, Vec<usize>, Vec<usize>) {
313        let mut merged = self.clone();
314        let mut right_map = Vec::with_capacity(other.len());
315        for parameter in other {
316            let idx = merged.ensure_parameter(parameter.clone());
317            right_map.push(idx);
318        }
319        let left_map: Vec<usize> = (0..self.len())
320            .map(|index| merged.assembled_index(index))
321            .collect();
322        let right_map = right_map
323            .into_iter()
324            .map(|index| merged.assembled_index(index))
325            .collect();
326        (merged, left_map, right_map)
327    }
328
329    /// Extend a parameter map from another one.
330    ///
331    /// When both managers reference the same parameter, the value and fixed/free status from
332    /// `self` are retained.
333    pub fn extend_from(&self, other: &Self) -> (Self, Vec<usize>) {
334        let mut merged = self.clone();
335        let mut indices = Vec::with_capacity(other.len());
336        for parameter in other {
337            let idx = merged.ensure_parameter(parameter.clone());
338            indices.push(idx);
339        }
340        let indices = indices
341            .into_iter()
342            .map(|index| merged.assembled_index(index))
343            .collect();
344        (merged, indices)
345    }
346
347    fn ensure_parameter(&mut self, parameter: Parameter) -> usize {
348        let name = parameter.name();
349        if let Some(idx) = self.index(&name) {
350            return idx;
351        }
352        let idx = self.len();
353        self.insert(parameter);
354        idx
355    }
356
357    fn assembled_index(&self, storage_index: usize) -> usize {
358        let n_free = self
359            .parameters
360            .iter()
361            .filter(|parameter| parameter.is_free())
362            .count();
363        let preceding_in_group = self.parameters[..storage_index]
364            .iter()
365            .filter(|parameter| self.parameters[storage_index].is_free() == parameter.is_free())
366            .count();
367        if self.parameters[storage_index].is_free() {
368            preceding_in_group
369        } else {
370            n_free + preceding_in_group
371        }
372    }
373
374    fn parameter_id(&self, storage_index: usize) -> ParameterID {
375        if self.parameters[storage_index].is_fixed() {
376            ParameterID::Constant(storage_index)
377        } else {
378            ParameterID::Parameter(storage_index)
379        }
380    }
381
382    fn assert_parameter_exists(&self, name: &str) -> LadduResult<()> {
383        if self.contains_key(name) {
384            Ok(())
385        } else {
386            Err(LadduError::UnregisteredParameter {
387                name: name.to_string(),
388                reason: "parameter not found".to_string(),
389            })
390        }
391    }
392}
393
394impl Display for ParameterMap {
395    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
396        writeln!(f, "ParameterMap:")?;
397        if self.parameters.is_empty() {
398            writeln!(f, "  <empty>")?;
399            return Ok(());
400        }
401        writeln!(f, "  free:")?;
402        let mut wrote_free = false;
403        for parameter in self
404            .parameters
405            .iter()
406            .filter(|parameter| parameter.is_free())
407        {
408            wrote_free = true;
409            writeln!(f, "    {}", parameter.name())?;
410        }
411        if !wrote_free {
412            writeln!(f, "    <none>")?;
413        }
414        writeln!(f, "  fixed:")?;
415        let mut wrote_fixed = false;
416        for parameter in self
417            .parameters
418            .iter()
419            .filter(|parameter| parameter.is_fixed())
420        {
421            wrote_fixed = true;
422            if let Some(value) = parameter.fixed() {
423                writeln!(f, "    {} = {}", parameter.name(), value)?;
424            }
425        }
426        if !wrote_fixed {
427            writeln!(f, "    <none>")?;
428        }
429        Ok(())
430    }
431}