llm_chain_openai/chatgpt/
executor.rs

1use super::error::OpenAIInnerError;
2use super::prompt::completion_to_output;
3use super::prompt::stream_to_output;
4use async_openai::config::OpenAIConfig;
5use async_openai::types::ChatCompletionRequestMessage;
6
7use async_openai::types::ChatCompletionRequestUserMessageContent;
8use llm_chain::options::Opt;
9use llm_chain::options::Options;
10use llm_chain::options::OptionsCascade;
11use llm_chain::output::Output;
12use llm_chain::tokens::TokenCollection;
13use tiktoken_rs::get_bpe_from_tokenizer;
14use tiktoken_rs::tokenizer::get_tokenizer;
15
16use super::prompt::create_chat_completion_request;
17use super::prompt::format_chat_messages;
18use async_openai::error::OpenAIError;
19use llm_chain::prompt::Prompt;
20
21use llm_chain::tokens::PromptTokensError;
22use llm_chain::tokens::{Tokenizer, TokenizerError};
23use llm_chain::traits;
24use llm_chain::traits::{ExecutorCreationError, ExecutorError};
25
26use async_trait::async_trait;
27use llm_chain::tokens::TokenCount;
28
29use std::sync::Arc;
30
31/// The `Executor` struct for the ChatGPT model. This executor uses the `async_openai` crate to communicate with the OpenAI API.
32#[derive(Clone)]
33pub struct Executor {
34    /// The client used to communicate with the OpenAI API.
35    client: Arc<async_openai::Client<OpenAIConfig>>,
36    /// The per-invocation options for this executor.
37    options: Options,
38}
39
40impl Default for Executor {
41    fn default() -> Self {
42        let options = Options::default();
43        let client = Arc::new(async_openai::Client::new());
44        Self { client, options }
45    }
46}
47
48impl Executor {
49    /// Creates a new `Executor` with the given client.
50    pub fn for_client(client: async_openai::Client<OpenAIConfig>, options: Options) -> Self {
51        use llm_chain::traits::Executor as _;
52        let mut exec = Self::new_with_options(options).unwrap();
53        exec.client = Arc::new(client);
54        exec
55    }
56
57    fn get_model_from_invocation_options(&self, opts: &OptionsCascade) -> String {
58        let Some(Opt::Model(model)) = opts.get(llm_chain::options::OptDiscriminants::Model) else {
59            return "gpt-3.5-turbo".to_string();
60        };
61        model.to_name()
62    }
63
64    fn cascade<'a>(&'a self, opts: Option<&'a Options>) -> OptionsCascade<'a> {
65        let mut v: Vec<&'a Options> = vec![&self.options];
66        if let Some(o) = opts {
67            v.push(o);
68        }
69        OptionsCascade::from_vec(v)
70    }
71}
72
73#[derive(thiserror::Error, Debug)]
74#[error(transparent)]
75pub enum Error {
76    OpenAIError(#[from] OpenAIError),
77}
78
79#[async_trait]
80impl traits::Executor for Executor {
81    type StepTokenizer<'a> = OpenAITokenizer;
82    /// Creates a new `Executor` with the given options.
83    ///
84    /// if the `OPENAI_ORG_ID` environment variable is present, it will be used as the org_ig for the OpenAI client.
85    fn new_with_options(options: Options) -> Result<Self, ExecutorCreationError> {
86        let mut cfg = OpenAIConfig::new();
87
88        let opts = OptionsCascade::new().with_options(&options);
89
90        if let Some(Opt::ApiKey(api_key)) = opts.get(llm_chain::options::OptDiscriminants::ApiKey) {
91            cfg = cfg.with_api_key(api_key)
92        }
93
94        if let Ok(org_id) = std::env::var("OPENAI_ORG_ID") {
95            cfg = cfg.with_org_id(org_id);
96        }
97        let client = Arc::new(async_openai::Client::with_config(cfg));
98        Ok(Self { client, options })
99    }
100
101    async fn execute(&self, options: &Options, prompt: &Prompt) -> Result<Output, ExecutorError> {
102        let opts = self.cascade(Some(options));
103        let client = self.client.clone();
104        let model = self.get_model_from_invocation_options(&opts);
105        let input = create_chat_completion_request(model, prompt, opts.is_streaming()).unwrap();
106        if opts.is_streaming() {
107            let res = async move { client.chat().create_stream(input).await }
108                .await
109                .map_err(|e| ExecutorError::InnerError(e.into()))?;
110            Ok(stream_to_output(res))
111        } else {
112            let res = async move { client.chat().create(input).await }
113                .await
114                .map_err(|e| ExecutorError::InnerError(e.into()))?;
115            Ok(completion_to_output(res))
116        }
117    }
118
119    fn tokens_used(
120        &self,
121        opts: &Options,
122        prompt: &Prompt,
123    ) -> Result<TokenCount, PromptTokensError> {
124        let opts_cas = self.cascade(Some(opts));
125        let model = self.get_model_from_invocation_options(&opts_cas);
126        let messages = format_chat_messages(prompt.to_chat()).map_err(|e| match e {
127            OpenAIInnerError::StringTemplateError(e) => PromptTokensError::PromptFormatFailed(e),
128            _ => PromptTokensError::UnableToCompute,
129        })?;
130        let tokens_used = num_tokens_from_messages(&model, &messages)
131            .map_err(|_| PromptTokensError::NotAvailable)?;
132
133        Ok(TokenCount::new(
134            self.max_tokens_allowed(opts),
135            tokens_used as i32,
136        ))
137    }
138    /// Get the context size from the model or return default context size
139    fn max_tokens_allowed(&self, opts: &Options) -> i32 {
140        let opts_cas = self.cascade(Some(opts));
141        let model = self.get_model_from_invocation_options(&opts_cas);
142        tiktoken_rs::model::get_context_size(&model)
143            .try_into()
144            .unwrap_or(4096)
145    }
146
147    fn answer_prefix(&self, _prompt: &Prompt) -> Option<String> {
148        None
149    }
150
151    fn get_tokenizer(&self, options: &Options) -> Result<OpenAITokenizer, TokenizerError> {
152        Ok(OpenAITokenizer::new(self.cascade(Some(options))))
153    }
154}
155
156fn num_tokens_from_messages(
157    model: &str,
158    messages: &[ChatCompletionRequestMessage],
159) -> Result<usize, PromptTokensError> {
160    let tokenizer = get_tokenizer(model).ok_or_else(|| PromptTokensError::NotAvailable)?;
161    if tokenizer != tiktoken_rs::tokenizer::Tokenizer::Cl100kBase {
162        return Err(PromptTokensError::NotAvailable);
163    }
164    let bpe = get_bpe_from_tokenizer(tokenizer).map_err(|_| PromptTokensError::NotAvailable)?;
165
166    let (tokens_per_message, tokens_per_name) = if model.starts_with("gpt-3.5") {
167        (
168            4,  // every message follows <im_start>{role/name}\n{content}<im_end>\n
169            -1, // if there's a name, the role is omitted
170        )
171    } else {
172        (3, 1)
173    };
174
175    let mut num_tokens: i32 = 0;
176    for message in messages {
177        let (role, content, name) = match message {
178            ChatCompletionRequestMessage::System(x) => (
179                x.role.to_string(),
180                x.content.to_owned().unwrap_or_default(),
181                None,
182            ),
183            ChatCompletionRequestMessage::User(x) => (
184                x.role.to_string(),
185                x.content
186                    .as_ref()
187                    .and_then(|x| match x {
188                        ChatCompletionRequestUserMessageContent::Text(x) => Some(x.to_string()),
189                        _ => None,
190                    })
191                    .unwrap_or_default(),
192                None,
193            ),
194            ChatCompletionRequestMessage::Assistant(x) => (
195                x.role.to_string(),
196                x.content.to_owned().unwrap_or_default(),
197                None,
198            ),
199            ChatCompletionRequestMessage::Tool(x) => (
200                x.role.to_string(),
201                x.content.to_owned().unwrap_or_default(),
202                None,
203            ),
204            ChatCompletionRequestMessage::Function(x) => (
205                x.role.to_string(),
206                x.content.to_owned().unwrap_or_default(),
207                None,
208            ),
209        };
210        num_tokens += tokens_per_message;
211        num_tokens += bpe.encode_with_special_tokens(&role).len() as i32;
212        num_tokens += bpe.encode_with_special_tokens(&content).len() as i32;
213        if let Some(name) = name {
214            num_tokens += bpe.encode_with_special_tokens(name).len() as i32;
215            num_tokens += tokens_per_name;
216        }
217    }
218    num_tokens += 3; // every reply is primed with <|start|>assistant<|message|>
219    Ok(num_tokens as usize)
220}
221
222pub struct OpenAITokenizer {
223    model_name: String,
224}
225
226impl OpenAITokenizer {
227    pub fn new(options: OptionsCascade) -> Self {
228        let model_name = match options.get(llm_chain::options::OptDiscriminants::Model) {
229            Some(Opt::Model(model_name)) => model_name.to_name(),
230            _ => "gpt-3.5-turbo".to_string(),
231        };
232        Self::for_model_name(model_name)
233    }
234    /// Creates an OpenAITokenizer for the passed in model name
235    pub fn for_model_name<S: Into<String>>(model_name: S) -> Self {
236        let model_name: String = model_name.into();
237        Self { model_name }
238    }
239
240    fn get_bpe_from_model(&self) -> Result<tiktoken_rs::CoreBPE, PromptTokensError> {
241        use tiktoken_rs::get_bpe_from_model;
242        get_bpe_from_model(&self.model_name).map_err(|_| PromptTokensError::NotAvailable)
243    }
244}
245
246impl Tokenizer for OpenAITokenizer {
247    fn tokenize_str(&self, doc: &str) -> Result<TokenCollection, TokenizerError> {
248        Ok(self
249            .get_bpe_from_model()
250            .map_err(|_| TokenizerError::TokenizationError)?
251            .encode_ordinary(doc)
252            .into())
253    }
254
255    fn to_string(&self, tokens: TokenCollection) -> Result<String, TokenizerError> {
256        let res = self
257            .get_bpe_from_model()
258            .map_err(|_e| TokenizerError::ToStringError)?
259            .decode(tokens.as_usize()?)
260            .map_err(|_e| TokenizerError::ToStringError)?;
261        Ok(res)
262    }
263}