Skip to main content

rlx_clinicalbert/
runner.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! High-level [`ClinicalBertRunner`] — config + safetensors → forward + pool.
17
18use std::path::{Path, PathBuf};
19
20use anyhow::{Context, Result, bail};
21use rlx_core::config::BertConfig;
22use rlx_core::flow_util::compile_built;
23use rlx_core::validate_standard_device;
24use rlx_core::weight_map::WeightMap;
25use rlx_runtime::{CompiledGraph, Device};
26
27use crate::builder::build_clinicalbert_built;
28#[cfg(feature = "mlm")]
29use crate::builder::build_clinicalbert_with_mlm_built;
30use crate::config::{ClinicalBertConfig, ClinicalBertVariant, validate_hf_config};
31#[cfg(feature = "mlm")]
32use crate::heads::MlmHead;
33#[cfg(feature = "pooler")]
34use crate::heads::PoolerHead;
35
36/// Pooling strategy for sentence-level embeddings (matches HF reference).
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub enum Pooling {
39    /// `[CLS]` (first token) hidden state.
40    Cls,
41    /// Attention-mask-weighted mean over tokens.
42    Mean,
43    /// No pooling — return raw `[batch, seq, hidden]` hidden states.
44    None,
45}
46
47impl Pooling {
48    pub fn from_str_opt(s: &str) -> Option<Self> {
49        match s.to_ascii_lowercase().as_str() {
50            "cls" => Some(Pooling::Cls),
51            "mean" | "avg" | "average" => Some(Pooling::Mean),
52            "none" | "raw" => Some(Pooling::None),
53            _ => None,
54        }
55    }
56}
57
58/// Where the Masked-Language-Model head — `dense(H→H) + GeLU + LN +
59/// tied-decoder(H→V) + bias` — runs. Both modes produce numerically
60/// equivalent logits (cos > 0.999999, drift from GEMM reduction order).
61///
62/// Measured on Bio_ClinicalBERT (seq=32, RTX 4090 + Intel x86, total ms):
63///
64/// | Device  | B=1  | B=8     | B=32     | Winner @ B=32 |
65/// |---------|------|---------|----------|---------------|
66/// | CPU+MKL | 52.9 / 55.4 | 135.8 / **124.1** | 437.0 / **399.1** | InGraph |
67/// | CUDA    | 9.4 / **6.1**  | 30.2 / **28.6**  | **79.8** / 96.6  | Cpu     |
68///
69/// (`Cpu / InGraph`; bold = faster.) CUDA crosses over near B=8 because the
70/// host sgemm of the H×V decoder matmul stops being the bottleneck once
71/// CUDA's encoder runs in single-digit ms.
72#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
73pub enum MlmExecMode {
74    /// CPU post-process via `rlx_cpu::blas::sgemm_bias` (Accelerate / MKL /
75    /// OpenBLAS depending on the rlx-cpu link).
76    Cpu,
77    /// Head appended to the encoder's IR graph as a second output — runs on
78    /// the same backend as the encoder.
79    InGraph,
80    /// Resolve at build time: `Cpu` for CUDA with batch > 8, `InGraph`
81    /// otherwise (the table above is the source of this rule).
82    #[default]
83    Auto,
84}
85
86impl MlmExecMode {
87    /// Resolve [`MlmExecMode::Auto`] against `(device, batch)`. Returns
88    /// `Cpu` or `InGraph` — non-`Auto` inputs pass through unchanged.
89    pub fn resolve(self, device: Device, batch: usize) -> MlmExecMode {
90        match self {
91            MlmExecMode::Cpu | MlmExecMode::InGraph => self,
92            MlmExecMode::Auto => match device {
93                Device::Cuda if batch > 8 => MlmExecMode::Cpu,
94                _ => MlmExecMode::InGraph,
95            },
96        }
97    }
98
99    pub fn from_str_opt(s: &str) -> Option<Self> {
100        match s.to_ascii_lowercase().as_str() {
101            "cpu" | "post" | "host" => Some(MlmExecMode::Cpu),
102            "ingraph" | "in-graph" | "in_graph" | "graph" | "fold" | "folded" => {
103                Some(MlmExecMode::InGraph)
104            }
105            "auto" | "default" => Some(MlmExecMode::Auto),
106            _ => None,
107        }
108    }
109}
110
111/// Compiled ClinicalBERT encoder.
112///
113/// Each `forward_*` call must match the compiled `(batch, seq)`. Call
114/// [`ClinicalBertRunner::recompile`] to retarget the graph for a new shape
115/// (cached: a no-op when dims are unchanged).
116pub struct ClinicalBertRunner {
117    config: ClinicalBertConfig,
118    weights_path: PathBuf,
119    compiled: CompiledGraph,
120    compiled_bs: (usize, usize),
121    device: Device,
122    pooling: Pooling,
123    #[cfg(feature = "pooler")]
124    pooler_head: Option<PoolerHead>,
125    #[cfg(feature = "mlm")]
126    mlm_head: Option<MlmHead>,
127    /// `true` when the compiled graph embeds the MLM head as a second
128    /// output (`with_mlm_in_graph()`). In that mode `forward()` caches
129    /// the head's `mlm_logits [B,S,V]` output so `mlm_logits()` /
130    /// `mlm_logits_into()` return it directly without a CPU matmul.
131    #[cfg(feature = "mlm")]
132    mlm_in_graph: bool,
133    /// Last forward's cached `mlm_logits` when `mlm_in_graph == true`.
134    #[cfg(feature = "mlm")]
135    cached_mlm_logits: Option<Vec<f32>>,
136}
137
138impl ClinicalBertRunner {
139    pub fn builder() -> ClinicalBertRunnerBuilder {
140        ClinicalBertRunnerBuilder::default()
141    }
142
143    pub fn config(&self) -> &ClinicalBertConfig {
144        &self.config
145    }
146
147    pub fn hidden_size(&self) -> usize {
148        self.config.bert.hidden_size
149    }
150
151    pub fn device(&self) -> Device {
152        self.device
153    }
154
155    pub fn pooling(&self) -> Pooling {
156        self.pooling
157    }
158
159    pub fn compiled_shape(&self) -> (usize, usize) {
160        self.compiled_bs
161    }
162
163    /// `true` when the builder was called with `.with_pooler()` and the
164    /// pooler weights were found in the checkpoint.
165    #[cfg(feature = "pooler")]
166    pub fn has_pooler(&self) -> bool {
167        self.pooler_head.is_some()
168    }
169
170    /// `true` when the builder was called with `.with_mlm()` and the MLM head
171    /// weights were found in the checkpoint.
172    #[cfg(feature = "mlm")]
173    pub fn has_mlm(&self) -> bool {
174        self.mlm_head.is_some() || self.mlm_in_graph
175    }
176
177    /// `true` when the MLM head runs in [`MlmExecMode::InGraph`] mode.
178    #[cfg(feature = "mlm")]
179    pub fn mlm_in_graph(&self) -> bool {
180        self.mlm_in_graph
181    }
182
183    /// The resolved [`MlmExecMode`] the runner is executing in (never
184    /// `Auto`), or `None` when the MLM head is disabled.
185    #[cfg(feature = "mlm")]
186    pub fn mlm_mode(&self) -> Option<MlmExecMode> {
187        if self.mlm_in_graph {
188            Some(MlmExecMode::InGraph)
189        } else if self.mlm_head.is_some() {
190            Some(MlmExecMode::Cpu)
191        } else {
192            None
193        }
194    }
195
196    /// Pooler output `[batch, hidden_size]` = `tanh(W · h_cls + b)`.
197    ///
198    /// `hidden` must be the encoder output for the same `(batch, seq)` the
199    /// runner is compiled for. Returns an error if `.with_pooler()` wasn't
200    /// called at build time.
201    #[cfg(feature = "pooler")]
202    pub fn pooler_output(&self, hidden: &[f32]) -> Result<Vec<f32>> {
203        let head = self.pooler_head.as_ref().ok_or_else(|| {
204            anyhow::anyhow!(
205                "rlx-clinicalbert: pooler not enabled — call .with_pooler() on the builder"
206            )
207        })?;
208        let (b, s) = self.compiled_bs;
209        head.apply(hidden, b, s)
210    }
211
212    /// MLM logits `[batch, seq, vocab_size]`. In [`MlmExecMode::InGraph`]
213    /// mode this returns the cached output of the last [`Self::forward`]
214    /// call (no compute); in [`MlmExecMode::Cpu`] mode it runs the CPU
215    /// post-process head on `hidden`.
216    #[cfg(feature = "mlm")]
217    pub fn mlm_logits(&self, hidden: &[f32]) -> Result<Vec<f32>> {
218        if self.mlm_in_graph {
219            return self.cached_mlm_logits.clone().ok_or_else(|| {
220                anyhow::anyhow!(
221                    "rlx-clinicalbert: call forward() first to populate the in-graph MLM logits"
222                )
223            });
224        }
225        let head = self
226            .mlm_head
227            .as_ref()
228            .ok_or_else(|| anyhow::anyhow!("rlx-clinicalbert: MLM head not enabled — call .with_mlm() or .with_mlm_in_graph() on the builder"))?;
229        let (b, s) = self.compiled_bs;
230        head.apply(hidden, b, s)
231    }
232
233    /// MLM logits into a caller-provided buffer (zero-allocation hot path).
234    /// `logits.len()` must equal `batch * seq * vocab_size`. Use
235    /// [`Self::allocate_mlm_logits`] to size it.
236    #[cfg(feature = "mlm")]
237    pub fn mlm_logits_into(&self, hidden: &[f32], logits: &mut [f32]) -> Result<()> {
238        if self.mlm_in_graph {
239            let src = self.cached_mlm_logits.as_ref().ok_or_else(|| {
240                anyhow::anyhow!(
241                    "rlx-clinicalbert: call forward() first to populate the in-graph MLM logits"
242                )
243            })?;
244            if logits.len() != src.len() {
245                bail!(
246                    "rlx-clinicalbert: mlm_logits_into expected buffer of {} floats, got {}",
247                    src.len(),
248                    logits.len()
249                );
250            }
251            logits.copy_from_slice(src);
252            return Ok(());
253        }
254        let head = self
255            .mlm_head
256            .as_ref()
257            .ok_or_else(|| anyhow::anyhow!("rlx-clinicalbert: MLM head not enabled — call .with_mlm() or .with_mlm_in_graph() on the builder"))?;
258        let (b, s) = self.compiled_bs;
259        head.apply_into(hidden, b, s, logits)
260    }
261
262    /// Allocate a buffer sized for [`Self::mlm_logits_into`].
263    #[cfg(feature = "mlm")]
264    pub fn allocate_mlm_logits(&self) -> Result<Vec<f32>> {
265        if self.mlm_in_graph {
266            let (b, s) = self.compiled_bs;
267            return Ok(vec![0f32; b * s * self.config.bert.vocab_size]);
268        }
269        let head = self
270            .mlm_head
271            .as_ref()
272            .ok_or_else(|| anyhow::anyhow!("rlx-clinicalbert: MLM head not enabled"))?;
273        let (b, s) = self.compiled_bs;
274        Ok(head.allocate_logits_buffer(b, s))
275    }
276
277    /// Retarget the compiled graph for a new `(batch, seq)`.
278    pub fn recompile(&mut self, batch: usize, seq: usize) -> Result<()> {
279        if self.compiled_bs == (batch, seq) {
280            return Ok(());
281        }
282        let mut wm = if self.weights_path.is_dir() {
283            WeightMap::from_resolved_path(&self.weights_path)
284        } else {
285            WeightMap::from_file(self.weights_path.to_str().ok_or_else(|| {
286                anyhow::anyhow!(
287                    "rlx-clinicalbert: non-UTF8 weights path {:?}",
288                    self.weights_path
289                )
290            })?)
291        }?;
292        let built = build_clinicalbert_built(&self.config.bert, &mut wm, batch, seq)?;
293        self.compiled = compile_built(built, self.device)?;
294        self.compiled_bs = (batch, seq);
295        Ok(())
296    }
297
298    /// Raw forward — returns flat `[batch * seq * hidden]` F32 hidden states.
299    /// All four inputs are flattened `[batch, seq]` F32 buffers.
300    ///
301    /// In [`MlmExecMode::InGraph`] mode `mlm_logits` is also computed and
302    /// cached for [`Self::mlm_logits`] / [`Self::mlm_logits_into`].
303    pub fn forward(
304        &mut self,
305        input_ids: &[f32],
306        attention_mask: &[f32],
307        token_type_ids: &[f32],
308        position_ids: &[f32],
309    ) -> Result<Vec<f32>> {
310        let (b, s) = self.compiled_bs;
311        let expected = b * s;
312        if input_ids.len() != expected
313            || attention_mask.len() != expected
314            || token_type_ids.len() != expected
315            || position_ids.len() != expected
316        {
317            bail!(
318                "rlx-clinicalbert: forward expects each input of length {expected} \
319                 (batch={b}, seq={s}); got {}, {}, {}, {}",
320                input_ids.len(),
321                attention_mask.len(),
322                token_type_ids.len(),
323                position_ids.len()
324            );
325        }
326        let outputs = self.compiled.run(&[
327            ("input_ids", input_ids),
328            ("attention_mask", attention_mask),
329            ("token_type_ids", token_type_ids),
330            ("position_ids", position_ids),
331        ]);
332        if std::env::var("RLX_CLINICALBERT_DEBUG").is_ok() {
333            let sizes: Vec<usize> = outputs.iter().map(|o| o.len()).collect();
334            eprintln!("[rlx-clinicalbert] forward outputs: {sizes:?}");
335        }
336        // The encoder graph declares `hidden_states` as the FIRST output.
337        // When `.with_mlm_in_graph()` is active, `mlm_logits` is the SECOND
338        // output; cache it for [`Self::mlm_logits`].
339        #[cfg(feature = "mlm")]
340        if self.mlm_in_graph {
341            if outputs.len() >= 2 {
342                self.cached_mlm_logits = Some(outputs[1].clone());
343            } else {
344                bail!(
345                    "rlx-clinicalbert: with_mlm_in_graph but compiled graph returned {} outputs",
346                    outputs.len()
347                );
348            }
349        }
350        outputs
351            .into_iter()
352            .next()
353            .ok_or_else(|| anyhow::anyhow!("rlx-clinicalbert: compiled graph returned no outputs"))
354    }
355
356    /// Forward + pool using the runner's configured pooling.
357    ///
358    /// Returns `[batch, hidden]` flattened. With [`Pooling::None`] this matches
359    /// [`Self::forward`] (returns `[batch, seq, hidden]`).
360    pub fn embed(
361        &mut self,
362        input_ids: &[f32],
363        attention_mask: &[f32],
364        token_type_ids: &[f32],
365        position_ids: &[f32],
366    ) -> Result<Vec<f32>> {
367        let hidden = self.forward(input_ids, attention_mask, token_type_ids, position_ids)?;
368        let (b, s) = self.compiled_bs;
369        let h = self.hidden_size();
370        Ok(match self.pooling {
371            Pooling::None => hidden,
372            Pooling::Cls => pool_cls(&hidden, b, s, h),
373            Pooling::Mean => pool_mean(&hidden, attention_mask, b, s, h),
374        })
375    }
376}
377
378#[derive(Debug, Clone, Default)]
379pub struct ClinicalBertRunnerBuilder {
380    weights: Option<PathBuf>,
381    config: Option<ClinicalBertConfig>,
382    config_path: Option<PathBuf>,
383    variant: Option<ClinicalBertVariant>,
384    device: Option<Device>,
385    batch: Option<usize>,
386    seq: Option<usize>,
387    pooling: Option<Pooling>,
388    #[cfg(feature = "pooler")]
389    enable_pooler: bool,
390    #[cfg(feature = "mlm")]
391    enable_mlm: bool,
392    #[cfg(feature = "mlm")]
393    enable_mlm_in_graph: bool,
394    #[cfg(feature = "mlm")]
395    mlm_mode: Option<MlmExecMode>,
396}
397
398impl ClinicalBertRunnerBuilder {
399    pub fn weights(mut self, path: impl Into<PathBuf>) -> Self {
400        self.weights = Some(path.into());
401        self
402    }
403
404    pub fn config(mut self, cfg: BertConfig) -> Self {
405        self.config = Some(ClinicalBertConfig::new(cfg));
406        self
407    }
408
409    pub fn config_path(mut self, path: impl Into<PathBuf>) -> Self {
410        self.config_path = Some(path.into());
411        self
412    }
413
414    pub fn variant(mut self, v: ClinicalBertVariant) -> Self {
415        self.variant = Some(v);
416        self
417    }
418
419    pub fn device(mut self, d: Device) -> Self {
420        self.device = Some(d);
421        self
422    }
423
424    pub fn batch(mut self, b: usize) -> Self {
425        self.batch = Some(b);
426        self
427    }
428
429    pub fn max_seq(mut self, s: usize) -> Self {
430        self.seq = Some(s);
431        self
432    }
433
434    pub fn pooling(mut self, p: Pooling) -> Self {
435        self.pooling = Some(p);
436        self
437    }
438
439    /// Load the pre-trained pooler head (`bert.pooler.dense.*`) so
440    /// [`ClinicalBertRunner::pooler_output`] becomes available.
441    #[cfg(feature = "pooler")]
442    pub fn with_pooler(mut self) -> Self {
443        self.enable_pooler = true;
444        self
445    }
446
447    /// Enable the MLM head in [`MlmExecMode::Cpu`] mode. Shortcut for
448    /// `.mlm_mode(MlmExecMode::Cpu)`.
449    #[cfg(feature = "mlm")]
450    pub fn with_mlm(mut self) -> Self {
451        self.enable_mlm = true;
452        self
453    }
454
455    /// Enable the MLM head in [`MlmExecMode::InGraph`] mode. Shortcut for
456    /// `.mlm_mode(MlmExecMode::InGraph)`. Mutually exclusive with
457    /// [`Self::with_mlm`].
458    #[cfg(feature = "mlm")]
459    pub fn with_mlm_in_graph(mut self) -> Self {
460        self.enable_mlm_in_graph = true;
461        self
462    }
463
464    /// Enable the MLM head and pick where it runs. The explicit form of
465    /// [`Self::with_mlm`] / [`Self::with_mlm_in_graph`]; pass
466    /// [`MlmExecMode::Auto`] to let the runner pick at [`Self::build`] time
467    /// based on the configured `(device, batch)`. Overrides any prior
468    /// shortcut call.
469    #[cfg(feature = "mlm")]
470    pub fn mlm_mode(mut self, mode: MlmExecMode) -> Self {
471        self.mlm_mode = Some(mode);
472        self.enable_mlm = false;
473        self.enable_mlm_in_graph = false;
474        self
475    }
476
477    pub fn build(self) -> Result<ClinicalBertRunner> {
478        let weights = self
479            .weights
480            .clone()
481            .ok_or_else(|| anyhow::anyhow!("rlx-clinicalbert: weights path required"))?;
482        let device = self.device.unwrap_or(Device::Cpu);
483        validate_standard_device("clinicalbert", device)?;
484
485        let mut config = if let Some(cfg) = self.config {
486            cfg
487        } else if let Some(variant) = self.variant {
488            ClinicalBertConfig::new(variant.preset()).with_variant(variant)
489        } else {
490            let cfg_path = self
491                .config_path
492                .clone()
493                .unwrap_or_else(|| ClinicalBertConfig::config_json_path(&weights));
494            if cfg_path.is_file() {
495                validate_hf_config(cfg_path.parent().unwrap_or(Path::new(".")))?;
496                ClinicalBertConfig::from_file(&cfg_path)?
497            } else {
498                bail!(
499                    "rlx-clinicalbert: no config supplied — call `.config(..)`, \
500                     `.config_path(..)`, or `.variant(..)`, or place `config.json` next \
501                     to {weights:?}"
502                );
503            }
504        };
505
506        if config.variant.is_none() {
507            config.variant = self.variant;
508        }
509
510        let batch = self.batch.unwrap_or(1);
511        let seq = self
512            .seq
513            .unwrap_or_else(|| config.bert.max_position_embeddings.min(512));
514
515        let weights_str = weights.to_str().ok_or_else(|| {
516            anyhow::anyhow!("rlx-clinicalbert: non-UTF8 weights path {weights:?}")
517        })?;
518        let mut wm = if weights.is_dir() {
519            WeightMap::from_resolved_path(&weights)
520        } else {
521            WeightMap::from_file(weights_str)
522        }
523        .with_context(|| format!("rlx-clinicalbert: loading {weights_str}"))?;
524
525        // Disallow combining the two MLM modes — they consume the same
526        // checkpoint tensors but route the head differently.
527        #[cfg(feature = "mlm")]
528        if self.enable_mlm && self.enable_mlm_in_graph {
529            bail!("rlx-clinicalbert: .with_mlm() and .with_mlm_in_graph() are mutually exclusive");
530        }
531
532        // Resolve the MLM execution mode. Three input paths:
533        //   1. `.mlm_mode(MlmExecMode::*)`           — explicit, takes precedence.
534        //   2. `.with_mlm()` / `.with_mlm_in_graph()` — legacy shortcuts.
535        //   3. neither                                — MLM head disabled.
536        // `Auto` is resolved against the configured (device, batch) here so
537        // the downstream code only sees the two concrete modes.
538        #[cfg(feature = "mlm")]
539        let resolved_mlm: Option<MlmExecMode> = match self.mlm_mode {
540            Some(MlmExecMode::Auto) => Some(MlmExecMode::Auto.resolve(device, batch)),
541            Some(m) => Some(m),
542            None => {
543                if self.enable_mlm {
544                    Some(MlmExecMode::Cpu)
545                } else if self.enable_mlm_in_graph {
546                    Some(MlmExecMode::InGraph)
547                } else {
548                    None
549                }
550            }
551        };
552
553        // Load heads BEFORE the encoder build — MLM head needs to clone the
554        // embedding matrix while it's still in the WeightMap.
555        #[cfg(feature = "mlm")]
556        let mlm_head: Option<MlmHead> = if resolved_mlm == Some(MlmExecMode::Cpu) {
557            Some(MlmHead::load(&config.bert, &mut wm)?)
558        } else {
559            None
560        };
561        #[cfg(feature = "pooler")]
562        let pooler_head: Option<PoolerHead> = if self.enable_pooler {
563            Some(PoolerHead::load(&config.bert, &mut wm)?)
564        } else {
565            None
566        };
567
568        // Pick the right encoder builder: with the head folded into the
569        // graph we emit a two-output graph (hidden_states + mlm_logits) and
570        // compile both in one pipeline. Otherwise the original encoder-only
571        // builder — the head, if any, runs as a CPU post-process.
572        #[cfg(feature = "mlm")]
573        let built = if resolved_mlm == Some(MlmExecMode::InGraph) {
574            build_clinicalbert_with_mlm_built(&config.bert, &mut wm, batch, seq)?
575        } else {
576            build_clinicalbert_built(&config.bert, &mut wm, batch, seq)?
577        };
578        #[cfg(not(feature = "mlm"))]
579        let built = build_clinicalbert_built(&config.bert, &mut wm, batch, seq)?;
580        let compiled = compile_built(built, device)?;
581
582        Ok(ClinicalBertRunner {
583            config,
584            weights_path: weights,
585            compiled,
586            compiled_bs: (batch, seq),
587            device,
588            pooling: self.pooling.unwrap_or(Pooling::Cls),
589            #[cfg(feature = "pooler")]
590            pooler_head,
591            #[cfg(feature = "mlm")]
592            mlm_head,
593            #[cfg(feature = "mlm")]
594            mlm_in_graph: resolved_mlm == Some(MlmExecMode::InGraph),
595            #[cfg(feature = "mlm")]
596            cached_mlm_logits: None,
597        })
598    }
599}
600
601fn pool_cls(hidden: &[f32], batch: usize, seq: usize, h: usize) -> Vec<f32> {
602    let mut out = vec![0f32; batch * h];
603    for bi in 0..batch {
604        let src = bi * seq * h;
605        out[bi * h..(bi + 1) * h].copy_from_slice(&hidden[src..src + h]);
606    }
607    out
608}
609
610fn pool_mean(
611    hidden: &[f32],
612    attention_mask: &[f32],
613    batch: usize,
614    seq: usize,
615    h: usize,
616) -> Vec<f32> {
617    let mut out = vec![0f32; batch * h];
618    for bi in 0..batch {
619        let mut count = 0.0f32;
620        for si in 0..seq {
621            let m = attention_mask[bi * seq + si];
622            if m > 0.0 {
623                count += 1.0;
624                let off = (bi * seq + si) * h;
625                let dst = bi * h;
626                for j in 0..h {
627                    out[dst + j] += hidden[off + j];
628                }
629            }
630        }
631        let inv = 1.0 / count.max(1.0);
632        for j in 0..h {
633            out[bi * h + j] *= inv;
634        }
635    }
636    out
637}