1use crate::error::{QuantRS2Error, QuantRS2Result};
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
9use scirs2_core::Complex64;
10use crate::linalg_stubs::CsrMatrix;
12use std::fmt::Debug;
13
14pub trait QuantumMatrix: Debug + Send + Sync {
16 fn dim(&self) -> usize;
18
19 fn to_dense(&self) -> Array2<Complex64>;
21
22 fn to_sparse(&self) -> QuantRS2Result<CsrMatrix<Complex64>>;
24
25 fn is_unitary(&self, tolerance: f64) -> QuantRS2Result<bool>;
27
28 fn tensor_product(&self, other: &dyn QuantumMatrix) -> QuantRS2Result<Array2<Complex64>>;
30
31 fn apply(&self, state: &ArrayView1<Complex64>) -> QuantRS2Result<Array1<Complex64>>;
33}
34
35#[derive(Debug, Clone)]
37pub struct DenseMatrix {
38 data: Array2<Complex64>,
39}
40
41impl DenseMatrix {
42 pub fn new(data: Array2<Complex64>) -> QuantRS2Result<Self> {
44 if data.nrows() != data.ncols() {
45 return Err(QuantRS2Error::InvalidInput(
46 "Matrix must be square".to_string(),
47 ));
48 }
49 Ok(Self { data })
50 }
51
52 pub fn from_vec(data: Vec<Complex64>, dim: usize) -> QuantRS2Result<Self> {
54 if data.len() != dim * dim {
55 return Err(QuantRS2Error::InvalidInput(format!(
56 "Expected {} elements, got {}",
57 dim * dim,
58 data.len()
59 )));
60 }
61 let matrix = Array2::from_shape_vec((dim, dim), data)
62 .map_err(|e| QuantRS2Error::InvalidInput(e.to_string()))?;
63 Self::new(matrix)
64 }
65
66 pub fn as_array(&self) -> &Array2<Complex64> {
68 &self.data
69 }
70
71 pub fn is_hermitian(&self, tolerance: f64) -> bool {
73 let n = self.data.nrows();
74 for i in 0..n {
75 for j in i..n {
76 let diff = (self.data[[i, j]] - self.data[[j, i]].conj()).norm();
77 if diff > tolerance {
78 return false;
79 }
80 }
81 }
82 true
83 }
84}
85
86impl QuantumMatrix for DenseMatrix {
87 fn dim(&self) -> usize {
88 self.data.nrows()
89 }
90
91 fn to_dense(&self) -> Array2<Complex64> {
92 self.data.clone()
93 }
94
95 fn to_sparse(&self) -> QuantRS2Result<CsrMatrix<Complex64>> {
96 let n = self.dim();
97 let mut rows = Vec::new();
98 let mut cols = Vec::new();
99 let mut data = Vec::new();
100
101 let tolerance = 1e-14;
102 for i in 0..n {
103 for j in 0..n {
104 let val = self.data[[i, j]];
105 if val.norm() > tolerance {
106 rows.push(i);
107 cols.push(j);
108 data.push(val);
109 }
110 }
111 }
112
113 CsrMatrix::new(data, rows, cols, (n, n))
114 .map_err(|e| QuantRS2Error::InvalidInput(e.to_string()))
115 }
116
117 fn is_unitary(&self, tolerance: f64) -> QuantRS2Result<bool> {
118 let n = self.dim();
119 let conj_transpose = self.data.t().mapv(|x| x.conj());
120 let product = self.data.dot(&conj_transpose);
121
122 for i in 0..n {
124 for j in 0..n {
125 let expected = if i == j {
126 Complex64::new(1.0, 0.0)
127 } else {
128 Complex64::new(0.0, 0.0)
129 };
130 let diff = (product[[i, j]] - expected).norm();
131 if diff > tolerance {
132 return Ok(false);
133 }
134 }
135 }
136 Ok(true)
137 }
138
139 fn tensor_product(&self, other: &dyn QuantumMatrix) -> QuantRS2Result<Array2<Complex64>> {
140 let other_dense = other.to_dense();
141 let n1 = self.dim();
142 let n2 = other_dense.nrows();
143 let n = n1 * n2;
144
145 let mut result = Array2::zeros((n, n));
146
147 for i1 in 0..n1 {
148 for j1 in 0..n1 {
149 let val1 = self.data[[i1, j1]];
150 for i2 in 0..n2 {
151 for j2 in 0..n2 {
152 let val2 = other_dense[[i2, j2]];
153 result[[i1 * n2 + i2, j1 * n2 + j2]] = val1 * val2;
154 }
155 }
156 }
157 }
158
159 Ok(result)
160 }
161
162 fn apply(&self, state: &ArrayView1<Complex64>) -> QuantRS2Result<Array1<Complex64>> {
163 if state.len() != self.dim() {
164 return Err(QuantRS2Error::InvalidInput(format!(
165 "State dimension {} doesn't match matrix dimension {}",
166 state.len(),
167 self.dim()
168 )));
169 }
170 Ok(self.data.dot(state))
171 }
172}
173
174#[derive(Clone)]
176pub struct SparseMatrix {
177 csr: CsrMatrix<Complex64>,
178 dim: usize,
179}
180
181impl Debug for SparseMatrix {
182 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183 f.debug_struct("SparseMatrix")
184 .field("dim", &self.dim)
185 .field("nnz", &self.csr.nnz())
186 .finish()
187 }
188}
189
190impl SparseMatrix {
191 pub fn new(csr: CsrMatrix<Complex64>) -> QuantRS2Result<Self> {
193 let (rows, cols) = csr.shape();
194 if rows != cols {
195 return Err(QuantRS2Error::InvalidInput(
196 "Matrix must be square".to_string(),
197 ));
198 }
199 Ok(Self { csr, dim: rows })
200 }
201
202 pub fn from_triplets(
204 rows: Vec<usize>,
205 cols: Vec<usize>,
206 data: Vec<Complex64>,
207 dim: usize,
208 ) -> QuantRS2Result<Self> {
209 let csr = CsrMatrix::new(data, rows, cols, (dim, dim))
210 .map_err(|e| QuantRS2Error::InvalidInput(e.to_string()))?;
211 Self::new(csr)
212 }
213}
214
215impl QuantumMatrix for SparseMatrix {
216 fn dim(&self) -> usize {
217 self.dim
218 }
219
220 fn to_dense(&self) -> Array2<Complex64> {
221 self.csr.to_dense()
222 }
223
224 fn to_sparse(&self) -> QuantRS2Result<CsrMatrix<Complex64>> {
225 Ok(self.csr.clone())
226 }
227
228 fn is_unitary(&self, tolerance: f64) -> QuantRS2Result<bool> {
229 let dense = DenseMatrix::new(self.to_dense())?;
231 dense.is_unitary(tolerance)
232 }
233
234 fn tensor_product(&self, other: &dyn QuantumMatrix) -> QuantRS2Result<Array2<Complex64>> {
235 let dense = DenseMatrix::new(self.to_dense())?;
238 dense.tensor_product(other)
239 }
240
241 fn apply(&self, state: &ArrayView1<Complex64>) -> QuantRS2Result<Array1<Complex64>> {
242 if state.len() != self.dim() {
243 return Err(QuantRS2Error::InvalidInput(format!(
244 "State dimension {} doesn't match matrix dimension {}",
245 state.len(),
246 self.dim()
247 )));
248 }
249 let dense = self.to_dense();
251 Ok(dense.dot(state))
252 }
253}
254
255pub fn partial_trace(
257 matrix: &Array2<Complex64>,
258 keep_qubits: &[usize],
259 total_qubits: usize,
260) -> QuantRS2Result<Array2<Complex64>> {
261 let full_dim = 1 << total_qubits;
262 if matrix.nrows() != full_dim || matrix.ncols() != full_dim {
263 return Err(QuantRS2Error::InvalidInput(format!(
264 "Matrix dimension {} doesn't match {} qubits",
265 matrix.nrows(),
266 total_qubits
267 )));
268 }
269
270 let keep_dim = 1 << keep_qubits.len();
271 let trace_qubits: Vec<usize> = (0..total_qubits)
272 .filter(|q| !keep_qubits.contains(q))
273 .collect();
274 let trace_dim = 1 << trace_qubits.len();
275
276 let mut result = Array2::zeros((keep_dim, keep_dim));
277
278 for i in 0..keep_dim {
280 for j in 0..keep_dim {
281 let mut sum = Complex64::new(0.0, 0.0);
282
283 for t in 0..trace_dim {
285 let row_idx = reconstruct_index(i, t, keep_qubits, &trace_qubits, total_qubits);
286 let col_idx = reconstruct_index(j, t, keep_qubits, &trace_qubits, total_qubits);
287 sum += matrix[[row_idx, col_idx]];
288 }
289
290 result[[i, j]] = sum;
291 }
292 }
293
294 Ok(result)
295}
296
297fn reconstruct_index(
299 keep_idx: usize,
300 trace_idx: usize,
301 keep_qubits: &[usize],
302 trace_qubits: &[usize],
303 _total_qubits: usize,
304) -> usize {
305 let mut index = 0;
306
307 for (i, &q) in keep_qubits.iter().enumerate() {
309 if (keep_idx >> i) & 1 == 1 {
310 index |= 1 << q;
311 }
312 }
313
314 for (i, &q) in trace_qubits.iter().enumerate() {
316 if (trace_idx >> i) & 1 == 1 {
317 index |= 1 << q;
318 }
319 }
320
321 index
322}
323
324pub fn tensor_product_many(matrices: &[&dyn QuantumMatrix]) -> QuantRS2Result<Array2<Complex64>> {
326 if matrices.is_empty() {
327 return Err(QuantRS2Error::InvalidInput(
328 "Cannot compute tensor product of empty list".to_string(),
329 ));
330 }
331
332 if matrices.len() == 1 {
333 return Ok(matrices[0].to_dense());
334 }
335
336 let mut result = matrices[0].to_dense();
337 for matrix in matrices.iter().skip(1) {
338 let dense_result = DenseMatrix::new(result)?;
339 result = dense_result.tensor_product(*matrix)?;
340 }
341
342 Ok(result)
343}
344
345pub fn matrices_approx_equal(
347 a: &ArrayView2<Complex64>,
348 b: &ArrayView2<Complex64>,
349 tolerance: f64,
350) -> bool {
351 if a.shape() != b.shape() {
352 return false;
353 }
354
355 for (x, y) in a.iter().zip(b.iter()) {
356 if (x - y).norm() > tolerance {
357 return false;
358 }
359 }
360
361 true
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367 use scirs2_core::Complex64;
368
369 #[test]
370 fn test_dense_matrix_creation() {
371 let data = Array2::from_shape_vec(
372 (2, 2),
373 vec![
374 Complex64::new(1.0, 0.0),
375 Complex64::new(0.0, 0.0),
376 Complex64::new(0.0, 0.0),
377 Complex64::new(1.0, 0.0),
378 ],
379 )
380 .unwrap();
381
382 let matrix = DenseMatrix::new(data).unwrap();
383 assert_eq!(matrix.dim(), 2);
384 }
385
386 #[test]
387 fn test_unitary_check() {
388 let sqrt2 = 1.0 / 2.0_f64.sqrt();
390 let data = Array2::from_shape_vec(
391 (2, 2),
392 vec![
393 Complex64::new(sqrt2, 0.0),
394 Complex64::new(sqrt2, 0.0),
395 Complex64::new(sqrt2, 0.0),
396 Complex64::new(-sqrt2, 0.0),
397 ],
398 )
399 .unwrap();
400
401 let matrix = DenseMatrix::new(data).unwrap();
402 assert!(matrix.is_unitary(1e-10).unwrap());
403 }
404
405 #[test]
406 fn test_tensor_product() {
407 let id = DenseMatrix::new(
409 Array2::from_shape_vec(
410 (2, 2),
411 vec![
412 Complex64::new(1.0, 0.0),
413 Complex64::new(0.0, 0.0),
414 Complex64::new(0.0, 0.0),
415 Complex64::new(1.0, 0.0),
416 ],
417 )
418 .unwrap(),
419 )
420 .unwrap();
421
422 let x = DenseMatrix::new(
423 Array2::from_shape_vec(
424 (2, 2),
425 vec![
426 Complex64::new(0.0, 0.0),
427 Complex64::new(1.0, 0.0),
428 Complex64::new(1.0, 0.0),
429 Complex64::new(0.0, 0.0),
430 ],
431 )
432 .unwrap(),
433 )
434 .unwrap();
435
436 let result = id.tensor_product(&x).unwrap();
437 assert_eq!(result.shape(), &[4, 4]);
438
439 assert_eq!(result[[0, 1]], Complex64::new(1.0, 0.0));
441 assert_eq!(result[[2, 3]], Complex64::new(1.0, 0.0));
442 }
443}