use anyhow::Result;
use super::bert::RlxBertModel;
use super::pooling::{Pooling, pool_embeddings};
use super::tokenizer::BertTokenizer;
pub fn embed_with_rlx(
model: &mut RlxBertModel,
tokenizer: &BertTokenizer,
texts: &[&str],
pooling: Pooling,
) -> Result<Vec<Vec<f32>>> {
let batch = tokenizer.encode_batch(texts)?;
let b = texts.len();
let s = batch.seq_len;
let hs = model.hidden_size();
model.recompile(b, s)?;
let ids: Vec<f32> = batch
.input_ids
.iter()
.flat_map(|r| r.iter().map(|&v| v as f32))
.collect();
let mask: Vec<f32> = batch
.attention_mask
.iter()
.flat_map(|r| r.iter().map(|&v| v as f32))
.collect();
let tt: Vec<f32> = batch
.token_type_ids
.iter()
.flat_map(|r| r.iter().map(|&v| v as f32))
.collect();
let pos: Vec<f32> = (0..b).flat_map(|_| (0..s).map(|i| i as f32)).collect();
let hidden = model.forward(&ids, &mask, &tt, &pos);
let mask_refs: Vec<&[u32]> = batch.attention_mask.iter().map(|r| r.as_slice()).collect();
Ok(pool_embeddings(&hidden, &mask_refs, b, s, hs, pooling))
}