llmvm_outsource_lib/
openai.rs

1use std::collections::VecDeque;
2
3use async_stream::stream;
4use futures::StreamExt;
5use llmvm_protocol::{
6    BackendGenerationRequest, BackendGenerationResponse, Message, MessageRole, ModelDescription,
7};
8use llmvm_protocol::{NotificationStream, ServiceError};
9use reqwest::{Client, Response as HttpResponse, Url};
10use serde::Deserialize;
11
12use crate::util::check_status_code;
13use crate::{OutsourceError, Result};
14
15const DEFAULT_OPENAI_API_HOST: &str = "https://api.openai.com";
16
17const CHAT_COMPLETION_ENDPOINT: &str = "/v1/chat/completions";
18const COMPLETION_ENDPOINT: &str = "/v1/completions";
19
20const MODEL_KEY: &str = "model";
21const PROMPT_KEY: &str = "prompt";
22const MESSAGES_KEY: &str = "messages";
23const MAX_TOKENS_KEY: &str = "max_tokens";
24const STREAM_KEY: &str = "stream";
25
26const SSE_DATA_PREFIX: &str = "data: ";
27const SSE_DONE_MESSAGE: &str = "[DONE]";
28
29#[derive(Deserialize)]
30struct CompletionChoice {
31    text: Option<String>,
32}
33
34#[derive(Deserialize)]
35struct CompletionResponse {
36    choices: Vec<CompletionChoice>,
37}
38
39#[derive(Deserialize)]
40struct ChatCompletionResponse {
41    choices: Vec<ChatCompletionChoice>,
42}
43
44#[derive(Deserialize)]
45struct ChatCompletionStreamResponse {
46    choices: Vec<ChatCompletionStreamChoice>,
47}
48
49#[derive(Deserialize)]
50struct ChatCompletionChoice {
51    message: ChatCompletionChoiceMessage,
52}
53
54#[derive(Deserialize)]
55struct ChatCompletionStreamChoice {
56    delta: ChatCompletionChoiceMessage,
57}
58
59#[derive(Deserialize)]
60struct ChatCompletionChoiceMessage {
61    content: Option<String>,
62}
63
64async fn send_generate_request(
65    mut request: BackendGenerationRequest,
66    model_description: ModelDescription,
67    config_openai_endpoint: Option<&str>,
68    api_key: &str,
69    is_chat_model: bool,
70    should_stream: bool,
71) -> Result<HttpResponse> {
72    let endpoint = if is_chat_model {
73        CHAT_COMPLETION_ENDPOINT
74    } else {
75        COMPLETION_ENDPOINT
76    };
77    let url = if model_description.endpoint.is_some() {
78        model_description.endpoint.unwrap().join(endpoint).unwrap()
79    } else {
80        Url::parse(config_openai_endpoint.unwrap_or(DEFAULT_OPENAI_API_HOST))
81            .expect("url should parse")
82            .join(endpoint)
83            .unwrap()
84    };
85
86    let mut body = request.model_parameters.take().unwrap_or_default();
87
88    body.insert(MODEL_KEY.to_string(), model_description.model_name.into());
89    body.insert(MAX_TOKENS_KEY.to_string(), request.max_tokens.into());
90    if is_chat_model {
91        let mut messages: Vec<_> = request.thread_messages.take().unwrap_or_default();
92        messages.push(Message {
93            role: MessageRole::User,
94            content: request.prompt,
95        });
96        body.insert(MESSAGES_KEY.to_string(), serde_json::to_value(messages)?);
97    } else {
98        body.insert(PROMPT_KEY.to_string(), request.prompt.into());
99    }
100
101    if should_stream {
102        body.insert(STREAM_KEY.to_string(), true.into());
103    }
104
105    let client = Client::new();
106    let response = client
107        .post(url)
108        .bearer_auth(api_key)
109        .json(&body)
110        .send()
111        .await?;
112
113    let response = check_status_code(response).await?;
114
115    Ok(response)
116}
117
118/// Generate text and return the whole response using the OpenAI API.
119pub async fn generate(
120    request: BackendGenerationRequest,
121    model_description: ModelDescription,
122    config_openai_endpoint: Option<&str>,
123    api_key: &str,
124) -> Result<BackendGenerationResponse> {
125    let is_chat_model = model_description.is_chat_model();
126    let response = send_generate_request(
127        request,
128        model_description,
129        config_openai_endpoint,
130        api_key,
131        is_chat_model,
132        false,
133    )
134    .await?;
135    let response = if is_chat_model {
136        let mut body: ChatCompletionResponse = response.json().await?;
137        let choice = body.choices.pop().ok_or(OutsourceError::NoTextInResponse)?;
138        choice.message.content
139    } else {
140        let mut body: CompletionResponse = response.json().await?;
141        let choice = body.choices.pop().ok_or(OutsourceError::NoTextInResponse)?;
142        choice.text
143    }
144    .unwrap_or_default();
145
146    Ok(BackendGenerationResponse { response })
147}
148
149fn extract_response_from_stream_event(
150    line_json: &str,
151    is_chat_model: bool,
152) -> Result<BackendGenerationResponse> {
153    let response = if is_chat_model {
154        let mut update: ChatCompletionStreamResponse = serde_json::from_str(line_json)?;
155        let choice = update
156            .choices
157            .pop()
158            .ok_or(OutsourceError::NoTextInResponse)?;
159        choice.delta.content
160    } else {
161        let mut update: CompletionResponse = serde_json::from_str(line_json)?;
162        let choice = update
163            .choices
164            .pop()
165            .ok_or(OutsourceError::NoTextInResponse)?;
166        choice.text
167    }
168    .unwrap_or_default();
169    Ok(BackendGenerationResponse { response })
170}
171
172/// Request text generation and return an asynchronous stream of generated tokens,
173/// using the OpenAI API.
174pub async fn generate_stream(
175    request: BackendGenerationRequest,
176    model_description: ModelDescription,
177    config_openai_endpoint: Option<&str>,
178    api_key: &str,
179) -> Result<NotificationStream<BackendGenerationResponse>> {
180    let is_chat_model = model_description.is_chat_model();
181    let response = send_generate_request(
182        request,
183        model_description,
184        config_openai_endpoint,
185        api_key,
186        is_chat_model,
187        true,
188    )
189    .await?;
190    let mut response_stream = response.bytes_stream();
191    Ok(stream! {
192        let mut buffer = VecDeque::new();
193        while let Some(bytes_result) = response_stream.next().await {
194            match bytes_result {
195                Err(e) => {
196                    let boxed_e: ServiceError = Box::new(e);
197                    yield Err(boxed_e.into());
198                    return;
199                },
200                Ok(bytes) => {
201                    buffer.extend(bytes);
202                }
203            }
204            while let Some(linebreak_pos) = buffer.iter().position(|b| b == &b'\n') {
205                let line_bytes = buffer.drain(0..(linebreak_pos + 1)).collect::<Vec<_>>();
206                if let Ok(line) = std::str::from_utf8(&line_bytes) {
207                    if !line.starts_with(SSE_DATA_PREFIX) {
208                        continue;
209                    }
210                    let line_json = &line[SSE_DATA_PREFIX.len()..];
211                    if line_json.starts_with(SSE_DONE_MESSAGE) {
212                        continue;
213                    }
214                    let result = extract_response_from_stream_event(line_json, is_chat_model);
215                    yield result.map_err(|e| e.into());
216                }
217            }
218        }
219    }
220    .boxed())
221}