1use super::{openai_post, ApiResponseOrError, Usage};
4use crate::openai_request_stream;
5use derive_builder::Builder;
6use futures_util::StreamExt;
7use reqwest::Method;
8use reqwest_eventsource::{CannotCloneRequestError, Event, EventSource};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::collections::HashMap;
12use tokio::sync::mpsc::{channel, Receiver, Sender};
13
14pub type ChatCompletion = ChatCompletionGeneric<ChatCompletionChoice>;
16
17pub type ChatCompletionDelta = ChatCompletionGeneric<ChatCompletionChoiceDelta>;
19
20#[derive(Deserialize, Clone, Debug)]
21pub struct ChatCompletionGeneric<C> {
22 pub id: String,
23 pub object: String,
24 pub created: u64,
25 pub model: String,
26 pub choices: Vec<C>,
27 pub usage: Option<Usage>,
28}
29
30#[derive(Deserialize, Clone, Debug)]
31pub struct ChatCompletionChoice {
32 pub index: u64,
33 pub finish_reason: String,
34 pub message: ChatCompletionMessage,
35}
36
37#[derive(Deserialize, Clone, Debug)]
38pub struct ChatCompletionChoiceDelta {
39 pub index: u64,
40 pub finish_reason: Option<String>,
41 pub delta: ChatCompletionMessageDelta,
42}
43
44#[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)]
45pub struct ChatCompletionMessage {
46 pub role: ChatCompletionMessageRole,
48 pub content: Option<String>,
53 #[serde(skip_serializing_if = "Option::is_none")]
55 pub name: Option<String>,
56 #[serde(skip_serializing_if = "Option::is_none")]
60 pub function_call: Option<ChatCompletionFunctionCall>,
61}
62
63#[derive(Deserialize, Clone, Debug, Eq, PartialEq)]
65pub struct ChatCompletionMessageDelta {
66 pub role: Option<ChatCompletionMessageRole>,
68 pub content: Option<String>,
70 #[serde(skip_serializing_if = "Option::is_none")]
72 pub name: Option<String>,
73 #[serde(skip_serializing_if = "Option::is_none")]
77 pub function_call: Option<ChatCompletionFunctionCallDelta>,
78}
79
80#[derive(Deserialize, Serialize, Debug, Clone)]
81pub struct ChatCompletionFunctionDefinition {
82 pub name: String,
84 #[serde(skip_serializing_if = "Option::is_none")]
86 pub description: Option<String>,
87 #[serde(skip_serializing_if = "Option::is_none")]
91 pub parameters: Option<Value>,
92}
93
94#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)]
95pub struct ChatCompletionFunctionCall {
96 pub name: String,
98 pub arguments: String,
101}
102
103#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)]
105pub struct ChatCompletionFunctionCallDelta {
106 pub name: Option<String>,
108 pub arguments: Option<String>,
111}
112
113#[derive(Deserialize, Serialize, Debug, Clone, Copy, Eq, PartialEq)]
114#[serde(rename_all = "lowercase")]
115pub enum ChatCompletionMessageRole {
116 System,
117 User,
118 Assistant,
119 Function,
120}
121
122#[derive(Serialize, Builder, Debug, Clone)]
123#[builder(pattern = "owned")]
124#[builder(name = "ChatCompletionBuilder")]
125#[builder(setter(strip_option, into))]
126pub struct ChatCompletionRequest {
127 model: String,
130 messages: Vec<ChatCompletionMessage>,
132 #[builder(default)]
136 #[serde(skip_serializing_if = "Option::is_none")]
137 temperature: Option<f32>,
138 #[builder(default)]
142 #[serde(skip_serializing_if = "Option::is_none")]
143 top_p: Option<f32>,
144 #[builder(default)]
146 #[serde(skip_serializing_if = "Option::is_none")]
147 n: Option<u8>,
148 #[builder(default)]
149 #[serde(skip_serializing_if = "Option::is_none")]
150 stream: Option<bool>,
151 #[builder(default)]
153 #[serde(skip_serializing_if = "Vec::is_empty")]
154 stop: Vec<String>,
155 #[builder(default)]
157 #[serde(skip_serializing_if = "Option::is_none")]
158 seed: Option<u64>,
159 #[builder(default)]
161 #[serde(skip_serializing_if = "Option::is_none")]
162 max_tokens: Option<u64>,
163 #[builder(default)]
167 #[serde(skip_serializing_if = "Option::is_none")]
168 presence_penalty: Option<f32>,
169 #[builder(default)]
173 #[serde(skip_serializing_if = "Option::is_none")]
174 frequency_penalty: Option<f32>,
175 #[builder(default)]
179 #[serde(skip_serializing_if = "Option::is_none")]
180 logit_bias: Option<HashMap<String, f32>>,
181 #[builder(default)]
183 #[serde(skip_serializing_if = "String::is_empty")]
184 user: String,
185 #[builder(default)]
192 #[serde(skip_serializing_if = "Vec::is_empty")]
193 functions: Vec<ChatCompletionFunctionDefinition>,
194 #[builder(default)]
204 #[serde(skip_serializing_if = "Option::is_none")]
205 function_call: Option<Value>,
206}
207
208impl<C> ChatCompletionGeneric<C> {
209 pub fn builder(
210 model: &str,
211 messages: impl Into<Vec<ChatCompletionMessage>>,
212 ) -> ChatCompletionBuilder {
213 ChatCompletionBuilder::create_empty()
214 .model(model)
215 .messages(messages)
216 }
217}
218
219impl ChatCompletion {
220 pub async fn create(request: &ChatCompletionRequest) -> ApiResponseOrError<Self> {
221 openai_post("chat/completions", request).await
222 }
223}
224
225impl ChatCompletionDelta {
226 pub async fn create(
227 request: &ChatCompletionRequest,
228 ) -> Result<Receiver<Self>, CannotCloneRequestError> {
229 let stream =
230 openai_request_stream(Method::POST, "chat/completions", |r| r.json(request)).await?;
231 let (tx, rx) = channel::<Self>(32);
232 tokio::spawn(forward_deserialized_chat_response_stream(stream, tx));
233 Ok(rx)
234 }
235
236 pub fn merge(
238 &mut self,
239 other: ChatCompletionDelta,
240 ) -> Result<(), ChatCompletionDeltaMergeError> {
241 if other.id.ne(&self.id) {
242 return Err(ChatCompletionDeltaMergeError::DifferentCompletionIds);
243 }
244 for other_choice in other.choices.iter() {
245 for choice in self.choices.iter_mut() {
246 if choice.index != other_choice.index {
247 continue;
248 }
249 choice.merge(other_choice)?;
250 }
251 }
252 Ok(())
253 }
254}
255
256impl ChatCompletionChoiceDelta {
257 pub fn merge(
258 &mut self,
259 other: &ChatCompletionChoiceDelta,
260 ) -> Result<(), ChatCompletionDeltaMergeError> {
261 if self.index != other.index {
262 return Err(ChatCompletionDeltaMergeError::DifferentCompletionChoiceIndices);
263 }
264 if self.delta.role.is_none() {
265 if let Some(other_role) = other.delta.role {
266 self.delta.role = Some(other_role);
268 }
269 }
270 if self.delta.name.is_none() {
271 if let Some(other_name) = &other.delta.name {
272 self.delta.name = Some(other_name.clone());
274 }
275 }
276 match self.delta.content.as_mut() {
278 Some(content) => {
279 match &other.delta.content {
280 Some(other_content) => {
281 content.push_str(other_content)
283 }
284 None => {}
285 }
286 }
287 None => {
288 match &other.delta.content {
289 Some(other_content) => {
290 self.delta.content = Some(other_content.clone());
292 }
293 None => {}
294 }
295 }
296 };
297
298 match self.delta.function_call.as_mut() {
302 Some(function_call) => {
303 match &other.delta.function_call {
304 Some(other_function_call) => {
305 match (&mut function_call.arguments, &other_function_call.arguments) {
307 (Some(function_call), Some(other_function_call)) => {
308 function_call.push_str(&other_function_call);
309 }
310 (None, Some(other_function_call)) => {
311 function_call.arguments = Some(other_function_call.clone());
312 }
313 _ => {}
314 }
315 }
316 None => {}
317 }
318 }
319 None => {
320 match &other.delta.function_call {
321 Some(other_function_call) => {
322 self.delta.function_call = Some(other_function_call.clone());
324 }
325 None => {}
326 }
327 }
328 };
329 Ok(())
330 }
331}
332
333impl From<ChatCompletionDelta> for ChatCompletion {
334 fn from(delta: ChatCompletionDelta) -> Self {
335 ChatCompletion {
336 id: delta.id,
337 object: delta.object,
338 created: delta.created,
339 model: delta.model,
340 usage: delta.usage,
341 choices: delta
342 .choices
343 .iter()
344 .map(|choice| ChatCompletionChoice {
345 index: choice.index,
346 finish_reason: clone_default_unwrapped_option_string(&choice.finish_reason),
347 message: ChatCompletionMessage {
348 role: choice
349 .delta
350 .role
351 .unwrap_or_else(|| ChatCompletionMessageRole::System),
352 content: choice.delta.content.clone(),
353 name: choice.delta.name.clone(),
354 function_call: choice.delta.function_call.clone().map(|f| f.into()),
355 },
356 })
357 .collect(),
358 }
359 }
360}
361
362impl From<ChatCompletionFunctionCallDelta> for ChatCompletionFunctionCall {
363 fn from(delta: ChatCompletionFunctionCallDelta) -> Self {
364 ChatCompletionFunctionCall {
365 name: delta.name.unwrap_or("".to_string()),
366 arguments: delta.arguments.unwrap_or_default(),
367 }
368 }
369}
370
371#[derive(Debug)]
372pub enum ChatCompletionDeltaMergeError {
373 DifferentCompletionIds,
374 DifferentCompletionChoiceIndices,
375 FunctionCallArgumentTypeMismatch,
376}
377
378async fn forward_deserialized_chat_response_stream(
379 mut stream: EventSource,
380 tx: Sender<ChatCompletionDelta>,
381) -> anyhow::Result<()> {
382 while let Some(event) = stream.next().await {
383 let event = event?;
384 match event {
385 Event::Message(event) => {
386 let completion = serde_json::from_str::<ChatCompletionDelta>(&event.data)?;
387 tx.send(completion).await?;
388 }
389 _ => {}
390 }
391 }
392 Ok(())
393}
394
395impl ChatCompletionBuilder {
396 pub async fn create(self) -> ApiResponseOrError<ChatCompletion> {
397 ChatCompletion::create(&self.build().unwrap()).await
398 }
399
400 pub async fn create_stream(
401 mut self,
402 ) -> Result<Receiver<ChatCompletionDelta>, CannotCloneRequestError> {
403 self.stream = Some(Some(true));
404 ChatCompletionDelta::create(&self.build().unwrap()).await
405 }
406}
407
408fn clone_default_unwrapped_option_string(string: &Option<String>) -> String {
409 match string {
410 Some(value) => value.clone(),
411 None => "".to_string(),
412 }
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418 use crate::set_key;
419 use dotenvy::dotenv;
420 use std::env;
421
422 #[tokio::test]
423 async fn chat() {
424 dotenv().ok();
425 set_key(env::var("OPENAI_KEY").unwrap());
426
427 let chat_completion = ChatCompletion::builder(
428 "gpt-3.5-turbo",
429 [ChatCompletionMessage {
430 role: ChatCompletionMessageRole::User,
431 content: Some("Hello!".to_string()),
432 name: None,
433 function_call: None,
434 }],
435 )
436 .temperature(0.0)
437 .create()
438 .await
439 .unwrap();
440
441 assert_eq!(
442 chat_completion
443 .choices
444 .first()
445 .unwrap()
446 .message
447 .content
448 .as_ref()
449 .unwrap(),
450 "Hello! How can I assist you today?"
451 );
452 }
453
454 #[tokio::test]
457 async fn chat_seed() {
458 dotenv().ok();
459 set_key(env::var("OPENAI_KEY").unwrap());
460
461 let chat_completion = ChatCompletion::builder(
462 "gpt-3.5-turbo",
463 [ChatCompletionMessage {
464 role: ChatCompletionMessageRole::User,
465 content: Some(
466 "What type of seed does Mr. England sow in the song? Reply with 1 word."
467 .to_string(),
468 ),
469 name: None,
470 function_call: None,
471 }],
472 )
473 .temperature(0.0)
475 .seed(1337u64)
476 .create()
477 .await
478 .unwrap();
479
480 assert_eq!(
481 chat_completion
482 .choices
483 .first()
484 .unwrap()
485 .message
486 .content
487 .as_ref()
488 .unwrap(),
489 "Love"
490 );
491 }
492
493 #[tokio::test]
494 async fn chat_stream() {
495 dotenv().ok();
496 set_key(env::var("OPENAI_KEY").unwrap());
497
498 let chat_stream = ChatCompletion::builder(
499 "gpt-3.5-turbo",
500 [ChatCompletionMessage {
501 role: ChatCompletionMessageRole::User,
502 content: Some("Hello!".to_string()),
503 name: None,
504 function_call: None,
505 }],
506 )
507 .temperature(0.0)
508 .create_stream()
509 .await
510 .unwrap();
511
512 let chat_completion = stream_to_completion(chat_stream).await;
513
514 assert_eq!(
515 chat_completion
516 .choices
517 .first()
518 .unwrap()
519 .message
520 .content
521 .as_ref()
522 .unwrap(),
523 "Hello! How can I assist you today?"
524 );
525 }
526
527 #[tokio::test]
528 async fn chat_function() {
529 dotenv().ok();
530 set_key(env::var("OPENAI_KEY").unwrap());
531
532 let chat_stream = ChatCompletion::builder(
533 "gpt-3.5-turbo-0613",
534 [
535 ChatCompletionMessage {
536 role: ChatCompletionMessageRole::User,
537 content: Some("What is the weather in Boston?".to_string()),
538 name: None,
539 function_call: None,
540 }
541 ]
542 ).functions([ChatCompletionFunctionDefinition {
543 description: Some("Get the current weather in a given location.".to_string()),
544 name: "get_current_weather".to_string(),
545 parameters: Some(serde_json::json!({
546 "type": "object",
547 "properties": {
548 "location": {
549 "type": "string",
550 "description": "The city and state to get the weather for. (eg: San Francisco, CA)"
551 }
552 },
553 "required": ["location"]
554 })),
555 }])
556 .temperature(0.2)
557 .create_stream()
558 .await
559 .unwrap();
560
561 let chat_completion = stream_to_completion(chat_stream).await;
562
563 assert_eq!(
564 chat_completion
565 .choices
566 .first()
567 .unwrap()
568 .message
569 .function_call
570 .as_ref()
571 .unwrap()
572 .name,
573 "get_current_weather".to_string(),
574 );
575
576 assert_eq!(
577 serde_json::from_str::<Value>(
578 &chat_completion
579 .choices
580 .first()
581 .unwrap()
582 .message
583 .function_call
584 .as_ref()
585 .unwrap()
586 .arguments
587 )
588 .unwrap(),
589 serde_json::json!({
590 "location": "Boston, MA"
591 }),
592 );
593 }
594
595 async fn stream_to_completion(
596 mut chat_stream: Receiver<ChatCompletionDelta>,
597 ) -> ChatCompletion {
598 let mut merged: Option<ChatCompletionDelta> = None;
599 while let Some(delta) = chat_stream.recv().await {
600 match merged.as_mut() {
601 Some(c) => {
602 c.merge(delta).unwrap();
603 }
604 None => merged = Some(delta),
605 };
606 }
607 merged.unwrap().into()
608 }
609}