use async_openai::types::CreateEmbeddingRequestArgs;
use async_openai::{types::CreateCompletionRequestArgs, Client};
use futures_util::StreamExt;
use kalosm_sample::Tokenizer;
use kalosm_streams::ChannelTextStream;
use std::sync::Arc;
use crate::{CreateModel, Embedder, Embedding, GenerationParameters, VectorSpace};
macro_rules! openai_model {
($ty: ident, $tybuilder: ident, $model: literal) => {
pub struct $ty {
client: Client<async_openai::config::OpenAIConfig>,
}
#[doc = $model]
#[derive(Debug, Default)]
pub struct $tybuilder {
config: async_openai::config::OpenAIConfig,
}
impl $tybuilder {
pub fn new() -> Self {
Self {
config: Default::default(),
}
}
pub fn with_api_key(mut self, api_key: &str) -> Self {
self.config = self.config.with_api_key(api_key);
self
}
pub fn with_base_url(mut self, base_url: &str) -> Self {
self.config = self.config.with_api_base(base_url);
self
}
pub fn with_organization_id(mut self, organization_id: &str) -> Self {
self.config = self.config.with_org_id(organization_id);
self
}
pub fn build(self) -> $ty {
$ty {
client: Client::with_config(self.config),
}
}
}
impl $ty {
pub fn builder() -> $tybuilder {
$tybuilder::new()
}
}
impl Default for $ty {
fn default() -> Self {
Self::builder().build()
}
}
#[async_trait::async_trait]
impl CreateModel for $ty {
async fn start() -> Self {
let client = Client::new();
$ty { client }
}
fn requires_download() -> bool {
false
}
}
#[async_trait::async_trait]
impl crate::model::Model for $ty {
type TextStream = ChannelTextStream<String>;
type SyncModel = crate::SyncModelNotSupported;
fn tokenizer(&self) -> Arc<dyn Tokenizer + Send + Sync> {
panic!("OpenAI does not expose tokenization")
}
async fn stream_text_inner(
&self,
prompt: &str,
generation_parameters: GenerationParameters,
) -> anyhow::Result<Self::TextStream> {
let request = CreateCompletionRequestArgs::default()
.model($model)
.n(1)
.prompt(prompt)
.stream(true)
.frequency_penalty(generation_parameters.repetition_penalty)
.temperature(generation_parameters.temperature)
.stop(
generation_parameters
.stop_on
.iter()
.cloned()
.collect::<Vec<String>>(),
)
.max_tokens(generation_parameters.max_length as u16)
.build()?;
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let mut stream = self.client.completions().create_stream(request).await?;
tokio::spawn(async move {
while let Some(response) = stream.next().await {
match response {
Ok(response) => {
let text = response.choices[0].text.clone();
if tx.send(text).is_err() {
break;
}
}
Err(e) => {
log::error!("Error in OpenAI stream: {}", e);
break;
}
}
}
Ok::<(), anyhow::Error>(())
});
Ok(rx.into())
}
}
};
}
openai_model!(Gpt3_5, Gpt3_5Builder, "gpt-3.5-turbo");
openai_model!(Gpt4, Gpt4Builder, "text-davinci-003");
#[derive(Debug)]
pub struct AdaEmbedder {
client: Client<async_openai::config::OpenAIConfig>,
}
#[derive(Debug, Default)]
pub struct AdaEmbedderBuilder {
config: async_openai::config::OpenAIConfig,
}
impl AdaEmbedderBuilder {
pub fn new() -> Self {
Self {
config: Default::default(),
}
}
pub fn with_api_key(mut self, api_key: &str) -> Self {
self.config = self.config.with_api_key(api_key);
self
}
pub fn with_base_url(mut self, base_url: &str) -> Self {
self.config = self.config.with_api_base(base_url);
self
}
pub fn with_organization_id(mut self, organization_id: &str) -> Self {
self.config = self.config.with_org_id(organization_id);
self
}
pub fn build(self) -> AdaEmbedder {
AdaEmbedder {
client: Client::with_config(self.config),
}
}
}
impl AdaEmbedder {
pub fn builder() -> AdaEmbedderBuilder {
AdaEmbedderBuilder::new()
}
}
impl Default for AdaEmbedder {
fn default() -> Self {
Self {
client: Client::new(),
}
}
}
#[async_trait::async_trait]
impl CreateModel for AdaEmbedder {
async fn start() -> Self {
let client = Client::new();
AdaEmbedder { client }
}
fn requires_download() -> bool {
false
}
}
pub struct AdaEmbedding;
impl VectorSpace for AdaEmbedding {}
#[async_trait::async_trait]
impl Embedder<AdaEmbedding> for AdaEmbedder {
async fn embed(&mut self, input: &str) -> anyhow::Result<Embedding<AdaEmbedding>> {
let request = CreateEmbeddingRequestArgs::default()
.model("text-embedding-ada-002")
.input([input])
.build()?;
let response = self.client.embeddings().create(request).await?;
let embedding = Embedding::from(response.data[0].embedding.iter().copied());
Ok(embedding)
}
}