use anyhow::Context;
use futures::{Stream, StreamExt};
use std::time::{SystemTime, UNIX_EPOCH};
use uuid::Uuid;
use crate::{
get_or_set_runtime,
types::{GeneralJsonAsyncIterator, GeneralJsonIterator, Json},
TransformerPipeline,
};
#[cfg(feature = "rust_bridge")]
use rust_bridge::{alias, alias_methods};
#[cfg(feature = "python")]
use crate::types::{GeneralJsonAsyncIteratorPython, GeneralJsonIteratorPython, JsonPython};
#[cfg(feature = "c")]
use crate::{
languages::c::JsonC,
languages::c::{GeneralJsonAsyncIteratorC, GeneralJsonIteratorC},
};
#[cfg_attr(feature = "rust_bridge", derive(alias))]
#[derive(Debug, Clone)]
pub struct OpenSourceAI {
database_url: Option<String>,
}
struct AsyncToSyncJsonIterator(std::pin::Pin<Box<dyn Stream<Item = anyhow::Result<Json>> + Send>>);
impl Iterator for AsyncToSyncJsonIterator {
type Item = anyhow::Result<Json>;
fn next(&mut self) -> Option<Self::Item> {
let runtime = get_or_set_runtime();
runtime.block_on(self.0.next())
}
}
#[cfg_attr(
feature = "rust_bridge",
alias_methods(
new,
chat_completions_create,
chat_completions_create_async,
chat_completions_create_stream,
chat_completions_create_stream_async
)
)]
impl OpenSourceAI {
pub fn new(database_url: Option<String>) -> Self {
Self { database_url }
}
fn create_pipeline_model_name_parameters(
&self,
mut model: Json,
) -> anyhow::Result<(TransformerPipeline, String, Json)> {
if model.is_object() {
let args = model.as_object_mut().unwrap();
let model_name = args
.remove("model")
.context("`model` is a required key in the model object")?;
let model_name = model_name.as_str().context("`model` must be a string")?;
Ok((
TransformerPipeline::new(
"conversational",
model_name,
Some(model.clone()),
self.database_url.clone(),
),
model_name.to_string(),
model,
))
} else {
let model_name = model
.as_str()
.context("`model` must either be a string or an object")?;
Ok((
TransformerPipeline::new(
"conversational",
model_name,
Some(
serde_json::json!({
"task": "conversational",
"model": model_name
})
.into(),
),
self.database_url.clone(),
),
model_name.to_string(),
serde_json::json!({
"task": "conversational",
"model": model_name
})
.into(),
))
}
}
#[allow(clippy::too_many_arguments)]
pub async fn chat_completions_create_stream_async(
&self,
model: Json,
messages: Vec<Json>,
max_tokens: Option<i32>,
temperature: Option<f64>,
n: Option<i32>,
chat_template: Option<String>,
) -> anyhow::Result<GeneralJsonAsyncIterator> {
let (transformer_pipeline, model_name, model_parameters) =
self.create_pipeline_model_name_parameters(model)?;
let max_tokens = max_tokens.unwrap_or(1000);
let temperature = temperature.unwrap_or(0.8);
let n = n.unwrap_or(1) as usize;
let to_hash = format!("{}{}{}{}", *model_parameters, max_tokens, temperature, n);
let md5_digest = md5::compute(to_hash.as_bytes());
let fingerprint = uuid::Uuid::from_slice(&md5_digest.0)?;
let mut args = serde_json::json!({ "max_tokens": max_tokens, "temperature": temperature });
if let Some(t) = chat_template {
args.as_object_mut().unwrap().insert(
"chat_template".to_string(),
serde_json::to_value(t).unwrap(),
);
}
let messages = serde_json::to_value(messages)?.into();
let iterator = transformer_pipeline
.transform_stream(messages, Some(args.into()), Some(1))
.await?;
let id = Uuid::new_v4().to_string();
let iter = iterator.map(move |choices| {
let since_the_epoch = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time went backwards");
Ok(serde_json::json!({
"id": id.clone(),
"system_fingerprint": fingerprint.clone(),
"object": "chat.completion.chunk",
"created": since_the_epoch.as_secs(),
"model": model_name.clone(),
"choices": choices?.as_array().context("Error parsing choices from GeneralJsonAsyncIterator")?.iter().enumerate().map(|(i, c)| {
serde_json::json!({
"index": i,
"delta": {
"role": "assistant",
"content": c
}
})
}).collect::<serde_json::Value>()
})
.into())
});
Ok(GeneralJsonAsyncIterator(Box::pin(iter)))
}
#[allow(clippy::too_many_arguments)]
pub fn chat_completions_create_stream(
&self,
model: Json,
messages: Vec<Json>,
max_tokens: Option<i32>,
temperature: Option<f64>,
n: Option<i32>,
chat_template: Option<String>,
) -> anyhow::Result<GeneralJsonIterator> {
let runtime = crate::get_or_set_runtime();
let iter = runtime.block_on(self.chat_completions_create_stream_async(
model,
messages,
max_tokens,
temperature,
n,
chat_template,
))?;
Ok(GeneralJsonIterator(Box::new(AsyncToSyncJsonIterator(
Box::pin(iter),
))))
}
#[allow(clippy::too_many_arguments)]
pub async fn chat_completions_create_async(
&self,
model: Json,
messages: Vec<Json>,
max_tokens: Option<i32>,
temperature: Option<f64>,
n: Option<i32>,
chat_template: Option<String>,
) -> anyhow::Result<Json> {
let (transformer_pipeline, model_name, model_parameters) =
self.create_pipeline_model_name_parameters(model)?;
let max_tokens = max_tokens.unwrap_or(1000);
let temperature = temperature.unwrap_or(0.8);
let n = n.unwrap_or(1) as usize;
let to_hash = format!("{}{}{}{}", *model_parameters, max_tokens, temperature, n);
let md5_digest = md5::compute(to_hash.as_bytes());
let fingerprint = uuid::Uuid::from_slice(&md5_digest.0)?;
let mut args = serde_json::json!({ "max_tokens": max_tokens, "temperature": temperature });
if let Some(t) = chat_template {
args.as_object_mut().unwrap().insert(
"chat_template".to_string(),
serde_json::to_value(t).unwrap(),
);
}
let choices = transformer_pipeline
.transform(messages, Some(args.into()))
.await?;
let choices: Vec<Json> = choices
.as_array()
.context("Error parsing return from TransformerPipeline")?
.iter()
.enumerate()
.map(|(i, c)| {
serde_json::json!({
"index": i,
"message": {
"role": "assistant",
"content": c
}
})
.into()
})
.collect();
let since_the_epoch = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time went backwards");
Ok(serde_json::json!({
"id": Uuid::new_v4().to_string(),
"object": "chat.completion",
"created": since_the_epoch.as_secs(),
"model": model_name,
"system_fingerprint": fingerprint,
"choices": choices,
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
})
.into())
}
#[allow(clippy::too_many_arguments)]
pub fn chat_completions_create(
&self,
model: Json,
messages: Vec<Json>,
max_tokens: Option<i32>,
temperature: Option<f64>,
n: Option<i32>,
chat_template: Option<String>,
) -> anyhow::Result<Json> {
let runtime = crate::get_or_set_runtime();
runtime.block_on(self.chat_completions_create_async(
model,
messages,
max_tokens,
temperature,
n,
chat_template,
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::StreamExt;
#[test]
fn can_open_source_ai_create() -> anyhow::Result<()> {
let client = OpenSourceAI::new(None);
let results = client.chat_completions_create(Json::from_serializable("meta-llama/Meta-Llama-3-8B-Instruct"), vec![
serde_json::json!({"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}).into(),
serde_json::json!({"role": "user", "content": "How many helicopters can a human eat in one sitting?"}).into(),
], Some(10), None, Some(3), None)?;
assert!(results["choices"].as_array().is_some());
Ok(())
}
#[sqlx::test]
fn can_open_source_ai_create_async() -> anyhow::Result<()> {
let client = OpenSourceAI::new(None);
let results = client.chat_completions_create_async(Json::from_serializable("meta-llama/Meta-Llama-3-8B-Instruct"), vec![
serde_json::json!({"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}).into(),
serde_json::json!({"role": "user", "content": "How many helicopters can a human eat in one sitting?"}).into(),
], Some(10), None, Some(3), None).await?;
assert!(results["choices"].as_array().is_some());
Ok(())
}
#[sqlx::test]
fn can_open_source_ai_create_stream_async() -> anyhow::Result<()> {
let client = OpenSourceAI::new(None);
let mut stream = client.chat_completions_create_stream_async(Json::from_serializable("meta-llama/Meta-Llama-3-8B-Instruct"), vec![
serde_json::json!({"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}).into(),
serde_json::json!({"role": "user", "content": "How many helicopters can a human eat in one sitting?"}).into(),
], Some(10), None, Some(3), None).await?;
while let Some(o) = stream.next().await {
o?;
}
Ok(())
}
#[test]
fn can_open_source_ai_create_stream() -> anyhow::Result<()> {
let client = OpenSourceAI::new(None);
let iterator = client.chat_completions_create_stream(Json::from_serializable("meta-llama/Meta-Llama-3-8B-Instruct"), vec![
serde_json::json!({"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}).into(),
serde_json::json!({"role": "user", "content": "How many helicopters can a human eat in one sitting?"}).into(),
], Some(10), None, Some(3), None)?;
for o in iterator {
o?;
}
Ok(())
}
}