ragtime 0.2.0

Easy Retrieval Augmented Generation
Documentation
use crate::{
    doc::ChunkId, l2_normalize, session_from_model_file, simple_prompt::SimplePrompt, EmbedModel,
    FormattedPrompt, Persistable,
};
use anyhow::{anyhow, bail, Result};
use ndarray::{Array1, Array2, Axis};
use ort::{inputs, Session};
use serde::{Deserialize, Serialize};
use std::{
    fs::{File, OpenOptions},
    iter,
    path::{Path, PathBuf},
    cmp::max,
};
use tokenizers::Tokenizer;
use usearch::{ffi::Matches, Index, IndexOptions, MetricKind, ScalarKind};

const DIMS: usize = 1024;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GteLargeEnArgs {
    pub model: PathBuf,
    pub tokenizer: PathBuf,
}

pub struct GteLargeEn {
    params: GteLargeEnArgs,
    session: Session,
    tokenizer: Tokenizer,
    index: Index,
}

fn options(dims: usize) -> IndexOptions {
    let mut opts = IndexOptions::default();
    opts.dimensions = dims;
    opts.metric = MetricKind::Cos;
    opts.quantization = ScalarKind::F32;
    opts
}

impl Persistable for GteLargeEn {
    type Ctx = ();

    fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
        let mut path = PathBuf::from(path.as_ref());
        path.set_extension("json");
        let mut fd = OpenOptions::new().create(true).write(true).open(&path)?;
        serde_json::to_writer_pretty(&mut fd, &self.params)?;
        path.set_extension("usearch");
        Ok(self.index.save(&*path.to_string_lossy())?)
    }

    fn load<P: AsRef<Path>>(_ctx: (), path: P, view: bool) -> Result<Self> {
        let mut fd = File::open(path.as_ref())?;
        let params: GteLargeEnArgs = serde_json::from_reader(&mut fd)?;
        let (session, tokenizer) = session_from_model_file(&params.model, &params.tokenizer)?;
        let index = Index::new(&options(DIMS))?;
        let mut path = PathBuf::from(path.as_ref());
        path.set_extension("usearch");
        let path = path.to_string_lossy();
        if view {
            index.view(&*path)?
        } else {
            index.load(&*path)?;
        }
        Ok(Self {
            params,
            session,
            tokenizer,
            index,
        })
    }
}

impl EmbedModel for GteLargeEn {
    type Ctx = ();
    type Args = GteLargeEnArgs;
    type EmbedPrompt = SimplePrompt;
    type SearchPrompt = SimplePrompt;

    fn new(_ctx: (), params: Self::Args) -> Result<Self> {
        let (session, tokenizer) = session_from_model_file(&params.model, &params.tokenizer)?;
        let index = Index::new(&options(DIMS))?;
        index.reserve(1000)?;
        Ok(Self {
            params,
            session,
            tokenizer,
            index,
        })
    }

    fn remove(&mut self, chunks: &[ChunkId]) -> Result<()> {
        for id in chunks {
            self.index.remove(id.0)?;
        }
        Ok(())
    }

    fn add(
        &mut self,
        summary: <Self::EmbedPrompt as FormattedPrompt>::FinalPrompt,
        text: &[(ChunkId, <Self::EmbedPrompt as FormattedPrompt>::FinalPrompt)],
    ) -> Result<()> {
        let embed = Self::embed(
            &self.tokenizer,
            &self.session,
            iter::once(summary.as_ref())
                .chain(text.iter().map(|(_, t)| t.as_ref()))
                .collect(),
        )?;
        let mut iter = embed.axis_iter(Axis(0));
        let summary = &iter.next().ok_or_else(|| anyhow!("no summary"))? * 0.5;
        let mut tmp = Array1::zeros(summary.shape()[0]);
        for (e, (id, _)) in iter.zip(text.iter()) {
            tmp.assign(&e);
            tmp += &summary;
            l2_normalize(
                tmp.as_slice_mut()
                    .ok_or_else(|| anyhow!("could not get embedding"))?,
            );
            if self.index.capacity() == self.index.size() {
                self.index.reserve(max(10, self.index.capacity() * 2))?;
            }
            self.index.add(id.0, tmp.as_slice().unwrap())?
        }
        Ok(())
    }

    /// The keys component of Matches is a vec of ChunkIds represented as u64s.
    fn search(
        &mut self,
        text: <Self::SearchPrompt as FormattedPrompt>::FinalPrompt,
        n: usize,
    ) -> Result<Matches> {
        let mut qembed = Self::embed(&self.tokenizer, &self.session, vec![text.as_ref()])?;
        let mut qembed = qembed
            .axis_iter_mut(Axis(0))
            .next()
            .ok_or_else(|| anyhow!("no query"))?;
        let qembed = qembed
            .as_slice_mut()
            .ok_or_else(|| anyhow!("could not get slice"))?;
        l2_normalize(qembed);
        Ok(self.index.search(qembed, n)?)
    }
}

impl GteLargeEn {
    fn embed<'a>(
        tokenizer: &Tokenizer,
        session: &'a Session,
        text: Vec<&str>,
    ) -> Result<Array2<f32>> {
        if text.len() == 0 {
            bail!("can't add an empty batch")
        }
        let tokens = tokenizer
            .encode_batch(text, true)
            .map_err(|e| anyhow!("{e:?}"))?;
        let longest = tokens
            .iter()
            .map(|t| t.get_ids().len())
            .max()
            .ok_or_else(|| anyhow!("no tokens returned from tokenizer"))?;
        let input = Array2::from_shape_fn((tokens.len(), longest), |(i, j)| {
            let enc = &tokens[i];
            if j >= enc.len() {
                0i64
            } else {
                enc.get_ids()[j] as i64
            }
        });
        let attention_mask = Array2::from_shape_fn((tokens.len(), longest), |(i, j)| {
            let enc = &tokens[i];
            if j >= enc.len() {
                0i64
            } else {
                1i64
            }
        });
        let token_type_ids: Array2<i64> = Array2::zeros((tokens.len(), longest));
        let res = session.run(inputs![input, attention_mask, token_type_ids]?)?;
        let res = res["last_hidden_state"].try_extract_tensor::<f32>()?;
        let shape = res.shape();
        let mut out: Array2<f32> = Array2::zeros((shape[0], shape[2]));
        for (i, e) in res.axis_iter(Axis(0)).enumerate() {
            // &e.mean_axis(Axis(0))
            //   .ok_or_else(|| anyhow!("could not get mean"))?,
            out.row_mut(i).assign(&e.index_axis(Axis(0), 0));
        }
        Ok(out)
    }
}