ai/
chat.rs

1use async_recursion::async_recursion;
2use clap::{Args,ValueEnum};
3use serde::{Serialize,Deserialize};
4use reqwest::Client;
5use derive_more::From;
6use tiktoken_rs::p50k_base;
7use crate::openai::chat::OpenAIChatCommand;
8use crate::openai::OpenAIError;
9use crate::completion::{CompletionOptions,CompletionFile,ClashingArgumentsError};
10use crate::Config;
11use log::debug;
12
13#[derive(Args, Clone, Debug, Default, Serialize, Deserialize)]
14pub struct ChatCommand {
15    #[command(flatten)]
16    #[serde(flatten)]
17    pub completion: CompletionOptions,
18
19    #[arg(long, short)]
20    pub system: Option<String>,
21
22    #[arg(long, short)]
23    pub direction: Option<String>,
24
25    #[arg(long)]
26    pub provider: Option<ChatProvider>,
27}
28
29impl ChatCommand {
30    #[async_recursion]
31    pub async fn run(&self, client: &Client, config: &Config) -> ChatResult {
32        let mut options = ChatOptions::try_from((self, config))?;
33        let print_output = !options.completion.quiet.unwrap_or(false);
34
35        if print_output && options.file.transcript.len() > 0 {
36            print!("{}", options.file.transcript);
37        }
38
39        if !options.ai_responds_first {
40            let append = options.completion.append.as_ref().map(|a| &**a);
41            let prefix_user = Some(&*options.prefix_user);
42
43            if let None = options.file.read(append, prefix_user, options.no_context) {
44                return Ok(vec![]);
45            }
46        }
47
48        let mut command = OpenAIChatCommand::try_from(options)?;
49        command.run(client, config).await
50    }
51}
52
53#[derive(Default, Debug)]
54pub(crate) struct ChatOptions {
55    pub ai_responds_first: bool,
56    pub completion: CompletionOptions,
57    pub direction: Option<ChatMessage>,
58    pub system: String,
59    pub file: CompletionFile<ChatCommand>,
60    pub no_context: bool,
61    pub provider: ChatProvider,
62    pub prefix_ai: String,
63    pub prefix_user: String,
64    pub stream: bool,
65    pub temperature: f32,
66    pub tokens_max: Option<usize>,
67    pub tokens_balance: f32,
68    pub stop: Vec<String>
69}
70
71impl TryFrom<(&ChatCommand, &Config)> for ChatOptions {
72    type Error = ChatError;
73
74    fn try_from((command, config): (&ChatCommand, &Config)) -> Result<Self, Self::Error> {
75        let file = command.completion.load_session_file::<ChatCommand>(config, command.clone());
76        let completion = if file.file.is_some() {
77            command.completion.merge(&file.overrides.completion)
78        } else {
79            command.completion.clone()
80        };
81
82        let stream = completion.parse_stream_option()?;
83        let system = command.system
84            .clone()
85            .or_else(|| file.overrides.system.clone())
86            .clone()
87            .unwrap_or_else(|| String::from("A friendly and helpful AI assistant."));
88
89        let provider = command.provider.unwrap_or_default();
90
91        Ok(ChatOptions {
92            ai_responds_first: completion.ai_responds_first.unwrap_or(false),
93            direction: command.direction.clone()
94                .map(|direction| ChatMessage::new(ChatRole::System, direction)),
95            temperature: completion.temperature.unwrap_or(0.8),
96            no_context: completion.no_context.unwrap_or(false),
97            provider,
98            prefix_ai: completion.prefix_ai.clone().unwrap_or_else(|| String::from("AI")),
99            prefix_user: completion.prefix_user.clone().unwrap_or_else(|| String::from("USER")),
100            system,
101            tokens_balance: completion.tokens_balance.unwrap_or(0.5),
102            stop: completion.parse_stops(),
103            tokens_max: completion.tokens_max,
104            completion,
105            stream,
106            file,
107        })
108    }
109}
110
111#[derive(Debug, From)]
112pub enum ChatError {
113    ClashingArguments(ClashingArgumentsError),
114    ChatTranscriptionError(ChatTranscriptionError),
115    TranscriptDeserializationError(serde_json::Error),
116    OpenAIError(OpenAIError),
117    NetworkError(reqwest::Error),
118    IOError(std::io::Error),
119    EventSource(reqwest_eventsource::Error),
120    Unauthorized
121}
122
123#[derive(Debug)]
124pub struct ChatTranscriptionError(pub String);
125
126pub type ChatResult = Result<Vec<ChatMessage>, ChatError>;
127
128#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
129pub struct ChatMessage {
130    pub role: ChatRole,
131    pub content: String,
132    #[serde(skip)]
133    pub tokens: usize
134}
135
136impl ChatMessage {
137    pub fn new(role: ChatRole, content: impl AsRef<str>) -> Self {
138        let tokens = p50k_base().unwrap()
139            .encode_with_special_tokens(&format!("{}{}", role, content.as_ref()))
140            .len();
141
142        ChatMessage {
143            role,
144            content: content.as_ref().to_string(),
145            tokens
146        }
147    }
148}
149
150pub type ChatMessages = Vec<ChatMessage>;
151
152impl TryFrom<&ChatOptions> for ChatMessages {
153    type Error = ChatError;
154
155    fn try_from(options: &ChatOptions) -> Result<Self, Self::Error> {
156        let ChatOptions { file, system, .. } = options;
157
158        let mut messages = vec![];
159        let mut message: Option<ChatMessage> = None;
160
161        messages.push(ChatMessage::new(ChatRole::System, system));
162
163        let handle_continuing_line = |line, message: &mut Option<ChatMessage>| match message {
164            Some(m) => {
165                *message = Some(ChatMessage::new(m.role, {
166                    let mut content = m.content.clone();
167                    content += "\n";
168                    content += line;
169                    content
170                }));
171                Ok(())
172            },
173            None => {
174                return Err(ChatError::ChatTranscriptionError(ChatTranscriptionError(
175                    "Missing opening chat role".into()
176                )));
177            }
178        };
179
180        for line in file.transcript.lines() {
181            match line.split_once(':') {
182                Some((role, dialog)) => match ChatRole::try_from((role, options)) {
183                    Ok(normalized_role) => {
184                        if let Some(message) = message {
185                            messages.push(message);
186                        }
187
188                        let mut dialog = dialog.trim_start().to_string();
189                        if role != "ai" && role != "assitant" && role != "user" && role != "system"
190                            && !dialog.to_lowercase().starts_with(role) {
191                            dialog = format!("{role}: {dialog}");
192                        }
193
194                        message = Some(ChatMessage::new(normalized_role, dialog));
195                    },
196                    Err(_) => {
197                        println!("lkjlkj");
198                        handle_continuing_line(line, &mut message)?
199                    }
200                },
201                None => handle_continuing_line(line, &mut message)?
202            }
203        }
204
205        if let Some(message) = message {
206            messages.push(message);
207        }
208
209        if options.no_context {
210            messages.push(ChatMessage::new(ChatRole::User, file.last_read_input.clone()));
211        }
212
213        if let Some(direction) = &options.direction {
214            messages.push(direction.clone());
215        }
216
217        if options.no_context {
218            messages.push(ChatMessage::new(ChatRole::Ai, file.last_written_input.clone()))
219        }
220
221        let lab = messages.labotomize(&options)?;
222        return Ok(lab);
223    }
224}
225
226pub(crate) trait ChatMessagesInternalExt {
227    fn labotomize(&self, options: &ChatOptions) -> Result<Self, ChatError> where Self: Sized;
228}
229
230impl ChatMessagesInternalExt for ChatMessages {
231    fn labotomize(&self, options: &ChatOptions) -> Result<Self, ChatError> {
232        let tokens_max = options.tokens_max
233            .expect("The max tokens should have been assigned when creating the command");
234
235        let tokens_balance = options.tokens_balance;
236        let upper_bound = (tokens_max as f32 * tokens_balance).floor() as usize;
237        let current_token_length: usize = self.iter().map(|m| m.tokens).sum();
238
239        if current_token_length > upper_bound {
240            let system = ChatMessage::new(ChatRole::System, options.system.clone());
241            let mut messages = vec![];
242            let mut remaining = upper_bound.checked_sub(system.tokens)
243                .ok_or_else(|| ChatTranscriptionError(format!(
244                    "Cannot fit your system message into the chat messages list. This means \
245                    that your tokens_max value is either too small or your system message is \
246                    too long. You're upper bound on transcript tokens is {upper_bound} and \
247                    your system message has {} tokens", system.tokens)))?;
248
249            for message in self.iter().skip(1).rev() {
250                match remaining.checked_sub(message.tokens) {
251                    Some(subtracted) => {
252                        remaining = subtracted;
253                        messages.push(message);
254                    },
255                    None => {
256                        debug!(target: "AI chat", "Lobotomized message {:?}", message);
257                    },
258                }
259            }
260
261            messages.push(&system);
262            Ok(messages.iter().rev().map(|i| i.clone()).cloned().collect())
263        } else {
264            Ok(self.clone())
265        }
266    }
267}
268
269#[derive(Copy, Clone, Debug, Default, ValueEnum, Serialize, Deserialize)]
270#[allow(non_camel_case_types)]
271pub enum ChatProvider {
272    #[default]
273    OpenAiGPT3Turbo,
274    OpenAiGPT3Turbo_0301,
275    OpenAiGPT4,
276    OpenAiGPT4_0314,
277    OpenAiGPT4_32K,
278    OpenAiGPT4_32K_0314
279}
280
281impl std::fmt::Display for ChatProvider {
282    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
283        write!(f, "{}", match self {
284            Self::OpenAiGPT3Turbo => "gpt-3.5-turbo",
285            Self::OpenAiGPT3Turbo_0301 => "gpt-3.5-turbo-0301",
286            Self::OpenAiGPT4 => "gpt-4",
287            Self::OpenAiGPT4_0314 => "gpt-4-0314",
288            Self::OpenAiGPT4_32K => "gpt-4-32k",
289            Self::OpenAiGPT4_32K_0314 => "gpt-4-32k-0314"
290        })
291    }
292}
293
294#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq)]
295pub enum ChatRole {
296    #[serde(rename = "assistant")]
297    Ai,
298    #[serde(rename = "user")]
299    User,
300    #[serde(rename = "system")]
301    System
302}
303
304impl std::fmt::Display for ChatRole {
305    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
306        write!(f, "{}", match self {
307            Self::Ai => "AI: ",
308            Self::User => "USER: ",
309            Self::System => "SYSTEM: "
310        })
311    }
312}
313
314impl TryFrom<(&str, &ChatOptions)> for ChatRole {
315    type Error = ChatError;
316
317    fn try_from((role, options): (&str, &ChatOptions)) -> Result<Self, Self::Error> {
318        let role = role.to_lowercase();
319        let role = role.trim();
320
321        if role == options.prefix_ai.to_lowercase() {
322            return Ok(ChatRole::Ai)
323        }
324
325        if role == options.prefix_user.to_lowercase() {
326            return Ok(ChatRole::User)
327        }
328
329        match &*role {
330            "ai" |
331            "assistant" => Ok(ChatRole::Ai),
332            "system" => Ok(ChatRole::System),
333            _ => Ok(ChatRole::User),
334        }
335    }
336}