use clap::Args;
use serde::{Serialize,Deserialize};
use serde::de::DeserializeOwned;
use std::fs::{self,File,OpenOptions};
use std::io::{self,Write};
use crate::Config;
#[derive(Args, Clone, Default, Debug, Serialize, Deserialize)]
pub struct CompletionOptions {
#[arg(long)]
pub ai_responds_first: Option<bool>,
#[arg(long)]
pub append: Option<String>,
#[arg(long, short)]
pub temperature: Option<f32>,
#[arg(short, long)]
pub name: Option<String>,
#[arg(long)]
pub no_context: Option<bool>,
#[arg(long)]
pub once: Option<bool>,
#[arg(long)]
pub overwrite: Option<bool>,
#[arg(long)]
pub quiet: Option<bool>,
#[arg(long)]
pub prefix_ai: Option<String>,
#[arg(long)]
pub prefix_user: Option<String>,
#[arg(skip)]
pub response_count: Option<usize>,
#[arg(long)]
pub stop: Option<Vec<String>>,
#[arg(long)]
pub stream: Option<bool>,
#[arg(long)]
pub tokens_max: Option<usize>,
#[arg(long)]
pub tokens_balance: Option<f32>,
}
impl CompletionOptions {
pub fn merge(&self, merged: &CompletionOptions) -> Self {
let original = self.clone();
let merged = merged.clone();
CompletionOptions {
ai_responds_first: original.ai_responds_first.or(merged.ai_responds_first),
append: original.append.or(merged.append),
temperature: original.temperature.or(merged.temperature),
name: original.name.or(merged.name),
overwrite: original.overwrite.or(merged.overwrite),
once: original.once.or(merged.once),
quiet: original.quiet.or(merged.quiet),
prefix_ai: original.prefix_ai.or(merged.prefix_ai),
prefix_user: original.prefix_user.or(merged.prefix_user),
stop: original.stop.or(merged.stop),
stream: original.stream.or(merged.stream),
tokens_max: original.tokens_max.or(merged.tokens_max),
tokens_balance: original.tokens_balance.or(merged.tokens_balance),
no_context: original.no_context.or(merged.no_context),
response_count: original.response_count.or(merged.response_count),
}
}
pub fn load_session_file<T>(&self, config: &Config, mut overrides: T) -> CompletionFile<T>
where
T: Clone + Default + DeserializeOwned + Serialize
{
let session_dir = {
let mut dir = config.dir.clone();
dir.push("sessions");
dir
};
fs::create_dir_all(&session_dir).expect("Config directory could not be created");
if self.overwrite.unwrap_or(false) {
let path = {
let mut path = session_dir.clone();
path.push(self.name.as_ref().unwrap());
path
};
let file = OpenOptions::new().write(true).truncate(true).open(path);
if let Ok(mut session_file) = file {
session_file.write_all(b"").expect("Unable to write to session file");
session_file.flush().expect("Unable to write to session file");
}
}
let file = self.name.clone().map(|name| {
let path = {
let mut path = session_dir.clone();
path.push(name);
path
};
let mut transcript = String::new();
let file = match fs::read_to_string(&path) {
Ok(mut session_config) if session_config.find("<->").is_some() => {
let divider_index = session_config.find("<->").unwrap();
transcript = session_config
.split_off(divider_index + 4)
.trim_start()
.to_string();
session_config.truncate(divider_index);
overrides = serde_yaml::from_str(&session_config)
.expect("Serializing self to yaml config should work 100% of the time");
OpenOptions::new()
.append(true)
.create(true)
.open(path)
.expect("Unable to open session file")
},
_ => {
let config = serde_yaml::to_string(&overrides)
.expect("Serializing self to yaml config should work 100% of the time");
let mut file = OpenOptions::new()
.append(true)
.create(true)
.open(path)
.expect("Unable to open session file");
if let Err(e) = writeln!(file, "{}<->", &config) {
eprintln!("Couldn't write new configuration to file: {}", e);
}
file
}
};
CompletionFile {
file: Some(file),
overrides,
transcript,
last_read_input: String::new(),
last_written_input: String::new()
}
});
file.unwrap_or_default()
}
pub fn parse_stops(&self) -> Vec<String> {
self.stop.iter()
.map(|s| s.iter().map(|s| s.split(",").map(|s| s.trim().to_string())).flatten())
.flatten()
.collect()
}
pub fn parse_stream_option(&self) -> Result<bool, ClashingArgumentsError> {
match (self.quiet, self.stream) {
(Some(true), Some(true)) => return Err(ClashingArgumentsError::new(
"Having both quiet and stream enabled doesn't make sense."
)),
(Some(true), None) |
(Some(true), Some(false)) |
(None, Some(false)) |
(Some(false), Some(false)) => Ok(false),
(Some(false), None) |
(Some(false), Some(true)) |
(None, Some(true)) |
(None, None) => Ok(true)
}
}
pub fn validate(&self) -> Result<(), ClashingArgumentsError> {
if self.name.is_none() {
if self.append.is_some() {
return Err(ClashingArgumentsError::new(
"The append option also requires a session name"));
}
if self.overwrite.unwrap_or(false) {
return Err(ClashingArgumentsError::new(
"The overwrite options also requires a session name"));
}
}
if self.ai_responds_first.unwrap_or(false) && self.append.is_some() {
return Err(ClashingArgumentsError::new(
"Specifying that the ai responds first with the append option is nonsensical"));
}
if let Some(count) = self.response_count {
if count == 0 {
return Err(ClashingArgumentsError::new("The response count should be more than 0"));
}
}
Ok(())
}
}
#[derive(Debug)]
pub struct ClashingArgumentsError(String);
impl ClashingArgumentsError {
pub fn new(error: impl Into<String>) -> Self { Self(error.into()) }
}
#[derive(Debug, Default)]
pub struct CompletionFile<T: Clone + Default + DeserializeOwned + Serialize> {
pub file: Option<File>,
pub overrides: T,
pub transcript: String,
pub last_read_input: String,
pub last_written_input: String
}
impl<T> CompletionFile<T>
where
T: Clone + Default + DeserializeOwned + Serialize
{
pub fn write(&mut self, line: String, no_context: bool, is_read: bool) -> io::Result<String> {
if !is_read {
self.last_written_input = line.clone();
}
if no_context {
return Ok(line)
}
match &mut self.file {
Some(file) => match writeln!(file, "{}", line) {
Ok(()) => {
self.transcript += &line;
self.transcript += "\n";
Ok(line)
},
Err(e) => Err(e)
},
None => {
self.transcript += &line;
self.transcript += "\n";
Ok(line)
}
}
}
pub fn read(
&mut self,
append: Option<&str>,
prefix_user: Option<&str>,
no_context: bool) -> Option<String>
{
let line = append
.map(|s| s.to_string())
.or_else(|| read_next_user_line(prefix_user))
.map(|s| s.trim().to_string());
line
.and_then(|line| {
let line = match &prefix_user {
Some(prefix) if !line.to_lowercase().starts_with(prefix) => {
format!("{}: {}", prefix, line)
},
_ => line
};
self.last_read_input = line.clone();
Some(line)
})
.and_then(|line| if no_context {
Some(line)
} else {
self.write(line, no_context, true).ok()
})
}
}
fn read_next_user_line(prefix_user: Option<&str>) -> Option<String> {
let mut rl = rustyline::Editor::<()>::new().expect("Failed to create rusty line editor");
let prefix = match prefix_user {
Some(user) => format!("{}: ", user),
None => String::new(),
};
match rl.readline(&prefix) {
Ok(line) => Some(String::from("") + line.trim_end()),
Err(_) => None
}
}