ai/
session.rs

1use async_recursion::async_recursion;
2use clap::{Args,ValueEnum};
3use reqwest::Client;
4use derive_more::From;
5use serde::{Serialize,Deserialize};
6use crate::openai::{OpenAISessionCommand,OpenAIError};
7use crate::cohere::session::{CohereSessionCommand,CohereError};
8use crate::completion::{CompletionFile,CompletionOptions,ClashingArgumentsError};
9use crate::Config;
10
11#[derive(Args, Clone, Default, Debug, Serialize, Deserialize)]
12pub struct SessionCommand {
13    #[command(flatten)]
14    #[serde(flatten)]
15    pub completion: CompletionOptions,
16
17    /// Model size
18    #[arg(value_enum, long, short)]
19    pub model: Option<Model>,
20
21    /// Model focus
22    #[arg(value_enum, long)]
23    pub model_focus: Option<ModelFocus>,
24
25    /// Prompt
26    #[arg(short, long)]
27    pub prompt: Option<String>,
28
29    /// Prompt path
30    #[arg(long)]
31    pub prompt_path: Option<String>,
32
33    /// Provider
34    #[arg(long)]
35    pub provider: Option<Provider>,
36}
37
38#[derive(Debug, Default)]
39pub(crate) struct SessionOptions {
40    pub ai_responds_first: bool,
41    pub completion: CompletionOptions,
42    pub file: CompletionFile<SessionCommand>,
43    pub model: Model,
44    pub model_focus: ModelFocus,
45    pub prompt: String,
46    pub stream: bool,
47    pub no_context: bool,
48    pub provider: Provider,
49}
50
51impl TryFrom<(&SessionCommand, &Config)> for SessionOptions {
52    type Error = SessionError;
53
54    fn try_from((command, config): (&SessionCommand, &Config)) -> Result<Self, Self::Error> {
55        let file = command.completion.load_session_file::<SessionCommand>(config, command.clone());
56        let completion = if file.file.is_some() {
57            command.completion.merge(&file.overrides.completion)
58        } else {
59            command.completion.clone()
60        };
61
62        completion.validate()?;
63
64        Ok(SessionOptions {
65            ai_responds_first: completion.ai_responds_first.unwrap_or(false),
66            stream: completion.parse_stream_option()?,
67            prompt: command.parse_prompt_option(),
68            no_context: command.parse_no_context_option(),
69            model: command.model.unwrap_or(Model::XXLarge),
70            model_focus: command.model_focus.unwrap_or(ModelFocus::Text),
71            provider: command.provider.unwrap_or(Provider::OpenAI),
72            completion,
73            file
74        })
75    }
76}
77
78pub type SessionResult = Result<Vec<String>, SessionError>;
79pub trait SessionResultExt {
80    fn single_result(&self) -> Option<&str>;
81}
82
83impl SessionResultExt for SessionResult {
84    fn single_result(&self) -> Option<&str> {
85        self.as_ref().ok().and_then(|r| r.first()).map(|x| &**x)
86    }
87}
88
89#[derive(From, Debug)]
90pub enum SessionError {
91    NoMatchingModel,
92    TemperatureOutOfValidRange,
93    ClashingArguments(ClashingArgumentsError),
94    CohereError(CohereError),
95    OpenAIError(OpenAIError),
96    IOError(std::io::Error),
97    DeserializeError(reqwest::Error),
98    Unauthorized
99}
100
101impl SessionCommand {
102    #[async_recursion]
103    pub async fn run(&self, client: &Client, config: &Config) -> SessionResult {
104        let mut options = SessionOptions::try_from((self, config))?;
105        let prefix_user = options.completion.prefix_user.as_ref().map(|u| &**u);
106
107        // The commands need to be instantiated before printing the opening prompt because they can
108        // print warnings about mismatched options without failing.
109        let command = match options.provider {
110            Provider::OpenAI => Ok(OpenAISessionCommand::try_from(&options)?),
111            Provider::Cohere => Err(CohereSessionCommand::try_from(&options)?),
112        };
113
114        let print_output = !options.completion.quiet.unwrap_or(false);
115        if print_output && options.file.transcript.len() > 0 {
116            println!("{}", options.file.transcript);
117        }
118
119        let line = if options.ai_responds_first {
120            String::new()
121        } else {
122            let append = options.completion.append.as_ref().map(|a| &**a);
123
124            if let Some(line) = options.file.read(append, prefix_user, options.no_context) {
125                line
126            } else {
127                return Ok(vec![]);
128            }
129        };
130
131        loop {
132            let transcript = &options.file.transcript;
133            let prompt = &options.prompt;
134            let prompt = match (options.no_context, &options.completion.prefix_ai) {
135                (true, None) => prompt.replace("${TRANSCRIPT}", &line),
136                (true, Some(prefix)) => prompt.replace("${TRANSCRIPT}", &line) + &prefix,
137                (false, None) => prompt.replace("${TRANSCRIPT}", transcript),
138                (false, Some(prefix)) =>
139                    prompt.replace("${TRANSCRIPT}", transcript) + &prefix
140            };
141
142            let result = match &command {
143                Ok(command) => command.run(client, config, &prompt).await?,
144                Err(command) => command.run(client, config, &prompt).await?,
145            };
146
147            if let Some(count) = options.completion.response_count {
148                if count > 1 {
149                    return Ok(result);
150                }
151            }
152
153            let text = result.first().unwrap().trim();
154            let written_response = match &options.completion.prefix_ai {
155                Some(prefix) => format!("{}{}", prefix, text),
156                None => text.to_owned()
157            };
158            let text = options.file.write(text.into(), options.no_context, false)?;
159
160            if !options.completion.quiet.unwrap_or(false) {
161                println!("{}", written_response);
162            }
163
164            if options.completion.append.is_some() {
165                return Ok(vec![ text.to_string() ]);
166            }
167
168            if let None = options.file.read(None, prefix_user, options.no_context) {
169                return Ok(vec![]);
170            }
171        }
172    }
173
174    pub fn parse_no_context_option(&self) -> bool {
175        self.completion.no_context.unwrap_or_else(|| {
176            match self.model_focus {
177                Some(ModelFocus::Code) => true,
178                _ => false,
179            }
180        })
181    }
182
183    pub fn parse_prompt_option(&self) -> String {
184        self.prompt
185            .clone()
186            .or_else(|| {
187                self.prompt_path
188                    .clone()
189                    .and_then(|path| {
190                        std::fs::read_to_string(path).ok()
191                    })
192            })
193            .unwrap_or_else(|| {
194                match self.model_focus {
195                    Some(ModelFocus::Text) | None => DEFAULT_CHAT_PROMPT_WRAPPER.to_owned(),
196                    Some(ModelFocus::Code) => DEFAULT_CODE_PROMPT_WRAPPER.to_owned(),
197                }
198            })
199    }
200}
201
202#[derive(Copy, Clone, Debug, Default, ValueEnum, Serialize, Deserialize)]
203pub enum Provider {
204    /// Cohere
205    Cohere,
206
207    /// OpenAI
208    #[default]
209    OpenAI,
210}
211
212#[derive(Copy, Clone, Debug, Default, ValueEnum, Serialize, Deserialize)]
213pub enum Model {
214    /// In the range of 0 - 1 billion parameters. OpenAI's Ada, Cohere's "small" option.
215    Tiny,
216
217    /// In the range of 1 - 5 billion parameters. OpenAI's Babbage option.
218    Small,
219
220    /// In the range of 5 - 10 billion parameters. OpenAI's Curie, Cohere's "medium" option.
221    Medium,
222
223    /// In the range of 10 - 50 billion parameters. Cohere's large option.
224    Large,
225
226    /// In the range of 50 - 150 billion paramaters. Cohere's xlarge option.
227    XLarge,
228
229    /// Greater than 150 billion paramaters. OpenAI's davinci model.
230    #[default]
231    XXLarge
232}
233
234#[derive(Copy, Clone, Default, Debug, ValueEnum, Serialize, Deserialize)]
235pub enum ModelFocus {
236    Code,
237    #[default]
238    Text
239}
240
241
242const DEFAULT_CODE_PROMPT_WRAPPER: &str = "${TRANSCRIPT}";
243const DEFAULT_CHAT_PROMPT_WRAPPER: &str = "
244The following is a transcript between a helpful AI assistant and a human. The AI assistant can provide factual information (but only from before mid 2021, when its training data cuts off), ask clarifying questions, and engage in chit chat.
245
246Transcript:
247
248${TRANSCRIPT}
249
250";