Skip to main content

nihility_model/
text.rs

1use crate::config::ModelType;
2use crate::error::Result;
3use crate::{ModelClient, NihilityModel, NihilityModelError};
4use std::pin::Pin;
5use async_openai::types::completions::CreateCompletionRequestArgs;
6use tokio::spawn;
7use tokio::sync::mpsc;
8use tokio_stream::wrappers::ReceiverStream;
9use tokio_stream::{Stream, StreamExt};
10use tracing::error;
11
12pub type TextResponseStream =
13    Pin<Box<dyn Stream<Item = core::result::Result<String, NihilityModelError>> + Send>>;
14
15impl NihilityModel {
16    pub async fn completion(&self, model_type: ModelType, prompt: &str) -> Result<String> {
17        let (client, model) = self
18            .config
19            .model_client(self.http_client.clone(), model_type)?;
20        let request = CreateCompletionRequestArgs::default()
21            .model(model)
22            .prompt(prompt)
23            .stream(false)
24            .build()?;
25        Ok(match client {
26            ModelClient::OpenAI(openai_client) => openai_client
27                .completions()
28                .create(request)
29                .await?
30                .choices
31                .first()
32                .ok_or(NihilityModelError::Response)?
33                .text
34                .clone(),
35            ModelClient::Azure(azure_client) => azure_client
36                .completions()
37                .create(request)
38                .await?
39                .choices
40                .first()
41                .ok_or(NihilityModelError::Response)?
42                .text
43                .clone(),
44        })
45    }
46
47    pub async fn completion_stream(
48        &self,
49        model_type: ModelType,
50        prompt: &str,
51    ) -> Result<TextResponseStream> {
52        let (client, model) = self
53            .config
54            .model_client(self.http_client.clone(), model_type)?;
55        let request = CreateCompletionRequestArgs::default()
56            .model(model)
57            .prompt(prompt)
58            .stream(true)
59            .build()?;
60        let (tx, rx) = mpsc::channel(10);
61        let mut resp_stream = match client {
62            ModelClient::OpenAI(openai_client) => {
63                openai_client.completions().create_stream(request).await?
64            }
65            ModelClient::Azure(azure_client) => {
66                azure_client.completions().create_stream(request).await?
67            }
68        };
69        spawn(async move {
70            while let Some(resp) = resp_stream.next().await {
71                match resp {
72                    Ok(ccr) => {
73                        let text = match ccr.choices.first() {
74                            None => {
75                                error!("Received empty choices");
76                                if let Err(e) = tx.send(Err(NihilityModelError::Response)).await {
77                                    error!("Send result to response error: {}", e);
78                                    break;
79                                }
80                                continue;
81                            }
82                            Some(choice) => choice.text.clone(),
83                        };
84                        if let Err(e) = tx.send(Ok(text)).await {
85                            error!("Send result to response error: {}", e);
86                            break;
87                        }
88                    }
89                    Err(e) => {
90                        error!("Response Stream resp Error: {}", e);
91                        if let Err(e) = tx.send(Err(NihilityModelError::from(e))).await {
92                            error!("Send result to response error: {}", e);
93                            break;
94                        }
95                        break;
96                    }
97                }
98            }
99        });
100        Ok(Box::pin(ReceiverStream::new(rx)) as TextResponseStream)
101    }
102}