Skip to main content

quantrs2_core/qml/
nlp.rs

1//! Quantum Machine Learning for Natural Language Processing
2//!
3//! This module provides specialized quantum machine learning layers and algorithms
4//! optimized for natural language processing tasks such as text classification,
5//! sentiment analysis, and language modeling.
6
7use super::{Parameter, QMLLayer};
8use crate::{
9    error::{QuantRS2Error, QuantRS2Result},
10    gate::{multi::*, single::*, GateOp},
11    parametric::{ParametricRotationX, ParametricRotationY, ParametricRotationZ},
12    qubit::QubitId,
13};
14use scirs2_core::ndarray::Array1;
15use scirs2_core::Complex64;
16use std::collections::HashMap;
17use std::f64::consts::PI;
18
19/// Text embedding strategies for quantum NLP
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum TextEmbeddingStrategy {
22    /// Word-level embeddings: each word is encoded separately
23    WordLevel,
24    /// Character-level embeddings: each character is encoded
25    CharLevel,
26    /// N-gram embeddings: overlapping n-grams are encoded
27    NGram(usize),
28    /// Token embeddings with positional encoding
29    TokenPositional,
30    /// Hierarchical embeddings: words -> sentences -> documents
31    Hierarchical,
32}
33
34/// Configuration for quantum NLP models
35#[derive(Debug, Clone)]
36pub struct QNLPConfig {
37    /// Number of qubits for text representation
38    pub text_qubits: usize,
39    /// Number of qubits for feature extraction
40    pub feature_qubits: usize,
41    /// Maximum sequence length
42    pub max_sequence_length: usize,
43    /// Vocabulary size
44    pub vocab_size: usize,
45    /// Embedding dimension
46    pub embedding_dim: usize,
47    /// Text embedding strategy
48    pub embedding_strategy: TextEmbeddingStrategy,
49    /// Number of attention heads (for quantum attention)
50    pub num_attention_heads: usize,
51    /// Hidden dimension for feedforward layers
52    pub hidden_dim: usize,
53}
54
55impl Default for QNLPConfig {
56    fn default() -> Self {
57        Self {
58            text_qubits: 8,
59            feature_qubits: 4,
60            max_sequence_length: 32,
61            vocab_size: 1000,
62            embedding_dim: 64,
63            embedding_strategy: TextEmbeddingStrategy::WordLevel,
64            num_attention_heads: 4,
65            hidden_dim: 128,
66        }
67    }
68}
69
70/// Quantum word embedding layer
71pub struct QuantumWordEmbedding {
72    /// Configuration
73    config: QNLPConfig,
74    /// Embedding parameters for each word in vocabulary
75    embeddings: Vec<Vec<Parameter>>,
76    /// Flattened view of all embedding parameters (row-major: word_id * num_qubits + qubit)
77    /// This cache is the single source of truth exposed via the QMLLayer trait.
78    /// It is kept in sync with `embeddings` via `rebuild_flat_cache` and `sync_from_flat`.
79    flat_params: Vec<Parameter>,
80    /// Number of qubits
81    num_qubits: usize,
82}
83
84impl QuantumWordEmbedding {
85    /// Create a new quantum word embedding layer
86    pub fn new(config: QNLPConfig) -> Self {
87        let num_qubits = config.text_qubits;
88        let mut embeddings = Vec::new();
89        let mut flat_params: Vec<Parameter> = Vec::new();
90
91        // Initialize embeddings for each word in vocabulary
92        for word_id in 0..config.vocab_size {
93            let mut word_embedding = Vec::new();
94            for qubit in 0..num_qubits {
95                // Initialize with deterministic pseudo-random values
96                let value = ((word_id * qubit.max(1)) as f64 * 0.1).sin() * 0.5;
97                let param = Parameter {
98                    name: format!("embed_{word_id}_{qubit}"),
99                    value,
100                    bounds: None,
101                };
102                flat_params.push(param.clone());
103                word_embedding.push(param);
104            }
105            embeddings.push(word_embedding);
106        }
107
108        Self {
109            config,
110            embeddings,
111            flat_params,
112            num_qubits,
113        }
114    }
115
116    /// Rebuild the flat parameter cache from the nested embeddings.
117    fn rebuild_flat_cache(&mut self) {
118        self.flat_params.clear();
119        for word_emb in &self.embeddings {
120            self.flat_params.extend(word_emb.iter().cloned());
121        }
122    }
123
124    /// Sync the nested embeddings from the flat parameter cache after an
125    /// external mutation through `parameters_mut()`.
126    fn sync_from_flat(&mut self) {
127        let nq = self.num_qubits;
128        for (word_id, word_emb) in self.embeddings.iter_mut().enumerate() {
129            for (qubit, param) in word_emb.iter_mut().enumerate() {
130                let flat_idx = word_id * nq + qubit;
131                if let Some(flat_param) = self.flat_params.get(flat_idx) {
132                    param.value = flat_param.value;
133                }
134            }
135        }
136    }
137
138    /// Encode a sequence of word IDs into quantum gates
139    pub fn encode_sequence(&self, word_ids: &[usize]) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
140        let mut gates: Vec<Box<dyn GateOp>> = Vec::new();
141        let nq = self.num_qubits;
142
143        for (position, &word_id) in word_ids.iter().enumerate() {
144            if word_id >= self.config.vocab_size {
145                return Err(QuantRS2Error::InvalidInput(format!(
146                    "Word ID {} exceeds vocabulary size {}",
147                    word_id, self.config.vocab_size
148                )));
149            }
150
151            if position >= self.config.max_sequence_length {
152                break; // Truncate sequence if too long
153            }
154
155            // Read embedding values from the flat_params cache (canonical store)
156            let flat_base = word_id * nq;
157            for qubit_idx in 0..nq {
158                let flat_idx = flat_base + qubit_idx;
159                let value = self
160                    .flat_params
161                    .get(flat_idx)
162                    .map(|p| p.value)
163                    .unwrap_or(0.0);
164
165                let qubit = QubitId(qubit_idx as u32);
166
167                // Use rotation gates to encode the embedding values
168                gates.push(Box::new(ParametricRotationY {
169                    target: qubit,
170                    theta: crate::parametric::Parameter::Constant(value * PI),
171                }));
172
173                // Add positional encoding (sinusoidal, scaled to small contribution)
174                let positional_angle =
175                    (position as f64) / (self.config.max_sequence_length as f64) * PI;
176                gates.push(Box::new(ParametricRotationZ {
177                    target: qubit,
178                    theta: crate::parametric::Parameter::Constant(positional_angle * 0.1),
179                }));
180            }
181        }
182
183        Ok(gates)
184    }
185}
186
187impl QMLLayer for QuantumWordEmbedding {
188    fn num_qubits(&self) -> usize {
189        self.num_qubits
190    }
191
192    fn parameters(&self) -> &[Parameter] {
193        // Return the pre-built flat cache.  The cache is row-major over
194        // (word_id, qubit) and built on construction; it is also updated
195        // whenever set_parameters() is called via parameters_mut().
196        &self.flat_params
197    }
198
199    fn parameters_mut(&mut self) -> &mut [Parameter] {
200        // Callers mutate the flat cache.  The nested `embeddings` field is
201        // a convenience copy that is kept in sync by sync_from_flat(), which
202        // is called by the default set_parameters() implementation via this
203        // method.  If callers mutate flat_params directly (e.g. in a training
204        // loop) they should call sync_from_flat() before using encode_sequence.
205        &mut self.flat_params
206    }
207
208    fn gates(&self) -> Vec<Box<dyn GateOp>> {
209        // Return empty - this layer provides encoding method
210        Vec::new()
211    }
212
213    fn compute_gradients(
214        &self,
215        _state: &Array1<Complex64>,
216        _loss_gradient: &Array1<Complex64>,
217    ) -> QuantRS2Result<Vec<f64>> {
218        // Placeholder for gradient computation
219        let total_params = self.config.vocab_size * self.num_qubits;
220        Ok(vec![0.0; total_params])
221    }
222
223    fn name(&self) -> &'static str {
224        "QuantumWordEmbedding"
225    }
226}
227
228/// Quantum attention mechanism for NLP
229pub struct QuantumAttention {
230    /// Number of qubits
231    num_qubits: usize,
232    /// Number of attention heads
233    num_heads: usize,
234    /// Query parameters
235    query_params: Vec<Parameter>,
236    /// Key parameters
237    key_params: Vec<Parameter>,
238    /// Value parameters
239    value_params: Vec<Parameter>,
240    /// Output projection parameters
241    output_params: Vec<Parameter>,
242    /// Flattened view: [query... | key... | value... | output...]
243    /// Used by the QMLLayer trait (parameters / parameters_mut).
244    flat_params: Vec<Parameter>,
245}
246
247impl QuantumAttention {
248    /// Create a new quantum attention layer
249    pub fn new(num_qubits: usize, num_heads: usize) -> Self {
250        let params_per_head = num_qubits / num_heads.max(1);
251
252        let mut query_params = Vec::new();
253        let mut key_params = Vec::new();
254        let mut value_params = Vec::new();
255        let mut output_params = Vec::new();
256
257        // Initialize parameters for each head
258        for head in 0..num_heads {
259            for i in 0..params_per_head {
260                // Query parameters
261                query_params.push(Parameter {
262                    name: format!("query_{head}_{i}"),
263                    value: ((head + i) as f64 * 0.1).sin() * 0.5,
264                    bounds: None,
265                });
266
267                // Key parameters
268                key_params.push(Parameter {
269                    name: format!("key_{head}_{i}"),
270                    value: ((head + i + 1) as f64 * 0.1).cos() * 0.5,
271                    bounds: None,
272                });
273
274                // Value parameters
275                value_params.push(Parameter {
276                    name: format!("value_{head}_{i}"),
277                    value: ((head + i + 2) as f64 * 0.1).sin() * 0.5,
278                    bounds: None,
279                });
280
281                // Output parameters
282                output_params.push(Parameter {
283                    name: format!("output_{head}_{i}"),
284                    value: ((head + i + 3) as f64 * 0.1).cos() * 0.5,
285                    bounds: None,
286                });
287            }
288        }
289
290        // Build the flat cache: query | key | value | output
291        let mut flat_params: Vec<Parameter> = Vec::new();
292        flat_params.extend(query_params.iter().cloned());
293        flat_params.extend(key_params.iter().cloned());
294        flat_params.extend(value_params.iter().cloned());
295        flat_params.extend(output_params.iter().cloned());
296
297        Self {
298            num_qubits,
299            num_heads,
300            query_params,
301            key_params,
302            value_params,
303            output_params,
304            flat_params,
305        }
306    }
307
308    /// Rebuild the flat cache from the four per-group parameter vectors.
309    pub fn rebuild_flat_cache(&mut self) {
310        self.flat_params.clear();
311        self.flat_params.extend(self.query_params.iter().cloned());
312        self.flat_params.extend(self.key_params.iter().cloned());
313        self.flat_params.extend(self.value_params.iter().cloned());
314        self.flat_params.extend(self.output_params.iter().cloned());
315    }
316
317    /// Sync the four per-group parameter vectors from the flat cache.
318    pub fn sync_from_flat(&mut self) {
319        let qlen = self.query_params.len();
320        let klen = self.key_params.len();
321        let vlen = self.value_params.len();
322
323        for (i, p) in self.query_params.iter_mut().enumerate() {
324            if let Some(fp) = self.flat_params.get(i) {
325                p.value = fp.value;
326            }
327        }
328        for (i, p) in self.key_params.iter_mut().enumerate() {
329            if let Some(fp) = self.flat_params.get(qlen + i) {
330                p.value = fp.value;
331            }
332        }
333        for (i, p) in self.value_params.iter_mut().enumerate() {
334            if let Some(fp) = self.flat_params.get(qlen + klen + i) {
335                p.value = fp.value;
336            }
337        }
338        for (i, p) in self.output_params.iter_mut().enumerate() {
339            if let Some(fp) = self.flat_params.get(qlen + klen + vlen + i) {
340                p.value = fp.value;
341            }
342        }
343    }
344
345    /// Generate attention gates for a sequence
346    pub fn attention_gates(&self) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
347        let mut gates: Vec<Box<dyn GateOp>> = Vec::new();
348        let params_per_head = self.num_qubits / self.num_heads;
349
350        // For each attention head
351        for head in 0..self.num_heads {
352            let head_offset = head * params_per_head;
353
354            // Apply query transformations
355            for i in 0..params_per_head {
356                let qubit = QubitId((head_offset + i) as u32);
357                let param_idx = head * params_per_head + i;
358
359                gates.push(Box::new(ParametricRotationY {
360                    target: qubit,
361                    theta: crate::parametric::Parameter::Constant(
362                        self.query_params[param_idx].value,
363                    ),
364                }));
365            }
366
367            // Apply key transformations
368            for i in 0..params_per_head {
369                let qubit = QubitId((head_offset + i) as u32);
370                let param_idx = head * params_per_head + i;
371
372                gates.push(Box::new(ParametricRotationZ {
373                    target: qubit,
374                    theta: crate::parametric::Parameter::Constant(self.key_params[param_idx].value),
375                }));
376            }
377
378            // Add entanglement within head (for attention computation)
379            for i in 0..params_per_head - 1 {
380                let control = QubitId((head_offset + i) as u32);
381                let target = QubitId((head_offset + i + 1) as u32);
382                gates.push(Box::new(CNOT { control, target }));
383            }
384
385            // Apply value transformations
386            for i in 0..params_per_head {
387                let qubit = QubitId((head_offset + i) as u32);
388                let param_idx = head * params_per_head + i;
389
390                gates.push(Box::new(ParametricRotationX {
391                    target: qubit,
392                    theta: crate::parametric::Parameter::Constant(
393                        self.value_params[param_idx].value,
394                    ),
395                }));
396            }
397        }
398
399        // Add inter-head entanglement (for multi-head attention)
400        for head in 0..self.num_heads - 1 {
401            let control = QubitId((head * params_per_head) as u32);
402            let target = QubitId(((head + 1) * params_per_head) as u32);
403            gates.push(Box::new(CNOT { control, target }));
404        }
405
406        // Apply output projection
407        for i in 0..self.output_params.len() {
408            let qubit = QubitId(i as u32);
409            gates.push(Box::new(ParametricRotationY {
410                target: qubit,
411                theta: crate::parametric::Parameter::Constant(self.output_params[i].value),
412            }));
413        }
414
415        Ok(gates)
416    }
417}
418
419impl QMLLayer for QuantumAttention {
420    fn num_qubits(&self) -> usize {
421        self.num_qubits
422    }
423
424    fn parameters(&self) -> &[Parameter] {
425        // Return the pre-built flat cache [query | key | value | output].
426        // The cache is constructed in `new()` and can be refreshed with
427        // `rebuild_flat_cache()` if the per-group Vecs are mutated directly.
428        &self.flat_params
429    }
430
431    fn parameters_mut(&mut self) -> &mut [Parameter] {
432        // Callers may mutate via this slice; call sync_from_flat() afterwards
433        // to propagate changes back to the per-group parameter Vecs used in
434        // attention_gates().
435        &mut self.flat_params
436    }
437
438    fn gates(&self) -> Vec<Box<dyn GateOp>> {
439        self.attention_gates().unwrap_or_default()
440    }
441
442    fn compute_gradients(
443        &self,
444        _state: &Array1<Complex64>,
445        _loss_gradient: &Array1<Complex64>,
446    ) -> QuantRS2Result<Vec<f64>> {
447        let total_params = self.query_params.len()
448            + self.key_params.len()
449            + self.value_params.len()
450            + self.output_params.len();
451        Ok(vec![0.0; total_params])
452    }
453
454    fn name(&self) -> &'static str {
455        "QuantumAttention"
456    }
457}
458
459/// Quantum text classifier for sentiment analysis and text classification
460pub struct QuantumTextClassifier {
461    /// Configuration
462    config: QNLPConfig,
463    /// Word embedding layer
464    embedding: QuantumWordEmbedding,
465    /// Attention layers
466    attention_layers: Vec<QuantumAttention>,
467    /// Classification parameters
468    classifier_params: Vec<Parameter>,
469    /// Number of output classes
470    num_classes: usize,
471}
472
473impl QuantumTextClassifier {
474    /// Create a new quantum text classifier
475    pub fn new(config: QNLPConfig, num_classes: usize) -> Self {
476        let embedding = QuantumWordEmbedding::new(config.clone());
477
478        // Create multiple attention layers for deeper models
479        let mut attention_layers = Vec::new();
480        for _layer_idx in 0..2 {
481            // 2 attention layers
482            attention_layers.push(QuantumAttention::new(
483                config.text_qubits,
484                config.num_attention_heads,
485            ));
486        }
487
488        // Create classification parameters
489        let mut classifier_params = Vec::new();
490        for class in 0..num_classes {
491            for qubit in 0..config.feature_qubits {
492                classifier_params.push(Parameter {
493                    name: format!("classifier_{class}_{qubit}"),
494                    value: ((class + qubit) as f64 * 0.2).sin() * 0.3,
495                    bounds: None,
496                });
497            }
498        }
499
500        Self {
501            config,
502            embedding,
503            attention_layers,
504            classifier_params,
505            num_classes,
506        }
507    }
508
509    /// Classify a text sequence
510    pub fn classify(&self, word_ids: &[usize]) -> QuantRS2Result<Vec<f64>> {
511        // This would implement the full forward pass
512        // For now, return dummy probabilities
513        let mut probs = vec![1.0 / self.num_classes as f64; self.num_classes];
514
515        // Add some variation based on input
516        for (i, &word_id) in word_ids.iter().enumerate() {
517            let variation = ((word_id + i) as f64 * 0.1).sin() * 0.1;
518            probs[i % self.num_classes] += variation;
519        }
520
521        // Normalize probabilities
522        let sum: f64 = probs.iter().sum();
523        if sum > 0.0 {
524            for prob in &mut probs {
525                *prob /= sum;
526            }
527        }
528
529        Ok(probs)
530    }
531
532    /// Generate the full circuit for text classification
533    pub fn build_circuit(&self, word_ids: &[usize]) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
534        let mut gates = Vec::new();
535
536        // 1. Word embedding
537        gates.extend(self.embedding.encode_sequence(word_ids)?);
538
539        // 2. Attention layers
540        for attention in &self.attention_layers {
541            gates.extend(attention.attention_gates()?);
542        }
543
544        // 3. Feature extraction and pooling (using measurement-like operations)
545        // This would include global pooling operations
546        for qubit in 0..self.config.text_qubits {
547            gates.push(Box::new(Hadamard {
548                target: QubitId(qubit as u32),
549            }));
550        }
551
552        // 4. Classification layer
553        for (_class, chunk) in self
554            .classifier_params
555            .chunks(self.config.feature_qubits)
556            .enumerate()
557        {
558            for (i, param) in chunk.iter().enumerate() {
559                let qubit = QubitId(i as u32);
560                gates.push(Box::new(ParametricRotationY {
561                    target: qubit,
562                    theta: crate::parametric::Parameter::Constant(param.value),
563                }));
564            }
565        }
566
567        Ok(gates)
568    }
569
570    /// Train the classifier using a dataset
571    pub fn train(
572        &mut self,
573        training_data: &[(Vec<usize>, usize)],
574        learning_rate: f64,
575        epochs: usize,
576    ) -> QuantRS2Result<Vec<f64>> {
577        let mut losses = Vec::new();
578
579        for epoch in 0..epochs {
580            let mut epoch_loss = 0.0;
581
582            for (word_ids, true_label) in training_data {
583                // Forward pass
584                let predictions = self.classify(word_ids)?;
585
586                // Compute loss (cross-entropy)
587                let loss = -predictions[*true_label].ln();
588                epoch_loss += loss;
589
590                // Backward pass (simplified gradient computation)
591                // In practice, this would use automatic differentiation
592                self.update_parameters(predictions, *true_label, learning_rate)?;
593            }
594
595            epoch_loss /= training_data.len() as f64;
596            losses.push(epoch_loss);
597
598            if epoch % 10 == 0 {
599                println!("Epoch {epoch}: Loss = {epoch_loss:.4}");
600            }
601        }
602
603        Ok(losses)
604    }
605
606    /// Update parameters based on gradients (simplified)
607    fn update_parameters(
608        &mut self,
609        predictions: Vec<f64>,
610        true_label: usize,
611        learning_rate: f64,
612    ) -> QuantRS2Result<()> {
613        // Simplified parameter update
614        // In practice, would compute proper gradients using parameter shift rule
615
616        for (i, param) in self.classifier_params.iter_mut().enumerate() {
617            // All parameters are learnable in this simplified implementation
618            {
619                let class_idx = i / self.config.feature_qubits;
620                let error = if class_idx == true_label {
621                    predictions[class_idx] - 1.0
622                } else {
623                    predictions[class_idx]
624                };
625
626                // Simple gradient descent update
627                param.value -= learning_rate * error * 0.1;
628            }
629        }
630
631        Ok(())
632    }
633}
634
635/// Quantum language model for text generation
636pub struct QuantumLanguageModel {
637    /// Configuration
638    config: QNLPConfig,
639    /// Embedding layer
640    embedding: QuantumWordEmbedding,
641    /// Transformer layers
642    transformer_layers: Vec<QuantumAttention>,
643    /// Output parameters
644    output_params: Vec<Parameter>,
645}
646
647impl QuantumLanguageModel {
648    /// Create a new quantum language model
649    pub fn new(config: QNLPConfig) -> Self {
650        let embedding = QuantumWordEmbedding::new(config.clone());
651
652        // Create transformer layers
653        let mut transformer_layers = Vec::new();
654        for _layer in 0..3 {
655            // 3 transformer layers
656            transformer_layers.push(QuantumAttention::new(
657                config.text_qubits,
658                config.num_attention_heads,
659            ));
660        }
661
662        // Create output parameters for next token prediction
663        let mut output_params = Vec::new();
664        for token in 0..config.vocab_size {
665            output_params.push(Parameter {
666                name: format!("output_{token}"),
667                value: (token as f64 * 0.01).sin() * 0.1,
668                bounds: None,
669            });
670        }
671
672        Self {
673            config,
674            embedding,
675            transformer_layers,
676            output_params,
677        }
678    }
679
680    /// Generate next token probabilities given a context
681    pub fn predict_next_token(&self, context: &[usize]) -> QuantRS2Result<Vec<f64>> {
682        // Build circuit for the context
683        let _gates = self.build_circuit(context)?;
684
685        // Simulate the circuit (placeholder)
686        // In practice, would run the quantum circuit and measure
687
688        // Return dummy probabilities for now
689        let mut probs = vec![1.0 / self.config.vocab_size as f64; self.config.vocab_size];
690
691        // Add some variation based on context
692        for (i, &token) in context.iter().enumerate() {
693            let variation = ((token + i) as f64 * 0.05).sin() * 0.01;
694            probs[token % self.config.vocab_size] += variation;
695        }
696
697        // Normalize
698        let sum: f64 = probs.iter().sum();
699        if sum > 0.0 {
700            for prob in &mut probs {
701                *prob /= sum;
702            }
703        }
704
705        Ok(probs)
706    }
707
708    /// Generate text given a starting context
709    pub fn generate_text(
710        &self,
711        start_context: &[usize],
712        max_length: usize,
713        temperature: f64,
714    ) -> QuantRS2Result<Vec<usize>> {
715        let mut generated = start_context.to_vec();
716
717        for _step in 0..max_length {
718            // Get context (last N tokens)
719            let context_start = if generated.len() > self.config.max_sequence_length {
720                generated.len() - self.config.max_sequence_length
721            } else {
722                0
723            };
724            let context = &generated[context_start..];
725
726            // Predict next token
727            let mut probs = self.predict_next_token(context)?;
728
729            // Apply temperature scaling
730            if temperature != 1.0 {
731                for prob in &mut probs {
732                    *prob = (*prob).powf(1.0 / temperature);
733                }
734                let sum: f64 = probs.iter().sum();
735                for prob in &mut probs {
736                    *prob /= sum;
737                }
738            }
739
740            // Sample next token (using simple deterministic selection for now)
741            let next_token = probs
742                .iter()
743                .enumerate()
744                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
745                .map(|(i, _)| i)
746                .unwrap_or(0);
747
748            generated.push(next_token);
749        }
750
751        Ok(generated)
752    }
753
754    /// Build the full language model circuit
755    fn build_circuit(&self, context: &[usize]) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
756        let mut gates = Vec::new();
757
758        // 1. Embedding
759        gates.extend(self.embedding.encode_sequence(context)?);
760
761        // 2. Transformer layers
762        for transformer in &self.transformer_layers {
763            gates.extend(transformer.attention_gates()?);
764        }
765
766        // 3. Output projection
767        for (i, param) in self.output_params.iter().enumerate() {
768            let qubit = QubitId((i % self.config.text_qubits) as u32);
769            gates.push(Box::new(ParametricRotationZ {
770                target: qubit,
771                theta: crate::parametric::Parameter::Constant(param.value),
772            }));
773        }
774
775        Ok(gates)
776    }
777}
778
779#[cfg(test)]
780mod tests {
781    use super::*;
782
783    #[test]
784    fn test_quantum_word_embedding() {
785        let config = QNLPConfig {
786            vocab_size: 100,
787            text_qubits: 4,
788            ..Default::default()
789        };
790
791        let embedding = QuantumWordEmbedding::new(config);
792        assert_eq!(embedding.num_qubits(), 4);
793
794        // Test encoding a simple sequence
795        let word_ids = vec![1, 5, 10];
796        let gates = embedding
797            .encode_sequence(&word_ids)
798            .expect("Failed to encode sequence");
799        assert!(!gates.is_empty());
800    }
801
802    #[test]
803    fn test_quantum_attention() {
804        let attention = QuantumAttention::new(8, 2);
805        assert_eq!(attention.num_qubits(), 8);
806        assert_eq!(attention.num_heads, 2);
807
808        let gates = attention
809            .attention_gates()
810            .expect("Failed to get attention gates");
811        assert!(!gates.is_empty());
812    }
813
814    #[test]
815    fn test_quantum_text_classifier() {
816        let config = QNLPConfig {
817            vocab_size: 50,
818            text_qubits: 4,
819            feature_qubits: 2,
820            ..Default::default()
821        };
822
823        let classifier = QuantumTextClassifier::new(config, 3);
824
825        // Test classification
826        let word_ids = vec![1, 2, 3];
827        let probs = classifier
828            .classify(&word_ids)
829            .expect("Failed to classify text");
830        assert_eq!(probs.len(), 3);
831
832        // Check probabilities sum to 1
833        let sum: f64 = probs.iter().sum();
834        assert!((sum - 1.0).abs() < 1e-10);
835    }
836
837    #[test]
838    fn test_quantum_language_model() {
839        let config = QNLPConfig {
840            vocab_size: 20,
841            text_qubits: 4,
842            max_sequence_length: 8,
843            ..Default::default()
844        };
845
846        let lm = QuantumLanguageModel::new(config);
847
848        // Test next token prediction
849        let context = vec![1, 2, 3];
850        let probs = lm
851            .predict_next_token(&context)
852            .expect("Failed to predict next token");
853        assert_eq!(probs.len(), 20);
854
855        // Test text generation
856        let generated = lm
857            .generate_text(&context, 5, 1.0)
858            .expect("Failed to generate text");
859        assert_eq!(generated.len(), 8); // 3 context + 5 generated
860    }
861
862    #[test]
863    fn test_text_classifier_training() {
864        let config = QNLPConfig {
865            vocab_size: 10,
866            text_qubits: 3,
867            feature_qubits: 2,
868            ..Default::default()
869        };
870
871        let mut classifier = QuantumTextClassifier::new(config, 2);
872
873        // Create dummy training data
874        let training_data = vec![
875            (vec![1, 2], 0), // Class 0
876            (vec![3, 4], 1), // Class 1
877            (vec![1, 3], 0), // Class 0
878            (vec![2, 4], 1), // Class 1
879        ];
880
881        let losses = classifier
882            .train(&training_data, 0.01, 5)
883            .expect("Failed to train classifier");
884        assert_eq!(losses.len(), 5);
885    }
886}
887
888/// Advanced quantum NLP utilities and algorithms
889pub mod advanced {
890    use super::*;
891
892    /// Quantum text preprocessing utilities
893    pub struct QuantumTextPreprocessor {
894        /// Vocabulary mapping
895        vocab: HashMap<String, usize>,
896        /// Reverse vocabulary mapping
897        reverse_vocab: HashMap<usize, String>,
898        /// Special tokens
899        special_tokens: HashMap<String, usize>,
900    }
901
902    impl QuantumTextPreprocessor {
903        /// Create a new preprocessor
904        pub fn new() -> Self {
905            let mut special_tokens = HashMap::new();
906            special_tokens.insert("<PAD>".to_string(), 0);
907            special_tokens.insert("<UNK>".to_string(), 1);
908            special_tokens.insert("<START>".to_string(), 2);
909            special_tokens.insert("<END>".to_string(), 3);
910
911            Self {
912                vocab: HashMap::new(),
913                reverse_vocab: HashMap::new(),
914                special_tokens,
915            }
916        }
917
918        /// Build vocabulary from text corpus
919        pub fn build_vocab(&mut self, texts: &[String], max_vocab_size: usize) {
920            let mut word_counts: HashMap<String, usize> = HashMap::new();
921
922            // Count word frequencies
923            for text in texts {
924                for word in text.split_whitespace() {
925                    *word_counts.entry(word.to_lowercase()).or_insert(0) += 1;
926                }
927            }
928
929            // Sort by frequency and take top words
930            let mut word_freq: Vec<_> = word_counts.into_iter().collect();
931            word_freq.sort_by_key(|b| std::cmp::Reverse(b.1));
932
933            // Add special tokens first
934            for (token, id) in &self.special_tokens {
935                self.vocab.insert(token.clone(), *id);
936                self.reverse_vocab.insert(*id, token.clone());
937            }
938
939            // Add most frequent words
940            let mut vocab_id = self.special_tokens.len();
941            for (word, _count) in word_freq
942                .into_iter()
943                .take(max_vocab_size - self.special_tokens.len())
944            {
945                self.vocab.insert(word.clone(), vocab_id);
946                self.reverse_vocab.insert(vocab_id, word);
947                vocab_id += 1;
948            }
949        }
950
951        /// Tokenize text to word IDs
952        pub fn tokenize(&self, text: &str) -> Vec<usize> {
953            let mut tokens = vec![self.special_tokens["<START>"]];
954
955            for word in text.split_whitespace() {
956                let word = word.to_lowercase();
957                let token_id = self
958                    .vocab
959                    .get(&word)
960                    .copied()
961                    .unwrap_or_else(|| self.special_tokens["<UNK>"]);
962                tokens.push(token_id);
963            }
964
965            tokens.push(self.special_tokens["<END>"]);
966            tokens
967        }
968
969        /// Convert token IDs back to text
970        pub fn detokenize(&self, token_ids: &[usize]) -> String {
971            token_ids
972                .iter()
973                .filter_map(|&id| self.reverse_vocab.get(&id))
974                .filter(|&word| !["<PAD>", "<START>", "<END>"].contains(&word.as_str()))
975                .cloned()
976                .collect::<Vec<_>>()
977                .join(" ")
978        }
979
980        /// Get vocabulary size
981        pub fn vocab_size(&self) -> usize {
982            self.vocab.len()
983        }
984    }
985
986    /// Quantum semantic similarity computation
987    pub struct QuantumSemanticSimilarity {
988        /// Embedding dimension
989        embedding_dim: usize,
990        /// Number of qubits
991        num_qubits: usize,
992        /// Similarity computation parameters
993        similarity_params: Vec<Parameter>,
994    }
995
996    impl QuantumSemanticSimilarity {
997        /// Create a new quantum semantic similarity computer
998        pub fn new(embedding_dim: usize, num_qubits: usize) -> Self {
999            let mut similarity_params = Vec::new();
1000
1001            // Parameters for similarity computation
1002            for i in 0..num_qubits * 2 {
1003                // For two text inputs
1004                similarity_params.push(Parameter {
1005                    name: format!("sim_{i}"),
1006                    value: (i as f64 * 0.1).sin() * 0.5,
1007                    bounds: None,
1008                });
1009            }
1010
1011            Self {
1012                embedding_dim,
1013                num_qubits,
1014                similarity_params,
1015            }
1016        }
1017
1018        /// Compute semantic similarity between two texts
1019        pub fn compute_similarity(
1020            &self,
1021            text1_tokens: &[usize],
1022            text2_tokens: &[usize],
1023        ) -> QuantRS2Result<f64> {
1024            // Create embeddings for both texts
1025            let config = QNLPConfig {
1026                text_qubits: self.num_qubits,
1027                vocab_size: 1000, // Default
1028                ..Default::default()
1029            };
1030
1031            let embedding1 = QuantumWordEmbedding::new(config.clone());
1032            let embedding2 = QuantumWordEmbedding::new(config);
1033
1034            // Generate quantum circuits for both texts
1035            let gates1 = embedding1.encode_sequence(text1_tokens)?;
1036            let gates2 = embedding2.encode_sequence(text2_tokens)?;
1037
1038            // Compute similarity using quantum interference
1039            // This is a simplified version - full implementation would measure overlap
1040            let similarity = self.quantum_text_overlap(gates1, gates2)?;
1041
1042            Ok(similarity)
1043        }
1044
1045        /// Compute quantum overlap between two text representations
1046        fn quantum_text_overlap(
1047            &self,
1048            _gates1: Vec<Box<dyn GateOp>>,
1049            _gates2: Vec<Box<dyn GateOp>>,
1050        ) -> QuantRS2Result<f64> {
1051            // Placeholder for quantum overlap computation
1052            // In practice, would:
1053            // 1. Prepare states using gates1 and gates2
1054            // 2. Compute fidelity/overlap between states
1055            // 3. Return similarity score
1056
1057            // Return dummy similarity for now
1058            Ok(0.7)
1059        }
1060    }
1061
1062    /// Quantum text summarization model
1063    pub struct QuantumTextSummarizer {
1064        /// Configuration
1065        config: QNLPConfig,
1066        /// Encoder for input text
1067        encoder: QuantumWordEmbedding,
1068        /// Attention mechanism for importance scoring
1069        attention: QuantumAttention,
1070        /// Summary generation parameters
1071        summary_params: Vec<Parameter>,
1072    }
1073
1074    impl QuantumTextSummarizer {
1075        /// Create a new quantum text summarizer
1076        pub fn new(config: QNLPConfig) -> Self {
1077            let encoder = QuantumWordEmbedding::new(config.clone());
1078            let attention = QuantumAttention::new(config.text_qubits, config.num_attention_heads);
1079
1080            let mut summary_params = Vec::new();
1081            for i in 0..config.text_qubits {
1082                summary_params.push(Parameter {
1083                    name: format!("summary_{i}"),
1084                    value: (i as f64 * 0.15).sin() * 0.4,
1085                    bounds: None,
1086                });
1087            }
1088
1089            Self {
1090                config,
1091                encoder,
1092                attention,
1093                summary_params,
1094            }
1095        }
1096
1097        /// Generate extractive summary from input text
1098        pub fn extractive_summarize(
1099            &self,
1100            text_tokens: &[usize],
1101            summary_length: usize,
1102        ) -> QuantRS2Result<Vec<usize>> {
1103            // Encode input text
1104            let _encoding_gates = self.encoder.encode_sequence(text_tokens)?;
1105
1106            // Apply attention to find important tokens
1107            let _attention_gates = self.attention.attention_gates()?;
1108
1109            // Score tokens for importance (simplified)
1110            let mut token_scores = Vec::new();
1111            for (i, &token) in text_tokens.iter().enumerate() {
1112                // Simple scoring based on token frequency and position
1113                let position_weight = (i as f64 / text_tokens.len() as f64).mul_add(-0.5, 1.0);
1114                let token_weight = (token as f64 * 0.1).sin().abs();
1115                let score = position_weight * token_weight;
1116                token_scores.push((i, token, score));
1117            }
1118
1119            // Sort by score and select top tokens
1120            token_scores.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
1121
1122            let mut summary_tokens = Vec::new();
1123            for (_, token, _) in token_scores.into_iter().take(summary_length) {
1124                summary_tokens.push(token);
1125            }
1126
1127            Ok(summary_tokens)
1128        }
1129
1130        /// Generate abstractive summary (placeholder)
1131        pub fn abstractive_summarize(
1132            &self,
1133            _text_tokens: &[usize],
1134            _summary_length: usize,
1135        ) -> QuantRS2Result<Vec<usize>> {
1136            // Placeholder for abstractive summarization
1137            // Would use sequence-to-sequence quantum model
1138            Ok(vec![1, 2, 3]) // Dummy summary
1139        }
1140    }
1141
1142    /// Quantum named entity recognition
1143    pub struct QuantumNamedEntityRecognition {
1144        /// Configuration
1145        config: QNLPConfig,
1146        /// Token encoder
1147        encoder: QuantumWordEmbedding,
1148        /// Entity type classifiers
1149        entity_classifiers: HashMap<String, Vec<Parameter>>,
1150        /// Supported entity types
1151        entity_types: Vec<String>,
1152    }
1153
1154    impl QuantumNamedEntityRecognition {
1155        /// Create a new quantum NER model
1156        pub fn new(config: QNLPConfig) -> Self {
1157            let encoder = QuantumWordEmbedding::new(config.clone());
1158            let entity_types = vec![
1159                "PERSON".to_string(),
1160                "ORGANIZATION".to_string(),
1161                "LOCATION".to_string(),
1162                "DATE".to_string(),
1163                "MONEY".to_string(),
1164            ];
1165
1166            let mut entity_classifiers = HashMap::new();
1167            for entity_type in &entity_types {
1168                let mut classifier_params = Vec::new();
1169                for i in 0..config.text_qubits {
1170                    classifier_params.push(Parameter {
1171                        name: format!("{entity_type}_{i}"),
1172                        value: (i as f64).mul_add(0.1, entity_type.len() as f64).sin() * 0.3,
1173                        bounds: None,
1174                    });
1175                }
1176                entity_classifiers.insert(entity_type.clone(), classifier_params);
1177            }
1178
1179            Self {
1180                config,
1181                encoder,
1182                entity_classifiers,
1183                entity_types,
1184            }
1185        }
1186
1187        /// Recognize named entities in text
1188        pub fn recognize_entities(
1189            &self,
1190            text_tokens: &[usize],
1191        ) -> QuantRS2Result<Vec<(usize, usize, String)>> {
1192            let mut entities = Vec::new();
1193
1194            // Simple sliding window approach
1195            for start in 0..text_tokens.len() {
1196                for end in start + 1..=text_tokens.len().min(start + 5) {
1197                    // Max entity length 5
1198                    let entity_tokens = &text_tokens[start..end];
1199
1200                    // Classify this span
1201                    if let Some(entity_type) = self.classify_span(entity_tokens)? {
1202                        entities.push((start, end, entity_type));
1203                    }
1204                }
1205            }
1206
1207            // Remove overlapping entities (keep longer ones)
1208            entities.sort_by_key(|b| std::cmp::Reverse(b.1 - b.0));
1209            let mut final_entities = Vec::new();
1210            let mut used_positions = vec![false; text_tokens.len()];
1211
1212            for (start, end, entity_type) in entities {
1213                if used_positions[start..end].iter().all(|&used| !used) {
1214                    for pos in start..end {
1215                        used_positions[pos] = true;
1216                    }
1217                    final_entities.push((start, end, entity_type));
1218                }
1219            }
1220
1221            final_entities.sort_by_key(|&(start, _, _)| start);
1222            Ok(final_entities)
1223        }
1224
1225        /// Classify a span of tokens as an entity type
1226        fn classify_span(&self, tokens: &[usize]) -> QuantRS2Result<Option<String>> {
1227            // Encode the span
1228            let _encoding_gates = self.encoder.encode_sequence(tokens)?;
1229
1230            let mut best_score = 0.0;
1231            let mut best_type = None;
1232
1233            // Score each entity type
1234            for entity_type in &self.entity_types {
1235                let score = self.compute_entity_score(tokens, entity_type)?;
1236                if score > best_score && score > 0.5 {
1237                    // Threshold
1238                    best_score = score;
1239                    best_type = Some(entity_type.clone());
1240                }
1241            }
1242
1243            Ok(best_type)
1244        }
1245
1246        /// Compute score for a specific entity type
1247        fn compute_entity_score(&self, tokens: &[usize], entity_type: &str) -> QuantRS2Result<f64> {
1248            // Simple scoring based on token patterns
1249            let mut score = 0.0;
1250
1251            for &token in tokens {
1252                // Simple heuristics based on token ID patterns
1253                match entity_type {
1254                    "PERSON" => {
1255                        if token % 7 == 1 {
1256                            // Arbitrary pattern for person names
1257                            score += 0.3;
1258                        }
1259                    }
1260                    "LOCATION" => {
1261                        if token % 5 == 2 {
1262                            // Arbitrary pattern for locations
1263                            score += 0.3;
1264                        }
1265                    }
1266                    "ORGANIZATION" => {
1267                        if token % 11 == 3 {
1268                            // Arbitrary pattern for organizations
1269                            score += 0.3;
1270                        }
1271                    }
1272                    "DATE" => {
1273                        if token % 13 == 4 {
1274                            // Arbitrary pattern for dates
1275                            score += 0.3;
1276                        }
1277                    }
1278                    "MONEY" => {
1279                        if token % 17 == 5 {
1280                            // Arbitrary pattern for money
1281                            score += 0.3;
1282                        }
1283                    }
1284                    _ => {}
1285                }
1286            }
1287
1288            score /= tokens.len() as f64; // Normalize by span length
1289            Ok(score)
1290        }
1291    }
1292}
1293
1294// Re-export advanced utilities
1295pub use advanced::*;