ai/
completion.rs

1use clap::Args;
2use serde::{Serialize,Deserialize};
3use serde::de::DeserializeOwned;
4use std::fs::{self,File,OpenOptions};
5use std::io::{self,Write};
6use crate::Config;
7
8#[derive(Args, Clone, Default, Debug, Serialize, Deserialize)]
9pub struct CompletionOptions {
10    /// Allow the AI to generate a response to the prompt before user input
11    #[arg(long)]
12    pub ai_responds_first: Option<bool>,
13
14    /// Append a string to an existing session and get only the latest response.
15    #[arg(long)]
16    pub append: Option<String>,
17
18    /// Temperature of the model, the allowed range of this value is different across providers,
19    /// for OpenAI it's 0 - 2, and Cohere uses a 0 - 5 scale.
20    #[arg(long, short)]
21    pub temperature: Option<f32>,
22
23    /// Saves your conversation context using the session name
24    #[arg(short, long)]
25    pub name: Option<String>,
26
27    /// Disables the context of the conversation, every message sent to the AI is standalone. If you
28    /// use a coding model this defaults to true unless prompt is specified.
29    #[arg(long)]
30    pub no_context: Option<bool>,
31
32    /// Only do one question / answer cycle and return the result.
33    #[arg(long)]
34    pub once: Option<bool>,
35
36    /// Overwrite the existing session if it already exists
37    #[arg(long)]
38    pub overwrite: Option<bool>,
39
40    /// Only write output the session file
41    #[arg(long)]
42    pub quiet: Option<bool>,
43
44    /// Prefix ai responses with the supplied string. This can be used for labels if your prompt has
45    /// a conversational style. Defaults to "AI"
46    #[arg(long)]
47    pub prefix_ai: Option<String>,
48
49    /// Prefix input with the supplied string. This can be used for labels if your prompt has a
50    /// conversational style. Defaults to "USER:"
51    #[arg(long)]
52    pub prefix_user: Option<String>,
53
54    /// Number of responses to generate
55    #[arg(skip)]
56    pub response_count: Option<usize>,
57
58    /// Stop tokens
59    #[arg(long)]
60    pub stop: Option<Vec<String>>,
61
62    /// Stream the output to the terminal
63    #[arg(long)]
64    pub stream: Option<bool>,
65
66    /// The number of maximum total tokens to allow. The maximum upper value of this is dependant on
67    /// the model you're currently using, but often it's 4096.
68    #[arg(long)]
69    pub tokens_max: Option<usize>,
70
71    /// A percentage given from 0 to 0.9 to indicate what percentage of the current conversation
72    /// context to keep. Defaults to 0.5
73    #[arg(long)]
74    pub tokens_balance: Option<f32>,
75}
76
77impl CompletionOptions {
78    pub fn merge(&self, merged: &CompletionOptions) -> Self {
79        let original = self.clone();
80        let merged = merged.clone();
81
82        CompletionOptions {
83            ai_responds_first: original.ai_responds_first.or(merged.ai_responds_first),
84            append: original.append.or(merged.append),
85            temperature: original.temperature.or(merged.temperature),
86            name: original.name.or(merged.name),
87            overwrite: original.overwrite.or(merged.overwrite),
88            once: original.once.or(merged.once),
89            quiet: original.quiet.or(merged.quiet),
90            prefix_ai: original.prefix_ai.or(merged.prefix_ai),
91            prefix_user: original.prefix_user.or(merged.prefix_user),
92            stop: original.stop.or(merged.stop),
93            stream: original.stream.or(merged.stream),
94            tokens_max: original.tokens_max.or(merged.tokens_max),
95            tokens_balance: original.tokens_balance.or(merged.tokens_balance),
96            no_context: original.no_context.or(merged.no_context),
97            response_count: original.response_count.or(merged.response_count),
98        }
99    }
100
101    pub fn load_session_file<T>(&self, config: &Config, mut overrides: T) -> CompletionFile<T>
102    where
103        T: Clone + Default + DeserializeOwned + Serialize
104    {
105        let session_dir = {
106            let mut dir = config.dir.clone();
107            dir.push("sessions");
108            dir
109        };
110        fs::create_dir_all(&session_dir).expect("Config directory could not be created");
111
112        if self.overwrite.unwrap_or(false) {
113            let path = {
114                let mut path = session_dir.clone();
115                path.push(self.name.as_ref().unwrap());
116                path
117            };
118            let file = OpenOptions::new().write(true).truncate(true).open(path);
119            if let Ok(mut session_file) = file {
120                session_file.write_all(b"").expect("Unable to write to session file");
121                session_file.flush().expect("Unable to write to session file");
122            }
123        }
124
125        let file = self.name.clone().map(|name| {
126            let path = {
127                let mut path = session_dir.clone();
128                path.push(name);
129                path
130            };
131
132            let mut transcript = String::new();
133            let file = match fs::read_to_string(&path) {
134                Ok(mut session_config) if session_config.find("<->").is_some() => {
135                    let divider_index = session_config.find("<->").unwrap();
136
137                    transcript = session_config
138                        .split_off(divider_index + 4)
139                        .trim_start()
140                        .to_string();
141                    session_config.truncate(divider_index);
142                    overrides = serde_yaml::from_str(&session_config)
143                        .expect("Serializing self to yaml config should work 100% of the time");
144
145                    OpenOptions::new()
146                        .append(true)
147                        .create(true)
148                        .open(path)
149                        .expect("Unable to open session file")
150                },
151                _ => {
152                    let config = serde_yaml::to_string(&overrides)
153                        .expect("Serializing self to yaml config should work 100% of the time");
154
155                    let mut file = OpenOptions::new()
156                        .append(true)
157                        .create(true)
158                        .open(path)
159                        .expect("Unable to open session file");
160
161                    if let Err(e) = writeln!(file, "{}<->", &config) {
162                        eprintln!("Couldn't write new configuration to file: {}", e);
163                    }
164
165                    file
166                }
167            };
168
169            CompletionFile {
170                file: Some(file),
171                overrides,
172                transcript,
173                last_read_input: String::new(),
174                last_written_input: String::new()
175            }
176        });
177
178        file.unwrap_or_default()
179    }
180
181    pub fn parse_stops(&self) -> Vec<String> {
182        self.stop.iter()
183            .map(|s| s.iter().map(|s| s.split(",").map(|s| s.trim().to_string())).flatten())
184            .flatten()
185            .collect()
186    }
187
188    pub fn parse_stream_option(&self) -> Result<bool, ClashingArgumentsError> {
189        match (self.quiet, self.stream) {
190            (Some(true), Some(true)) => return Err(ClashingArgumentsError::new(
191                "Having both quiet and stream enabled doesn't make sense."
192            )),
193            (Some(true), None) |
194            (Some(true), Some(false)) |
195            (None, Some(false)) |
196            (Some(false), Some(false)) => Ok(false),
197            (Some(false), None) |
198            (Some(false), Some(true)) |
199            (None, Some(true)) |
200            (None, None) => Ok(true)
201        }
202    }
203
204    pub fn validate(&self) -> Result<(), ClashingArgumentsError> {
205        if self.name.is_none() {
206            if self.append.is_some() {
207                return Err(ClashingArgumentsError::new(
208                    "The append option also requires a session name"));
209            }
210
211            if self.overwrite.unwrap_or(false) {
212                return Err(ClashingArgumentsError::new(
213                    "The overwrite options also requires a session name"));
214            }
215        }
216
217        if self.ai_responds_first.unwrap_or(false) && self.append.is_some() {
218            return Err(ClashingArgumentsError::new(
219                "Specifying that the ai responds first with the append option is nonsensical"));
220        }
221
222        if let Some(count) = self.response_count {
223            if count == 0 {
224                return Err(ClashingArgumentsError::new("The response count should be more than 0"));
225            }
226        }
227
228        Ok(())
229    }
230}
231
232#[derive(Debug)]
233pub struct ClashingArgumentsError(String);
234
235impl ClashingArgumentsError {
236    pub fn new(error: impl Into<String>) -> Self { Self(error.into()) }
237}
238
239#[derive(Debug, Default)]
240pub struct CompletionFile<T: Clone + Default + DeserializeOwned + Serialize> {
241    pub file: Option<File>,
242    pub overrides: T,
243    pub transcript: String,
244    pub last_read_input: String,
245    pub last_written_input: String
246}
247
248impl<T> CompletionFile<T>
249where
250    T: Clone + Default + DeserializeOwned + Serialize
251{
252    pub fn write(&mut self, line: String, no_context: bool, is_read: bool) -> io::Result<String> {
253        if !is_read {
254            self.last_written_input = line.clone();
255        }
256
257        if no_context {
258            return Ok(line)
259        }
260
261        match &mut self.file {
262            Some(file) => match writeln!(file, "{}", line) {
263                Ok(()) => {
264                    self.transcript += &line;
265                    self.transcript += "\n";
266                    Ok(line)
267                },
268                Err(e) => Err(e)
269            },
270            None => {
271                self.transcript += &line;
272                self.transcript += "\n";
273                Ok(line)
274            }
275        }
276    }
277
278    pub fn read(
279        &mut self,
280        append: Option<&str>,
281        prefix_user: Option<&str>,
282        no_context: bool) -> Option<String>
283    {
284        let line = append
285            .map(|s| s.to_string())
286            .or_else(|| read_next_user_line(prefix_user))
287            .map(|s| s.trim().to_string());
288
289        line
290            .and_then(|line| {
291                let line = match &prefix_user {
292                    Some(prefix) if !line.to_lowercase().starts_with(prefix) => {
293                        format!("{}: {}", prefix, line)
294                    },
295                    _ => line
296                };
297                self.last_read_input = line.clone();
298                Some(line)
299            })
300            .and_then(|line| if no_context {
301                Some(line)
302            } else {
303                self.write(line, no_context, true).ok()
304            })
305    }
306}
307
308fn read_next_user_line(prefix_user: Option<&str>) -> Option<String> {
309    let mut rl = rustyline::Editor::<()>::new().expect("Failed to create rusty line editor");
310    let prefix = match prefix_user {
311        Some(user) => format!("{}: ", user),
312        None => String::new(),
313    };
314
315    match rl.readline(&prefix) {
316        Ok(line) => Some(String::from("") + line.trim_end()),
317        Err(_) => None
318    }
319}