use super::arch::ModelArch;
use super::driver::Driver;
use super::{EmbedBackend, Encoding};
pub struct GenericBackend<D: Driver, A: ModelArch<D>> {
driver: D,
arch: A,
max_tokens: usize,
is_gpu: bool,
max_batch: usize,
_mmap: MmapHolder,
}
pub enum MmapHolder {
Owned(memmap2::Mmap),
Shared(std::sync::Arc<memmap2::Mmap>),
}
impl<D: Driver, A: ModelArch<D>> GenericBackend<D, A> {
pub fn new(driver: D, arch: A, max_tokens: usize, is_gpu: bool, mmap: memmap2::Mmap) -> Self {
Self::with_max_batch(
driver,
arch,
max_tokens,
is_gpu,
MmapHolder::Owned(mmap),
32,
)
}
pub fn new_shared(
driver: D,
arch: A,
max_tokens: usize,
is_gpu: bool,
mmap: std::sync::Arc<memmap2::Mmap>,
) -> Self {
Self::with_max_batch(
driver,
arch,
max_tokens,
is_gpu,
MmapHolder::Shared(mmap),
32,
)
}
#[expect(clippy::cast_possible_wrap, reason = "warmup seq length is small")]
pub fn with_max_batch(
driver: D,
arch: A,
max_tokens: usize,
is_gpu: bool,
mmap: MmapHolder,
max_batch: usize,
) -> Self {
let backend = Self {
driver,
arch,
max_tokens,
is_gpu,
max_batch,
_mmap: mmap,
};
if is_gpu && max_tokens <= 1024 {
let seq = if max_tokens <= 1024 {
512.min(max_tokens)
} else {
64
};
let mut dummy = Vec::with_capacity(32);
for _ in 0..32 {
let ids: Vec<i64> = (0..seq as i64).collect();
dummy.push(Encoding {
input_ids: ids,
attention_mask: vec![1; seq],
token_type_ids: vec![0; seq],
});
}
let _ = backend.arch.forward(&backend.driver, &dummy);
}
backend
}
}
impl<D, A> EmbedBackend for GenericBackend<D, A>
where
D: Driver + Send + Sync + 'static,
A: ModelArch<D> + Send + Sync + 'static,
{
fn embed_batch(&self, encodings: &[Encoding]) -> crate::Result<Vec<Vec<f32>>> {
let max_batch = self.max_batch;
if encodings.len() <= max_batch {
return self.arch.forward(&self.driver, encodings);
}
let mut all = Vec::with_capacity(encodings.len());
for chunk in encodings.chunks(max_batch) {
let mut results = self.arch.forward(&self.driver, chunk)?;
all.append(&mut results);
}
Ok(all)
}
fn supports_clone(&self) -> bool {
false
}
fn clone_backend(&self) -> Box<dyn EmbedBackend> {
panic!("GenericBackend does not support cloning")
}
fn is_gpu(&self) -> bool {
self.is_gpu
}
fn max_tokens(&self) -> usize {
self.max_tokens
}
fn name(&self) -> &'static str {
self.driver.name()
}
}