ai_chain_openai_compatible/chatgpt/
executor.rs

1use super::prompt::completion_to_output;
2use super::prompt::stream_to_output;
3use async_openai::config::OpenAIConfig;
4use async_openai::types::ChatCompletionRequestMessage;
5
6use async_openai::types::ChatCompletionRequestUserMessageContent;
7use ai_chain::options::Opt;
8use ai_chain::options::Options;
9use ai_chain::options::OptionsCascade;
10use ai_chain::output::Output;
11use ai_chain::tokens::TokenCollection;
12use tiktoken_rs::{cl100k_base, get_bpe_from_tokenizer};
13use tiktoken_rs::tokenizer::get_tokenizer;
14
15use super::prompt::create_chat_completion_request;
16use super::prompt::format_chat_messages;
17use async_openai::error::OpenAIError;
18use ai_chain::prompt::Prompt;
19
20use ai_chain::tokens::PromptTokensError;
21use ai_chain::tokens::{Tokenizer, TokenizerError};
22use ai_chain::traits;
23use ai_chain::traits::{ExecutorCreationError, ExecutorError};
24
25use async_trait::async_trait;
26use ai_chain::tokens::TokenCount;
27
28use std::sync::Arc;
29use tiktoken_rs::tokenizer::Tokenizer::Cl100kBase;
30use crate::chatgpt::config::OAIConfig;
31use crate::chatgpt::OpenAICompatibleInnerError;
32
33/// The `Executor` struct for the ChatGPT model. This executor uses the `async_openai` crate to communicate with the OpenAI API.
34#[derive(Clone)]
35pub struct Executor<C: OAIConfig> {
36    config: C,
37    /// The client used to communicate with the OpenAI API.
38    client: Arc<async_openai::Client<C>>,
39    /// The per-invocation options for this executor.
40    options: Options,
41}
42
43
44impl<C: OAIConfig> Executor<C> {
45    /// Creates a new `Executor` with the given client.
46    pub fn for_client(client: async_openai::Client<C>, options: Options) -> Self {
47        use ai_chain::traits::Executor as _;
48        let mut exec = Self::new_with_options(options).unwrap();
49        exec.client = Arc::new(client);
50        exec
51    }
52
53    fn get_model_from_invocation_options(&self, opts: &OptionsCascade) -> String {
54        let Some(Opt::Model(model)) = opts.get(ai_chain::options::OptDiscriminants::Model) else {
55            return C::model_config().0
56        };
57        model.to_name()
58    }
59
60    fn cascade<'a>(&'a self, opts: Option<&'a Options>) -> OptionsCascade<'a> {
61        let mut v: Vec<&'a Options> = vec![&self.options];
62        if let Some(o) = opts {
63            v.push(o);
64        }
65        OptionsCascade::from_vec(v)
66    }
67}
68
69#[derive(thiserror::Error, Debug)]
70#[error(transparent)]
71pub enum Error {
72    OpenAIError(#[from] OpenAIError),
73}
74
75
76impl<OConfig: OAIConfig> Executor<OConfig> {
77    fn create_config() -> OConfig {
78        OConfig::create()
79    }
80
81    fn num_tokens_from_messages(
82        model: &str,
83        messages: &[ChatCompletionRequestMessage],
84    ) -> Result<usize, PromptTokensError> {
85        let mut tokenizer = Cl100kBase;
86        let (_, config_params) = OConfig::model_config();
87        if config_params.iter().any(|x| model.starts_with(x)) {
88            tokenizer = Cl100kBase;
89        } else {
90            let tokenizer1 = get_tokenizer(model).ok_or_else(|| PromptTokensError::NotAvailable)?;
91            if tokenizer1 != Cl100kBase {
92                return Err(PromptTokensError::NotAvailable);
93            }
94            tokenizer = tokenizer1;
95        }
96
97        let bpe = get_bpe_from_tokenizer(tokenizer).map_err(|_| PromptTokensError::NotAvailable)?;
98
99        let (tokens_per_message, tokens_per_name) = if model.starts_with("gpt-3.5") {
100            (
101                4,  // every message follows <im_start>{role/name}\n{content}<im_end>\n
102                -1, // if there's a name, the role is omitted
103            )
104        } else {
105            (3, 1)
106        };
107
108        let mut num_tokens: i32 = 0;
109        for message in messages {
110            let (role, content, name) = match message {
111                ChatCompletionRequestMessage::System(x) => (
112                    x.role.to_string(),
113                    x.content.to_owned().unwrap_or_default(),
114                    None,
115                ),
116                ChatCompletionRequestMessage::User(x) => (
117                    x.role.to_string(),
118                    x.content
119                        .as_ref()
120                        .and_then(|x| match x {
121                            ChatCompletionRequestUserMessageContent::Text(x) => Some(x.to_string()),
122                            _ => None,
123                        })
124                        .unwrap_or_default(),
125                    None,
126                ),
127                ChatCompletionRequestMessage::Assistant(x) => (
128                    x.role.to_string(),
129                    x.content.to_owned().unwrap_or_default(),
130                    None,
131                ),
132                ChatCompletionRequestMessage::Tool(x) => (
133                    x.role.to_string(),
134                    x.content.to_owned().unwrap_or_default(),
135                    None,
136                ),
137                ChatCompletionRequestMessage::Function(x) => (
138                    x.role.to_string(),
139                    x.content.to_owned().unwrap_or_default(),
140                    None,
141                ),
142            };
143            num_tokens += tokens_per_message;
144            num_tokens += bpe.encode_with_special_tokens(&role).len() as i32;
145            num_tokens += bpe.encode_with_special_tokens(&content).len() as i32;
146            if let Some(name) = name {
147                num_tokens += bpe.encode_with_special_tokens(name).len() as i32;
148                num_tokens += tokens_per_name;
149            }
150        }
151        num_tokens += 3; // every reply is primed with <|start|>assistant<|message|>
152        Ok(num_tokens as usize)
153    }
154}
155
156#[async_trait]
157impl<OConfig: OAIConfig> traits::Executor for Executor<OConfig> {
158    type StepTokenizer<'a> = OpenAITokenizer<OConfig>;
159    /// Creates a new `Executor` with the given options.
160    ///
161    /// if the `OPENAI_ORG_ID` environment variable is present, it will be used as the org_ig for the OpenAI client.
162    fn new_with_options(options: Options) -> Result<Self, ExecutorCreationError> {
163        let mut cfg = Self::create_config();
164
165        let opts = OptionsCascade::new().with_options(&options);
166
167        if let Some(Opt::ApiKey(api_key)) = opts.get(ai_chain::options::OptDiscriminants::ApiKey) {
168            cfg = cfg.with_api_key(api_key)
169        }
170        if let Ok(base_url) = std::env::var("OPENAI_API_BASE_URL") {
171            cfg = cfg.with_api_base(base_url);
172        }
173        let config = cfg.clone();
174
175        let client = Arc::new(async_openai::Client::with_config(cfg));
176        Ok(Self {
177            config,
178            client,
179            options,
180        })
181    }
182
183    async fn execute(&self, options: &Options, prompt: &Prompt) -> Result<Output, ExecutorError> {
184        let opts = self.cascade(Some(options));
185        let client = self.client.clone();
186        let model = self.get_model_from_invocation_options(&opts);
187        let input = create_chat_completion_request(model, prompt, opts.is_streaming()).unwrap();
188        if opts.is_streaming() {
189            let res = async move { client.chat().create_stream(input).await }
190                .await
191                .map_err(|e| ExecutorError::InnerError(e.into()))?;
192            Ok(stream_to_output(res))
193        } else {
194            let res = async move { client.chat().create(input).await }
195                .await
196                .map_err(|e| ExecutorError::InnerError(e.into()))?;
197            Ok(completion_to_output(res))
198        }
199    }
200
201    fn tokens_used(
202        &self,
203        opts: &Options,
204        prompt: &Prompt,
205    ) -> Result<TokenCount, PromptTokensError> {
206        let opts_cas = self.cascade(Some(opts));
207        let model = self.get_model_from_invocation_options(&opts_cas);
208        let messages = format_chat_messages(prompt.to_chat()).map_err(|e| match e {
209            OpenAICompatibleInnerError::StringTemplateError(e) => PromptTokensError::PromptFormatFailed(e),
210            _ => PromptTokensError::UnableToCompute,
211        })?;
212        let tokens_used = Self::num_tokens_from_messages(&model, &messages)
213            .map_err(|_| PromptTokensError::NotAvailable)?;
214
215        Ok(TokenCount::new(
216            self.max_tokens_allowed(opts),
217            tokens_used as i32,
218        ))
219    }
220    /// Get the context size from the model or return default context size
221    fn max_tokens_allowed(&self, opts: &Options) -> i32 {
222        let opts_cas = self.cascade(Some(opts));
223        let model = self.get_model_from_invocation_options(&opts_cas);
224        tiktoken_rs::model::get_context_size(&model)
225            .try_into()
226            .unwrap_or(4096)
227    }
228
229    fn answer_prefix(&self, _prompt: &Prompt) -> Option<String> {
230        None
231    }
232
233    fn get_tokenizer(&self, options: &Options) -> Result<OpenAITokenizer<OConfig>, TokenizerError> {
234        Ok(OpenAITokenizer::new(self.cascade(Some(options))))
235    }
236}
237
238
239pub struct OpenAITokenizer<C> {
240    model_name: String,
241    config: Option<C>,
242}
243
244impl<C: OAIConfig> OpenAITokenizer<C> {
245    pub fn new(options: OptionsCascade) -> Self {
246        let model_name = match options.get(ai_chain::options::OptDiscriminants::Model) {
247            Some(Opt::Model(model_name)) => model_name.to_name(),
248            _ => C::model_config().0
249        };
250        Self::for_model_name(model_name)
251    }
252    /// Creates an OpenAITokenizer for the passed in model name
253    pub fn for_model_name<S: Into<String>>(model_name: S) -> Self {
254        let model_name: String = model_name.into();
255        Self { model_name, config: None }
256    }
257
258    fn get_bpe_from_model(&self) -> Result<tiktoken_rs::CoreBPE, PromptTokensError> {
259        use tiktoken_rs::get_bpe_from_model;
260        let (_, config_params) = C::model_config();
261        if config_params.iter().any(|x| self.model_name.starts_with(x)) {
262            return cl100k_base().map_err(|_| PromptTokensError::NotAvailable);
263        }
264        get_bpe_from_model(&self.model_name).map_err(|_| PromptTokensError::NotAvailable)
265    }
266}
267
268impl<C: OAIConfig> Tokenizer for OpenAITokenizer<C> {
269    fn tokenize_str(&self, doc: &str) -> Result<TokenCollection, TokenizerError> {
270        // if model.starts_with("moonshot") {
271        //     tokenizer = Cl100kBase;
272        Ok(self
273            .get_bpe_from_model()
274            .map_err(|_| TokenizerError::TokenizationError)?
275            .encode_ordinary(doc)
276            .into())
277    }
278
279    fn to_string(&self, tokens: TokenCollection) -> Result<String, TokenizerError> {
280        let res = self
281            .get_bpe_from_model()
282            .map_err(|_e| TokenizerError::ToStringError)?
283            .decode(tokens.as_usize()?)
284            .map_err(|_e| TokenizerError::ToStringError)?;
285        Ok(res)
286    }
287}