ai_chain_openai_compatible/chatgpt/
executor.rs1use super::prompt::completion_to_output;
2use super::prompt::stream_to_output;
3use async_openai::config::OpenAIConfig;
4use async_openai::types::ChatCompletionRequestMessage;
5
6use async_openai::types::ChatCompletionRequestUserMessageContent;
7use ai_chain::options::Opt;
8use ai_chain::options::Options;
9use ai_chain::options::OptionsCascade;
10use ai_chain::output::Output;
11use ai_chain::tokens::TokenCollection;
12use tiktoken_rs::{cl100k_base, get_bpe_from_tokenizer};
13use tiktoken_rs::tokenizer::get_tokenizer;
14
15use super::prompt::create_chat_completion_request;
16use super::prompt::format_chat_messages;
17use async_openai::error::OpenAIError;
18use ai_chain::prompt::Prompt;
19
20use ai_chain::tokens::PromptTokensError;
21use ai_chain::tokens::{Tokenizer, TokenizerError};
22use ai_chain::traits;
23use ai_chain::traits::{ExecutorCreationError, ExecutorError};
24
25use async_trait::async_trait;
26use ai_chain::tokens::TokenCount;
27
28use std::sync::Arc;
29use tiktoken_rs::tokenizer::Tokenizer::Cl100kBase;
30use crate::chatgpt::config::OAIConfig;
31use crate::chatgpt::OpenAICompatibleInnerError;
32
33#[derive(Clone)]
35pub struct Executor<C: OAIConfig> {
36 config: C,
37 client: Arc<async_openai::Client<C>>,
39 options: Options,
41}
42
43
44impl<C: OAIConfig> Executor<C> {
45 pub fn for_client(client: async_openai::Client<C>, options: Options) -> Self {
47 use ai_chain::traits::Executor as _;
48 let mut exec = Self::new_with_options(options).unwrap();
49 exec.client = Arc::new(client);
50 exec
51 }
52
53 fn get_model_from_invocation_options(&self, opts: &OptionsCascade) -> String {
54 let Some(Opt::Model(model)) = opts.get(ai_chain::options::OptDiscriminants::Model) else {
55 return C::model_config().0
56 };
57 model.to_name()
58 }
59
60 fn cascade<'a>(&'a self, opts: Option<&'a Options>) -> OptionsCascade<'a> {
61 let mut v: Vec<&'a Options> = vec![&self.options];
62 if let Some(o) = opts {
63 v.push(o);
64 }
65 OptionsCascade::from_vec(v)
66 }
67}
68
69#[derive(thiserror::Error, Debug)]
70#[error(transparent)]
71pub enum Error {
72 OpenAIError(#[from] OpenAIError),
73}
74
75
76impl<OConfig: OAIConfig> Executor<OConfig> {
77 fn create_config() -> OConfig {
78 OConfig::create()
79 }
80
81 fn num_tokens_from_messages(
82 model: &str,
83 messages: &[ChatCompletionRequestMessage],
84 ) -> Result<usize, PromptTokensError> {
85 let mut tokenizer = Cl100kBase;
86 let (_, config_params) = OConfig::model_config();
87 if config_params.iter().any(|x| model.starts_with(x)) {
88 tokenizer = Cl100kBase;
89 } else {
90 let tokenizer1 = get_tokenizer(model).ok_or_else(|| PromptTokensError::NotAvailable)?;
91 if tokenizer1 != Cl100kBase {
92 return Err(PromptTokensError::NotAvailable);
93 }
94 tokenizer = tokenizer1;
95 }
96
97 let bpe = get_bpe_from_tokenizer(tokenizer).map_err(|_| PromptTokensError::NotAvailable)?;
98
99 let (tokens_per_message, tokens_per_name) = if model.starts_with("gpt-3.5") {
100 (
101 4, -1, )
104 } else {
105 (3, 1)
106 };
107
108 let mut num_tokens: i32 = 0;
109 for message in messages {
110 let (role, content, name) = match message {
111 ChatCompletionRequestMessage::System(x) => (
112 x.role.to_string(),
113 x.content.to_owned().unwrap_or_default(),
114 None,
115 ),
116 ChatCompletionRequestMessage::User(x) => (
117 x.role.to_string(),
118 x.content
119 .as_ref()
120 .and_then(|x| match x {
121 ChatCompletionRequestUserMessageContent::Text(x) => Some(x.to_string()),
122 _ => None,
123 })
124 .unwrap_or_default(),
125 None,
126 ),
127 ChatCompletionRequestMessage::Assistant(x) => (
128 x.role.to_string(),
129 x.content.to_owned().unwrap_or_default(),
130 None,
131 ),
132 ChatCompletionRequestMessage::Tool(x) => (
133 x.role.to_string(),
134 x.content.to_owned().unwrap_or_default(),
135 None,
136 ),
137 ChatCompletionRequestMessage::Function(x) => (
138 x.role.to_string(),
139 x.content.to_owned().unwrap_or_default(),
140 None,
141 ),
142 };
143 num_tokens += tokens_per_message;
144 num_tokens += bpe.encode_with_special_tokens(&role).len() as i32;
145 num_tokens += bpe.encode_with_special_tokens(&content).len() as i32;
146 if let Some(name) = name {
147 num_tokens += bpe.encode_with_special_tokens(name).len() as i32;
148 num_tokens += tokens_per_name;
149 }
150 }
151 num_tokens += 3; Ok(num_tokens as usize)
153 }
154}
155
156#[async_trait]
157impl<OConfig: OAIConfig> traits::Executor for Executor<OConfig> {
158 type StepTokenizer<'a> = OpenAITokenizer<OConfig>;
159 fn new_with_options(options: Options) -> Result<Self, ExecutorCreationError> {
163 let mut cfg = Self::create_config();
164
165 let opts = OptionsCascade::new().with_options(&options);
166
167 if let Some(Opt::ApiKey(api_key)) = opts.get(ai_chain::options::OptDiscriminants::ApiKey) {
168 cfg = cfg.with_api_key(api_key)
169 }
170 if let Ok(base_url) = std::env::var("OPENAI_API_BASE_URL") {
171 cfg = cfg.with_api_base(base_url);
172 }
173 let config = cfg.clone();
174
175 let client = Arc::new(async_openai::Client::with_config(cfg));
176 Ok(Self {
177 config,
178 client,
179 options,
180 })
181 }
182
183 async fn execute(&self, options: &Options, prompt: &Prompt) -> Result<Output, ExecutorError> {
184 let opts = self.cascade(Some(options));
185 let client = self.client.clone();
186 let model = self.get_model_from_invocation_options(&opts);
187 let input = create_chat_completion_request(model, prompt, opts.is_streaming()).unwrap();
188 if opts.is_streaming() {
189 let res = async move { client.chat().create_stream(input).await }
190 .await
191 .map_err(|e| ExecutorError::InnerError(e.into()))?;
192 Ok(stream_to_output(res))
193 } else {
194 let res = async move { client.chat().create(input).await }
195 .await
196 .map_err(|e| ExecutorError::InnerError(e.into()))?;
197 Ok(completion_to_output(res))
198 }
199 }
200
201 fn tokens_used(
202 &self,
203 opts: &Options,
204 prompt: &Prompt,
205 ) -> Result<TokenCount, PromptTokensError> {
206 let opts_cas = self.cascade(Some(opts));
207 let model = self.get_model_from_invocation_options(&opts_cas);
208 let messages = format_chat_messages(prompt.to_chat()).map_err(|e| match e {
209 OpenAICompatibleInnerError::StringTemplateError(e) => PromptTokensError::PromptFormatFailed(e),
210 _ => PromptTokensError::UnableToCompute,
211 })?;
212 let tokens_used = Self::num_tokens_from_messages(&model, &messages)
213 .map_err(|_| PromptTokensError::NotAvailable)?;
214
215 Ok(TokenCount::new(
216 self.max_tokens_allowed(opts),
217 tokens_used as i32,
218 ))
219 }
220 fn max_tokens_allowed(&self, opts: &Options) -> i32 {
222 let opts_cas = self.cascade(Some(opts));
223 let model = self.get_model_from_invocation_options(&opts_cas);
224 tiktoken_rs::model::get_context_size(&model)
225 .try_into()
226 .unwrap_or(4096)
227 }
228
229 fn answer_prefix(&self, _prompt: &Prompt) -> Option<String> {
230 None
231 }
232
233 fn get_tokenizer(&self, options: &Options) -> Result<OpenAITokenizer<OConfig>, TokenizerError> {
234 Ok(OpenAITokenizer::new(self.cascade(Some(options))))
235 }
236}
237
238
239pub struct OpenAITokenizer<C> {
240 model_name: String,
241 config: Option<C>,
242}
243
244impl<C: OAIConfig> OpenAITokenizer<C> {
245 pub fn new(options: OptionsCascade) -> Self {
246 let model_name = match options.get(ai_chain::options::OptDiscriminants::Model) {
247 Some(Opt::Model(model_name)) => model_name.to_name(),
248 _ => C::model_config().0
249 };
250 Self::for_model_name(model_name)
251 }
252 pub fn for_model_name<S: Into<String>>(model_name: S) -> Self {
254 let model_name: String = model_name.into();
255 Self { model_name, config: None }
256 }
257
258 fn get_bpe_from_model(&self) -> Result<tiktoken_rs::CoreBPE, PromptTokensError> {
259 use tiktoken_rs::get_bpe_from_model;
260 let (_, config_params) = C::model_config();
261 if config_params.iter().any(|x| self.model_name.starts_with(x)) {
262 return cl100k_base().map_err(|_| PromptTokensError::NotAvailable);
263 }
264 get_bpe_from_model(&self.model_name).map_err(|_| PromptTokensError::NotAvailable)
265 }
266}
267
268impl<C: OAIConfig> Tokenizer for OpenAITokenizer<C> {
269 fn tokenize_str(&self, doc: &str) -> Result<TokenCollection, TokenizerError> {
270 Ok(self
273 .get_bpe_from_model()
274 .map_err(|_| TokenizerError::TokenizationError)?
275 .encode_ordinary(doc)
276 .into())
277 }
278
279 fn to_string(&self, tokens: TokenCollection) -> Result<String, TokenizerError> {
280 let res = self
281 .get_bpe_from_model()
282 .map_err(|_e| TokenizerError::ToStringError)?
283 .decode(tokens.as_usize()?)
284 .map_err(|_e| TokenizerError::ToStringError)?;
285 Ok(res)
286 }
287}