Skip to main content

demes/
specification.rs

1//! Implement the demes technical
2//! [specification](https://popsim-consortium.github.io/demes-spec-docs/main/specification.html)
3//! in terms of rust structs.
4
5use crate::error::OpaqueYamlError;
6use crate::time::*;
7use crate::CloningRate;
8use crate::DemeSize;
9use crate::DemesError;
10use crate::InputCloningRate;
11use crate::InputDemeSize;
12use crate::InputMigrationRate;
13use crate::InputProportion;
14use crate::InputSelfingRate;
15use crate::MigrationRate;
16use crate::Proportion;
17use crate::SelfingRate;
18use serde::{Deserialize, Serialize};
19use std::collections::{HashMap, HashSet};
20use std::convert::TryFrom;
21use std::fmt::Display;
22use std::io::Read;
23
24macro_rules! get_deme {
25    ($name: expr, $deme_map: expr, $demes: expr) => {
26        match $deme_map.get($name) {
27            Some(index) => $demes.get(*index),
28            None => None,
29        }
30    };
31}
32
33// Divide all times by the scaling factor
34fn rescale_input_time(input: Option<InputTime>, scaling_factor: f64) -> Option<InputTime> {
35    input.map(|time| (f64::from(time) / scaling_factor).into())
36}
37
38// Divide all sizes by the scaling factor
39fn rescale_input_size(input: Option<InputDemeSize>, scaling_factor: f64) -> Option<InputDemeSize> {
40    input.map(|size| (f64::from(size) / scaling_factor).into())
41}
42
43// Multiply all migration rates by the scaling factor
44fn rescale_input_migration_rate(
45    input: Option<InputMigrationRate>,
46    scaling_factor: f64,
47) -> Option<InputMigrationRate> {
48    input.map(|rate| (f64::from(rate) * scaling_factor).into())
49}
50
51fn size_at_details<F: Into<f64>>(
52    time: F,
53    epoch_start_time: f64,
54    epoch_end_time: f64,
55    epoch_start_size: f64,
56    epoch_end_size: f64,
57    size_function: SizeFunction,
58) -> Result<Option<f64>, DemesError> {
59    let time: f64 = time.into();
60    Time::try_from(time)
61        .map_err(|_| DemesError::EpochError(format!("invalid time value: {time:?}")))?;
62
63    if time == f64::INFINITY && epoch_start_time == f64::INFINITY {
64        return Ok(Some(epoch_start_size));
65    };
66    if time < epoch_end_time || time >= epoch_start_time {
67        return Ok(None);
68    }
69    let time_span = epoch_start_time - epoch_end_time;
70    let dt = epoch_start_time - time;
71    let size = match size_function {
72        SizeFunction::Constant => return Ok(Some(epoch_end_size)),
73        SizeFunction::Linear => {
74            epoch_start_size + dt * (epoch_end_size - epoch_start_size) / time_span
75        }
76        SizeFunction::Exponential => {
77            let r = (epoch_end_size / epoch_start_size).ln() / time_span;
78            epoch_start_size * (r * dt).exp()
79        }
80    };
81    Ok(Some(size))
82}
83
84/// Specify how deme sizes change during an [`Epoch`](crate::Epoch).
85///
86/// # Examples
87///
88/// ```
89/// let yaml = "
90/// time_units: years
91/// generation_time: 25
92/// description:
93///   A deme of 50 individuals that grew to 100 individuals
94///   in the last 100 years.
95///   Default behavior is that size changes are exponential.
96/// demes:
97///  - name: deme
98///    epochs:
99///     - start_size: 50
100///       end_time: 100
101///     - start_size: 50
102///       end_size: 100
103/// ";
104/// let graph = demes::loads(yaml).unwrap();
105/// let deme = graph.get_deme_from_name("deme").unwrap();
106/// assert_eq!(deme.num_epochs(), 2);
107/// let last_epoch = deme.get_epoch(1).unwrap();
108/// assert!(matches!(last_epoch.size_function(),
109///                  demes::SizeFunction::Exponential));
110/// let first_epoch = deme.get_epoch(0).unwrap();
111/// assert!(matches!(first_epoch.size_function(),
112///                  demes::SizeFunction::Constant));
113/// ```
114///
115/// Let's change the function to linear for the second
116/// epoch:
117///
118/// ```
119/// let yaml = "
120/// time_units: years
121/// generation_time: 25
122/// description:
123///   A deme of 50 individuals that grew to 100 individuals
124///   in the last 100 years.
125/// demes:
126///  - name: deme
127///    epochs:
128///     - start_size: 50
129///       end_time: 100
130///     - start_size: 50
131///       end_size: 100
132///       size_function: linear
133/// ";
134/// let graph = demes::loads(yaml).unwrap();
135/// let deme = graph.get_deme_from_name("deme").unwrap();
136/// let last_epoch = deme.get_epoch(1).unwrap();
137/// assert!(matches!(last_epoch.size_function(),
138///                  demes::SizeFunction::Linear));
139/// ```
140#[derive(Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)]
141#[serde(rename_all = "lowercase")]
142#[non_exhaustive]
143pub enum SizeFunction {
144    #[allow(missing_docs)]
145    Constant,
146    #[allow(missing_docs)]
147    Exponential,
148    #[allow(missing_docs)]
149    Linear,
150}
151
152#[derive(Clone, Debug)]
153enum InputFormatInternal {
154    Yaml(String),
155    #[allow(dead_code)]
156    Json(String),
157    #[allow(dead_code)]
158    Toml(String),
159}
160
161#[derive(Debug)]
162#[non_exhaustive]
163/// The string input format for a graph
164pub enum InputFormat<'graph> {
165    /// Input is YAML
166    Yaml(&'graph str),
167    /// Input is JSON
168    Json(&'graph str),
169    /// Input is TOML
170    Toml(&'graph str),
171}
172
173impl InputFormat<'_> {
174    /// Get the input data as [str]
175    pub fn to_str(&self) -> &str {
176        match self {
177            Self::Yaml(s) => s,
178            Self::Json(s) => s,
179            Self::Toml(s) => s,
180        }
181    }
182}
183
184impl Display for SizeFunction {
185    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
186        let value = match self {
187            SizeFunction::Constant => "constant",
188            SizeFunction::Linear => "linear",
189            SizeFunction::Exponential => "exponential",
190        };
191        write!(f, "{value}")
192    }
193}
194
195/// A deme can be identified as an index
196/// or as a name
197#[derive(Copy, Clone, Debug)]
198pub enum DemeId<'name> {
199    /// The index of a deme
200    Index(usize),
201    /// The name of a deme
202    Name(&'name str),
203}
204
205impl From<usize> for DemeId<'_> {
206    fn from(value: usize) -> Self {
207        Self::Index(value)
208    }
209}
210
211impl<'name> From<&'name str> for DemeId<'name> {
212    fn from(value: &'name str) -> Self {
213        Self::Name(value)
214    }
215}
216
217/// An unresolved migration epoch.
218///
219/// All input migrations are resolved to [`AsymmetricMigration`](crate::AsymmetricMigration)
220/// instances.
221///
222/// # Examples
223///
224/// ## [`GraphBuilder`](crate::GraphBuilder)
225///
226/// This type supports member field initialization using defaults.
227/// This form of initalization is used in:
228///
229/// * [`GraphDefaults`](crate::GraphDefaults)
230///
231/// ```
232/// let _ = demes::UnresolvedMigration{source: Some("A".to_string()),
233///                                    dest: Some("B".to_string()),
234///                                    rate: Some(0.2.into()),
235///                                    ..Default::default()
236///                                    };
237/// ```
238#[derive(Clone, Default, Debug, Deserialize, PartialEq)]
239#[serde(deny_unknown_fields)]
240pub struct UnresolvedMigration {
241    /// The demes involved in symmetric migration epochs
242    pub demes: Option<Vec<String>>,
243    /// The source deme of an asymmetric migration epoch
244    pub source: Option<String>,
245    /// The destination deme of an asymmetric migration epoch
246    pub dest: Option<String>,
247    /// The start time of a migration epoch
248    pub start_time: Option<InputTime>,
249    /// The end time of a migration epoch
250    pub end_time: Option<InputTime>,
251    /// The rate during a migration epoch
252    pub rate: Option<InputMigrationRate>,
253}
254
255impl UnresolvedMigration {
256    fn validate(&self) -> Result<(), DemesError> {
257        if let Some(value) = self.start_time {
258            Time::try_from(value).map_err(|_| {
259                DemesError::MigrationError(format!("invalid start_time: {value:?}"))
260            })?;
261        }
262        if let Some(value) = self.end_time {
263            Time::try_from(value)
264                .map_err(|_| DemesError::MigrationError(format!("invalid end_time: {value:?}")))?;
265        }
266        Ok(())
267    }
268
269    fn valid_asymmetric_or_err(&self) -> Result<(), DemesError> {
270        let source = self
271            .source
272            .as_ref()
273            .ok_or_else(|| DemesError::MigrationError("source is none".to_string()))?;
274
275        let dest = self
276            .dest
277            .as_ref()
278            .ok_or_else(|| DemesError::MigrationError("dest is none".to_string()))?;
279
280        self.rate.ok_or_else(|| {
281            DemesError::MigrationError(format!(
282                "rate frmm source: {source} to dest: {dest} is None",
283            ))
284        })?;
285
286        Ok(())
287    }
288
289    fn valid_symmetric_or_err(&self) -> Result<(), DemesError> {
290        let demes = self
291            .demes
292            .as_ref()
293            .ok_or_else(|| DemesError::MigrationError("demes is None".to_string()))?;
294        self.rate.ok_or_else(|| {
295            DemesError::MigrationError(format!("migration rate among {demes:?} is None",))
296        })?;
297        Ok(())
298    }
299
300    fn resolved_rate_or_err(&self) -> Result<InputMigrationRate, DemesError> {
301        self.rate
302            .ok_or_else(|| DemesError::MigrationError("migration rate not resolved".to_string()))
303    }
304
305    fn resolved_dest_or_err(&self) -> Result<String, DemesError> {
306        match &self.dest {
307            Some(dest) => Ok(dest.to_string()),
308            None => Err(DemesError::MigrationError(
309                "migration dest not resolved".to_string(),
310            )),
311        }
312    }
313
314    fn resolved_source_or_err(&self) -> Result<String, DemesError> {
315        match &self.source {
316            Some(source) => Ok(source.to_string()),
317            None => Err(DemesError::MigrationError(
318                "migration source not resolved".to_string(),
319            )),
320        }
321    }
322
323    /// Set the source deme
324    ///
325    /// See ['GraphBuilder'].
326    ///
327    /// # Examples
328    ///
329    /// ```
330    /// let _ = demes::UnresolvedMigration::default().set_source("A");
331    /// ```
332    pub fn set_source<A>(self, source: A) -> Self
333    where
334        A: AsRef<str>,
335    {
336        Self {
337            source: Some(source.as_ref().to_owned()),
338            ..self
339        }
340    }
341
342    /// Set the destination deme
343    ///
344    /// See ['GraphBuilder'].
345    ///
346    /// # Examples
347    ///
348    /// ```
349    /// let _ = demes::UnresolvedMigration::default().set_dest("A");
350    /// ```
351    pub fn set_dest<A>(self, dest: A) -> Self
352    where
353        A: AsRef<str>,
354    {
355        Self {
356            dest: Some(dest.as_ref().to_owned()),
357            ..self
358        }
359    }
360
361    /// Set the demes
362    ///
363    /// See ['GraphBuilder'].
364    ///
365    /// # Examples
366    ///
367    /// ```
368    /// let _ = demes::UnresolvedMigration::default().set_demes(["A", "B"].as_slice());
369    /// ```
370    pub fn set_demes<I, A>(self, d: I) -> Self
371    where
372        I: IntoIterator<Item = A>,
373        A: AsRef<str>,
374    {
375        Self {
376            demes: Some(
377                d.into_iter()
378                    .map(|a| a.as_ref().to_owned())
379                    .collect::<Vec<_>>(),
380            ),
381            ..self
382        }
383    }
384
385    /// Set the start time
386    ///
387    /// See ['GraphBuilder'].
388    ///
389    /// # Examples
390    ///
391    /// ```
392    /// let _ = demes::UnresolvedMigration::default().set_start_time(1.0);
393    /// ```
394    pub fn set_start_time<T>(self, time: T) -> Self
395    where
396        T: Into<InputTime>,
397    {
398        Self {
399            start_time: Some(time.into()),
400            ..self
401        }
402    }
403
404    /// Set the end time
405    ///
406    /// See ['GraphBuilder'].
407    ///
408    /// # Examples
409    ///
410    /// ```
411    /// let _ = demes::UnresolvedMigration::default().set_end_time(10.);
412    /// ```
413    pub fn set_end_time<T>(self, time: T) -> Self
414    where
415        T: Into<InputTime>,
416    {
417        Self {
418            end_time: Some(time.into()),
419            ..self
420        }
421    }
422
423    /// Set the symmetric migration rate among all `demes`.
424    ///
425    /// See ['GraphBuilder'].
426    ///
427    /// # Examples
428    ///
429    /// ```
430    /// let _ = demes::UnresolvedMigration::default().set_rate(0.3333);
431    /// ```
432    pub fn set_rate<R>(self, rate: R) -> Self
433    where
434        R: Into<InputMigrationRate>,
435    {
436        Self {
437            rate: Some(rate.into()),
438            ..self
439        }
440    }
441
442    fn rescale(&mut self, scaling_factor: f64) -> Result<(), DemesError> {
443        self.start_time = rescale_input_time(self.start_time, scaling_factor);
444        self.end_time = rescale_input_time(self.end_time, scaling_factor);
445        self.rate = rescale_input_migration_rate(self.rate, scaling_factor);
446        Ok(())
447    }
448}
449
450/// An asymmetric migration epoch.
451///
452/// All input migrations are resolved to asymmetric migration instances.
453#[derive(Clone, Debug, Serialize, Eq, PartialEq)]
454pub struct AsymmetricMigration {
455    source: String,
456    dest: String,
457    rate: MigrationRate,
458    start_time: Time,
459    end_time: Time,
460}
461
462impl AsymmetricMigration {
463    fn resolved_time_to_generations(
464        &mut self,
465        generation_time: GenerationTime,
466        rounding: fn(Time, GenerationTime) -> Time,
467    ) -> Result<(), DemesError> {
468        self.start_time = convert_resolved_time_to_generations(
469            generation_time,
470            rounding,
471            DemesError::MigrationError,
472            "start_time is not resolved",
473            Some(self.start_time),
474        )?;
475        self.end_time = convert_resolved_time_to_generations(
476            generation_time,
477            rounding,
478            DemesError::MigrationError,
479            "end_time is not resolved",
480            Some(self.end_time),
481        )?;
482
483        if self.end_time >= self.start_time {
484            Err(DemesError::MigrationError(
485                "conversion of migration times to generations resulted in a zero-length epoch"
486                    .to_string(),
487            ))
488        } else {
489            Ok(())
490        }
491    }
492
493    /// Get name of the source deme
494    pub fn source(&self) -> &str {
495        &self.source
496    }
497
498    /// Get name of the destination deme
499    pub fn dest(&self) -> &str {
500        &self.dest
501    }
502
503    /// Get the resolved migration rate
504    pub fn rate(&self) -> MigrationRate {
505        self.rate
506    }
507
508    /// Resolved start [`Time`](crate::Time) of the migration epoch
509    pub fn start_time(&self) -> Time {
510        self.start_time
511    }
512
513    /// Resolved end [`Time`](crate::Time) of the migration epoch
514    pub fn end_time(&self) -> Time {
515        self.end_time
516    }
517
518    /// Resolved time interval of the migration epoch
519    pub fn time_interval(&self) -> TimeInterval {
520        TimeInterval::new(self.start_time(), self.end_time())
521    }
522}
523
524#[derive(Clone, Debug)]
525enum Migration {
526    Asymmetric(UnresolvedMigration),
527    Symmetric(UnresolvedMigration),
528}
529
530impl TryFrom<UnresolvedMigration> for Migration {
531    type Error = DemesError;
532
533    fn try_from(value: UnresolvedMigration) -> Result<Self, Self::Error> {
534        if value.demes.is_none() {
535            if value.source.is_none() || value.dest.is_none() {
536                Err(DemesError::MigrationError(
537                    "a migration must specify either demes or source and dest".to_string(),
538                ))
539            } else {
540                value.valid_asymmetric_or_err()?;
541                Ok(Migration::Asymmetric(UnresolvedMigration {
542                    demes: None,
543                    source: Some(value.resolved_source_or_err()?),
544                    dest: Some(value.resolved_dest_or_err()?),
545                    rate: Some(value.resolved_rate_or_err()?),
546                    start_time: value.start_time,
547                    end_time: value.end_time,
548                }))
549            }
550        } else if value.source.is_some() || value.dest.is_some() {
551            Err(DemesError::MigrationError(
552                "a migration must specify either demes or source and dest, but not both"
553                    .to_string(),
554            ))
555        } else {
556            value.valid_symmetric_or_err()?;
557            Ok(Migration::Symmetric(value))
558        }
559    }
560}
561
562impl From<Migration> for UnresolvedMigration {
563    fn from(value: Migration) -> Self {
564        match value {
565            Migration::Symmetric(s) => s,
566            Migration::Asymmetric(a) => a,
567        }
568    }
569}
570
571impl From<AsymmetricMigration> for UnresolvedMigration {
572    fn from(value: AsymmetricMigration) -> Self {
573        Self {
574            demes: None,
575            source: Some(value.source().to_owned()),
576            dest: Some(value.dest.to_owned()),
577            start_time: Some(value.start_time().into()),
578            end_time: Some(value.end_time().into()),
579            rate: Some(f64::from(value.rate()).into()),
580        }
581    }
582}
583
584/// A resolved Pulse event
585#[derive(Clone, Debug, Serialize, Eq, PartialEq)]
586#[serde(deny_unknown_fields)]
587pub struct Pulse {
588    sources: Vec<String>,
589    dest: String,
590    time: Time,
591    proportions: Vec<Proportion>,
592}
593
594/// An unresolved Pulse event.
595#[derive(Clone, Default, Debug, Deserialize, PartialEq)]
596#[serde(deny_unknown_fields)]
597pub struct UnresolvedPulse {
598    #[allow(missing_docs)]
599    pub sources: Option<Vec<String>>,
600    #[allow(missing_docs)]
601    pub dest: Option<String>,
602    #[allow(missing_docs)]
603    pub time: Option<InputTime>,
604    #[allow(missing_docs)]
605    pub proportions: Option<Vec<InputProportion>>,
606}
607
608impl TryFrom<UnresolvedPulse> for Pulse {
609    type Error = DemesError;
610    fn try_from(value: UnresolvedPulse) -> Result<Self, Self::Error> {
611        let input_proportions = value.proportions.ok_or_else(|| {
612            DemesError::PulseError("pulse proportions are unresolved".to_string())
613        })?;
614        let mut proportions = vec![];
615        for p in input_proportions {
616            proportions.push(Proportion::try_from(p)?);
617        }
618        Ok(Self {
619            sources: value.sources.ok_or_else(|| {
620                DemesError::PulseError("pulse sources are unresolved".to_string())
621            })?,
622            dest: value
623                .dest
624                .ok_or_else(|| DemesError::PulseError("pulse dest are unresolved".to_string()))?,
625            time: value
626                .time
627                .ok_or_else(|| DemesError::PulseError("pulse time are unresolved".to_string()))?
628                .try_into()?,
629            proportions,
630        })
631    }
632}
633
634impl From<Pulse> for UnresolvedPulse {
635    fn from(value: Pulse) -> Self {
636        Self {
637            sources: Some(value.sources),
638            dest: Some(value.dest),
639            time: Some(f64::from(value.time).into()),
640            proportions: Some(
641                value
642                    .proportions
643                    .into_iter()
644                    .map(|p| f64::from(p).into())
645                    .collect::<Vec<_>>(),
646            ),
647        }
648    }
649}
650
651impl UnresolvedPulse {
652    fn validate_as_default(&self) -> Result<(), DemesError> {
653        if let Some(value) = self.time {
654            Time::try_from(value)
655                .map_err(|_| DemesError::PulseError(format!("invalid time: {value:?}")))?;
656        }
657
658        if let Some(proportions) = &self.proportions {
659            for v in proportions {
660                if Proportion::try_from(*v).is_err() {
661                    return Err(DemesError::PulseError(format!(
662                        "invalid proportion: {:?}",
663                        *v
664                    )));
665                }
666            }
667        }
668
669        Ok(())
670    }
671
672    fn get_proportions(&self) -> Result<&[InputProportion], DemesError> {
673        Ok(self
674            .proportions
675            .as_ref()
676            .ok_or_else(|| DemesError::PulseError("proportions are None".to_string()))?)
677    }
678
679    fn get_time(&self) -> Result<Time, DemesError> {
680        self.time
681            .ok_or_else(|| DemesError::PulseError("time is None".to_string()))?
682            .try_into()
683    }
684
685    fn get_sources(&self) -> Result<&[String], DemesError> {
686        Ok(self
687            .sources
688            .as_ref()
689            .ok_or_else(|| DemesError::PulseError("sources are None".to_string()))?)
690    }
691
692    fn get_dest(&self) -> Result<&str, DemesError> {
693        Ok(self
694            .dest
695            .as_ref()
696            .ok_or_else(|| DemesError::PulseError("pulse dest is None".to_string()))?)
697    }
698
699    fn resolve(&mut self, defaults: &GraphDefaults) -> Result<(), DemesError> {
700        defaults.apply_pulse_defaults(self);
701        Ok(())
702    }
703
704    fn validate_proportions(&self, sources: &[String]) -> Result<(), DemesError> {
705        if self.proportions.is_none() {
706            return Err(DemesError::PulseError("proportions is None".to_string()));
707        }
708        let proportions = self.get_proportions()?;
709        for p in proportions.iter() {
710            Proportion::try_from(*p)
711                .map_err(|_| DemesError::PulseError(format!("invalid proportion: {:?}", *p)))?;
712        }
713        if proportions.len() != sources.len() {
714            return Err(DemesError::PulseError(format!("number of sources must equal number of proportions; got {} source and {} proportions", sources.len(), proportions.len())));
715        }
716
717        let sum_proportions = proportions
718            .iter()
719            .fold(0.0, |sum, &proportion| sum + f64::from(proportion));
720
721        if !(1e-9..1.0 + 1e-9).contains(&sum_proportions) {
722            return Err(DemesError::PulseError(format!(
723                "pulse proportions must sum to 0.0 < p < 1.0, got: {sum_proportions}",
724            )));
725        }
726
727        Ok(())
728    }
729
730    fn validate_pulse_time(
731        &self,
732        deme_map: &DemeMap,
733        demes: &[UnresolvedDeme],
734        time: Time,
735        dest: &str,
736        sources: &[String],
737    ) -> Result<(), DemesError> {
738        if !time.is_valid_pulse_time() {
739            return Err(DemesError::PulseError(format!(
740                "invalid pulse time: {}",
741                f64::from(time)
742            )));
743        }
744
745        for source_name in sources {
746            let source = get_deme!(source_name, deme_map, demes).ok_or_else(|| {
747                DemesError::PulseError(format!("invalid pulse source: {source_name}"))
748            })?;
749
750            let ti = source.get_time_interval()?;
751
752            if !ti.contains_exclusive_start_inclusive_end(time) {
753                return Err(DemesError::PulseError(format!(
754                    "pulse at time: {time:?} does not overlap with source: {source_name}",
755                )));
756            }
757        }
758
759        let dest_deme = get_deme!(dest, deme_map, demes)
760            .ok_or_else(|| DemesError::PulseError(format!("invalid pulse dest: {dest}")))?;
761        let ti = dest_deme.get_time_interval()?;
762        if !ti.contains_inclusive_start_exclusive_end(time) {
763            return Err(DemesError::PulseError(format!(
764                "pulse at time: {:?} does not overlap with dest: {}",
765                time, dest_deme.name,
766            )));
767        }
768
769        Ok(())
770    }
771
772    fn validate_destination_deme_existence(
773        &self,
774        dest: &str,
775        deme_map: &DemeMap,
776        demes: &[UnresolvedDeme],
777        time: Time,
778    ) -> Result<(), DemesError> {
779        match get_deme!(dest, deme_map, demes) {
780            Some(d) => {
781                let t = d.get_time_interval()?;
782                if !t.contains_inclusive(time) {
783                    return Err(DemesError::PulseError(format!(
784                        "destination deme {dest} does not exist at time of pulse",
785                    )));
786                }
787                Ok(())
788            }
789            None => Err(DemesError::PulseError(format!(
790                "pulse deme {dest} is invalid",
791            ))),
792        }
793    }
794
795    fn dest_is_not_source(&self, dest: &str, sources: &[String]) -> Result<(), DemesError> {
796        if sources.iter().any(|s| s.as_str() == dest) {
797            Err(DemesError::PulseError(format!(
798                "dest: {dest} is also listed as a source",
799            )))
800        } else {
801            Ok(())
802        }
803    }
804
805    fn sources_are_unique(&self, sources: &[String]) -> Result<(), DemesError> {
806        let mut unique_sources = HashSet::<String>::default();
807        for source in sources {
808            if unique_sources.contains(source) {
809                return Err(DemesError::PulseError(format!(
810                    "source: {source} listed multiple times",
811                )));
812            }
813            unique_sources.insert(source.clone());
814        }
815        Ok(())
816    }
817
818    fn validate(&self, deme_map: &DemeMap, demes: &[UnresolvedDeme]) -> Result<(), DemesError> {
819        let dest = self.get_dest()?;
820        let sources = self.get_sources()?;
821        let time = self.get_time()?;
822        self.validate_proportions(sources)?;
823
824        sources.iter().try_for_each(|source| {
825            self.validate_destination_deme_existence(source, deme_map, demes, time)
826        })?;
827
828        self.validate_destination_deme_existence(dest, deme_map, demes, time)?;
829        self.dest_is_not_source(dest, sources)?;
830        self.sources_are_unique(sources)?;
831        self.validate_pulse_time(deme_map, demes, time, dest, sources)
832    }
833
834    fn rescale(&mut self, scaling_factor: f64) -> Result<(), DemesError> {
835        self.time = rescale_input_time(self.time, scaling_factor);
836        Ok(())
837    }
838}
839
840impl Pulse {
841    fn resolved_time_to_generations(
842        &mut self,
843        generation_time: GenerationTime,
844        rounding: fn(Time, GenerationTime) -> Time,
845    ) -> Result<(), DemesError> {
846        self.time = convert_resolved_time_to_generations(
847            generation_time,
848            rounding,
849            DemesError::PulseError,
850            "pulse time is note resolved",
851            Some(self.time),
852        )?;
853        Ok(())
854    }
855
856    /// Resolved time of the pulse
857    pub fn time(&self) -> Time {
858        self.time
859    }
860
861    /// Resolved pulse source demes as slice
862    pub fn sources(&self) -> &[String] {
863        &self.sources
864    }
865
866    /// Resolved pulse destination deme
867    pub fn dest(&self) -> &str {
868        &self.dest
869    }
870
871    /// Resolved pulse proportions
872    pub fn proportions(&self) -> &[Proportion] {
873        &self.proportions
874    }
875}
876
877/// HDM representation of an epoch.
878///
879/// Direct construction of this type is useful in:
880/// * [`DemeDefaults`](crate::DemeDefaults)
881/// * [`GraphDefaults`](crate::GraphDefaults)
882///
883/// # Examples
884///
885/// This type supports field initialization with defaults:
886///
887/// ```
888/// let _ = demes::UnresolvedEpoch{
889///              start_size: Some(demes::InputDemeSize::from(1e6)),
890///              ..Default::default()
891///              };
892/// ```
893///
894/// Type inference improves ergonomics:
895///
896/// ```
897/// let _ = demes::UnresolvedEpoch{
898///              start_size: Some(1e6.into()),
899///              ..Default::default()
900///              };
901/// ```
902#[derive(Clone, Copy, Debug, Default, Deserialize, PartialEq)]
903#[serde(deny_unknown_fields)]
904pub struct UnresolvedEpoch {
905    #[allow(missing_docs)]
906    pub end_time: Option<InputTime>,
907    // NOTE: the Option is for input. An actual value must be put in via resolution.
908    #[allow(missing_docs)]
909    pub start_size: Option<InputDemeSize>,
910    // NOTE: the Option is for input. An actual value must be put in via resolution.
911    #[allow(missing_docs)]
912    pub end_size: Option<InputDemeSize>,
913    #[allow(missing_docs)]
914    #[serde(skip_serializing_if = "Option::is_none")]
915    pub size_function: Option<crate::specification::SizeFunction>,
916    #[allow(missing_docs)]
917    #[serde(skip_serializing_if = "Option::is_none")]
918    pub cloning_rate: Option<InputCloningRate>,
919    #[allow(missing_docs)]
920    #[serde(skip_serializing_if = "Option::is_none")]
921    pub selfing_rate: Option<InputSelfingRate>,
922}
923
924impl From<Epoch> for UnresolvedEpoch {
925    fn from(value: Epoch) -> Self {
926        Self {
927            end_time: Some(value.end_time.into()),
928            start_size: Some(f64::from(value.start_size).into()),
929            end_size: Some(f64::from(value.end_size).into()),
930            size_function: Some(value.size_function),
931            cloning_rate: Some(f64::from(value.cloning_rate).into()),
932            selfing_rate: Some(f64::from(value.selfing_rate).into()),
933        }
934    }
935}
936
937impl UnresolvedEpoch {
938    fn validate_as_default(&self) -> Result<(), DemesError> {
939        if let Some(value) = self.end_time {
940            Time::try_from(value)
941                .map_err(|_| DemesError::EpochError(format!("invalid end_time: {value:?}")))?;
942        }
943        if let Some(value) = self.start_size {
944            DemeSize::try_from(value)
945                .map_err(|_| DemesError::EpochError(format!("invalid start_size: {value:?}")))?;
946        }
947        if let Some(value) = self.end_size {
948            DemeSize::try_from(value)
949                .map_err(|_| DemesError::EpochError(format!("invalid end_size: {value:?}")))?;
950        }
951        if let Some(value) = self.cloning_rate {
952            CloningRate::try_from(value)
953                .map_err(|_| DemesError::EpochError(format!("invalid cloning_rate: {value:?}")))?;
954        }
955        if let Some(value) = self.selfing_rate {
956            SelfingRate::try_from(value)
957                .map_err(|_| DemesError::EpochError(format!("invalid selfing_rate: {value:?}")))?;
958        }
959        Ok(())
960    }
961
962    fn rescale(&mut self, scaling_factor: f64) -> Result<(), DemesError> {
963        self.end_time = rescale_input_time(self.end_time, scaling_factor);
964        self.start_size = rescale_input_size(self.start_size, scaling_factor);
965        self.end_size = rescale_input_size(self.end_size, scaling_factor);
966
967        self.cloning_rate = self.cloning_rate.map_or_else(
968            || None,
969            |c| {
970                if f64::from(c) > 0.0 {
971                    Some(c)
972                } else {
973                    None
974                }
975            },
976        );
977        self.selfing_rate = self.selfing_rate.map_or_else(
978            || None,
979            |s| {
980                if f64::from(s) > 0.0 {
981                    Some(s)
982                } else {
983                    None
984                }
985            },
986        );
987        Ok(())
988    }
989}
990
991/// A resolved epoch
992#[derive(Clone, Copy, Debug, Serialize, Eq, PartialEq)]
993#[serde(deny_unknown_fields)]
994pub struct Epoch {
995    #[serde(skip)]
996    start_time: Time,
997    end_time: Time,
998    start_size: DemeSize,
999    end_size: DemeSize,
1000    size_function: SizeFunction,
1001    cloning_rate: CloningRate,
1002    selfing_rate: SelfingRate,
1003}
1004
1005impl Epoch {
1006    fn new_from_unresolved(
1007        start_time: InputTime,
1008        unresolved: UnresolvedEpoch,
1009    ) -> Result<Self, DemesError> {
1010        Ok(Self {
1011            start_time: start_time.try_into().map_err(|_| {
1012                DemesError::EpochError(format!("invalid start_time: {start_time:?}"))
1013            })?,
1014            end_time: unresolved
1015                .end_time
1016                .ok_or_else(|| DemesError::EpochError("end_time unresolved".to_string()))?
1017                .try_into()
1018                .map_err(|_| {
1019                    DemesError::EpochError(format!("invalid end_time: {:?}", unresolved.end_time))
1020                })?,
1021            start_size: unresolved
1022                .start_size
1023                .ok_or_else(|| DemesError::EpochError("end_time unresolved".to_string()))?
1024                .try_into()
1025                .map_err(|_| {
1026                    DemesError::EpochError(format!(
1027                        "invalid start_size: {:?}",
1028                        unresolved.start_size
1029                    ))
1030                })?,
1031            end_size: unresolved
1032                .end_size
1033                .ok_or_else(|| DemesError::EpochError("end_time unresolved".to_string()))?
1034                .try_into()
1035                .map_err(|_| {
1036                    DemesError::EpochError(format!(
1037                        "invalid cloning_rate: {:?}",
1038                        unresolved.cloning_rate
1039                    ))
1040                })?,
1041            size_function: unresolved
1042                .size_function
1043                .ok_or_else(|| DemesError::EpochError("end_time unresolved".to_string()))?,
1044            cloning_rate: unresolved
1045                .cloning_rate
1046                .ok_or_else(|| DemesError::EpochError("end_time unresolved".to_string()))?
1047                .try_into()
1048                .map_err(|_| {
1049                    DemesError::EpochError(format!(
1050                        "invalid cloning_rate: {:?}",
1051                        unresolved.cloning_rate
1052                    ))
1053                })?,
1054
1055            selfing_rate: unresolved
1056                .selfing_rate
1057                .ok_or_else(|| DemesError::EpochError("end_time unresolved".to_string()))?
1058                .try_into()?,
1059        })
1060    }
1061
1062    fn resolved_time_to_generations(
1063        &mut self,
1064        generation_time: GenerationTime,
1065        rounding: fn(Time, GenerationTime) -> Time,
1066    ) -> Result<(), DemesError> {
1067        self.start_time = convert_resolved_time_to_generations(
1068            generation_time,
1069            rounding,
1070            DemesError::EpochError,
1071            "start_time is unresolved",
1072            Some(self.start_time),
1073        )?;
1074        self.end_time = convert_resolved_time_to_generations(
1075            generation_time,
1076            rounding,
1077            DemesError::EpochError,
1078            "end_time is unresolved",
1079            Some(self.end_time),
1080        )?;
1081        Ok(())
1082    }
1083
1084    /// The resolved size function
1085    pub fn size_function(&self) -> SizeFunction {
1086        self.size_function
1087    }
1088
1089    /// The resolved selfing rate
1090    pub fn selfing_rate(&self) -> SelfingRate {
1091        self.selfing_rate
1092    }
1093
1094    /// The resolved cloning rate
1095    pub fn cloning_rate(&self) -> CloningRate {
1096        self.cloning_rate
1097    }
1098
1099    /// The resolved start time
1100    pub fn start_time(&self) -> Time {
1101        self.start_time
1102    }
1103
1104    /// The resolved end time
1105    pub fn end_time(&self) -> Time {
1106        self.end_time
1107    }
1108
1109    /// The resolved start size
1110    pub fn start_size(&self) -> DemeSize {
1111        self.start_size
1112    }
1113
1114    /// The resolved end size
1115    pub fn end_size(&self) -> DemeSize {
1116        self.end_size
1117    }
1118
1119    /// The resolved time interval
1120    pub fn time_interval(&self) -> TimeInterval {
1121        TimeInterval::new(self.start_time(), self.end_time())
1122    }
1123
1124    /// Size of Epoch at a given time
1125    ///
1126    /// # Returns
1127    ///
1128    /// * `Some(size)` if `time` falls within the epoch's time interval.
1129    /// * `None` if `time` is a valid time but outside of the epochs' time
1130    ///   interval.
1131    ///
1132    /// # Errors
1133    ///
1134    /// * If `time` fails to convert into [`Time`].
1135    /// * If conversion from [`f64`] to [`DemeSize`] fails
1136    ///   during calculation of size change function.
1137    pub fn size_at<F: Into<f64>>(&self, time: F) -> Result<Option<DemeSize>, DemesError> {
1138        match size_at_details(
1139            time,
1140            self.start_time.into(),
1141            self.end_time.into(),
1142            self.start_size.into(),
1143            self.end_size().into(),
1144            self.size_function,
1145        )? {
1146            None => Ok(None),
1147            Some(size) => match DemeSize::try_from(size) {
1148                Ok(size) => Ok(Some(size)),
1149                Err(_) => Err(DemesError::EpochError(format!(
1150                    "size calculation led to invalid size: {size}"
1151                ))),
1152            },
1153        }
1154    }
1155}
1156
1157impl UnresolvedEpoch {
1158    fn resolve_size_function(
1159        &mut self,
1160        defaults: &GraphDefaults,
1161        deme_defaults: &DemeDefaults,
1162    ) -> Option<()> {
1163        if self.size_function.is_some() {
1164            return Some(());
1165        }
1166
1167        if self.start_size? == self.end_size? {
1168            self.size_function = Some(SizeFunction::Constant);
1169        } else {
1170            self.size_function =
1171                defaults.apply_epoch_size_function_defaults(self.size_function, deme_defaults);
1172        }
1173
1174        Some(())
1175    }
1176
1177    fn resolve_selfing_rate(&mut self, defaults: &GraphDefaults, deme_defaults: &DemeDefaults) {
1178        if self.selfing_rate.is_none() {
1179            self.selfing_rate = match deme_defaults.epoch.selfing_rate {
1180                Some(selfing_rate) => Some(selfing_rate),
1181                None => match defaults.epoch.selfing_rate {
1182                    Some(selfing_rate) => Some(selfing_rate),
1183                    None => Some(InputSelfingRate::default()),
1184                },
1185            }
1186        }
1187    }
1188
1189    fn resolve_cloning_rate(&mut self, defaults: &GraphDefaults, deme_defaults: &DemeDefaults) {
1190        if self.cloning_rate.is_none() {
1191            self.cloning_rate = match deme_defaults.epoch.cloning_rate {
1192                Some(cloning_rate) => Some(cloning_rate),
1193                None => match defaults.epoch.cloning_rate {
1194                    Some(cloning_rate) => Some(cloning_rate),
1195                    None => Some(InputCloningRate::default()),
1196                },
1197            }
1198        }
1199    }
1200
1201    fn resolve(
1202        &mut self,
1203        defaults: &GraphDefaults,
1204        deme_defaults: &DemeDefaults,
1205    ) -> Result<(), DemesError> {
1206        self.resolve_selfing_rate(defaults, deme_defaults);
1207        self.resolve_cloning_rate(defaults, deme_defaults);
1208        self.resolve_size_function(defaults, deme_defaults)
1209            .ok_or_else(|| DemesError::EpochError("failed to resolve size_function".to_string()))
1210    }
1211
1212    fn validate_end_time(&self, index: usize, deme_name: &str) -> Result<(), DemesError> {
1213        if self.end_time.is_none() {
1214            Err(DemesError::EpochError(format!(
1215                "deme {deme_name}, epoch {index}: end time is None",
1216            )))
1217        } else {
1218            Ok(())
1219        }
1220    }
1221
1222    fn validate_cloning_rate(&self, index: usize, deme_name: &str) -> Result<(), DemesError> {
1223        match self.cloning_rate {
1224            Some(value) => {
1225                if CloningRate::try_from(value).is_err() {
1226                    Err(DemesError::EpochError(format!(
1227                        "deme {deme_name}, epoch {index}: invalid cloning_rate: {value:?}"
1228                    )))
1229                } else {
1230                    Ok(())
1231                }
1232            }
1233            None => Err(DemesError::EpochError(format!(
1234                "deme {deme_name}, epoch {index}:cloning_rate is None",
1235            ))),
1236        }
1237    }
1238
1239    fn validate_selfing_rate(&self, index: usize, deme_name: &str) -> Result<(), DemesError> {
1240        match self.selfing_rate {
1241            Some(value) => SelfingRate::try_from(value)
1242                .map_err(|_| DemesError::EpochError(format!("invalid selfing_rate: {value:?}")))
1243                .map(|_| ()),
1244            None => Err(DemesError::EpochError(format!(
1245                "deme {deme_name}, epoch {index}: selfing_rate is None",
1246            ))),
1247        }
1248    }
1249
1250    fn validate_size_function(
1251        &self,
1252        index: usize,
1253        deme_name: &str,
1254        start_size: InputDemeSize,
1255        end_size: InputDemeSize,
1256    ) -> Result<(), DemesError> {
1257        let size_function = self.size_function.ok_or_else(|| {
1258            DemesError::EpochError(format!(
1259                "deme {deme_name}, epoch {index}:size function is None",
1260            ))
1261        })?;
1262
1263        let is_constant = matches!(size_function, SizeFunction::Constant);
1264
1265        if (is_constant && start_size != end_size) || (!is_constant && start_size == end_size) {
1266            Err(DemesError::EpochError(format!(
1267                "deme {}, index{}: start_size ({:?}) == end_size ({:?}) paired with invalid size_function: {}",
1268                deme_name, index, self.start_size, self.end_size, size_function
1269            )))
1270        } else {
1271            Ok(())
1272        }
1273    }
1274
1275    fn validate(&self, index: usize, deme_name: &str) -> Result<(), DemesError> {
1276        let start_size = self.start_size.ok_or_else(|| {
1277            DemesError::EpochError(format!(
1278                "deme {deme_name}, epoch {index}: start_size is None",
1279            ))
1280        })?;
1281        DemeSize::try_from(start_size)
1282            .map_err(|_| DemesError::EpochError(format!("invalid start_size: {start_size:?}")))?;
1283        let end_size = self.end_size.ok_or_else(|| {
1284            DemesError::EpochError(format!("deme {deme_name}, epoch {index}: end_size is None",))
1285        })?;
1286        DemeSize::try_from(end_size)
1287            .map_err(|_| DemesError::EpochError(format!("invalid end_size: {end_size:?}")))?;
1288        self.validate_end_time(index, deme_name)?;
1289        self.validate_cloning_rate(index, deme_name)?;
1290        self.validate_selfing_rate(index, deme_name)?;
1291        self.validate_size_function(index, deme_name, start_size, end_size)
1292    }
1293}
1294
1295#[derive(Default, Clone, Debug, Deserialize)]
1296#[serde(deny_unknown_fields)]
1297pub(crate) struct UnresolvedDeme {
1298    name: String,
1299    #[serde(default = "String::default")]
1300    description: String,
1301    #[allow(missing_docs)]
1302    // NOTE: we use option here because
1303    // an empty vector in the input means
1304    // "no ancestors" (i.e., the demes themselves are
1305    // the most ancient).
1306    // When there are toplevel deme defaults,
1307    // we only fill them in when this value is None
1308    pub ancestors: Option<Vec<String>>,
1309    #[allow(missing_docs)]
1310    pub proportions: Option<Vec<InputProportion>>,
1311    #[allow(missing_docs)]
1312    pub start_time: Option<InputTime>,
1313    #[serde(default = "DemeDefaults::default")]
1314    #[serde(skip_serializing)]
1315    #[allow(missing_docs)]
1316    pub defaults: DemeDefaults,
1317    #[serde(default = "Vec::<UnresolvedEpoch>::default")]
1318    epochs: Vec<UnresolvedEpoch>,
1319
1320    #[serde(skip)]
1321    ancestor_map: DemeMap,
1322    #[serde(skip)]
1323    ancestor_indexes: Vec<usize>,
1324}
1325
1326impl From<Deme> for UnresolvedDeme {
1327    fn from(value: Deme) -> Self {
1328        let epochs: Vec<UnresolvedEpoch> = value
1329            .epochs
1330            .into_iter()
1331            .map(UnresolvedEpoch::from)
1332            .collect::<Vec<_>>();
1333        Self {
1334            name: value.name,
1335            description: value.description,
1336            ancestor_map: DemeMap::default(),
1337            ancestor_indexes: vec![],
1338            epochs,
1339            start_time: Some(f64::from(value.start_time).into()),
1340            proportions: Some(
1341                value
1342                    .proportions
1343                    .into_iter()
1344                    .map(|p| f64::from(p).into())
1345                    .collect::<Vec<_>>(),
1346            ),
1347            ancestors: Some(value.ancestors),
1348            defaults: DemeDefaults::default(),
1349        }
1350    }
1351}
1352
1353/// A resolved deme.
1354#[derive(Clone, Debug, Serialize)]
1355pub struct Deme {
1356    name: String,
1357    description: String,
1358    #[serde(skip)]
1359    ancestor_map: DemeMap,
1360    #[serde(skip)]
1361    ancestor_indexes: Vec<usize>,
1362    epochs: Vec<Epoch>,
1363    ancestors: Vec<String>,
1364    proportions: Vec<Proportion>,
1365    start_time: Time,
1366}
1367
1368impl Deme {
1369    fn resolved_time_to_generations(
1370        &mut self,
1371        generation_time: GenerationTime,
1372        rounding: fn(Time, GenerationTime) -> Time,
1373    ) -> Result<(), DemesError> {
1374        self.start_time = convert_resolved_time_to_generations(
1375            generation_time,
1376            rounding,
1377            DemesError::DemeError,
1378            &format!("start_time unresolved for deme: {}", self.name),
1379            Some(self.start_time),
1380        )?;
1381        self.epochs
1382            .iter_mut()
1383            .try_for_each(|epoch| epoch.resolved_time_to_generations(generation_time, rounding))?;
1384
1385        let valid = |w: (Time, Time)| {
1386            if w.1 >= w.0 {
1387                Err(DemesError::EpochError(
1388                    "conversion to generations resulted in an invalid Epoch".to_string(),
1389                ))
1390            } else {
1391                Ok(())
1392            }
1393        };
1394
1395        self.start_times()
1396            .zip(self.end_times())
1397            .try_for_each(valid)?;
1398
1399        self.end_times()
1400            .take(self.num_epochs() - 1)
1401            .zip(self.end_times().skip(1))
1402            .try_for_each(valid)?;
1403
1404        Ok(())
1405    }
1406
1407    /// Iterator over resolved epoch start times
1408    pub fn start_times(&self) -> impl Iterator<Item = Time> + '_ {
1409        self.epochs.iter().map(|e| e.start_time)
1410    }
1411
1412    /// The resolved start time
1413    pub fn start_time(&self) -> Time {
1414        self.start_time
1415    }
1416
1417    /// Deme name
1418    pub fn name(&self) -> &str {
1419        &self.name
1420    }
1421
1422    /// The resolved time interval
1423    pub fn time_interval(&self) -> TimeInterval {
1424        TimeInterval::new(self.start_time(), self.end_time())
1425    }
1426
1427    /// Number of ancestors
1428    pub fn num_ancestors(&self) -> usize {
1429        self.ancestors.len()
1430    }
1431
1432    /// Iterator over resolved epoch start sizes.
1433    pub fn start_sizes(&self) -> impl Iterator<Item = DemeSize> + '_ {
1434        self.epochs.iter().map(|e| e.start_size)
1435    }
1436
1437    /// Iterator over resolved epoch end sizes
1438    pub fn end_sizes(&self) -> impl Iterator<Item = DemeSize> + '_ {
1439        self.epochs.iter().map(|e| e.end_size)
1440    }
1441
1442    /// Itertor over resolved epoch end times
1443    pub fn end_times(&self) -> impl Iterator<Item = Time> + '_ {
1444        self.epochs.iter().map(|e| e.end_time)
1445    }
1446
1447    /// End time of the deme.
1448    ///
1449    /// Obtained from the value stored in the most
1450    /// recent epoch.
1451    pub fn end_time(&self) -> Time {
1452        assert!(!self.epochs.is_empty());
1453        self.epochs[self.epochs.len() - 1].end_time()
1454    }
1455
1456    /// Hash map of ancestor name to ancestor deme
1457    #[deprecated(note = "Use Deme::ancestor_names and/or Deme::ancestor_indexes instead")]
1458    pub fn ancestors(&self) -> &DemeMap {
1459        &self.ancestor_map
1460    }
1461
1462    /// Resolved start size
1463    pub fn start_size(&self) -> DemeSize {
1464        self.epochs[0].start_size()
1465    }
1466
1467    /// Resolved end size
1468    pub fn end_size(&self) -> DemeSize {
1469        assert!(!self.epochs.is_empty());
1470        self.epochs[self.epochs.len() - 1].end_size()
1471    }
1472
1473    /// Names of ancestor demes.
1474    ///
1475    /// Empty if no ancestors.
1476    pub fn ancestor_names(&self) -> &[String] {
1477        &self.ancestors
1478    }
1479
1480    /// Indexes of ancestor demes.
1481    ///
1482    /// Empty if no ancestors.
1483    pub fn ancestor_indexes(&self) -> &[usize] {
1484        debug_assert_eq!(self.ancestor_indexes.len(), self.ancestors.len());
1485        &self.ancestor_indexes
1486    }
1487
1488    /// Description string
1489    pub fn description(&self) -> &str {
1490        &self.description
1491    }
1492
1493    /// Obtain the number of [`Epoch`](crate::Epoch) instances.
1494    ///
1495    /// # Examples
1496    ///
1497    /// See [`here`](crate::SizeFunction).
1498    pub fn num_epochs(&self) -> usize {
1499        self.epochs.len()
1500    }
1501
1502    /// Resolved epochs
1503    pub fn epochs(&self) -> &[Epoch] {
1504        &self.epochs
1505    }
1506
1507    /// Returns a refernce to an epoch [`Epoch`](crate::Epoch) at index `epoch`.
1508    ///
1509    /// # Examples
1510    ///
1511    /// See [`here`](crate::SizeFunction) for examples.
1512    pub fn get_epoch(&self, epoch: usize) -> Option<&Epoch> {
1513        self.epochs.get(epoch)
1514    }
1515
1516    /// Resolved proportions
1517    pub fn proportions(&self) -> &[Proportion] {
1518        &self.proportions
1519    }
1520
1521    /// Size of Deme at a given time
1522    ///
1523    /// # Errors
1524    ///
1525    /// See [`Epoch::size_at`] for details.
1526    pub fn size_at<F: Into<f64>>(&self, time: F) -> Result<Option<DemeSize>, DemesError> {
1527        let time: f64 = time.into();
1528        Time::try_from(time)
1529            .map_err(|_| DemesError::DemeError(format!("invalid time: {time:?}")))?;
1530
1531        if time == f64::INFINITY && self.start_time == f64::INFINITY {
1532            return Ok(Some(self.epochs()[0].start_size));
1533        };
1534
1535        let epoch = self.epochs().iter().find(|x| {
1536            x.time_interval()
1537                .contains_exclusive_start_inclusive_end(time)
1538        });
1539
1540        match epoch {
1541            None => Ok(None),
1542            Some(e) => Ok(e.size_at(time)?),
1543        }
1544    }
1545}
1546
1547impl TryFrom<UnresolvedDeme> for Deme {
1548    type Error = DemesError;
1549
1550    fn try_from(value: UnresolvedDeme) -> Result<Self, Self::Error> {
1551        let mut epochs = vec![];
1552        let start_time = value.start_time.ok_or_else(|| {
1553            DemesError::DemeError(format!("deme {} start_time is not resolved", value.name))
1554        })?;
1555        let mut epoch_start_time = start_time;
1556        for hdm_epoch in value.epochs.into_iter() {
1557            let end_time = hdm_epoch
1558                .end_time
1559                .ok_or_else(|| DemesError::EpochError("epoch end time unresolved".to_string()))?;
1560            let e = Epoch::new_from_unresolved(epoch_start_time, hdm_epoch)?;
1561            epoch_start_time = end_time;
1562            epochs.push(e);
1563        }
1564        let input_proportions = value.proportions.ok_or_else(|| {
1565            DemesError::PulseError("pulse proportions are unresolved".to_string())
1566        })?;
1567        let mut proportions = vec![];
1568        for p in input_proportions {
1569            proportions.push(Proportion::try_from(p)?);
1570        }
1571        Ok(Self {
1572            description: value.description,
1573            ancestor_map: value.ancestor_map,
1574            ancestor_indexes: value.ancestor_indexes,
1575            epochs,
1576            ancestors: value.ancestors.ok_or_else(|| {
1577                DemesError::DemeError(format!("deme {} ancestors are not resolved", value.name))
1578            })?,
1579            proportions,
1580            start_time: start_time.try_into().map_err(|_| {
1581                DemesError::DemeError(format!("invalid start_time: {start_time:?}"))
1582            })?,
1583            name: value.name,
1584        })
1585    }
1586}
1587
1588impl PartialEq for Deme {
1589    fn eq(&self, other: &Self) -> bool {
1590        self.name == other.name
1591            && self.description == other.description
1592            && self.ancestors == other.ancestors
1593            && self.proportions == other.proportions
1594            && self.start_time == other.start_time
1595            && self.epochs == other.epochs
1596            && self.ancestor_map == other.ancestor_map
1597    }
1598}
1599
1600/// HDM data for a [`Deme`](crate::Deme)
1601#[derive(Default, Clone, Debug, Deserialize)]
1602#[serde(deny_unknown_fields)]
1603pub struct UnresolvedDemeHistory {
1604    #[allow(missing_docs)]
1605    // NOTE: we use option here because
1606    // an empty vector in the input means
1607    // "no ancestors" (i.e., the demes themselves are
1608    // the most ancient).
1609    // When there are toplevel deme defaults,
1610    // we only fill them in when this value is None
1611    pub ancestors: Option<Vec<String>>,
1612    #[allow(missing_docs)]
1613    pub proportions: Option<Vec<InputProportion>>,
1614    #[allow(missing_docs)]
1615    pub start_time: Option<InputTime>,
1616    #[serde(default = "DemeDefaults::default")]
1617    #[serde(skip_serializing)]
1618    #[allow(missing_docs)]
1619    pub defaults: DemeDefaults,
1620}
1621
1622impl PartialEq for UnresolvedDeme {
1623    fn eq(&self, other: &Self) -> bool {
1624        self.name == other.name
1625            && self.description == other.description
1626            && self.ancestors == other.ancestors
1627            && self.proportions == other.proportions
1628            && self.start_time == other.start_time
1629            && self.epochs == other.epochs
1630            && self.ancestor_map == other.ancestor_map
1631    }
1632}
1633
1634impl Eq for UnresolvedDeme {}
1635
1636impl UnresolvedDeme {
1637    pub(crate) fn new_via_builder(
1638        name: &str,
1639        epochs: Vec<UnresolvedEpoch>,
1640        history: UnresolvedDemeHistory,
1641        description: Option<&str>,
1642    ) -> Self {
1643        let description = match description {
1644            Some(desc) => desc.to_string(),
1645            None => String::default(),
1646        };
1647        Self {
1648            name: name.to_string(),
1649            epochs,
1650            start_time: history.start_time,
1651            ancestors: history.ancestors,
1652            proportions: history.proportions,
1653            defaults: history.defaults,
1654            description,
1655            ..Default::default()
1656        }
1657    }
1658
1659    fn resolve_times(
1660        &mut self,
1661        deme_map: &DemeMap,
1662        demes: &[UnresolvedDeme],
1663        defaults: &GraphDefaults,
1664    ) -> Result<(), DemesError> {
1665        // apply top-level default if it exists
1666
1667        self.start_time = match self.start_time {
1668            Some(start_time) => Some(start_time),
1669            None => match defaults.deme.start_time {
1670                Some(start_time) => Some(start_time),
1671                None => Some(InputTime::default_deme_start_time()),
1672            },
1673        };
1674
1675        if self
1676            .ancestors
1677            .as_ref()
1678            .ok_or_else(|| DemesError::DemeError("unexpected None for deme ancestors".to_string()))?
1679            .is_empty()
1680            && self.start_time_resolved_or(|| {
1681                DemesError::DemeError(format!("deme {}: start_time unresolved", self.name))
1682            })? != InputTime::default_deme_start_time()
1683        {
1684            return Err(DemesError::DemeError(format!(
1685                "deme {} has finite start time but no ancestors",
1686                self.name
1687            )));
1688        }
1689
1690        if self.get_num_ancestors()? == 1 {
1691            let first_ancestor_name = &self.get_ancestor_names()?[0];
1692
1693            let deme_start_time = match self.start_time {
1694                Some(start_time) => {
1695                    if start_time == InputTime::default_deme_start_time() {
1696                        let first_ancestor_deme = get_deme!(first_ancestor_name, deme_map, demes)
1697                            .ok_or_else(|| {
1698                            DemesError::DemeError(
1699                                "fatal error: ancestor maps to no Deme object".to_string(),
1700                            )
1701                        })?;
1702                        first_ancestor_deme.get_end_time()?.into()
1703                    } else {
1704                        start_time
1705                    }
1706                }
1707                None => InputTime::default_deme_start_time(),
1708            };
1709
1710            deme_start_time.err_if_not_valid_deme_start_time()?;
1711            self.start_time = Some(deme_start_time);
1712        }
1713
1714        for ancestor in self.get_ancestor_names()?.iter() {
1715            let a = get_deme!(ancestor, deme_map, demes).ok_or_else(|| {
1716                DemesError::DemeError(format!(
1717                    "ancestor {ancestor} not present in global deme map",
1718                ))
1719            })?;
1720            let t = a.get_time_interval()?;
1721            if !t.contains_start_time(self.get_start_time()?) {
1722                return Err(DemesError::DemeError(format!(
1723                    "Ancestor {} does not exist at deme {}'s start_time",
1724                    ancestor, self.name
1725                )));
1726            }
1727        }
1728
1729        // last epoch end time defaults to 0,
1730        // unless defaults are specified
1731        let last_epoch_ref = self
1732            .epochs
1733            .last_mut()
1734            .ok_or_else(|| DemesError::DemeError("epochs are empty".to_string()))?;
1735        if last_epoch_ref.end_time.is_none() {
1736            last_epoch_ref.end_time = match self.defaults.epoch.end_time {
1737                Some(end_time) => Some(end_time),
1738                None => match defaults.epoch.end_time {
1739                    Some(end_time) => Some(end_time),
1740                    None => Some(InputTime::default_epoch_end_time()),
1741                },
1742            }
1743        }
1744
1745        // apply default epoch start times
1746        for epoch in self.epochs.iter_mut() {
1747            match epoch.end_time {
1748                Some(end_time) => {
1749                    Time::try_from(end_time).map_err(|_| {
1750                        DemesError::EpochError(format!("invalid end_time: {end_time:?}"))
1751                    })?;
1752                }
1753                None => {
1754                    epoch.end_time = match self.defaults.epoch.end_time {
1755                        Some(end_time) => Some(end_time),
1756                        None => defaults.epoch.end_time,
1757                    }
1758                }
1759            }
1760        }
1761
1762        let mut last_time = f64::from(self.get_start_time()?);
1763        for (i, epoch) in self.epochs.iter().enumerate() {
1764            let end_time = f64::from(epoch.end_time.ok_or_else(|| {
1765                DemesError::EpochError(format!(
1766                    "deme: {}, epoch: {i} end time must be specified",
1767                    self.name
1768                ))
1769            })?);
1770
1771            if !end_time.is_finite() {
1772                return Err(DemesError::EpochError(format!(
1773                    "invalid end_time: {end_time:?}"
1774                )));
1775            }
1776
1777            if end_time >= last_time {
1778                return Err(DemesError::EpochError(
1779                    "Epoch end times must be listed in decreasing order".to_string(),
1780                ));
1781            }
1782            last_time = end_time;
1783            Time::try_from(
1784                epoch
1785                    .end_time
1786                    .ok_or_else(|| DemesError::EpochError("end_time is None".to_string()))?,
1787            )
1788            .map_err(|_| {
1789                DemesError::EpochError(format!("invalid end_time: {:?}", epoch.end_time))
1790            })?;
1791        }
1792
1793        Ok(())
1794    }
1795
1796    fn resolve_first_epoch_sizes(
1797        &mut self,
1798        defaults: &GraphDefaults,
1799    ) -> Result<Option<InputDemeSize>, DemesError> {
1800        let self_defaults = self.defaults.clone();
1801        let epoch_sizes = {
1802            let temp_epoch = self.epochs.get_mut(0).ok_or_else(|| {
1803                DemesError::DemeError(format!("deme {} has no epochs", self.name))
1804            })?;
1805
1806            temp_epoch.start_size = match temp_epoch.start_size {
1807                Some(start_size) => Some(start_size),
1808                None => self_defaults.epoch.start_size,
1809            };
1810            temp_epoch.end_size = match temp_epoch.end_size {
1811                Some(end_size) => Some(end_size),
1812                None => self_defaults.epoch.end_size,
1813            };
1814
1815            defaults.apply_epoch_size_defaults(temp_epoch);
1816            if temp_epoch.start_size.is_none() && temp_epoch.end_size.is_none() {
1817                return Err(DemesError::EpochError(format!(
1818                    "first epoch of deme {} must define one or both of start_size and end_size",
1819                    self.name
1820                )));
1821            }
1822            if temp_epoch.start_size.is_none() {
1823                temp_epoch.start_size = temp_epoch.end_size;
1824            }
1825            if temp_epoch.end_size.is_none() {
1826                temp_epoch.end_size = temp_epoch.start_size;
1827            }
1828            // temp_epoch.clone()
1829            (temp_epoch.start_size, temp_epoch.end_size)
1830        };
1831
1832        let epoch_start_size = epoch_sizes.0.ok_or_else(|| {
1833            DemesError::EpochError(format!(
1834                "first epoch of {} has unresolved start_size",
1835                self.name
1836            ))
1837        })?;
1838        let epoch_end_size = epoch_sizes.1.ok_or_else(|| {
1839            DemesError::EpochError(format!(
1840                "first epoch of {} has unresolved end_size",
1841                self.name
1842            ))
1843        })?;
1844
1845        let start_time = self.start_time.ok_or_else(|| {
1846            DemesError::EpochError(format!("deme {} start_time is None", self.name))
1847        })?;
1848
1849        if start_time == InputTime::default_deme_start_time() && epoch_sizes.0 != epoch_sizes.1 {
1850            let msg = format!(
1851                    "first epoch of deme {} cannot have varying size and an infinite time interval: start_size = {}, end_size = {}",
1852                    self.name, f64::from(epoch_start_size), f64::from(epoch_end_size),
1853                );
1854            return Err(DemesError::EpochError(msg));
1855        }
1856
1857        Ok(Some(epoch_end_size))
1858    }
1859
1860    fn resolve_sizes(&mut self, defaults: &GraphDefaults) -> Result<(), DemesError> {
1861        let mut last_end_size = self.resolve_first_epoch_sizes(defaults)?;
1862        let local_defaults = self.defaults.clone();
1863        for epoch in self.epochs.iter_mut().skip(1) {
1864            match epoch.start_size {
1865                Some(_) => (),
1866                None => match local_defaults.epoch.start_size {
1867                    Some(start_size) => epoch.start_size = Some(start_size),
1868                    None => match defaults.epoch.start_size {
1869                        Some(start_size) => epoch.start_size = Some(start_size),
1870                        None => epoch.start_size = last_end_size,
1871                    },
1872                },
1873            }
1874            match epoch.end_size {
1875                Some(_) => (),
1876                None => match local_defaults.epoch.end_size {
1877                    Some(end_size) => epoch.end_size = Some(end_size),
1878                    None => match defaults.epoch.end_size {
1879                        Some(end_size) => epoch.end_size = Some(end_size),
1880                        None => epoch.end_size = epoch.start_size,
1881                    },
1882                },
1883            }
1884            last_end_size = epoch.end_size;
1885        }
1886        Ok(())
1887    }
1888
1889    fn resolve_proportions(&mut self) -> Result<(), DemesError> {
1890        let num_ancestors = self.get_num_ancestors()?;
1891
1892        let proportions = self
1893            .proportions
1894            .as_mut()
1895            .ok_or_else(|| DemesError::DemeError("proportions is None".to_string()))?;
1896
1897        if proportions.is_empty() && num_ancestors == 1 {
1898            proportions.push(InputProportion::from(1.0));
1899        }
1900
1901        if num_ancestors != proportions.len() {
1902            return Err(DemesError::DemeError(format!(
1903                "deme {} ancestors and proportions have different lengths",
1904                self.name
1905            )));
1906        }
1907        Ok(())
1908    }
1909
1910    fn check_empty_epochs(&mut self) {
1911        if self.epochs.is_empty() {
1912            self.epochs.push(UnresolvedEpoch::default());
1913        }
1914    }
1915
1916    fn apply_toplevel_defaults(&mut self, defaults: &GraphDefaults) {
1917        if self.ancestors.is_none() {
1918            self.ancestors = match &defaults.deme.ancestors {
1919                Some(ancestors) => Some(ancestors.to_vec()),
1920                None => Some(vec![]),
1921            }
1922        }
1923
1924        if self.proportions.is_none() {
1925            self.proportions = match &defaults.deme.proportions {
1926                Some(proportions) => Some(proportions.to_vec()),
1927                None => Some(vec![]),
1928            }
1929        }
1930    }
1931
1932    fn validate_ancestor_uniqueness(&self, deme_map: &DemeMap) -> Result<(), DemesError> {
1933        match &self.ancestors {
1934            Some(ancestors) => {
1935                let mut ancestor_set = HashSet::<String>::default();
1936                for ancestor in ancestors {
1937                    if ancestor == &self.name {
1938                        return Err(DemesError::DemeError(format!(
1939                            "deme: {} lists itself as an ancestor",
1940                            self.name
1941                        )));
1942                    }
1943                    if !deme_map.contains_key(ancestor) {
1944                        return Err(DemesError::DemeError(format!(
1945                            "deme: {} lists invalid ancestor: {ancestor}",
1946                            self.name
1947                        )));
1948                    }
1949                    if ancestor_set.contains(ancestor) {
1950                        return Err(DemesError::DemeError(format!(
1951                            "deme: {} lists ancestor: {ancestor} multiple times",
1952                            self.name
1953                        )));
1954                    }
1955                    ancestor_set.insert(ancestor.clone());
1956                }
1957                Ok(())
1958            }
1959            None => Ok(()),
1960        }
1961    }
1962
1963    // Make the internal data match the MDM spec
1964    fn resolve(
1965        &mut self,
1966        deme_map: &DemeMap,
1967        demes: &[UnresolvedDeme],
1968        defaults: &GraphDefaults,
1969    ) -> Result<(), DemesError> {
1970        self.defaults.validate()?;
1971        self.apply_toplevel_defaults(defaults);
1972        self.validate_ancestor_uniqueness(deme_map)?;
1973        self.check_empty_epochs();
1974        assert!(
1975            self.ancestor_indexes.is_empty(),
1976            "{:?} has non-empty ancestor index",
1977            self.name
1978        );
1979        assert!(
1980            self.ancestor_map.is_empty(),
1981            "{:?} has non-empty ancestor map",
1982            self.name
1983        );
1984        self.resolve_times(deme_map, demes, defaults)?;
1985        self.resolve_sizes(defaults)?;
1986        let self_defaults = self.defaults.clone();
1987        self.epochs
1988            .iter_mut()
1989            .try_for_each(|e| e.resolve(defaults, &self_defaults))?;
1990        self.resolve_proportions()?;
1991
1992        let mut ancestor_map = DemeMap::default();
1993        let ancestors = self.ancestors.as_ref().ok_or_else(|| {
1994            DemesError::DemeError(format!("deme {}: ancestors are None", self.name))
1995        })?;
1996        for ancestor in ancestors {
1997            let deme = deme_map.get(ancestor).ok_or_else(|| {
1998                DemesError::DemeError(format!("invalid ancestor of {}: {ancestor}", self.name))
1999            })?;
2000            ancestor_map.insert(ancestor.clone(), *deme);
2001            self.ancestor_indexes.push(*deme);
2002        }
2003        self.ancestor_map = ancestor_map;
2004        Ok(())
2005    }
2006
2007    fn validate_start_time(&self) -> Result<(), DemesError> {
2008        match self.start_time {
2009            Some(start_time) => {
2010                Time::try_from(start_time).map_err(|_| {
2011                    DemesError::DemeError(format!("invalid start_time: {start_time:?}"))
2012                })?;
2013                start_time.err_if_not_valid_deme_start_time()
2014            }
2015            None => Err(DemesError::DemeError("start_time is None".to_string())),
2016        }
2017    }
2018
2019    fn start_time_resolved_or<F: FnOnce() -> DemesError>(
2020        &self,
2021        err: F,
2022    ) -> Result<InputTime, DemesError> {
2023        self.start_time.ok_or_else(err)
2024    }
2025
2026    // Names must be valid Python identifiers
2027    // https://docs.python.org/3/reference/lexical_analysis.html#identifiers
2028    pub(crate) fn validate_name(&self) -> Result<(), DemesError> {
2029        let python_identifier = match regex::Regex::new(r"^[^\d\W]\w*$") {
2030            Ok(p) => p,
2031            Err(_) => {
2032                return Err(DemesError::DemeError(
2033                    "failed to biuld python_identifier regex".to_string(),
2034                ))
2035            }
2036        };
2037        if python_identifier.is_match(&self.name) {
2038            Ok(())
2039        } else {
2040            Err(DemesError::DemeError(format!(
2041                "invalid deme name: {}:",
2042                self.name
2043            )))
2044        }
2045    }
2046
2047    fn validate(&self) -> Result<(), DemesError> {
2048        self.validate_name()?;
2049        self.validate_start_time()?;
2050        if self.epochs.is_empty() {
2051            return Err(DemesError::DemeError(format!(
2052                "no epochs for deme {}",
2053                self.name
2054            )));
2055        }
2056
2057        self.epochs
2058            .iter()
2059            .enumerate()
2060            .try_for_each(|(i, e)| e.validate(i, &self.name))?;
2061
2062        let proportions = self
2063            .proportions
2064            .as_ref()
2065            .ok_or_else(|| DemesError::DemeError("proportions is None".to_string()))?;
2066        for p in proportions.iter() {
2067            Proportion::try_from(*p)?;
2068        }
2069
2070        if !proportions.is_empty() {
2071            let sum_proportions: f64 = proportions.iter().map(|p| f64::from(*p)).sum();
2072            // NOTE: this is same default as Python's math.isclose().
2073            if (sum_proportions - 1.0).abs() > 1e-9 {
2074                return Err(DemesError::DemeError(format!(
2075                    "proportions for deme {} should sum to ~1.0, got: {sum_proportions}",
2076                    self.name
2077                )));
2078            }
2079        }
2080
2081        Ok(())
2082    }
2083
2084    fn get_time_interval(&self) -> Result<TimeInterval, DemesError> {
2085        let start_time = self.get_start_time()?;
2086        let end_time = self.get_end_time()?;
2087        Ok(TimeInterval::new(start_time, end_time))
2088    }
2089
2090    fn get_ancestor_names(&self) -> Result<&[String], DemesError> {
2091        match &self.ancestors {
2092            Some(ancestors) => Ok(ancestors),
2093            None => Err(DemesError::DemeError(format!(
2094                "deme {} ancestors are unresolved",
2095                self.name
2096            ))),
2097        }
2098    }
2099
2100    fn get_start_time(&self) -> Result<Time, DemesError> {
2101        match self.start_time.ok_or_else(|| {
2102            DemesError::DemeError(format!("deme {} start_time is unresolved", self.name))
2103        }) {
2104            Ok(value) => value.try_into(),
2105            Err(e) => Err(e),
2106        }
2107    }
2108
2109    fn get_end_time(&self) -> Result<Time, DemesError> {
2110        match self
2111            .epochs
2112            .last()
2113            .as_ref()
2114            .ok_or_else(|| DemesError::DemeError(format!("deme {} has no epochs", self.name)))?
2115            .end_time
2116            .ok_or_else(|| {
2117                DemesError::DemeError(format!(
2118                    "last epoch of deme {} end_time unresolved",
2119                    self.name
2120                ))
2121            }) {
2122            Ok(value) => value.try_into(),
2123            Err(e) => Err(e),
2124        }
2125    }
2126
2127    fn get_num_ancestors(&self) -> Result<usize, DemesError> {
2128        Ok(self
2129            .ancestors
2130            .as_ref()
2131            .ok_or_else(|| {
2132                DemesError::DemeError(format!("deme {} ancestors are unresolved", self.name))
2133            })?
2134            .len())
2135    }
2136
2137    fn rescale(&mut self, scaling_factor: f64) -> Result<(), DemesError> {
2138        self.start_time = rescale_input_time(self.start_time, scaling_factor);
2139        self.epochs
2140            .iter_mut()
2141            .try_for_each(|e| e.rescale(scaling_factor))
2142    }
2143}
2144
2145type DemeMap = HashMap<String, usize>;
2146
2147fn deme_name_exists<F: FnOnce(String) -> DemesError>(
2148    map: &DemeMap,
2149    name: &str,
2150    err: F,
2151) -> Result<(), DemesError> {
2152    if !map.contains_key(name) {
2153        Err(err(format!("deme {name} does not exist")))
2154    } else {
2155        Ok(())
2156    }
2157}
2158
2159/// Top-level defaults
2160#[derive(Default, Debug, Deserialize)]
2161#[serde(deny_unknown_fields)]
2162pub struct GraphDefaults {
2163    #[allow(missing_docs)]
2164    #[serde(default = "UnresolvedEpoch::default")]
2165    #[allow(missing_docs)]
2166    pub epoch: UnresolvedEpoch,
2167    #[serde(default = "UnresolvedMigration::default")]
2168    #[allow(missing_docs)]
2169    pub migration: UnresolvedMigration,
2170    #[serde(default = "UnresolvedPulse::default")]
2171    #[allow(missing_docs)]
2172    pub pulse: UnresolvedPulse,
2173    #[serde(default = "TopLevelDemeDefaults::default")]
2174    #[allow(missing_docs)]
2175    pub deme: TopLevelDemeDefaults,
2176}
2177
2178impl GraphDefaults {
2179    // This fn exists so that we catch invalid inputs
2180    // prior to resolution.  During resolution,
2181    // we only visit the top-level defaults if needed.
2182    // Thus, we will miss invalid inputs if we wait
2183    // until resolution.
2184    fn validate(&self) -> Result<(), DemesError> {
2185        self.epoch.validate_as_default()?;
2186        self.pulse.validate_as_default()?;
2187        self.migration.validate()?;
2188        self.deme.validate()
2189    }
2190
2191    fn apply_default_epoch_start_size(
2192        &self,
2193        start_size: Option<InputDemeSize>,
2194    ) -> Option<InputDemeSize> {
2195        if start_size.is_some() {
2196            return start_size;
2197        }
2198        self.epoch.start_size
2199    }
2200
2201    fn apply_default_epoch_end_size(
2202        &self,
2203        end_size: Option<InputDemeSize>,
2204    ) -> Option<InputDemeSize> {
2205        if end_size.is_some() {
2206            return end_size;
2207        }
2208        self.epoch.end_size
2209    }
2210
2211    fn apply_epoch_size_defaults(&self, epoch: &mut UnresolvedEpoch) {
2212        epoch.start_size = self.apply_default_epoch_start_size(epoch.start_size);
2213        epoch.end_size = self.apply_default_epoch_end_size(epoch.end_size);
2214    }
2215
2216    fn apply_epoch_size_function_defaults(
2217        &self,
2218        size_function: Option<SizeFunction>,
2219        deme_level_defaults: &DemeDefaults,
2220    ) -> Option<SizeFunction> {
2221        if size_function.is_some() {
2222            return size_function;
2223        }
2224
2225        match deme_level_defaults.epoch.size_function {
2226            Some(sf) => Some(sf),
2227            None => match self.epoch.size_function {
2228                Some(sf) => Some(sf),
2229                None => Some(SizeFunction::Exponential),
2230            },
2231        }
2232    }
2233
2234    fn apply_migration_defaults(&self, other: &mut UnresolvedMigration) {
2235        if other.rate.is_none() {
2236            other.rate = self.migration.rate;
2237        }
2238        if other.start_time.is_none() {
2239            other.start_time = self.migration.start_time;
2240        }
2241        if other.end_time.is_none() {
2242            other.end_time = self.migration.end_time;
2243        }
2244        if other.source.is_none() {
2245            other.source.clone_from(&self.migration.source);
2246        }
2247        if other.dest.is_none() {
2248            other.dest.clone_from(&self.migration.dest);
2249        }
2250        if other.demes.is_none() {
2251            other.demes.clone_from(&self.migration.demes);
2252        }
2253    }
2254
2255    fn apply_pulse_defaults(&self, other: &mut UnresolvedPulse) {
2256        if other.time.is_none() {
2257            other.time = self.pulse.time;
2258        }
2259        if other.sources.is_none() {
2260            other.sources.clone_from(&self.pulse.sources);
2261        }
2262        if other.dest.is_none() {
2263            other.dest.clone_from(&self.pulse.dest);
2264        }
2265        if other.proportions.is_none() {
2266            other.proportions.clone_from(&self.pulse.proportions);
2267        }
2268    }
2269}
2270
2271/// Top-level defaults for a [`Deme`](crate::Deme).
2272///
2273/// This type is used as a member of
2274/// [`GraphDefaults`](crate::GraphDefaults)
2275#[derive(Clone, Default, Debug, Deserialize)]
2276#[serde(deny_unknown_fields)]
2277pub struct TopLevelDemeDefaults {
2278    #[allow(missing_docs)]
2279    pub description: Option<String>,
2280    #[allow(missing_docs)]
2281    pub start_time: Option<InputTime>,
2282    #[allow(missing_docs)]
2283    pub ancestors: Option<Vec<String>>,
2284    #[allow(missing_docs)]
2285    pub proportions: Option<Vec<InputProportion>>,
2286}
2287
2288impl TopLevelDemeDefaults {
2289    fn validate(&self) -> Result<(), DemesError> {
2290        if let Some(value) = self.start_time {
2291            Time::try_from(value)
2292                .map_err(|_| DemesError::DemeError(format!("invalid start_time: {value:?}")))?;
2293        }
2294
2295        if let Some(proportions) = &self.proportions {
2296            for v in proportions {
2297                if Proportion::try_from(*v).is_err() {
2298                    return Err(DemesError::GraphError(format!(
2299                        "invalid default proportion: {v:?}"
2300                    )));
2301                }
2302            }
2303        }
2304
2305        Ok(())
2306    }
2307}
2308
2309/// Deme-level defaults
2310#[derive(Clone, Default, Debug, Deserialize)]
2311#[serde(deny_unknown_fields)]
2312pub struct DemeDefaults {
2313    #[allow(missing_docs)]
2314    pub epoch: UnresolvedEpoch,
2315}
2316
2317impl DemeDefaults {
2318    fn validate(&self) -> Result<(), DemesError> {
2319        self.epoch.validate_as_default()
2320    }
2321}
2322
2323/// Top-level metadata
2324///
2325/// # Examples
2326///
2327/// ```
2328/// #[derive(serde::Deserialize)]
2329/// struct MyMetaData {
2330///    foo: i32,
2331///    bar: String
2332/// }
2333///
2334/// let yaml = "
2335/// time_units: generations
2336/// metadata:
2337///  foo: 1
2338///  bar: bananas
2339/// demes:
2340///  - name: A
2341///    epochs:
2342///     - start_size: 100
2343/// ";
2344///
2345/// let graph = demes::loads(yaml).unwrap();
2346/// let yaml_metadata = graph.metadata().unwrap().as_yaml_string().unwrap();
2347/// let my_metadata: MyMetaData = serde_yaml::from_str(&yaml_metadata).unwrap();
2348/// assert_eq!(my_metadata.foo, 1);
2349/// assert_eq!(&my_metadata.bar, "bananas");
2350/// ```
2351#[derive(Clone, Default, Debug, Serialize, Deserialize, Eq, PartialEq)]
2352pub struct Metadata {
2353    metadata: std::collections::BTreeMap<String, serde_yaml::Value>,
2354}
2355
2356impl TryFrom<std::collections::BTreeMap<String, serde_yaml::Value>> for Metadata {
2357    type Error = DemesError;
2358
2359    fn try_from(
2360        value: std::collections::BTreeMap<String, serde_yaml::Value>,
2361    ) -> Result<Self, Self::Error> {
2362        if value.is_empty() {
2363            Err(DemesError::GraphError(
2364                "toplevel metadata must mot be empty".to_string(),
2365            ))
2366        } else {
2367            Ok(Metadata { metadata: value })
2368        }
2369    }
2370}
2371
2372fn require_non_empty_metadata<'de, D>(
2373    deserializer: D,
2374) -> Result<Option<std::collections::BTreeMap<String, serde_yaml::Value>>, D::Error>
2375where
2376    D: serde::Deserializer<'de>,
2377{
2378    let buf = std::collections::BTreeMap::<String, serde_yaml::Value>::deserialize(deserializer)?;
2379
2380    if !buf.is_empty() {
2381        Ok(Some(buf))
2382    } else {
2383        Err(serde::de::Error::custom(
2384            "metadata: cannot be an empty mapping".to_string(),
2385        ))
2386    }
2387}
2388
2389impl Metadata {
2390    /// `true` if metadata is present, `false` otherwise
2391    fn is_empty(&self) -> bool {
2392        self.metadata.is_empty()
2393    }
2394
2395    /// Return the metadata as YAML
2396    pub fn as_yaml_string(&self) -> Result<String, serde_yaml::Error> {
2397        serde_yaml::to_string(self.as_raw_ref())
2398    }
2399
2400    pub(crate) fn as_raw_ref(&self) -> &std::collections::BTreeMap<String, serde_yaml::Value> {
2401        &self.metadata
2402    }
2403
2404    #[cfg(feature = "json")]
2405    /// Return a copy of the metadata as string in JSON format
2406    pub fn to_json_string(&self) -> Result<String, DemesError> {
2407        serde_json::to_string(self.as_raw_ref())
2408            .map_err(|e| DemesError::JsonError(crate::error::OpaqueJSONError(e)))
2409    }
2410}
2411
2412#[derive(Deserialize, Debug)]
2413#[serde(deny_unknown_fields)]
2414pub(crate) struct UnresolvedGraph {
2415    #[serde(skip_serializing)]
2416    #[serde(skip_deserializing)]
2417    #[serde(default = "Option::default")]
2418    input_string: Option<InputFormatInternal>,
2419    #[serde(skip_serializing_if = "Option::is_none")]
2420    description: Option<String>,
2421    #[serde(skip_serializing_if = "Option::is_none")]
2422    doi: Option<Vec<String>>,
2423    #[serde(default = "GraphDefaults::default")]
2424    defaults: GraphDefaults,
2425    #[serde(deserialize_with = "require_non_empty_metadata")]
2426    #[serde(default = "Option::default")]
2427    #[serde(skip_serializing_if = "Option::is_none")]
2428    metadata: Option<std::collections::BTreeMap<String, serde_yaml::Value>>,
2429    time_units: TimeUnits,
2430    #[serde(skip_serializing_if = "Option::is_none")]
2431    generation_time: Option<InputGenerationTime>,
2432    pub(crate) demes: Vec<UnresolvedDeme>,
2433    #[serde(default = "Vec::<UnresolvedMigration>::default")]
2434    #[serde(rename = "migrations")]
2435    #[serde(skip_serializing)]
2436    input_migrations: Vec<UnresolvedMigration>,
2437    #[serde(default = "Vec::<AsymmetricMigration>::default")]
2438    #[serde(rename = "migrations")]
2439    #[serde(skip_deserializing)]
2440    #[serde(skip_serializing_if = "Vec::<AsymmetricMigration>::is_empty")]
2441    resolved_migrations: Vec<AsymmetricMigration>,
2442    #[serde(default = "Vec::<UnresolvedPulse>::default")]
2443    pulses: Vec<UnresolvedPulse>,
2444    #[serde(skip)]
2445    deme_map: DemeMap,
2446}
2447
2448impl UnresolvedGraph {
2449    pub(crate) fn new(
2450        time_units: TimeUnits,
2451        generation_time: Option<InputGenerationTime>,
2452        defaults: Option<GraphDefaults>,
2453    ) -> Self {
2454        Self {
2455            input_string: None,
2456            time_units,
2457            generation_time,
2458
2459            // remaining fields have defaults
2460            description: Option::<String>::default(),
2461            doi: Option::<Vec<String>>::default(),
2462            defaults: defaults.unwrap_or_default(),
2463            metadata: Option::default(),
2464            demes: Vec::<UnresolvedDeme>::default(),
2465            input_migrations: Vec::<UnresolvedMigration>::default(),
2466            resolved_migrations: Vec::<AsymmetricMigration>::default(),
2467            pulses: Vec::<UnresolvedPulse>::default(),
2468            deme_map: DemeMap::default(),
2469        }
2470    }
2471
2472    pub(crate) fn add_deme(&mut self, deme: UnresolvedDeme) {
2473        self.demes.push(deme);
2474    }
2475
2476    pub(crate) fn add_migration<I: Into<UnresolvedMigration>>(&mut self, migration: I) {
2477        self.input_migrations.push(migration.into());
2478    }
2479
2480    pub(crate) fn add_pulse(
2481        &mut self,
2482        sources: Option<Vec<String>>,
2483        dest: Option<String>,
2484        time: Option<InputTime>,
2485        proportions: Option<Vec<InputProportion>>,
2486    ) {
2487        self.pulses.push(UnresolvedPulse {
2488            sources,
2489            dest,
2490            time,
2491            proportions,
2492        });
2493    }
2494
2495    fn build_deme_map(&self) -> Result<DemeMap, DemesError> {
2496        let mut rv = DemeMap::default();
2497
2498        for (i, deme) in self.demes.iter().enumerate() {
2499            if rv.contains_key(&deme.name) {
2500                return Err(DemesError::DemeError(format!(
2501                    "duplicate deme name: {}",
2502                    deme.name,
2503                )));
2504            }
2505            rv.insert(deme.name.clone(), i);
2506        }
2507
2508        Ok(rv)
2509    }
2510
2511    fn resolve_asymmetric_migration(
2512        &mut self,
2513        source: String,
2514        dest: String,
2515        rate: InputMigrationRate,
2516        start_time: Option<InputTime>,
2517        end_time: Option<InputTime>,
2518    ) -> Result<(), DemesError> {
2519        let source_deme = get_deme!(&source, &self.deme_map, &self.demes).ok_or_else(|| {
2520            crate::DemesError::MigrationError(format!("invalid source deme name {source}"))
2521        })?;
2522        let dest_deme = get_deme!(&dest, &self.deme_map, &self.demes).ok_or_else(|| {
2523            crate::DemesError::MigrationError(format!("invalid dest deme name {dest}"))
2524        })?;
2525
2526        let start_time = match start_time {
2527            Some(t) => t,
2528            None => {
2529                std::cmp::min(source_deme.get_start_time()?, dest_deme.get_start_time()?).into()
2530            }
2531        };
2532
2533        let end_time = match end_time {
2534            Some(t) => t,
2535            None => std::cmp::max(source_deme.get_end_time()?, dest_deme.get_end_time()?).into(),
2536        };
2537
2538        deme_name_exists(&self.deme_map, &source, DemesError::MigrationError)?;
2539        deme_name_exists(&self.deme_map, &dest, DemesError::MigrationError)?;
2540
2541        let a = AsymmetricMigration {
2542            source,
2543            dest,
2544            rate: rate.try_into()?,
2545            start_time: start_time.try_into().map_err(|_| {
2546                DemesError::MigrationError(format!("invalid start_time: {start_time:?}"))
2547            })?,
2548            end_time: end_time.try_into().map_err(|_| {
2549                DemesError::MigrationError(format!("invalid end_time: {end_time:?}"))
2550            })?,
2551        };
2552
2553        self.resolved_migrations.push(a);
2554
2555        Ok(())
2556    }
2557
2558    fn process_input_asymmetric_migration(
2559        &mut self,
2560        u: &UnresolvedMigration,
2561    ) -> Result<(), DemesError> {
2562        self.resolve_asymmetric_migration(
2563            u.source.clone().ok_or_else(|| {
2564                DemesError::MigrationError("migration source is None".to_string())
2565            })?,
2566            u.dest
2567                .clone()
2568                .ok_or_else(|| DemesError::MigrationError("migration dest is None".to_string()))?,
2569            u.rate
2570                .ok_or_else(|| DemesError::MigrationError("migration rate is None".to_string()))?,
2571            u.start_time,
2572            u.end_time,
2573        )
2574    }
2575
2576    fn process_input_symmetric_migration(
2577        &mut self,
2578        u: &UnresolvedMigration,
2579    ) -> Result<(), DemesError> {
2580        let demes = u
2581            .demes
2582            .as_ref()
2583            .ok_or_else(|| DemesError::MigrationError("migration demes is None".to_string()))?;
2584
2585        if demes.len() < 2 {
2586            return Err(DemesError::MigrationError(
2587                "the demes field of a migration mut contain at least two demes".to_string(),
2588            ));
2589        }
2590
2591        let rate = u
2592            .rate
2593            .ok_or_else(|| DemesError::MigrationError("migration rate is None".to_string()))?;
2594
2595        // Each input symmetric migration becomes two AsymmetricMigration instances
2596        for (i, source_name) in demes.iter().enumerate().take(demes.len() - 1) {
2597            for dest_name in demes.iter().skip(i + 1) {
2598                if source_name == dest_name {
2599                    return Err(DemesError::MigrationError(format!(
2600                        "source/dest demes must differ: {source_name}",
2601                    )));
2602                }
2603                deme_name_exists(&self.deme_map, source_name, DemesError::MigrationError)?;
2604                deme_name_exists(&self.deme_map, dest_name, DemesError::MigrationError)?;
2605
2606                let start_time = u.start_time;
2607                let end_time = u.end_time;
2608
2609                self.resolve_asymmetric_migration(
2610                    source_name.to_string(),
2611                    dest_name.to_string(),
2612                    rate,
2613                    start_time,
2614                    end_time,
2615                )?;
2616                self.resolve_asymmetric_migration(
2617                    dest_name.to_string(),
2618                    source_name.to_string(),
2619                    rate,
2620                    start_time,
2621                    end_time,
2622                )?;
2623            }
2624        }
2625
2626        Ok(())
2627    }
2628
2629    fn resolve_migrations(&mut self) -> Result<(), DemesError> {
2630        // NOTE: due to the borrow checker not trusting us, we
2631        // do the old "swap it out" trick to demonstrate
2632        // that we are not doing bad things.
2633        // This object is only mut b/c we need to swap it
2634        let mut input_migrations: Vec<UnresolvedMigration> = vec![];
2635        std::mem::swap(&mut input_migrations, &mut self.input_migrations);
2636
2637        if input_migrations.is_empty() {
2638            // if there are non-default
2639            // fields in migration defaults,
2640            // then we go for it and add a default value
2641            if self.defaults.migration != UnresolvedMigration::default() {
2642                input_migrations.push(self.defaults.migration.clone());
2643            }
2644        }
2645
2646        for input_mig in &input_migrations {
2647            let mut input_mig_clone = input_mig.clone();
2648            self.defaults.apply_migration_defaults(&mut input_mig_clone);
2649            let m = Migration::try_from(input_mig_clone)?;
2650            match m {
2651                Migration::Asymmetric(a) => self.process_input_asymmetric_migration(&a)?,
2652                Migration::Symmetric(s) => self.process_input_symmetric_migration(&s)?,
2653            }
2654        }
2655
2656        // The spec states that we can discard the unresolved migration stuff,
2657        // but we'll swap it back. It does no harm to do so.
2658        std::mem::swap(&mut input_migrations, &mut self.input_migrations);
2659        Ok(())
2660    }
2661
2662    fn build_migration_epochs(&self) -> HashMap<(String, String), Vec<TimeInterval>> {
2663        let mut rv = HashMap::<(String, String), Vec<TimeInterval>>::default();
2664
2665        for migration in &self.resolved_migrations {
2666            let source = migration.source().to_string();
2667            let dest = migration.dest().to_string();
2668            let key = (source, dest);
2669
2670            match rv.get_mut(&key) {
2671                Some(v) => v.push(migration.time_interval()),
2672                None => {
2673                    let _ = rv.insert(key, vec![migration.time_interval()]);
2674                }
2675            }
2676        }
2677
2678        rv
2679    }
2680
2681    fn check_migration_epoch_overlap(&self) -> Result<(), DemesError> {
2682        let mig_epochs = self.build_migration_epochs();
2683
2684        for (demes, epochs) in &mig_epochs {
2685            if epochs.windows(2).any(|w| w[0].overlaps(&w[1])) {
2686                return Err(DemesError::MigrationError(format!(
2687                    "overlapping migration epochs between source: {} and dest: {}",
2688                    demes.0, demes.1
2689                )));
2690            }
2691        }
2692        Ok(())
2693    }
2694
2695    fn get_non_overlapping_migration_intervals(&self) -> Vec<TimeInterval> {
2696        let mut unique_times = HashSet::<HashableTime>::default();
2697        for migration in &self.resolved_migrations {
2698            unique_times.insert(HashableTime::from(migration.start_time()));
2699            unique_times.insert(HashableTime::from(migration.end_time()));
2700        }
2701        unique_times.retain(|t| f64::from(*t).is_finite());
2702
2703        let mut end_times = unique_times.into_iter().map(Time::from).collect::<Vec<_>>();
2704
2705        // REVERSE sort
2706        end_times.sort_by(|a, b| b.cmp(a));
2707
2708        let mut start_times = vec![Time::try_from(f64::INFINITY).unwrap()];
2709
2710        if let Some((_last, elements)) = end_times.split_last() {
2711            start_times.extend_from_slice(elements);
2712        }
2713
2714        start_times
2715            .into_iter()
2716            .zip(end_times)
2717            .map(|times| TimeInterval::new(times.0, times.1))
2718            .collect::<Vec<_>>()
2719    }
2720
2721    fn validate_input_migration_rates(&self) -> Result<(), DemesError> {
2722        let intervals = self.get_non_overlapping_migration_intervals();
2723        let mut input_rates = HashMap::<String, Vec<f64>>::default();
2724
2725        for deme in self.deme_map.keys() {
2726            input_rates.insert(deme.clone(), vec![0.0; intervals.len()]);
2727        }
2728
2729        for (i, ti) in intervals.iter().enumerate() {
2730            for migration in &self.resolved_migrations {
2731                let mti = migration.time_interval();
2732                if ti.overlaps(&mti) {
2733                    match input_rates.get_mut(migration.dest()) {
2734                        Some(rates) => {
2735                            let rate = rates[i] + f64::from(migration.rate());
2736                            if rate > 1.0 + 1e-9 {
2737                                let msg = format!("migration rate into dest: {} is > 1 in the time interval ({:?}, {:?}]",
2738                                                  migration.dest(), ti.start_time(), ti.end_time());
2739                                return Err(DemesError::MigrationError(msg));
2740                            }
2741                            rates[i] = rate;
2742                        }
2743                        None => panic!("fatal error when validating migration rate sums"),
2744                    }
2745                }
2746            }
2747        }
2748
2749        Ok(())
2750    }
2751
2752    fn validate_migrations(&self) -> Result<(), DemesError> {
2753        for m in &self.resolved_migrations {
2754            let source = get_deme!(&m.source, &self.deme_map, &self.demes).ok_or_else(|| {
2755                DemesError::MigrationError(format!("invalid migration source: {}", m.source))
2756            })?;
2757            let dest = get_deme!(&m.dest, &self.deme_map, &self.demes).ok_or_else(|| {
2758                DemesError::MigrationError(format!("invalid migration dest: {}", m.dest))
2759            })?;
2760
2761            if source.name == dest.name {
2762                return Err(DemesError::MigrationError(format!(
2763                    "source: {} == dest: {}",
2764                    source.name, dest.name
2765                )));
2766            }
2767
2768            {
2769                let interval = source.get_time_interval()?;
2770                if !interval.contains_inclusive_start_exclusive_end(m.start_time) {
2771                    return Err(DemesError::MigrationError(format!(
2772                            "migration start_time: {:?} does not overlap with existence of source deme {}",
2773                            m.start_time,
2774                            source.name
2775                        )));
2776                }
2777                let interval = dest.get_time_interval()?;
2778                if !interval.contains_inclusive_start_exclusive_end(m.start_time) {
2779                    return Err(DemesError::MigrationError(format!(
2780                            "migration start_time: {:?} does not overlap with existence of dest deme {}",
2781                            m.start_time,
2782                            dest.name
2783                        )));
2784                }
2785            }
2786
2787            {
2788                if !f64::from(m.end_time).is_finite() {
2789                    return Err(DemesError::MigrationError(format!(
2790                        "invalid migration end_time: {:?}",
2791                        m.end_time
2792                    )));
2793                }
2794                let interval = source.get_time_interval()?;
2795                if !interval.contains_exclusive_start_inclusive_end(m.end_time) {
2796                    return Err(DemesError::MigrationError(format!(
2797                            "migration end_time: {:?} does not overlap with existence of source deme {}",
2798                            m.end_time,
2799                            source.name
2800                        )));
2801                }
2802                let interval = dest.get_time_interval()?;
2803                if !interval.contains_exclusive_start_inclusive_end(m.end_time) {
2804                    return Err(DemesError::MigrationError(format!(
2805                        "migration end_time: {:?} does not overlap with existence of dest deme {}",
2806                        m.end_time, dest.name
2807                    )));
2808                }
2809            }
2810
2811            let interval = m.time_interval();
2812            if !interval.duration_greater_than_zero() {
2813                return Err(DemesError::MigrationError(format!(
2814                    "invalid migration duration: {interval:?} ",
2815                )));
2816            }
2817        }
2818        self.check_migration_epoch_overlap()?;
2819        self.validate_input_migration_rates()?;
2820        Ok(())
2821    }
2822
2823    fn resolve_pulses(&mut self) -> Result<(), DemesError> {
2824        if self.pulses.is_empty() && self.defaults.pulse != UnresolvedPulse::default() {
2825            let c = self.defaults.pulse.clone();
2826            self.pulses.push(c);
2827        }
2828        self.pulses
2829            .iter_mut()
2830            .try_for_each(|pulse| pulse.resolve(&self.defaults))?;
2831        // NOTE: the sort_by flips the order to b, a
2832        // to put more ancient events at the front.
2833        // FIXME: we cannot remove this unwrap
2834        // unless we define Time as fully-ordered.
2835        self.pulses
2836            .sort_by(|a, b| b.time.partial_cmp(&a.time).unwrap());
2837        Ok(())
2838    }
2839
2840    // NOTE: this function could output a resoled Graph
2841    // type and maybe save some extra work/moves.
2842    pub(crate) fn resolve(self) -> Result<Self, DemesError> {
2843        let mut g = self;
2844        if g.demes.is_empty() {
2845            return Err(DemesError::DemeError(
2846                "no demes have been specified".to_string(),
2847            ));
2848        }
2849        g.defaults.validate()?;
2850        g.deme_map = g.build_deme_map()?;
2851
2852        let mut resolved_demes = vec![];
2853        for deme in g.demes.iter_mut() {
2854            deme.resolve(&g.deme_map, &resolved_demes, &g.defaults)?;
2855            resolved_demes.push(deme.clone());
2856        }
2857        g.demes = resolved_demes;
2858        g.demes.iter().try_for_each(|deme| deme.validate())?;
2859        g.resolve_migrations()?;
2860        g.resolve_pulses()?;
2861        g.validate_migrations()?;
2862
2863        match g.generation_time {
2864            Some(_) => (), //value.validate(DemesError::GraphError)?,
2865            None => {
2866                if matches!(g.time_units, TimeUnits::Generations) {
2867                    g.generation_time = Some(InputGenerationTime::from(1.));
2868                }
2869            }
2870        }
2871        Ok(g)
2872    }
2873
2874    pub(crate) fn validate(&self) -> Result<(), DemesError> {
2875        if self.demes.is_empty() {
2876            return Err(DemesError::DemeError("no demes specified".to_string()));
2877        }
2878
2879        if !matches!(&self.time_units, TimeUnits::Generations) && self.generation_time.is_none() {
2880            return Err(DemesError::GraphError(
2881                "missing generation_time".to_string(),
2882            ));
2883        }
2884
2885        if matches!(&self.time_units, TimeUnits::Generations) {
2886            if let Some(value) = self.generation_time {
2887                if !value.equals(1.0) {
2888                    return Err(DemesError::GraphError(
2889                        "time units are generations but generation_time != 1.0".to_string(),
2890                    ));
2891                }
2892            }
2893        }
2894        self.pulses
2895            .iter()
2896            .try_for_each(|pulse| pulse.validate(&self.deme_map, &self.demes))?;
2897
2898        Ok(())
2899    }
2900
2901    pub(crate) fn set_metadata(&mut self, metadata: Metadata) {
2902        assert!(!metadata.is_empty());
2903        self.metadata = Some(metadata.metadata.clone())
2904    }
2905
2906    // Take our definition from
2907    // https://momentsld.github.io/moments/api/api_demes.html#moments.Demes.DemesUtil.rescale
2908    fn rescale(self, scaling_factor: f64) -> Result<Self, DemesError> {
2909        if !scaling_factor.is_finite() || scaling_factor <= 0.0 {
2910            return Err(DemesError::ValueError(format!(
2911                "invalid scaling_factor: {scaling_factor}"
2912            )));
2913        }
2914        let mut g = self;
2915
2916        g.demes
2917            .iter_mut()
2918            .try_for_each(|d| d.rescale(scaling_factor))?;
2919
2920        g.pulses
2921            .iter_mut()
2922            .try_for_each(|p| p.rescale(scaling_factor))?;
2923
2924        g.input_migrations
2925            .iter_mut()
2926            .try_for_each(|m| m.rescale(scaling_factor))?;
2927
2928        g.resolve()
2929    }
2930}
2931
2932impl From<Graph> for UnresolvedGraph {
2933    fn from(value: Graph) -> Self {
2934        let input_migrations: Vec<UnresolvedMigration> = value
2935            .resolved_migrations
2936            .into_iter()
2937            .map(UnresolvedMigration::from)
2938            .collect::<Vec<_>>();
2939        let pulses: Vec<UnresolvedPulse> = value
2940            .pulses
2941            .into_iter()
2942            .map(UnresolvedPulse::from)
2943            .collect::<Vec<_>>();
2944        let demes: Vec<UnresolvedDeme> = value
2945            .demes
2946            .into_iter()
2947            .map(UnresolvedDeme::from)
2948            .collect::<Vec<_>>();
2949
2950        let doi = if value.doi.is_empty() {
2951            None
2952        } else {
2953            Some(value.doi)
2954        };
2955
2956        Self {
2957            input_string: value.input_string,
2958            description: value.description,
2959            doi,
2960            defaults: GraphDefaults::default(),
2961            metadata: value.metadata,
2962            time_units: value.time_units,
2963            generation_time: Some(f64::from(value.generation_time).into()),
2964            demes,
2965            input_migrations,
2966            resolved_migrations: vec![],
2967            pulses,
2968            deme_map: value.deme_map,
2969        }
2970    }
2971}
2972
2973/// A resolved demes Graph.
2974///
2975/// Instances of this type will be fully-resolved according to
2976/// the machine data model described
2977/// [here](https://popsim-consortium.github.io/demes-spec-docs/main/specification.html#).
2978///
2979/// A graph cannot be directly initialized. See:
2980/// * [`load`](crate::load)
2981/// * [`loads`](crate::loads)
2982/// * [`GraphBuilder`](crate::GraphBuilder)
2983#[derive(Serialize, Debug, Clone)]
2984#[serde(deny_unknown_fields, try_from = "UnresolvedGraph")]
2985pub struct Graph {
2986    #[serde(skip_serializing)]
2987    #[serde(skip_deserializing)]
2988    #[serde(default = "Option::default")]
2989    input_string: Option<InputFormatInternal>,
2990    #[serde(skip_serializing_if = "Option::is_none")]
2991    description: Option<String>,
2992    #[serde(skip_serializing_if = "Vec::is_empty")]
2993    doi: Vec<String>,
2994    #[serde(default = "Option::default")]
2995    #[serde(skip_serializing_if = "Option::is_none")]
2996    metadata: Option<std::collections::BTreeMap<String, serde_yaml::Value>>,
2997    time_units: TimeUnits,
2998    generation_time: GenerationTime,
2999    pub(crate) demes: Vec<Deme>,
3000    #[serde(default = "Vec::<AsymmetricMigration>::default")]
3001    #[serde(rename = "migrations")]
3002    #[serde(skip_deserializing)]
3003    #[serde(skip_serializing_if = "Vec::<AsymmetricMigration>::is_empty")]
3004    resolved_migrations: Vec<AsymmetricMigration>,
3005    #[serde(default = "Vec::<Pulse>::default")]
3006    pulses: Vec<Pulse>,
3007    #[serde(skip)]
3008    deme_map: DemeMap,
3009}
3010
3011// NOTE: the manual implementation
3012// skips over stuff that's only used by the HDM.
3013// We are testing equality of the MDM only.
3014impl PartialEq for Graph {
3015    fn eq(&self, other: &Self) -> bool {
3016        self.description == other.description
3017            && self.doi == other.doi
3018            && self.time_units == other.time_units
3019            && self.generation_time == other.generation_time
3020            && self.demes == other.demes
3021            && self.resolved_migrations == other.resolved_migrations
3022            && self.metadata == other.metadata
3023            && self.pulses == other.pulses
3024    }
3025}
3026
3027impl Eq for Graph {}
3028
3029impl std::fmt::Display for Graph {
3030    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
3031        write!(f, "{}", self.as_string().unwrap())
3032    }
3033}
3034
3035impl TryFrom<UnresolvedGraph> for Graph {
3036    type Error = DemesError;
3037
3038    fn try_from(value: UnresolvedGraph) -> Result<Self, Self::Error> {
3039        value.validate()?;
3040        let mut pulses = vec![];
3041        for p in value.pulses {
3042            pulses.push(Pulse::try_from(p)?);
3043        }
3044        let mut demes = vec![];
3045        for hdm_deme in value.demes.into_iter() {
3046            let deme = Deme::try_from(hdm_deme)?;
3047            demes.push(deme);
3048        }
3049        Ok(Self {
3050            input_string: value.input_string,
3051            description: value.description,
3052            doi: value.doi.unwrap_or_default(),
3053            metadata: value.metadata,
3054            time_units: value.time_units,
3055            generation_time: value
3056                .generation_time
3057                .ok_or_else(|| DemesError::GraphError("generation_time is unresolved".to_string()))?
3058                .try_into()?,
3059            demes,
3060            resolved_migrations: value.resolved_migrations,
3061            pulses,
3062            deme_map: value.deme_map,
3063        })
3064    }
3065}
3066
3067fn string_from_reader<T: Read>(reader: T) -> Result<String, DemesError> {
3068    let mut reader = reader;
3069    let mut buf = String::default();
3070    let _ = reader
3071        .read_to_string(&mut buf)
3072        .map_err(|e| DemesError::IOerror(crate::error::OpaqueIOError(e)))?;
3073    Ok(buf)
3074}
3075
3076impl Graph {
3077    pub(crate) fn new_from_str(yaml: &'_ str) -> Result<Self, DemesError> {
3078        let g: UnresolvedGraph =
3079            serde_yaml::from_str(yaml).map_err(|e| DemesError::YamlError(OpaqueYamlError(e)))?;
3080        let mut g = g.resolve()?;
3081        g.validate()?;
3082        g.input_string = Some(InputFormatInternal::Yaml(yaml.to_owned()));
3083        g.try_into()
3084    }
3085
3086    #[cfg(feature = "json")]
3087    #[cfg_attr(doc_cfg, doc(cfg(feature = "json")))]
3088    pub(crate) fn new_resolved_from_json_str(json: &'_ str) -> Result<Self, DemesError> {
3089        let json: std::collections::HashMap<String, serde_json::Value> = serde_json::from_str(json)
3090            .map_err(|e| DemesError::JsonError(crate::error::OpaqueJSONError(e)))?;
3091        let json = crate::process_json::fix_json_input(json)?;
3092        let json = serde_json::to_string(&json)
3093            .map_err(|e| DemesError::JsonError(crate::error::OpaqueJSONError(e)))?;
3094        let g: UnresolvedGraph = serde_json::from_str(&json)
3095            .map_err(|e| DemesError::JsonError(crate::error::OpaqueJSONError(e)))?;
3096        let mut g = g.resolve()?;
3097        g.validate()?;
3098        g.input_string = Some(InputFormatInternal::Json(json.to_owned()));
3099        g.try_into()
3100    }
3101
3102    #[cfg(feature = "toml")]
3103    #[cfg_attr(doc_cfg, doc(cfg(feature = "toml")))]
3104    pub(crate) fn new_resolved_from_toml_str(toml: &'_ str) -> Result<Self, DemesError> {
3105        let g: UnresolvedGraph = toml::from_str(toml)
3106            .map_err(|e| DemesError::TomlDeError(crate::error::OpaqueTOMLError(e)))?;
3107        let mut g = g.resolve()?;
3108        g.validate()?;
3109        g.input_string = Some(InputFormatInternal::Toml(toml.to_owned()));
3110        g.try_into()
3111    }
3112
3113    pub(crate) fn new_from_reader<T: Read>(reader: T) -> Result<Self, DemesError> {
3114        let yaml = string_from_reader(reader)?;
3115        Self::new_from_str(&yaml)
3116    }
3117
3118    #[cfg(feature = "json")]
3119    #[cfg_attr(doc_cfg, doc(cfg(feature = "json")))]
3120    pub(crate) fn new_from_json_reader<T: Read>(reader: T) -> Result<Self, DemesError> {
3121        let json = string_from_reader(reader)?;
3122        Self::new_resolved_from_json_str(&json)
3123    }
3124
3125    #[cfg(feature = "toml")]
3126    #[cfg_attr(doc_cfg, doc(cfg(feature = "toml")))]
3127    pub(crate) fn new_from_toml_reader<T: Read>(reader: T) -> Result<Self, DemesError> {
3128        let toml = string_from_reader(reader)?;
3129        Self::new_resolved_from_toml_str(&toml)
3130    }
3131
3132    pub(crate) fn new_resolved_from_str(yaml: &'_ str) -> Result<Self, DemesError> {
3133        let graph = Self::new_from_str(yaml)?;
3134        assert!(graph.input_string.is_some());
3135        Ok(graph)
3136    }
3137
3138    pub(crate) fn new_resolved_from_reader<T: Read>(reader: T) -> Result<Self, DemesError> {
3139        let graph = Self::new_from_reader(reader)?;
3140        assert!(graph.input_string.is_some());
3141        Ok(graph)
3142    }
3143
3144    #[cfg(feature = "json")]
3145    #[cfg_attr(doc_cfg, doc(cfg(feature = "json")))]
3146    pub(crate) fn new_resolved_from_json_reader<T: Read>(reader: T) -> Result<Self, DemesError> {
3147        let graph = Self::new_from_json_reader(reader)?;
3148        assert!(graph.input_string.is_some());
3149        Ok(graph)
3150    }
3151
3152    #[cfg(feature = "toml")]
3153    #[cfg_attr(doc_cfg, doc(cfg(feature = "toml")))]
3154    pub(crate) fn new_resolved_from_toml_reader<T: Read>(reader: T) -> Result<Self, DemesError> {
3155        let graph = Self::new_from_toml_reader(reader)?;
3156        assert!(graph.input_string.is_some());
3157        Ok(graph)
3158    }
3159
3160    /// The number of [`Deme`](crate::Deme) instances in the graph.
3161    pub fn num_demes(&self) -> usize {
3162        self.demes.len()
3163    }
3164
3165    /// Obtain a reference to a [`Deme`](crate::Deme) by its name.
3166    ///
3167    /// # Returns
3168    ///
3169    /// `Some(&Deme)` if `name` exists, `None` otherwise.
3170    ///
3171    /// # Examples
3172    ///
3173    /// See [`here`](crate::SizeFunction).
3174    #[deprecated(note = "Use Graph::get_deme instead")]
3175    pub fn get_deme_from_name(&self, name: &str) -> Option<&Deme> {
3176        let id = DemeId::from(name);
3177        self.get_deme(id)
3178    }
3179
3180    /// Get the [`Deme`](crate::Deme) at identifier `id`.
3181    ///
3182    /// # Parameters
3183    ///
3184    /// * `id`, the [`DemeId`] to fetch.
3185    ///
3186    /// # Panics
3187    ///
3188    /// * If either variant of [`DemeId`] refers to an invalid deme
3189    ///
3190    /// # Note
3191    ///
3192    /// See [`Graph::get_deme`] for a version that will not panic
3193    pub fn deme<'name, I: Into<DemeId<'name>>>(&self, id: I) -> &Deme {
3194        self.get_deme(id).unwrap()
3195    }
3196
3197    /// Get the [`Deme`](crate::Deme) at identifier `id`.
3198    ///
3199    /// # Parameters
3200    ///
3201    /// * `id`, the [`DemeId`] to fetch.
3202    ///
3203    /// # Returns
3204    ///
3205    /// * `Some(&[`Deme`])` if `id` is valid
3206    /// * `None` otherwise
3207    pub fn get_deme<'name, I: Into<DemeId<'name>>>(&self, id: I) -> Option<&Deme> {
3208        match id.into() {
3209            DemeId::Index(i) => self.demes.get(i),
3210            DemeId::Name(name) => get_deme!(name, &self.deme_map, &self.demes),
3211        }
3212    }
3213
3214    /// Get the [`Deme`](crate::Deme) instances via a slice.
3215    pub fn demes(&self) -> &[Deme] {
3216        &self.demes
3217    }
3218
3219    /// Get the [`GenerationTime`](crate::GenerationTime) for the graph.
3220    pub fn generation_time(&self) -> GenerationTime {
3221        self.generation_time
3222    }
3223
3224    /// Get the [`TimeUnits`](crate::TimeUnits) for the graph.
3225    pub fn time_units(&self) -> TimeUnits {
3226        self.time_units.clone()
3227    }
3228
3229    /// Get the migration events for the graph.
3230    pub fn migrations(&self) -> &[AsymmetricMigration] {
3231        &self.resolved_migrations
3232    }
3233
3234    /// Get the pulse events for the graph.
3235    pub fn pulses(&self) -> &[Pulse] {
3236        &self.pulses
3237    }
3238
3239    /// Get a copy of the top-level [`Metadata`](crate::Metadata).
3240    pub fn metadata(&self) -> Option<Metadata> {
3241        self.metadata.as_ref().map(|md| Metadata {
3242            metadata: md.clone(),
3243        })
3244    }
3245
3246    fn convert_to_generations_details(
3247        self,
3248        round: fn(Time, GenerationTime) -> Time,
3249    ) -> Result<Self, DemesError> {
3250        let mut converted = self;
3251
3252        converted.demes.iter_mut().try_for_each(|deme| {
3253            deme.resolved_time_to_generations(converted.generation_time, round)
3254        })?;
3255
3256        converted.pulses.iter_mut().try_for_each(|pulse| {
3257            pulse.resolved_time_to_generations(converted.generation_time, round)
3258        })?;
3259
3260        converted
3261            .resolved_migrations
3262            .iter_mut()
3263            .try_for_each(|pulse| {
3264                pulse.resolved_time_to_generations(converted.generation_time, round)
3265            })?;
3266
3267        converted.time_units = TimeUnits::Generations;
3268        converted.generation_time = GenerationTime::try_from(1.0)?;
3269
3270        Ok(converted)
3271    }
3272
3273    /// Convert the time units to generations.
3274    ///
3275    /// # Errors
3276    ///
3277    /// If the time unit of an event differs sufficiently in
3278    /// magnitude from the `generation_time`, it is possible
3279    /// that conversion results in epochs (or migration
3280    /// durations) of length zero, which will return an error.
3281    ///
3282    /// If any field is unresolved, an error will be returned.
3283    pub fn into_generations(self) -> Result<Self, DemesError> {
3284        self.into_generations_with(crate::time::to_generations)
3285    }
3286
3287    /// Convert the time units to generations, rounding the output to an integer value.
3288    pub fn into_integer_generations(self) -> Result<Graph, DemesError> {
3289        self.into_generations_with(crate::time::round_time_to_integer_generations)
3290    }
3291
3292    /// Convert the time units to generations with a callback to specify the conversion
3293    /// policy
3294    pub fn into_generations_with(
3295        self,
3296        with: fn(Time, GenerationTime) -> Time,
3297    ) -> Result<Graph, DemesError> {
3298        self.convert_to_generations_details(with)
3299    }
3300
3301    /// Return a representation of the graph as a string.
3302    ///
3303    /// The format is in YAML and corresponds to the MDM
3304    /// representation of the data.
3305    ///
3306    /// # Error
3307    ///
3308    /// Will return an error if `serde_yaml::to_string`
3309    /// returns an error.
3310    pub fn as_string(&self) -> Result<String, DemesError> {
3311        match serde_yaml::to_string(self) {
3312            Ok(string) => Ok(string),
3313            Err(e) => Err(DemesError::YamlError(OpaqueYamlError(e))),
3314        }
3315    }
3316
3317    /// Return a representation of the graph as a string.
3318    ///
3319    /// The format is in JSON and corresponds to the MDM
3320    /// representation of the data.
3321    ///
3322    /// # Error
3323    ///
3324    /// Will return an error if `serde_json::to_string`
3325    /// returns an error.
3326    #[cfg(feature = "json")]
3327    #[cfg_attr(doc_cfg, doc(cfg(feature = "json")))]
3328    pub fn as_json_string(&self) -> Result<String, DemesError> {
3329        match serde_json::to_string(self) {
3330            Ok(string) => Ok(string),
3331            Err(e) => Err(DemesError::JsonError(crate::error::OpaqueJSONError(e))),
3332        }
3333    }
3334
3335    /// Return the most recent end time of any deme
3336    /// in the Graph.
3337    ///
3338    /// This function is useful to check if the most
3339    /// recent end time is greater than zero, meaning
3340    /// that the model ends at a time point ancestral to
3341    /// "now".
3342    pub fn most_recent_deme_end_time(&self) -> Time {
3343        let init = self.demes[0].end_time();
3344        self.demes
3345            .iter()
3346            .skip(1)
3347            .fold(init, |current_min, current_deme| {
3348                std::cmp::min(current_min, current_deme.end_time())
3349            })
3350    }
3351
3352    /// Return the description field.
3353    pub fn description(&self) -> Option<&str> {
3354        match &self.description {
3355            Some(x) => Some(x),
3356            None => None,
3357        }
3358    }
3359
3360    /// Return an iterator over DOI information.
3361    pub fn doi(&self) -> impl Iterator<Item = &str> {
3362        self.doi.iter().map(|s| s.as_str())
3363    }
3364
3365    /// Check if any epochs have non-integer
3366    /// `start_size` or `end_size`.
3367    ///
3368    /// # Returns
3369    ///
3370    /// * The deme name and epoch index where the first
3371    ///   non-integer value is encountered
3372    /// * None if non non-integer values are encountered
3373    pub fn has_non_integer_sizes(&self) -> Option<(&str, usize)> {
3374        for deme in &self.demes {
3375            for (i, epoch) in deme.epochs.iter().enumerate() {
3376                for size in [f64::from(epoch.start_size()), f64::from(epoch.end_size())] {
3377                    if size.is_finite() && size.fract() != 0.0 {
3378                        return Some((deme.name(), i));
3379                    }
3380                }
3381            }
3382        }
3383        None
3384    }
3385
3386    fn epoch_start_end_size_rounding_details(
3387        old_size: DemeSize,
3388        rounding_fn: fn(f64) -> f64,
3389    ) -> Result<DemeSize, DemesError> {
3390        let size = f64::from(old_size);
3391        if size.is_finite() && size.fract() != 0.0 {
3392            let new_size = rounding_fn(size);
3393            if !new_size.is_finite() || new_size.fract() != 0.0 || new_size <= 0.0 {
3394                let msg = format!("invalid size after rounding: {new_size}");
3395                return Err(DemesError::EpochError(msg));
3396            }
3397            return new_size.try_into();
3398        }
3399        Ok(old_size)
3400    }
3401
3402    fn round_epoch_start_end_sizes_with(
3403        self,
3404        rounding_fn: fn(f64) -> f64,
3405    ) -> Result<Self, DemesError> {
3406        let mut graph = self;
3407
3408        for deme in &mut graph.demes {
3409            for epoch in &mut deme.epochs {
3410                epoch.start_size =
3411                    Graph::epoch_start_end_size_rounding_details(epoch.start_size, rounding_fn)?;
3412                epoch.end_size =
3413                    Graph::epoch_start_end_size_rounding_details(epoch.end_size, rounding_fn)?;
3414            }
3415        }
3416
3417        Ok(graph)
3418    }
3419
3420    /// Round all epoch start/end sizes to nearest integer value.
3421    ///
3422    /// # Returns
3423    ///
3424    /// A modified graph with rounded sizes.
3425    ///
3426    /// # Error
3427    ///
3428    /// * [`EpochError`](crate::DemesError::EpochError) if rounding
3429    ///   leads to a value of 0.
3430    ///
3431    /// # Note
3432    ///
3433    /// Rounding uses [f64::round](f64::round)
3434    pub fn into_integer_start_end_sizes(self) -> Result<Self, DemesError> {
3435        self.round_epoch_start_end_sizes_with(f64::round)
3436    }
3437
3438    /// Obtain names of all demes in the graph.
3439    ///
3440    /// # Note
3441    ///
3442    /// These are ordered by a deme's index in the model.
3443    ///
3444    /// # Panics
3445    ///
3446    /// This function allocates space for the return value,
3447    /// which may panic upon out-of-memory.
3448    pub fn deme_names(&self) -> Box<[&str]> {
3449        self.demes
3450            .iter()
3451            .map(|deme| deme.name())
3452            .collect::<Vec<&str>>()
3453            .into_boxed_slice()
3454    }
3455
3456    /// Get a reference to the input string, if any.
3457    ///
3458    /// # Examples
3459    ///
3460    /// ```
3461    /// let yaml = "
3462    /// time_units: years
3463    /// generation_time: 25
3464    /// description:
3465    ///   A deme of 50 individuals that grew to 100 individuals
3466    ///   in the last 100 years.
3467    ///   Default behavior is that size changes are exponential.
3468    /// demes:
3469    ///  - name: deme
3470    ///    epochs:
3471    ///     - start_size: 50
3472    ///       end_time: 100
3473    ///     - start_size: 50
3474    ///       end_size: 100
3475    /// ";
3476    /// let graph = demes::loads(yaml).unwrap();
3477    /// assert_eq!(graph.input_string().unwrap().to_str(), yaml);
3478    /// ```
3479    ///
3480    /// # Note
3481    ///
3482    /// The string is in the same format (YAML or JSON)
3483    /// that was used to generate the graph.
3484    pub fn input_string(&'_ self) -> Option<InputFormat<'_>> {
3485        match &self.input_string {
3486            None => None,
3487            Some(format) => match format {
3488                InputFormatInternal::Yaml(string) => Some(InputFormat::Yaml(string.as_str())),
3489                InputFormatInternal::Json(string) => Some(InputFormat::Json(string.as_str())),
3490                InputFormatInternal::Toml(string) => Some(InputFormat::Toml(string.as_str())),
3491            },
3492        }
3493    }
3494
3495    /// Rescale a model by a constant scaling factor.
3496    ///
3497    /// For a given scaling factor, `Q`:
3498    /// 1. All input population sizes will be divided by `Q`.
3499    /// 2. All times will be divided by `Q`.
3500    /// 3. All rates (migration, etc.) will be multiplied by `Q.`
3501    /// 4. Pulse proportions, selfing rates, and cloning rates all
3502    ///    remaing unchanged.
3503    ///
3504    /// The result is a new [`Graph`] where the products of populations sizes
3505    /// times rates and the timings of events divided by population sizes are the
3506    /// same as in the input.
3507    ///
3508    /// # Parameters
3509    ///
3510    /// * `scaling_factor`: the value of `Q`. This must be > 0.0 and finite.
3511    ///
3512    /// # Returns
3513    ///
3514    /// * The rescaled [`Graph`]
3515    ///
3516    /// # Errors
3517    ///
3518    /// * [`DemesError`] if `scaling_factor` is invalid or if rescaling results
3519    ///   in an invalid graph.  For example, rescaling with `scaling_factor << 1`
3520    ///   could result in migration rates `> 1`, which is invalid.
3521    pub fn rescale(self, scaling_factor: f64) -> Result<Self, DemesError> {
3522        let g = UnresolvedGraph::from(self);
3523        g.rescale(scaling_factor)?.try_into()
3524    }
3525
3526    /// Remove recent history from a [Graph].
3527    ///
3528    /// For a given value of `when`, a new graph is created with all
3529    /// history from `[0, when)` removed.
3530    ///
3531    /// # Examples
3532    ///
3533    /// Remove the first ten generations:
3534    ///
3535    /// ```
3536    /// let yaml = "
3537    /// time_units: generations
3538    /// demes:
3539    ///  - start_time: .inf
3540    ///    name: deme
3541    ///    epochs:
3542    ///     - start_size: 100
3543    /// ";
3544    /// let graph = demes::loads(yaml).unwrap();
3545    /// assert_eq!(graph.demes()[0].end_time(), 0.0);
3546    /// let when = demes::Time::try_from(10.0).unwrap();
3547    /// let sliced = graph.slice_until(when).unwrap();
3548    /// assert_eq!(sliced.demes()[0].end_time(), 10.0);
3549    /// ```
3550    ///
3551    /// For the next example, removing the first 20 generations
3552    /// removes the ancestral deme entirely:
3553    ///
3554    /// ```
3555    /// let yaml = "
3556    /// time_units: generations
3557    /// demes:
3558    ///  - start_time: .inf
3559    ///    name: ancestor
3560    ///    epochs:
3561    ///     - start_size: 100
3562    ///       end_time: 20
3563    ///  - name: derived
3564    ///    start_time: 20
3565    ///    ancestors: [ancestor]
3566    ///    proportions: [1.0]
3567    ///    epochs:
3568    ///     - start_size: 50
3569    /// ";
3570    /// let graph = demes::loads(yaml).unwrap();
3571    /// assert_eq!(graph.demes().len(), 2);
3572    /// assert_eq!(graph.demes()[0].end_time(), 20.0);
3573    /// assert_eq!(graph.demes()[1].end_time(), 0.0);
3574    /// let when = demes::Time::try_from(20.0).unwrap();
3575    /// let sliced = graph.slice_until(when).unwrap();
3576    /// assert_eq!(sliced.demes().len(), 1);
3577    /// assert_eq!(sliced.demes()[0].end_time(), 20.0);
3578    /// assert_eq!(sliced.demes()[0].start_time(), f64::INFINITY);
3579    /// assert_eq!(sliced.demes()[0].name(), "ancestor");
3580    /// ```
3581    ///
3582    /// For the same input, removing the first 10 generations
3583    /// simply truncates the duration of the derived deme:
3584    ///
3585    /// ```
3586    /// # let yaml = "
3587    /// # time_units: generations
3588    /// # demes:
3589    /// #  - start_time: .inf
3590    /// #    name: ancestor
3591    /// #    epochs:
3592    /// #     - start_size: 100
3593    /// #       end_time: 20
3594    /// #  - name: derived
3595    /// #    start_time: 20
3596    /// #    ancestors: [ancestor]
3597    /// #    proportions: [1.0]
3598    /// #    epochs:
3599    /// #     - start_size: 50
3600    /// # ";
3601    /// # let graph = demes::loads(yaml).unwrap();
3602    /// # assert_eq!(graph.demes().len(), 2);
3603    /// # assert_eq!(graph.demes()[0].end_time(), 20.0);
3604    /// # assert_eq!(graph.demes()[1].end_time(), 0.0);
3605    /// let when = demes::Time::try_from(10.0).unwrap();
3606    /// let sliced = graph.slice_until(when).unwrap();
3607    /// assert_eq!(sliced.demes().len(), 2);
3608    /// assert_eq!(sliced.demes()[0].end_time(), 20.0);
3609    /// assert_eq!(sliced.demes()[0].start_time(), f64::INFINITY);
3610    /// assert_eq!(sliced.demes()[0].name(), "ancestor");
3611    /// assert_eq!(sliced.demes()[1].end_time(), 10.0);
3612    /// assert_eq!(sliced.demes()[1].start_time(), 20.0);
3613    /// assert_eq!(sliced.demes()[1].name(), "derived");
3614    /// ```
3615    pub fn slice_until(self, when: Time) -> Result<Self, DemesError> {
3616        crate::graph_operations::slice::slice_until(self, when)
3617    }
3618
3619    /// Remove ancient history from a [Graph].
3620    ///
3621    /// For a given value of `when`, a new graph is created with only
3622    /// history from `[0, when)` retained.
3623    ///
3624    /// # Examples
3625    ///
3626    /// ```
3627    /// let yaml = "
3628    /// time_units: generations
3629    /// demes:
3630    ///  - start_time: .inf
3631    ///    name: ancestor
3632    ///    epochs:
3633    ///     - start_size: 100
3634    ///       end_time: 20
3635    ///  - name: derived
3636    ///    start_time: 20
3637    ///    ancestors: [ancestor]
3638    ///    proportions: [1.0]
3639    ///    epochs:
3640    ///     - start_size: 50
3641    /// ";
3642    /// let graph = demes::loads(yaml).unwrap();
3643    /// assert_eq!(graph.demes().len(), 2);
3644    /// assert_eq!(graph.demes()[0].end_time(), 20.0);
3645    /// assert_eq!(graph.demes()[1].end_time(), 0.0);
3646    /// let when = demes::Time::try_from(20.0).unwrap();
3647    /// let sliced = graph.slice_after(when).unwrap();
3648    /// assert_eq!(sliced.demes().len(), 1);
3649    /// assert_eq!(sliced.demes()[0].end_time(), 0.0);
3650    /// assert_eq!(sliced.demes()[0].start_time(), f64::INFINITY);
3651    /// assert_eq!(sliced.demes()[0].name(), "derived");
3652    /// ```
3653    ///
3654    /// If `when` is within an epoch, we insert an extra epoch
3655    /// that has constant size until infinity in the past.
3656    /// The size is the epoch's size at `when`.
3657    /// Let's look at an example involving population growth.
3658    ///
3659    /// ```
3660    /// let yaml = "
3661    /// time_units: generations
3662    /// demes:
3663    ///  - name: growing
3664    ///    epochs:
3665    ///     - start_size: 100
3666    ///       end_time: 100
3667    ///     - start_size: 100
3668    ///       end_size: 200
3669    /// ";
3670    /// let graph = demes::loads(yaml).unwrap();
3671    /// let when = demes::Time::try_from(50.).unwrap();
3672    /// let sliced = graph.clone().slice_after(when).unwrap();
3673    /// let deme = &sliced.deme(0);
3674    /// assert_eq!(deme.num_epochs(), 2);
3675    /// let e = deme.epochs()[0];
3676    /// assert_eq!(e.start_time(), f64::INFINITY);
3677    /// assert_eq!(e.end_time(), when);
3678    /// assert_eq!(e.start_size(), graph.deme(0).size_at(when).unwrap().unwrap());
3679    /// assert_eq!(e.start_size(), e.end_size());
3680    /// let e = deme.epochs()[1];
3681    /// assert_eq!(e.start_time(), when);
3682    /// assert_eq!(e.end_time(), 0.0);
3683    /// assert_eq!(e.start_size(), graph.deme(0).size_at(when).unwrap().unwrap());
3684    /// assert_eq!(e.end_size(), graph.deme(0).end_size());
3685    /// ```
3686    pub fn slice_after(self, when: Time) -> Result<Self, DemesError> {
3687        crate::graph_operations::slice::slice_after(self, when)
3688    }
3689
3690    /// Obtain a deme index from a deme name
3691    ///
3692    /// # Parameters
3693    ///
3694    /// * `name` - the deme name
3695    ///
3696    /// # Returns
3697    ///
3698    /// * `Some(index)` if `name` is a valid deme name
3699    /// * `None` Otherwise
3700    ///
3701    /// # Complexity
3702    ///
3703    /// * O(1)
3704    ///
3705    /// # Examples
3706    ///
3707    /// ```
3708    /// let yaml = "
3709    ///  time_units: generations
3710    ///  demes:
3711    ///   - name: ancestor1
3712    ///     epochs:
3713    ///      - start_size: 50
3714    ///        end_time: 20
3715    ///   - name: ancestor2
3716    ///     epochs:
3717    ///      - start_size: 50
3718    ///        end_time: 20
3719    ///   - name: derived
3720    ///     ancestors: [ancestor1, ancestor2]
3721    ///     proportions: [0.5, 0.5]
3722    ///     start_time: 20
3723    ///     epochs:
3724    ///      - start_size: 50
3725    /// ";
3726    /// let graph = demes::loads(yaml).unwrap();
3727    /// for (i, d) in graph.demes().iter().enumerate() {
3728    ///     assert_eq!(graph.deme_index(d.name()), Some(i))
3729    /// }
3730    /// ```
3731    pub fn deme_index<S: AsRef<str>>(&self, name: S) -> Option<usize> {
3732        self.deme_map.get(name.as_ref()).cloned()
3733    }
3734
3735    /// Obtain the ancestry proportions for a deme at a given time.
3736    ///
3737    /// # Parameters
3738    ///
3739    /// * `deme` - the "focal" deme whose ancestry proportions will be calculated.
3740    /// * `at` - the [Time] at which to calculate ancestry proportions.
3741    ///
3742    /// # Returns
3743    ///
3744    /// * A boxed slice if `deme` exists in the graph and time `at` is a sensible
3745    ///   parental time point
3746    /// * None, otherwise
3747    ///
3748    /// # Details
3749    ///
3750    /// * `at` is treated as a time immediately before individuals are born
3751    ///   into `deme`.
3752    /// * This function allocates memory.
3753    ///   See [Graph::fill_ancestry_proportions] for a function to reuse
3754    ///   existing allocations.
3755    ///
3756    /// # Complexity
3757    ///
3758    /// * See [Graph::fill_ancestry_proportions].
3759    ///
3760    /// # Panics
3761    ///
3762    /// * See [Graph::fill_ancestry_proportions].
3763    pub fn ancestry_proportions<'d, I: Into<DemeId<'d>>>(
3764        &self,
3765        deme: I,
3766        at: Time,
3767    ) -> Option<Box<[f64]>> {
3768        let mut rv = vec![0.0; self.num_demes()];
3769        self.fill_ancestry_proportions(deme, at, &mut rv)
3770            .map(|_| rv.into_boxed_slice())
3771    }
3772
3773    /// Obtain the ancestry proportions for a deme at a given time.
3774    ///
3775    /// # Parameters
3776    ///
3777    /// * `deme` - the "focal" deme whose ancestry proportions will be calculated.
3778    /// * `at` - the [Time] at which to calculate ancestry proportions.
3779    /// * `buffer` -  output location for the ancestry proportions.
3780    ///   The buffer length must be at least the number of demes in
3781    ///   the graph. (See [Graph::num_demes].)
3782    ///
3783    /// # Returns
3784    ///
3785    /// * A unit type if `deme` exists in the graph and time `at` is a sensible
3786    ///   parental time point
3787    /// * None, otherwise
3788    /// # Details
3789    ///
3790    /// `at` is treated as a time immediately before individuals are born
3791    /// into `deme`.
3792    ///
3793    /// # Complexity
3794    ///
3795    /// * Linear in the number of ancestor demes, pulses, and migration events
3796    ///   affecting ancestry at time `at`.
3797    ///
3798    /// # Panics
3799    ///
3800    /// * if any calculations result in invalid values (x < 0, x > 1, x not finite).
3801    /// * if `buffer` does not contain sufficient space for the output
3802    pub fn fill_ancestry_proportions<'d, I: Into<DemeId<'d>>>(
3803        &self,
3804        deme: I,
3805        at: Time,
3806        buffer: &mut [f64],
3807    ) -> Option<()> {
3808        let deme = self.get_deme(deme)?;
3809        if at <= deme.start_time() && at > deme.end_time() {
3810            self.fill_ancestry_proportions_details(deme, at, buffer);
3811            Some(())
3812        } else {
3813            None
3814        }
3815    }
3816
3817    fn fill_ancestry_proportions_details(&self, deme: &Deme, at: Time, buffer: &mut [f64]) {
3818        let deme_index = self.deme_map[deme.name()];
3819        if at <= deme.start_time() && at > deme.end_time() {
3820            buffer.fill_with(|| 0.0);
3821            if at == deme.start_time() {
3822                for (a, p) in deme
3823                    .ancestor_indexes()
3824                    .iter()
3825                    .cloned()
3826                    .zip(deme.proportions().iter().cloned())
3827                {
3828                    buffer[a] += f64::from(p);
3829                }
3830            } else {
3831                buffer[deme_index] = 1.0;
3832            }
3833            for pulse in self
3834                .pulses()
3835                .iter()
3836                .filter(|&p| p.time() == at && p.dest() == deme.name())
3837            {
3838                let sum_pulses = pulse
3839                    .proportions()
3840                    .iter()
3841                    .fold(0.0, |sum, &p| sum + f64::from(p));
3842                buffer.iter_mut().for_each(|v| *v *= 1. - sum_pulses);
3843                for (source, proportion) in pulse
3844                    .sources()
3845                    .iter()
3846                    .zip(pulse.proportions().iter().cloned().map(f64::from))
3847                {
3848                    let source_index = self.deme_map[source];
3849                    buffer[source_index] += proportion;
3850                }
3851            }
3852            let input_migrations = self
3853                .migrations()
3854                .iter()
3855                .filter(|m| at <= m.start_time() && at > m.end_time() && m.dest() == deme.name())
3856                .collect::<Vec<_>>();
3857            let sum_migrates = input_migrations
3858                .iter()
3859                .fold(0.0, |sum, m| sum + f64::from(m.rate()));
3860            buffer.iter_mut().for_each(|v| *v *= 1. - sum_migrates);
3861            for i in input_migrations {
3862                let source = self.deme_map[i.source()];
3863                buffer[source] += f64::from(i.rate())
3864            }
3865            assert!(
3866                buffer
3867                    .iter()
3868                    .all(|p| p.is_finite() && (0.0..=1.0).contains(p)),
3869                "invalid value in {buffer:?}"
3870            );
3871        }
3872    }
3873
3874    #[allow(missing_docs)]
3875    pub fn ancestry_proportions_matrix(&self, at: Time) -> Result<Box<[f64]>, DemesError> {
3876        let mut buffer = vec![0.; self.num_demes() * self.num_demes()];
3877        self.fill_ancestry_proportions_matrix(at, &mut buffer)
3878            .map(|_| buffer.into_boxed_slice())
3879    }
3880
3881    #[allow(missing_docs)]
3882    pub fn fill_ancestry_proportions_matrix(
3883        &self,
3884        at: Time,
3885        buffer: &mut [f64],
3886    ) -> Result<(), DemesError> {
3887        if at == 0.0 {
3888            return Err(DemesError::ValueError(format!(
3889                "time must be > 0.0, got {at:?}"
3890            )));
3891        }
3892        buffer.fill_with(|| 0.);
3893        for (deme_index, deme) in self.demes().iter().enumerate() {
3894            if at == deme.start_time() {
3895                for (a, p) in deme
3896                    .ancestor_indexes()
3897                    .iter()
3898                    .cloned()
3899                    .zip(deme.proportions().iter().cloned())
3900                {
3901                    buffer[deme_index * self.num_demes() + a] += f64::from(p);
3902                }
3903            } else if at > deme.end_time() && at <= deme.start_time() {
3904                buffer[deme_index * self.num_demes() + deme_index] = 1.0;
3905            }
3906        }
3907        let mut temp = vec![0.0; self.num_demes() * self.num_demes()];
3908        let pulses = self
3909            .pulses()
3910            .iter()
3911            .filter(|&p| p.time() == at)
3912            .collect::<Vec<_>>();
3913        if !pulses.is_empty() {
3914            for (deme_index, deme) in self.demes().iter().enumerate() {
3915                for &pulse in pulses.iter().filter(|p| p.dest() == deme.name()) {
3916                    for (a, p) in pulse
3917                        .sources()
3918                        .iter()
3919                        .zip(pulse.proportions().iter().cloned())
3920                    {
3921                        let source_index = self.deme_index(a).unwrap();
3922                        temp[deme_index * self.num_demes() + source_index] = p.into();
3923                    }
3924                }
3925            }
3926            for (i, t) in temp.chunks_exact(self.num_demes()).enumerate() {
3927                let sum = t.iter().sum::<f64>();
3928                for (j, tt) in buffer
3929                    .iter_mut()
3930                    .skip(i * self.num_demes())
3931                    .take(self.num_demes())
3932                    .zip(t.iter().cloned())
3933                {
3934                    *j *= 1. - sum;
3935                    *j += tt;
3936                }
3937            }
3938        }
3939        let input_migrations = self
3940            .migrations()
3941            .iter()
3942            .filter(|m| at <= m.start_time() && at > m.end_time())
3943            .collect::<Vec<_>>();
3944        if !input_migrations.is_empty() {
3945            temp.fill_with(|| 0.);
3946            for (deme_index, deme) in self.demes().iter().enumerate() {
3947                for &m in input_migrations.iter().filter(|p| p.dest() == deme.name()) {
3948                    let source_index = self.deme_index(m.source()).unwrap();
3949                    temp[deme_index * self.num_demes() + source_index] = m.rate().into();
3950                }
3951            }
3952            for (i, t) in temp.chunks_exact(self.num_demes()).enumerate() {
3953                let sum = t.iter().sum::<f64>();
3954                for (j, tt) in buffer
3955                    .iter_mut()
3956                    .skip(i * self.num_demes())
3957                    .take(self.num_demes())
3958                    .zip(t.iter().cloned())
3959                {
3960                    *j *= 1. - sum;
3961                    *j += tt;
3962                }
3963            }
3964        }
3965        Ok(())
3966    }
3967}
3968
3969#[cfg(test)]
3970mod tests {
3971    use super::*;
3972
3973    #[test]
3974    fn test_size_function() {
3975        let yaml = "---\nexponential\n".to_string();
3976        let sf: SizeFunction = serde_yaml::from_str(&yaml).unwrap();
3977        assert!(matches!(sf, SizeFunction::Exponential));
3978
3979        let yaml = "---\nconstant\n".to_string();
3980        let sf: SizeFunction = serde_yaml::from_str(&yaml).unwrap();
3981        assert!(matches!(sf, SizeFunction::Constant));
3982    }
3983
3984    #[test]
3985    fn test_display() {
3986        let t = Time::try_from(1.0).unwrap();
3987        let f = format!("{t}");
3988        assert_eq!(f, String::from("1"));
3989    }
3990
3991    #[test]
3992    #[should_panic]
3993    fn test_time_validity() {
3994        let _ = Time::try_from(f64::NAN).unwrap();
3995    }
3996
3997    #[test]
3998    fn test_newtype_compare_to_f64() {
3999        {
4000            let v = Time::try_from(100.0).unwrap();
4001            assert_eq!(v, 100.0);
4002            assert_eq!(100.0, v);
4003            assert!(v > 50.0);
4004            assert!(50.0 < v);
4005        }
4006
4007        {
4008            let v = DemeSize::try_from(100.0).unwrap();
4009            assert_eq!(v, 100.0);
4010            assert_eq!(100.0, v);
4011        }
4012
4013        {
4014            let v = SelfingRate::try_from(1.0).unwrap();
4015            assert_eq!(v, 1.0);
4016            assert_eq!(1.0, v);
4017        }
4018
4019        {
4020            let v = CloningRate::try_from(1.0).unwrap();
4021            assert_eq!(v, 1.0);
4022            assert_eq!(1.0, v);
4023        }
4024
4025        {
4026            let v = Proportion::try_from(1.0).unwrap();
4027            assert_eq!(v, 1.0);
4028            assert_eq!(1.0, v);
4029        }
4030
4031        {
4032            let v = MigrationRate::try_from(1.0).unwrap();
4033            assert_eq!(v, 1.0);
4034            assert_eq!(1.0, v);
4035        }
4036    }
4037}
4038
4039#[cfg(test)]
4040mod test_graph {
4041    use super::*;
4042
4043    #[test]
4044    fn test_round_trip_with_default_epoch_sizes() {
4045        let yaml = "
4046time_units: generations
4047defaults:
4048  epoch:
4049    start_size: 1000
4050demes:
4051  - name: A
4052";
4053        let g = Graph::new_resolved_from_str(yaml).unwrap();
4054        assert_eq!(g.num_demes(), 1);
4055
4056        // Defaults are part of the HDM and not the MDM.
4057        // Thus, writing the Graph to YAML should NOT
4058        // contain that block.
4059        let y = serde_yaml::to_string(&g).unwrap();
4060        assert!(!y.contains("defaults:"));
4061
4062        let _ = Graph::new_resolved_from_str(&y).unwrap();
4063    }
4064
4065    #[test]
4066    fn custom_time_unit_serialization() {
4067        let yaml = "
4068time_units: years
4069generation_time: 25
4070defaults:
4071  epoch:
4072    start_size: 1000
4073demes:
4074  - name: A
4075";
4076        let g = Graph::new_resolved_from_str(yaml).unwrap();
4077        assert_eq!(g.num_demes(), 1);
4078
4079        let y = serde_yaml::to_string(&g).unwrap();
4080        let _ = Graph::new_resolved_from_str(&y).unwrap();
4081    }
4082
4083    #[test]
4084    fn deserialize_migration_defaults() {
4085        let yaml = "
4086time_units: years
4087generation_time: 25
4088defaults:
4089  migration:
4090    rate: 0.25
4091    source: A
4092    dest: B
4093demes:
4094  - name: A
4095    epochs: 
4096     - start_size: 100
4097  - name: B
4098    epochs:
4099     - start_size: 42
4100";
4101        let g = Graph::new_resolved_from_str(yaml).unwrap();
4102        assert_eq!(g.migrations().len(), 1);
4103        assert_eq!(g.migrations()[0].source(), "A");
4104        assert_eq!(g.migrations()[0].dest(), "B");
4105        assert_eq!(g.migrations()[0].rate(), 0.25);
4106    }
4107
4108    #[test]
4109    fn deserialize_migration_defaults_rate_only() {
4110        let yaml = "
4111time_units: years
4112generation_time: 25
4113defaults:
4114  migration:
4115    rate: 0.25
4116demes:
4117  - name: A
4118    epochs: 
4119     - start_size: 100
4120  - name: B
4121    epochs:
4122     - start_size: 42
4123migrations:
4124  - source: A
4125    dest: B
4126";
4127        let g = Graph::new_resolved_from_str(yaml).unwrap();
4128        assert_eq!(g.migrations().len(), 1);
4129        assert_eq!(g.migrations()[0].source(), "A");
4130        assert_eq!(g.migrations()[0].dest(), "B");
4131        assert_eq!(g.migrations()[0].rate(), 0.25);
4132    }
4133
4134    #[test]
4135    fn deserialize_migration_defaults_source_only() {
4136        let yaml = "
4137time_units: years
4138generation_time: 25
4139defaults:
4140  migration:
4141    source: A
4142demes:
4143  - name: A
4144    epochs: 
4145     - start_size: 100
4146  - name: B
4147    epochs:
4148     - start_size: 42
4149migrations:
4150  - dest: B
4151    rate: 0.25
4152";
4153        let g = Graph::new_resolved_from_str(yaml).unwrap();
4154        assert_eq!(g.migrations().len(), 1);
4155        assert_eq!(g.migrations()[0].source(), "A");
4156        assert_eq!(g.migrations()[0].dest(), "B");
4157        assert_eq!(g.migrations()[0].rate(), 0.25);
4158    }
4159
4160    #[test]
4161    fn deserialize_migration_defaults_dest_only() {
4162        let yaml = "
4163time_units: years
4164generation_time: 25
4165defaults:
4166  migration:
4167    dest: B
4168demes:
4169  - name: A
4170    epochs: 
4171     - start_size: 100
4172  - name: B
4173    epochs:
4174     - start_size: 42
4175migrations:
4176  - source: A
4177    rate: 0.25
4178";
4179        let g = Graph::new_resolved_from_str(yaml).unwrap();
4180        assert_eq!(g.migrations().len(), 1);
4181        assert_eq!(g.migrations()[0].source(), "A");
4182        assert_eq!(g.migrations()[0].dest(), "B");
4183        assert_eq!(g.migrations()[0].rate(), 0.25);
4184    }
4185
4186    #[test]
4187    fn deserialize_migration_defaults_symmetric() {
4188        let yaml = "
4189time_units: years
4190generation_time: 25
4191defaults:
4192  migration:
4193    rate: 0.25
4194    demes: [A, B]
4195demes:
4196  - name: A
4197    epochs: 
4198     - start_size: 100
4199  - name: B
4200    epochs:
4201     - start_size: 42
4202";
4203        let g = Graph::new_resolved_from_str(yaml).unwrap();
4204        assert_eq!(g.migrations().len(), 2);
4205        assert_eq!(g.migrations()[0].source(), "A");
4206        assert_eq!(g.migrations()[0].dest(), "B");
4207        assert_eq!(g.migrations()[1].source(), "B");
4208        assert_eq!(g.migrations()[1].dest(), "A");
4209        assert_eq!(g.migrations()[0].rate(), 0.25);
4210        assert_eq!(g.migrations()[1].rate(), 0.25);
4211    }
4212
4213    #[test]
4214    fn deserialize_migration_defaults_symmetric_swap_deme_order() {
4215        let yaml = "
4216time_units: years
4217description: same tests as above, but demes in different order in migration defaults
4218generation_time: 25
4219defaults:
4220  migration:
4221    rate: 0.25
4222    demes: [B, A]
4223demes:
4224  - name: A
4225    epochs: 
4226     - start_size: 100
4227  - name: B
4228    epochs:
4229     - start_size: 42
4230";
4231        let g = Graph::new_resolved_from_str(yaml).unwrap();
4232        assert_eq!(g.migrations().len(), 2);
4233        assert_eq!(g.migrations()[0].source(), "B");
4234        assert_eq!(g.migrations()[0].dest(), "A");
4235        assert_eq!(g.migrations()[1].source(), "A");
4236        assert_eq!(g.migrations()[1].dest(), "B");
4237        assert_eq!(g.migrations()[0].rate(), 0.25);
4238        assert_eq!(g.migrations()[1].rate(), 0.25);
4239    }
4240
4241    #[test]
4242    fn deserialize_pulse_defaults() {
4243        let yaml = "
4244time_units: years
4245generation_time: 25
4246defaults:
4247  pulse: {sources: [A], dest: B, proportions: [0.25], time: 100}
4248demes:
4249  - name: A
4250    epochs: 
4251     - start_size: 100
4252  - name: B
4253    epochs:
4254     - start_size: 250 
4255";
4256        let g = Graph::new_resolved_from_str(yaml).unwrap();
4257        assert_eq!(g.pulses().len(), 1);
4258        assert_eq!(g.pulses()[0].sources(), vec!["A".to_string()]);
4259        assert_eq!(g.pulses()[0].dest(), "B");
4260        assert_eq!(
4261            g.pulses()[0].proportions(),
4262            vec![Proportion::try_from(0.25).unwrap()]
4263        );
4264        assert_eq!(g.pulses()[0].time(), 100.0);
4265    }
4266}
4267
4268#[cfg(test)]
4269mod test_to_generations {
4270    #[test]
4271    fn test_raw_conversion() {
4272        let yaml = "
4273time_units: years
4274generation_time: 25
4275demes:
4276 - name: ancestor
4277   epochs:
4278    - start_size: 100
4279      end_time: 100
4280 - name: derived
4281   ancestors: [ancestor]
4282   epochs:
4283    - start_size: 100
4284";
4285        let g = crate::loads(yaml).unwrap();
4286
4287        let converted = g.into_generations().unwrap();
4288        let deme = converted.deme(0);
4289        assert_eq!(deme.end_time(), 4.0);
4290        let deme = converted.deme(1);
4291        assert_eq!(deme.start_time(), 4.0);
4292    }
4293
4294    #[test]
4295    fn test_demelevel_default_epoch_conversion() {
4296        let yaml = "
4297time_units: years
4298generation_time: 25
4299demes:
4300 - name: ancestor
4301   defaults:
4302    epoch:
4303     end_time: 10
4304   epochs:
4305    - start_size: 100
4306 - name: derived
4307   ancestors: [ancestor]
4308   epochs:
4309    - start_size: 100
4310";
4311        let g = crate::loads(yaml).unwrap();
4312
4313        let converted = g.into_generations().unwrap();
4314        let deme = converted.deme(0);
4315        assert_eq!(deme.end_time(), 0.4);
4316        let deme = converted.deme(1);
4317        assert_eq!(deme.start_time(), 0.4);
4318    }
4319
4320    #[test]
4321    fn test_pulse_conversion() {
4322        let yaml = "
4323time_units: years
4324generation_time: 25
4325demes:
4326 - name: one
4327   epochs:
4328    - start_size: 100
4329 - name: two
4330   epochs:
4331    - start_size: 100
4332pulses:
4333 - sources: [one]
4334   dest: two
4335   proportions: [0.25]
4336   time: 50
4337";
4338        let g = crate::loads(yaml).unwrap();
4339
4340        let converted = g.into_generations().unwrap();
4341        for p in converted.pulses().iter() {
4342            assert_eq!(p.time(), 2.0);
4343        }
4344    }
4345
4346    #[test]
4347    fn test_default_pulse_conversion() {
4348        let yaml = "
4349time_units: years
4350generation_time: 25
4351defaults:
4352 pulse:
4353  time: 50
4354demes:
4355 - name: one
4356   epochs:
4357    - start_size: 100
4358 - name: two
4359   epochs:
4360    - start_size: 100
4361pulses:
4362 - sources: [one]
4363   dest: two
4364   proportions: [0.25]
4365";
4366        let g = crate::loads(yaml).unwrap();
4367
4368        let converted = g.into_generations().unwrap();
4369        for p in converted.pulses().iter() {
4370            assert_eq!(p.time(), 2.0);
4371        }
4372    }
4373
4374    #[test]
4375    fn test_migration_conversion() {
4376        let yaml = "
4377time_units: years
4378generation_time: 25
4379demes:
4380 - name: one
4381   epochs:
4382    - start_size: 100
4383 - name: two
4384   epochs:
4385    - start_size: 100
4386migrations:
4387 - demes: [one, two]
4388   rate: 0.25
4389   start_time: 50
4390   end_time: 10
4391";
4392        let g = crate::loads(yaml).unwrap();
4393
4394        let converted = g.into_generations().unwrap();
4395        for p in converted.migrations().iter() {
4396            assert_eq!(p.start_time(), 50.0 / 25.0);
4397            assert_eq!(p.end_time(), 10.0 / 25.0);
4398        }
4399    }
4400
4401    #[test]
4402    fn test_default_migration_conversion() {
4403        let yaml = "
4404time_units: years
4405generation_time: 25
4406defaults:
4407 migration:
4408  start_time: 50
4409  end_time: 10
4410demes:
4411 - name: one
4412   epochs:
4413    - start_size: 100
4414 - name: two
4415   epochs:
4416    - start_size: 100
4417migrations:
4418 - demes: [one, two]
4419   rate: 0.25
4420";
4421        let g = crate::loads(yaml).unwrap();
4422
4423        let converted = g.into_generations().unwrap();
4424        for p in converted.migrations().iter() {
4425            assert_eq!(p.start_time(), 50.0 / 25.0);
4426            assert_eq!(p.end_time(), 10.0 / 25.0);
4427        }
4428    }
4429
4430    #[test]
4431    fn test_toplevel_default_epoch_conversion() {
4432        let yaml = "
4433time_units: years
4434generation_time: 25
4435defaults:
4436 deme:
4437   start_time: 100
4438demes:
4439 - name: ancestor
4440   start_time: .inf
4441   epochs:
4442    - start_size: 100
4443      end_time: 10
4444 - name: derived
4445   ancestors: [ancestor]
4446   epochs:
4447    - start_size: 100
4448";
4449        let g = crate::loads(yaml).unwrap();
4450
4451        let converted = g.into_generations().unwrap();
4452        assert!(matches!(
4453            converted.time_units(),
4454            super::TimeUnits::Generations
4455        ));
4456        let deme = converted.deme(0);
4457        assert_eq!(deme.end_time(), 10.0 / 25.0);
4458        let deme = converted.deme(1);
4459        assert_eq!(deme.start_time(), 100.0 / 25.0);
4460    }
4461
4462    #[test]
4463    #[should_panic]
4464    fn test_raw_conversion_to_zero_length_epoch() {
4465        let yaml = "
4466time_units: years
4467generation_time: 1e300
4468demes:
4469 - name: ancestor
4470   epochs:
4471    - start_size: 100
4472      end_time: 1e-200
4473 - name: derived
4474   ancestors: [ancestor]
4475   epochs:
4476    - start_size: 100
4477";
4478        let g = crate::loads(yaml).unwrap();
4479
4480        let _ = g.into_generations().unwrap();
4481    }
4482}
4483
4484#[cfg(test)]
4485mod test_to_integer_generations {
4486    #[test]
4487    fn test_demelevel_default_epoch_conversion() {
4488        let yaml = "
4489time_units: years
4490generation_time: 25
4491demes:
4492 - name: ancestor
4493   defaults:
4494    epoch:
4495     end_time: 103
4496   epochs:
4497    - start_size: 100
4498 - name: derived
4499   ancestors: [ancestor]
4500   epochs:
4501    - start_size: 100
4502";
4503        let g = crate::loads(yaml).unwrap();
4504
4505        let converted = g.into_integer_generations().unwrap();
4506        let deme = converted.deme(0);
4507        assert_eq!(deme.end_time(), (103_f64 / 25.0).round());
4508        let deme = converted.deme(1);
4509        assert_eq!(deme.start_time(), (103_f64 / 25.0).round());
4510
4511        let g2 = serde_yaml::to_string(&converted).unwrap();
4512        let converted_from_str = crate::loads(&g2).unwrap();
4513        assert_eq!(converted, converted_from_str);
4514    }
4515
4516    #[test]
4517    #[should_panic]
4518    fn test_conversion_to_zero_length_epoch() {
4519        let yaml = "
4520time_units: years
4521description: rounding results in epochs of length zero
4522generation_time: 25
4523demes:
4524 - name: ancestor
4525   epochs:
4526    - start_size: 100
4527      end_time: 10
4528 - name: derived
4529   ancestors: [ancestor]
4530   epochs:
4531    - start_size: 100
4532";
4533        let g = crate::loads(yaml).unwrap();
4534
4535        g.into_integer_generations().unwrap();
4536    }
4537
4538    #[test]
4539    fn test_demelevel_epoch_conversion_non_integer_input_times() {
4540        let yaml = "
4541time_units: generations
4542demes:
4543 - name: ancestor
4544   defaults:
4545    epoch:
4546     end_time: 10.6
4547   epochs:
4548    - start_size: 100
4549 - name: derived
4550   ancestors: [ancestor]
4551   epochs:
4552    - start_size: 100
4553";
4554        let g = crate::loads(yaml).unwrap();
4555
4556        let converted = g.into_integer_generations().unwrap();
4557        let deme = converted.deme(0);
4558        assert_eq!(deme.end_time(), 10.6_f64.round());
4559        let deme = converted.deme(1);
4560        assert_eq!(deme.start_time(), 10.6_f64.round());
4561    }
4562
4563    #[test]
4564    #[should_panic]
4565    fn invalid_second_epoch_length_when_integer_rounded() {
4566        let yaml = "
4567time_units: years
4568description:
4569  50/1000 = 0.05, rounds to zero.
4570  Thus, the second epoch has length zero.
4571generation_time: 1000.0
4572demes:
4573 - name: A
4574   epochs:
4575    - start_size: 200
4576      end_time: 50
4577    - start_size: 100
4578";
4579        let graph = crate::loads(yaml).unwrap();
4580        let _ = graph.into_integer_generations().unwrap();
4581    }
4582}
4583
4584#[test]
4585#[should_panic]
4586fn test_control_character_in_yaml() {
4587    let yaml = "
4588time_units: years
4589generation_time: 25
4590demes:
4591 - name: ancestor\0
4592   defaults:
4593    epoch:
4594     end_time: 103
4595   epochs:
4596    - start_size: 100
4597 - name: derived
4598   ancestors: [ancestor]
4599   epochs:
4600    - start_size: 100
4601";
4602    let _ = Graph::new_from_str(yaml).unwrap();
4603}
4604
4605#[cfg(test)]
4606mod test_graph_to_unresolved_graph {
4607    use super::*;
4608
4609    static YAML0: &str = "
4610time_units: generations
4611demes:
4612 - name: ancestor
4613   defaults:
4614    epoch:
4615     end_time: 10.6
4616   epochs:
4617    - start_size: 100
4618 - name: derived
4619   ancestors: [ancestor]
4620   epochs:
4621    - start_size: 100
4622";
4623
4624    static YAML1: &str = "
4625time_units: years
4626generation_time: 25
4627defaults:
4628  pulse: {sources: [A], dest: B, proportions: [0.25], time: 100}
4629demes:
4630  - name: A
4631    epochs: 
4632     - start_size: 100
4633  - name: B
4634    epochs:
4635     - start_size: 250 
4636";
4637
4638    static YAML2: &str = "
4639time_units: years
4640description: same tests as above, but demes in different order in migration defaults
4641generation_time: 25
4642defaults:
4643  migration:
4644    rate: 0.25
4645    demes: [B, A]
4646demes:
4647  - name: A
4648    epochs: 
4649     - start_size: 100
4650  - name: B
4651    epochs:
4652     - start_size: 42
4653";
4654
4655    macro_rules! make_graph_to_unresolved_graph_test {
4656        ($name: ident, $yaml: ident) => {
4657            #[test]
4658            fn $name() {
4659                let yaml = $yaml;
4660                let graph = crate::loads(yaml).unwrap();
4661                let u = UnresolvedGraph::from(graph.clone());
4662                let graph_roundtrip = Graph::try_from(u.resolve().unwrap()).unwrap();
4663                assert_eq!(graph, graph_roundtrip);
4664            }
4665        };
4666    }
4667
4668    make_graph_to_unresolved_graph_test!(test_yaml0, YAML0);
4669    make_graph_to_unresolved_graph_test!(test_yaml1, YAML1);
4670    make_graph_to_unresolved_graph_test!(test_yaml2, YAML2);
4671}
4672
4673#[test]
4674#[cfg(feature = "toml")]
4675fn test_toml() {
4676    let toml: &str = "
4677        time_units = \"years\"
4678        description = \"a description\"
4679        generation_time = 25
4680
4681        [defaults]
4682        [defaults.migration]
4683        rate = 0.25
4684        demes = [\"A\", \"B\"]
4685
4686        [[demes]]
4687        name = \"A\"
4688        [[demes.epochs]]
4689        start_size = 100
4690
4691        [[demes]]
4692        name = \"B\"
4693        [[demes.epochs]]
4694        start_size = 42
4695";
4696    let m: UnresolvedGraph = toml::from_str(toml).unwrap();
4697    assert_eq!(m.demes.len(), 2);
4698    assert_eq!(m.demes[0].epochs.len(), 1);
4699    assert_eq!(m.demes[1].epochs.len(), 1);
4700    let _: Graph = m.resolve().unwrap().try_into().unwrap();
4701}
4702
4703#[test]
4704#[cfg(feature = "toml")]
4705fn test_roundtrip() {
4706    let mut f = std::fs::File::open("examples/jouganous.yaml").unwrap();
4707    let mut buf = String::default();
4708    let _ = f.read_to_string(&mut buf).unwrap();
4709
4710    // Load graph from yaml
4711    let graph = crate::loads(&buf).unwrap();
4712
4713    // use serde to convert yaml into toml
4714    let toml_from_yaml = serde_yaml::from_str::<toml::Value>(&buf).unwrap();
4715
4716    let toml_string = toml::to_string(&toml_from_yaml).unwrap();
4717
4718    let u: UnresolvedGraph = toml::from_str(&toml_string).unwrap();
4719    let graph_from_toml = Graph::try_from(u.resolve().unwrap()).unwrap();
4720    assert_eq!(graph, graph_from_toml);
4721}
4722
4723#[test]
4724#[should_panic]
4725fn test_negative_epoch_end_time() {
4726    let yaml = "
4727 time_units: years
4728 generation_time: 25
4729 description: A deme that existed until 20 years ago.
4730 demes:
4731  - name: deme
4732    epochs:
4733     - start_size: 50
4734       end_time: -1
4735 ";
4736    let _ = crate::loads(yaml).unwrap();
4737}
4738
4739#[test]
4740#[should_panic]
4741fn test_infinite_epoch_end_time() {
4742    let yaml = "
4743 time_units: years
4744 generation_time: 25
4745 description: A deme that existed until 20 years ago.
4746 demes:
4747  - name: deme
4748    epochs:
4749     - start_size: 50
4750       end_time: .inf
4751 ";
4752    let _ = crate::loads(yaml).unwrap();
4753}
4754
4755#[cfg(test)]
4756mod deme_equality {
4757    #[test]
4758    fn test_different_names() {
4759        let yaml = "
4760 time_units: generations
4761 demes:
4762  - name: deme
4763    epochs:
4764     - start_size: 50
4765  - name: demeB
4766    epochs:
4767     - start_size: 50
4768       ";
4769        let graph = crate::loads(yaml).unwrap();
4770        assert_eq!(graph.demes()[0], graph.demes()[0]);
4771        assert_eq!(graph.demes()[1], graph.demes()[1]);
4772        assert_ne!(graph.demes()[0], graph.demes()[1]);
4773    }
4774
4775    // Some of the tests below are semi-contrived.
4776    // Because a Graph cannot have two demes with the
4777    // same name, we must compare deme objects between
4778    // different graphs to ensure that differences in other
4779    // fields correctly give a not equals result.
4780
4781    #[test]
4782    fn test_different_epochs_start_size() {
4783        let yaml = "
4784 time_units: generations
4785 demes:
4786  - name: deme
4787    epochs:
4788     - start_size: 50
4789       ";
4790        let yaml2 = "
4791 time_units: generations
4792 description: A deme that existed until 20 years ago.
4793 demes:
4794  - name: deme
4795    epochs:
4796     - start_size: 10
4797       ";
4798        let graph = crate::loads(yaml).unwrap();
4799        let graph2 = crate::loads(yaml2).unwrap();
4800        assert_ne!(graph.demes()[0], graph2.demes()[0]);
4801    }
4802
4803    #[test]
4804    fn test_different_epochs_end_size() {
4805        let yaml = "
4806 time_units: generations
4807 demes:
4808  - name: deme
4809    epochs:
4810     - start_size: 50
4811       end_time: 20
4812     - end_size: 100
4813       ";
4814        let yaml2 = "
4815 time_units: generations
4816 description: A deme that existed until 20 years ago.
4817 demes:
4818  - name: deme
4819    epochs:
4820     - start_size: 50
4821       end_time: 20
4822     - end_size: 500
4823       ";
4824        let graph = crate::loads(yaml).unwrap();
4825        let graph2 = crate::loads(yaml2).unwrap();
4826        assert_ne!(graph.demes()[0], graph2.demes()[0]);
4827    }
4828
4829    #[test]
4830    fn test_different_epochs_growth_function() {
4831        let yaml = "
4832 time_units: generations
4833 demes:
4834  - name: deme
4835    epochs:
4836     - start_size: 50
4837       end_time: 20
4838     - end_size: 100
4839       ";
4840        let yaml2 = "
4841 time_units: generations
4842 description: A deme that existed until 20 years ago.
4843 demes:
4844  - name: deme
4845    epochs:
4846     - start_size: 50
4847       end_time: 20
4848     - end_size: 100
4849       size_function: linear
4850       ";
4851        let graph = crate::loads(yaml).unwrap();
4852        let graph2 = crate::loads(yaml2).unwrap();
4853        assert_ne!(graph.demes()[0], graph2.demes()[0]);
4854    }
4855
4856    #[test]
4857    fn test_different_descriptions() {
4858        let yaml = "
4859 time_units: generations
4860 demes:
4861  - name: deme
4862    description: yes
4863    epochs:
4864     - start_size: 50
4865       ";
4866        let yaml2 = "
4867 time_units: generations
4868 description: A deme that existed until 20 years ago.
4869 demes:
4870  - name: deme
4871    description: no
4872    epochs:
4873     - start_size: 50
4874       ";
4875        let graph = crate::loads(yaml).unwrap();
4876        let graph2 = crate::loads(yaml2).unwrap();
4877        assert_ne!(graph.demes()[0], graph2.demes()[0]);
4878    }
4879
4880    #[test]
4881    fn test_different_ancestor_proportions() {
4882        let yaml = "
4883 time_units: generations
4884 demes:
4885  - name: ancestor1
4886    epochs:
4887     - start_size: 50
4888       end_time: 20
4889  - name: ancestor2
4890    epochs:
4891     - start_size: 50
4892       end_time: 20
4893  - name: derived
4894    ancestors: [ancestor1, ancestor2]
4895    proportions: [0.75, 0.25]
4896    start_time: 20
4897    epochs:
4898     - start_size: 50
4899";
4900        let yaml2 = "
4901 time_units: generations
4902 demes:
4903  - name: ancestor1
4904    epochs:
4905     - start_size: 50
4906       end_time: 20
4907  - name: ancestor2
4908    epochs:
4909     - start_size: 50
4910       end_time: 20
4911  - name: derived
4912    ancestors: [ancestor1, ancestor2]
4913    proportions: [0.5, 0.5]
4914    start_time: 20
4915    epochs:
4916     - start_size: 50
4917";
4918        let graph = crate::loads(yaml).unwrap();
4919        let graph2 = crate::loads(yaml2).unwrap();
4920        assert_eq!(graph.demes()[0], graph2.demes()[0]);
4921        assert_eq!(graph.demes()[1], graph2.demes()[1]);
4922        assert_ne!(graph.demes()[2], graph2.demes()[2]);
4923    }
4924}
4925
4926#[cfg(test)]
4927mod test_rescaling {
4928    pub static SIMPLE_TEST_GRAPH_0: &str = "
4929 time_units: generations
4930 demes:
4931  - name: ancestor1
4932    epochs:
4933     - start_size: 50
4934       end_time: 20
4935  - name: ancestor2
4936    epochs:
4937     - start_size: 50
4938       end_time: 20
4939  - name: derived
4940    ancestors: [ancestor1, ancestor2]
4941    proportions: [0.5, 0.5]
4942    start_time: 20
4943    epochs:
4944     - start_size: 50
4945";
4946
4947    pub static SIMPLE_TEST_GRAPH_1: &str = "
4948 time_units: generations
4949 demes:
4950  - name: ancestor1
4951    epochs:
4952     - start_size: 50
4953       end_time: 20
4954  - name: ancestor2
4955    epochs:
4956     - start_size: 50
4957       end_time: 20
4958  - name: derived1
4959    ancestors: [ancestor1]
4960    proportions: [1.0]
4961    start_time: 20
4962    epochs:
4963     - start_size: 50
4964  - name: derived2
4965    ancestors: [ancestor2]
4966    proportions: [1.0]
4967    start_time: 20
4968    epochs:
4969     - start_size: 50
4970 migrations:
4971  - demes: [derived1, derived2]
4972    start_time: 20
4973    rate: 0.25
4974";
4975
4976    pub static SIMPLE_TEST_GRAPH_2: &str = "
4977 time_units: generations
4978 demes:
4979  - name: ancestor1
4980    epochs:
4981     - start_size: 50
4982       end_time: 20
4983  - name: ancestor2
4984    epochs:
4985     - start_size: 50
4986       end_time: 20
4987  - name: derived1
4988    ancestors: [ancestor1]
4989    proportions: [1.0]
4990    start_time: 20
4991    epochs:
4992     - start_size: 50
4993  - name: derived2
4994    ancestors: [ancestor2]
4995    proportions: [1.0]
4996    start_time: 20
4997    epochs:
4998     - start_size: 50
4999 pulses:
5000  - sources: [derived1]
5001    dest: derived2
5002    time: 19
5003    proportions: [0.25]
5004";
5005
5006    fn run_test(yaml: &str, scaling_factor: f64) -> Result<(), String> {
5007        let graph = crate::loads(yaml).unwrap();
5008        let rescale = match graph.clone().rescale(scaling_factor) {
5009            Ok(r) => r,
5010            Err(e) => return Err(format!("{e:?}")),
5011        };
5012        compare_graphs(&graph, &rescale, scaling_factor)
5013    }
5014
5015    fn compare_time(
5016        a: crate::Time,
5017        b: crate::Time,
5018        scaling_factor: f64,
5019        prefix: &str,
5020    ) -> Result<(), String> {
5021        if !matches!(
5022            a.partial_cmp(&(b * scaling_factor).unwrap()).unwrap(),
5023            std::cmp::Ordering::Equal
5024        ) {
5025            return Err(format!(
5026                "{prefix} {:?}/{scaling_factor} should equal {:?}",
5027                a, b
5028            ));
5029        }
5030
5031        Ok(())
5032    }
5033
5034    fn compare_size(
5035        a: crate::DemeSize,
5036        b: crate::DemeSize,
5037        scaling_factor: f64,
5038        prefix: &str,
5039    ) -> Result<(), String> {
5040        if !matches!(
5041            a.partial_cmp(&(b * scaling_factor).unwrap()).unwrap(),
5042            std::cmp::Ordering::Equal
5043        ) {
5044            return Err(format!(
5045                "{prefix} {:?}/{scaling_factor} should equal {:?}",
5046                a, b
5047            ));
5048        }
5049
5050        Ok(())
5051    }
5052
5053    fn compare_epochs(
5054        epochs_i: &[crate::Epoch],
5055        epochs_j: &[crate::Epoch],
5056        scaling_factor: f64,
5057    ) -> Result<(), String> {
5058        for (epoch_i, epoch_j) in epochs_i.iter().zip(epochs_j.iter()) {
5059            compare_size(
5060                epoch_i.end_size(),
5061                epoch_j.end_size(),
5062                scaling_factor,
5063                "epoch end size",
5064            )?;
5065            compare_time(
5066                epoch_i.start_time(),
5067                epoch_j.start_time(),
5068                scaling_factor,
5069                "epoch start time",
5070            )?;
5071            compare_time(
5072                epoch_i.end_time(),
5073                epoch_j.end_time(),
5074                scaling_factor,
5075                "epoch end time",
5076            )?;
5077            compare_size(
5078                epoch_i.start_size(),
5079                epoch_j.start_size(),
5080                scaling_factor,
5081                "epoch start time",
5082            )?;
5083        }
5084        Ok(())
5085    }
5086
5087    fn compare_demes(
5088        demes_i: &[crate::Deme],
5089        demes_j: &[crate::Deme],
5090        scaling_factor: f64,
5091    ) -> Result<(), String> {
5092        for (deme_i, deme_j) in demes_i.iter().zip(demes_j.iter()) {
5093            compare_time(
5094                deme_i.start_time(),
5095                deme_j.start_time(),
5096                scaling_factor,
5097                "deme start time",
5098            )?;
5099
5100            compare_epochs(deme_i.epochs(), deme_j.epochs(), scaling_factor)?;
5101        }
5102        Ok(())
5103    }
5104
5105    fn compare_pulses(
5106        pulses_i: &[crate::Pulse],
5107        pulses_j: &[crate::Pulse],
5108        scaling_factor: f64,
5109    ) -> Result<(), String> {
5110        for (pulse_i, pulse_j) in pulses_i.iter().zip(pulses_j.iter()) {
5111            compare_time(pulse_i.time(), pulse_j.time(), scaling_factor, "pulse time")?;
5112        }
5113        Ok(())
5114    }
5115
5116    fn compare_migrations(
5117        migrations_i: &[crate::AsymmetricMigration],
5118        migrations_j: &[crate::AsymmetricMigration],
5119        scaling_factor: f64,
5120    ) -> Result<(), String> {
5121        for (mig_i, mig_j) in migrations_i.iter().zip(migrations_j.iter()) {
5122            compare_time(
5123                mig_i.start_time(),
5124                mig_j.start_time(),
5125                scaling_factor,
5126                "migration start time",
5127            )?;
5128            compare_time(
5129                mig_i.end_time(),
5130                mig_j.end_time(),
5131                scaling_factor,
5132                "migration end time",
5133            )?;
5134            if !matches!(
5135                mig_i
5136                    .rate()
5137                    .partial_cmp(&(mig_j.rate() / scaling_factor).unwrap())
5138                    .unwrap(),
5139                std::cmp::Ordering::Equal,
5140            ) {
5141                return Err(format!(
5142                    "migration rate {:?}*{scaling_factor} should equal {:?}",
5143                    mig_i.rate(),
5144                    mig_j.rate()
5145                ));
5146            }
5147        }
5148        Ok(())
5149    }
5150
5151    fn compare_graphs(
5152        input: &crate::Graph,
5153        rescaled: &crate::Graph,
5154        scaling_factor: f64,
5155    ) -> Result<(), String> {
5156        compare_demes(input.demes(), rescaled.demes(), scaling_factor)?;
5157        compare_pulses(input.pulses(), rescaled.pulses(), scaling_factor)?;
5158        compare_migrations(input.migrations(), rescaled.migrations(), scaling_factor)?;
5159        Ok(())
5160    }
5161
5162    #[test]
5163    fn test_rescaling() {
5164        run_test(SIMPLE_TEST_GRAPH_0, 10.).unwrap()
5165    }
5166
5167    #[test]
5168    fn test_rescaling1() {
5169        run_test(SIMPLE_TEST_GRAPH_1, 0.1).unwrap()
5170    }
5171
5172    #[test]
5173    fn test_rescaling2() {
5174        run_test(SIMPLE_TEST_GRAPH_2, 10.).unwrap()
5175    }
5176
5177    #[test]
5178    fn test_rescaling1_bad_scale_factor() {
5179        // Will result in migration rates > 1, which is an error at resolution time
5180        assert!(run_test(SIMPLE_TEST_GRAPH_1, 1000.0).is_err())
5181    }
5182
5183    #[test]
5184    fn test_rescaling_bad_scaling_factors() {
5185        for bad in [-1.0, f64::INFINITY, 0.0, f64::NAN] {
5186            let graph = crate::loads(SIMPLE_TEST_GRAPH_0).unwrap();
5187            if let Err(e) = graph.rescale(bad) {
5188                assert!(matches!(e, crate::DemesError::ValueError(_)))
5189            } else {
5190                panic!()
5191            }
5192        }
5193    }
5194
5195    #[test]
5196    fn test_rescaling_cloning_rates() {
5197        let yaml = "
5198 time_units: generations
5199 demes:
5200  - name: ancestor1
5201    epochs:
5202     - start_size: 50
5203       end_time: 20
5204       cloning_rate: 0.5
5205";
5206        let graph = crate::loads(yaml).unwrap();
5207        let rescaled = graph.clone().rescale(10.);
5208        assert!(rescaled.is_ok());
5209        let rescaled = rescaled.unwrap();
5210        for (di, dj) in graph.demes().iter().zip(rescaled.demes.iter()) {
5211            for (ei, ej) in di.epochs().iter().zip(dj.epochs().iter()) {
5212                assert_eq!(ei.cloning_rate(), ej.cloning_rate)
5213            }
5214        }
5215    }
5216
5217    #[test]
5218    fn test_rescaling_selfing_rates() {
5219        let yaml = "
5220 time_units: generations
5221 demes:
5222  - name: ancestor1
5223    epochs:
5224     - start_size: 50
5225       end_time: 20
5226       selfing_rate: 0.5
5227";
5228        let graph = crate::loads(yaml).unwrap();
5229        let rescaled = graph.clone().rescale(10.);
5230        assert!(rescaled.is_ok());
5231        let rescaled = rescaled.unwrap();
5232        for (di, dj) in graph.demes().iter().zip(rescaled.demes.iter()) {
5233            for (ei, ej) in di.epochs().iter().zip(dj.epochs().iter()) {
5234                assert_eq!(ei.selfing_rate(), ej.selfing_rate)
5235            }
5236        }
5237    }
5238}
5239
5240#[cfg(test)]
5241mod test_forward_ancestry_proportions {
5242
5243    #[test]
5244    fn test_simple_graph() {
5245        let graph = crate::loads(super::test_rescaling::SIMPLE_TEST_GRAPH_2).unwrap();
5246        assert!(graph
5247            .ancestry_proportions(2, 21.0.try_into().unwrap())
5248            .is_none());
5249        let proportions = graph
5250            .ancestry_proportions("derived1", 20.0.try_into().unwrap())
5251            .unwrap();
5252        assert_eq!(proportions[0], 1.0);
5253        assert!(proportions.iter().skip(1).cloned().all(|p| p == 0.0));
5254
5255        let mut proportions = proportions;
5256        graph
5257            .fill_ancestry_proportions("derived1", 19.0.try_into().unwrap(), &mut proportions)
5258            .unwrap();
5259        assert_eq!(proportions.as_ref(), &[0., 0., 1., 0.]);
5260    }
5261
5262    #[test]
5263    fn test_migrations() {
5264        let graph = crate::loads(super::test_rescaling::SIMPLE_TEST_GRAPH_1).unwrap();
5265        let proportions = graph.ancestry_proportions("derived2", 19.0.try_into().unwrap());
5266        let proportions = proportions.unwrap().to_vec();
5267        assert_eq!(&proportions, &[0., 0., 0.25, 0.75])
5268    }
5269
5270    #[test]
5271    fn test_pulses() {
5272        let graph = crate::loads(super::test_rescaling::SIMPLE_TEST_GRAPH_2).unwrap();
5273        let proportions = graph.ancestry_proportions("derived2", 19.0.try_into().unwrap());
5274        let proportions = proportions.unwrap().to_vec();
5275        assert_eq!(&proportions, &[0., 0., 0.25, 0.75])
5276    }
5277
5278    #[test]
5279    fn test_invalid_deme_name() {
5280        let graph = crate::loads(super::test_rescaling::SIMPLE_TEST_GRAPH_2).unwrap();
5281        assert!(graph
5282            .ancestry_proportions("bananas", 19.0.try_into().unwrap())
5283            .is_none());
5284    }
5285
5286    #[test]
5287    fn test_invalid_deme_index() {
5288        let graph = crate::loads(super::test_rescaling::SIMPLE_TEST_GRAPH_2).unwrap();
5289        assert!(graph
5290            .ancestry_proportions(usize::MAX, 19.0.try_into().unwrap())
5291            .is_none());
5292    }
5293
5294    #[test]
5295    fn test_time_0() {
5296        let graph = crate::loads(super::test_rescaling::SIMPLE_TEST_GRAPH_2).unwrap();
5297        assert!(graph
5298            .ancestry_proportions(0, 0.0.try_into().unwrap())
5299            .is_none());
5300        let mut buffer = vec![9.; graph.num_demes()];
5301        assert!(graph
5302            .fill_ancestry_proportions(0, 0.0.try_into().unwrap(), &mut buffer)
5303            .is_none())
5304    }
5305}
5306
5307#[cfg(test)]
5308mod test_ancestry_proportion_matrix {
5309    use std::ops::Deref;
5310    #[test]
5311    fn test_simpler_graph() {
5312        let graph = crate::loads(super::test_rescaling::SIMPLE_TEST_GRAPH_0).unwrap();
5313        let proportions = graph
5314            .ancestry_proportions_matrix(21.0.try_into().unwrap())
5315            .unwrap();
5316        for i in [0, 1] {
5317            let mut e = vec![0.; graph.num_demes()];
5318            e[i] = 1.;
5319            assert_eq!(
5320                proportions[i * graph.num_demes()..(i + 1) * graph.num_demes()],
5321                e
5322            );
5323        }
5324        let e = vec![0.; graph.num_demes()];
5325        assert_eq!(proportions[2 * graph.num_demes()..3 * graph.num_demes()], e);
5326
5327        let proportions = graph
5328            .ancestry_proportions_matrix(20.0.try_into().unwrap())
5329            .unwrap();
5330        for i in [0, 1] {
5331            let e = vec![0.; graph.num_demes()];
5332            assert_eq!(
5333                proportions[i * graph.num_demes()..(i + 1) * graph.num_demes()],
5334                e
5335            );
5336        }
5337        let e = vec![0.5, 0.5, 0.];
5338        assert_eq!(proportions[2 * graph.num_demes()..3 * graph.num_demes()], e);
5339
5340        let proportions = graph
5341            .ancestry_proportions_matrix(19.0.try_into().unwrap())
5342            .unwrap();
5343        for i in [0, 1] {
5344            let e = vec![0.; graph.num_demes()];
5345            assert_eq!(
5346                proportions[i * graph.num_demes()..(i + 1) * graph.num_demes()],
5347                e
5348            );
5349        }
5350        let e = vec![0., 0., 1.];
5351        assert_eq!(proportions[2 * graph.num_demes()..3 * graph.num_demes()], e);
5352    }
5353
5354    #[test]
5355    fn test_simple_graph() {
5356        let graph = crate::loads(super::test_rescaling::SIMPLE_TEST_GRAPH_2).unwrap();
5357        let proportions = graph
5358            .ancestry_proportions_matrix(21.0.try_into().unwrap())
5359            .unwrap();
5360        assert_eq!(proportions.len(), 16);
5361        assert_eq!(&proportions[..graph.num_demes()], &[1., 0., 0., 0.]);
5362        assert_eq!(
5363            &proportions[graph.num_demes()..2 * graph.num_demes()],
5364            &[0., 1., 0., 0.]
5365        );
5366        for i in [2, 3] {
5367            assert_eq!(
5368                &proportions[i * graph.num_demes()..(i + 1) * graph.num_demes()],
5369                &[0., 0., 0., 0.],
5370                "{i}"
5371            );
5372        }
5373        for (i, ap) in proportions.chunks(4).enumerate() {
5374            if let Some(dp) = graph.ancestry_proportions(i, 21.0.try_into().unwrap()) {
5375                assert_eq!(ap, dp.deref());
5376            }
5377        }
5378
5379        let proportions = graph
5380            .ancestry_proportions_matrix(20.0.try_into().unwrap())
5381            .unwrap();
5382        for i in [0, 1] {
5383            assert_eq!(
5384                &proportions[i * graph.num_demes()..(i + 1) * graph.num_demes()],
5385                &[0., 0., 0., 0.],
5386                "{i}"
5387            );
5388        }
5389        assert_eq!(
5390            &proportions[2 * graph.num_demes()..3 * graph.num_demes()],
5391            &[1., 0., 0., 0.],
5392            "{proportions:?}"
5393        );
5394        assert_eq!(
5395            &proportions[3 * graph.num_demes()..4 * graph.num_demes()],
5396            &[0., 1., 0., 0.],
5397        );
5398        for i in [0, 1] {
5399            assert_eq!(
5400                &proportions[i * graph.num_demes()..(i + 1) * graph.num_demes()],
5401                &[0., 0., 0., 0.],
5402            );
5403        }
5404        for (i, ap) in proportions.chunks(4).enumerate() {
5405            if let Some(dp) = graph.ancestry_proportions(i, 20.0.try_into().unwrap()) {
5406                assert_eq!(ap, dp.deref());
5407            }
5408        }
5409
5410        let mut proportions = proportions;
5411        graph
5412            .fill_ancestry_proportions_matrix(19.0.try_into().unwrap(), &mut proportions)
5413            .unwrap();
5414        for (i, ap) in proportions.chunks(4).enumerate() {
5415            if let Some(dp) = graph.ancestry_proportions(i, 19.0.try_into().unwrap()) {
5416                assert_eq!(ap, dp.deref());
5417            }
5418        }
5419    }
5420
5421    #[test]
5422    fn test_simple_graph_with_migrations() {
5423        let graph = crate::loads(super::test_rescaling::SIMPLE_TEST_GRAPH_1).unwrap();
5424        for time in [21.0, 20.0, 19.0, 18.0] {
5425            let at = crate::Time::try_from(time).unwrap();
5426            let proportions = graph.ancestry_proportions_matrix(at).unwrap();
5427            for deme_index in 0..graph.num_demes() {
5428                let pslice = &proportions
5429                    [deme_index * graph.num_demes()..(deme_index + 1) * graph.num_demes()];
5430                if let Some(prop) = graph.ancestry_proportions(deme_index, at) {
5431                    assert_eq!(pslice, prop.deref())
5432                } else {
5433                    assert!(pslice.iter().all(|&p| p == 0.))
5434                }
5435            }
5436        }
5437
5438        // Again, now with a buffer
5439        let mut buffer = vec![666.0; graph.num_demes() * graph.num_demes()];
5440        for time in [21.0, 20.0, 19.0, 18.0] {
5441            let at = crate::Time::try_from(time).unwrap();
5442            graph
5443                .fill_ancestry_proportions_matrix(at, &mut buffer)
5444                .unwrap();
5445            for deme_index in 0..graph.num_demes() {
5446                let pslice =
5447                    &buffer[deme_index * graph.num_demes()..(deme_index + 1) * graph.num_demes()];
5448                if let Some(prop) = graph.ancestry_proportions(deme_index, at) {
5449                    assert_eq!(pslice, prop.deref())
5450                } else {
5451                    assert!(pslice.iter().all(|&p| p == 0.))
5452                }
5453            }
5454        }
5455    }
5456
5457    #[test]
5458    fn test_time_0() {
5459        let graph = crate::loads(super::test_rescaling::SIMPLE_TEST_GRAPH_1).unwrap();
5460        assert!(graph
5461            .ancestry_proportions_matrix(0.0.try_into().unwrap())
5462            .is_err());
5463        let mut buffer = vec![f64::NAN; graph.num_demes() * graph.num_demes()];
5464        assert!(graph
5465            .fill_ancestry_proportions_matrix(0.0.try_into().unwrap(), &mut buffer)
5466            .is_err());
5467    }
5468}