1use std::marker::PhantomData;
2
3use num::Num;
4use serde::{
5 de::Visitor,
6 ser::SerializeSeq,
7 Deserialize,
8 Serialize,
9};
10
11use crate::{
12 code::fermions::{
13 An,
14 Cr,
15 Fermions,
16 Orbital,
17 },
18 serialize::Encoding,
19 terms::SumRepr,
20};
21
22impl Serialize for Fermions {
23 fn serialize<S>(
24 &self,
25 serializer: S,
26 ) -> Result<S::Ok, S::Error>
27 where
28 S: serde::Serializer,
29 {
30 match self {
31 Fermions::Offset => {
32 let seq = serializer.serialize_seq(Some(0))?;
33 seq.end()
34 }
35 Fermions::One {
36 cr,
37 an,
38 } => {
39 let mut seq = serializer.serialize_seq(Some(2))?;
40 seq.serialize_element(&cr.0.index())?;
41 seq.serialize_element(&an.0.index())?;
42 seq.end()
43 }
44 Fermions::Two {
45 cr,
46 an,
47 } => {
48 let mut seq = serializer.serialize_seq(Some(4))?;
49 seq.serialize_element(&cr.0 .0.index())?;
50 seq.serialize_element(&cr.1 .0.index())?;
51 seq.serialize_element(&an.0 .0.index())?;
52 seq.serialize_element(&an.1 .0.index())?;
53 seq.end()
54 }
55 }
56 }
57}
58
59struct FermionsVisitor;
60
61impl<'de> Visitor<'de> for FermionsVisitor {
62 type Value = Fermions;
63
64 fn expecting(
65 &self,
66 formatter: &mut std::fmt::Formatter,
67 ) -> std::fmt::Result {
68 formatter.write_str("sequence of 0, 2 or 4 orbital indices")
69 }
70
71 fn visit_seq<A>(
72 self,
73 seq: A,
74 ) -> Result<Self::Value, A::Error>
75 where
76 A: serde::de::SeqAccess<'de>,
77 {
78 use serde::de::Error;
79
80 let mut seq = seq;
81 let idx_tup: (Option<u32>, Option<u32>, Option<u32>, Option<u32>) = (
82 seq.next_element()?,
83 seq.next_element()?,
84 seq.next_element()?,
85 seq.next_element()?,
86 );
87
88 match idx_tup {
89 (None, None, None, None) => Ok(Fermions::Offset),
90 (Some(p), Some(q), None, None) => Fermions::one_electron(
91 Cr(Orbital::with_index(p)),
92 An(Orbital::with_index(q)),
93 )
94 .ok_or(A::Error::custom("cannot parse one-electron term")),
95 (Some(p), Some(q), Some(r), Some(s)) => Fermions::two_electron(
96 (Cr(Orbital::with_index(p)), Cr(Orbital::with_index(q))),
97 (An(Orbital::with_index(r)), An(Orbital::with_index(s))),
98 )
99 .ok_or(A::Error::custom("cannot parse two-electron term")),
100 _ => Err(A::Error::custom("cannot parse sequence")),
101 }
102 }
103}
104
105impl<'de> Deserialize<'de> for Fermions {
106 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
107 where
108 D: serde::Deserializer<'de>,
109 {
110 deserializer.deserialize_seq(FermionsVisitor)
111 }
112}
113
114#[derive(Serialize, Deserialize)]
115struct FermiSumTerm<T> {
116 code: Fermions,
117 value: T,
118}
119
120struct FermiSumSerSequence<'a, T>(&'a SumRepr<T, Fermions>);
121
122impl<'a, T> Serialize for FermiSumSerSequence<'a, T>
123where
124 T: Num + Serialize,
125{
126 fn serialize<S>(
127 &self,
128 serializer: S,
129 ) -> Result<S::Ok, S::Error>
130 where
131 S: serde::Serializer,
132 {
133 let mut seq = serializer.serialize_seq(Some(self.0.len()))?;
134 for (coeff, &code) in self.0.iter() {
135 seq.serialize_element(&FermiSumTerm {
136 code,
137 value: coeff,
138 })?;
139 }
140
141 seq.end()
142 }
143}
144
145#[derive(Serialize)]
146struct FermiSumSer<'a, T>
147where
148 T: Num,
149{
150 r#type: &'a str,
151 encoding: Encoding,
152 terms: FermiSumSerSequence<'a, T>,
153}
154
155impl<T> Serialize for SumRepr<T, Fermions>
156where
157 T: Num + Serialize,
158{
159 fn serialize<S>(
160 &self,
161 serializer: S,
162 ) -> Result<S::Ok, S::Error>
163 where
164 S: serde::Serializer,
165 {
166 (FermiSumSer {
167 r#type: "sumrepr",
168 encoding: Encoding::Fermions,
169 terms: FermiSumSerSequence(self),
170 })
171 .serialize(serializer)
172 }
173}
174
175struct FermiSumDeSequence<T>(SumRepr<T, Fermions>);
176
177struct FermiSumVisitor<T> {
178 _marker: PhantomData<T>,
179}
180
181impl<T> FermiSumVisitor<T> {
182 fn new() -> Self {
183 Self {
184 _marker: PhantomData,
185 }
186 }
187}
188
189impl<'de, T> Visitor<'de> for FermiSumVisitor<T>
190where
191 T: Num + Deserialize<'de>,
192{
193 type Value = FermiSumDeSequence<T>;
194
195 fn expecting(
196 &self,
197 formatter: &mut std::fmt::Formatter,
198 ) -> std::fmt::Result {
199 write!(formatter, "sequence of objects with keys: 'code', 'value'")
200 }
201
202 fn visit_seq<A>(
203 self,
204 seq: A,
205 ) -> Result<Self::Value, A::Error>
206 where
207 A: serde::de::SeqAccess<'de>,
208 {
209 let mut seq = seq;
210 let mut repr = SumRepr::new();
211
212 while let Some(FermiSumTerm {
213 code,
214 value,
215 }) = seq.next_element()?
216 {
217 repr.add_term(code, value);
218 }
219
220 Ok(FermiSumDeSequence(repr))
221 }
222}
223
224impl<'de, T> Deserialize<'de> for FermiSumDeSequence<T>
225where
226 T: Num + Deserialize<'de>,
227{
228 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
229 where
230 D: serde::Deserializer<'de>,
231 {
232 deserializer.deserialize_seq(FermiSumVisitor::new())
233 }
234}
235
236#[derive(Deserialize)]
237struct FermiSumDe<T>
238where
239 T: Num,
240{
241 r#type: String,
242 encoding: Encoding,
243 terms: FermiSumDeSequence<T>,
244}
245
246impl<'de, T> Deserialize<'de> for SumRepr<T, Fermions>
247where
248 T: Num + Deserialize<'de>,
249{
250 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
251 where
252 D: serde::Deserializer<'de>,
253 {
254 use serde::de::Error;
255
256 let sumde = FermiSumDe::deserialize(deserializer)?;
257
258 if sumde.r#type != "sumrepr" {
259 return Err(D::Error::custom("type should be: 'sumrepr'"));
260 }
261
262 if sumde.encoding != Encoding::Fermions {
263 return Err(D::Error::custom("encoding should be: 'fermions'"));
264 }
265
266 Ok(sumde.terms.0)
267 }
268}