1use crate::config::ModelType;
2use crate::error::Result;
3use crate::{ModelClient, NihilityModel, NihilityModelError};
4use std::pin::Pin;
5use async_openai::types::completions::CreateCompletionRequestArgs;
6use tokio::spawn;
7use tokio::sync::mpsc;
8use tokio_stream::wrappers::ReceiverStream;
9use tokio_stream::{Stream, StreamExt};
10use tracing::error;
11
12pub type TextResponseStream =
13 Pin<Box<dyn Stream<Item = core::result::Result<String, NihilityModelError>> + Send>>;
14
15impl NihilityModel {
16 pub async fn completion(&self, model_type: ModelType, prompt: &str) -> Result<String> {
17 let (client, model) = self
18 .config
19 .model_client(self.http_client.clone(), model_type)?;
20 let request = CreateCompletionRequestArgs::default()
21 .model(model)
22 .prompt(prompt)
23 .stream(false)
24 .build()?;
25 Ok(match client {
26 ModelClient::OpenAI(openai_client) => openai_client
27 .completions()
28 .create(request)
29 .await?
30 .choices
31 .first()
32 .ok_or(NihilityModelError::Response)?
33 .text
34 .clone(),
35 ModelClient::Azure(azure_client) => azure_client
36 .completions()
37 .create(request)
38 .await?
39 .choices
40 .first()
41 .ok_or(NihilityModelError::Response)?
42 .text
43 .clone(),
44 })
45 }
46
47 pub async fn completion_stream(
48 &self,
49 model_type: ModelType,
50 prompt: &str,
51 ) -> Result<TextResponseStream> {
52 let (client, model) = self
53 .config
54 .model_client(self.http_client.clone(), model_type)?;
55 let request = CreateCompletionRequestArgs::default()
56 .model(model)
57 .prompt(prompt)
58 .stream(true)
59 .build()?;
60 let (tx, rx) = mpsc::channel(10);
61 let mut resp_stream = match client {
62 ModelClient::OpenAI(openai_client) => {
63 openai_client.completions().create_stream(request).await?
64 }
65 ModelClient::Azure(azure_client) => {
66 azure_client.completions().create_stream(request).await?
67 }
68 };
69 spawn(async move {
70 while let Some(resp) = resp_stream.next().await {
71 match resp {
72 Ok(ccr) => {
73 let text = match ccr.choices.first() {
74 None => {
75 error!("Received empty choices");
76 if let Err(e) = tx.send(Err(NihilityModelError::Response)).await {
77 error!("Send result to response error: {}", e);
78 break;
79 }
80 continue;
81 }
82 Some(choice) => choice.text.clone(),
83 };
84 if let Err(e) = tx.send(Ok(text)).await {
85 error!("Send result to response error: {}", e);
86 break;
87 }
88 }
89 Err(e) => {
90 error!("Response Stream resp Error: {}", e);
91 if let Err(e) = tx.send(Err(NihilityModelError::from(e))).await {
92 error!("Send result to response error: {}", e);
93 break;
94 }
95 break;
96 }
97 }
98 }
99 });
100 Ok(Box::pin(ReceiverStream::new(rx)) as TextResponseStream)
101 }
102}