1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
// Given a chat conversation, the model will return a chat completion response.
// See: https://platform.openai.com/docs/api-reference/chat
//! Chat API
use std::collections::HashMap;
use crate::requests::Requests;
use crate::*;
use serde::{Deserialize, Serialize};
use super::{completions::Completion, CHAT_COMPLETION_CREATE};
#[derive(Debug, Serialize, Deserialize)]
pub struct ChatBody {
/// ID of the model to use.
/// See the model endpoint compatibility table for details on which models work with the Chat API.
pub model: String,
/// The messages to generate chat completions for, in the chat format.
pub messages: Vec<Message>,
/// What sampling temperature to use, between 0 and 2.
/// Higher values like 0.8 will make the output more random,
/// while lower values like 0.2 will make it more focused and deterministic.
/// We generally recommend altering this or top_p but not both.
/// Defaults to 1
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
/// An alternative to sampling with temperature, called nucleus sampling,
/// where the model considers the results of the tokens with top_p probability mass.
/// So 0.1 means only the tokens comprising the top 10% probability mass are considered.
/// We generally recommend altering this or temperature but not both.
/// Defaults to 1
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
/// How many chat completion choices to generate for each input message.
/// Defaults to 1
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<i32>,
/// If set, partial message deltas will be sent, like in ChatGPT.
/// Tokens will be sent as data-only server-sent events as they become available,
/// with the stream terminated by a data: [DONE] message. See the OpenAI Cookbook for example code.
/// Defaults to false
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
/// Up to 4 sequences where the API will stop generating further tokens.
/// Defaults to null
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
/// The maximum number of tokens to generate in the chat completion.
/// The total length of input tokens and generated tokens is limited by the model's context length.
/// Defaults to inf
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<i32>,
/// Number between -2.0 and 2.0.
/// Positive values penalize new tokens based on whether they appear in the text so far,
/// increasing the model's likelihood to talk about new topics.
/// Defaults to 0
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
/// Number between -2.0 and 2.0.
/// Positive values penalize new tokens based on their existing frequency in the text so far,
/// decreasing the model's likelihood to repeat the same line verbatim.
/// Defaults to 0
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
/// Modify the likelihood of specified tokens appearing in the completion.
/// Accepts a json object that maps tokens (specified by their token ID in the tokenizer)
/// to an associated bias value from -100 to 100. Mathematically,
/// the bias is added to the logits generated by the model prior to sampling.
/// The exact effect will vary per model, but values between -1 and 1 should
/// decrease or increase likelihood of selection;
/// values like -100 or 100 should result in a ban or exclusive selection of the relevant token.
/// Defaults to null
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<HashMap<String, String>>,
/// A unique identifier representing your end-user,
/// which can help OpenAI to monitor and detect abuse. Learn more.
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}
pub trait ChatApi {
/// Creates a completion for the chat message
fn chat_completion_create(&self, chat_body: &ChatBody) -> ApiResult<Completion>;
}
impl ChatApi for OpenAI {
fn chat_completion_create(&self, chat_body: &ChatBody) -> ApiResult<Completion> {
let request_body = serde_json::to_value(chat_body).unwrap();
let res = self.post(CHAT_COMPLETION_CREATE, request_body)?;
let completion: Completion = serde_json::from_value(res.clone()).unwrap();
Ok(completion)
}
}
#[cfg(test)]
mod tests {
use crate::{apis::chat::ChatBody, openai::new_test_openai, Message, Role};
use super::ChatApi;
#[test]
fn test_chat_completion() {
let openai = new_test_openai();
let body = ChatBody {
model: "gpt-3.5-turbo".to_string(),
max_tokens: Some(7),
temperature: Some(0_f32),
top_p: Some(0_f32),
n: Some(2),
stream: Some(false),
stop: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: None,
messages: vec![Message { role: Role::User, content: "Hello!".to_string() }],
};
let rs = openai.chat_completion_create(&body);
let choice = rs.unwrap().choices;
let message = &choice[0].message.as_ref().unwrap();
assert!(message.content.contains("Hello"));
}
}