use std::{borrow::Cow, fmt::Formatter, str::FromStr, sync::Arc};
use arrow::array::{AsArray, Float32Builder};
use arrow_array::{Array, ArrayRef, FixedSizeListArray, Float32Array};
use arrow_data::ArrayData;
use arrow_schema::DataType;
use async_openai::{
Client,
config::OpenAIConfig,
types::{CreateEmbeddingRequest, Embedding, EmbeddingInput, EncodingFormat},
};
use tokio::{runtime::Handle, task};
use crate::{Error, Result};
use super::EmbeddingFunction;
#[derive(Debug)]
pub enum EmbeddingModel {
TextEmbeddingAda002,
TextEmbedding3Small,
TextEmbedding3Large,
}
impl EmbeddingModel {
fn ndims(&self) -> usize {
match self {
Self::TextEmbeddingAda002 => 1536,
Self::TextEmbedding3Small => 1536,
Self::TextEmbedding3Large => 3072,
}
}
}
impl FromStr for EmbeddingModel {
type Err = Error;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s {
"text-embedding-ada-002" => Ok(Self::TextEmbeddingAda002),
"text-embedding-3-small" => Ok(Self::TextEmbedding3Small),
"text-embedding-3-large" => Ok(Self::TextEmbedding3Large),
_ => Err(Error::InvalidInput {
message: "Invalid input. Available models are: 'text-embedding-3-small', 'text-embedding-ada-002', 'text-embedding-3-large' ".to_string()
}),
}
}
}
impl std::fmt::Display for EmbeddingModel {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
match self {
Self::TextEmbeddingAda002 => write!(f, "text-embedding-ada-002"),
Self::TextEmbedding3Small => write!(f, "text-embedding-3-small"),
Self::TextEmbedding3Large => write!(f, "text-embedding-3-large"),
}
}
}
impl TryFrom<&str> for EmbeddingModel {
type Error = Error;
fn try_from(value: &str) -> std::result::Result<Self, Self::Error> {
value.parse()
}
}
pub struct OpenAIEmbeddingFunction {
model: EmbeddingModel,
api_key: String,
api_base: Option<String>,
org_id: Option<String>,
}
impl std::fmt::Debug for OpenAIEmbeddingFunction {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
let creds_display = if self.api_key.len() > 6 {
format!(
"{}***{}",
&self.api_key[0..2],
&self.api_key[self.api_key.len() - 4..]
)
} else {
"[INVALID]".to_string()
};
f.debug_struct("OpenAI")
.field("model", &self.model)
.field("api_key", &creds_display)
.field("api_base", &self.api_base)
.field("org_id", &self.org_id)
.finish()
}
}
impl OpenAIEmbeddingFunction {
pub fn new<A: Into<String>>(api_key: A) -> Self {
Self::new_impl(api_key.into(), EmbeddingModel::TextEmbeddingAda002)
}
pub fn new_with_model<A: Into<String>, M: TryInto<EmbeddingModel>>(
api_key: A,
model: M,
) -> crate::Result<Self>
where
M::Error: Into<crate::Error>,
{
Ok(Self::new_impl(
api_key.into(),
model.try_into().map_err(|e| e.into())?,
))
}
fn new_impl(api_key: String, model: EmbeddingModel) -> Self {
Self {
model,
api_key,
api_base: None,
org_id: None,
}
}
pub fn api_base<S: Into<String>>(mut self, api_base: S) -> Self {
self.api_base = Some(api_base.into());
self
}
pub fn org_id<S: Into<String>>(mut self, org_id: S) -> Self {
self.org_id = Some(org_id.into());
self
}
}
impl EmbeddingFunction for OpenAIEmbeddingFunction {
fn name(&self) -> &str {
"openai"
}
fn source_type(&self) -> Result<Cow<'_, DataType>> {
Ok(Cow::Owned(DataType::Utf8))
}
fn dest_type(&self) -> Result<Cow<'_, DataType>> {
let n_dims = self.model.ndims();
Ok(Cow::Owned(DataType::new_fixed_size_list(
DataType::Float32,
n_dims as i32,
false,
)))
}
fn compute_source_embeddings(&self, source: ArrayRef) -> crate::Result<ArrayRef> {
let len = source.len();
let n_dims = self.model.ndims();
let inner = self.compute_inner(source)?;
let fsl = DataType::new_fixed_size_list(DataType::Float32, n_dims as i32, false);
let array_data = ArrayData::builder(fsl)
.len(len)
.add_child_data(inner.into_data())
.build()?;
Ok(Arc::new(FixedSizeListArray::from(array_data)))
}
fn compute_query_embeddings(&self, input: Arc<dyn Array>) -> Result<Arc<dyn Array>> {
let arr = self.compute_inner(input)?;
Ok(Arc::new(arr))
}
}
impl OpenAIEmbeddingFunction {
fn compute_inner(&self, source: Arc<dyn Array>) -> Result<Float32Array> {
if source.is_nullable() {
return Err(crate::Error::InvalidInput {
message: "Expected non-nullable data type".to_string(),
});
}
if !matches!(source.data_type(), DataType::Utf8 | DataType::LargeUtf8) {
return Err(crate::Error::InvalidInput {
message: "Expected Utf8 data type".to_string(),
});
};
let mut creds = OpenAIConfig::new().with_api_key(self.api_key.clone());
if let Some(api_base) = &self.api_base {
creds = creds.with_api_base(api_base.clone());
}
if let Some(org_id) = &self.org_id {
creds = creds.with_org_id(org_id.clone());
}
let input = match source.data_type() {
DataType::Utf8 => {
let array = source
.as_string::<i32>()
.into_iter()
.map(|s| {
s.expect("we already asserted that the array is non-nullable")
.to_string()
})
.collect::<Vec<String>>();
EmbeddingInput::StringArray(array)
}
DataType::LargeUtf8 => {
let array = source
.as_string::<i64>()
.into_iter()
.map(|s| {
s.expect("we already asserted that the array is non-nullable")
.to_string()
})
.collect::<Vec<String>>();
EmbeddingInput::StringArray(array)
}
_ => unreachable!("This should not happen. We already checked the data type."),
};
let client = Client::with_config(creds);
let embed = client.embeddings();
let req = CreateEmbeddingRequest {
model: self.model.to_string(),
input,
encoding_format: Some(EncodingFormat::Float),
user: None,
dimensions: None,
};
task::block_in_place(move || {
Handle::current().block_on(async {
let mut builder = Float32Builder::new();
let res = embed.create(req).await.map_err(|e| crate::Error::Runtime {
message: format!("OpenAI embed request failed: {e}"),
})?;
for Embedding { embedding, .. } in res.data.iter() {
builder.append_slice(embedding);
}
Ok(builder.finish())
})
})
}
}