1use crate::error::{QuantRS2Error, QuantRS2Result};
8use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
9use num_complex::Complex64;
10use scirs2_linalg::{det, inv};
11use scirs2_sparse::csr::CsrMatrix;
12use std::fmt::Debug;
13
14pub trait QuantumMatrix: Debug + Send + Sync {
16 fn dim(&self) -> usize;
18
19 fn to_dense(&self) -> Array2<Complex64>;
21
22 fn to_sparse(&self) -> QuantRS2Result<CsrMatrix<Complex64>>;
24
25 fn is_unitary(&self, tolerance: f64) -> QuantRS2Result<bool>;
27
28 fn tensor_product(&self, other: &dyn QuantumMatrix) -> QuantRS2Result<Array2<Complex64>>;
30
31 fn apply(&self, state: &ArrayView1<Complex64>) -> QuantRS2Result<Array1<Complex64>>;
33}
34
35#[derive(Debug, Clone)]
37pub struct DenseMatrix {
38 data: Array2<Complex64>,
39}
40
41impl DenseMatrix {
42 pub fn new(data: Array2<Complex64>) -> QuantRS2Result<Self> {
44 if data.nrows() != data.ncols() {
45 return Err(QuantRS2Error::InvalidInput(
46 "Matrix must be square".to_string(),
47 ));
48 }
49 Ok(Self { data })
50 }
51
52 pub fn from_vec(data: Vec<Complex64>, dim: usize) -> QuantRS2Result<Self> {
54 if data.len() != dim * dim {
55 return Err(QuantRS2Error::InvalidInput(format!(
56 "Expected {} elements, got {}",
57 dim * dim,
58 data.len()
59 )));
60 }
61 let matrix = Array2::from_shape_vec((dim, dim), data)
62 .map_err(|e| QuantRS2Error::InvalidInput(e.to_string()))?;
63 Self::new(matrix)
64 }
65
66 pub fn as_array(&self) -> &Array2<Complex64> {
68 &self.data
69 }
70
71 pub fn is_hermitian(&self, tolerance: f64) -> bool {
73 let n = self.data.nrows();
74 for i in 0..n {
75 for j in i..n {
76 let diff = (self.data[[i, j]] - self.data[[j, i]].conj()).norm();
77 if diff > tolerance {
78 return false;
79 }
80 }
81 }
82 true
83 }
84}
85
86impl QuantumMatrix for DenseMatrix {
87 fn dim(&self) -> usize {
88 self.data.nrows()
89 }
90
91 fn to_dense(&self) -> Array2<Complex64> {
92 self.data.clone()
93 }
94
95 fn to_sparse(&self) -> QuantRS2Result<CsrMatrix<Complex64>> {
96 let n = self.dim();
97 let mut rows = Vec::new();
98 let mut cols = Vec::new();
99 let mut data = Vec::new();
100
101 let tolerance = 1e-14;
102 for i in 0..n {
103 for j in 0..n {
104 let val = self.data[[i, j]];
105 if val.norm() > tolerance {
106 rows.push(i);
107 cols.push(j);
108 data.push(val);
109 }
110 }
111 }
112
113 CsrMatrix::new(data, rows, cols, (n, n))
114 .map_err(|e| QuantRS2Error::InvalidInput(e.to_string()))
115 }
116
117 fn is_unitary(&self, tolerance: f64) -> QuantRS2Result<bool> {
118 let n = self.dim();
119 let conj_transpose = self.data.t().mapv(|x| x.conj());
120 let product = self.data.dot(&conj_transpose);
121
122 for i in 0..n {
124 for j in 0..n {
125 let expected = if i == j {
126 Complex64::new(1.0, 0.0)
127 } else {
128 Complex64::new(0.0, 0.0)
129 };
130 let diff = (product[[i, j]] - expected).norm();
131 if diff > tolerance {
132 return Ok(false);
133 }
134 }
135 }
136 Ok(true)
137 }
138
139 fn tensor_product(&self, other: &dyn QuantumMatrix) -> QuantRS2Result<Array2<Complex64>> {
140 let other_dense = other.to_dense();
141 let n1 = self.dim();
142 let n2 = other_dense.nrows();
143 let n = n1 * n2;
144
145 let mut result = Array2::zeros((n, n));
146
147 for i1 in 0..n1 {
148 for j1 in 0..n1 {
149 let val1 = self.data[[i1, j1]];
150 for i2 in 0..n2 {
151 for j2 in 0..n2 {
152 let val2 = other_dense[[i2, j2]];
153 result[[i1 * n2 + i2, j1 * n2 + j2]] = val1 * val2;
154 }
155 }
156 }
157 }
158
159 Ok(result)
160 }
161
162 fn apply(&self, state: &ArrayView1<Complex64>) -> QuantRS2Result<Array1<Complex64>> {
163 if state.len() != self.dim() {
164 return Err(QuantRS2Error::InvalidInput(format!(
165 "State dimension {} doesn't match matrix dimension {}",
166 state.len(),
167 self.dim()
168 )));
169 }
170 Ok(self.data.dot(state))
171 }
172}
173
174#[derive(Clone)]
176pub struct SparseMatrix {
177 csr: CsrMatrix<Complex64>,
178 dim: usize,
179}
180
181impl Debug for SparseMatrix {
182 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183 f.debug_struct("SparseMatrix")
184 .field("dim", &self.dim)
185 .field("nnz", &self.csr.nnz())
186 .finish()
187 }
188}
189
190impl SparseMatrix {
191 pub fn new(csr: CsrMatrix<Complex64>) -> QuantRS2Result<Self> {
193 let (rows, cols) = csr.shape();
194 if rows != cols {
195 return Err(QuantRS2Error::InvalidInput(
196 "Matrix must be square".to_string(),
197 ));
198 }
199 Ok(Self { csr, dim: rows })
200 }
201
202 pub fn from_triplets(
204 rows: Vec<usize>,
205 cols: Vec<usize>,
206 data: Vec<Complex64>,
207 dim: usize,
208 ) -> QuantRS2Result<Self> {
209 let csr = CsrMatrix::new(data, rows, cols, (dim, dim))
210 .map_err(|e| QuantRS2Error::InvalidInput(e.to_string()))?;
211 Self::new(csr)
212 }
213}
214
215impl QuantumMatrix for SparseMatrix {
216 fn dim(&self) -> usize {
217 self.dim
218 }
219
220 fn to_dense(&self) -> Array2<Complex64> {
221 let dense_vec = self.csr.to_dense();
222 let rows = dense_vec.len();
223 let cols = if rows > 0 { dense_vec[0].len() } else { 0 };
224
225 let mut flat = Vec::with_capacity(rows * cols);
226 for row in dense_vec {
227 flat.extend(row);
228 }
229
230 Array2::from_shape_vec((rows, cols), flat).unwrap()
231 }
232
233 fn to_sparse(&self) -> QuantRS2Result<CsrMatrix<Complex64>> {
234 Ok(self.csr.clone())
235 }
236
237 fn is_unitary(&self, tolerance: f64) -> QuantRS2Result<bool> {
238 let dense = DenseMatrix::new(self.to_dense())?;
240 dense.is_unitary(tolerance)
241 }
242
243 fn tensor_product(&self, other: &dyn QuantumMatrix) -> QuantRS2Result<Array2<Complex64>> {
244 let dense = DenseMatrix::new(self.to_dense())?;
247 dense.tensor_product(other)
248 }
249
250 fn apply(&self, state: &ArrayView1<Complex64>) -> QuantRS2Result<Array1<Complex64>> {
251 if state.len() != self.dim() {
252 return Err(QuantRS2Error::InvalidInput(format!(
253 "State dimension {} doesn't match matrix dimension {}",
254 state.len(),
255 self.dim()
256 )));
257 }
258 let dense = self.to_dense();
260 Ok(dense.dot(state))
261 }
262}
263
264pub fn partial_trace(
266 matrix: &Array2<Complex64>,
267 keep_qubits: &[usize],
268 total_qubits: usize,
269) -> QuantRS2Result<Array2<Complex64>> {
270 let full_dim = 1 << total_qubits;
271 if matrix.nrows() != full_dim || matrix.ncols() != full_dim {
272 return Err(QuantRS2Error::InvalidInput(format!(
273 "Matrix dimension {} doesn't match {} qubits",
274 matrix.nrows(),
275 total_qubits
276 )));
277 }
278
279 let keep_dim = 1 << keep_qubits.len();
280 let trace_qubits: Vec<usize> = (0..total_qubits)
281 .filter(|q| !keep_qubits.contains(q))
282 .collect();
283 let trace_dim = 1 << trace_qubits.len();
284
285 let mut result = Array2::zeros((keep_dim, keep_dim));
286
287 for i in 0..keep_dim {
289 for j in 0..keep_dim {
290 let mut sum = Complex64::new(0.0, 0.0);
291
292 for t in 0..trace_dim {
294 let row_idx = reconstruct_index(i, t, keep_qubits, &trace_qubits, total_qubits);
295 let col_idx = reconstruct_index(j, t, keep_qubits, &trace_qubits, total_qubits);
296 sum += matrix[[row_idx, col_idx]];
297 }
298
299 result[[i, j]] = sum;
300 }
301 }
302
303 Ok(result)
304}
305
306fn reconstruct_index(
308 keep_idx: usize,
309 trace_idx: usize,
310 keep_qubits: &[usize],
311 trace_qubits: &[usize],
312 total_qubits: usize,
313) -> usize {
314 let mut index = 0;
315
316 for (i, &q) in keep_qubits.iter().enumerate() {
318 if (keep_idx >> i) & 1 == 1 {
319 index |= 1 << q;
320 }
321 }
322
323 for (i, &q) in trace_qubits.iter().enumerate() {
325 if (trace_idx >> i) & 1 == 1 {
326 index |= 1 << q;
327 }
328 }
329
330 index
331}
332
333pub fn tensor_product_many(matrices: &[&dyn QuantumMatrix]) -> QuantRS2Result<Array2<Complex64>> {
335 if matrices.is_empty() {
336 return Err(QuantRS2Error::InvalidInput(
337 "Cannot compute tensor product of empty list".to_string(),
338 ));
339 }
340
341 if matrices.len() == 1 {
342 return Ok(matrices[0].to_dense());
343 }
344
345 let mut result = matrices[0].to_dense();
346 for matrix in matrices.iter().skip(1) {
347 let dense_result = DenseMatrix::new(result)?;
348 result = dense_result.tensor_product(*matrix)?;
349 }
350
351 Ok(result)
352}
353
354pub fn matrices_approx_equal(
356 a: &ArrayView2<Complex64>,
357 b: &ArrayView2<Complex64>,
358 tolerance: f64,
359) -> bool {
360 if a.shape() != b.shape() {
361 return false;
362 }
363
364 for (x, y) in a.iter().zip(b.iter()) {
365 if (x - y).norm() > tolerance {
366 return false;
367 }
368 }
369
370 true
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376 use num_complex::Complex64;
377
378 #[test]
379 fn test_dense_matrix_creation() {
380 let data = Array2::from_shape_vec(
381 (2, 2),
382 vec![
383 Complex64::new(1.0, 0.0),
384 Complex64::new(0.0, 0.0),
385 Complex64::new(0.0, 0.0),
386 Complex64::new(1.0, 0.0),
387 ],
388 )
389 .unwrap();
390
391 let matrix = DenseMatrix::new(data).unwrap();
392 assert_eq!(matrix.dim(), 2);
393 }
394
395 #[test]
396 fn test_unitary_check() {
397 let sqrt2 = 1.0 / 2.0_f64.sqrt();
399 let data = Array2::from_shape_vec(
400 (2, 2),
401 vec![
402 Complex64::new(sqrt2, 0.0),
403 Complex64::new(sqrt2, 0.0),
404 Complex64::new(sqrt2, 0.0),
405 Complex64::new(-sqrt2, 0.0),
406 ],
407 )
408 .unwrap();
409
410 let matrix = DenseMatrix::new(data).unwrap();
411 assert!(matrix.is_unitary(1e-10).unwrap());
412 }
413
414 #[test]
415 fn test_tensor_product() {
416 let id = DenseMatrix::new(
418 Array2::from_shape_vec(
419 (2, 2),
420 vec![
421 Complex64::new(1.0, 0.0),
422 Complex64::new(0.0, 0.0),
423 Complex64::new(0.0, 0.0),
424 Complex64::new(1.0, 0.0),
425 ],
426 )
427 .unwrap(),
428 )
429 .unwrap();
430
431 let x = DenseMatrix::new(
432 Array2::from_shape_vec(
433 (2, 2),
434 vec![
435 Complex64::new(0.0, 0.0),
436 Complex64::new(1.0, 0.0),
437 Complex64::new(1.0, 0.0),
438 Complex64::new(0.0, 0.0),
439 ],
440 )
441 .unwrap(),
442 )
443 .unwrap();
444
445 let result = id.tensor_product(&x).unwrap();
446 assert_eq!(result.shape(), &[4, 4]);
447
448 assert_eq!(result[[0, 1]], Complex64::new(1.0, 0.0));
450 assert_eq!(result[[2, 3]], Complex64::new(1.0, 0.0));
451 }
452}