Skip to main content

agm_core/model/
file.rs

1//! Top-level AGM file and header types (spec S8, S9).
2
3use std::collections::BTreeMap;
4use std::fmt;
5use std::str::FromStr;
6
7use serde::de;
8use serde::{Deserialize, Serialize};
9
10use super::imports::ImportEntry;
11use super::node::Node;
12
13// ---------------------------------------------------------------------------
14// TokenEstimate
15// ---------------------------------------------------------------------------
16
17#[derive(Debug, Clone, PartialEq)]
18pub enum TokenEstimate {
19    Count(u64),
20    Variable,
21}
22
23impl fmt::Display for TokenEstimate {
24    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25        match self {
26            Self::Count(n) => write!(f, "{n}"),
27            Self::Variable => write!(f, "variable"),
28        }
29    }
30}
31
32impl FromStr for TokenEstimate {
33    type Err = std::num::ParseIntError;
34
35    fn from_str(s: &str) -> Result<Self, Self::Err> {
36        if s.eq_ignore_ascii_case("variable") {
37            Ok(Self::Variable)
38        } else {
39            s.parse::<u64>().map(Self::Count)
40        }
41    }
42}
43
44impl Serialize for TokenEstimate {
45    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
46        match self {
47            Self::Count(n) => serializer.serialize_u64(*n),
48            Self::Variable => serializer.serialize_str("variable"),
49        }
50    }
51}
52
53impl<'de> Deserialize<'de> for TokenEstimate {
54    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
55        struct Visitor;
56
57        impl<'de> de::Visitor<'de> for Visitor {
58            type Value = TokenEstimate;
59
60            fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61                f.write_str("a positive integer or the string \"variable\"")
62            }
63
64            fn visit_u64<E: de::Error>(self, v: u64) -> Result<Self::Value, E> {
65                Ok(TokenEstimate::Count(v))
66            }
67
68            fn visit_i64<E: de::Error>(self, v: i64) -> Result<Self::Value, E> {
69                if v >= 0 {
70                    Ok(TokenEstimate::Count(v as u64))
71                } else {
72                    Err(E::custom("token estimate must be non-negative"))
73                }
74            }
75
76            fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
77                if v.eq_ignore_ascii_case("variable") {
78                    Ok(TokenEstimate::Variable)
79                } else {
80                    Err(E::custom(format!("expected \"variable\", got {v:?}")))
81                }
82            }
83        }
84
85        deserializer.deserialize_any(Visitor)
86    }
87}
88
89// ---------------------------------------------------------------------------
90// LoadProfile
91// ---------------------------------------------------------------------------
92
93#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
94pub struct LoadProfile {
95    pub filter: String,
96    #[serde(skip_serializing_if = "Option::is_none")]
97    pub estimated_tokens: Option<TokenEstimate>,
98}
99
100// ---------------------------------------------------------------------------
101// Header
102// ---------------------------------------------------------------------------
103
104#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
105pub struct Header {
106    pub agm: String,
107    pub package: String,
108    pub version: String,
109    #[serde(skip_serializing_if = "Option::is_none")]
110    pub title: Option<String>,
111    #[serde(skip_serializing_if = "Option::is_none")]
112    pub owner: Option<String>,
113    #[serde(skip_serializing_if = "Option::is_none")]
114    pub imports: Option<Vec<ImportEntry>>,
115    #[serde(skip_serializing_if = "Option::is_none")]
116    pub default_load: Option<String>,
117    #[serde(skip_serializing_if = "Option::is_none")]
118    pub description: Option<String>,
119    #[serde(skip_serializing_if = "Option::is_none")]
120    pub tags: Option<Vec<String>>,
121    #[serde(skip_serializing_if = "Option::is_none")]
122    pub status: Option<String>,
123    #[serde(skip_serializing_if = "Option::is_none")]
124    pub load_profiles: Option<BTreeMap<String, LoadProfile>>,
125    #[serde(skip_serializing_if = "Option::is_none")]
126    pub target_runtime: Option<String>,
127}
128
129// ---------------------------------------------------------------------------
130// AgmFile
131// ---------------------------------------------------------------------------
132
133#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
134pub struct AgmFile {
135    #[serde(flatten)]
136    pub header: Header,
137    pub nodes: Vec<Node>,
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    #[test]
145    fn test_token_estimate_from_str_number_returns_count() {
146        let t: TokenEstimate = "1200".parse().unwrap();
147        assert_eq!(t, TokenEstimate::Count(1200));
148    }
149
150    #[test]
151    fn test_token_estimate_from_str_variable_returns_variable() {
152        let t: TokenEstimate = "variable".parse().unwrap();
153        assert_eq!(t, TokenEstimate::Variable);
154    }
155
156    #[test]
157    fn test_token_estimate_from_str_variable_case_insensitive() {
158        let t: TokenEstimate = "Variable".parse().unwrap();
159        assert_eq!(t, TokenEstimate::Variable);
160    }
161
162    #[test]
163    fn test_token_estimate_from_str_invalid_returns_error() {
164        assert!("not_a_number".parse::<TokenEstimate>().is_err());
165    }
166
167    #[test]
168    fn test_token_estimate_display_count() {
169        assert_eq!(TokenEstimate::Count(4200).to_string(), "4200");
170    }
171
172    #[test]
173    fn test_token_estimate_display_variable() {
174        assert_eq!(TokenEstimate::Variable.to_string(), "variable");
175    }
176
177    #[test]
178    fn test_token_estimate_serde_count() {
179        let t = TokenEstimate::Count(1200);
180        let json = serde_json::to_string(&t).unwrap();
181        assert_eq!(json, "1200");
182        let back: TokenEstimate = serde_json::from_str(&json).unwrap();
183        assert_eq!(t, back);
184    }
185
186    #[test]
187    fn test_token_estimate_serde_variable() {
188        let t = TokenEstimate::Variable;
189        let json = serde_json::to_string(&t).unwrap();
190        assert_eq!(json, "\"variable\"");
191        let back: TokenEstimate = serde_json::from_str(&json).unwrap();
192        assert_eq!(t, back);
193    }
194
195    #[test]
196    fn test_header_required_fields_only() {
197        let h = Header {
198            agm: "1.0".to_owned(),
199            package: "test.pkg".to_owned(),
200            version: "0.1.0".to_owned(),
201            title: None,
202            owner: None,
203            imports: None,
204            default_load: None,
205            description: None,
206            tags: None,
207            status: None,
208            load_profiles: None,
209            target_runtime: None,
210        };
211        assert_eq!(h.agm, "1.0");
212        assert_eq!(h.package, "test.pkg");
213    }
214
215    #[test]
216    fn test_load_profile_serde_roundtrip() {
217        let lp = LoadProfile {
218            filter: "priority in [critical]".to_owned(),
219            estimated_tokens: Some(TokenEstimate::Count(1200)),
220        };
221        let json = serde_json::to_string(&lp).unwrap();
222        let back: LoadProfile = serde_json::from_str(&json).unwrap();
223        assert_eq!(lp, back);
224    }
225
226    #[test]
227    fn test_agm_file_serde_minimal() {
228        let file = AgmFile {
229            header: Header {
230                agm: "1.0".to_owned(),
231                package: "test.pkg".to_owned(),
232                version: "0.1.0".to_owned(),
233                title: None,
234                owner: None,
235                imports: None,
236                default_load: None,
237                description: None,
238                tags: None,
239                status: None,
240                load_profiles: None,
241                target_runtime: None,
242            },
243            nodes: vec![],
244        };
245        let json = serde_json::to_string(&file).unwrap();
246        let back: AgmFile = serde_json::from_str(&json).unwrap();
247        assert_eq!(file, back);
248    }
249}