1use std::{collections::BTreeSet, iter::FromIterator};
2use aces::{ContextHandle, NodeID};
3use crate::{Node, ToNode, NodeList};
4
5#[derive(Clone, PartialEq, Eq, Debug)]
15pub struct Polynomial {
16 pub(crate) monomials: BTreeSet<BTreeSet<Node>>,
17
18 pub(crate) is_flat: bool,
20}
21
22impl Polynomial {
23 pub(crate) fn with_product_multiplied(mut self, mut factors: Vec<Self>) -> Self {
25 self.multiply_assign(&mut factors);
26 self
27 }
28
29 pub(crate) fn with_product_added(mut self, mut factors: Vec<Self>) -> Self {
31 if let Some((head, tail)) = factors.split_first_mut() {
32 head.multiply_assign(tail);
33 self.add_assign(head);
34 }
35 self
36 }
37
38 pub(crate) fn flattened_clone(&self) -> Self {
42 if self.is_flat {
43 self.clone()
44 } else {
45 let mut more_monos = self.monomials.iter();
46 let mut single_mono = more_monos.next().expect("non-flat empty polynomial").clone();
47
48 for mono in more_monos {
49 single_mono.append(&mut mono.clone());
50 }
51
52 Polynomial { monomials: BTreeSet::from_iter(Some(single_mono)), is_flat: true }
53 }
54 }
55
56 fn multiply_assign(&mut self, factors: &mut [Self]) {
57 for factor in factors {
58 if !factor.is_flat {
59 self.is_flat = false;
60 }
61
62 let lhs: Vec<_> = self.monomials.iter().cloned().collect();
63 self.monomials.clear();
64
65 for this_mono in lhs.iter() {
66 for other_mono in factor.monomials.iter() {
67 let mut mono = this_mono.clone();
68 mono.extend(other_mono.iter().cloned());
69 self.monomials.insert(mono);
70 }
71 }
72 }
73 }
74
75 pub(crate) fn add_assign(&mut self, other: &mut Self) {
76 self.is_flat = false;
77 self.monomials.append(&mut other.monomials);
78 }
79
80 pub(crate) fn compile_as_vec(&self, ctx: &ContextHandle) -> Vec<Vec<NodeID>> {
81 let mut ctx = ctx.lock().unwrap();
82
83 self.monomials
84 .iter()
85 .map(|mono| mono.iter().map(|node| ctx.share_node_name(node)).collect())
86 .collect()
87 }
88}
89
90impl Default for Polynomial {
91 fn default() -> Self {
92 Polynomial { monomials: BTreeSet::default(), is_flat: true }
93 }
94}
95
96impl From<Node> for Polynomial {
97 fn from(node: Node) -> Self {
98 Polynomial {
99 monomials: BTreeSet::from_iter(Some(BTreeSet::from_iter(Some(node)))),
100 is_flat: true,
101 }
102 }
103}
104
105impl From<&str> for Polynomial {
107 fn from(node: &str) -> Self {
108 Polynomial {
109 monomials: BTreeSet::from_iter(Some(BTreeSet::from_iter(Some(node.to_node())))),
110 is_flat: true,
111 }
112 }
113}
114
115impl From<Vec<Node>> for Polynomial {
116 fn from(mono: Vec<Node>) -> Self {
117 Polynomial {
118 monomials: BTreeSet::from_iter(Some(BTreeSet::from_iter(mono.iter().cloned()))),
119 is_flat: true,
120 }
121 }
122}
123
124impl From<Vec<&str>> for Polynomial {
126 fn from(mono: Vec<&str>) -> Self {
127 Polynomial {
128 monomials: BTreeSet::from_iter(Some(BTreeSet::from_iter(
129 mono.iter().map(|n| n.to_node()),
130 ))),
131 is_flat: true,
132 }
133 }
134}
135
136impl From<Vec<Vec<Node>>> for Polynomial {
137 fn from(monos: Vec<Vec<Node>>) -> Self {
138 Polynomial {
139 monomials: BTreeSet::from_iter(
140 monos.into_iter().map(|mono| BTreeSet::from_iter(mono.iter().cloned())),
141 ),
142 is_flat: false,
143 }
144 }
145}
146
147impl From<Vec<Vec<&str>>> for Polynomial {
149 fn from(monos: Vec<Vec<&str>>) -> Self {
150 Polynomial {
151 monomials: BTreeSet::from_iter(
152 monos.into_iter().map(|mono| BTreeSet::from_iter(mono.iter().map(|n| n.to_node()))),
153 ),
154 is_flat: false,
155 }
156 }
157}
158
159impl From<NodeList> for Polynomial {
160 fn from(mono: NodeList) -> Self {
161 Polynomial {
162 monomials: BTreeSet::from_iter(Some(BTreeSet::from_iter(mono.nodes.iter().cloned()))),
163 is_flat: true,
164 }
165 }
166}
167
168impl From<Vec<NodeList>> for Polynomial {
169 fn from(monos: Vec<NodeList>) -> Self {
170 Polynomial {
171 monomials: BTreeSet::from_iter(
172 monos.into_iter().map(|mono| BTreeSet::from_iter(mono.nodes.iter().cloned())),
173 ),
174 is_flat: false,
175 }
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use crate::ToNode;
182 use super::*;
183
184 #[test]
185 fn test_poly() {
186 let phrase = "(a (b + c) d e) + f g";
187 let poly: Polynomial = phrase.parse().unwrap();
188
189 assert_eq!(
190 poly,
191 Polynomial {
192 monomials: BTreeSet::from_iter(vec![
193 BTreeSet::from_iter(
194 vec!["a".to_node(), "b".to_node(), "d".to_node(), "e".to_node()]
195 .into_iter()
196 ),
197 BTreeSet::from_iter(
198 vec!["a".to_node(), "c".to_node(), "d".to_node(), "e".to_node()]
199 .into_iter()
200 ),
201 BTreeSet::from_iter(vec!["f".to_node(), "g".to_node()].into_iter()),
202 ]),
203 is_flat: false,
204 }
205 );
206 }
207}