use tonic::IntoRequest;
use crate::{
content::{IntoContent, TryIntoContent},
error::status_into_error,
full_model_name,
proto::{BatchEmbedContentsResponse, Content, EmbedContentResponse, Model as Info, TaskType},
};
use super::{
client::Client,
error::{Error, ServiceError},
proto::{BatchEmbedContentsRequest, EmbedContentRequest},
};
#[derive(Debug)]
pub struct Model<'c> {
client: &'c Client,
name: String,
pub task_type: Option<TaskType>,
}
impl<'c> Model<'c> {
pub fn new(client: &'c Client, name: &str) -> Self {
Model {
client,
name: full_model_name(name),
task_type: None,
}
}
pub async fn embed_content<T: TryIntoContent>(
&self,
content: T,
) -> Result<EmbedContentResponse, Error> {
self.embed_content_with_title("", content).await
}
pub async fn embed_content_with_title<T>(
&self,
title: &str,
content: T,
) -> Result<EmbedContentResponse, Error>
where
T: TryIntoContent,
{
let request = self
.build_request(title, content.try_into_content()?)
.await?;
self.client
.gc
.clone()
.embed_content(request)
.await
.map_err(status_into_error)
.map(|response| response.into_inner())
}
pub fn new_batch(&self) -> Batch<'_> {
Batch {
m: self,
req: BatchEmbedContentsRequest {
model: self.name.clone(),
requests: Vec::new(),
},
}
}
pub async fn embed_batch<I, T>(&self, contents: I) -> Result<BatchEmbedContentsResponse, Error>
where
I: IntoIterator<Item = T>,
T: TryIntoContent,
{
let mut batch = self.new_batch();
for content in contents.into_iter() {
batch = batch.add_content(content.try_into_content()?);
}
batch.embed().await
}
pub async fn info(&self) -> Result<Info, Error> {
self.client.get_model(&self.name).await
}
async fn build_request(
&self,
title: &str,
content: Content,
) -> Result<tonic::Request<EmbedContentRequest>, Error> {
let mut request = self._build_request(title, content).into_request();
self.client.add_auth(&mut request).await?;
Ok(request)
}
fn _build_request(&self, title: &str, content: Content) -> EmbedContentRequest {
let title = if title.is_empty() {
None
} else {
Some(title.to_owned())
};
let task_type = title
.as_ref()
.map(|_| TaskType::RetrievalDocument.into())
.or(self.task_type.map(Into::into));
EmbedContentRequest {
model: self.name.clone(),
content: Some(content),
task_type,
title,
output_dimensionality: None,
}
}
}
#[derive(Debug)]
pub struct Batch<'m> {
m: &'m Model<'m>,
req: BatchEmbedContentsRequest,
}
impl Batch<'_> {
pub fn add_content<T: IntoContent>(self, content: T) -> Self {
self.add_content_with_title("", content)
}
pub fn add_content_with_title<T: IntoContent>(mut self, title: &str, content: T) -> Self {
self.req
.requests
.push(self.m._build_request(title, content.into_content()));
self
}
pub async fn embed(self) -> Result<BatchEmbedContentsResponse, Error> {
let expected = self.req.requests.len();
let mut request = self.req.into_request();
self.m.client.add_auth(&mut request).await?;
let response = self
.m
.client
.gc
.clone()
.batch_embed_contents(request)
.await
.map_err(status_into_error)
.map(|response| response.into_inner())?;
if response.embeddings.len() != expected {
return Err(Error::Service(ServiceError::InvalidResponse(
format!(
"Expected {} embeddings, got {}",
expected,
response.embeddings.len()
)
.into(),
)));
}
Ok(response)
}
}
impl Client {
pub fn embedding_model<'c>(&'c self, name: &str) -> Model<'c> {
Model::new(self, name)
}
}