gliner2 0.1.1

Rust implementation of GLiNER2 with compatibility for upstream weights and Python training output.
Documentation
use anyhow::{Result, bail};
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Entity {
    pub text: String,
    pub label: String,
    pub confidence: f32,
    pub start: usize, // char offset
    pub end: usize,   // char offset
}

/// Decodes span scores into entities. Greedy overlap suppression runs **per label** (same as
/// GliNER2 `_extract_entities`); results are concatenated in `labels` order.
///
/// `scores_flat` is row-major `[num_entities * l * max_width]` matching `scores[p][i][j]` =
/// `scores_flat[p * (l * max_width) + i * max_width + j]`.
#[allow(clippy::too_many_arguments)] // Mirrors tensor layout; grouping would not simplify call sites.
pub fn find_spans(
    scores_flat: &[f32],
    num_entities: usize,
    l: usize,
    max_width: usize,
    threshold: f32,
    labels: &[&str],
    text: &str,
    start_offsets: &[usize],
    end_offsets: &[usize],
) -> Result<Vec<Entity>> {
    let expected = num_entities * l * max_width;
    if scores_flat.len() != expected {
        bail!(
            "find_spans: expected {} scores ({} * {} * {}), got {}",
            expected,
            num_entities,
            l,
            max_width,
            scores_flat.len()
        );
    }

    let mut out = Vec::new();

    for (p, &label) in labels.iter().enumerate().take(num_entities) {
        let mut per_label = Vec::new();
        let base = p * (l * max_width);
        for i in 0..l {
            for j in 0..max_width {
                let conf = scores_flat[base + i * max_width + j];
                if conf >= threshold {
                    let end_token_idx = i + j;
                    if end_token_idx < l {
                        let char_start = start_offsets[i];
                        let char_end = end_offsets[end_token_idx];
                        let text_val = text[char_start..char_end].to_string();

                        per_label.push(Entity {
                            text: text_val,
                            label: label.to_string(),
                            confidence: conf,
                            start: char_start,
                            end: char_end,
                        });
                    }
                }
            }
        }
        out.extend(greedy_select(per_label));
    }

    Ok(out)
}

/// Convenience when the `candle` feature is enabled: decode from a `[NumEntities, L, max_width]` tensor.
#[cfg(feature = "candle")]
pub fn find_spans_tensor(
    scores: &candle_core::Tensor,
    threshold: f32,
    labels: &[&str],
    text: &str,
    start_offsets: &[usize],
    end_offsets: &[usize],
) -> anyhow::Result<Vec<Entity>> {
    let (num_entities, l, max_width) = scores.dims3().map_err(|e| anyhow::anyhow!("{e}"))?;
    let scores_v = scores
        .flatten_all()
        .map_err(|e| anyhow::anyhow!("{e}"))?
        .to_vec1::<f32>()
        .map_err(|e| anyhow::anyhow!("{e}"))?;
    find_spans(
        &scores_v,
        num_entities,
        l,
        max_width,
        threshold,
        labels,
        text,
        start_offsets,
        end_offsets,
    )
}

/// `[NumEntities, L, max_width]` LibTorch tensor → entities (`tch` feature).
#[cfg(feature = "tch")]
pub fn find_spans_tch_tensor(
    scores: &tch::Tensor,
    threshold: f32,
    labels: &[&str],
    text: &str,
    start_offsets: &[usize],
    end_offsets: &[usize],
) -> Result<Vec<Entity>> {
    let sz = scores.size();
    if sz.len() != 3 {
        bail!(
            "find_spans_tch_tensor: expected 3D scores, got {} dims",
            sz.len()
        );
    }
    let num_entities = sz[0] as usize;
    let l = sz[1] as usize;
    let max_width = sz[2] as usize;
    let n = num_entities * l * max_width;
    let flat_t = scores.flatten(0, 2);
    let mut scores_v = vec![0f32; n];
    flat_t.copy_data(&mut scores_v, n);
    find_spans(
        &scores_v,
        num_entities,
        l,
        max_width,
        threshold,
        labels,
        text,
        start_offsets,
        end_offsets,
    )
}

pub fn greedy_select(mut entities: Vec<Entity>) -> Vec<Entity> {
    // Sort by confidence descending
    entities.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());

    let mut selected: Vec<Entity> = Vec::new();

    for entity in entities {
        let mut overlap = false;
        for s in &selected {
            // Overlap if not (entity.end <= s.start or entity.start >= s.end)
            if !(entity.end <= s.start || entity.start >= s.end) {
                overlap = true;
                break;
            }
        }
        if !overlap {
            selected.push(entity);
        }
    }

    selected
}

#[cfg(test)]
mod tests {
    use super::{Entity, greedy_select};

    #[test]
    fn per_label_greedy_keeps_overlapping_spans_for_different_labels() {
        let a = Entity {
            text: "foo".into(),
            label: "A".into(),
            confidence: 0.9,
            start: 0,
            end: 5,
        };
        let b = Entity {
            text: "bar".into(),
            label: "B".into(),
            confidence: 0.5,
            start: 2,
            end: 7,
        };

        let global = greedy_select(vec![a.clone(), b.clone()]);
        assert_eq!(global.len(), 1, "global NMS drops lower-confidence overlap");

        let mut per_label = greedy_select(vec![a]);
        per_label.extend(greedy_select(vec![b]));
        assert_eq!(per_label.len(), 2, "per-label NMS matches GliNER2 engine");
    }
}