langchain_rust/llm/openai/
mod.rs

1use std::pin::Pin;
2
3pub use async_openai::config::{AzureConfig, Config, OpenAIConfig};
4use async_openai::{
5    error::OpenAIError,
6    types::{
7        ChatChoiceStream, ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessageArgs,
8        ChatCompletionRequestMessage, ChatCompletionRequestMessageContentPartImageArgs,
9        ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
10        ChatCompletionRequestUserMessageArgs, ChatCompletionRequestUserMessageContent,
11        ChatCompletionRequestUserMessageContentPart, ChatCompletionStreamOptions,
12        ChatCompletionToolArgs, ChatCompletionToolType, CreateChatCompletionRequest,
13        CreateChatCompletionRequestArgs, FunctionObjectArgs,
14    },
15    Client,
16};
17use async_trait::async_trait;
18use futures::{Stream, StreamExt};
19
20use crate::{
21    language_models::{llm::LLM, options::CallOptions, GenerateResult, LLMError, TokenUsage},
22    schemas::{
23        messages::{Message, MessageType},
24        FunctionCallBehavior, StreamData,
25    },
26};
27
28#[derive(Clone)]
29pub enum OpenAIModel {
30    Gpt35,
31    Gpt4,
32    Gpt4Turbo,
33    Gpt4o,
34    Gpt4oMini,
35}
36
37impl ToString for OpenAIModel {
38    fn to_string(&self) -> String {
39        match self {
40            OpenAIModel::Gpt35 => "gpt-3.5-turbo".to_string(),
41            OpenAIModel::Gpt4 => "gpt-4".to_string(),
42            OpenAIModel::Gpt4Turbo => "gpt-4-turbo-preview".to_string(),
43            OpenAIModel::Gpt4o => "gpt-4o".to_string(),
44            OpenAIModel::Gpt4oMini => "gpt-4o-mini".to_string(),
45        }
46    }
47}
48
49impl Into<String> for OpenAIModel {
50    fn into(self) -> String {
51        self.to_string()
52    }
53}
54
55#[derive(Clone)]
56pub struct OpenAI<C: Config> {
57    config: C,
58    options: CallOptions,
59    model: String,
60}
61
62impl<C: Config> OpenAI<C> {
63    pub fn new(config: C) -> Self {
64        Self {
65            config,
66            options: CallOptions::default(),
67            model: OpenAIModel::Gpt4oMini.to_string(),
68        }
69    }
70
71    pub fn with_model<S: Into<String>>(mut self, model: S) -> Self {
72        self.model = model.into();
73        self
74    }
75
76    pub fn with_config(mut self, config: C) -> Self {
77        self.config = config;
78        self
79    }
80
81    pub fn with_options(mut self, options: CallOptions) -> Self {
82        self.options = options;
83        self
84    }
85}
86
87impl Default for OpenAI<OpenAIConfig> {
88    fn default() -> Self {
89        Self::new(OpenAIConfig::default())
90    }
91}
92
93#[async_trait]
94impl<C: Config + Send + Sync + 'static> LLM for OpenAI<C> {
95    async fn generate(&self, prompt: &[Message]) -> Result<GenerateResult, LLMError> {
96        let client = Client::with_config(self.config.clone());
97        let request = self.generate_request(prompt, self.options.streaming_func.is_some())?;
98        match &self.options.streaming_func {
99            Some(func) => {
100                let mut stream = client.chat().create_stream(request).await?;
101                let mut generate_result = GenerateResult::default();
102                while let Some(result) = stream.next().await {
103                    match result {
104                        Ok(response) => {
105                            if let Some(usage) = response.usage {
106                                generate_result.tokens = Some(TokenUsage {
107                                    prompt_tokens: usage.prompt_tokens,
108                                    completion_tokens: usage.completion_tokens,
109                                    total_tokens: usage.total_tokens,
110                                });
111                            }
112                            for chat_choice in response.choices.iter() {
113                                let chat_choice: ChatChoiceStream = chat_choice.clone();
114                                {
115                                    let mut func = func.lock().await;
116                                    let _ = func(
117                                        serde_json::to_string(&chat_choice).unwrap_or("".into()),
118                                    )
119                                    .await;
120                                }
121                                if let Some(content) = chat_choice.delta.content {
122                                    generate_result.generation.push_str(&content);
123                                }
124                            }
125                        }
126                        Err(err) => {
127                            eprintln!("Error from streaming response: {:?}", err);
128                        }
129                    }
130                }
131                Ok(generate_result)
132            }
133            None => {
134                let response = client.chat().create(request).await?;
135                let mut generate_result = GenerateResult::default();
136
137                if let Some(usage) = response.usage {
138                    generate_result.tokens = Some(TokenUsage {
139                        prompt_tokens: usage.prompt_tokens,
140                        completion_tokens: usage.completion_tokens,
141                        total_tokens: usage.total_tokens,
142                    });
143                }
144
145                if let Some(choice) = &response.choices.first() {
146                    generate_result.generation = choice.message.content.clone().unwrap_or_default();
147                    if let Some(function) = &choice.message.tool_calls {
148                        generate_result.generation =
149                            serde_json::to_string(&function).unwrap_or_default();
150                    }
151                } else {
152                    generate_result.generation = "".to_string();
153                }
154
155                Ok(generate_result)
156            }
157        }
158    }
159
160    async fn invoke(&self, prompt: &str) -> Result<String, LLMError> {
161        self.generate(&[Message::new_human_message(prompt)])
162            .await
163            .map(|res| res.generation)
164    }
165
166    async fn stream(
167        &self,
168        messages: &[Message],
169    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, LLMError>> + Send>>, LLMError> {
170        let client = Client::with_config(self.config.clone());
171        let request = self.generate_request(messages, true)?;
172
173        let original_stream = client.chat().create_stream(request).await?;
174
175        let new_stream = original_stream.map(|result| match result {
176            Ok(completion) => {
177                let value_completion = serde_json::to_value(completion).map_err(LLMError::from)?;
178                let usage = value_completion.pointer("/usage");
179                if usage.is_some() && !usage.unwrap().is_null() {
180                    let usage = serde_json::from_value::<TokenUsage>(usage.unwrap().clone())
181                        .map_err(LLMError::from)?;
182                    return Ok(StreamData::new(value_completion, Some(usage), ""));
183                }
184                let content = value_completion
185                    .pointer("/choices/0/delta/content")
186                    .ok_or(LLMError::ContentNotFound(
187                        "/choices/0/delta/content".to_string(),
188                    ))?
189                    .clone();
190
191                Ok(StreamData::new(
192                    value_completion,
193                    None,
194                    content.as_str().unwrap_or(""),
195                ))
196            }
197            Err(e) => Err(LLMError::from(e)),
198        });
199
200        Ok(Box::pin(new_stream))
201    }
202
203    fn add_options(&mut self, options: CallOptions) {
204        self.options.merge_options(options)
205    }
206}
207
208impl<C: Config> OpenAI<C> {
209    fn to_openai_messages(
210        &self,
211        messages: &[Message],
212    ) -> Result<Vec<ChatCompletionRequestMessage>, LLMError> {
213        let mut openai_messages: Vec<ChatCompletionRequestMessage> = Vec::new();
214        for m in messages {
215            match m.message_type {
216                MessageType::AIMessage => openai_messages.push(match &m.tool_calls {
217                    Some(value) => {
218                        let function: Vec<ChatCompletionMessageToolCall> =
219                            serde_json::from_value(value.clone())?;
220                        ChatCompletionRequestAssistantMessageArgs::default()
221                            .tool_calls(function)
222                            .content(m.content.clone())
223                            .build()?
224                            .into()
225                    }
226                    None => ChatCompletionRequestAssistantMessageArgs::default()
227                        .content(m.content.clone())
228                        .build()?
229                        .into(),
230                }),
231                MessageType::HumanMessage => {
232                    let content: ChatCompletionRequestUserMessageContent = match m.images.clone() {
233                        Some(images) => {
234                            let content: Result<
235                                Vec<ChatCompletionRequestUserMessageContentPart>,
236                                OpenAIError,
237                            > = images
238                                .into_iter()
239                                .map(|image| {
240                                    Ok(ChatCompletionRequestMessageContentPartImageArgs::default()
241                                        .image_url(image.image_url)
242                                        .build()?
243                                        .into())
244                                })
245                                .collect();
246
247                            content?.into()
248                        }
249                        None => m.content.clone().into(),
250                    };
251
252                    openai_messages.push(
253                        ChatCompletionRequestUserMessageArgs::default()
254                            .content(content)
255                            .build()?
256                            .into(),
257                    )
258                }
259                MessageType::SystemMessage => openai_messages.push(
260                    ChatCompletionRequestSystemMessageArgs::default()
261                        .content(m.content.clone())
262                        .build()?
263                        .into(),
264                ),
265                MessageType::ToolMessage => {
266                    openai_messages.push(
267                        ChatCompletionRequestToolMessageArgs::default()
268                            .content(m.content.clone())
269                            .tool_call_id(m.id.clone().unwrap_or_default())
270                            .build()?
271                            .into(),
272                    );
273                }
274            }
275        }
276        Ok(openai_messages)
277    }
278
279    fn generate_request(
280        &self,
281        messages: &[Message],
282        stream: bool,
283    ) -> Result<CreateChatCompletionRequest, LLMError> {
284        let messages: Vec<ChatCompletionRequestMessage> = self.to_openai_messages(messages)?;
285        let mut request_builder = CreateChatCompletionRequestArgs::default();
286        if let Some(temperature) = self.options.temperature {
287            request_builder.temperature(temperature);
288        }
289        if let Some(max_tokens) = self.options.max_tokens {
290            request_builder.max_tokens(max_tokens);
291        }
292        if stream {
293            if let Some(include_usage) = self.options.stream_usage {
294                request_builder.stream_options(ChatCompletionStreamOptions { include_usage });
295            }
296        }
297        request_builder.model(self.model.to_string());
298        if let Some(stop_words) = &self.options.stop_words {
299            request_builder.stop(stop_words);
300        }
301
302        if let Some(behavior) = &self.options.functions {
303            let mut functions = Vec::new();
304            for f in behavior.iter() {
305                let tool = FunctionObjectArgs::default()
306                    .name(f.name.clone())
307                    .description(f.description.clone())
308                    .parameters(f.parameters.clone())
309                    .build()?;
310                functions.push(
311                    ChatCompletionToolArgs::default()
312                        .r#type(ChatCompletionToolType::Function)
313                        .function(tool)
314                        .build()?,
315                )
316            }
317            request_builder.tools(functions);
318        }
319
320        if let Some(behavior) = &self.options.function_call_behavior {
321            match behavior {
322                FunctionCallBehavior::Auto => request_builder.tool_choice("auto"),
323                FunctionCallBehavior::None => request_builder.tool_choice("none"),
324                FunctionCallBehavior::Named(name) => request_builder.tool_choice(name.as_str()),
325            };
326        }
327        request_builder.messages(messages);
328        Ok(request_builder.build()?)
329    }
330}
331#[cfg(test)]
332mod tests {
333
334    use crate::schemas::FunctionDefinition;
335
336    use super::*;
337
338    use base64::prelude::*;
339    use serde_json::json;
340    use std::sync::Arc;
341    use tokio::sync::Mutex;
342    use tokio::test;
343
344    #[test]
345    #[ignore]
346    async fn test_invoke() {
347        let message_complete = Arc::new(Mutex::new(String::new()));
348
349        // Define the streaming function
350        // This function will append the content received from the stream to `message_complete`
351        let streaming_func = {
352            let message_complete = message_complete.clone();
353            move |content: String| {
354                let message_complete = message_complete.clone();
355                async move {
356                    let mut message_complete_lock = message_complete.lock().await;
357                    println!("Content: {:?}", content);
358                    message_complete_lock.push_str(&content);
359                    Ok(())
360                }
361            }
362        };
363        let options = CallOptions::new().with_streaming_func(streaming_func);
364        // Setup the OpenAI client with the necessary options
365        let open_ai = OpenAI::new(OpenAIConfig::default())
366            .with_model(OpenAIModel::Gpt35.to_string()) // You can change the model as needed
367            .with_options(options);
368
369        // Define a set of messages to send to the generate function
370
371        // Call the generate function
372        match open_ai.invoke("hola").await {
373            Ok(result) => {
374                // Print the response from the generate function
375                println!("Generate Result: {:?}", result);
376                println!("Message Complete: {:?}", message_complete.lock().await);
377            }
378            Err(e) => {
379                // Handle any errors
380                eprintln!("Error calling generate: {:?}", e);
381            }
382        }
383    }
384
385    #[test]
386    #[ignore]
387    async fn test_generate_function() {
388        let message_complete = Arc::new(Mutex::new(String::new()));
389
390        // Define the streaming function
391        // This function will append the content received from the stream to `message_complete`
392        let streaming_func = {
393            let message_complete = message_complete.clone();
394            move |content: String| {
395                let message_complete = message_complete.clone();
396                async move {
397                    let content = serde_json::from_str::<ChatChoiceStream>(&content).unwrap();
398                    if content.finish_reason.is_some() {
399                        return Ok(());
400                    }
401                    let mut message_complete_lock = message_complete.lock().await;
402                    println!("Content: {:?}", content);
403                    message_complete_lock.push_str(&content.delta.content.unwrap());
404                    Ok(())
405                }
406            }
407        };
408        // Define the streaming function as an async block without capturing external references directly
409        let options = CallOptions::new().with_streaming_func(streaming_func);
410        // Setup the OpenAI client with the necessary options
411        let open_ai = OpenAI::new(OpenAIConfig::default())
412            .with_model(OpenAIModel::Gpt35.to_string()) // You can change the model as needed
413            .with_options(options);
414
415        // Define a set of messages to send to the generate function
416        let messages = vec![Message::new_human_message("Hello, how are you?")];
417
418        // Call the generate function
419        match open_ai.generate(&messages).await {
420            Ok(result) => {
421                // Print the response from the generate function
422                println!("Generate Result: {:?}", result);
423                println!("Message Complete: {:?}", message_complete.lock().await);
424            }
425            Err(e) => {
426                // Handle any errors
427                eprintln!("Error calling generate: {:?}", e);
428            }
429        }
430    }
431
432    #[test]
433    #[ignore]
434    async fn test_openai_stream() {
435        // Setup the OpenAI client with the necessary options
436        let open_ai = OpenAI::default().with_model(OpenAIModel::Gpt35.to_string());
437
438        // Define a set of messages to send to the generate function
439        let messages = vec![Message::new_human_message("Hello, how are you?")];
440
441        open_ai
442            .stream(&messages)
443            .await
444            .unwrap()
445            .for_each(|result| async {
446                match result {
447                    Ok(stream_data) => {
448                        println!("Stream Data: {:?}", stream_data.content);
449                    }
450                    Err(e) => {
451                        eprintln!("Error calling generate: {:?}", e);
452                    }
453                }
454            })
455            .await;
456    }
457
458    #[test]
459    #[ignore]
460    async fn test_function() {
461        let mut functions = Vec::new();
462        functions.push(FunctionDefinition {
463            name: "cli".to_string(),
464            description: "Use the Ubuntu command line to preform any action you wish.".to_string(),
465            parameters: json!({
466                "type": "object",
467                "properties": {
468                    "command": {
469                        "type": "string",
470                        "description": "The raw command you want executed"
471                    }
472                },
473                "required": ["command"]
474            }),
475        });
476
477        let llm = OpenAI::default()
478            .with_model(OpenAIModel::Gpt35)
479            .with_config(OpenAIConfig::new())
480            .with_options(CallOptions::new().with_functions(functions));
481        let response = llm
482            .invoke("Use the command line to create a new rust project. Execute the first command.")
483            .await
484            .unwrap();
485        println!("{}", response)
486    }
487
488    #[test]
489    #[ignore]
490    async fn test_generate_with_image_message() {
491        // Setup the OpenAI client with the necessary options
492        let open_ai =
493            OpenAI::new(OpenAIConfig::default()).with_model(OpenAIModel::Gpt4o.to_string());
494
495        // Convert image to base64
496        let image = std::fs::read("./src/llm/test_data/example.jpg").unwrap();
497        let image_base64 = BASE64_STANDARD.encode(image);
498
499        // Define a set of messages to send to the generate function
500        let image_urls = vec![format!("data:image/jpeg;base64,{image_base64}")];
501        let messages = vec![
502            Message::new_human_message("Describe this image"),
503            Message::new_human_message_with_images(image_urls),
504        ];
505
506        // Call the generate function
507        let response = open_ai.generate(&messages).await.unwrap();
508        println!("Response: {:?}", response);
509    }
510}