use trustformers_core::{
errors::Result,
layers::{LayerNorm, Linear},
quantum::{QuantumAnsatz, QuantumNeuralLayer},
tensor::Tensor,
traits::Layer,
};
use super::{config::QuantumClassicalConfig, model::QuantumClassicalModelOutput};
#[derive(Debug)]
pub struct QuantumRecurrentNN {
pub config: QuantumClassicalConfig,
pub rnn_layers: Vec<Linear>,
pub quantum_memory_layers: Vec<QuantumNeuralLayer>,
pub hidden_state: Option<Tensor>,
pub layer_norm: LayerNorm,
pub output_projection: Linear,
}
impl QuantumRecurrentNN {
pub fn new(config: &QuantumClassicalConfig) -> Result<Self> {
let mut rnn_layers = Vec::new();
let mut quantum_memory_layers = Vec::new();
for _ in 0..config.n_classical_layers {
rnn_layers.push(Linear::new(
config.d_model * 2,
config.d_model,
config.use_bias,
));
}
for _ in 0..config.n_quantum_layers {
let ansatz = QuantumAnsatz::from(config.quantum_ansatz.clone());
let parameters = vec![0.1; config.get_quantum_parameters_count()];
quantum_memory_layers.push(QuantumNeuralLayer::new(
config.num_qubits,
ansatz,
¶meters,
)?);
}
let layer_norm = LayerNorm::new(vec![config.d_model], 1e-12)?;
let output_projection = Linear::new(config.d_model, config.d_model, config.use_bias);
Ok(Self {
config: config.clone(),
rnn_layers,
quantum_memory_layers,
hidden_state: None,
layer_norm,
output_projection,
})
}
pub fn forward(&mut self, input: &Tensor) -> Result<QuantumClassicalModelOutput> {
let batch_size = input.shape()[0];
let seq_len = input.shape()[1];
let d_model = self.config.d_model;
if self.hidden_state.is_none() {
self.hidden_state = Some(Tensor::zeros(&[batch_size, d_model])?);
}
let mut outputs = Vec::new();
let mut quantum_measurements = Vec::new();
for t in 0..seq_len {
let input_t = input.slice(1, t, t + 1)?.squeeze(1)?;
let combined = Tensor::concat(
&[
input_t,
self.hidden_state.as_ref().expect("operation failed").clone(),
],
1,
)?;
let mut hidden = combined;
for rnn_layer in &self.rnn_layers {
hidden = rnn_layer.forward(hidden)?;
hidden = hidden.tanh()?;
}
self.hidden_state = Some(hidden.clone());
for quantum_layer in &mut self.quantum_memory_layers {
let quantum_output = quantum_layer.forward(&hidden)?;
quantum_measurements.push(quantum_output.clone());
hidden = quantum_output;
}
outputs.push(hidden);
}
let output = Tensor::concat(&outputs, 1)?;
let normalized = self.layer_norm.forward(output.clone())?;
let final_output = self.output_projection.forward(normalized)?;
let combined_measurements = if !quantum_measurements.is_empty() {
let mut combined = quantum_measurements[0].clone();
for measurement in quantum_measurements.iter().skip(1) {
combined = combined.add(measurement)?;
}
Some(combined)
} else {
None
};
Ok(QuantumClassicalModelOutput {
hidden_states: final_output,
quantum_measurements: combined_measurements,
classical_activations: Some(output),
quantum_attention_weights: None,
quantum_fidelity_scores: None,
quantum_entanglement_measures: None,
quantum_error_mitigation: None,
})
}
pub fn reset_hidden_state(&mut self) {
self.hidden_state = None;
}
pub fn parameter_count(&self) -> usize {
self.rnn_layers.iter().map(|l| l.parameter_count()).sum::<usize>()
+ self.layer_norm.parameter_count()
+ self.output_projection.parameter_count()
+ self.quantum_memory_layers.len() * self.config.get_quantum_parameters_count()
}
pub fn memory_usage(&self) -> f32 {
self.parameter_count() as f32 * 4.0 / 1_000_000.0
}
}