qtransformers_core/
lib.rs1use numpy::{PyArray2, PyReadonlyArray2};
2use pyo3::prelude::*;
3use std::cmp::Ordering;
4
5use rand::distributions::WeightedIndex;
6use rand::prelude::*;
7use rayon::prelude::*;
8
9#[pyfunction]
11fn quantum_attention_rs<'py>(
12 py: Python<'py>,
13 q: PyReadonlyArray2<'py, f32>,
14 k: PyReadonlyArray2<'py, f32>,
15 v: PyReadonlyArray2<'py, f32>,
16 top_k: usize,
17) -> PyResult<Bound<'py, PyArray2<f32>>> {
18 let q = q.as_array();
19 let k = k.as_array();
20 let v = v.as_array();
21
22 let seq_len = q.shape()[0];
23 let d_model = q.shape()[1];
24
25 if k.shape()[0] != seq_len
26 || v.shape()[0] != seq_len
27 || k.shape()[1] != d_model
28 || v.shape()[1] != d_model
29 {
30 return Err(pyo3::exceptions::PyValueError::new_err(
31 "q, k, v must have shapes (seq_len, d_model) with matching dims",
32 ));
33 }
34
35 let scale = (d_model as f32).sqrt();
38 let num_samples = top_k.max(1).min(seq_len.max(1));
39
40 let rows: Vec<Vec<f32>> = (0..seq_len)
42 .into_par_iter()
43 .map(|i| {
44 let mut logits: Vec<f32> = vec![0.0; seq_len];
46 for j in 0..seq_len {
47 let mut dot = 0.0f32;
48 for d in 0..d_model {
49 dot += k[[j, d]] * q[[i, d]];
50 }
51 logits[j] = dot / scale;
52 }
53
54 let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
56 let weights: Vec<f32> = logits.iter().map(|&x| (x - max_logit).exp()).collect();
57
58 let sum_w: f32 = weights.iter().sum();
59 if !sum_w.is_finite() || sum_w <= 0.0 {
60 return vec![0.0f32; d_model];
61 }
62
63 let dist = match WeightedIndex::new(&weights) {
65 Ok(d) => d,
66 Err(_) => return vec![0.0f32; d_model],
67 };
68 let mut rng = thread_rng();
69
70 let mut counts = vec![0usize; seq_len];
72 for _ in 0..num_samples {
73 let j = dist.sample(&mut rng);
74 counts[j] += 1;
75 }
76 let inv_samples = 1.0f32 / (num_samples as f32);
77
78 let mut out = vec![0.0f32; d_model];
80 for (j, &c) in counts.iter().enumerate() {
81 if c > 0 {
82 let w = (c as f32) * inv_samples;
83 for d in 0..d_model {
84 out[d] += v[[j, d]] * w;
85 }
86 }
87 }
88
89 out
90 })
91 .collect();
92
93 let result = PyArray2::from_vec2_bound(py, &rows)
95 .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Shape error: {}", e)))?;
96
97 Ok(result)
98}
99
100#[pyfunction]
102fn classical_attention_rs<'py>(
103 py: Python<'py>,
104 q: PyReadonlyArray2<'py, f32>,
105 k: PyReadonlyArray2<'py, f32>,
106 v: PyReadonlyArray2<'py, f32>,
107 top_k: usize,
108) -> PyResult<Bound<'py, PyArray2<f32>>> {
109 let q = q.as_array();
110 let k = k.as_array();
111 let v = v.as_array();
112
113 let seq_len = q.shape()[0];
114 let d_model = q.shape()[1];
115
116 if k.shape()[0] != seq_len
117 || v.shape()[0] != seq_len
118 || k.shape()[1] != d_model
119 || v.shape()[1] != d_model
120 {
121 return Err(pyo3::exceptions::PyValueError::new_err(
122 "q, k, v must have shapes (seq_len, d_model) with matching dims",
123 ));
124 }
125
126 let scale = (d_model as f32).sqrt();
127 let k_eff = top_k.clamp(1, seq_len.max(1));
128
129 let rows: Vec<Vec<f32>> = (0..seq_len)
131 .into_par_iter()
132 .map(|i| {
133 let mut scores: Vec<f32> = vec![0.0; seq_len];
135 for j in 0..seq_len {
136 let mut dot = 0.0f32;
137 for d in 0..d_model {
138 dot += k[[j, d]] * q[[i, d]];
139 }
140 scores[j] = dot / scale;
141 }
142
143 let mut idx: Vec<usize> = (0..seq_len).collect();
145 idx.sort_unstable_by(|&a, &b| {
146 let sa = scores[a];
147 let sb = scores[b];
148 sb.partial_cmp(&sa).unwrap_or(Ordering::Equal)
149 });
150 let top_idx = &idx[0..k_eff];
151
152 let max_top = top_idx
154 .iter()
155 .map(|&j| scores[j])
156 .fold(f32::NEG_INFINITY, f32::max);
157 let mut exp_scores: Vec<f32> = top_idx
158 .iter()
159 .map(|&j| (scores[j] - max_top).exp())
160 .collect();
161 let sum_exp: f32 = exp_scores.iter().sum();
162 if !sum_exp.is_finite() || sum_exp <= 0.0 {
163 return vec![0.0f32; d_model];
164 }
165 for e in exp_scores.iter_mut() {
166 *e /= sum_exp;
167 }
168
169 let mut out = vec![0.0f32; d_model];
171 for (w, &j) in exp_scores.iter().zip(top_idx.iter()) {
172 for d in 0..d_model {
173 out[d] += v[[j, d]] * *w;
174 }
175 }
176 out
177 })
178 .collect();
179
180 let result = PyArray2::from_vec2_bound(py, &rows)
182 .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Shape error: {}", e)))?;
183
184 Ok(result)
185}
186
187#[pymodule]
189fn qtransformers_core(m: &Bound<'_, pyo3::types::PyModule>) -> PyResult<()> {
190 m.add_function(wrap_pyfunction!(quantum_attention_rs, m)?)?;
191 m.add_function(wrap_pyfunction!(classical_attention_rs, m)?)?;
192 Ok(())
193}
194
195