kalosm_language_model/embedding/
model.rsuse std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
pub(crate) type BoxedFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
use crate::embedding::Embedding;
pub trait Embedder: Send + Sync + 'static {
type Error: Send + Sync + 'static;
fn embed_string(
&self,
input: String,
) -> impl Future<Output = Result<Embedding, Self::Error>> + Send {
self.embed_for(EmbeddingInput {
text: input,
variant: EmbeddingVariant::Document,
})
}
fn embed_vec(
&self,
inputs: Vec<String>,
) -> impl Future<Output = Result<Vec<Embedding>, Self::Error>> + Send {
async move {
let mut embeddings = Vec::with_capacity(inputs.len());
for input in inputs {
embeddings.push(self.embed_string(input).await?);
}
Ok(embeddings)
}
}
fn embed_for(
&self,
input: EmbeddingInput,
) -> impl Future<Output = Result<Embedding, Self::Error>> + Send;
fn embed_vec_for(
&self,
inputs: Vec<EmbeddingInput>,
) -> impl Future<Output = Result<Vec<Embedding>, Self::Error>> + Send {
async move {
let mut embeddings = Vec::with_capacity(inputs.len());
for input in inputs {
embeddings.push(self.embed_for(input).await?);
}
Ok(embeddings)
}
}
}
impl<E: Embedder> Embedder for Arc<E> {
type Error = E::Error;
fn embed_for(
&self,
input: EmbeddingInput,
) -> impl Future<Output = Result<Embedding, Self::Error>> + Send {
E::embed_for(self, input)
}
fn embed_string(
&self,
input: String,
) -> impl Future<Output = Result<Embedding, Self::Error>> + Send {
E::embed_string(self, input)
}
fn embed_vec(
&self,
inputs: Vec<String>,
) -> impl Future<Output = Result<Vec<Embedding>, Self::Error>> + Send {
E::embed_vec(self, inputs)
}
fn embed_vec_for(
&self,
inputs: Vec<EmbeddingInput>,
) -> impl Future<Output = Result<Vec<Embedding>, Self::Error>> + Send {
E::embed_vec_for(self, inputs)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct EmbeddingInput {
pub text: String,
pub variant: EmbeddingVariant,
}
impl EmbeddingInput {
pub fn new(text: impl ToString, variant: EmbeddingVariant) -> Self {
Self {
text: text.to_string(),
variant,
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum EmbeddingVariant {
Query,
#[default]
Document,
}
pub trait EmbedderExt: Embedder {
fn into_any_embedder(self) -> DynEmbedder
where
Self: Sized,
Self::Error: std::error::Error,
{
DynEmbedder {
embedder: Box::new(AnyEmbedder::<Self>(self)),
}
}
fn embed(
&self,
input: impl ToString,
) -> impl Future<Output = Result<Embedding, Self::Error>> + Send {
self.embed_string(input.to_string())
}
fn embed_query(
&self,
input: impl ToString,
) -> impl Future<Output = Result<Embedding, Self::Error>> + Send {
self.embed_for(EmbeddingInput {
text: input.to_string(),
variant: EmbeddingVariant::Query,
})
}
fn embed_batch(
&self,
inputs: impl IntoIterator<Item = impl ToString>,
) -> impl Future<Output = Result<Vec<Embedding>, Self::Error>> + Send {
let inputs = inputs
.into_iter()
.map(|s| s.to_string())
.collect::<Vec<_>>();
self.embed_vec(inputs)
}
fn embed_batch_for(
&self,
inputs: impl IntoIterator<Item = EmbeddingInput>,
) -> impl Future<Output = Result<Vec<Embedding>, Self::Error>> + Send {
self.embed_vec_for(inputs.into_iter().collect())
}
}
impl<E: Embedder> EmbedderExt for E {}
pub struct DynEmbedder {
embedder: Box<dyn BoxedEmbedder + Send + Sync>,
}
impl Embedder for DynEmbedder {
type Error = Box<dyn std::error::Error + Send + Sync>;
fn embed_string(
&self,
input: String,
) -> impl Future<Output = Result<Embedding, Self::Error>> + Send {
self.embedder.embed_string_boxed(input)
}
fn embed_vec(
&self,
inputs: Vec<String>,
) -> impl Future<Output = Result<Vec<Embedding>, Self::Error>> + Send {
self.embedder.embed_vec_boxed(inputs)
}
fn embed_for(
&self,
input: EmbeddingInput,
) -> impl Future<Output = Result<Embedding, Self::Error>> + Send {
self.embedder.embed_for_boxed(input)
}
fn embed_vec_for(
&self,
inputs: Vec<EmbeddingInput>,
) -> impl Future<Output = Result<Vec<Embedding>, Self::Error>> + Send {
self.embedder.embed_vec_for_boxed(inputs)
}
}
struct AnyEmbedder<E: Embedder + Send + Sync + 'static>(E);
#[allow(clippy::type_complexity)]
trait BoxedEmbedder {
fn embed_string_boxed(
&self,
input: String,
) -> BoxedFuture<'_, Result<Embedding, Box<dyn std::error::Error + Send + Sync>>>;
fn embed_vec_boxed(
&self,
inputs: Vec<String>,
) -> BoxedFuture<'_, Result<Vec<Embedding>, Box<dyn std::error::Error + Send + Sync>>>;
fn embed_for_boxed(
&self,
input: EmbeddingInput,
) -> BoxedFuture<'_, Result<Embedding, Box<dyn std::error::Error + Send + Sync>>>;
fn embed_vec_for_boxed(
&self,
inputs: Vec<EmbeddingInput>,
) -> BoxedFuture<'_, Result<Vec<Embedding>, Box<dyn std::error::Error + Send + Sync>>>;
}
impl<E: Embedder + Send + Sync + 'static> BoxedEmbedder for AnyEmbedder<E>
where
E::Error: std::error::Error,
{
fn embed_string_boxed(
&self,
input: String,
) -> BoxedFuture<'_, Result<Embedding, Box<dyn std::error::Error + Send + Sync>>> {
let future = self.0.embed_string(input);
Box::pin(async move { future.await.map_err(|e| e.into()) })
}
fn embed_vec_boxed(
&self,
inputs: Vec<String>,
) -> BoxedFuture<'_, Result<Vec<Embedding>, Box<dyn std::error::Error + Send + Sync>>> {
let future = self.0.embed_vec(inputs);
Box::pin(async move {
future
.await
.map(|e| e.into_iter().collect())
.map_err(|e| e.into())
})
}
fn embed_for_boxed(
&self,
input: EmbeddingInput,
) -> BoxedFuture<'_, Result<Embedding, Box<dyn std::error::Error + Send + Sync>>> {
let future = self.0.embed_for(input);
Box::pin(async move { future.await.map_err(|e| e.into()) })
}
fn embed_vec_for_boxed(
&self,
inputs: Vec<EmbeddingInput>,
) -> BoxedFuture<'_, Result<Vec<Embedding>, Box<dyn std::error::Error + Send + Sync>>> {
let future = self.0.embed_vec_for(inputs);
Box::pin(async move {
future
.await
.map(|e| e.into_iter().collect())
.map_err(|e| e.into())
})
}
}