use crate::common::auth::{AuthProvider, OpenAIAuth};
use crate::common::client::create_http_client;
use crate::common::errors::{ErrorResponse, OpenAIToolError, Result};
use crate::common::models::EmbeddingModel;
use crate::embedding::response::Response;
use core::str;
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Debug, Clone, Deserialize, Default)]
struct Input {
#[serde(skip_serializing_if = "String::is_empty")]
input_text: String,
#[serde(skip_serializing_if = "Vec::is_empty")]
input_text_array: Vec<String>,
}
impl Input {
pub fn from_text(input_text: &str) -> Self {
Self { input_text: input_text.to_string(), input_text_array: vec![] }
}
pub fn from_text_array(input_text_array: Vec<String>) -> Self {
Self { input_text: String::new(), input_text_array }
}
}
impl Serialize for Input {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
if !self.input_text.is_empty() && self.input_text_array.is_empty() {
self.input_text.serialize(serializer)
} else if self.input_text.is_empty() && !self.input_text_array.is_empty() {
self.input_text_array.serialize(serializer)
} else {
"".serialize(serializer)
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
struct Body {
model: EmbeddingModel,
input: Input,
encoding_format: Option<String>,
}
const EMBEDDINGS_PATH: &str = "embeddings";
pub struct Embedding {
auth: AuthProvider,
body: Body,
timeout: Option<Duration>,
}
impl Embedding {
pub fn new() -> Result<Self> {
let auth = AuthProvider::openai_from_env()?;
let body = Body::default();
Ok(Self { auth, body, timeout: None })
}
pub fn with_auth(auth: AuthProvider) -> Self {
Self { auth, body: Body::default(), timeout: None }
}
pub fn azure() -> Result<Self> {
let auth = AuthProvider::azure_from_env()?;
Ok(Self { auth, body: Body::default(), timeout: None })
}
pub fn detect_provider() -> Result<Self> {
let auth = AuthProvider::from_env()?;
Ok(Self { auth, body: Body::default(), timeout: None })
}
pub fn with_url<S: Into<String>>(base_url: S, api_key: S) -> Self {
let auth = AuthProvider::from_url_with_key(base_url, api_key);
Self { auth, body: Body::default(), timeout: None }
}
pub fn from_url<S: Into<String>>(url: S) -> Result<Self> {
let auth = AuthProvider::from_url(url)?;
Ok(Self { auth, body: Body::default(), timeout: None })
}
pub fn auth(&self) -> &AuthProvider {
&self.auth
}
pub fn base_url<T: AsRef<str>>(&mut self, url: T) -> &mut Self {
if let AuthProvider::OpenAI(ref openai_auth) = self.auth {
let new_auth = OpenAIAuth::new(openai_auth.api_key()).with_base_url(url.as_ref());
self.auth = AuthProvider::OpenAI(new_auth);
} else {
tracing::warn!("base_url() is only supported for OpenAI provider. Use azure() or with_auth() for Azure.");
}
self
}
pub fn model(&mut self, model: EmbeddingModel) -> &mut Self {
self.body.model = model;
self
}
#[deprecated(since = "0.2.0", note = "Use `model(EmbeddingModel)` instead for type safety")]
pub fn model_id<T: AsRef<str>>(&mut self, model_id: T) -> &mut Self {
self.body.model = EmbeddingModel::from(model_id.as_ref());
self
}
pub fn timeout(&mut self, timeout: Duration) -> &mut Self {
self.timeout = Some(timeout);
self
}
pub fn input_text<T: AsRef<str>>(&mut self, input_text: T) -> &mut Self {
self.body.input = Input::from_text(input_text.as_ref());
self
}
pub fn input_text_array<T: AsRef<str>>(&mut self, input_text_array: Vec<T>) -> &mut Self {
let input_strings = input_text_array.into_iter().map(|s| s.as_ref().to_string()).collect();
self.body.input = Input::from_text_array(input_strings);
self
}
pub fn encoding_format<T: AsRef<str>>(&mut self, encoding_format: T) -> &mut Self {
let encoding_format = encoding_format.as_ref();
assert!(encoding_format == "float" || encoding_format == "base64", "encoding_format must be either 'float' or 'base64'");
self.body.encoding_format = Some(encoding_format.to_string());
self
}
pub async fn embed(&self) -> Result<Response> {
if self.body.input.input_text.is_empty() && self.body.input.input_text_array.is_empty() {
return Err(OpenAIToolError::Error("Input text is not set.".into()));
}
let body = serde_json::to_string(&self.body)?;
let client = create_http_client(self.timeout)?;
let mut headers = request::header::HeaderMap::new();
headers.insert("Content-Type", request::header::HeaderValue::from_static("application/json"));
headers.insert("User-Agent", request::header::HeaderValue::from_static("openai-tools-rust"));
self.auth.apply_headers(&mut headers)?;
if cfg!(test) {
let body_for_debug = serde_json::to_string_pretty(&self.body).unwrap().replace(self.auth.api_key(), "*************");
tracing::info!("Request body: {}", body_for_debug);
}
let endpoint = self.auth.endpoint(EMBEDDINGS_PATH);
let response = client.post(&endpoint).headers(headers).body(body).send().await.map_err(OpenAIToolError::RequestError)?;
let status = response.status();
let content = response.text().await.map_err(OpenAIToolError::RequestError)?;
if cfg!(test) {
tracing::info!("Response content: {}", content);
}
if !status.is_success() {
if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&content) {
return Err(OpenAIToolError::Error(error_resp.error.message.unwrap_or_default()));
}
return Err(OpenAIToolError::Error(format!("API error ({}): {}", status, content)));
}
serde_json::from_str::<Response>(&content).map_err(OpenAIToolError::SerdeJsonError)
}
}