ai_chain_openai/chatgpt/
executor.rs

1use super::error::OpenAIInnerError;
2use super::prompt::completion_to_output;
3use super::prompt::stream_to_output;
4use async_openai::config::OpenAIConfig;
5use async_openai::types::ChatCompletionRequestMessage;
6
7use async_openai::types::ChatCompletionRequestUserMessageContent;
8use ai_chain::options::Opt;
9use ai_chain::options::Options;
10use ai_chain::options::OptionsCascade;
11use ai_chain::output::Output;
12use ai_chain::tokens::TokenCollection;
13use tiktoken_rs::{cl100k_base, get_bpe_from_tokenizer};
14use tiktoken_rs::tokenizer::get_tokenizer;
15
16use super::prompt::create_chat_completion_request;
17use super::prompt::format_chat_messages;
18use async_openai::error::OpenAIError;
19use ai_chain::prompt::Prompt;
20
21use ai_chain::tokens::PromptTokensError;
22use ai_chain::tokens::{Tokenizer, TokenizerError};
23use ai_chain::traits;
24use ai_chain::traits::{ExecutorCreationError, ExecutorError};
25
26use async_trait::async_trait;
27use ai_chain::tokens::TokenCount;
28
29use std::sync::Arc;
30use tiktoken_rs::tokenizer::Tokenizer::Cl100kBase;
31
32/// The `Executor` struct for the ChatGPT model. This executor uses the `async_openai` crate to communicate with the OpenAI API.
33#[derive(Clone)]
34pub struct Executor {
35    /// The client used to communicate with the OpenAI API.
36    client: Arc<async_openai::Client<OpenAIConfig>>,
37    /// The per-invocation options for this executor.
38    options: Options,
39}
40
41impl Default for Executor {
42    fn default() -> Self {
43        let options = Options::default();
44        let client = Arc::new(async_openai::Client::new());
45        Self { client, options }
46    }
47}
48
49impl Executor {
50    /// Creates a new `Executor` with the given client.
51    pub fn for_client(client: async_openai::Client<OpenAIConfig>, options: Options) -> Self {
52        use ai_chain::traits::Executor as _;
53        let mut exec = Self::new_with_options(options).unwrap();
54        exec.client = Arc::new(client);
55        exec
56    }
57
58    fn get_model_from_invocation_options(&self, opts: &OptionsCascade) -> String {
59        let Some(Opt::Model(model)) = opts.get(ai_chain::options::OptDiscriminants::Model) else {
60            return "gpt-3.5-turbo".to_string();
61        };
62        model.to_name()
63    }
64
65    fn cascade<'a>(&'a self, opts: Option<&'a Options>) -> OptionsCascade<'a> {
66        let mut v: Vec<&'a Options> = vec![&self.options];
67        if let Some(o) = opts {
68            v.push(o);
69        }
70        OptionsCascade::from_vec(v)
71    }
72}
73
74#[derive(thiserror::Error, Debug)]
75#[error(transparent)]
76pub enum Error {
77    OpenAIError(#[from] OpenAIError),
78}
79
80#[async_trait]
81impl traits::Executor for Executor {
82    type StepTokenizer<'a> = OpenAITokenizer;
83    /// Creates a new `Executor` with the given options.
84    ///
85    /// if the `OPENAI_ORG_ID` environment variable is present, it will be used as the org_ig for the OpenAI client.
86    fn new_with_options(options: Options) -> Result<Self, ExecutorCreationError> {
87        let mut cfg = OpenAIConfig::new();
88
89        let opts = OptionsCascade::new().with_options(&options);
90
91        if let Some(Opt::ApiKey(api_key)) = opts.get(ai_chain::options::OptDiscriminants::ApiKey) {
92            cfg = cfg.with_api_key(api_key)
93        }
94
95        if let Ok(org_id) = std::env::var("OPENAI_ORG_ID") {
96            cfg = cfg.with_org_id(org_id);
97        }
98
99        if let Ok(base_url) = std::env::var("OPENAI_API_BASE_URL") {
100            cfg = cfg.with_api_base(base_url);
101        }
102
103        let client = Arc::new(async_openai::Client::with_config(cfg));
104        Ok(Self { client, options })
105    }
106
107    async fn execute(&self, options: &Options, prompt: &Prompt) -> Result<Output, ExecutorError> {
108        let opts = self.cascade(Some(options));
109        let client = self.client.clone();
110        let model = self.get_model_from_invocation_options(&opts);
111        let input = create_chat_completion_request(model, prompt, opts.is_streaming()).unwrap();
112        if opts.is_streaming() {
113            let res = async move { client.chat().create_stream(input).await }
114                .await
115                .map_err(|e| ExecutorError::InnerError(e.into()))?;
116            Ok(stream_to_output(res))
117        } else {
118            let res = async move { client.chat().create(input).await }
119                .await
120                .map_err(|e| ExecutorError::InnerError(e.into()))?;
121            Ok(completion_to_output(res))
122        }
123    }
124
125    fn tokens_used(
126        &self,
127        opts: &Options,
128        prompt: &Prompt,
129    ) -> Result<TokenCount, PromptTokensError> {
130        let opts_cas = self.cascade(Some(opts));
131        let model = self.get_model_from_invocation_options(&opts_cas);
132        let messages = format_chat_messages(prompt.to_chat()).map_err(|e| match e {
133            OpenAIInnerError::StringTemplateError(e) => PromptTokensError::PromptFormatFailed(e),
134            _ => PromptTokensError::UnableToCompute,
135        })?;
136        let tokens_used = num_tokens_from_messages(&model, &messages)
137            .map_err(|_| PromptTokensError::NotAvailable)?;
138
139        Ok(TokenCount::new(
140            self.max_tokens_allowed(opts),
141            tokens_used as i32,
142        ))
143    }
144    /// Get the context size from the model or return default context size
145    fn max_tokens_allowed(&self, opts: &Options) -> i32 {
146        let opts_cas = self.cascade(Some(opts));
147        let model = self.get_model_from_invocation_options(&opts_cas);
148        tiktoken_rs::model::get_context_size(&model)
149            .try_into()
150            .unwrap_or(4096)
151    }
152
153    fn answer_prefix(&self, _prompt: &Prompt) -> Option<String> {
154        None
155    }
156
157    fn get_tokenizer(&self, options: &Options) -> Result<OpenAITokenizer, TokenizerError> {
158        Ok(OpenAITokenizer::new(self.cascade(Some(options))))
159    }
160}
161
162fn num_tokens_from_messages(
163    model: &str,
164    messages: &[ChatCompletionRequestMessage],
165) -> Result<usize, PromptTokensError> {
166    let mut tokenizer = Cl100kBase;
167
168    if model.starts_with("moonshot") {
169        tokenizer = Cl100kBase;
170    } else {
171        let tokenizer1 = get_tokenizer(model).ok_or_else(|| PromptTokensError::NotAvailable)?;
172        if tokenizer1 != Cl100kBase {
173            return Err(PromptTokensError::NotAvailable);
174        }
175        tokenizer = tokenizer1;
176    }
177
178    let bpe = get_bpe_from_tokenizer(tokenizer).map_err(|_| PromptTokensError::NotAvailable)?;
179
180    let (tokens_per_message, tokens_per_name) = if model.starts_with("gpt-3.5") {
181        (
182            4,  // every message follows <im_start>{role/name}\n{content}<im_end>\n
183            -1, // if there's a name, the role is omitted
184        )
185    } else {
186        (3, 1)
187    };
188
189    let mut num_tokens: i32 = 0;
190    for message in messages {
191        let (role, content, name) = match message {
192            ChatCompletionRequestMessage::System(x) => (
193                x.role.to_string(),
194                x.content.to_owned().unwrap_or_default(),
195                None,
196            ),
197            ChatCompletionRequestMessage::User(x) => (
198                x.role.to_string(),
199                x.content
200                    .as_ref()
201                    .and_then(|x| match x {
202                        ChatCompletionRequestUserMessageContent::Text(x) => Some(x.to_string()),
203                        _ => None,
204                    })
205                    .unwrap_or_default(),
206                None,
207            ),
208            ChatCompletionRequestMessage::Assistant(x) => (
209                x.role.to_string(),
210                x.content.to_owned().unwrap_or_default(),
211                None,
212            ),
213            ChatCompletionRequestMessage::Tool(x) => (
214                x.role.to_string(),
215                x.content.to_owned().unwrap_or_default(),
216                None,
217            ),
218            ChatCompletionRequestMessage::Function(x) => (
219                x.role.to_string(),
220                x.content.to_owned().unwrap_or_default(),
221                None,
222            ),
223        };
224        num_tokens += tokens_per_message;
225        num_tokens += bpe.encode_with_special_tokens(&role).len() as i32;
226        num_tokens += bpe.encode_with_special_tokens(&content).len() as i32;
227        if let Some(name) = name {
228            num_tokens += bpe.encode_with_special_tokens(name).len() as i32;
229            num_tokens += tokens_per_name;
230        }
231    }
232    num_tokens += 3; // every reply is primed with <|start|>assistant<|message|>
233    Ok(num_tokens as usize)
234}
235
236pub struct OpenAITokenizer {
237    model_name: String,
238}
239
240impl OpenAITokenizer {
241    pub fn new(options: OptionsCascade) -> Self {
242        let model_name = match options.get(ai_chain::options::OptDiscriminants::Model) {
243            Some(Opt::Model(model_name)) => model_name.to_name(),
244            _ => "gpt-3.5-turbo".to_string(),
245        };
246        Self::for_model_name(model_name)
247    }
248    /// Creates an OpenAITokenizer for the passed in model name
249    pub fn for_model_name<S: Into<String>>(model_name: S) -> Self {
250        let model_name: String = model_name.into();
251        Self { model_name }
252    }
253
254    fn get_bpe_from_model(&self) -> Result<tiktoken_rs::CoreBPE, PromptTokensError> {
255        use tiktoken_rs::get_bpe_from_model;
256        if self.model_name.starts_with("moonshot") {
257            return cl100k_base().map_err(|_| PromptTokensError::NotAvailable);
258        }
259        get_bpe_from_model(&self.model_name).map_err(|_| PromptTokensError::NotAvailable)
260    }
261}
262
263impl Tokenizer for OpenAITokenizer {
264    fn tokenize_str(&self, doc: &str) -> Result<TokenCollection, TokenizerError> {
265        // if model.starts_with("moonshot") {
266        //     tokenizer = Cl100kBase;
267        Ok(self
268            .get_bpe_from_model()
269            .map_err(|_| TokenizerError::TokenizationError)?
270            .encode_ordinary(doc)
271            .into())
272    }
273
274    fn to_string(&self, tokens: TokenCollection) -> Result<String, TokenizerError> {
275        let res = self
276            .get_bpe_from_model()
277            .map_err(|_e| TokenizerError::ToStringError)?
278            .decode(tokens.as_usize()?)
279            .map_err(|_e| TokenizerError::ToStringError)?;
280        Ok(res)
281    }
282}