quantrs2_ml/torchquantum/
encoding.rs

1//! Quantum state encoders (TorchQuantum-compatible)
2//!
3//! This module provides encoding schemes for classical data into quantum states:
4//! - GeneralEncoder: Configurable encoder with custom gate sequences
5//! - PhaseEncoder: Phase encoding using RZ gates
6//! - StateEncoder: Direct state vector encoding
7//! - AmplitudeEncoder: Amplitude encoding
8
9use super::{
10    gates::{TQHadamard, TQPauliX, TQPauliY, TQPauliZ, TQRx, TQRy, TQRz, TQSX},
11    CType, NParamsEnum, OpHistoryEntry, TQDevice, TQModule, TQOperator, TQParameter, WiresEnum,
12};
13use crate::error::{MLError, Result};
14use scirs2_core::ndarray::{Array2, ArrayD, IxDyn};
15use scirs2_core::Complex64;
16use std::f64::consts::PI;
17
18/// Base encoder trait
19pub trait TQEncoder: TQModule {
20    /// Encode classical data into quantum state
21    fn encode(&mut self, qdev: &mut TQDevice, x: &Array2<f64>) -> Result<()>;
22}
23
24/// General encoder with configurable gate sequence
25#[derive(Debug, Clone)]
26pub struct TQGeneralEncoder {
27    /// Encoding function list
28    pub func_list: Vec<EncodingOp>,
29    n_wires: Option<usize>,
30    static_mode: bool,
31}
32
33/// Encoding operation specification
34#[derive(Debug, Clone)]
35pub struct EncodingOp {
36    /// Input indices from the data
37    pub input_idx: Vec<usize>,
38    /// Gate function name
39    pub func: String,
40    /// Wires to apply gate
41    pub wires: Vec<usize>,
42}
43
44impl TQGeneralEncoder {
45    pub fn new(func_list: Vec<EncodingOp>) -> Self {
46        Self {
47            func_list,
48            n_wires: None,
49            static_mode: false,
50        }
51    }
52
53    /// Create encoder from predefined pattern
54    pub fn from_pattern(pattern: &str, n_wires: usize) -> Self {
55        let func_list = match pattern {
56            "ry" => (0..n_wires)
57                .map(|i| EncodingOp {
58                    input_idx: vec![i],
59                    func: "ry".to_string(),
60                    wires: vec![i],
61                })
62                .collect(),
63            "rx" => (0..n_wires)
64                .map(|i| EncodingOp {
65                    input_idx: vec![i],
66                    func: "rx".to_string(),
67                    wires: vec![i],
68                })
69                .collect(),
70            "rz" => (0..n_wires)
71                .map(|i| EncodingOp {
72                    input_idx: vec![i],
73                    func: "rz".to_string(),
74                    wires: vec![i],
75                })
76                .collect(),
77            "rxyz" => {
78                let mut ops = Vec::new();
79                for (gate_idx, gate) in ["rx", "ry", "rz"].iter().enumerate() {
80                    for i in 0..n_wires {
81                        ops.push(EncodingOp {
82                            input_idx: vec![gate_idx * n_wires + i],
83                            func: gate.to_string(),
84                            wires: vec![i],
85                        });
86                    }
87                }
88                ops
89            }
90            _ => {
91                // Default: RY encoding
92                (0..n_wires)
93                    .map(|i| EncodingOp {
94                        input_idx: vec![i],
95                        func: "ry".to_string(),
96                        wires: vec![i],
97                    })
98                    .collect()
99            }
100        };
101
102        Self {
103            func_list,
104            n_wires: Some(n_wires),
105            static_mode: false,
106        }
107    }
108}
109
110impl TQModule for TQGeneralEncoder {
111    fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
112        Err(MLError::InvalidConfiguration(
113            "Use encode() instead of forward() for encoders".to_string(),
114        ))
115    }
116
117    fn forward_with_input(&mut self, qdev: &mut TQDevice, x: Option<&Array2<f64>>) -> Result<()> {
118        if let Some(data) = x {
119            self.encode(qdev, data)
120        } else {
121            Err(MLError::InvalidConfiguration(
122                "Input data required for encoder".to_string(),
123            ))
124        }
125    }
126
127    fn parameters(&self) -> Vec<TQParameter> {
128        Vec::new()
129    }
130
131    fn n_wires(&self) -> Option<usize> {
132        self.n_wires
133    }
134
135    fn set_n_wires(&mut self, n_wires: usize) {
136        self.n_wires = Some(n_wires);
137    }
138
139    fn is_static_mode(&self) -> bool {
140        self.static_mode
141    }
142
143    fn static_on(&mut self) {
144        self.static_mode = true;
145    }
146
147    fn static_off(&mut self) {
148        self.static_mode = false;
149    }
150
151    fn name(&self) -> &str {
152        "GeneralEncoder"
153    }
154}
155
156impl TQEncoder for TQGeneralEncoder {
157    fn encode(&mut self, qdev: &mut TQDevice, x: &Array2<f64>) -> Result<()> {
158        let bsz = x.nrows();
159
160        // Ensure device batch size matches input
161        if qdev.bsz != bsz {
162            qdev.reset_states(bsz);
163        }
164
165        for op in &self.func_list {
166            // Get parameters from input data
167            let params: Vec<f64> = op
168                .input_idx
169                .iter()
170                .filter_map(|&idx| {
171                    if idx < x.ncols() {
172                        Some(x[[0, idx]]) // Use first batch element for parameter
173                    } else {
174                        None
175                    }
176                })
177                .collect();
178
179            // Apply gate based on function name
180            match op.func.as_str() {
181                "rx" => {
182                    let mut gate = TQRx::new(true, false);
183                    gate.apply_with_params(qdev, &op.wires, Some(&params))?;
184                }
185                "ry" => {
186                    let mut gate = TQRy::new(true, false);
187                    gate.apply_with_params(qdev, &op.wires, Some(&params))?;
188                }
189                "rz" => {
190                    let mut gate = TQRz::new(true, false);
191                    gate.apply_with_params(qdev, &op.wires, Some(&params))?;
192                }
193                "h" | "hadamard" => {
194                    let mut gate = TQHadamard::new();
195                    gate.apply(qdev, &op.wires)?;
196                }
197                "x" | "paulix" => {
198                    let mut gate = TQPauliX::new();
199                    gate.apply(qdev, &op.wires)?;
200                }
201                "y" | "pauliy" => {
202                    let mut gate = TQPauliY::new();
203                    gate.apply(qdev, &op.wires)?;
204                }
205                "z" | "pauliz" => {
206                    let mut gate = TQPauliZ::new();
207                    gate.apply(qdev, &op.wires)?;
208                }
209                "sx" => {
210                    let mut gate = TQSX::new();
211                    gate.apply(qdev, &op.wires)?;
212                }
213                _ => {
214                    return Err(MLError::InvalidConfiguration(format!(
215                        "Unknown gate: {}",
216                        op.func
217                    )));
218                }
219            }
220        }
221
222        Ok(())
223    }
224}
225
226/// Phase encoder (applies same rotation type to all qubits)
227#[derive(Debug, Clone)]
228pub struct TQPhaseEncoder {
229    /// Rotation type (rx, ry, rz)
230    pub func: String,
231    n_wires: Option<usize>,
232    static_mode: bool,
233}
234
235impl TQPhaseEncoder {
236    pub fn new(func: impl Into<String>) -> Self {
237        Self {
238            func: func.into(),
239            n_wires: None,
240            static_mode: false,
241        }
242    }
243
244    /// Create RY phase encoder
245    pub fn ry() -> Self {
246        Self::new("ry")
247    }
248
249    /// Create RX phase encoder
250    pub fn rx() -> Self {
251        Self::new("rx")
252    }
253
254    /// Create RZ phase encoder
255    pub fn rz() -> Self {
256        Self::new("rz")
257    }
258}
259
260impl TQModule for TQPhaseEncoder {
261    fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
262        Err(MLError::InvalidConfiguration(
263            "Use encode() instead of forward() for encoders".to_string(),
264        ))
265    }
266
267    fn forward_with_input(&mut self, qdev: &mut TQDevice, x: Option<&Array2<f64>>) -> Result<()> {
268        if let Some(data) = x {
269            self.encode(qdev, data)
270        } else {
271            Err(MLError::InvalidConfiguration(
272                "Input data required for encoder".to_string(),
273            ))
274        }
275    }
276
277    fn parameters(&self) -> Vec<TQParameter> {
278        Vec::new()
279    }
280
281    fn n_wires(&self) -> Option<usize> {
282        self.n_wires
283    }
284
285    fn set_n_wires(&mut self, n_wires: usize) {
286        self.n_wires = Some(n_wires);
287    }
288
289    fn is_static_mode(&self) -> bool {
290        self.static_mode
291    }
292
293    fn static_on(&mut self) {
294        self.static_mode = true;
295    }
296
297    fn static_off(&mut self) {
298        self.static_mode = false;
299    }
300
301    fn name(&self) -> &str {
302        "PhaseEncoder"
303    }
304}
305
306impl TQEncoder for TQPhaseEncoder {
307    fn encode(&mut self, qdev: &mut TQDevice, x: &Array2<f64>) -> Result<()> {
308        let n_wires = qdev.n_wires;
309
310        for wire in 0..n_wires {
311            let param = if wire < x.ncols() { x[[0, wire]] } else { 0.0 };
312
313            match self.func.as_str() {
314                "rx" => {
315                    let mut gate = TQRx::new(true, false);
316                    gate.apply_with_params(qdev, &[wire], Some(&[param]))?;
317                }
318                "ry" => {
319                    let mut gate = TQRy::new(true, false);
320                    gate.apply_with_params(qdev, &[wire], Some(&[param]))?;
321                }
322                "rz" => {
323                    let mut gate = TQRz::new(true, false);
324                    gate.apply_with_params(qdev, &[wire], Some(&[param]))?;
325                }
326                _ => {
327                    return Err(MLError::InvalidConfiguration(format!(
328                        "Unknown rotation gate: {}",
329                        self.func
330                    )));
331                }
332            }
333        }
334
335        Ok(())
336    }
337}
338
339/// Amplitude/State encoder (direct state preparation)
340#[derive(Debug, Clone)]
341pub struct TQStateEncoder {
342    n_wires: Option<usize>,
343    static_mode: bool,
344}
345
346impl TQStateEncoder {
347    pub fn new() -> Self {
348        Self {
349            n_wires: None,
350            static_mode: false,
351        }
352    }
353}
354
355impl Default for TQStateEncoder {
356    fn default() -> Self {
357        Self::new()
358    }
359}
360
361impl TQModule for TQStateEncoder {
362    fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
363        Err(MLError::InvalidConfiguration(
364            "Use encode() instead of forward() for encoders".to_string(),
365        ))
366    }
367
368    fn forward_with_input(&mut self, qdev: &mut TQDevice, x: Option<&Array2<f64>>) -> Result<()> {
369        if let Some(data) = x {
370            self.encode(qdev, data)
371        } else {
372            Err(MLError::InvalidConfiguration(
373                "Input data required for encoder".to_string(),
374            ))
375        }
376    }
377
378    fn parameters(&self) -> Vec<TQParameter> {
379        Vec::new()
380    }
381
382    fn n_wires(&self) -> Option<usize> {
383        self.n_wires
384    }
385
386    fn set_n_wires(&mut self, n_wires: usize) {
387        self.n_wires = Some(n_wires);
388    }
389
390    fn is_static_mode(&self) -> bool {
391        self.static_mode
392    }
393
394    fn static_on(&mut self) {
395        self.static_mode = true;
396    }
397
398    fn static_off(&mut self) {
399        self.static_mode = false;
400    }
401
402    fn name(&self) -> &str {
403        "StateEncoder"
404    }
405}
406
407impl TQEncoder for TQStateEncoder {
408    fn encode(&mut self, qdev: &mut TQDevice, x: &Array2<f64>) -> Result<()> {
409        let bsz = x.nrows();
410        let state_size = 1 << qdev.n_wires;
411
412        // Normalize input and prepare state
413        let mut state_data = Vec::with_capacity(state_size * bsz);
414
415        for batch in 0..bsz {
416            // Get amplitude values
417            let mut amplitudes: Vec<f64> = (0..state_size)
418                .map(|i| if i < x.ncols() { x[[batch, i]] } else { 0.0 })
419                .collect();
420
421            // Normalize
422            let norm: f64 = amplitudes.iter().map(|a| a * a).sum::<f64>().sqrt();
423            if norm > 1e-10 {
424                for a in &mut amplitudes {
425                    *a /= norm;
426                }
427            }
428
429            // Convert to complex
430            for &a in &amplitudes {
431                state_data.push(CType::new(a, 0.0));
432            }
433        }
434
435        // Reshape and set states
436        let mut shape = vec![bsz];
437        shape.extend(vec![2; qdev.n_wires]);
438        let states = ArrayD::from_shape_vec(IxDyn(&shape), state_data)
439            .map_err(|e| MLError::InvalidConfiguration(e.to_string()))?;
440
441        qdev.set_states(states);
442
443        Ok(())
444    }
445}
446
447/// Alias for amplitude encoding
448pub type TQAmplitudeEncoder = TQStateEncoder;
449
450/// Multi-phase encoder - applies multiple rotation gates to each qubit
451/// Each feature is encoded with a sequence of gates (e.g., RX, RY, RZ)
452#[derive(Debug, Clone)]
453pub struct TQMultiPhaseEncoder {
454    /// Gate functions to apply (e.g., ["rx", "ry", "rz"])
455    pub funcs: Vec<String>,
456    /// Wire mapping (if None, uses wires 0, 1, 2, ...)
457    pub wires: Option<Vec<usize>>,
458    n_wires: Option<usize>,
459    static_mode: bool,
460}
461
462impl TQMultiPhaseEncoder {
463    pub fn new(funcs: Vec<&str>) -> Self {
464        Self {
465            funcs: funcs.iter().map(|s| s.to_string()).collect(),
466            wires: None,
467            n_wires: None,
468            static_mode: false,
469        }
470    }
471
472    /// Create with specific wire mapping
473    pub fn with_wires(funcs: Vec<&str>, wires: Vec<usize>) -> Self {
474        let n_wires = wires.len();
475        Self {
476            funcs: funcs.iter().map(|s| s.to_string()).collect(),
477            wires: Some(wires),
478            n_wires: Some(n_wires),
479            static_mode: false,
480        }
481    }
482
483    /// Create RX, RY, RZ encoder
484    pub fn rxyz() -> Self {
485        Self::new(vec!["rx", "ry", "rz"])
486    }
487
488    /// Create RY, RZ encoder (common for VQE)
489    pub fn ryrz() -> Self {
490        Self::new(vec!["ry", "rz"])
491    }
492}
493
494impl TQModule for TQMultiPhaseEncoder {
495    fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
496        Err(MLError::InvalidConfiguration(
497            "Use encode() instead of forward() for encoders".to_string(),
498        ))
499    }
500
501    fn forward_with_input(&mut self, qdev: &mut TQDevice, x: Option<&Array2<f64>>) -> Result<()> {
502        if let Some(data) = x {
503            self.encode(qdev, data)
504        } else {
505            Err(MLError::InvalidConfiguration(
506                "Input data required for encoder".to_string(),
507            ))
508        }
509    }
510
511    fn parameters(&self) -> Vec<TQParameter> {
512        Vec::new()
513    }
514
515    fn n_wires(&self) -> Option<usize> {
516        self.n_wires
517    }
518
519    fn set_n_wires(&mut self, n_wires: usize) {
520        self.n_wires = Some(n_wires);
521    }
522
523    fn is_static_mode(&self) -> bool {
524        self.static_mode
525    }
526
527    fn static_on(&mut self) {
528        self.static_mode = true;
529    }
530
531    fn static_off(&mut self) {
532        self.static_mode = false;
533    }
534
535    fn name(&self) -> &str {
536        "MultiPhaseEncoder"
537    }
538}
539
540impl TQEncoder for TQMultiPhaseEncoder {
541    fn encode(&mut self, qdev: &mut TQDevice, x: &Array2<f64>) -> Result<()> {
542        let wires: Vec<usize> = self
543            .wires
544            .clone()
545            .unwrap_or_else(|| (0..qdev.n_wires).collect());
546
547        let mut x_idx = 0;
548
549        for (func_idx, func) in self.funcs.iter().enumerate() {
550            for (wire_idx, &wire) in wires.iter().enumerate() {
551                // Calculate parameter index
552                let param_idx = func_idx * wires.len() + wire_idx;
553                let param = if param_idx < x.ncols() {
554                    x[[0, param_idx]]
555                } else {
556                    0.0
557                };
558
559                match func.as_str() {
560                    "rx" => {
561                        let mut gate = TQRx::new(true, false);
562                        gate.apply_with_params(qdev, &[wire], Some(&[param]))?;
563                    }
564                    "ry" => {
565                        let mut gate = TQRy::new(true, false);
566                        gate.apply_with_params(qdev, &[wire], Some(&[param]))?;
567                    }
568                    "rz" => {
569                        let mut gate = TQRz::new(true, false);
570                        gate.apply_with_params(qdev, &[wire], Some(&[param]))?;
571                    }
572                    "u1" | "phaseshift" => {
573                        let mut gate = TQRz::new(true, false); // U1 ≈ RZ
574                        gate.apply_with_params(qdev, &[wire], Some(&[param]))?;
575                    }
576                    _ => {
577                        return Err(MLError::InvalidConfiguration(format!(
578                            "Unknown gate in MultiPhaseEncoder: {}",
579                            func
580                        )));
581                    }
582                }
583                x_idx += 1;
584            }
585        }
586
587        Ok(())
588    }
589}
590
591/// Magnitude encoder - encodes data in the magnitude of amplitudes
592/// Each classical value is mapped to the magnitude of a computational basis state
593#[derive(Debug, Clone)]
594pub struct TQMagnitudeEncoder {
595    n_wires: Option<usize>,
596    static_mode: bool,
597}
598
599impl TQMagnitudeEncoder {
600    pub fn new() -> Self {
601        Self {
602            n_wires: None,
603            static_mode: false,
604        }
605    }
606}
607
608impl Default for TQMagnitudeEncoder {
609    fn default() -> Self {
610        Self::new()
611    }
612}
613
614impl TQModule for TQMagnitudeEncoder {
615    fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
616        Err(MLError::InvalidConfiguration(
617            "Use encode() instead of forward() for encoders".to_string(),
618        ))
619    }
620
621    fn forward_with_input(&mut self, qdev: &mut TQDevice, x: Option<&Array2<f64>>) -> Result<()> {
622        if let Some(data) = x {
623            self.encode(qdev, data)
624        } else {
625            Err(MLError::InvalidConfiguration(
626                "Input data required for encoder".to_string(),
627            ))
628        }
629    }
630
631    fn parameters(&self) -> Vec<TQParameter> {
632        Vec::new()
633    }
634
635    fn n_wires(&self) -> Option<usize> {
636        self.n_wires
637    }
638
639    fn set_n_wires(&mut self, n_wires: usize) {
640        self.n_wires = Some(n_wires);
641    }
642
643    fn is_static_mode(&self) -> bool {
644        self.static_mode
645    }
646
647    fn static_on(&mut self) {
648        self.static_mode = true;
649    }
650
651    fn static_off(&mut self) {
652        self.static_mode = false;
653    }
654
655    fn name(&self) -> &str {
656        "MagnitudeEncoder"
657    }
658}
659
660impl TQEncoder for TQMagnitudeEncoder {
661    fn encode(&mut self, qdev: &mut TQDevice, x: &Array2<f64>) -> Result<()> {
662        let bsz = x.nrows();
663        let state_size = 1 << qdev.n_wires;
664
665        // Normalize input values to use as magnitudes
666        let mut state_data = Vec::with_capacity(state_size * bsz);
667
668        for batch in 0..bsz {
669            // Get magnitude values (must be non-negative)
670            let mut magnitudes: Vec<f64> = (0..state_size)
671                .map(|i| {
672                    if i < x.ncols() {
673                        x[[batch, i]].abs()
674                    } else {
675                        0.0
676                    }
677                })
678                .collect();
679
680            // Normalize to ensure sum of squared magnitudes = 1
681            let norm_sq: f64 = magnitudes.iter().map(|m| m * m).sum();
682            let norm = norm_sq.sqrt();
683            if norm > 1e-10 {
684                for m in &mut magnitudes {
685                    *m /= norm;
686                }
687            }
688
689            // Convert to complex amplitudes (real-valued)
690            for &m in &magnitudes {
691                state_data.push(CType::new(m, 0.0));
692            }
693        }
694
695        // Reshape and set states
696        let mut shape = vec![bsz];
697        shape.extend(vec![2; qdev.n_wires]);
698        let states = ArrayD::from_shape_vec(IxDyn(&shape), state_data)
699            .map_err(|e| MLError::InvalidConfiguration(e.to_string()))?;
700
701        qdev.set_states(states);
702
703        Ok(())
704    }
705}
706
707/// Angle encoder - encodes data as angles in rotation gates
708/// More flexible than PhaseEncoder with configurable scaling
709#[derive(Debug, Clone)]
710pub struct TQAngleEncoder {
711    /// Rotation type (rx, ry, rz)
712    pub func: String,
713    /// Scaling factor for input values
714    pub scaling: f64,
715    n_wires: Option<usize>,
716    static_mode: bool,
717}
718
719impl TQAngleEncoder {
720    pub fn new(func: impl Into<String>, scaling: f64) -> Self {
721        Self {
722            func: func.into(),
723            scaling,
724            n_wires: None,
725            static_mode: false,
726        }
727    }
728
729    /// Create with default PI scaling (maps [0,1] to [0, PI])
730    pub fn with_pi_scaling(func: impl Into<String>) -> Self {
731        Self::new(func, PI)
732    }
733
734    /// Create with 2*PI scaling (maps [0,1] to [0, 2*PI])
735    pub fn with_2pi_scaling(func: impl Into<String>) -> Self {
736        Self::new(func, 2.0 * PI)
737    }
738
739    /// Create RY encoder with arcsin scaling (for probability amplitude encoding)
740    pub fn arcsin() -> Self {
741        Self {
742            func: "ry".to_string(),
743            scaling: 1.0, // Will use arcsin transformation
744            n_wires: None,
745            static_mode: false,
746        }
747    }
748}
749
750impl TQModule for TQAngleEncoder {
751    fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
752        Err(MLError::InvalidConfiguration(
753            "Use encode() instead of forward() for encoders".to_string(),
754        ))
755    }
756
757    fn forward_with_input(&mut self, qdev: &mut TQDevice, x: Option<&Array2<f64>>) -> Result<()> {
758        if let Some(data) = x {
759            self.encode(qdev, data)
760        } else {
761            Err(MLError::InvalidConfiguration(
762                "Input data required for encoder".to_string(),
763            ))
764        }
765    }
766
767    fn parameters(&self) -> Vec<TQParameter> {
768        Vec::new()
769    }
770
771    fn n_wires(&self) -> Option<usize> {
772        self.n_wires
773    }
774
775    fn set_n_wires(&mut self, n_wires: usize) {
776        self.n_wires = Some(n_wires);
777    }
778
779    fn is_static_mode(&self) -> bool {
780        self.static_mode
781    }
782
783    fn static_on(&mut self) {
784        self.static_mode = true;
785    }
786
787    fn static_off(&mut self) {
788        self.static_mode = false;
789    }
790
791    fn name(&self) -> &str {
792        "AngleEncoder"
793    }
794}
795
796impl TQEncoder for TQAngleEncoder {
797    fn encode(&mut self, qdev: &mut TQDevice, x: &Array2<f64>) -> Result<()> {
798        let n_wires = qdev.n_wires;
799
800        for wire in 0..n_wires {
801            let raw_value = if wire < x.ncols() { x[[0, wire]] } else { 0.0 };
802
803            // Apply scaling
804            let param = if self.func == "arcsin" {
805                // Arcsin encoding: map value to angle via arcsin
806                // Clamp to [-1, 1] for valid arcsin input
807                let clamped = raw_value.clamp(-1.0, 1.0);
808                2.0 * clamped.asin()
809            } else {
810                raw_value * self.scaling
811            };
812
813            match self.func.as_str() {
814                "rx" => {
815                    let mut gate = TQRx::new(true, false);
816                    gate.apply_with_params(qdev, &[wire], Some(&[param]))?;
817                }
818                "ry" | "arcsin" => {
819                    let mut gate = TQRy::new(true, false);
820                    gate.apply_with_params(qdev, &[wire], Some(&[param]))?;
821                }
822                "rz" => {
823                    let mut gate = TQRz::new(true, false);
824                    gate.apply_with_params(qdev, &[wire], Some(&[param]))?;
825                }
826                _ => {
827                    return Err(MLError::InvalidConfiguration(format!(
828                        "Unknown rotation gate: {}",
829                        self.func
830                    )));
831                }
832            }
833        }
834
835        Ok(())
836    }
837}
838
839/// IQP (Instantaneous Quantum Polynomial) encoder
840/// Encodes data using IQP-style circuit with entangling ZZ interactions
841#[derive(Debug, Clone)]
842pub struct TQIQPEncoder {
843    /// Number of repetitions
844    pub reps: usize,
845    n_wires: Option<usize>,
846    static_mode: bool,
847}
848
849impl TQIQPEncoder {
850    pub fn new(reps: usize) -> Self {
851        Self {
852            reps,
853            n_wires: None,
854            static_mode: false,
855        }
856    }
857}
858
859impl Default for TQIQPEncoder {
860    fn default() -> Self {
861        Self::new(1)
862    }
863}
864
865impl TQModule for TQIQPEncoder {
866    fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
867        Err(MLError::InvalidConfiguration(
868            "Use encode() instead of forward() for encoders".to_string(),
869        ))
870    }
871
872    fn forward_with_input(&mut self, qdev: &mut TQDevice, x: Option<&Array2<f64>>) -> Result<()> {
873        if let Some(data) = x {
874            self.encode(qdev, data)
875        } else {
876            Err(MLError::InvalidConfiguration(
877                "Input data required for encoder".to_string(),
878            ))
879        }
880    }
881
882    fn parameters(&self) -> Vec<TQParameter> {
883        Vec::new()
884    }
885
886    fn n_wires(&self) -> Option<usize> {
887        self.n_wires
888    }
889
890    fn set_n_wires(&mut self, n_wires: usize) {
891        self.n_wires = Some(n_wires);
892    }
893
894    fn is_static_mode(&self) -> bool {
895        self.static_mode
896    }
897
898    fn static_on(&mut self) {
899        self.static_mode = true;
900    }
901
902    fn static_off(&mut self) {
903        self.static_mode = false;
904    }
905
906    fn name(&self) -> &str {
907        "IQPEncoder"
908    }
909}
910
911impl TQEncoder for TQIQPEncoder {
912    fn encode(&mut self, qdev: &mut TQDevice, x: &Array2<f64>) -> Result<()> {
913        use super::gates::TQRZZ;
914
915        let n_wires = qdev.n_wires;
916
917        for _ in 0..self.reps {
918            // First: Hadamard on all qubits
919            for wire in 0..n_wires {
920                let mut h = TQHadamard::new();
921                h.apply(qdev, &[wire])?;
922            }
923
924            // Second: RZ encoding
925            for wire in 0..n_wires {
926                let param = if wire < x.ncols() { x[[0, wire]] } else { 0.0 };
927                let mut rz = TQRz::new(true, false);
928                rz.apply_with_params(qdev, &[wire], Some(&[param]))?;
929            }
930
931            // Third: ZZ interactions (product encoding)
932            let mut pair_idx = 0;
933            for i in 0..n_wires {
934                for j in (i + 1)..n_wires {
935                    // Product of features
936                    let xi = if i < x.ncols() { x[[0, i]] } else { 0.0 };
937                    let xj = if j < x.ncols() { x[[0, j]] } else { 0.0 };
938                    let param = xi * xj;
939
940                    let mut rzz = TQRZZ::new(true, false);
941                    rzz.apply_with_params(qdev, &[i, j], Some(&[param]))?;
942                    pair_idx += 1;
943                }
944            }
945        }
946
947        Ok(())
948    }
949}