1use crate::error::{QuantRS2Error, QuantRS2Result};
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
9use scirs2_core::Complex64;
10use crate::linalg_stubs::CsrMatrix;
12use std::fmt::Debug;
13
14pub trait QuantumMatrix: Debug + Send + Sync {
16 fn dim(&self) -> usize;
18
19 fn to_dense(&self) -> Array2<Complex64>;
21
22 fn to_sparse(&self) -> QuantRS2Result<CsrMatrix<Complex64>>;
24
25 fn is_unitary(&self, tolerance: f64) -> QuantRS2Result<bool>;
27
28 fn tensor_product(&self, other: &dyn QuantumMatrix) -> QuantRS2Result<Array2<Complex64>>;
30
31 fn apply(&self, state: &ArrayView1<Complex64>) -> QuantRS2Result<Array1<Complex64>>;
33}
34
35#[derive(Debug, Clone)]
37pub struct DenseMatrix {
38 data: Array2<Complex64>,
39}
40
41impl DenseMatrix {
42 pub fn new(data: Array2<Complex64>) -> QuantRS2Result<Self> {
44 if data.nrows() != data.ncols() {
45 return Err(QuantRS2Error::InvalidInput(
46 "Matrix must be square".to_string(),
47 ));
48 }
49 Ok(Self { data })
50 }
51
52 pub fn from_vec(data: Vec<Complex64>, dim: usize) -> QuantRS2Result<Self> {
54 if data.len() != dim * dim {
55 return Err(QuantRS2Error::InvalidInput(format!(
56 "Expected {} elements, got {}",
57 dim * dim,
58 data.len()
59 )));
60 }
61 let matrix = Array2::from_shape_vec((dim, dim), data)
62 .map_err(|e| QuantRS2Error::InvalidInput(e.to_string()))?;
63 Self::new(matrix)
64 }
65
66 pub const fn as_array(&self) -> &Array2<Complex64> {
68 &self.data
69 }
70
71 pub fn is_hermitian(&self, tolerance: f64) -> bool {
73 let n = self.data.nrows();
74 for i in 0..n {
75 for j in i..n {
76 let diff = (self.data[[i, j]] - self.data[[j, i]].conj()).norm();
77 if diff > tolerance {
78 return false;
79 }
80 }
81 }
82 true
83 }
84}
85
86impl QuantumMatrix for DenseMatrix {
87 fn dim(&self) -> usize {
88 self.data.nrows()
89 }
90
91 fn to_dense(&self) -> Array2<Complex64> {
92 self.data.clone()
93 }
94
95 fn to_sparse(&self) -> QuantRS2Result<CsrMatrix<Complex64>> {
96 let n = self.dim();
97 let mut rows = Vec::new();
98 let mut cols = Vec::new();
99 let mut data = Vec::new();
100
101 let tolerance = 1e-14;
102 for i in 0..n {
103 for j in 0..n {
104 let val = self.data[[i, j]];
105 if val.norm() > tolerance {
106 rows.push(i);
107 cols.push(j);
108 data.push(val);
109 }
110 }
111 }
112
113 CsrMatrix::new(data, rows, cols, (n, n)).map_err(|e| QuantRS2Error::InvalidInput(e))
114 }
115
116 fn is_unitary(&self, tolerance: f64) -> QuantRS2Result<bool> {
117 let n = self.dim();
118 let conj_transpose = self.data.t().mapv(|x| x.conj());
119 let product = self.data.dot(&conj_transpose);
120
121 for i in 0..n {
123 for j in 0..n {
124 let expected = if i == j {
125 Complex64::new(1.0, 0.0)
126 } else {
127 Complex64::new(0.0, 0.0)
128 };
129 let diff = (product[[i, j]] - expected).norm();
130 if diff > tolerance {
131 return Ok(false);
132 }
133 }
134 }
135 Ok(true)
136 }
137
138 fn tensor_product(&self, other: &dyn QuantumMatrix) -> QuantRS2Result<Array2<Complex64>> {
139 let other_dense = other.to_dense();
140 let n1 = self.dim();
141 let n2 = other_dense.nrows();
142 let n = n1 * n2;
143
144 let mut result = Array2::zeros((n, n));
145
146 for i1 in 0..n1 {
147 for j1 in 0..n1 {
148 let val1 = self.data[[i1, j1]];
149 for i2 in 0..n2 {
150 for j2 in 0..n2 {
151 let val2 = other_dense[[i2, j2]];
152 result[[i1 * n2 + i2, j1 * n2 + j2]] = val1 * val2;
153 }
154 }
155 }
156 }
157
158 Ok(result)
159 }
160
161 fn apply(&self, state: &ArrayView1<Complex64>) -> QuantRS2Result<Array1<Complex64>> {
162 if state.len() != self.dim() {
163 return Err(QuantRS2Error::InvalidInput(format!(
164 "State dimension {} doesn't match matrix dimension {}",
165 state.len(),
166 self.dim()
167 )));
168 }
169 Ok(self.data.dot(state))
170 }
171}
172
173#[derive(Clone)]
175pub struct SparseMatrix {
176 csr: CsrMatrix<Complex64>,
177 dim: usize,
178}
179
180impl Debug for SparseMatrix {
181 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182 f.debug_struct("SparseMatrix")
183 .field("dim", &self.dim)
184 .field("nnz", &self.csr.nnz())
185 .finish()
186 }
187}
188
189impl SparseMatrix {
190 pub fn new(csr: CsrMatrix<Complex64>) -> QuantRS2Result<Self> {
192 let (rows, cols) = csr.shape();
193 if rows != cols {
194 return Err(QuantRS2Error::InvalidInput(
195 "Matrix must be square".to_string(),
196 ));
197 }
198 Ok(Self { csr, dim: rows })
199 }
200
201 pub fn from_triplets(
203 rows: Vec<usize>,
204 cols: Vec<usize>,
205 data: Vec<Complex64>,
206 dim: usize,
207 ) -> QuantRS2Result<Self> {
208 let csr = CsrMatrix::new(data, rows, cols, (dim, dim))
209 .map_err(|e| QuantRS2Error::InvalidInput(e))?;
210 Self::new(csr)
211 }
212}
213
214impl QuantumMatrix for SparseMatrix {
215 fn dim(&self) -> usize {
216 self.dim
217 }
218
219 fn to_dense(&self) -> Array2<Complex64> {
220 self.csr.to_dense()
221 }
222
223 fn to_sparse(&self) -> QuantRS2Result<CsrMatrix<Complex64>> {
224 Ok(self.csr.clone())
225 }
226
227 fn is_unitary(&self, tolerance: f64) -> QuantRS2Result<bool> {
228 let dense = DenseMatrix::new(self.to_dense())?;
230 dense.is_unitary(tolerance)
231 }
232
233 fn tensor_product(&self, other: &dyn QuantumMatrix) -> QuantRS2Result<Array2<Complex64>> {
234 let dense = DenseMatrix::new(self.to_dense())?;
237 dense.tensor_product(other)
238 }
239
240 fn apply(&self, state: &ArrayView1<Complex64>) -> QuantRS2Result<Array1<Complex64>> {
241 if state.len() != self.dim() {
242 return Err(QuantRS2Error::InvalidInput(format!(
243 "State dimension {} doesn't match matrix dimension {}",
244 state.len(),
245 self.dim()
246 )));
247 }
248 let dense = self.to_dense();
250 Ok(dense.dot(state))
251 }
252}
253
254pub fn partial_trace(
256 matrix: &Array2<Complex64>,
257 keep_qubits: &[usize],
258 total_qubits: usize,
259) -> QuantRS2Result<Array2<Complex64>> {
260 let full_dim = 1 << total_qubits;
261 if matrix.nrows() != full_dim || matrix.ncols() != full_dim {
262 return Err(QuantRS2Error::InvalidInput(format!(
263 "Matrix dimension {} doesn't match {} qubits",
264 matrix.nrows(),
265 total_qubits
266 )));
267 }
268
269 let keep_dim = 1 << keep_qubits.len();
270 let trace_qubits: Vec<usize> = (0..total_qubits)
271 .filter(|q| !keep_qubits.contains(q))
272 .collect();
273 let trace_dim = 1 << trace_qubits.len();
274
275 let mut result = Array2::zeros((keep_dim, keep_dim));
276
277 for i in 0..keep_dim {
279 for j in 0..keep_dim {
280 let mut sum = Complex64::new(0.0, 0.0);
281
282 for t in 0..trace_dim {
284 let row_idx = reconstruct_index(i, t, keep_qubits, &trace_qubits, total_qubits);
285 let col_idx = reconstruct_index(j, t, keep_qubits, &trace_qubits, total_qubits);
286 sum += matrix[[row_idx, col_idx]];
287 }
288
289 result[[i, j]] = sum;
290 }
291 }
292
293 Ok(result)
294}
295
296fn reconstruct_index(
298 keep_idx: usize,
299 trace_idx: usize,
300 keep_qubits: &[usize],
301 trace_qubits: &[usize],
302 _total_qubits: usize,
303) -> usize {
304 let mut index = 0;
305
306 for (i, &q) in keep_qubits.iter().enumerate() {
308 if (keep_idx >> i) & 1 == 1 {
309 index |= 1 << q;
310 }
311 }
312
313 for (i, &q) in trace_qubits.iter().enumerate() {
315 if (trace_idx >> i) & 1 == 1 {
316 index |= 1 << q;
317 }
318 }
319
320 index
321}
322
323pub fn tensor_product_many(matrices: &[&dyn QuantumMatrix]) -> QuantRS2Result<Array2<Complex64>> {
325 if matrices.is_empty() {
326 return Err(QuantRS2Error::InvalidInput(
327 "Cannot compute tensor product of empty list".to_string(),
328 ));
329 }
330
331 if matrices.len() == 1 {
332 return Ok(matrices[0].to_dense());
333 }
334
335 let mut result = matrices[0].to_dense();
336 for matrix in matrices.iter().skip(1) {
337 let dense_result = DenseMatrix::new(result)?;
338 result = dense_result.tensor_product(*matrix)?;
339 }
340
341 Ok(result)
342}
343
344pub fn matrices_approx_equal(
346 a: &ArrayView2<Complex64>,
347 b: &ArrayView2<Complex64>,
348 tolerance: f64,
349) -> bool {
350 if a.shape() != b.shape() {
351 return false;
352 }
353
354 for (x, y) in a.iter().zip(b.iter()) {
355 if (x - y).norm() > tolerance {
356 return false;
357 }
358 }
359
360 true
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366 use scirs2_core::Complex64;
367
368 #[test]
369 fn test_dense_matrix_creation() {
370 let data = Array2::from_shape_vec(
371 (2, 2),
372 vec![
373 Complex64::new(1.0, 0.0),
374 Complex64::new(0.0, 0.0),
375 Complex64::new(0.0, 0.0),
376 Complex64::new(1.0, 0.0),
377 ],
378 )
379 .expect("Matrix data creation should succeed");
380
381 let matrix = DenseMatrix::new(data).expect("DenseMatrix creation should succeed");
382 assert_eq!(matrix.dim(), 2);
383 }
384
385 #[test]
386 fn test_unitary_check() {
387 let sqrt2 = 1.0 / 2.0_f64.sqrt();
389 let data = Array2::from_shape_vec(
390 (2, 2),
391 vec![
392 Complex64::new(sqrt2, 0.0),
393 Complex64::new(sqrt2, 0.0),
394 Complex64::new(sqrt2, 0.0),
395 Complex64::new(-sqrt2, 0.0),
396 ],
397 )
398 .expect("Hadamard matrix data creation should succeed");
399
400 let matrix = DenseMatrix::new(data).expect("DenseMatrix creation should succeed");
401 assert!(matrix
402 .is_unitary(1e-10)
403 .expect("Unitary check should succeed"));
404 }
405
406 #[test]
407 fn test_tensor_product() {
408 let id = DenseMatrix::new(
410 Array2::from_shape_vec(
411 (2, 2),
412 vec![
413 Complex64::new(1.0, 0.0),
414 Complex64::new(0.0, 0.0),
415 Complex64::new(0.0, 0.0),
416 Complex64::new(1.0, 0.0),
417 ],
418 )
419 .expect("Identity matrix data creation should succeed"),
420 )
421 .expect("Identity DenseMatrix creation should succeed");
422
423 let x = DenseMatrix::new(
424 Array2::from_shape_vec(
425 (2, 2),
426 vec![
427 Complex64::new(0.0, 0.0),
428 Complex64::new(1.0, 0.0),
429 Complex64::new(1.0, 0.0),
430 Complex64::new(0.0, 0.0),
431 ],
432 )
433 .expect("Pauli-X matrix data creation should succeed"),
434 )
435 .expect("Pauli-X DenseMatrix creation should succeed");
436
437 let result = id
438 .tensor_product(&x)
439 .expect("Tensor product should succeed");
440 assert_eq!(result.shape(), &[4, 4]);
441
442 assert_eq!(result[[0, 1]], Complex64::new(1.0, 0.0));
444 assert_eq!(result[[2, 3]], Complex64::new(1.0, 0.0));
445 }
446}