1use super::{BinaryParameters, BinaryRecord, GroupCount, PureParameters};
2use crate::{FeosResult, parameter::PureRecord};
3use nalgebra::DVector;
4use num_traits::Zero;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8#[derive(Serialize, Deserialize, Clone, Debug)]
10pub struct AssociationRecord<A> {
11 #[serde(skip_serializing_if = "String::is_empty")]
12 #[serde(default)]
13 pub id: String,
14 #[serde(flatten)]
15 pub parameters: Option<A>,
16 #[serde(skip_serializing_if = "f64::is_zero")]
18 #[serde(default)]
19 pub na: f64,
20 #[serde(skip_serializing_if = "f64::is_zero")]
22 #[serde(default)]
23 pub nb: f64,
24 #[serde(skip_serializing_if = "f64::is_zero")]
26 #[serde(default)]
27 pub nc: f64,
28}
29
30impl<A> AssociationRecord<A> {
31 pub fn new(parameters: Option<A>, na: f64, nb: f64, nc: f64) -> Self {
32 Self::with_id(Default::default(), parameters, na, nb, nc)
33 }
34
35 pub fn with_id(id: String, parameters: Option<A>, na: f64, nb: f64, nc: f64) -> Self {
36 Self {
37 id,
38 parameters,
39 na,
40 nb,
41 nc,
42 }
43 }
44}
45
46#[derive(Serialize, Deserialize, Clone, Debug)]
48pub struct BinaryAssociationRecord<A> {
49 #[serde(skip_serializing_if = "String::is_empty")]
51 #[serde(default)]
52 pub id1: String,
53 #[serde(skip_serializing_if = "String::is_empty")]
55 #[serde(default)]
56 pub id2: String,
57 #[serde(flatten)]
59 pub parameters: A,
60}
61
62impl<A> BinaryAssociationRecord<A> {
63 pub fn new(parameters: A) -> Self {
64 Self::with_id(Default::default(), Default::default(), parameters)
65 }
66
67 pub fn with_id(id1: String, id2: String, parameters: A) -> Self {
68 Self {
69 id1,
70 id2,
71 parameters,
72 }
73 }
74}
75
76#[derive(Clone, Debug)]
77pub struct AssociationSite {
78 pub assoc_comp: usize,
79 pub id: String,
80 pub n: f64,
81}
82
83impl AssociationSite {
84 fn new(assoc_comp: usize, id: String, n: f64) -> Self {
85 Self { assoc_comp, id, n }
86 }
87}
88
89pub trait CombiningRule<P> {
90 fn combining_rule(comp_i: &P, comp_j: &P, parameters_i: &Self, parameters_j: &Self) -> Self;
91}
92
93impl<P> CombiningRule<P> for () {
94 fn combining_rule(_: &P, _: &P, _: &Self, _: &Self) {}
95}
96
97#[derive(Clone)]
100pub struct AssociationParameters<A> {
101 pub component_index: DVector<usize>,
102 pub sites_a: Vec<AssociationSite>,
103 pub sites_b: Vec<AssociationSite>,
104 pub sites_c: Vec<AssociationSite>,
105 pub binary_ab: Vec<BinaryParameters<A, ()>>,
106 pub binary_cc: Vec<BinaryParameters<A, ()>>,
107}
108
109impl<A: Clone> AssociationParameters<A> {
110 pub fn new<P, B>(
111 pure_records: &[PureRecord<P, A>],
112 binary_records: &[BinaryRecord<usize, B, A>],
113 ) -> FeosResult<Self>
114 where
115 A: CombiningRule<P>,
116 {
117 let mut sites_a = Vec::new();
118 let mut sites_b = Vec::new();
119 let mut sites_c = Vec::new();
120 let mut pars_a = Vec::new();
121 let mut pars_b = Vec::new();
122 let mut pars_c = Vec::new();
123
124 for (i, record) in pure_records.iter().enumerate() {
125 for site in record.association_sites.iter() {
126 if site.na > 0.0 {
127 sites_a.push(AssociationSite::new(i, site.id.clone(), site.na));
128 pars_a.push(&site.parameters);
129 }
130 if site.nb > 0.0 {
131 sites_b.push(AssociationSite::new(i, site.id.clone(), site.nb));
132 pars_b.push(&site.parameters);
133 }
134 if site.nc > 0.0 {
135 sites_c.push(AssociationSite::new(i, site.id.clone(), site.nc));
136 pars_c.push(&site.parameters);
137 }
138 }
139 }
140
141 let record_map: HashMap<_, _> = binary_records
142 .iter()
143 .flat_map(|br| {
144 br.association_sites.iter().flat_map(|a| {
145 [
146 ((br.id1, br.id2, &a.id1, &a.id2), &a.parameters),
147 ((br.id2, br.id1, &a.id2, &a.id1), &a.parameters),
148 ]
149 })
150 })
151 .collect();
152
153 let mut binary_ab = Vec::new();
154 for ((a, site_a), pa) in sites_a.iter().enumerate().zip(&pars_a) {
155 for ((b, site_b), pb) in sites_b.iter().enumerate().zip(&pars_b) {
156 if let Some(&record) =
157 record_map.get(&(site_a.assoc_comp, site_b.assoc_comp, &site_a.id, &site_b.id))
158 {
159 binary_ab.push(BinaryParameters::new(a, b, record.clone(), ()));
160 } else if let (Some(pa), Some(pb)) = (pa, pb) {
161 binary_ab.push(BinaryParameters::new(
162 a,
163 b,
164 A::combining_rule(
165 &pure_records[site_a.assoc_comp].model_record,
166 &pure_records[site_b.assoc_comp].model_record,
167 pa,
168 pb,
169 ),
170 (),
171 ));
172 }
173 }
174 }
175
176 let mut binary_cc = Vec::new();
177 for ((a, site_a), pa) in sites_c.iter().enumerate().zip(&pars_c) {
178 for ((b, site_b), pb) in sites_c.iter().enumerate().zip(&pars_c) {
179 if let Some(&record) =
180 record_map.get(&(site_a.assoc_comp, site_b.assoc_comp, &site_a.id, &site_b.id))
181 {
182 binary_cc.push(BinaryParameters::new(a, b, record.clone(), ()));
183 } else if let (Some(pa), Some(pb)) = (pa, pb) {
184 binary_cc.push(BinaryParameters::new(
185 a,
186 b,
187 A::combining_rule(
188 &pure_records[site_a.assoc_comp].model_record,
189 &pure_records[site_b.assoc_comp].model_record,
190 pa,
191 pb,
192 ),
193 (),
194 ));
195 }
196 }
197 }
198
199 let component_index = DVector::from_vec((0..pure_records.len()).collect());
200
201 Ok(Self {
202 component_index,
203 sites_a,
204 sites_b,
205 sites_c,
206 binary_ab,
207 binary_cc,
208 })
209 }
210
211 pub fn new_hetero<P, C: GroupCount>(
212 groups: &[PureParameters<P, C>],
213 association_sites: &[Vec<AssociationRecord<A>>],
214 binary_records: &[BinaryParameters<Vec<BinaryAssociationRecord<A>>, ()>],
215 ) -> FeosResult<Self>
216 where
217 A: CombiningRule<P>,
218 {
219 let mut sites_a = Vec::new();
220 let mut sites_b = Vec::new();
221 let mut sites_c = Vec::new();
222 let mut pars_a = Vec::new();
223 let mut pars_b = Vec::new();
224 let mut pars_c = Vec::new();
225
226 for (i, (record, sites)) in groups.iter().zip(association_sites).enumerate() {
227 for site in sites.iter() {
228 if site.na > 0.0 {
229 let na = site.na * record.count.into_f64();
230 sites_a.push(AssociationSite::new(i, site.id.clone(), na));
231 pars_a.push(&site.parameters)
232 }
233 if site.nb > 0.0 {
234 let nb = site.nb * record.count.into_f64();
235 sites_b.push(AssociationSite::new(i, site.id.clone(), nb));
236 pars_b.push(&site.parameters)
237 }
238 if site.nc > 0.0 {
239 let nc = site.nc * record.count.into_f64();
240 sites_c.push(AssociationSite::new(i, site.id.clone(), nc));
241 pars_c.push(&site.parameters)
242 }
243 }
244 }
245
246 let record_map: HashMap<_, _> = binary_records
247 .iter()
248 .flat_map(|br| {
249 br.model_record.iter().flat_map(|a| {
250 [
251 ((br.id1, br.id2, &a.id1, &a.id2), &a.parameters),
252 ((br.id2, br.id1, &a.id2, &a.id1), &a.parameters),
253 ]
254 })
255 })
256 .collect();
257
258 let mut binary_ab = Vec::new();
259 for ((a, site_a), pa) in sites_a.iter().enumerate().zip(&pars_a) {
260 for ((b, site_b), pb) in sites_b.iter().enumerate().zip(&pars_b) {
261 if let Some(&record) =
262 record_map.get(&(site_a.assoc_comp, site_b.assoc_comp, &site_a.id, &site_b.id))
263 {
264 binary_ab.push(BinaryParameters::new(a, b, record.clone(), ()));
265 } else if let (Some(pa), Some(pb)) = (pa, pb) {
266 binary_ab.push(BinaryParameters::new(
267 a,
268 b,
269 A::combining_rule(
270 &groups[site_a.assoc_comp].model_record,
271 &groups[site_b.assoc_comp].model_record,
272 pa,
273 pb,
274 ),
275 (),
276 ));
277 }
278 }
279 }
280
281 let mut binary_cc = Vec::new();
282 for ((a, site_a), pa) in sites_c.iter().enumerate().zip(&pars_c) {
283 for ((b, site_b), pb) in sites_c.iter().enumerate().zip(&pars_c) {
284 if let Some(&record) =
285 record_map.get(&(site_a.assoc_comp, site_b.assoc_comp, &site_a.id, &site_b.id))
286 {
287 binary_cc.push(BinaryParameters::new(a, b, record.clone(), ()));
288 } else if let (Some(pa), Some(pb)) = (pa, pb) {
289 binary_cc.push(BinaryParameters::new(
290 a,
291 b,
292 A::combining_rule(
293 &groups[site_a.assoc_comp].model_record,
294 &groups[site_b.assoc_comp].model_record,
295 pa,
296 pb,
297 ),
298 (),
299 ));
300 }
301 }
302 }
303
304 let component_index =
305 DVector::from_vec(groups.iter().map(|pr| pr.component_index).collect());
306
307 Ok(Self {
308 component_index,
309 sites_a,
310 sites_b,
311 sites_c,
312 binary_ab,
313 binary_cc,
314 })
315 }
316
317 pub fn is_empty(&self) -> bool {
318 (self.sites_a.is_empty() | self.sites_b.is_empty()) & self.sites_c.is_empty()
319 }
320}