ai_chain_openai/chatgpt/
executor.rs1use super::error::OpenAIInnerError;
2use super::prompt::completion_to_output;
3use super::prompt::stream_to_output;
4use async_openai::config::OpenAIConfig;
5use async_openai::types::ChatCompletionRequestMessage;
6
7use async_openai::types::ChatCompletionRequestUserMessageContent;
8use ai_chain::options::Opt;
9use ai_chain::options::Options;
10use ai_chain::options::OptionsCascade;
11use ai_chain::output::Output;
12use ai_chain::tokens::TokenCollection;
13use tiktoken_rs::{cl100k_base, get_bpe_from_tokenizer};
14use tiktoken_rs::tokenizer::get_tokenizer;
15
16use super::prompt::create_chat_completion_request;
17use super::prompt::format_chat_messages;
18use async_openai::error::OpenAIError;
19use ai_chain::prompt::Prompt;
20
21use ai_chain::tokens::PromptTokensError;
22use ai_chain::tokens::{Tokenizer, TokenizerError};
23use ai_chain::traits;
24use ai_chain::traits::{ExecutorCreationError, ExecutorError};
25
26use async_trait::async_trait;
27use ai_chain::tokens::TokenCount;
28
29use std::sync::Arc;
30use tiktoken_rs::tokenizer::Tokenizer::Cl100kBase;
31
32#[derive(Clone)]
34pub struct Executor {
35 client: Arc<async_openai::Client<OpenAIConfig>>,
37 options: Options,
39}
40
41impl Default for Executor {
42 fn default() -> Self {
43 let options = Options::default();
44 let client = Arc::new(async_openai::Client::new());
45 Self { client, options }
46 }
47}
48
49impl Executor {
50 pub fn for_client(client: async_openai::Client<OpenAIConfig>, options: Options) -> Self {
52 use ai_chain::traits::Executor as _;
53 let mut exec = Self::new_with_options(options).unwrap();
54 exec.client = Arc::new(client);
55 exec
56 }
57
58 fn get_model_from_invocation_options(&self, opts: &OptionsCascade) -> String {
59 let Some(Opt::Model(model)) = opts.get(ai_chain::options::OptDiscriminants::Model) else {
60 return "gpt-3.5-turbo".to_string();
61 };
62 model.to_name()
63 }
64
65 fn cascade<'a>(&'a self, opts: Option<&'a Options>) -> OptionsCascade<'a> {
66 let mut v: Vec<&'a Options> = vec![&self.options];
67 if let Some(o) = opts {
68 v.push(o);
69 }
70 OptionsCascade::from_vec(v)
71 }
72}
73
74#[derive(thiserror::Error, Debug)]
75#[error(transparent)]
76pub enum Error {
77 OpenAIError(#[from] OpenAIError),
78}
79
80#[async_trait]
81impl traits::Executor for Executor {
82 type StepTokenizer<'a> = OpenAITokenizer;
83 fn new_with_options(options: Options) -> Result<Self, ExecutorCreationError> {
87 let mut cfg = OpenAIConfig::new();
88
89 let opts = OptionsCascade::new().with_options(&options);
90
91 if let Some(Opt::ApiKey(api_key)) = opts.get(ai_chain::options::OptDiscriminants::ApiKey) {
92 cfg = cfg.with_api_key(api_key)
93 }
94
95 if let Ok(org_id) = std::env::var("OPENAI_ORG_ID") {
96 cfg = cfg.with_org_id(org_id);
97 }
98
99 if let Ok(base_url) = std::env::var("OPENAI_API_BASE_URL") {
100 cfg = cfg.with_api_base(base_url);
101 }
102
103 let client = Arc::new(async_openai::Client::with_config(cfg));
104 Ok(Self { client, options })
105 }
106
107 async fn execute(&self, options: &Options, prompt: &Prompt) -> Result<Output, ExecutorError> {
108 let opts = self.cascade(Some(options));
109 let client = self.client.clone();
110 let model = self.get_model_from_invocation_options(&opts);
111 let input = create_chat_completion_request(model, prompt, opts.is_streaming()).unwrap();
112 if opts.is_streaming() {
113 let res = async move { client.chat().create_stream(input).await }
114 .await
115 .map_err(|e| ExecutorError::InnerError(e.into()))?;
116 Ok(stream_to_output(res))
117 } else {
118 let res = async move { client.chat().create(input).await }
119 .await
120 .map_err(|e| ExecutorError::InnerError(e.into()))?;
121 Ok(completion_to_output(res))
122 }
123 }
124
125 fn tokens_used(
126 &self,
127 opts: &Options,
128 prompt: &Prompt,
129 ) -> Result<TokenCount, PromptTokensError> {
130 let opts_cas = self.cascade(Some(opts));
131 let model = self.get_model_from_invocation_options(&opts_cas);
132 let messages = format_chat_messages(prompt.to_chat()).map_err(|e| match e {
133 OpenAIInnerError::StringTemplateError(e) => PromptTokensError::PromptFormatFailed(e),
134 _ => PromptTokensError::UnableToCompute,
135 })?;
136 let tokens_used = num_tokens_from_messages(&model, &messages)
137 .map_err(|_| PromptTokensError::NotAvailable)?;
138
139 Ok(TokenCount::new(
140 self.max_tokens_allowed(opts),
141 tokens_used as i32,
142 ))
143 }
144 fn max_tokens_allowed(&self, opts: &Options) -> i32 {
146 let opts_cas = self.cascade(Some(opts));
147 let model = self.get_model_from_invocation_options(&opts_cas);
148 tiktoken_rs::model::get_context_size(&model)
149 .try_into()
150 .unwrap_or(4096)
151 }
152
153 fn answer_prefix(&self, _prompt: &Prompt) -> Option<String> {
154 None
155 }
156
157 fn get_tokenizer(&self, options: &Options) -> Result<OpenAITokenizer, TokenizerError> {
158 Ok(OpenAITokenizer::new(self.cascade(Some(options))))
159 }
160}
161
162fn num_tokens_from_messages(
163 model: &str,
164 messages: &[ChatCompletionRequestMessage],
165) -> Result<usize, PromptTokensError> {
166 let mut tokenizer = Cl100kBase;
167
168 if model.starts_with("moonshot") {
169 tokenizer = Cl100kBase;
170 } else {
171 let tokenizer1 = get_tokenizer(model).ok_or_else(|| PromptTokensError::NotAvailable)?;
172 if tokenizer1 != Cl100kBase {
173 return Err(PromptTokensError::NotAvailable);
174 }
175 tokenizer = tokenizer1;
176 }
177
178 let bpe = get_bpe_from_tokenizer(tokenizer).map_err(|_| PromptTokensError::NotAvailable)?;
179
180 let (tokens_per_message, tokens_per_name) = if model.starts_with("gpt-3.5") {
181 (
182 4, -1, )
185 } else {
186 (3, 1)
187 };
188
189 let mut num_tokens: i32 = 0;
190 for message in messages {
191 let (role, content, name) = match message {
192 ChatCompletionRequestMessage::System(x) => (
193 x.role.to_string(),
194 x.content.to_owned().unwrap_or_default(),
195 None,
196 ),
197 ChatCompletionRequestMessage::User(x) => (
198 x.role.to_string(),
199 x.content
200 .as_ref()
201 .and_then(|x| match x {
202 ChatCompletionRequestUserMessageContent::Text(x) => Some(x.to_string()),
203 _ => None,
204 })
205 .unwrap_or_default(),
206 None,
207 ),
208 ChatCompletionRequestMessage::Assistant(x) => (
209 x.role.to_string(),
210 x.content.to_owned().unwrap_or_default(),
211 None,
212 ),
213 ChatCompletionRequestMessage::Tool(x) => (
214 x.role.to_string(),
215 x.content.to_owned().unwrap_or_default(),
216 None,
217 ),
218 ChatCompletionRequestMessage::Function(x) => (
219 x.role.to_string(),
220 x.content.to_owned().unwrap_or_default(),
221 None,
222 ),
223 };
224 num_tokens += tokens_per_message;
225 num_tokens += bpe.encode_with_special_tokens(&role).len() as i32;
226 num_tokens += bpe.encode_with_special_tokens(&content).len() as i32;
227 if let Some(name) = name {
228 num_tokens += bpe.encode_with_special_tokens(name).len() as i32;
229 num_tokens += tokens_per_name;
230 }
231 }
232 num_tokens += 3; Ok(num_tokens as usize)
234}
235
236pub struct OpenAITokenizer {
237 model_name: String,
238}
239
240impl OpenAITokenizer {
241 pub fn new(options: OptionsCascade) -> Self {
242 let model_name = match options.get(ai_chain::options::OptDiscriminants::Model) {
243 Some(Opt::Model(model_name)) => model_name.to_name(),
244 _ => "gpt-3.5-turbo".to_string(),
245 };
246 Self::for_model_name(model_name)
247 }
248 pub fn for_model_name<S: Into<String>>(model_name: S) -> Self {
250 let model_name: String = model_name.into();
251 Self { model_name }
252 }
253
254 fn get_bpe_from_model(&self) -> Result<tiktoken_rs::CoreBPE, PromptTokensError> {
255 use tiktoken_rs::get_bpe_from_model;
256 if self.model_name.starts_with("moonshot") {
257 return cl100k_base().map_err(|_| PromptTokensError::NotAvailable);
258 }
259 get_bpe_from_model(&self.model_name).map_err(|_| PromptTokensError::NotAvailable)
260 }
261}
262
263impl Tokenizer for OpenAITokenizer {
264 fn tokenize_str(&self, doc: &str) -> Result<TokenCollection, TokenizerError> {
265 Ok(self
268 .get_bpe_from_model()
269 .map_err(|_| TokenizerError::TokenizationError)?
270 .encode_ordinary(doc)
271 .into())
272 }
273
274 fn to_string(&self, tokens: TokenCollection) -> Result<String, TokenizerError> {
275 let res = self
276 .get_bpe_from_model()
277 .map_err(|_e| TokenizerError::ToStringError)?
278 .decode(tokens.as_usize()?)
279 .map_err(|_e| TokenizerError::ToStringError)?;
280 Ok(res)
281 }
282}