1use async_recursion::async_recursion;
2use clap::{Args,ValueEnum};
3use reqwest::Client;
4use derive_more::From;
5use serde::{Serialize,Deserialize};
6use crate::openai::{OpenAISessionCommand,OpenAIError};
7use crate::cohere::session::{CohereSessionCommand,CohereError};
8use crate::completion::{CompletionFile,CompletionOptions,ClashingArgumentsError};
9use crate::Config;
10
11#[derive(Args, Clone, Default, Debug, Serialize, Deserialize)]
12pub struct SessionCommand {
13 #[command(flatten)]
14 #[serde(flatten)]
15 pub completion: CompletionOptions,
16
17 #[arg(value_enum, long, short)]
19 pub model: Option<Model>,
20
21 #[arg(value_enum, long)]
23 pub model_focus: Option<ModelFocus>,
24
25 #[arg(short, long)]
27 pub prompt: Option<String>,
28
29 #[arg(long)]
31 pub prompt_path: Option<String>,
32
33 #[arg(long)]
35 pub provider: Option<Provider>,
36}
37
38#[derive(Debug, Default)]
39pub(crate) struct SessionOptions {
40 pub ai_responds_first: bool,
41 pub completion: CompletionOptions,
42 pub file: CompletionFile<SessionCommand>,
43 pub model: Model,
44 pub model_focus: ModelFocus,
45 pub prompt: String,
46 pub stream: bool,
47 pub no_context: bool,
48 pub provider: Provider,
49}
50
51impl TryFrom<(&SessionCommand, &Config)> for SessionOptions {
52 type Error = SessionError;
53
54 fn try_from((command, config): (&SessionCommand, &Config)) -> Result<Self, Self::Error> {
55 let file = command.completion.load_session_file::<SessionCommand>(config, command.clone());
56 let completion = if file.file.is_some() {
57 command.completion.merge(&file.overrides.completion)
58 } else {
59 command.completion.clone()
60 };
61
62 completion.validate()?;
63
64 Ok(SessionOptions {
65 ai_responds_first: completion.ai_responds_first.unwrap_or(false),
66 stream: completion.parse_stream_option()?,
67 prompt: command.parse_prompt_option(),
68 no_context: command.parse_no_context_option(),
69 model: command.model.unwrap_or(Model::XXLarge),
70 model_focus: command.model_focus.unwrap_or(ModelFocus::Text),
71 provider: command.provider.unwrap_or(Provider::OpenAI),
72 completion,
73 file
74 })
75 }
76}
77
78pub type SessionResult = Result<Vec<String>, SessionError>;
79pub trait SessionResultExt {
80 fn single_result(&self) -> Option<&str>;
81}
82
83impl SessionResultExt for SessionResult {
84 fn single_result(&self) -> Option<&str> {
85 self.as_ref().ok().and_then(|r| r.first()).map(|x| &**x)
86 }
87}
88
89#[derive(From, Debug)]
90pub enum SessionError {
91 NoMatchingModel,
92 TemperatureOutOfValidRange,
93 ClashingArguments(ClashingArgumentsError),
94 CohereError(CohereError),
95 OpenAIError(OpenAIError),
96 IOError(std::io::Error),
97 DeserializeError(reqwest::Error),
98 Unauthorized
99}
100
101impl SessionCommand {
102 #[async_recursion]
103 pub async fn run(&self, client: &Client, config: &Config) -> SessionResult {
104 let mut options = SessionOptions::try_from((self, config))?;
105 let prefix_user = options.completion.prefix_user.as_ref().map(|u| &**u);
106
107 let command = match options.provider {
110 Provider::OpenAI => Ok(OpenAISessionCommand::try_from(&options)?),
111 Provider::Cohere => Err(CohereSessionCommand::try_from(&options)?),
112 };
113
114 let print_output = !options.completion.quiet.unwrap_or(false);
115 if print_output && options.file.transcript.len() > 0 {
116 println!("{}", options.file.transcript);
117 }
118
119 let line = if options.ai_responds_first {
120 String::new()
121 } else {
122 let append = options.completion.append.as_ref().map(|a| &**a);
123
124 if let Some(line) = options.file.read(append, prefix_user, options.no_context) {
125 line
126 } else {
127 return Ok(vec![]);
128 }
129 };
130
131 loop {
132 let transcript = &options.file.transcript;
133 let prompt = &options.prompt;
134 let prompt = match (options.no_context, &options.completion.prefix_ai) {
135 (true, None) => prompt.replace("${TRANSCRIPT}", &line),
136 (true, Some(prefix)) => prompt.replace("${TRANSCRIPT}", &line) + &prefix,
137 (false, None) => prompt.replace("${TRANSCRIPT}", transcript),
138 (false, Some(prefix)) =>
139 prompt.replace("${TRANSCRIPT}", transcript) + &prefix
140 };
141
142 let result = match &command {
143 Ok(command) => command.run(client, config, &prompt).await?,
144 Err(command) => command.run(client, config, &prompt).await?,
145 };
146
147 if let Some(count) = options.completion.response_count {
148 if count > 1 {
149 return Ok(result);
150 }
151 }
152
153 let text = result.first().unwrap().trim();
154 let written_response = match &options.completion.prefix_ai {
155 Some(prefix) => format!("{}{}", prefix, text),
156 None => text.to_owned()
157 };
158 let text = options.file.write(text.into(), options.no_context, false)?;
159
160 if !options.completion.quiet.unwrap_or(false) {
161 println!("{}", written_response);
162 }
163
164 if options.completion.append.is_some() {
165 return Ok(vec![ text.to_string() ]);
166 }
167
168 if let None = options.file.read(None, prefix_user, options.no_context) {
169 return Ok(vec![]);
170 }
171 }
172 }
173
174 pub fn parse_no_context_option(&self) -> bool {
175 self.completion.no_context.unwrap_or_else(|| {
176 match self.model_focus {
177 Some(ModelFocus::Code) => true,
178 _ => false,
179 }
180 })
181 }
182
183 pub fn parse_prompt_option(&self) -> String {
184 self.prompt
185 .clone()
186 .or_else(|| {
187 self.prompt_path
188 .clone()
189 .and_then(|path| {
190 std::fs::read_to_string(path).ok()
191 })
192 })
193 .unwrap_or_else(|| {
194 match self.model_focus {
195 Some(ModelFocus::Text) | None => DEFAULT_CHAT_PROMPT_WRAPPER.to_owned(),
196 Some(ModelFocus::Code) => DEFAULT_CODE_PROMPT_WRAPPER.to_owned(),
197 }
198 })
199 }
200}
201
202#[derive(Copy, Clone, Debug, Default, ValueEnum, Serialize, Deserialize)]
203pub enum Provider {
204 Cohere,
206
207 #[default]
209 OpenAI,
210}
211
212#[derive(Copy, Clone, Debug, Default, ValueEnum, Serialize, Deserialize)]
213pub enum Model {
214 Tiny,
216
217 Small,
219
220 Medium,
222
223 Large,
225
226 XLarge,
228
229 #[default]
231 XXLarge
232}
233
234#[derive(Copy, Clone, Default, Debug, ValueEnum, Serialize, Deserialize)]
235pub enum ModelFocus {
236 Code,
237 #[default]
238 Text
239}
240
241
242const DEFAULT_CODE_PROMPT_WRAPPER: &str = "${TRANSCRIPT}";
243const DEFAULT_CHAT_PROMPT_WRAPPER: &str = "
244The following is a transcript between a helpful AI assistant and a human. The AI assistant can provide factual information (but only from before mid 2021, when its training data cuts off), ask clarifying questions, and engage in chit chat.
245
246Transcript:
247
248${TRANSCRIPT}
249
250";