1use std::collections::BTreeMap;
4use std::fmt;
5use std::str::FromStr;
6
7use serde::de;
8use serde::{Deserialize, Serialize};
9
10use super::imports::ImportEntry;
11use super::node::Node;
12
13#[derive(Debug, Clone, PartialEq)]
18pub enum TokenEstimate {
19 Count(u64),
20 Variable,
21}
22
23impl fmt::Display for TokenEstimate {
24 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25 match self {
26 Self::Count(n) => write!(f, "{n}"),
27 Self::Variable => write!(f, "variable"),
28 }
29 }
30}
31
32impl FromStr for TokenEstimate {
33 type Err = std::num::ParseIntError;
34
35 fn from_str(s: &str) -> Result<Self, Self::Err> {
36 if s.eq_ignore_ascii_case("variable") {
37 Ok(Self::Variable)
38 } else {
39 s.parse::<u64>().map(Self::Count)
40 }
41 }
42}
43
44impl Serialize for TokenEstimate {
45 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
46 match self {
47 Self::Count(n) => serializer.serialize_u64(*n),
48 Self::Variable => serializer.serialize_str("variable"),
49 }
50 }
51}
52
53impl<'de> Deserialize<'de> for TokenEstimate {
54 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
55 struct Visitor;
56
57 impl<'de> de::Visitor<'de> for Visitor {
58 type Value = TokenEstimate;
59
60 fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61 f.write_str("a positive integer or the string \"variable\"")
62 }
63
64 fn visit_u64<E: de::Error>(self, v: u64) -> Result<Self::Value, E> {
65 Ok(TokenEstimate::Count(v))
66 }
67
68 fn visit_i64<E: de::Error>(self, v: i64) -> Result<Self::Value, E> {
69 if v >= 0 {
70 Ok(TokenEstimate::Count(v as u64))
71 } else {
72 Err(E::custom("token estimate must be non-negative"))
73 }
74 }
75
76 fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
77 if v.eq_ignore_ascii_case("variable") {
78 Ok(TokenEstimate::Variable)
79 } else {
80 Err(E::custom(format!("expected \"variable\", got {v:?}")))
81 }
82 }
83 }
84
85 deserializer.deserialize_any(Visitor)
86 }
87}
88
89#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
94pub struct LoadProfile {
95 pub filter: String,
96 #[serde(skip_serializing_if = "Option::is_none")]
97 pub estimated_tokens: Option<TokenEstimate>,
98}
99
100#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
105pub struct Header {
106 pub agm: String,
107 pub package: String,
108 pub version: String,
109 #[serde(skip_serializing_if = "Option::is_none")]
110 pub title: Option<String>,
111 #[serde(skip_serializing_if = "Option::is_none")]
112 pub owner: Option<String>,
113 #[serde(skip_serializing_if = "Option::is_none")]
114 pub imports: Option<Vec<ImportEntry>>,
115 #[serde(skip_serializing_if = "Option::is_none")]
116 pub default_load: Option<String>,
117 #[serde(skip_serializing_if = "Option::is_none")]
118 pub description: Option<String>,
119 #[serde(skip_serializing_if = "Option::is_none")]
120 pub tags: Option<Vec<String>>,
121 #[serde(skip_serializing_if = "Option::is_none")]
122 pub status: Option<String>,
123 #[serde(skip_serializing_if = "Option::is_none")]
124 pub load_profiles: Option<BTreeMap<String, LoadProfile>>,
125 #[serde(skip_serializing_if = "Option::is_none")]
126 pub target_runtime: Option<String>,
127}
128
129#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
134pub struct AgmFile {
135 #[serde(flatten)]
136 pub header: Header,
137 pub nodes: Vec<Node>,
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 #[test]
145 fn test_token_estimate_from_str_number_returns_count() {
146 let t: TokenEstimate = "1200".parse().unwrap();
147 assert_eq!(t, TokenEstimate::Count(1200));
148 }
149
150 #[test]
151 fn test_token_estimate_from_str_variable_returns_variable() {
152 let t: TokenEstimate = "variable".parse().unwrap();
153 assert_eq!(t, TokenEstimate::Variable);
154 }
155
156 #[test]
157 fn test_token_estimate_from_str_variable_case_insensitive() {
158 let t: TokenEstimate = "Variable".parse().unwrap();
159 assert_eq!(t, TokenEstimate::Variable);
160 }
161
162 #[test]
163 fn test_token_estimate_from_str_invalid_returns_error() {
164 assert!("not_a_number".parse::<TokenEstimate>().is_err());
165 }
166
167 #[test]
168 fn test_token_estimate_display_count() {
169 assert_eq!(TokenEstimate::Count(4200).to_string(), "4200");
170 }
171
172 #[test]
173 fn test_token_estimate_display_variable() {
174 assert_eq!(TokenEstimate::Variable.to_string(), "variable");
175 }
176
177 #[test]
178 fn test_token_estimate_serde_count() {
179 let t = TokenEstimate::Count(1200);
180 let json = serde_json::to_string(&t).unwrap();
181 assert_eq!(json, "1200");
182 let back: TokenEstimate = serde_json::from_str(&json).unwrap();
183 assert_eq!(t, back);
184 }
185
186 #[test]
187 fn test_token_estimate_serde_variable() {
188 let t = TokenEstimate::Variable;
189 let json = serde_json::to_string(&t).unwrap();
190 assert_eq!(json, "\"variable\"");
191 let back: TokenEstimate = serde_json::from_str(&json).unwrap();
192 assert_eq!(t, back);
193 }
194
195 #[test]
196 fn test_header_required_fields_only() {
197 let h = Header {
198 agm: "1.0".to_owned(),
199 package: "test.pkg".to_owned(),
200 version: "0.1.0".to_owned(),
201 title: None,
202 owner: None,
203 imports: None,
204 default_load: None,
205 description: None,
206 tags: None,
207 status: None,
208 load_profiles: None,
209 target_runtime: None,
210 };
211 assert_eq!(h.agm, "1.0");
212 assert_eq!(h.package, "test.pkg");
213 }
214
215 #[test]
216 fn test_load_profile_serde_roundtrip() {
217 let lp = LoadProfile {
218 filter: "priority in [critical]".to_owned(),
219 estimated_tokens: Some(TokenEstimate::Count(1200)),
220 };
221 let json = serde_json::to_string(&lp).unwrap();
222 let back: LoadProfile = serde_json::from_str(&json).unwrap();
223 assert_eq!(lp, back);
224 }
225
226 #[test]
227 fn test_agm_file_serde_minimal() {
228 let file = AgmFile {
229 header: Header {
230 agm: "1.0".to_owned(),
231 package: "test.pkg".to_owned(),
232 version: "0.1.0".to_owned(),
233 title: None,
234 owner: None,
235 imports: None,
236 default_load: None,
237 description: None,
238 tags: None,
239 status: None,
240 load_profiles: None,
241 target_runtime: None,
242 },
243 nodes: vec![],
244 };
245 let json = serde_json::to_string(&file).unwrap();
246 let back: AgmFile = serde_json::from_str(&json).unwrap();
247 assert_eq!(file, back);
248 }
249}