1use std::{collections::HashMap, fmt::Display, ops::Index};
2
3use serde::{Deserialize, Serialize};
4
5use super::Parameter;
6use crate::{
7 resources::{ParameterID, Parameters},
8 LadduError, LadduResult,
9};
10
11#[derive(Default, Debug, Clone)]
13pub struct ParameterMap {
14 parameters: Vec<Parameter>,
15 name_to_index: HashMap<String, usize>,
16}
17
18#[derive(Serialize, Deserialize)]
19struct ParameterMapSerde {
20 parameters: Vec<Parameter>,
21}
22
23impl Index<usize> for ParameterMap {
24 type Output = Parameter;
25
26 fn index(&self, index: usize) -> &Self::Output {
27 &self.parameters[index]
28 }
29}
30
31impl Index<&str> for ParameterMap {
32 type Output = Parameter;
33
34 fn index(&self, key: &str) -> &Self::Output {
35 self.get(key)
36 .unwrap_or_else(|| panic!("parameter '{key}' not found"))
37 }
38}
39
40impl IntoIterator for ParameterMap {
41 type Item = Parameter;
42 type IntoIter = std::vec::IntoIter<Parameter>;
43
44 fn into_iter(self) -> Self::IntoIter {
45 self.parameters.into_iter()
46 }
47}
48
49impl<'a> IntoIterator for &'a ParameterMap {
50 type Item = &'a Parameter;
51 type IntoIter = std::slice::Iter<'a, Parameter>;
52
53 fn into_iter(self) -> Self::IntoIter {
54 self.parameters.iter()
55 }
56}
57
58impl Serialize for ParameterMap {
59 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
60 where
61 S: serde::Serializer,
62 {
63 ParameterMapSerde {
64 parameters: self.parameters.clone(),
65 }
66 .serialize(serializer)
67 }
68}
69
70impl<'de> Deserialize<'de> for ParameterMap {
71 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
72 where
73 D: serde::Deserializer<'de>,
74 {
75 let serde = ParameterMapSerde::deserialize(deserializer)?;
76 Ok(Self::from_parameters(serde.parameters))
77 }
78}
79
80impl ParameterMap {
81 fn from_parameters(parameters: Vec<Parameter>) -> Self {
82 let name_to_index = parameters
83 .iter()
84 .enumerate()
85 .map(|(index, parameter)| (parameter.name(), index))
86 .collect();
87 Self {
88 parameters,
89 name_to_index,
90 }
91 }
92
93 pub fn register_parameter(&mut self, p: &Parameter) -> LadduResult<ParameterID> {
95 let name = p.name();
96 if name.is_empty() {
97 return Err(LadduError::UnregisteredParameter {
98 name: "<unnamed>".to_string(),
99 reason: "Parameter was not initialized with a name".to_string(),
100 });
101 }
102
103 if let Some((index, existing)) = self.get_indexed(&name) {
104 match (existing.fixed(), p.fixed()) {
105 (Some(a), Some(b)) if (a - b).abs() > f64::EPSILON => {
106 return Err(LadduError::ParameterConflict {
107 name,
108 reason: "conflicting fixed values for the same parameter name".to_string(),
109 });
110 }
111 (Some(_), None) => {
112 return Err(LadduError::ParameterConflict {
113 name,
114 reason: "attempted to use a fixed parameter name as free".to_string(),
115 });
116 }
117 (None, Some(_)) => {
118 return Err(LadduError::ParameterConflict {
119 name,
120 reason: "attempted to use a free parameter name as fixed".to_string(),
121 });
122 }
123 (Some(_), Some(_)) | (None, None) => return Ok(self.parameter_id(index)),
124 }
125 }
126
127 let index = self.parameters.len();
128 self.insert(p.clone());
129 Ok(self.parameter_id(index))
130 }
131
132 pub fn free_parameter_indices(&self) -> Vec<usize> {
134 (0..self.free().len()).collect()
135 }
136
137 pub fn rename_parameter(&mut self, old: &str, new: &str) -> LadduResult<()> {
139 if old == new {
140 return Ok(());
141 }
142 if self.contains_key(new) {
143 return Err(LadduError::ParameterConflict {
144 name: new.to_string(),
145 reason: "rename target already exists".to_string(),
146 });
147 }
148 if let Some(index) = self.index(old) {
149 let parameter = self.parameters[index].clone();
150 parameter.set_name(new);
151 self.name_to_index.remove(old);
152 self.name_to_index.insert(new.to_string(), index);
153 } else {
154 self.assert_parameter_exists(old)?;
155 }
156 Ok(())
157 }
158
159 pub fn rename_parameters(&mut self, mapping: &HashMap<String, String>) -> LadduResult<()> {
161 for (old, new) in mapping {
162 self.rename_parameter(old, new)?;
163 }
164 Ok(())
165 }
166
167 pub fn fix_parameter(&self, name: &str, value: f64) -> LadduResult<()> {
169 self.assert_parameter_exists(name)?;
170 if let Some(parameter) = self.get(name) {
171 parameter.set_fixed_value(Some(value));
172 }
173 Ok(())
174 }
175
176 pub fn free_parameter(&self, name: &str) -> LadduResult<()> {
178 self.assert_parameter_exists(name)?;
179 if let Some(parameter) = self.get(name) {
180 parameter.set_fixed_value(None);
181 }
182 Ok(())
183 }
184
185 pub fn contains_key(&self, name: &str) -> bool {
187 self.name_to_index.contains_key(name)
188 }
189
190 pub fn index(&self, name: &str) -> Option<usize> {
192 self.name_to_index.get(name).copied()
193 }
194
195 pub fn insert(&mut self, parameter: Parameter) -> Option<Parameter> {
197 let name = parameter.name();
198 if let Some(index) = self.index(&name) {
199 Some(std::mem::replace(&mut self.parameters[index], parameter))
200 } else {
201 let index = self.parameters.len();
202 self.parameters.push(parameter);
203 self.name_to_index.insert(name, index);
204 None
205 }
206 }
207
208 pub fn len(&self) -> usize {
210 self.parameters.len()
211 }
212
213 pub fn is_empty(&self) -> bool {
215 self.parameters.is_empty()
216 }
217
218 pub fn iter(&self) -> std::slice::Iter<'_, Parameter> {
220 self.parameters.iter()
221 }
222
223 pub fn get(&self, key: &str) -> Option<&Parameter> {
225 self.index(key).map(|index| &self.parameters[index])
226 }
227
228 pub fn get_indexed(&self, key: &str) -> Option<(usize, &Parameter)> {
230 self.index(key)
231 .map(|index| (index, &self.parameters[index]))
232 }
233
234 pub fn names(&self) -> Vec<String> {
236 self.parameters.iter().map(Parameter::name).collect()
237 }
238
239 pub fn filter(&self, predicate: impl Fn(&Parameter) -> bool) -> Self {
241 Self::from_parameters(
242 self.parameters
243 .iter()
244 .filter(|parameter| predicate(parameter))
245 .cloned()
246 .collect(),
247 )
248 }
249
250 pub fn free(&self) -> Self {
252 self.filter(|p| p.is_free())
253 }
254
255 pub fn fixed(&self) -> Self {
257 self.filter(|p| p.is_fixed())
258 }
259
260 pub fn initialized(&self) -> Self {
262 self.filter(|p| p.initial().is_some())
263 }
264
265 pub fn uninitialized(&self) -> Self {
267 self.filter(|p| p.initial().is_none())
268 }
269
270 pub fn assemble(&self, free_values: &[f64]) -> LadduResult<Parameters> {
274 let expected_free = self.free().len();
275 let n_fixed = self.fixed().len();
276 let mut values = vec![0.0; expected_free + n_fixed];
277 let mut storage_to_assembled = vec![0; self.len()];
278 let mut free_iter = free_values.iter();
279 let mut free_index = 0;
280 let mut fixed_index = expected_free;
281 for (storage_index, parameter) in self.parameters.iter().enumerate() {
282 if let Some(value) = parameter.fixed() {
283 values[fixed_index] = value;
284 storage_to_assembled[storage_index] = fixed_index;
285 fixed_index += 1;
286 } else if let Some(value) = free_iter.next() {
287 values[free_index] = *value;
288 storage_to_assembled[storage_index] = free_index;
289 free_index += 1;
290 } else {
291 return Err(LadduError::LengthMismatch {
292 context: "parameter values".to_string(),
293 expected: expected_free,
294 actual: free_values.len(),
295 });
296 }
297 }
298 if free_iter.next().is_some() {
299 return Err(LadduError::LengthMismatch {
300 context: "parameter values".to_string(),
301 expected: expected_free,
302 actual: free_values.len(),
303 });
304 }
305 Ok(Parameters::new(values, expected_free, storage_to_assembled))
306 }
307
308 pub fn merge(&self, other: &Self) -> (Self, Vec<usize>, Vec<usize>) {
313 let mut merged = self.clone();
314 let mut right_map = Vec::with_capacity(other.len());
315 for parameter in other {
316 let idx = merged.ensure_parameter(parameter.clone());
317 right_map.push(idx);
318 }
319 let left_map: Vec<usize> = (0..self.len())
320 .map(|index| merged.assembled_index(index))
321 .collect();
322 let right_map = right_map
323 .into_iter()
324 .map(|index| merged.assembled_index(index))
325 .collect();
326 (merged, left_map, right_map)
327 }
328
329 pub fn extend_from(&self, other: &Self) -> (Self, Vec<usize>) {
334 let mut merged = self.clone();
335 let mut indices = Vec::with_capacity(other.len());
336 for parameter in other {
337 let idx = merged.ensure_parameter(parameter.clone());
338 indices.push(idx);
339 }
340 let indices = indices
341 .into_iter()
342 .map(|index| merged.assembled_index(index))
343 .collect();
344 (merged, indices)
345 }
346
347 fn ensure_parameter(&mut self, parameter: Parameter) -> usize {
348 let name = parameter.name();
349 if let Some(idx) = self.index(&name) {
350 return idx;
351 }
352 let idx = self.len();
353 self.insert(parameter);
354 idx
355 }
356
357 fn assembled_index(&self, storage_index: usize) -> usize {
358 let n_free = self
359 .parameters
360 .iter()
361 .filter(|parameter| parameter.is_free())
362 .count();
363 let preceding_in_group = self.parameters[..storage_index]
364 .iter()
365 .filter(|parameter| self.parameters[storage_index].is_free() == parameter.is_free())
366 .count();
367 if self.parameters[storage_index].is_free() {
368 preceding_in_group
369 } else {
370 n_free + preceding_in_group
371 }
372 }
373
374 fn parameter_id(&self, storage_index: usize) -> ParameterID {
375 if self.parameters[storage_index].is_fixed() {
376 ParameterID::Constant(storage_index)
377 } else {
378 ParameterID::Parameter(storage_index)
379 }
380 }
381
382 fn assert_parameter_exists(&self, name: &str) -> LadduResult<()> {
383 if self.contains_key(name) {
384 Ok(())
385 } else {
386 Err(LadduError::UnregisteredParameter {
387 name: name.to_string(),
388 reason: "parameter not found".to_string(),
389 })
390 }
391 }
392}
393
394impl Display for ParameterMap {
395 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
396 writeln!(f, "ParameterMap:")?;
397 if self.parameters.is_empty() {
398 writeln!(f, " <empty>")?;
399 return Ok(());
400 }
401 writeln!(f, " free:")?;
402 let mut wrote_free = false;
403 for parameter in self
404 .parameters
405 .iter()
406 .filter(|parameter| parameter.is_free())
407 {
408 wrote_free = true;
409 writeln!(f, " {}", parameter.name())?;
410 }
411 if !wrote_free {
412 writeln!(f, " <none>")?;
413 }
414 writeln!(f, " fixed:")?;
415 let mut wrote_fixed = false;
416 for parameter in self
417 .parameters
418 .iter()
419 .filter(|parameter| parameter.is_fixed())
420 {
421 wrote_fixed = true;
422 if let Some(value) = parameter.fixed() {
423 writeln!(f, " {} = {}", parameter.name(), value)?;
424 }
425 }
426 if !wrote_fixed {
427 writeln!(f, " <none>")?;
428 }
429 Ok(())
430 }
431}