grid_tariffs/
operator.rs

1use std::{hash::Hash, str::FromStr};
2
3use chrono::{NaiveDate, Utc};
4use indexmap::IndexMap;
5#[cfg(feature = "schemars")]
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8
9use crate::{
10    Country, Currency, Language, Links, PriceList, fuse::MainFuseSizes,
11    price_list::PriceListSimplified, registry::sweden,
12};
13
14#[derive(Debug, Clone)]
15pub struct GridOperator {
16    name: &'static str,
17    vat_number: &'static str,
18    /// Costs are specified in this currency
19    country: Country,
20    /// The main fuse size range that this info covers
21    main_fuses: MainFuseSizes,
22    price_lists: &'static [PriceList],
23    links: Links,
24}
25
26impl GridOperator {
27    pub const fn name(&self) -> &str {
28        self.name
29    }
30
31    pub const fn vat_number(&self) -> &str {
32        self.vat_number
33    }
34
35    pub const fn country(&self) -> Country {
36        self.country
37    }
38
39    pub const fn links(&self) -> &Links {
40        &self.links
41    }
42
43    pub fn all_price_lists(&self) -> &'static [PriceList] {
44        self.price_lists
45    }
46
47    pub fn price_lists(&self, date: NaiveDate) -> Vec<&'static PriceList> {
48        let mut map: IndexMap<Option<&str>, &PriceList> = IndexMap::new();
49        for pl in self.price_lists {
50            if date >= pl.from_date() {
51                if let Some(current_max_date) = map.get(&pl.variant()).map(|pl| pl.from_date()) {
52                    if pl.from_date() > current_max_date {
53                        map.insert(pl.variant(), pl);
54                    }
55                } else {
56                    map.insert(pl.variant(), pl);
57                }
58            }
59        }
60        map.into_values().collect()
61    }
62
63    pub fn price_list(&self, variant: Option<&str>, date: NaiveDate) -> Option<&'static PriceList> {
64        self.price_lists(date)
65            .iter()
66            .filter(|pl| pl.variant() == variant)
67            .next_back()
68            .copied()
69    }
70
71    pub fn current_price_lists(&self) -> Vec<&'static PriceList> {
72        let now = Utc::now().date_naive();
73        self.price_lists(now)
74    }
75
76    pub fn current_price_list(&self, variant: Option<&str>) -> Option<&'static PriceList> {
77        let now = Utc::now().date_naive();
78        self.price_list(variant, now)
79    }
80
81    pub const fn currency(&self) -> Currency {
82        match self.country {
83            Country::SE => Currency::SEK,
84        }
85    }
86
87    pub fn get(country: Country, name: &str) -> Option<&'static Self> {
88        match country {
89            Country::SE => sweden::GRID_OPERATORS
90                .iter()
91                .find(|o| o.name == name)
92                .copied(),
93        }
94    }
95
96    pub fn get_by_vat_id(vat_id: &str) -> Option<&'static Self> {
97        let country: Country = vat_id[0..2].parse().ok()?;
98        match country {
99            Country::SE => sweden::GRID_OPERATORS
100                .iter()
101                .find(|o| o.vat_number == vat_id)
102                .copied(),
103        }
104    }
105
106    pub fn all() -> Vec<&'static Self> {
107        sweden::GRID_OPERATORS.to_vec()
108    }
109
110    pub fn all_for_country(country: Country) -> &'static [&'static Self] {
111        match country {
112            Country::SE => sweden::GRID_OPERATORS,
113        }
114    }
115
116    pub const fn builder() -> GridOperatorBuilder {
117        GridOperatorBuilder::new()
118    }
119
120    pub fn simplified(
121        &self,
122        fuse_size: u16,
123        yearly_consumption: u32,
124        language: Language,
125    ) -> GridOperatorSimplified {
126        GridOperatorSimplified::new(self, fuse_size, yearly_consumption, language)
127    }
128}
129
130/// Grid operator with only current prices
131#[derive(Debug, Clone, Serialize)]
132#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
133pub struct GridOperatorSimplified {
134    name: &'static str,
135    vat_number: &'static str,
136    /// Costs are specified in this currency
137    country: Country,
138    price_lists: Vec<PriceListSimplified>,
139}
140
141impl GridOperatorSimplified {
142    pub fn name(&self) -> &'static str {
143        self.name
144    }
145
146    pub fn vat_number(&self) -> &'static str {
147        self.vat_number
148    }
149
150    pub fn country(&self) -> Country {
151        self.country
152    }
153
154    pub fn price_lists(&self) -> &[PriceListSimplified] {
155        &self.price_lists
156    }
157}
158
159impl GridOperatorSimplified {
160    fn new(op: &GridOperator, fuse_size: u16, yearly_consumption: u32, language: Language) -> Self {
161        Self {
162            name: op.name,
163            vat_number: op.vat_number,
164            country: op.country(),
165            price_lists: op
166                .current_price_lists()
167                .into_iter()
168                .map(|pl| pl.simplified(fuse_size, yearly_consumption, language))
169                .collect(),
170        }
171    }
172}
173
174#[derive(Debug, Clone)]
175pub struct GridOperatorBuilder {
176    name: Option<&'static str>,
177    vat_number: Option<&'static str>,
178    /// Costs are specified in this currency
179    country: Option<Country>,
180    /// The main fuse size range that this info covers
181    main_fuses: Option<MainFuseSizes>,
182    price_lists: Option<&'static [PriceList]>,
183    links: Option<Links>,
184}
185
186impl Default for GridOperatorBuilder {
187    fn default() -> Self {
188        Self::new()
189    }
190}
191
192impl GridOperatorBuilder {
193    pub const fn new() -> Self {
194        Self {
195            name: None,
196            vat_number: None,
197            country: None,
198            main_fuses: None,
199            price_lists: None,
200            links: None,
201        }
202    }
203
204    pub const fn name(mut self, name: &'static str) -> Self {
205        self.name = Some(name);
206        self
207    }
208
209    pub const fn vat_number(mut self, vat_number: &'static str) -> Self {
210        self.vat_number = Some(vat_number);
211        self
212    }
213
214    pub const fn country(mut self, country: Country) -> Self {
215        self.country = Some(country);
216        self
217    }
218
219    pub const fn main_fuses(mut self, main_fuses: MainFuseSizes) -> Self {
220        self.main_fuses = Some(main_fuses);
221        self
222    }
223
224    pub const fn links(mut self, links: Links) -> Self {
225        self.links = Some(links);
226        self
227    }
228
229    pub const fn price_lists(mut self, price_lists: &'static [PriceList]) -> Self {
230        self.price_lists = Some(price_lists);
231        self
232    }
233
234    pub const fn build(self) -> GridOperator {
235        GridOperator {
236            name: self.name.expect("`name` required"),
237            vat_number: self.vat_number.expect("`vat_number` required"),
238            country: self.country.expect("`country` required"),
239            main_fuses: self.main_fuses.expect("`main_fuses` required"),
240            price_lists: self.price_lists.expect("`price_lists` expected"),
241            links: self.links.expect("`links` required"),
242        }
243    }
244}
245
246impl FromStr for &'static GridOperator {
247    type Err = &'static str;
248
249    fn from_str(s: &str) -> Result<Self, Self::Err> {
250        GridOperator::get_by_vat_id(s).ok_or("grid operator not found")
251    }
252}
253
254impl<'de> Deserialize<'de> for &'static GridOperator {
255    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
256    where
257        D: serde::Deserializer<'de>,
258    {
259        let s = String::deserialize(deserializer)?;
260        Self::from_str(&s).map_err(serde::de::Error::custom)
261    }
262}
263
264impl Serialize for GridOperator {
265    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
266    where
267        S: serde::Serializer,
268    {
269        serializer.serialize_str(self.vat_number())
270    }
271}
272
273#[cfg(feature = "schemars")]
274impl JsonSchema for GridOperator {
275    fn schema_name() -> std::borrow::Cow<'static, str> {
276        "GridOperator".into()
277    }
278
279    fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
280        String::json_schema(generator)
281    }
282}
283
284impl PartialEq for GridOperator {
285    fn eq(&self, other: &Self) -> bool {
286        self.vat_number == other.vat_number
287    }
288}
289
290impl Eq for GridOperator {}
291
292impl PartialOrd for GridOperator {
293    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
294        Some(self.cmp(other))
295    }
296}
297
298impl Ord for GridOperator {
299    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
300        self.vat_number.cmp(other.vat_number)
301    }
302}
303
304impl Hash for GridOperator {
305    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
306        self.vat_number.hash(state);
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    #[test]
315    fn test_grid_operator_serialize_deserialize() {
316        let operator: &GridOperator = "SE556037732601".parse().unwrap();
317
318        // Serialize to JSON string
319        let serialized = serde_json::to_string(operator).expect("serialization should succeed");
320
321        // The serialized value should be the VAT number wrapped in quotes
322        assert_eq!(serialized, format!("\"{}\"", operator.vat_number()));
323
324        // Deserialize back to GridOperator reference
325        let deserialized: &'static GridOperator =
326            serde_json::from_str(&serialized).expect("deserialization should succeed");
327
328        // Verify the deserialized operator matches the original
329        assert_eq!(deserialized.name(), operator.name());
330        assert_eq!(deserialized.vat_number(), operator.vat_number());
331        assert_eq!(deserialized.country(), operator.country());
332    }
333
334    #[test]
335    fn test_grid_operator_deserialize_invalid() {
336        // Try to deserialize an invalid VAT number
337        let result: Result<&'static GridOperator, _> = serde_json::from_str("\"SE000000000000\"");
338
339        assert!(
340            result.is_err(),
341            "deserialization should fail for invalid VAT number"
342        );
343    }
344
345    #[test]
346    fn test_grid_operator_partial_eq() {
347        let operator1: &GridOperator = "SE556037732601".parse().unwrap();
348        let operator2: &GridOperator = "SE556037732601".parse().unwrap();
349        let operator3: &GridOperator = "SE556532083401".parse().unwrap();
350
351        // Test equality - same VAT number
352        assert_eq!(operator1, operator2);
353
354        // Test inequality - different VAT numbers
355        assert_ne!(operator1, operator3);
356    }
357
358    #[test]
359    fn test_grid_operator_partial_ord() {
360        let operator1: &GridOperator = "SE556037732601".parse().unwrap();
361        let operator2: &GridOperator = "SE556532083401".parse().unwrap();
362
363        // Test ordering based on VAT number
364        assert!(operator1 < operator2);
365        assert!(operator2 > operator1);
366        assert!(operator1 <= operator1);
367        assert!(operator1 >= operator1);
368    }
369
370    #[test]
371    fn test_grid_operator_ord() {
372        let operator1: &GridOperator = "SE556037732601".parse().unwrap();
373        let operator2: &GridOperator = "SE556532083401".parse().unwrap();
374        let operator3: &GridOperator = "SE556037732601".parse().unwrap();
375
376        // Test Ord::cmp
377        assert_eq!(operator1.cmp(operator2), std::cmp::Ordering::Less);
378        assert_eq!(operator2.cmp(operator1), std::cmp::Ordering::Greater);
379        assert_eq!(operator1.cmp(operator3), std::cmp::Ordering::Equal);
380    }
381
382    #[test]
383    fn test_grid_operator_hash() {
384        use std::collections::hash_map::DefaultHasher;
385        use std::hash::{Hash, Hasher};
386
387        let operator1: &GridOperator = "SE556037732601".parse().unwrap();
388        let operator2: &GridOperator = "SE556037732601".parse().unwrap();
389        let operator3: &GridOperator = "SE556532083401".parse().unwrap();
390
391        // Helper function to compute hash
392        let compute_hash = |op: &GridOperator| {
393            let mut hasher = DefaultHasher::new();
394            op.hash(&mut hasher);
395            hasher.finish()
396        };
397
398        // Operators with same VAT number should have same hash
399        assert_eq!(compute_hash(operator1), compute_hash(operator2));
400
401        // Operators with different VAT numbers should have different hashes
402        assert_ne!(compute_hash(operator1), compute_hash(operator3));
403    }
404}