use trustformers_core::{
errors::Result,
layers::{LayerNorm, Linear},
quantum::QuantumAttention,
tensor::Tensor,
traits::Layer,
};
use super::{config::QuantumClassicalConfig, model::QuantumClassicalModelOutput};
#[derive(Debug)]
pub struct QuantumAttentionLayer {
pub config: QuantumClassicalConfig,
pub quantum_attention: QuantumAttention,
pub query_projection: Linear,
pub key_projection: Linear,
pub value_projection: Linear,
pub output_projection: Linear,
pub layer_norm: LayerNorm,
}
impl QuantumAttentionLayer {
pub fn new(config: &QuantumClassicalConfig) -> Result<Self> {
let quantum_attention = QuantumAttention::new(config.d_model, config.num_qubits)?;
let query_projection = Linear::new(config.d_model, config.d_model, config.use_bias);
let key_projection = Linear::new(config.d_model, config.d_model, config.use_bias);
let value_projection = Linear::new(config.d_model, config.d_model, config.use_bias);
let output_projection = Linear::new(config.d_model, config.d_model, config.use_bias);
let layer_norm = LayerNorm::new(vec![config.d_model], 1e-12)?;
Ok(Self {
config: config.clone(),
quantum_attention,
query_projection,
key_projection,
value_projection,
output_projection,
layer_norm,
})
}
pub fn forward(&mut self, input: &Tensor) -> Result<QuantumClassicalModelOutput> {
let _batch_size = input.shape()[0];
let seq_len = input.shape()[1];
let mut outputs = Vec::new();
let mut quantum_measurements = Vec::new();
for t in 0..seq_len {
let input_t = input.slice(1, t, t + 1)?.squeeze(1)?;
let query = self.query_projection.forward(input_t.clone())?;
let key = self.key_projection.forward(input_t.clone())?;
let value = self.value_projection.forward(input_t)?;
let quantum_output = self.quantum_attention.forward(&query, &key, &value)?;
quantum_measurements.push(quantum_output.clone());
outputs.push(quantum_output);
}
let output = Tensor::concat(&outputs, 1)?;
let normalized = self.layer_norm.forward(output.clone())?;
let final_output = self.output_projection.forward(normalized)?;
let combined_measurements = if !quantum_measurements.is_empty() {
let mut combined = quantum_measurements[0].clone();
for measurement in quantum_measurements.iter().skip(1) {
combined = combined.add(measurement)?;
}
Some(combined)
} else {
None
};
Ok(QuantumClassicalModelOutput {
hidden_states: final_output.clone(),
quantum_measurements: combined_measurements.clone(),
classical_activations: Some(output),
quantum_attention_weights: Some(combined_measurements.unwrap_or(final_output)),
quantum_fidelity_scores: None,
quantum_entanglement_measures: None,
quantum_error_mitigation: None,
})
}
pub fn parameter_count(&self) -> usize {
self.query_projection.parameter_count()
+ self.key_projection.parameter_count()
+ self.value_projection.parameter_count()
+ self.output_projection.parameter_count()
+ self.layer_norm.parameter_count()
}
pub fn memory_usage(&self) -> f32 {
self.parameter_count() as f32 * 4.0 / 1_000_000.0
}
}