f2q/serialize/
fermions.rs

1use std::marker::PhantomData;
2
3use num::Num;
4use serde::{
5    de::Visitor,
6    ser::SerializeSeq,
7    Deserialize,
8    Serialize,
9};
10
11use crate::{
12    code::fermions::{
13        An,
14        Cr,
15        Fermions,
16        Orbital,
17    },
18    serialize::Encoding,
19    terms::SumRepr,
20};
21
22impl Serialize for Fermions {
23    fn serialize<S>(
24        &self,
25        serializer: S,
26    ) -> Result<S::Ok, S::Error>
27    where
28        S: serde::Serializer,
29    {
30        match self {
31            Fermions::Offset => {
32                let seq = serializer.serialize_seq(Some(0))?;
33                seq.end()
34            }
35            Fermions::One {
36                cr,
37                an,
38            } => {
39                let mut seq = serializer.serialize_seq(Some(2))?;
40                seq.serialize_element(&cr.0.index())?;
41                seq.serialize_element(&an.0.index())?;
42                seq.end()
43            }
44            Fermions::Two {
45                cr,
46                an,
47            } => {
48                let mut seq = serializer.serialize_seq(Some(4))?;
49                seq.serialize_element(&cr.0 .0.index())?;
50                seq.serialize_element(&cr.1 .0.index())?;
51                seq.serialize_element(&an.0 .0.index())?;
52                seq.serialize_element(&an.1 .0.index())?;
53                seq.end()
54            }
55        }
56    }
57}
58
59struct FermionsVisitor;
60
61impl<'de> Visitor<'de> for FermionsVisitor {
62    type Value = Fermions;
63
64    fn expecting(
65        &self,
66        formatter: &mut std::fmt::Formatter,
67    ) -> std::fmt::Result {
68        formatter.write_str("sequence of 0, 2 or 4 orbital indices")
69    }
70
71    fn visit_seq<A>(
72        self,
73        seq: A,
74    ) -> Result<Self::Value, A::Error>
75    where
76        A: serde::de::SeqAccess<'de>,
77    {
78        use serde::de::Error;
79
80        let mut seq = seq;
81        let idx_tup: (Option<u32>, Option<u32>, Option<u32>, Option<u32>) = (
82            seq.next_element()?,
83            seq.next_element()?,
84            seq.next_element()?,
85            seq.next_element()?,
86        );
87
88        match idx_tup {
89            (None, None, None, None) => Ok(Fermions::Offset),
90            (Some(p), Some(q), None, None) => Fermions::one_electron(
91                Cr(Orbital::with_index(p)),
92                An(Orbital::with_index(q)),
93            )
94            .ok_or(A::Error::custom("cannot parse one-electron term")),
95            (Some(p), Some(q), Some(r), Some(s)) => Fermions::two_electron(
96                (Cr(Orbital::with_index(p)), Cr(Orbital::with_index(q))),
97                (An(Orbital::with_index(r)), An(Orbital::with_index(s))),
98            )
99            .ok_or(A::Error::custom("cannot parse two-electron term")),
100            _ => Err(A::Error::custom("cannot parse sequence")),
101        }
102    }
103}
104
105impl<'de> Deserialize<'de> for Fermions {
106    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
107    where
108        D: serde::Deserializer<'de>,
109    {
110        deserializer.deserialize_seq(FermionsVisitor)
111    }
112}
113
114#[derive(Serialize, Deserialize)]
115struct FermiSumTerm<T> {
116    code:  Fermions,
117    value: T,
118}
119
120struct FermiSumSerSequence<'a, T>(&'a SumRepr<T, Fermions>);
121
122impl<'a, T> Serialize for FermiSumSerSequence<'a, T>
123where
124    T: Num + Serialize,
125{
126    fn serialize<S>(
127        &self,
128        serializer: S,
129    ) -> Result<S::Ok, S::Error>
130    where
131        S: serde::Serializer,
132    {
133        let mut seq = serializer.serialize_seq(Some(self.0.len()))?;
134        for (coeff, &code) in self.0.iter() {
135            seq.serialize_element(&FermiSumTerm {
136                code,
137                value: coeff,
138            })?;
139        }
140
141        seq.end()
142    }
143}
144
145#[derive(Serialize)]
146struct FermiSumSer<'a, T>
147where
148    T: Num,
149{
150    r#type:   &'a str,
151    encoding: Encoding,
152    terms:    FermiSumSerSequence<'a, T>,
153}
154
155impl<T> Serialize for SumRepr<T, Fermions>
156where
157    T: Num + Serialize,
158{
159    fn serialize<S>(
160        &self,
161        serializer: S,
162    ) -> Result<S::Ok, S::Error>
163    where
164        S: serde::Serializer,
165    {
166        (FermiSumSer {
167            r#type:   "sumrepr",
168            encoding: Encoding::Fermions,
169            terms:    FermiSumSerSequence(self),
170        })
171        .serialize(serializer)
172    }
173}
174
175struct FermiSumDeSequence<T>(SumRepr<T, Fermions>);
176
177struct FermiSumVisitor<T> {
178    _marker: PhantomData<T>,
179}
180
181impl<T> FermiSumVisitor<T> {
182    fn new() -> Self {
183        Self {
184            _marker: PhantomData,
185        }
186    }
187}
188
189impl<'de, T> Visitor<'de> for FermiSumVisitor<T>
190where
191    T: Num + Deserialize<'de>,
192{
193    type Value = FermiSumDeSequence<T>;
194
195    fn expecting(
196        &self,
197        formatter: &mut std::fmt::Formatter,
198    ) -> std::fmt::Result {
199        write!(formatter, "sequence of objects with keys: 'code', 'value'")
200    }
201
202    fn visit_seq<A>(
203        self,
204        seq: A,
205    ) -> Result<Self::Value, A::Error>
206    where
207        A: serde::de::SeqAccess<'de>,
208    {
209        let mut seq = seq;
210        let mut repr = SumRepr::new();
211
212        while let Some(FermiSumTerm {
213            code,
214            value,
215        }) = seq.next_element()?
216        {
217            repr.add_term(code, value);
218        }
219
220        Ok(FermiSumDeSequence(repr))
221    }
222}
223
224impl<'de, T> Deserialize<'de> for FermiSumDeSequence<T>
225where
226    T: Num + Deserialize<'de>,
227{
228    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
229    where
230        D: serde::Deserializer<'de>,
231    {
232        deserializer.deserialize_seq(FermiSumVisitor::new())
233    }
234}
235
236#[derive(Deserialize)]
237struct FermiSumDe<T>
238where
239    T: Num,
240{
241    r#type:   String,
242    encoding: Encoding,
243    terms:    FermiSumDeSequence<T>,
244}
245
246impl<'de, T> Deserialize<'de> for SumRepr<T, Fermions>
247where
248    T: Num + Deserialize<'de>,
249{
250    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
251    where
252        D: serde::Deserializer<'de>,
253    {
254        use serde::de::Error;
255
256        let sumde = FermiSumDe::deserialize(deserializer)?;
257
258        if sumde.r#type != "sumrepr" {
259            return Err(D::Error::custom("type should be: 'sumrepr'"));
260        }
261
262        if sumde.encoding != Encoding::Fermions {
263            return Err(D::Error::custom("encoding should be: 'fermions'"));
264        }
265
266        Ok(sumde.terms.0)
267    }
268}