qtransformers_core/
lib.rs

1use numpy::{PyArray2, PyReadonlyArray2};
2use pyo3::prelude::*;
3use std::cmp::Ordering;
4
5use rand::distributions::WeightedIndex;
6use rand::prelude::*;
7use rayon::prelude::*;
8
9/// High-performance quantum-inspired attention kernel
10#[pyfunction]
11fn quantum_attention_rs<'py>(
12    py: Python<'py>,
13    q: PyReadonlyArray2<'py, f32>,
14    k: PyReadonlyArray2<'py, f32>,
15    v: PyReadonlyArray2<'py, f32>,
16    top_k: usize,
17) -> PyResult<Bound<'py, PyArray2<f32>>> {
18    let q = q.as_array();
19    let k = k.as_array();
20    let v = v.as_array();
21
22    let seq_len = q.shape()[0];
23    let d_model = q.shape()[1];
24
25    if k.shape()[0] != seq_len
26        || v.shape()[0] != seq_len
27        || k.shape()[1] != d_model
28        || v.shape()[1] != d_model
29    {
30        return Err(pyo3::exceptions::PyValueError::new_err(
31            "q, k, v must have shapes (seq_len, d_model) with matching dims",
32        ));
33    }
34
35    // Sampling-based approximation of attention probabilities
36    // logits_i = (K · q_i) / sqrt(d)
37    let scale = (d_model as f32).sqrt();
38    let num_samples = top_k.max(1).min(seq_len.max(1));
39
40    // Compute per-row outputs in parallel
41    let rows: Vec<Vec<f32>> = (0..seq_len)
42        .into_par_iter()
43        .map(|i| {
44            // Compute logits for row i: K @ q[i]
45            let mut logits: Vec<f32> = vec![0.0; seq_len];
46            for j in 0..seq_len {
47                let mut dot = 0.0f32;
48                for d in 0..d_model {
49                    dot += k[[j, d]] * q[[i, d]];
50                }
51                logits[j] = dot / scale;
52            }
53
54            // Amplitude encoding and probabilities: p_j ∝ exp(logit/2)^2 = exp(logit)
55            let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
56            let weights: Vec<f32> = logits.iter().map(|&x| (x - max_logit).exp()).collect();
57
58            let sum_w: f32 = weights.iter().sum();
59            if !sum_w.is_finite() || sum_w <= 0.0 {
60                return vec![0.0f32; d_model];
61            }
62
63            // Build categorical sampler
64            let dist = match WeightedIndex::new(&weights) {
65                Ok(d) => d,
66                Err(_) => return vec![0.0f32; d_model],
67            };
68            let mut rng = thread_rng();
69
70            // Empirical probability via sampling
71            let mut counts = vec![0usize; seq_len];
72            for _ in 0..num_samples {
73                let j = dist.sample(&mut rng);
74                counts[j] += 1;
75            }
76            let inv_samples = 1.0f32 / (num_samples as f32);
77
78            // Weighted sum over V using empirical probabilities
79            let mut out = vec![0.0f32; d_model];
80            for (j, &c) in counts.iter().enumerate() {
81                if c > 0 {
82                    let w = (c as f32) * inv_samples;
83                    for d in 0..d_model {
84                        out[d] += v[[j, d]] * w;
85                    }
86                }
87            }
88
89            out
90        })
91        .collect();
92
93    // Create 2D array from rows
94    let result = PyArray2::from_vec2_bound(py, &rows)
95        .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Shape error: {}", e)))?;
96
97    Ok(result)
98}
99
100/// Classical efficient attention approximation
101#[pyfunction]
102fn classical_attention_rs<'py>(
103    py: Python<'py>,
104    q: PyReadonlyArray2<'py, f32>,
105    k: PyReadonlyArray2<'py, f32>,
106    v: PyReadonlyArray2<'py, f32>,
107    top_k: usize,
108) -> PyResult<Bound<'py, PyArray2<f32>>> {
109    let q = q.as_array();
110    let k = k.as_array();
111    let v = v.as_array();
112
113    let seq_len = q.shape()[0];
114    let d_model = q.shape()[1];
115
116    if k.shape()[0] != seq_len
117        || v.shape()[0] != seq_len
118        || k.shape()[1] != d_model
119        || v.shape()[1] != d_model
120    {
121        return Err(pyo3::exceptions::PyValueError::new_err(
122            "q, k, v must have shapes (seq_len, d_model) with matching dims",
123        ));
124    }
125
126    let scale = (d_model as f32).sqrt();
127    let k_eff = top_k.clamp(1, seq_len.max(1));
128
129    // Compute per-row outputs in parallel
130    let rows: Vec<Vec<f32>> = (0..seq_len)
131        .into_par_iter()
132        .map(|i| {
133            // Compute scores for row i: K @ q[i] / sqrt(d)
134            let mut scores: Vec<f32> = vec![0.0; seq_len];
135            for j in 0..seq_len {
136                let mut dot = 0.0f32;
137                for d in 0..d_model {
138                    dot += k[[j, d]] * q[[i, d]];
139                }
140                scores[j] = dot / scale;
141            }
142
143            // Select top-k indices by score (descending)
144            let mut idx: Vec<usize> = (0..seq_len).collect();
145            idx.sort_unstable_by(|&a, &b| {
146                let sa = scores[a];
147                let sb = scores[b];
148                sb.partial_cmp(&sa).unwrap_or(Ordering::Equal)
149            });
150            let top_idx = &idx[0..k_eff];
151
152            // Stable softmax over top-k
153            let max_top = top_idx
154                .iter()
155                .map(|&j| scores[j])
156                .fold(f32::NEG_INFINITY, f32::max);
157            let mut exp_scores: Vec<f32> = top_idx
158                .iter()
159                .map(|&j| (scores[j] - max_top).exp())
160                .collect();
161            let sum_exp: f32 = exp_scores.iter().sum();
162            if !sum_exp.is_finite() || sum_exp <= 0.0 {
163                return vec![0.0f32; d_model];
164            }
165            for e in exp_scores.iter_mut() {
166                *e /= sum_exp;
167            }
168
169            // Weighted sum of V rows
170            let mut out = vec![0.0f32; d_model];
171            for (w, &j) in exp_scores.iter().zip(top_idx.iter()) {
172                for d in 0..d_model {
173                    out[d] += v[[j, d]] * *w;
174                }
175            }
176            out
177        })
178        .collect();
179
180    // Create 2D array from rows
181    let result = PyArray2::from_vec2_bound(py, &rows)
182        .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Shape error: {}", e)))?;
183
184    Ok(result)
185}
186
187/// Python module definition
188#[pymodule]
189fn qtransformers_core(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
190    m.add_function(wrap_pyfunction!(quantum_attention_rs, m)?)?;
191    m.add_function(wrap_pyfunction!(classical_attention_rs, m)?)?;
192    Ok(())
193}
194
195// Note: Unit tests for PyO3 extension modules cannot run via `cargo test`
196// because they require Python symbols at link time. Tests are run via:
197//   maturin develop && pytest tests/rust/