use std::ops::{Deref, DerefMut};
use crate::core::embedding_model::{EmbeddingModel, EmbeddingModelOptions, EmbeddingModelResponse};
use crate::error::Result;
#[derive(Debug, Clone)]
pub struct EmbeddingModelRequest<M: EmbeddingModel> {
pub model: M,
pub(crate) options: EmbeddingModelOptions,
}
impl<M: EmbeddingModel> EmbeddingModelRequest<M> {
pub fn builder() -> EmbeddingModelRequestBuilder<M> {
EmbeddingModelRequestBuilder::default()
}
pub async fn embed(&self) -> Result<EmbeddingModelResponse> {
self.model.embed(self.options.clone()).await
}
}
impl<M: EmbeddingModel> Deref for EmbeddingModelRequest<M> {
type Target = EmbeddingModelOptions;
fn deref(&self) -> &Self::Target {
&self.options
}
}
impl<M: EmbeddingModel> DerefMut for EmbeddingModelRequest<M> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.options
}
}
pub struct ModelStage {}
pub struct OptionsStage {}
pub struct EmbeddingModelRequestBuilder<M: EmbeddingModel, State = ModelStage> {
model: Option<M>,
options: EmbeddingModelOptions,
state: std::marker::PhantomData<State>,
}
impl<M: EmbeddingModel, State> Deref for EmbeddingModelRequestBuilder<M, State> {
type Target = EmbeddingModelOptions;
fn deref(&self) -> &Self::Target {
&self.options
}
}
impl<M: EmbeddingModel, State> DerefMut for EmbeddingModelRequestBuilder<M, State> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.options
}
}
impl<M: EmbeddingModel> EmbeddingModelRequestBuilder<M> {
fn default() -> Self {
EmbeddingModelRequestBuilder {
model: None,
options: EmbeddingModelOptions::builder()
.input(vec![])
.dimensions(None)
.build()
.unwrap(),
state: std::marker::PhantomData,
}
}
}
impl<M: EmbeddingModel> EmbeddingModelRequestBuilder<M, ModelStage> {
pub fn model(self, model: M) -> EmbeddingModelRequestBuilder<M, OptionsStage> {
EmbeddingModelRequestBuilder {
model: Some(model),
options: self.options,
state: std::marker::PhantomData,
}
}
}
impl<M: EmbeddingModel> EmbeddingModelRequestBuilder<M, OptionsStage> {
pub fn input(
mut self,
input: impl Into<Vec<String>>,
) -> EmbeddingModelRequestBuilder<M, OptionsStage> {
self.options.input = input.into();
self
}
pub fn dimensions(
mut self,
dimensions: usize,
) -> EmbeddingModelRequestBuilder<M, OptionsStage> {
self.options.dimensions = Some(dimensions);
self
}
pub fn build(self) -> EmbeddingModelRequest<M> {
let model = self
.model
.unwrap_or_else(|| unreachable!("Model must be set"));
EmbeddingModelRequest {
model,
options: self.options,
}
}
}