llm_chain_openai/chatgpt/
executor.rs1use super::error::OpenAIInnerError;
2use super::prompt::completion_to_output;
3use super::prompt::stream_to_output;
4use async_openai::config::OpenAIConfig;
5use async_openai::types::ChatCompletionRequestMessage;
6
7use async_openai::types::ChatCompletionRequestUserMessageContent;
8use llm_chain::options::Opt;
9use llm_chain::options::Options;
10use llm_chain::options::OptionsCascade;
11use llm_chain::output::Output;
12use llm_chain::tokens::TokenCollection;
13use tiktoken_rs::get_bpe_from_tokenizer;
14use tiktoken_rs::tokenizer::get_tokenizer;
15
16use super::prompt::create_chat_completion_request;
17use super::prompt::format_chat_messages;
18use async_openai::error::OpenAIError;
19use llm_chain::prompt::Prompt;
20
21use llm_chain::tokens::PromptTokensError;
22use llm_chain::tokens::{Tokenizer, TokenizerError};
23use llm_chain::traits;
24use llm_chain::traits::{ExecutorCreationError, ExecutorError};
25
26use async_trait::async_trait;
27use llm_chain::tokens::TokenCount;
28
29use std::sync::Arc;
30
31#[derive(Clone)]
33pub struct Executor {
34 client: Arc<async_openai::Client<OpenAIConfig>>,
36 options: Options,
38}
39
40impl Default for Executor {
41 fn default() -> Self {
42 let options = Options::default();
43 let client = Arc::new(async_openai::Client::new());
44 Self { client, options }
45 }
46}
47
48impl Executor {
49 pub fn for_client(client: async_openai::Client<OpenAIConfig>, options: Options) -> Self {
51 use llm_chain::traits::Executor as _;
52 let mut exec = Self::new_with_options(options).unwrap();
53 exec.client = Arc::new(client);
54 exec
55 }
56
57 fn get_model_from_invocation_options(&self, opts: &OptionsCascade) -> String {
58 let Some(Opt::Model(model)) = opts.get(llm_chain::options::OptDiscriminants::Model) else {
59 return "gpt-3.5-turbo".to_string();
60 };
61 model.to_name()
62 }
63
64 fn cascade<'a>(&'a self, opts: Option<&'a Options>) -> OptionsCascade<'a> {
65 let mut v: Vec<&'a Options> = vec![&self.options];
66 if let Some(o) = opts {
67 v.push(o);
68 }
69 OptionsCascade::from_vec(v)
70 }
71}
72
73#[derive(thiserror::Error, Debug)]
74#[error(transparent)]
75pub enum Error {
76 OpenAIError(#[from] OpenAIError),
77}
78
79#[async_trait]
80impl traits::Executor for Executor {
81 type StepTokenizer<'a> = OpenAITokenizer;
82 fn new_with_options(options: Options) -> Result<Self, ExecutorCreationError> {
86 let mut cfg = OpenAIConfig::new();
87
88 let opts = OptionsCascade::new().with_options(&options);
89
90 if let Some(Opt::ApiKey(api_key)) = opts.get(llm_chain::options::OptDiscriminants::ApiKey) {
91 cfg = cfg.with_api_key(api_key)
92 }
93
94 if let Ok(org_id) = std::env::var("OPENAI_ORG_ID") {
95 cfg = cfg.with_org_id(org_id);
96 }
97 let client = Arc::new(async_openai::Client::with_config(cfg));
98 Ok(Self { client, options })
99 }
100
101 async fn execute(&self, options: &Options, prompt: &Prompt) -> Result<Output, ExecutorError> {
102 let opts = self.cascade(Some(options));
103 let client = self.client.clone();
104 let model = self.get_model_from_invocation_options(&opts);
105 let input = create_chat_completion_request(model, prompt, opts.is_streaming()).unwrap();
106 if opts.is_streaming() {
107 let res = async move { client.chat().create_stream(input).await }
108 .await
109 .map_err(|e| ExecutorError::InnerError(e.into()))?;
110 Ok(stream_to_output(res))
111 } else {
112 let res = async move { client.chat().create(input).await }
113 .await
114 .map_err(|e| ExecutorError::InnerError(e.into()))?;
115 Ok(completion_to_output(res))
116 }
117 }
118
119 fn tokens_used(
120 &self,
121 opts: &Options,
122 prompt: &Prompt,
123 ) -> Result<TokenCount, PromptTokensError> {
124 let opts_cas = self.cascade(Some(opts));
125 let model = self.get_model_from_invocation_options(&opts_cas);
126 let messages = format_chat_messages(prompt.to_chat()).map_err(|e| match e {
127 OpenAIInnerError::StringTemplateError(e) => PromptTokensError::PromptFormatFailed(e),
128 _ => PromptTokensError::UnableToCompute,
129 })?;
130 let tokens_used = num_tokens_from_messages(&model, &messages)
131 .map_err(|_| PromptTokensError::NotAvailable)?;
132
133 Ok(TokenCount::new(
134 self.max_tokens_allowed(opts),
135 tokens_used as i32,
136 ))
137 }
138 fn max_tokens_allowed(&self, opts: &Options) -> i32 {
140 let opts_cas = self.cascade(Some(opts));
141 let model = self.get_model_from_invocation_options(&opts_cas);
142 tiktoken_rs::model::get_context_size(&model)
143 .try_into()
144 .unwrap_or(4096)
145 }
146
147 fn answer_prefix(&self, _prompt: &Prompt) -> Option<String> {
148 None
149 }
150
151 fn get_tokenizer(&self, options: &Options) -> Result<OpenAITokenizer, TokenizerError> {
152 Ok(OpenAITokenizer::new(self.cascade(Some(options))))
153 }
154}
155
156fn num_tokens_from_messages(
157 model: &str,
158 messages: &[ChatCompletionRequestMessage],
159) -> Result<usize, PromptTokensError> {
160 let tokenizer = get_tokenizer(model).ok_or_else(|| PromptTokensError::NotAvailable)?;
161 if tokenizer != tiktoken_rs::tokenizer::Tokenizer::Cl100kBase {
162 return Err(PromptTokensError::NotAvailable);
163 }
164 let bpe = get_bpe_from_tokenizer(tokenizer).map_err(|_| PromptTokensError::NotAvailable)?;
165
166 let (tokens_per_message, tokens_per_name) = if model.starts_with("gpt-3.5") {
167 (
168 4, -1, )
171 } else {
172 (3, 1)
173 };
174
175 let mut num_tokens: i32 = 0;
176 for message in messages {
177 let (role, content, name) = match message {
178 ChatCompletionRequestMessage::System(x) => (
179 x.role.to_string(),
180 x.content.to_owned().unwrap_or_default(),
181 None,
182 ),
183 ChatCompletionRequestMessage::User(x) => (
184 x.role.to_string(),
185 x.content
186 .as_ref()
187 .and_then(|x| match x {
188 ChatCompletionRequestUserMessageContent::Text(x) => Some(x.to_string()),
189 _ => None,
190 })
191 .unwrap_or_default(),
192 None,
193 ),
194 ChatCompletionRequestMessage::Assistant(x) => (
195 x.role.to_string(),
196 x.content.to_owned().unwrap_or_default(),
197 None,
198 ),
199 ChatCompletionRequestMessage::Tool(x) => (
200 x.role.to_string(),
201 x.content.to_owned().unwrap_or_default(),
202 None,
203 ),
204 ChatCompletionRequestMessage::Function(x) => (
205 x.role.to_string(),
206 x.content.to_owned().unwrap_or_default(),
207 None,
208 ),
209 };
210 num_tokens += tokens_per_message;
211 num_tokens += bpe.encode_with_special_tokens(&role).len() as i32;
212 num_tokens += bpe.encode_with_special_tokens(&content).len() as i32;
213 if let Some(name) = name {
214 num_tokens += bpe.encode_with_special_tokens(name).len() as i32;
215 num_tokens += tokens_per_name;
216 }
217 }
218 num_tokens += 3; Ok(num_tokens as usize)
220}
221
222pub struct OpenAITokenizer {
223 model_name: String,
224}
225
226impl OpenAITokenizer {
227 pub fn new(options: OptionsCascade) -> Self {
228 let model_name = match options.get(llm_chain::options::OptDiscriminants::Model) {
229 Some(Opt::Model(model_name)) => model_name.to_name(),
230 _ => "gpt-3.5-turbo".to_string(),
231 };
232 Self::for_model_name(model_name)
233 }
234 pub fn for_model_name<S: Into<String>>(model_name: S) -> Self {
236 let model_name: String = model_name.into();
237 Self { model_name }
238 }
239
240 fn get_bpe_from_model(&self) -> Result<tiktoken_rs::CoreBPE, PromptTokensError> {
241 use tiktoken_rs::get_bpe_from_model;
242 get_bpe_from_model(&self.model_name).map_err(|_| PromptTokensError::NotAvailable)
243 }
244}
245
246impl Tokenizer for OpenAITokenizer {
247 fn tokenize_str(&self, doc: &str) -> Result<TokenCollection, TokenizerError> {
248 Ok(self
249 .get_bpe_from_model()
250 .map_err(|_| TokenizerError::TokenizationError)?
251 .encode_ordinary(doc)
252 .into())
253 }
254
255 fn to_string(&self, tokens: TokenCollection) -> Result<String, TokenizerError> {
256 let res = self
257 .get_bpe_from_model()
258 .map_err(|_e| TokenizerError::ToStringError)?
259 .decode(tokens.as_usize()?)
260 .map_err(|_e| TokenizerError::ToStringError)?;
261 Ok(res)
262 }
263}