grid_tariffs/
operator.rs

1use std::{hash::Hash, str::FromStr};
2
3use chrono::{NaiveDate, Utc};
4use indexmap::IndexMap;
5use serde::{Deserialize, Serialize};
6
7use crate::{
8    Country, Currency, Language, Links, PriceList, fuse::MainFuseSizes,
9    price_list::PriceListSimplified, registry::sweden,
10};
11
12#[derive(Debug, Clone)]
13#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
14pub struct GridOperator {
15    name: &'static str,
16    vat_number: &'static str,
17    /// Costs are specified in this currency
18    country: Country,
19    /// The main fuse size range that this info covers
20    main_fuses: MainFuseSizes,
21    price_lists: &'static [PriceList],
22    links: Links,
23}
24
25impl GridOperator {
26    pub const fn name(&self) -> &str {
27        self.name
28    }
29
30    pub const fn vat_number(&self) -> &str {
31        self.vat_number
32    }
33
34    pub const fn country(&self) -> Country {
35        self.country
36    }
37
38    pub const fn links(&self) -> &Links {
39        &self.links
40    }
41
42    pub fn all_price_lists(&self) -> &'static [PriceList] {
43        self.price_lists
44    }
45
46    pub fn price_lists(&self, date: NaiveDate) -> Vec<&'static PriceList> {
47        let mut map: IndexMap<Option<&str>, &PriceList> = IndexMap::new();
48        for pl in self.price_lists {
49            if date >= pl.from_date() {
50                if let Some(current_max_date) = map.get(&pl.variant()).map(|pl| pl.from_date()) {
51                    if pl.from_date() > current_max_date {
52                        map.insert(pl.variant(), pl);
53                    }
54                } else {
55                    map.insert(pl.variant(), pl);
56                }
57            }
58        }
59        map.into_values().collect()
60    }
61
62    pub fn price_list(&self, variant: Option<&str>, date: NaiveDate) -> Option<&'static PriceList> {
63        self.price_lists(date)
64            .iter()
65            .filter(|pl| pl.variant() == variant)
66            .next_back()
67            .copied()
68    }
69
70    pub fn current_price_lists(&self) -> Vec<&'static PriceList> {
71        let now = Utc::now().date_naive();
72        self.price_lists(now)
73    }
74
75    pub fn current_price_list(&self, variant: Option<&str>) -> Option<&'static PriceList> {
76        let now = Utc::now().date_naive();
77        self.price_list(variant, now)
78    }
79
80    pub const fn currency(&self) -> Currency {
81        match self.country {
82            Country::SE => Currency::SEK,
83        }
84    }
85
86    pub fn get(country: Country, name: &str) -> Option<&'static Self> {
87        match country {
88            Country::SE => sweden::GRID_OPERATORS
89                .iter()
90                .find(|o| o.name == name)
91                .copied(),
92        }
93    }
94
95    pub fn get_by_vat_id(vat_id: &str) -> Option<&'static Self> {
96        let country: Country = vat_id[0..2].parse().ok()?;
97        match country {
98            Country::SE => sweden::GRID_OPERATORS
99                .iter()
100                .find(|o| o.vat_number == vat_id)
101                .copied(),
102        }
103    }
104
105    pub fn all() -> Vec<&'static Self> {
106        sweden::GRID_OPERATORS.to_vec()
107    }
108
109    pub fn all_for_country(country: Country) -> &'static [&'static Self] {
110        match country {
111            Country::SE => sweden::GRID_OPERATORS,
112        }
113    }
114
115    pub const fn builder() -> GridOperatorBuilder {
116        GridOperatorBuilder::new()
117    }
118
119    pub fn simplified(
120        &self,
121        fuse_size: u16,
122        yearly_consumption: u32,
123        language: Language,
124    ) -> GridOperatorSimplified {
125        GridOperatorSimplified::new(self, fuse_size, yearly_consumption, language)
126    }
127}
128
129/// Grid operator with only current prices
130#[derive(Debug, Clone, Serialize)]
131#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
132pub struct GridOperatorSimplified {
133    name: &'static str,
134    vat_number: &'static str,
135    /// Costs are specified in this currency
136    country: Country,
137    price_lists: Vec<PriceListSimplified>,
138}
139
140impl GridOperatorSimplified {
141    pub fn name(&self) -> &'static str {
142        self.name
143    }
144
145    pub fn vat_number(&self) -> &'static str {
146        self.vat_number
147    }
148
149    pub fn country(&self) -> Country {
150        self.country
151    }
152
153    pub fn price_lists(&self) -> &[PriceListSimplified] {
154        &self.price_lists
155    }
156}
157
158impl GridOperatorSimplified {
159    fn new(op: &GridOperator, fuse_size: u16, yearly_consumption: u32, language: Language) -> Self {
160        Self {
161            name: op.name,
162            vat_number: op.vat_number,
163            country: op.country(),
164            price_lists: op
165                .current_price_lists()
166                .into_iter()
167                .map(|pl| pl.simplified(fuse_size, yearly_consumption, language))
168                .collect(),
169        }
170    }
171}
172
173#[derive(Debug, Clone)]
174pub struct GridOperatorBuilder {
175    name: Option<&'static str>,
176    vat_number: Option<&'static str>,
177    /// Costs are specified in this currency
178    country: Option<Country>,
179    /// The main fuse size range that this info covers
180    main_fuses: Option<MainFuseSizes>,
181    price_lists: Option<&'static [PriceList]>,
182    links: Option<Links>,
183}
184
185impl Default for GridOperatorBuilder {
186    fn default() -> Self {
187        Self::new()
188    }
189}
190
191impl GridOperatorBuilder {
192    pub const fn new() -> Self {
193        Self {
194            name: None,
195            vat_number: None,
196            country: None,
197            main_fuses: None,
198            price_lists: None,
199            links: None,
200        }
201    }
202
203    pub const fn name(mut self, name: &'static str) -> Self {
204        self.name = Some(name);
205        self
206    }
207
208    pub const fn vat_number(mut self, vat_number: &'static str) -> Self {
209        self.vat_number = Some(vat_number);
210        self
211    }
212
213    pub const fn country(mut self, country: Country) -> Self {
214        self.country = Some(country);
215        self
216    }
217
218    pub const fn main_fuses(mut self, main_fuses: MainFuseSizes) -> Self {
219        self.main_fuses = Some(main_fuses);
220        self
221    }
222
223    pub const fn links(mut self, links: Links) -> Self {
224        self.links = Some(links);
225        self
226    }
227
228    pub const fn price_lists(mut self, price_lists: &'static [PriceList]) -> Self {
229        self.price_lists = Some(price_lists);
230        self
231    }
232
233    pub const fn build(self) -> GridOperator {
234        GridOperator {
235            name: self.name.expect("`name` required"),
236            vat_number: self.vat_number.expect("`vat_number` required"),
237            country: self.country.expect("`country` required"),
238            main_fuses: self.main_fuses.expect("`main_fuses` required"),
239            price_lists: self.price_lists.expect("`price_lists` expected"),
240            links: self.links.expect("`links` required"),
241        }
242    }
243}
244
245impl FromStr for &'static GridOperator {
246    type Err = &'static str;
247
248    fn from_str(s: &str) -> Result<Self, Self::Err> {
249        GridOperator::get_by_vat_id(s).ok_or("grid operator not found")
250    }
251}
252
253impl<'de> Deserialize<'de> for &'static GridOperator {
254    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
255    where
256        D: serde::Deserializer<'de>,
257    {
258        let s = String::deserialize(deserializer)?;
259        Self::from_str(&s).map_err(serde::de::Error::custom)
260    }
261}
262
263impl Serialize for GridOperator {
264    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
265    where
266        S: serde::Serializer,
267    {
268        serializer.serialize_str(self.vat_number())
269    }
270}
271
272impl PartialEq for GridOperator {
273    fn eq(&self, other: &Self) -> bool {
274        self.vat_number == other.vat_number
275    }
276}
277
278impl Eq for GridOperator {}
279
280impl PartialOrd for GridOperator {
281    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
282        Some(self.cmp(other))
283    }
284}
285
286impl Ord for GridOperator {
287    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
288        self.vat_number.cmp(other.vat_number)
289    }
290}
291
292impl Hash for GridOperator {
293    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
294        self.vat_number.hash(state);
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[test]
303    fn test_grid_operator_serialize_deserialize() {
304        let operator: &GridOperator = "SE556037732601".parse().unwrap();
305
306        // Serialize to JSON string
307        let serialized = serde_json::to_string(operator).expect("serialization should succeed");
308
309        // The serialized value should be the VAT number wrapped in quotes
310        assert_eq!(serialized, format!("\"{}\"", operator.vat_number()));
311
312        // Deserialize back to GridOperator reference
313        let deserialized: &'static GridOperator =
314            serde_json::from_str(&serialized).expect("deserialization should succeed");
315
316        // Verify the deserialized operator matches the original
317        assert_eq!(deserialized.name(), operator.name());
318        assert_eq!(deserialized.vat_number(), operator.vat_number());
319        assert_eq!(deserialized.country(), operator.country());
320    }
321
322    #[test]
323    fn test_grid_operator_deserialize_invalid() {
324        // Try to deserialize an invalid VAT number
325        let result: Result<&'static GridOperator, _> = serde_json::from_str("\"SE000000000000\"");
326
327        assert!(
328            result.is_err(),
329            "deserialization should fail for invalid VAT number"
330        );
331    }
332
333    #[test]
334    fn test_grid_operator_partial_eq() {
335        let operator1: &GridOperator = "SE556037732601".parse().unwrap();
336        let operator2: &GridOperator = "SE556037732601".parse().unwrap();
337        let operator3: &GridOperator = "SE556532083401".parse().unwrap();
338
339        // Test equality - same VAT number
340        assert_eq!(operator1, operator2);
341
342        // Test inequality - different VAT numbers
343        assert_ne!(operator1, operator3);
344    }
345
346    #[test]
347    fn test_grid_operator_partial_ord() {
348        let operator1: &GridOperator = "SE556037732601".parse().unwrap();
349        let operator2: &GridOperator = "SE556532083401".parse().unwrap();
350
351        // Test ordering based on VAT number
352        assert!(operator1 < operator2);
353        assert!(operator2 > operator1);
354        assert!(operator1 <= operator1);
355        assert!(operator1 >= operator1);
356    }
357
358    #[test]
359    fn test_grid_operator_ord() {
360        let operator1: &GridOperator = "SE556037732601".parse().unwrap();
361        let operator2: &GridOperator = "SE556532083401".parse().unwrap();
362        let operator3: &GridOperator = "SE556037732601".parse().unwrap();
363
364        // Test Ord::cmp
365        assert_eq!(operator1.cmp(operator2), std::cmp::Ordering::Less);
366        assert_eq!(operator2.cmp(operator1), std::cmp::Ordering::Greater);
367        assert_eq!(operator1.cmp(operator3), std::cmp::Ordering::Equal);
368    }
369
370    #[test]
371    fn test_grid_operator_hash() {
372        use std::collections::hash_map::DefaultHasher;
373        use std::hash::{Hash, Hasher};
374
375        let operator1: &GridOperator = "SE556037732601".parse().unwrap();
376        let operator2: &GridOperator = "SE556037732601".parse().unwrap();
377        let operator3: &GridOperator = "SE556532083401".parse().unwrap();
378
379        // Helper function to compute hash
380        let compute_hash = |op: &GridOperator| {
381            let mut hasher = DefaultHasher::new();
382            op.hash(&mut hasher);
383            hasher.finish()
384        };
385
386        // Operators with same VAT number should have same hash
387        assert_eq!(compute_hash(operator1), compute_hash(operator2));
388
389        // Operators with different VAT numbers should have different hashes
390        assert_ne!(compute_hash(operator1), compute_hash(operator3));
391    }
392}