1use async_recursion::async_recursion;
2use clap::{Args,ValueEnum};
3use serde::{Serialize,Deserialize};
4use reqwest::Client;
5use derive_more::From;
6use tiktoken_rs::p50k_base;
7use crate::openai::chat::OpenAIChatCommand;
8use crate::openai::OpenAIError;
9use crate::completion::{CompletionOptions,CompletionFile,ClashingArgumentsError};
10use crate::Config;
11use log::debug;
12
13#[derive(Args, Clone, Debug, Default, Serialize, Deserialize)]
14pub struct ChatCommand {
15 #[command(flatten)]
16 #[serde(flatten)]
17 pub completion: CompletionOptions,
18
19 #[arg(long, short)]
20 pub system: Option<String>,
21
22 #[arg(long, short)]
23 pub direction: Option<String>,
24
25 #[arg(long)]
26 pub provider: Option<ChatProvider>,
27}
28
29impl ChatCommand {
30 #[async_recursion]
31 pub async fn run(&self, client: &Client, config: &Config) -> ChatResult {
32 let mut options = ChatOptions::try_from((self, config))?;
33 let print_output = !options.completion.quiet.unwrap_or(false);
34
35 if print_output && options.file.transcript.len() > 0 {
36 print!("{}", options.file.transcript);
37 }
38
39 if !options.ai_responds_first {
40 let append = options.completion.append.as_ref().map(|a| &**a);
41 let prefix_user = Some(&*options.prefix_user);
42
43 if let None = options.file.read(append, prefix_user, options.no_context) {
44 return Ok(vec![]);
45 }
46 }
47
48 let mut command = OpenAIChatCommand::try_from(options)?;
49 command.run(client, config).await
50 }
51}
52
53#[derive(Default, Debug)]
54pub(crate) struct ChatOptions {
55 pub ai_responds_first: bool,
56 pub completion: CompletionOptions,
57 pub direction: Option<ChatMessage>,
58 pub system: String,
59 pub file: CompletionFile<ChatCommand>,
60 pub no_context: bool,
61 pub provider: ChatProvider,
62 pub prefix_ai: String,
63 pub prefix_user: String,
64 pub stream: bool,
65 pub temperature: f32,
66 pub tokens_max: Option<usize>,
67 pub tokens_balance: f32,
68 pub stop: Vec<String>
69}
70
71impl TryFrom<(&ChatCommand, &Config)> for ChatOptions {
72 type Error = ChatError;
73
74 fn try_from((command, config): (&ChatCommand, &Config)) -> Result<Self, Self::Error> {
75 let file = command.completion.load_session_file::<ChatCommand>(config, command.clone());
76 let completion = if file.file.is_some() {
77 command.completion.merge(&file.overrides.completion)
78 } else {
79 command.completion.clone()
80 };
81
82 let stream = completion.parse_stream_option()?;
83 let system = command.system
84 .clone()
85 .or_else(|| file.overrides.system.clone())
86 .clone()
87 .unwrap_or_else(|| String::from("A friendly and helpful AI assistant."));
88
89 let provider = command.provider.unwrap_or_default();
90
91 Ok(ChatOptions {
92 ai_responds_first: completion.ai_responds_first.unwrap_or(false),
93 direction: command.direction.clone()
94 .map(|direction| ChatMessage::new(ChatRole::System, direction)),
95 temperature: completion.temperature.unwrap_or(0.8),
96 no_context: completion.no_context.unwrap_or(false),
97 provider,
98 prefix_ai: completion.prefix_ai.clone().unwrap_or_else(|| String::from("AI")),
99 prefix_user: completion.prefix_user.clone().unwrap_or_else(|| String::from("USER")),
100 system,
101 tokens_balance: completion.tokens_balance.unwrap_or(0.5),
102 stop: completion.parse_stops(),
103 tokens_max: completion.tokens_max,
104 completion,
105 stream,
106 file,
107 })
108 }
109}
110
111#[derive(Debug, From)]
112pub enum ChatError {
113 ClashingArguments(ClashingArgumentsError),
114 ChatTranscriptionError(ChatTranscriptionError),
115 TranscriptDeserializationError(serde_json::Error),
116 OpenAIError(OpenAIError),
117 NetworkError(reqwest::Error),
118 IOError(std::io::Error),
119 EventSource(reqwest_eventsource::Error),
120 Unauthorized
121}
122
123#[derive(Debug)]
124pub struct ChatTranscriptionError(pub String);
125
126pub type ChatResult = Result<Vec<ChatMessage>, ChatError>;
127
128#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
129pub struct ChatMessage {
130 pub role: ChatRole,
131 pub content: String,
132 #[serde(skip)]
133 pub tokens: usize
134}
135
136impl ChatMessage {
137 pub fn new(role: ChatRole, content: impl AsRef<str>) -> Self {
138 let tokens = p50k_base().unwrap()
139 .encode_with_special_tokens(&format!("{}{}", role, content.as_ref()))
140 .len();
141
142 ChatMessage {
143 role,
144 content: content.as_ref().to_string(),
145 tokens
146 }
147 }
148}
149
150pub type ChatMessages = Vec<ChatMessage>;
151
152impl TryFrom<&ChatOptions> for ChatMessages {
153 type Error = ChatError;
154
155 fn try_from(options: &ChatOptions) -> Result<Self, Self::Error> {
156 let ChatOptions { file, system, .. } = options;
157
158 let mut messages = vec![];
159 let mut message: Option<ChatMessage> = None;
160
161 messages.push(ChatMessage::new(ChatRole::System, system));
162
163 let handle_continuing_line = |line, message: &mut Option<ChatMessage>| match message {
164 Some(m) => {
165 *message = Some(ChatMessage::new(m.role, {
166 let mut content = m.content.clone();
167 content += "\n";
168 content += line;
169 content
170 }));
171 Ok(())
172 },
173 None => {
174 return Err(ChatError::ChatTranscriptionError(ChatTranscriptionError(
175 "Missing opening chat role".into()
176 )));
177 }
178 };
179
180 for line in file.transcript.lines() {
181 match line.split_once(':') {
182 Some((role, dialog)) => match ChatRole::try_from((role, options)) {
183 Ok(normalized_role) => {
184 if let Some(message) = message {
185 messages.push(message);
186 }
187
188 let mut dialog = dialog.trim_start().to_string();
189 if role != "ai" && role != "assitant" && role != "user" && role != "system"
190 && !dialog.to_lowercase().starts_with(role) {
191 dialog = format!("{role}: {dialog}");
192 }
193
194 message = Some(ChatMessage::new(normalized_role, dialog));
195 },
196 Err(_) => {
197 println!("lkjlkj");
198 handle_continuing_line(line, &mut message)?
199 }
200 },
201 None => handle_continuing_line(line, &mut message)?
202 }
203 }
204
205 if let Some(message) = message {
206 messages.push(message);
207 }
208
209 if options.no_context {
210 messages.push(ChatMessage::new(ChatRole::User, file.last_read_input.clone()));
211 }
212
213 if let Some(direction) = &options.direction {
214 messages.push(direction.clone());
215 }
216
217 if options.no_context {
218 messages.push(ChatMessage::new(ChatRole::Ai, file.last_written_input.clone()))
219 }
220
221 let lab = messages.labotomize(&options)?;
222 return Ok(lab);
223 }
224}
225
226pub(crate) trait ChatMessagesInternalExt {
227 fn labotomize(&self, options: &ChatOptions) -> Result<Self, ChatError> where Self: Sized;
228}
229
230impl ChatMessagesInternalExt for ChatMessages {
231 fn labotomize(&self, options: &ChatOptions) -> Result<Self, ChatError> {
232 let tokens_max = options.tokens_max
233 .expect("The max tokens should have been assigned when creating the command");
234
235 let tokens_balance = options.tokens_balance;
236 let upper_bound = (tokens_max as f32 * tokens_balance).floor() as usize;
237 let current_token_length: usize = self.iter().map(|m| m.tokens).sum();
238
239 if current_token_length > upper_bound {
240 let system = ChatMessage::new(ChatRole::System, options.system.clone());
241 let mut messages = vec![];
242 let mut remaining = upper_bound.checked_sub(system.tokens)
243 .ok_or_else(|| ChatTranscriptionError(format!(
244 "Cannot fit your system message into the chat messages list. This means \
245 that your tokens_max value is either too small or your system message is \
246 too long. You're upper bound on transcript tokens is {upper_bound} and \
247 your system message has {} tokens", system.tokens)))?;
248
249 for message in self.iter().skip(1).rev() {
250 match remaining.checked_sub(message.tokens) {
251 Some(subtracted) => {
252 remaining = subtracted;
253 messages.push(message);
254 },
255 None => {
256 debug!(target: "AI chat", "Lobotomized message {:?}", message);
257 },
258 }
259 }
260
261 messages.push(&system);
262 Ok(messages.iter().rev().map(|i| i.clone()).cloned().collect())
263 } else {
264 Ok(self.clone())
265 }
266 }
267}
268
269#[derive(Copy, Clone, Debug, Default, ValueEnum, Serialize, Deserialize)]
270#[allow(non_camel_case_types)]
271pub enum ChatProvider {
272 #[default]
273 OpenAiGPT3Turbo,
274 OpenAiGPT3Turbo_0301,
275 OpenAiGPT4,
276 OpenAiGPT4_0314,
277 OpenAiGPT4_32K,
278 OpenAiGPT4_32K_0314
279}
280
281impl std::fmt::Display for ChatProvider {
282 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
283 write!(f, "{}", match self {
284 Self::OpenAiGPT3Turbo => "gpt-3.5-turbo",
285 Self::OpenAiGPT3Turbo_0301 => "gpt-3.5-turbo-0301",
286 Self::OpenAiGPT4 => "gpt-4",
287 Self::OpenAiGPT4_0314 => "gpt-4-0314",
288 Self::OpenAiGPT4_32K => "gpt-4-32k",
289 Self::OpenAiGPT4_32K_0314 => "gpt-4-32k-0314"
290 })
291 }
292}
293
294#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq)]
295pub enum ChatRole {
296 #[serde(rename = "assistant")]
297 Ai,
298 #[serde(rename = "user")]
299 User,
300 #[serde(rename = "system")]
301 System
302}
303
304impl std::fmt::Display for ChatRole {
305 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
306 write!(f, "{}", match self {
307 Self::Ai => "AI: ",
308 Self::User => "USER: ",
309 Self::System => "SYSTEM: "
310 })
311 }
312}
313
314impl TryFrom<(&str, &ChatOptions)> for ChatRole {
315 type Error = ChatError;
316
317 fn try_from((role, options): (&str, &ChatOptions)) -> Result<Self, Self::Error> {
318 let role = role.to_lowercase();
319 let role = role.trim();
320
321 if role == options.prefix_ai.to_lowercase() {
322 return Ok(ChatRole::Ai)
323 }
324
325 if role == options.prefix_user.to_lowercase() {
326 return Ok(ChatRole::User)
327 }
328
329 match &*role {
330 "ai" |
331 "assistant" => Ok(ChatRole::Ai),
332 "system" => Ok(ChatRole::System),
333 _ => Ok(ChatRole::User),
334 }
335 }
336}