qudit_expr/expressions/
named.rs

1use itertools::Itertools;
2use qudit_core::RealScalar;
3use std::{
4    collections::HashMap,
5    ops::{Deref, DerefMut},
6};
7
8use crate::ComplexExpression;
9
10#[derive(PartialEq, Eq, Debug, Clone)]
11pub struct ExpressionBody {
12    body: Vec<ComplexExpression>,
13}
14
15impl ExpressionBody {
16    pub fn new(body: Vec<ComplexExpression>) -> Self {
17        Self { body }
18    }
19
20    pub fn num_elements(&self) -> usize {
21        self.body.len()
22    }
23
24    pub fn elements(&self) -> &[ComplexExpression] {
25        &self.body
26    }
27
28    pub fn elements_mut(&mut self) -> &mut [ComplexExpression] {
29        &mut self.body
30    }
31
32    pub fn conjugate(&mut self) {
33        for expr in self.body.iter_mut() {
34            expr.conjugate_in_place()
35        }
36    }
37
38    pub fn apply_element_permutation(&mut self, elem_perm: &[usize]) {
39        // TODO: do physical element permutation in place via transpositions
40        let mut swap_vec = vec![];
41        std::mem::swap(&mut swap_vec, &mut self.body);
42        self.body = swap_vec
43            .into_iter()
44            .enumerate()
45            .sorted_by(|(old_idx_a, _), (old_idx_b, _)| {
46                elem_perm[*old_idx_a].cmp(&elem_perm[*old_idx_b])
47            })
48            .map(|(_, expr)| expr)
49            .collect();
50    }
51
52    pub fn rename(&mut self, var_map: HashMap<String, String>) {
53        for elem in self.body.iter_mut() {
54            *elem = elem.map_var_names(&var_map);
55        }
56    }
57}
58
59impl From<Vec<ComplexExpression>> for ExpressionBody {
60    fn from(value: Vec<ComplexExpression>) -> Self {
61        ExpressionBody::new(value)
62    }
63}
64
65impl From<ExpressionBody> for Vec<ComplexExpression> {
66    fn from(value: ExpressionBody) -> Self {
67        value.body
68    }
69}
70
71impl AsRef<[ComplexExpression]> for ExpressionBody {
72    fn as_ref(&self) -> &[ComplexExpression] {
73        self.elements()
74    }
75}
76
77impl Deref for ExpressionBody {
78    type Target = [ComplexExpression];
79
80    fn deref(&self) -> &Self::Target {
81        &self.body
82    }
83}
84
85impl DerefMut for ExpressionBody {
86    fn deref_mut(&mut self) -> &mut Self::Target {
87        &mut self.body
88    }
89}
90
91#[derive(PartialEq, Eq, Debug, Clone)]
92pub struct BoundExpressionBody {
93    variables: Vec<String>,
94    body: ExpressionBody,
95}
96
97impl BoundExpressionBody {
98    pub fn new<B: Into<ExpressionBody>>(variables: Vec<String>, body: B) -> Self {
99        Self {
100            variables,
101            body: body.into(),
102        }
103    }
104
105    pub fn num_params(&self) -> usize {
106        self.variables.len()
107    }
108
109    pub fn variables(&self) -> &[String] {
110        &self.variables
111    }
112
113    pub fn set_variables(&mut self, new_variables: Vec<String>) {
114        self.variables = new_variables;
115    }
116
117    pub fn num_elements(&self) -> usize {
118        self.body.num_elements()
119    }
120
121    pub fn elements(&self) -> &[ComplexExpression] {
122        self.body.elements()
123    }
124
125    pub fn elements_mut(&mut self) -> &mut [ComplexExpression] {
126        self.body.elements_mut()
127    }
128
129    pub fn conjugate(&mut self) {
130        self.body.conjugate()
131    }
132
133    pub fn destruct(self) -> (Vec<String>, Vec<ComplexExpression>) {
134        let Self { variables, body } = self;
135        (variables, body.into())
136    }
137
138    pub fn apply_element_permutation(&mut self, elem_perm: &[usize]) {
139        self.body.apply_element_permutation(elem_perm);
140    }
141
142    pub fn alpha_rename(&mut self, starting_number: Option<usize>) {
143        let mut var_id = starting_number.unwrap_or_default();
144
145        let mut var_map = HashMap::new();
146        let mut new_vars = Vec::new();
147
148        for var in self.variables() {
149            let new_var_name = format!("alpha_{}", var_id);
150            new_vars.push(new_var_name.clone());
151            var_map.insert(var.clone(), new_var_name);
152            var_id += 1;
153        }
154
155        self.body.rename(var_map);
156        self.variables = new_vars;
157    }
158
159    pub fn get_arg_map<R: RealScalar>(&self, args: &[R]) -> HashMap<&str, R> {
160        self.variables()
161            .iter()
162            .zip(args.iter())
163            .map(|(a, b)| (a.as_str(), *b))
164            .collect()
165    }
166}
167
168impl AsRef<[ComplexExpression]> for BoundExpressionBody {
169    fn as_ref(&self) -> &[ComplexExpression] {
170        self.elements()
171    }
172}
173
174impl Deref for BoundExpressionBody {
175    type Target = ExpressionBody;
176
177    fn deref(&self) -> &Self::Target {
178        &self.body
179    }
180}
181
182impl DerefMut for BoundExpressionBody {
183    fn deref_mut(&mut self) -> &mut Self::Target {
184        &mut self.body
185    }
186}
187
188#[derive(Debug, Clone)]
189pub struct NamedExpression {
190    name: String,
191    body: BoundExpressionBody,
192}
193
194impl NamedExpression {
195    pub fn new<S: Into<String>, B: Into<ExpressionBody>>(
196        name: S,
197        variables: Vec<String>,
198        body: B,
199    ) -> Self {
200        Self {
201            name: name.into(),
202            body: BoundExpressionBody::new(variables, body),
203        }
204    }
205
206    pub fn from_body_with_name(name: String, body: BoundExpressionBody) -> Self {
207        Self { name, body }
208    }
209
210    pub fn name(&self) -> &str {
211        &self.name
212    }
213
214    pub fn set_name(&mut self, new_name: impl Into<String>) {
215        self.name = new_name.into();
216    }
217
218    pub fn set_variables(&mut self, new_variables: Vec<String>) {
219        self.body.set_variables(new_variables);
220    }
221
222    pub fn destruct(self) -> (String, Vec<String>, Vec<ComplexExpression>) {
223        let Self { name, body } = self;
224        let (variables, body) = body.destruct();
225        (name, variables, body)
226    }
227
228    pub fn apply_element_permutation(&mut self, elem_perm: &[usize]) {
229        self.body.apply_element_permutation(elem_perm);
230    }
231}
232
233impl AsRef<[ComplexExpression]> for NamedExpression {
234    fn as_ref(&self) -> &[ComplexExpression] {
235        self.elements()
236    }
237}
238
239impl AsRef<BoundExpressionBody> for NamedExpression {
240    fn as_ref(&self) -> &BoundExpressionBody {
241        &self.body
242    }
243}
244
245impl<B: AsRef<[ComplexExpression]>> PartialEq<B> for NamedExpression {
246    fn eq(&self, other: &B) -> bool {
247        self.elements() == other.as_ref()
248    }
249}
250
251impl Eq for NamedExpression {}
252
253impl Deref for NamedExpression {
254    type Target = BoundExpressionBody;
255
256    fn deref(&self) -> &Self::Target {
257        &self.body
258    }
259}
260
261impl DerefMut for NamedExpression {
262    fn deref_mut(&mut self) -> &mut Self::Target {
263        &mut self.body
264    }
265}