nihility-model 0.2.2

nihility project ai model module
Documentation
use crate::config::ModelType;
use crate::error::Result;
use crate::{ModelClient, NihilityModel, NihilityModelError};
use std::pin::Pin;
use async_openai::types::completions::CreateCompletionRequestArgs;
use tokio::spawn;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::{Stream, StreamExt};
use tracing::error;

pub type TextResponseStream =
    Pin<Box<dyn Stream<Item = core::result::Result<String, NihilityModelError>> + Send>>;

impl NihilityModel {
    pub async fn completion(&self, model_type: ModelType, prompt: &str) -> Result<String> {
        let (client, model) = self
            .config
            .model_client(self.http_client.clone(), model_type)?;
        let request = CreateCompletionRequestArgs::default()
            .model(model)
            .prompt(prompt)
            .stream(false)
            .build()?;
        Ok(match client {
            ModelClient::OpenAI(openai_client) => openai_client
                .completions()
                .create(request)
                .await?
                .choices
                .first()
                .ok_or(NihilityModelError::Response)?
                .text
                .clone(),
            ModelClient::Azure(azure_client) => azure_client
                .completions()
                .create(request)
                .await?
                .choices
                .first()
                .ok_or(NihilityModelError::Response)?
                .text
                .clone(),
        })
    }

    pub async fn completion_stream(
        &self,
        model_type: ModelType,
        prompt: &str,
    ) -> Result<TextResponseStream> {
        let (client, model) = self
            .config
            .model_client(self.http_client.clone(), model_type)?;
        let request = CreateCompletionRequestArgs::default()
            .model(model)
            .prompt(prompt)
            .stream(true)
            .build()?;
        let (tx, rx) = mpsc::channel(10);
        let mut resp_stream = match client {
            ModelClient::OpenAI(openai_client) => {
                openai_client.completions().create_stream(request).await?
            }
            ModelClient::Azure(azure_client) => {
                azure_client.completions().create_stream(request).await?
            }
        };
        spawn(async move {
            while let Some(resp) = resp_stream.next().await {
                match resp {
                    Ok(ccr) => {
                        let text = match ccr.choices.first() {
                            None => {
                                error!("Received empty choices");
                                if let Err(e) = tx.send(Err(NihilityModelError::Response)).await {
                                    error!("Send result to response error: {}", e);
                                    break;
                                }
                                continue;
                            }
                            Some(choice) => choice.text.clone(),
                        };
                        if let Err(e) = tx.send(Ok(text)).await {
                            error!("Send result to response error: {}", e);
                            break;
                        }
                    }
                    Err(e) => {
                        error!("Response Stream resp Error: {}", e);
                        if let Err(e) = tx.send(Err(NihilityModelError::from(e))).await {
                            error!("Send result to response error: {}", e);
                            break;
                        }
                        break;
                    }
                }
            }
        });
        Ok(Box::pin(ReceiverStream::new(rx)) as TextResponseStream)
    }
}