Skip to main content

agm_core/model/
orchestration.rs

1//! Orchestration types (spec S13.10, S27).
2
3use serde::{Deserialize, Serialize};
4use std::fmt;
5use std::str::FromStr;
6
7use super::fields::ParseEnumError;
8
9#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
10#[serde(rename_all = "snake_case")]
11pub enum Strategy {
12    Sequential,
13    Parallel,
14}
15
16impl fmt::Display for Strategy {
17    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
18        match self {
19            Self::Sequential => write!(f, "sequential"),
20            Self::Parallel => write!(f, "parallel"),
21        }
22    }
23}
24
25impl FromStr for Strategy {
26    type Err = ParseEnumError;
27
28    fn from_str(s: &str) -> Result<Self, Self::Err> {
29        match s {
30            "sequential" => Ok(Self::Sequential),
31            "parallel" => Ok(Self::Parallel),
32            _ => Err(ParseEnumError {
33                type_name: "Strategy",
34                value: s.to_owned(),
35            }),
36        }
37    }
38}
39
40#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
41pub struct ParallelGroup {
42    pub group: String,
43    pub nodes: Vec<String>,
44    pub strategy: Strategy,
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub requires: Option<Vec<String>>,
47    #[serde(skip_serializing_if = "Option::is_none")]
48    pub max_concurrency: Option<u32>,
49}
50
51#[cfg(test)]
52mod tests {
53    use super::*;
54
55    #[test]
56    fn test_strategy_from_str_valid_returns_ok() {
57        assert_eq!(
58            "sequential".parse::<Strategy>().unwrap(),
59            Strategy::Sequential
60        );
61        assert_eq!("parallel".parse::<Strategy>().unwrap(), Strategy::Parallel);
62    }
63
64    #[test]
65    fn test_strategy_from_str_invalid_returns_error() {
66        let err = "concurrent".parse::<Strategy>().unwrap_err();
67        assert_eq!(err.type_name, "Strategy");
68    }
69
70    #[test]
71    fn test_strategy_display_roundtrip() {
72        for s in [Strategy::Sequential, Strategy::Parallel] {
73            let text = s.to_string();
74            assert_eq!(text.parse::<Strategy>().unwrap(), s);
75        }
76    }
77
78    #[test]
79    fn test_parallel_group_serde_roundtrip() {
80        let pg = ParallelGroup {
81            group: "1-schema".to_owned(),
82            nodes: vec!["migration.schema".to_owned()],
83            strategy: Strategy::Sequential,
84            requires: None,
85            max_concurrency: None,
86        };
87        let json = serde_json::to_string(&pg).unwrap();
88        let back: ParallelGroup = serde_json::from_str(&json).unwrap();
89        assert_eq!(pg, back);
90    }
91
92    #[test]
93    fn test_parallel_group_with_requires_and_concurrency() {
94        let pg = ParallelGroup {
95            group: "3-backend".to_owned(),
96            nodes: vec!["backend.repo".to_owned(), "backend.cmd".to_owned()],
97            strategy: Strategy::Parallel,
98            requires: Some(vec!["2-models".to_owned()]),
99            max_concurrency: Some(4),
100        };
101        let json = serde_json::to_string(&pg).unwrap();
102        let back: ParallelGroup = serde_json::from_str(&json).unwrap();
103        assert_eq!(pg, back);
104    }
105
106    #[test]
107    fn test_parallel_group_optional_fields_absent_in_json() {
108        let pg = ParallelGroup {
109            group: "1-test".to_owned(),
110            nodes: vec![],
111            strategy: Strategy::Sequential,
112            requires: None,
113            max_concurrency: None,
114        };
115        let json = serde_json::to_string(&pg).unwrap();
116        assert!(!json.contains("requires"));
117        assert!(!json.contains("max_concurrency"));
118    }
119}