1use clap::Args;
2use serde::{Serialize,Deserialize};
3use serde::de::DeserializeOwned;
4use std::fs::{self,File,OpenOptions};
5use std::io::{self,Write};
6use crate::Config;
7
8#[derive(Args, Clone, Default, Debug, Serialize, Deserialize)]
9pub struct CompletionOptions {
10 #[arg(long)]
12 pub ai_responds_first: Option<bool>,
13
14 #[arg(long)]
16 pub append: Option<String>,
17
18 #[arg(long, short)]
21 pub temperature: Option<f32>,
22
23 #[arg(short, long)]
25 pub name: Option<String>,
26
27 #[arg(long)]
30 pub no_context: Option<bool>,
31
32 #[arg(long)]
34 pub once: Option<bool>,
35
36 #[arg(long)]
38 pub overwrite: Option<bool>,
39
40 #[arg(long)]
42 pub quiet: Option<bool>,
43
44 #[arg(long)]
47 pub prefix_ai: Option<String>,
48
49 #[arg(long)]
52 pub prefix_user: Option<String>,
53
54 #[arg(skip)]
56 pub response_count: Option<usize>,
57
58 #[arg(long)]
60 pub stop: Option<Vec<String>>,
61
62 #[arg(long)]
64 pub stream: Option<bool>,
65
66 #[arg(long)]
69 pub tokens_max: Option<usize>,
70
71 #[arg(long)]
74 pub tokens_balance: Option<f32>,
75}
76
77impl CompletionOptions {
78 pub fn merge(&self, merged: &CompletionOptions) -> Self {
79 let original = self.clone();
80 let merged = merged.clone();
81
82 CompletionOptions {
83 ai_responds_first: original.ai_responds_first.or(merged.ai_responds_first),
84 append: original.append.or(merged.append),
85 temperature: original.temperature.or(merged.temperature),
86 name: original.name.or(merged.name),
87 overwrite: original.overwrite.or(merged.overwrite),
88 once: original.once.or(merged.once),
89 quiet: original.quiet.or(merged.quiet),
90 prefix_ai: original.prefix_ai.or(merged.prefix_ai),
91 prefix_user: original.prefix_user.or(merged.prefix_user),
92 stop: original.stop.or(merged.stop),
93 stream: original.stream.or(merged.stream),
94 tokens_max: original.tokens_max.or(merged.tokens_max),
95 tokens_balance: original.tokens_balance.or(merged.tokens_balance),
96 no_context: original.no_context.or(merged.no_context),
97 response_count: original.response_count.or(merged.response_count),
98 }
99 }
100
101 pub fn load_session_file<T>(&self, config: &Config, mut overrides: T) -> CompletionFile<T>
102 where
103 T: Clone + Default + DeserializeOwned + Serialize
104 {
105 let session_dir = {
106 let mut dir = config.dir.clone();
107 dir.push("sessions");
108 dir
109 };
110 fs::create_dir_all(&session_dir).expect("Config directory could not be created");
111
112 if self.overwrite.unwrap_or(false) {
113 let path = {
114 let mut path = session_dir.clone();
115 path.push(self.name.as_ref().unwrap());
116 path
117 };
118 let file = OpenOptions::new().write(true).truncate(true).open(path);
119 if let Ok(mut session_file) = file {
120 session_file.write_all(b"").expect("Unable to write to session file");
121 session_file.flush().expect("Unable to write to session file");
122 }
123 }
124
125 let file = self.name.clone().map(|name| {
126 let path = {
127 let mut path = session_dir.clone();
128 path.push(name);
129 path
130 };
131
132 let mut transcript = String::new();
133 let file = match fs::read_to_string(&path) {
134 Ok(mut session_config) if session_config.find("<->").is_some() => {
135 let divider_index = session_config.find("<->").unwrap();
136
137 transcript = session_config
138 .split_off(divider_index + 4)
139 .trim_start()
140 .to_string();
141 session_config.truncate(divider_index);
142 overrides = serde_yaml::from_str(&session_config)
143 .expect("Serializing self to yaml config should work 100% of the time");
144
145 OpenOptions::new()
146 .append(true)
147 .create(true)
148 .open(path)
149 .expect("Unable to open session file")
150 },
151 _ => {
152 let config = serde_yaml::to_string(&overrides)
153 .expect("Serializing self to yaml config should work 100% of the time");
154
155 let mut file = OpenOptions::new()
156 .append(true)
157 .create(true)
158 .open(path)
159 .expect("Unable to open session file");
160
161 if let Err(e) = writeln!(file, "{}<->", &config) {
162 eprintln!("Couldn't write new configuration to file: {}", e);
163 }
164
165 file
166 }
167 };
168
169 CompletionFile {
170 file: Some(file),
171 overrides,
172 transcript,
173 last_read_input: String::new(),
174 last_written_input: String::new()
175 }
176 });
177
178 file.unwrap_or_default()
179 }
180
181 pub fn parse_stops(&self) -> Vec<String> {
182 self.stop.iter()
183 .map(|s| s.iter().map(|s| s.split(",").map(|s| s.trim().to_string())).flatten())
184 .flatten()
185 .collect()
186 }
187
188 pub fn parse_stream_option(&self) -> Result<bool, ClashingArgumentsError> {
189 match (self.quiet, self.stream) {
190 (Some(true), Some(true)) => return Err(ClashingArgumentsError::new(
191 "Having both quiet and stream enabled doesn't make sense."
192 )),
193 (Some(true), None) |
194 (Some(true), Some(false)) |
195 (None, Some(false)) |
196 (Some(false), Some(false)) => Ok(false),
197 (Some(false), None) |
198 (Some(false), Some(true)) |
199 (None, Some(true)) |
200 (None, None) => Ok(true)
201 }
202 }
203
204 pub fn validate(&self) -> Result<(), ClashingArgumentsError> {
205 if self.name.is_none() {
206 if self.append.is_some() {
207 return Err(ClashingArgumentsError::new(
208 "The append option also requires a session name"));
209 }
210
211 if self.overwrite.unwrap_or(false) {
212 return Err(ClashingArgumentsError::new(
213 "The overwrite options also requires a session name"));
214 }
215 }
216
217 if self.ai_responds_first.unwrap_or(false) && self.append.is_some() {
218 return Err(ClashingArgumentsError::new(
219 "Specifying that the ai responds first with the append option is nonsensical"));
220 }
221
222 if let Some(count) = self.response_count {
223 if count == 0 {
224 return Err(ClashingArgumentsError::new("The response count should be more than 0"));
225 }
226 }
227
228 Ok(())
229 }
230}
231
232#[derive(Debug)]
233pub struct ClashingArgumentsError(String);
234
235impl ClashingArgumentsError {
236 pub fn new(error: impl Into<String>) -> Self { Self(error.into()) }
237}
238
239#[derive(Debug, Default)]
240pub struct CompletionFile<T: Clone + Default + DeserializeOwned + Serialize> {
241 pub file: Option<File>,
242 pub overrides: T,
243 pub transcript: String,
244 pub last_read_input: String,
245 pub last_written_input: String
246}
247
248impl<T> CompletionFile<T>
249where
250 T: Clone + Default + DeserializeOwned + Serialize
251{
252 pub fn write(&mut self, line: String, no_context: bool, is_read: bool) -> io::Result<String> {
253 if !is_read {
254 self.last_written_input = line.clone();
255 }
256
257 if no_context {
258 return Ok(line)
259 }
260
261 match &mut self.file {
262 Some(file) => match writeln!(file, "{}", line) {
263 Ok(()) => {
264 self.transcript += &line;
265 self.transcript += "\n";
266 Ok(line)
267 },
268 Err(e) => Err(e)
269 },
270 None => {
271 self.transcript += &line;
272 self.transcript += "\n";
273 Ok(line)
274 }
275 }
276 }
277
278 pub fn read(
279 &mut self,
280 append: Option<&str>,
281 prefix_user: Option<&str>,
282 no_context: bool) -> Option<String>
283 {
284 let line = append
285 .map(|s| s.to_string())
286 .or_else(|| read_next_user_line(prefix_user))
287 .map(|s| s.trim().to_string());
288
289 line
290 .and_then(|line| {
291 let line = match &prefix_user {
292 Some(prefix) if !line.to_lowercase().starts_with(prefix) => {
293 format!("{}: {}", prefix, line)
294 },
295 _ => line
296 };
297 self.last_read_input = line.clone();
298 Some(line)
299 })
300 .and_then(|line| if no_context {
301 Some(line)
302 } else {
303 self.write(line, no_context, true).ok()
304 })
305 }
306}
307
308fn read_next_user_line(prefix_user: Option<&str>) -> Option<String> {
309 let mut rl = rustyline::Editor::<()>::new().expect("Failed to create rusty line editor");
310 let prefix = match prefix_user {
311 Some(user) => format!("{}: ", user),
312 None => String::new(),
313 };
314
315 match rl.readline(&prefix) {
316 Ok(line) => Some(String::from("") + line.trim_end()),
317 Err(_) => None
318 }
319}