use crate::config::ModelType;
use crate::error::Result;
use crate::{ModelClient, NihilityModel, NihilityModelError};
use std::pin::Pin;
use async_openai::types::completions::CreateCompletionRequestArgs;
use tokio::spawn;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::{Stream, StreamExt};
use tracing::error;
pub type TextResponseStream =
Pin<Box<dyn Stream<Item = core::result::Result<String, NihilityModelError>> + Send>>;
impl NihilityModel {
pub async fn completion(&self, model_type: ModelType, prompt: &str) -> Result<String> {
let (client, model) = self
.config
.model_client(self.http_client.clone(), model_type)?;
let request = CreateCompletionRequestArgs::default()
.model(model)
.prompt(prompt)
.stream(false)
.build()?;
Ok(match client {
ModelClient::OpenAI(openai_client) => openai_client
.completions()
.create(request)
.await?
.choices
.first()
.ok_or(NihilityModelError::Response)?
.text
.clone(),
ModelClient::Azure(azure_client) => azure_client
.completions()
.create(request)
.await?
.choices
.first()
.ok_or(NihilityModelError::Response)?
.text
.clone(),
})
}
pub async fn completion_stream(
&self,
model_type: ModelType,
prompt: &str,
) -> Result<TextResponseStream> {
let (client, model) = self
.config
.model_client(self.http_client.clone(), model_type)?;
let request = CreateCompletionRequestArgs::default()
.model(model)
.prompt(prompt)
.stream(true)
.build()?;
let (tx, rx) = mpsc::channel(10);
let mut resp_stream = match client {
ModelClient::OpenAI(openai_client) => {
openai_client.completions().create_stream(request).await?
}
ModelClient::Azure(azure_client) => {
azure_client.completions().create_stream(request).await?
}
};
spawn(async move {
while let Some(resp) = resp_stream.next().await {
match resp {
Ok(ccr) => {
let text = match ccr.choices.first() {
None => {
error!("Received empty choices");
if let Err(e) = tx.send(Err(NihilityModelError::Response)).await {
error!("Send result to response error: {}", e);
break;
}
continue;
}
Some(choice) => choice.text.clone(),
};
if let Err(e) = tx.send(Ok(text)).await {
error!("Send result to response error: {}", e);
break;
}
}
Err(e) => {
error!("Response Stream resp Error: {}", e);
if let Err(e) = tx.send(Err(NihilityModelError::from(e))).await {
error!("Send result to response error: {}", e);
break;
}
break;
}
}
}
});
Ok(Box::pin(ReceiverStream::new(rx)) as TextResponseStream)
}
}