grid_tariffs/
operator.rs

1use std::{hash::Hash, str::FromStr};
2
3use chrono::{NaiveDate, Utc};
4use indexmap::IndexMap;
5#[cfg(feature = "schemars")]
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8
9use crate::{
10    Country, Currency, Language, Links, PriceList, price_list::PriceListSimplified,
11    registry::sweden,
12};
13
14#[derive(Debug, Clone)]
15pub struct GridOperator {
16    name: &'static str,
17    vat_number: &'static str,
18    /// Costs are specified in this currency
19    country: Country,
20    price_lists: &'static [PriceList],
21    links: Links,
22}
23
24impl GridOperator {
25    pub const fn name(&self) -> &str {
26        self.name
27    }
28
29    pub const fn vat_number(&self) -> &str {
30        self.vat_number
31    }
32
33    pub const fn country(&self) -> Country {
34        self.country
35    }
36
37    pub const fn links(&self) -> &Links {
38        &self.links
39    }
40
41    pub fn all_price_lists(&self) -> &'static [PriceList] {
42        self.price_lists
43    }
44
45    pub fn price_lists(&self, date: NaiveDate) -> Vec<&'static PriceList> {
46        let mut map: IndexMap<Option<&str>, &PriceList> = IndexMap::new();
47        for pl in self.price_lists {
48            if date >= pl.from_date() {
49                if let Some(current_max_date) = map.get(&pl.variant()).map(|pl| pl.from_date()) {
50                    if pl.from_date() > current_max_date {
51                        map.insert(pl.variant(), pl);
52                    }
53                } else {
54                    map.insert(pl.variant(), pl);
55                }
56            }
57        }
58        map.into_values().collect()
59    }
60
61    pub fn price_list(&self, variant: Option<&str>, date: NaiveDate) -> Option<&'static PriceList> {
62        self.price_lists(date)
63            .iter()
64            .filter(|pl| pl.variant() == variant)
65            .next_back()
66            .copied()
67    }
68
69    pub fn current_price_lists(&self) -> Vec<&'static PriceList> {
70        let now = Utc::now().date_naive();
71        self.price_lists(now)
72    }
73
74    pub fn current_price_list(&self, variant: Option<&str>) -> Option<&'static PriceList> {
75        let now = Utc::now().date_naive();
76        self.price_list(variant, now)
77    }
78
79    pub const fn currency(&self) -> Currency {
80        match self.country {
81            Country::SE => Currency::SEK,
82        }
83    }
84
85    pub fn get(country: Country, name: &str) -> Option<&'static Self> {
86        match country {
87            Country::SE => sweden::GRID_OPERATORS
88                .iter()
89                .find(|o| o.name == name)
90                .copied(),
91        }
92    }
93
94    pub fn get_by_vat_id(vat_id: &str) -> Option<&'static Self> {
95        let country: Country = vat_id[0..2].parse().ok()?;
96        match country {
97            Country::SE => sweden::GRID_OPERATORS
98                .iter()
99                .find(|o| o.vat_number == vat_id)
100                .copied(),
101        }
102    }
103
104    pub fn all() -> Vec<&'static Self> {
105        sweden::GRID_OPERATORS.to_vec()
106    }
107
108    pub fn all_for_country(country: Country) -> &'static [&'static Self] {
109        match country {
110            Country::SE => sweden::GRID_OPERATORS,
111        }
112    }
113
114    pub const fn builder() -> GridOperatorBuilder {
115        GridOperatorBuilder::new()
116    }
117
118    pub fn simplified(
119        &self,
120        fuse_size: u16,
121        yearly_consumption: u32,
122        language: Language,
123    ) -> GridOperatorSimplified {
124        GridOperatorSimplified::new(self, fuse_size, yearly_consumption, language)
125    }
126}
127
128/// Grid operator with only current prices
129#[derive(Debug, Clone, Serialize)]
130#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
131pub struct GridOperatorSimplified {
132    name: &'static str,
133    vat_number: &'static str,
134    /// Costs are specified in this currency
135    country: Country,
136    price_lists: Vec<PriceListSimplified>,
137}
138
139impl GridOperatorSimplified {
140    pub fn name(&self) -> &'static str {
141        self.name
142    }
143
144    pub fn vat_number(&self) -> &'static str {
145        self.vat_number
146    }
147
148    pub fn country(&self) -> Country {
149        self.country
150    }
151
152    pub fn price_lists(&self) -> &[PriceListSimplified] {
153        &self.price_lists
154    }
155}
156
157impl GridOperatorSimplified {
158    fn new(op: &GridOperator, fuse_size: u16, yearly_consumption: u32, language: Language) -> Self {
159        Self {
160            name: op.name,
161            vat_number: op.vat_number,
162            country: op.country(),
163            price_lists: op
164                .current_price_lists()
165                .into_iter()
166                .map(|pl| pl.simplified(op, fuse_size, yearly_consumption, language))
167                .collect(),
168        }
169    }
170}
171
172#[derive(Debug, Clone)]
173pub struct GridOperatorBuilder {
174    name: Option<&'static str>,
175    vat_number: Option<&'static str>,
176    /// Costs are specified in this currency
177    country: Option<Country>,
178    price_lists: Option<&'static [PriceList]>,
179    links: Option<Links>,
180}
181
182impl Default for GridOperatorBuilder {
183    fn default() -> Self {
184        Self::new()
185    }
186}
187
188impl GridOperatorBuilder {
189    pub const fn new() -> Self {
190        Self {
191            name: None,
192            vat_number: None,
193            country: None,
194            price_lists: None,
195            links: None,
196        }
197    }
198
199    pub const fn name(mut self, name: &'static str) -> Self {
200        self.name = Some(name);
201        self
202    }
203
204    pub const fn vat_number(mut self, vat_number: &'static str) -> Self {
205        self.vat_number = Some(vat_number);
206        self
207    }
208
209    pub const fn country(mut self, country: Country) -> Self {
210        self.country = Some(country);
211        self
212    }
213
214    pub const fn links(mut self, links: Links) -> Self {
215        self.links = Some(links);
216        self
217    }
218
219    pub const fn price_lists(mut self, price_lists: &'static [PriceList]) -> Self {
220        self.price_lists = Some(price_lists);
221        self
222    }
223
224    pub const fn build(self) -> GridOperator {
225        GridOperator {
226            name: self.name.expect("`name` required"),
227            vat_number: self.vat_number.expect("`vat_number` required"),
228            country: self.country.expect("`country` required"),
229            price_lists: self.price_lists.expect("`price_lists` expected"),
230            links: self.links.expect("`links` required"),
231        }
232    }
233}
234
235impl FromStr for &'static GridOperator {
236    type Err = &'static str;
237
238    fn from_str(s: &str) -> Result<Self, Self::Err> {
239        GridOperator::get_by_vat_id(s).ok_or("grid operator not found")
240    }
241}
242
243impl<'de> Deserialize<'de> for &'static GridOperator {
244    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
245    where
246        D: serde::Deserializer<'de>,
247    {
248        let s = String::deserialize(deserializer)?;
249        Self::from_str(&s).map_err(serde::de::Error::custom)
250    }
251}
252
253impl Serialize for GridOperator {
254    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
255    where
256        S: serde::Serializer,
257    {
258        serializer.serialize_str(self.vat_number())
259    }
260}
261
262#[cfg(feature = "schemars")]
263impl JsonSchema for GridOperator {
264    fn schema_name() -> std::borrow::Cow<'static, str> {
265        "GridOperator".into()
266    }
267
268    fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
269        String::json_schema(generator)
270    }
271}
272
273impl PartialEq for GridOperator {
274    fn eq(&self, other: &Self) -> bool {
275        self.vat_number == other.vat_number
276    }
277}
278
279impl Eq for GridOperator {}
280
281impl PartialOrd for GridOperator {
282    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
283        Some(self.cmp(other))
284    }
285}
286
287impl Ord for GridOperator {
288    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
289        self.vat_number.cmp(other.vat_number)
290    }
291}
292
293impl Hash for GridOperator {
294    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
295        self.vat_number.hash(state);
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302
303    #[test]
304    fn test_grid_operator_serialize_deserialize() {
305        let operator: &GridOperator = "SE556037732601".parse().unwrap();
306
307        // Serialize to JSON string
308        let serialized = serde_json::to_string(operator).expect("serialization should succeed");
309
310        // The serialized value should be the VAT number wrapped in quotes
311        assert_eq!(serialized, format!("\"{}\"", operator.vat_number()));
312
313        // Deserialize back to GridOperator reference
314        let deserialized: &'static GridOperator =
315            serde_json::from_str(&serialized).expect("deserialization should succeed");
316
317        // Verify the deserialized operator matches the original
318        assert_eq!(deserialized.name(), operator.name());
319        assert_eq!(deserialized.vat_number(), operator.vat_number());
320        assert_eq!(deserialized.country(), operator.country());
321    }
322
323    #[test]
324    fn test_grid_operator_deserialize_invalid() {
325        // Try to deserialize an invalid VAT number
326        let result: Result<&'static GridOperator, _> = serde_json::from_str("\"SE000000000000\"");
327
328        assert!(
329            result.is_err(),
330            "deserialization should fail for invalid VAT number"
331        );
332    }
333
334    #[test]
335    fn test_grid_operator_partial_eq() {
336        let operator1: &GridOperator = "SE556037732601".parse().unwrap();
337        let operator2: &GridOperator = "SE556037732601".parse().unwrap();
338        let operator3: &GridOperator = "SE556532083401".parse().unwrap();
339
340        // Test equality - same VAT number
341        assert_eq!(operator1, operator2);
342
343        // Test inequality - different VAT numbers
344        assert_ne!(operator1, operator3);
345    }
346
347    #[test]
348    fn test_grid_operator_partial_ord() {
349        let operator1: &GridOperator = "SE556037732601".parse().unwrap();
350        let operator2: &GridOperator = "SE556532083401".parse().unwrap();
351
352        // Test ordering based on VAT number
353        assert!(operator1 < operator2);
354        assert!(operator2 > operator1);
355        assert!(operator1 <= operator1);
356        assert!(operator1 >= operator1);
357    }
358
359    #[test]
360    fn test_grid_operator_ord() {
361        let operator1: &GridOperator = "SE556037732601".parse().unwrap();
362        let operator2: &GridOperator = "SE556532083401".parse().unwrap();
363        let operator3: &GridOperator = "SE556037732601".parse().unwrap();
364
365        // Test Ord::cmp
366        assert_eq!(operator1.cmp(operator2), std::cmp::Ordering::Less);
367        assert_eq!(operator2.cmp(operator1), std::cmp::Ordering::Greater);
368        assert_eq!(operator1.cmp(operator3), std::cmp::Ordering::Equal);
369    }
370
371    #[test]
372    fn test_grid_operator_hash() {
373        use std::collections::hash_map::DefaultHasher;
374        use std::hash::{Hash, Hasher};
375
376        let operator1: &GridOperator = "SE556037732601".parse().unwrap();
377        let operator2: &GridOperator = "SE556037732601".parse().unwrap();
378        let operator3: &GridOperator = "SE556532083401".parse().unwrap();
379
380        // Helper function to compute hash
381        let compute_hash = |op: &GridOperator| {
382            let mut hasher = DefaultHasher::new();
383            op.hash(&mut hasher);
384            hasher.finish()
385        };
386
387        // Operators with same VAT number should have same hash
388        assert_eq!(compute_hash(operator1), compute_hash(operator2));
389
390        // Operators with different VAT numbers should have different hashes
391        assert_ne!(compute_hash(operator1), compute_hash(operator3));
392    }
393}