use std::future::Future;
use std::mem::MaybeUninit;
use std::ops::Deref;
use std::pin::Pin;
pub use crate::Bert;
use crate::BertBuilder;
use crate::BertError;
use crate::BertLoadingError;
use crate::Pooling;
pub use kalosm_language_model::{
Embedder, EmbedderCacheExt, EmbedderExt, Embedding, EmbeddingInput, EmbeddingVariant,
ModelBuilder,
};
use kalosm_model_types::ModelLoadingProgress;
impl ModelBuilder for BertBuilder {
type Model = Bert;
type Error = BertLoadingError;
async fn start_with_loading_handler(
self,
handler: impl FnMut(ModelLoadingProgress) + Send + Sync + 'static,
) -> Result<Self::Model, Self::Error> {
self.build_with_loading_handler(handler).await
}
fn requires_download(&self) -> bool {
true
}
}
impl Bert {
pub fn embed_with_pooling(
&self,
input: &str,
pooling: Pooling,
) -> Result<Embedding, BertError> {
let mut tensors = self.embed_batch_raw(vec![input], pooling)?;
Ok(Embedding::from(
tensors
.pop()
.unwrap()
.to_vec2()?
.into_iter()
.next()
.unwrap(),
))
}
pub fn embed_batch_with_pooling(
&self,
inputs: Vec<&str>,
pooling: Pooling,
) -> Result<Vec<Embedding>, BertError> {
let tensors = self.embed_batch_raw(inputs, pooling)?;
let mut embeddings = Vec::with_capacity(tensors.len());
for tensor in tensors {
embeddings.push(Embedding::from(
tensor.to_vec2()?.into_iter().next().unwrap(),
));
}
Ok(embeddings)
}
}
impl Embedder for Bert {
type Error = BertError;
fn embed_for(
&self,
input: EmbeddingInput,
) -> impl Future<Output = Result<Embedding, Self::Error>> + Send {
match (&*self.embedding_search_prefix, input.variant) {
(Some(prefix), EmbeddingVariant::Query) => {
let mut new_input = prefix.clone();
new_input.push_str(&input.text);
self.embed_string(new_input)
}
_ => self.embed_string(input.text),
}
}
fn embed_vec_for(
&self,
inputs: Vec<EmbeddingInput>,
) -> impl Future<Output = Result<Vec<Embedding>, Self::Error>> + Send {
let inputs = inputs
.into_iter()
.map(
|input| match (&*self.embedding_search_prefix, input.variant) {
(Some(prefix), EmbeddingVariant::Query) => {
let mut new_input = prefix.clone();
new_input.push_str(&input.text);
new_input
}
_ => input.text,
},
)
.collect::<Vec<_>>();
self.embed_vec(inputs)
}
async fn embed_string(&self, input: String) -> Result<Embedding, Self::Error> {
let self_clone = self.clone();
tokio::task::spawn_blocking(move || self_clone.embed_with_pooling(&input, Pooling::CLS))
.await?
}
async fn embed_vec(&self, inputs: Vec<String>) -> Result<Vec<Embedding>, Self::Error> {
let self_clone = self.clone();
tokio::task::spawn_blocking(move || {
let inputs_borrowed = inputs.iter().map(|s| s.as_str()).collect::<Vec<_>>();
self_clone.embed_batch_with_pooling(inputs_borrowed, Pooling::CLS)
})
.await?
}
}
impl Deref for Bert {
type Target = dyn Fn(
&str,
) -> Pin<
Box<dyn Future<Output = Result<Embedding, BertError>> + Send + 'static>,
>;
fn deref(&self) -> &Self::Target {
let uninit_callable = MaybeUninit::<Self>::uninit();
let uninit_closure = move |text: &str| {
let myself = unsafe { &*uninit_callable.as_ptr() };
let self_clone = myself.clone();
let input = text.to_string();
Box::pin(async move {
tokio::task::spawn_blocking(move || {
self_clone.embed_with_pooling(&input, Pooling::CLS)
})
.await?
})
as Pin<Box<dyn Future<Output = Result<Embedding, BertError>> + Send + 'static>>
};
let size_of_closure = std::alloc::Layout::for_value(&uninit_closure);
assert_eq!(size_of_closure, std::alloc::Layout::new::<Self>());
fn cast_lifetime<'a, T>(_a: &T, b: &'a T) -> &'a T {
b
}
let reference_to_closure = cast_lifetime(
{
&uninit_closure
},
#[allow(clippy::missing_transmute_annotations)]
unsafe {
std::mem::transmute(self)
},
);
reference_to_closure as &_
}
}
#[cfg(test)]
#[tokio::test]
async fn test_bert() {
use crate::BertSource;
let bert = Bert::builder()
.with_source(BertSource::snowflake_arctic_embed_extra_small())
.build()
.await
.unwrap();
let result = bert("The quick brown fox jumps over the lazy dog.")
.await
.unwrap();
println!("{result:?}");
}