1use std::{hash::Hash, str::FromStr};
2
3use chrono::{NaiveDate, Utc};
4use indexmap::IndexMap;
5#[cfg(feature = "schemars")]
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8
9use crate::{
10 Country, Currency, Language, Links, PriceList, price_list::PriceListSimplified,
11 registry::sweden,
12};
13
14#[derive(Debug, Clone)]
15pub struct GridOperator {
16 name: &'static str,
17 vat_number: &'static str,
18 country: Country,
20 price_lists: &'static [PriceList],
21 links: Links,
22}
23
24impl GridOperator {
25 pub const fn name(&self) -> &str {
26 self.name
27 }
28
29 pub const fn vat_number(&self) -> &str {
30 self.vat_number
31 }
32
33 pub const fn country(&self) -> Country {
34 self.country
35 }
36
37 pub const fn links(&self) -> &Links {
38 &self.links
39 }
40
41 pub fn all_price_lists(&self) -> &'static [PriceList] {
42 self.price_lists
43 }
44
45 pub fn price_lists(&self, date: NaiveDate) -> Vec<&'static PriceList> {
46 let mut map: IndexMap<Option<&str>, &PriceList> = IndexMap::new();
47 for pl in self.price_lists {
48 if date >= pl.from_date() {
49 if let Some(current_max_date) = map.get(&pl.variant()).map(|pl| pl.from_date()) {
50 if pl.from_date() > current_max_date {
51 map.insert(pl.variant(), pl);
52 }
53 } else {
54 map.insert(pl.variant(), pl);
55 }
56 }
57 }
58 map.into_values().collect()
59 }
60
61 pub fn price_list(&self, variant: Option<&str>, date: NaiveDate) -> Option<&'static PriceList> {
62 self.price_lists(date)
63 .iter()
64 .filter(|pl| pl.variant() == variant)
65 .next_back()
66 .copied()
67 }
68
69 pub fn current_price_lists(&self) -> Vec<&'static PriceList> {
70 let now = Utc::now().date_naive();
71 self.price_lists(now)
72 }
73
74 pub fn current_price_list(&self, variant: Option<&str>) -> Option<&'static PriceList> {
75 let now = Utc::now().date_naive();
76 self.price_list(variant, now)
77 }
78
79 pub const fn currency(&self) -> Currency {
80 match self.country {
81 Country::SE => Currency::SEK,
82 }
83 }
84
85 pub fn get(country: Country, name: &str) -> Option<&'static Self> {
86 match country {
87 Country::SE => sweden::GRID_OPERATORS
88 .iter()
89 .find(|o| o.name == name)
90 .copied(),
91 }
92 }
93
94 pub fn get_by_vat_id(vat_id: &str) -> Option<&'static Self> {
95 let country: Country = vat_id[0..2].parse().ok()?;
96 match country {
97 Country::SE => sweden::GRID_OPERATORS
98 .iter()
99 .find(|o| o.vat_number == vat_id)
100 .copied(),
101 }
102 }
103
104 pub fn all() -> Vec<&'static Self> {
105 sweden::GRID_OPERATORS.to_vec()
106 }
107
108 pub fn all_for_country(country: Country) -> &'static [&'static Self] {
109 match country {
110 Country::SE => sweden::GRID_OPERATORS,
111 }
112 }
113
114 pub const fn builder() -> GridOperatorBuilder {
115 GridOperatorBuilder::new()
116 }
117
118 pub fn simplified(
119 &self,
120 fuse_size: u16,
121 yearly_consumption: u32,
122 language: Language,
123 ) -> GridOperatorSimplified {
124 GridOperatorSimplified::new(self, fuse_size, yearly_consumption, language)
125 }
126}
127
128#[derive(Debug, Clone, Serialize)]
130#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
131pub struct GridOperatorSimplified {
132 name: &'static str,
133 vat_number: &'static str,
134 country: Country,
136 price_lists: Vec<PriceListSimplified>,
137}
138
139impl GridOperatorSimplified {
140 pub fn name(&self) -> &'static str {
141 self.name
142 }
143
144 pub fn vat_number(&self) -> &'static str {
145 self.vat_number
146 }
147
148 pub fn country(&self) -> Country {
149 self.country
150 }
151
152 pub fn price_lists(&self) -> &[PriceListSimplified] {
153 &self.price_lists
154 }
155}
156
157impl GridOperatorSimplified {
158 fn new(op: &GridOperator, fuse_size: u16, yearly_consumption: u32, language: Language) -> Self {
159 Self {
160 name: op.name,
161 vat_number: op.vat_number,
162 country: op.country(),
163 price_lists: op
164 .current_price_lists()
165 .into_iter()
166 .map(|pl| pl.simplified(op, fuse_size, yearly_consumption, language))
167 .collect(),
168 }
169 }
170}
171
172#[derive(Debug, Clone)]
173pub struct GridOperatorBuilder {
174 name: Option<&'static str>,
175 vat_number: Option<&'static str>,
176 country: Option<Country>,
178 price_lists: Option<&'static [PriceList]>,
179 links: Option<Links>,
180}
181
182impl Default for GridOperatorBuilder {
183 fn default() -> Self {
184 Self::new()
185 }
186}
187
188impl GridOperatorBuilder {
189 pub const fn new() -> Self {
190 Self {
191 name: None,
192 vat_number: None,
193 country: None,
194 price_lists: None,
195 links: None,
196 }
197 }
198
199 pub const fn name(mut self, name: &'static str) -> Self {
200 self.name = Some(name);
201 self
202 }
203
204 pub const fn vat_number(mut self, vat_number: &'static str) -> Self {
205 self.vat_number = Some(vat_number);
206 self
207 }
208
209 pub const fn country(mut self, country: Country) -> Self {
210 self.country = Some(country);
211 self
212 }
213
214 pub const fn links(mut self, links: Links) -> Self {
215 self.links = Some(links);
216 self
217 }
218
219 pub const fn price_lists(mut self, price_lists: &'static [PriceList]) -> Self {
220 self.price_lists = Some(price_lists);
221 self
222 }
223
224 pub const fn build(self) -> GridOperator {
225 GridOperator {
226 name: self.name.expect("`name` required"),
227 vat_number: self.vat_number.expect("`vat_number` required"),
228 country: self.country.expect("`country` required"),
229 price_lists: self.price_lists.expect("`price_lists` expected"),
230 links: self.links.expect("`links` required"),
231 }
232 }
233}
234
235impl FromStr for &'static GridOperator {
236 type Err = &'static str;
237
238 fn from_str(s: &str) -> Result<Self, Self::Err> {
239 GridOperator::get_by_vat_id(s).ok_or("grid operator not found")
240 }
241}
242
243impl<'de> Deserialize<'de> for &'static GridOperator {
244 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
245 where
246 D: serde::Deserializer<'de>,
247 {
248 let s = String::deserialize(deserializer)?;
249 Self::from_str(&s).map_err(serde::de::Error::custom)
250 }
251}
252
253impl Serialize for GridOperator {
254 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
255 where
256 S: serde::Serializer,
257 {
258 serializer.serialize_str(self.vat_number())
259 }
260}
261
262#[cfg(feature = "schemars")]
263impl JsonSchema for GridOperator {
264 fn schema_name() -> std::borrow::Cow<'static, str> {
265 "GridOperator".into()
266 }
267
268 fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
269 String::json_schema(generator)
270 }
271}
272
273impl PartialEq for GridOperator {
274 fn eq(&self, other: &Self) -> bool {
275 self.vat_number == other.vat_number
276 }
277}
278
279impl Eq for GridOperator {}
280
281impl PartialOrd for GridOperator {
282 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
283 Some(self.cmp(other))
284 }
285}
286
287impl Ord for GridOperator {
288 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
289 self.vat_number.cmp(other.vat_number)
290 }
291}
292
293impl Hash for GridOperator {
294 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
295 self.vat_number.hash(state);
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 #[test]
304 fn test_grid_operator_serialize_deserialize() {
305 let operator: &GridOperator = "SE556037732601".parse().unwrap();
306
307 let serialized = serde_json::to_string(operator).expect("serialization should succeed");
309
310 assert_eq!(serialized, format!("\"{}\"", operator.vat_number()));
312
313 let deserialized: &'static GridOperator =
315 serde_json::from_str(&serialized).expect("deserialization should succeed");
316
317 assert_eq!(deserialized.name(), operator.name());
319 assert_eq!(deserialized.vat_number(), operator.vat_number());
320 assert_eq!(deserialized.country(), operator.country());
321 }
322
323 #[test]
324 fn test_grid_operator_deserialize_invalid() {
325 let result: Result<&'static GridOperator, _> = serde_json::from_str("\"SE000000000000\"");
327
328 assert!(
329 result.is_err(),
330 "deserialization should fail for invalid VAT number"
331 );
332 }
333
334 #[test]
335 fn test_grid_operator_partial_eq() {
336 let operator1: &GridOperator = "SE556037732601".parse().unwrap();
337 let operator2: &GridOperator = "SE556037732601".parse().unwrap();
338 let operator3: &GridOperator = "SE556532083401".parse().unwrap();
339
340 assert_eq!(operator1, operator2);
342
343 assert_ne!(operator1, operator3);
345 }
346
347 #[test]
348 fn test_grid_operator_partial_ord() {
349 let operator1: &GridOperator = "SE556037732601".parse().unwrap();
350 let operator2: &GridOperator = "SE556532083401".parse().unwrap();
351
352 assert!(operator1 < operator2);
354 assert!(operator2 > operator1);
355 assert!(operator1 <= operator1);
356 assert!(operator1 >= operator1);
357 }
358
359 #[test]
360 fn test_grid_operator_ord() {
361 let operator1: &GridOperator = "SE556037732601".parse().unwrap();
362 let operator2: &GridOperator = "SE556532083401".parse().unwrap();
363 let operator3: &GridOperator = "SE556037732601".parse().unwrap();
364
365 assert_eq!(operator1.cmp(operator2), std::cmp::Ordering::Less);
367 assert_eq!(operator2.cmp(operator1), std::cmp::Ordering::Greater);
368 assert_eq!(operator1.cmp(operator3), std::cmp::Ordering::Equal);
369 }
370
371 #[test]
372 fn test_grid_operator_hash() {
373 use std::collections::hash_map::DefaultHasher;
374 use std::hash::{Hash, Hasher};
375
376 let operator1: &GridOperator = "SE556037732601".parse().unwrap();
377 let operator2: &GridOperator = "SE556037732601".parse().unwrap();
378 let operator3: &GridOperator = "SE556532083401".parse().unwrap();
379
380 let compute_hash = |op: &GridOperator| {
382 let mut hasher = DefaultHasher::new();
383 op.hash(&mut hasher);
384 hasher.finish()
385 };
386
387 assert_eq!(compute_hash(operator1), compute_hash(operator2));
389
390 assert_ne!(compute_hash(operator1), compute_hash(operator3));
392 }
393}