use crate::model_type::{EmbeddingModelType, OllamaTextEmbeddingsInference};
use crate::zoo_embedding_errors::ZooEmbeddingError;
use async_trait::async_trait;
use lazy_static::lazy_static;
use reqwest::blocking::Client;
use reqwest::Client as AsyncClient;
use reqwest::ClientBuilder;
use serde::{Deserialize, Serialize};
use std::time::Duration;
lazy_static! {
pub static ref DEFAULT_EMBEDDINGS_SERVER_URL: &'static str = "https://api.zoo.ngo/embeddings";
pub static ref DEFAULT_EMBEDDINGS_LOCAL_URL: &'static str = "http://localhost:11434/";
}
#[async_trait]
pub trait EmbeddingGenerator: Sync + Send {
fn model_type(&self) -> EmbeddingModelType;
fn set_model_type(&mut self, model_type: EmbeddingModelType);
fn box_clone(&self) -> Box<dyn EmbeddingGenerator>;
fn generate_embedding_blocking(&self, input_string: &str) -> Result<Vec<f32>, ZooEmbeddingError>;
fn generate_embedding_default_blocking(&self, input_string: &str) -> Result<Vec<f32>, ZooEmbeddingError> {
self.generate_embedding_blocking(input_string)
}
fn generate_embeddings_blocking(&self, input_strings: &Vec<String>)
-> Result<Vec<Vec<f32>>, ZooEmbeddingError>;
fn generate_embeddings_blocking_default(
&self,
input_strings: &Vec<String>,
) -> Result<Vec<Vec<f32>>, ZooEmbeddingError> {
self.generate_embeddings_blocking(input_strings)
}
async fn generate_embedding(&self, input_string: &str) -> Result<Vec<f32>, ZooEmbeddingError>;
async fn generate_embedding_default(&self, input_string: &str) -> Result<Vec<f32>, ZooEmbeddingError> {
self.generate_embedding(input_string).await
}
async fn generate_embeddings(&self, input_strings: &Vec<String>) -> Result<Vec<Vec<f32>>, ZooEmbeddingError>;
async fn generate_embeddings_default(
&self,
input_strings: &Vec<String>,
) -> Result<Vec<Vec<f32>>, ZooEmbeddingError> {
self.generate_embeddings(input_strings).await
}
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct RemoteEmbeddingGenerator {
pub model_type: EmbeddingModelType,
pub api_url: String,
pub api_key: Option<String>,
}
#[async_trait]
impl EmbeddingGenerator for RemoteEmbeddingGenerator {
fn box_clone(&self) -> Box<dyn EmbeddingGenerator> {
Box::new(self.clone())
}
fn generate_embeddings_blocking(
&self,
input_strings: &Vec<String>,
) -> Result<Vec<Vec<f32>>, ZooEmbeddingError> {
let input_strings: Vec<String> = input_strings
.iter()
.map(|s| s.chars().take(self.model_type.max_input_token_count()).collect())
.collect();
match self.model_type {
EmbeddingModelType::OllamaTextEmbeddingsInference(_) => {
let mut embeddings = Vec::new();
for input_string in input_strings.iter() {
let embedding = self.generate_embedding_ollama_blocking(input_string)?;
embeddings.push(embedding);
}
Ok(embeddings)
}
}
}
fn generate_embedding_blocking(&self, input_string: &str) -> Result<Vec<f32>, ZooEmbeddingError> {
let input_strings = [input_string.to_string()];
let input_strings: Vec<String> = input_strings
.iter()
.map(|s| s.chars().take(self.model_type.max_input_token_count()).collect())
.collect();
let results = self.generate_embeddings_blocking(&input_strings)?;
if results.is_empty() {
Err(ZooEmbeddingError::FailedEmbeddingGeneration(
"No results returned from the embedding generation".to_string(),
))
} else {
Ok(results[0].clone())
}
}
async fn generate_embeddings(&self, input_strings: &Vec<String>) -> Result<Vec<Vec<f32>>, ZooEmbeddingError> {
let input_strings: Vec<String> = input_strings
.iter()
.map(|s| s.chars().take(self.model_type.max_input_token_count()).collect())
.collect();
match self.model_type.clone() {
EmbeddingModelType::OllamaTextEmbeddingsInference(model) => {
let mut embeddings = Vec::new();
for input_string in input_strings.iter() {
let embedding = self
.generate_embedding_ollama(input_string.clone(), model.to_string())
.await?;
embeddings.push(embedding);
}
Ok(embeddings)
}
}
}
async fn generate_embedding(&self, input_string: &str) -> Result<Vec<f32>, ZooEmbeddingError> {
let input_strings = [input_string.to_string()];
let input_strings: Vec<String> = input_strings
.iter()
.map(|s| s.chars().take(self.model_type.max_input_token_count()).collect())
.collect();
let results = self.generate_embeddings(&input_strings).await?;
if results.is_empty() {
Err(ZooEmbeddingError::FailedEmbeddingGeneration(
"No results returned from the embedding generation".to_string(),
))
} else {
Ok(results[0].clone())
}
}
fn model_type(&self) -> EmbeddingModelType {
self.model_type.clone()
}
fn set_model_type(&mut self, model_type: EmbeddingModelType) {
self.model_type = model_type
}
}
impl RemoteEmbeddingGenerator {
pub fn new(model_type: EmbeddingModelType, api_url: &str, api_key: Option<String>) -> RemoteEmbeddingGenerator {
RemoteEmbeddingGenerator {
model_type,
api_url: api_url.to_string(),
api_key,
}
}
pub fn new_default() -> RemoteEmbeddingGenerator {
let model_architecture =
EmbeddingModelType::OllamaTextEmbeddingsInference(OllamaTextEmbeddingsInference::SnowflakeArcticEmbedM);
RemoteEmbeddingGenerator {
model_type: model_architecture,
api_url: DEFAULT_EMBEDDINGS_SERVER_URL.to_string(),
api_key: None,
}
}
pub fn new_default_local() -> RemoteEmbeddingGenerator {
let model_architecture =
EmbeddingModelType::OllamaTextEmbeddingsInference(OllamaTextEmbeddingsInference::SnowflakeArcticEmbedM);
RemoteEmbeddingGenerator {
model_type: model_architecture,
api_url: DEFAULT_EMBEDDINGS_LOCAL_URL.to_string(),
api_key: None,
}
}
fn tei_endpoint_url(&self) -> String {
if self.api_url.ends_with('/') {
format!("{}embed", self.api_url)
} else {
format!("{}/embed", self.api_url)
}
}
fn ollama_endpoint_url(&self) -> String {
if self.api_url.ends_with('/') {
format!("{}api/embeddings", self.api_url)
} else {
format!("{}/api/embeddings", self.api_url)
}
}
pub async fn generate_embedding_ollama(
&self,
input_string: String,
model: String,
) -> Result<Vec<f32>, ZooEmbeddingError> {
let max_retries = 3;
let mut retry_count = 0;
let mut shortening_retry = 0;
let mut input_string = input_string.clone();
loop {
let request_body = OllamaEmbeddingsRequestBody {
model: model.clone(),
prompt: input_string.clone(),
};
let timeout = Duration::from_secs(60);
let client = ClientBuilder::new().timeout(timeout).build()?;
let mut request = client
.post(self.ollama_endpoint_url().to_string())
.header("Content-Type", "application/json")
.json(&request_body);
if let Some(api_key) = &self.api_key {
request = request.header("Authorization", format!("Bearer {}", api_key));
}
let response = request.send().await;
match response {
Ok(response) if response.status().is_success() => {
let embedding_response: Result<OllamaEmbeddingsResponse, _> =
response.json::<OllamaEmbeddingsResponse>().await;
match embedding_response {
Ok(embedding_response) => {
return Ok(embedding_response.embedding);
}
Err(err) => {
return Err(ZooEmbeddingError::RequestFailed(format!(
"Failed to deserialize response JSON: {}",
err
)));
}
}
}
Ok(response) if response.status() == reqwest::StatusCode::PAYLOAD_TOO_LARGE => {
let reduction_step = if shortening_retry > 1 {
100 * shortening_retry
} else {
50
};
let shortened_max_size = input_string.len().saturating_sub(reduction_step).max(5);
input_string = input_string.chars().take(shortened_max_size).collect();
retry_count = 0;
shortening_retry += 1;
if shortening_retry > 10 {
return Err(ZooEmbeddingError::RequestFailed(format!(
"HTTP request failed after multiple recursive iterations shortening input. Status: {}",
response.status()
)));
}
continue;
}
Ok(response) => {
return Err(ZooEmbeddingError::RequestFailed(format!(
"HTTP request failed with status: {}",
response.status()
)));
}
Err(err) => {
if retry_count < max_retries {
retry_count += 1;
continue;
} else {
return Err(ZooEmbeddingError::RequestFailed(format!(
"HTTP request failed after {} retries: {}",
max_retries, err
)));
}
}
}
}
}
fn generate_embedding_ollama_blocking(&self, input_string: &str) -> Result<Vec<f32>, ZooEmbeddingError> {
let request_body = OllamaEmbeddingsRequestBody {
model: self.model_type.to_string(),
prompt: String::from(input_string),
};
let client = Client::new();
let mut request = client
.post(&format!("{}", self.ollama_endpoint_url()))
.header("Content-Type", "application/json")
.json(&request_body);
if let Some(api_key) = &self.api_key {
request = request.header("Authorization", format!("Bearer {}", api_key));
}
let response = request.send().map_err(|err| {
ZooEmbeddingError::RequestFailed(format!("HTTP request failed: {}", err))
})?;
if response.status().is_success() {
let embedding_response: OllamaEmbeddingsResponse = response.json().map_err(|err| {
ZooEmbeddingError::RequestFailed(format!("Failed to deserialize response JSON: {}", err))
})?;
Ok(embedding_response.embedding)
} else {
Err(ZooEmbeddingError::RequestFailed(format!(
"HTTP request failed with status: {}",
response.status()
)))
}
}
pub async fn generate_embedding_tei(
&self,
input_strings: Vec<String>,
) -> Result<Vec<Vec<f32>>, ZooEmbeddingError> {
let max_retries = 3;
let mut retry_count = 0;
let mut shortening_retry = 0;
let mut current_input_strings = input_strings.clone();
loop {
let request_body = EmbeddingArrayRequestBody {
inputs: current_input_strings.iter().map(|s| s.to_string()).collect(),
};
let timeout = Duration::from_secs(60);
let client = ClientBuilder::new().timeout(timeout).build()?;
let mut request = client
.post(self.tei_endpoint_url().to_string())
.header("Content-Type", "application/json")
.json(&request_body);
if let Some(api_key) = &self.api_key {
request = request.header("Authorization", format!("Bearer {}", api_key));
}
let response = request.send().await;
match response {
Ok(response) if response.status().is_success() => {
let embedding_response: Result<Vec<Vec<f32>>, _> = response.json::<Vec<Vec<f32>>>().await;
match embedding_response {
Ok(embedding_response) => {
return Ok(embedding_response);
}
Err(err) => {
return Err(ZooEmbeddingError::RequestFailed(format!(
"Failed to deserialize response JSON: {}",
err
)));
}
}
}
Ok(response) if response.status() == reqwest::StatusCode::PAYLOAD_TOO_LARGE => {
let max_size = current_input_strings.iter().map(|s| s.len()).max().unwrap_or(0);
let reduction_step = if shortening_retry > 1 {
100 * shortening_retry
} else {
50
};
let shortened_max_size = max_size.saturating_sub(reduction_step).max(5);
current_input_strings = current_input_strings
.iter()
.map(|s| {
if s.len() > shortened_max_size {
s.chars().take(shortened_max_size).collect()
} else {
s.clone()
}
})
.collect();
retry_count = 0;
shortening_retry += 1;
if shortening_retry > 10 {
return Err(ZooEmbeddingError::RequestFailed(format!(
"HTTP request failed after multiple recursive iterations shortening input. Status: {}",
response.status()
)));
}
continue;
}
Ok(response) => {
return Err(ZooEmbeddingError::RequestFailed(format!(
"HTTP request failed with status: {}",
response.status()
)));
}
Err(err) => {
if retry_count < max_retries {
retry_count += 1;
continue;
} else {
return Err(ZooEmbeddingError::RequestFailed(format!(
"HTTP request failed after {} retries: {}",
max_retries, err
)));
}
}
}
}
}
pub async fn generate_embedding_open_ai(&self, input_string: &str) -> Result<Vec<f32>, ZooEmbeddingError> {
let request_body = EmbeddingRequestBody {
input: String::from(input_string),
model: self.model_type().to_string(),
};
let client = AsyncClient::new();
let mut request = client
.post(self.api_url.to_string())
.header("Content-Type", "application/json")
.json(&request_body);
if let Some(api_key) = &self.api_key {
request = request.header("Authorization", format!("Bearer {}", api_key));
}
let response = request.send().await.map_err(|err| {
ZooEmbeddingError::RequestFailed(format!("HTTP request failed: {}", err))
})?;
if response.status().is_success() {
let embedding_response: EmbeddingResponse = response.json().await.map_err(|err| {
ZooEmbeddingError::RequestFailed(format!("Failed to deserialize response JSON: {}", err))
})?;
Ok(embedding_response.data[0].embedding.clone())
} else {
Err(ZooEmbeddingError::RequestFailed(format!(
"HTTP request failed with status: {}",
response.status()
)))
}
}
}
#[derive(Serialize)]
#[allow(dead_code)]
struct EmbeddingRequestBody {
input: String,
model: String,
}
#[derive(Deserialize)]
#[allow(dead_code)]
struct EmbeddingResponseData {
embedding: Vec<f32>,
index: usize,
object: String,
}
#[derive(Deserialize)]
#[allow(dead_code)]
struct EmbeddingResponse {
object: String,
model: String,
data: Vec<EmbeddingResponseData>,
usage: serde_json::Value, }
#[derive(Serialize)]
#[allow(dead_code)]
struct EmbeddingArrayRequestBody {
inputs: Vec<String>,
}
#[derive(Debug, Serialize)]
#[allow(dead_code)]
struct OllamaEmbeddingsRequestBody {
model: String,
prompt: String,
}
#[derive(Deserialize)]
#[allow(dead_code)]
struct OllamaEmbeddingsResponse {
embedding: Vec<f32>,
}