langchain_rust/llm/openai/
mod.rs1use std::pin::Pin;
2
3pub use async_openai::config::{AzureConfig, Config, OpenAIConfig};
4use async_openai::{
5 error::OpenAIError,
6 types::{
7 ChatChoiceStream, ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessageArgs,
8 ChatCompletionRequestMessage, ChatCompletionRequestMessageContentPartImageArgs,
9 ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
10 ChatCompletionRequestUserMessageArgs, ChatCompletionRequestUserMessageContent,
11 ChatCompletionRequestUserMessageContentPart, ChatCompletionStreamOptions,
12 ChatCompletionToolArgs, ChatCompletionToolType, CreateChatCompletionRequest,
13 CreateChatCompletionRequestArgs, FunctionObjectArgs,
14 },
15 Client,
16};
17use async_trait::async_trait;
18use futures::{Stream, StreamExt};
19
20use crate::{
21 language_models::{llm::LLM, options::CallOptions, GenerateResult, LLMError, TokenUsage},
22 schemas::{
23 messages::{Message, MessageType},
24 FunctionCallBehavior, StreamData,
25 },
26};
27
28#[derive(Clone)]
29pub enum OpenAIModel {
30 Gpt35,
31 Gpt4,
32 Gpt4Turbo,
33 Gpt4o,
34 Gpt4oMini,
35}
36
37impl ToString for OpenAIModel {
38 fn to_string(&self) -> String {
39 match self {
40 OpenAIModel::Gpt35 => "gpt-3.5-turbo".to_string(),
41 OpenAIModel::Gpt4 => "gpt-4".to_string(),
42 OpenAIModel::Gpt4Turbo => "gpt-4-turbo-preview".to_string(),
43 OpenAIModel::Gpt4o => "gpt-4o".to_string(),
44 OpenAIModel::Gpt4oMini => "gpt-4o-mini".to_string(),
45 }
46 }
47}
48
49impl Into<String> for OpenAIModel {
50 fn into(self) -> String {
51 self.to_string()
52 }
53}
54
55#[derive(Clone)]
56pub struct OpenAI<C: Config> {
57 config: C,
58 options: CallOptions,
59 model: String,
60}
61
62impl<C: Config> OpenAI<C> {
63 pub fn new(config: C) -> Self {
64 Self {
65 config,
66 options: CallOptions::default(),
67 model: OpenAIModel::Gpt4oMini.to_string(),
68 }
69 }
70
71 pub fn with_model<S: Into<String>>(mut self, model: S) -> Self {
72 self.model = model.into();
73 self
74 }
75
76 pub fn with_config(mut self, config: C) -> Self {
77 self.config = config;
78 self
79 }
80
81 pub fn with_options(mut self, options: CallOptions) -> Self {
82 self.options = options;
83 self
84 }
85}
86
87impl Default for OpenAI<OpenAIConfig> {
88 fn default() -> Self {
89 Self::new(OpenAIConfig::default())
90 }
91}
92
93#[async_trait]
94impl<C: Config + Send + Sync + 'static> LLM for OpenAI<C> {
95 async fn generate(&self, prompt: &[Message]) -> Result<GenerateResult, LLMError> {
96 let client = Client::with_config(self.config.clone());
97 let request = self.generate_request(prompt, self.options.streaming_func.is_some())?;
98 match &self.options.streaming_func {
99 Some(func) => {
100 let mut stream = client.chat().create_stream(request).await?;
101 let mut generate_result = GenerateResult::default();
102 while let Some(result) = stream.next().await {
103 match result {
104 Ok(response) => {
105 if let Some(usage) = response.usage {
106 generate_result.tokens = Some(TokenUsage {
107 prompt_tokens: usage.prompt_tokens,
108 completion_tokens: usage.completion_tokens,
109 total_tokens: usage.total_tokens,
110 });
111 }
112 for chat_choice in response.choices.iter() {
113 let chat_choice: ChatChoiceStream = chat_choice.clone();
114 {
115 let mut func = func.lock().await;
116 let _ = func(
117 serde_json::to_string(&chat_choice).unwrap_or("".into()),
118 )
119 .await;
120 }
121 if let Some(content) = chat_choice.delta.content {
122 generate_result.generation.push_str(&content);
123 }
124 }
125 }
126 Err(err) => {
127 eprintln!("Error from streaming response: {:?}", err);
128 }
129 }
130 }
131 Ok(generate_result)
132 }
133 None => {
134 let response = client.chat().create(request).await?;
135 let mut generate_result = GenerateResult::default();
136
137 if let Some(usage) = response.usage {
138 generate_result.tokens = Some(TokenUsage {
139 prompt_tokens: usage.prompt_tokens,
140 completion_tokens: usage.completion_tokens,
141 total_tokens: usage.total_tokens,
142 });
143 }
144
145 if let Some(choice) = &response.choices.first() {
146 generate_result.generation = choice.message.content.clone().unwrap_or_default();
147 if let Some(function) = &choice.message.tool_calls {
148 generate_result.generation =
149 serde_json::to_string(&function).unwrap_or_default();
150 }
151 } else {
152 generate_result.generation = "".to_string();
153 }
154
155 Ok(generate_result)
156 }
157 }
158 }
159
160 async fn invoke(&self, prompt: &str) -> Result<String, LLMError> {
161 self.generate(&[Message::new_human_message(prompt)])
162 .await
163 .map(|res| res.generation)
164 }
165
166 async fn stream(
167 &self,
168 messages: &[Message],
169 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamData, LLMError>> + Send>>, LLMError> {
170 let client = Client::with_config(self.config.clone());
171 let request = self.generate_request(messages, true)?;
172
173 let original_stream = client.chat().create_stream(request).await?;
174
175 let new_stream = original_stream.map(|result| match result {
176 Ok(completion) => {
177 let value_completion = serde_json::to_value(completion).map_err(LLMError::from)?;
178 let usage = value_completion.pointer("/usage");
179 if usage.is_some() && !usage.unwrap().is_null() {
180 let usage = serde_json::from_value::<TokenUsage>(usage.unwrap().clone())
181 .map_err(LLMError::from)?;
182 return Ok(StreamData::new(value_completion, Some(usage), ""));
183 }
184 let content = value_completion
185 .pointer("/choices/0/delta/content")
186 .ok_or(LLMError::ContentNotFound(
187 "/choices/0/delta/content".to_string(),
188 ))?
189 .clone();
190
191 Ok(StreamData::new(
192 value_completion,
193 None,
194 content.as_str().unwrap_or(""),
195 ))
196 }
197 Err(e) => Err(LLMError::from(e)),
198 });
199
200 Ok(Box::pin(new_stream))
201 }
202
203 fn add_options(&mut self, options: CallOptions) {
204 self.options.merge_options(options)
205 }
206}
207
208impl<C: Config> OpenAI<C> {
209 fn to_openai_messages(
210 &self,
211 messages: &[Message],
212 ) -> Result<Vec<ChatCompletionRequestMessage>, LLMError> {
213 let mut openai_messages: Vec<ChatCompletionRequestMessage> = Vec::new();
214 for m in messages {
215 match m.message_type {
216 MessageType::AIMessage => openai_messages.push(match &m.tool_calls {
217 Some(value) => {
218 let function: Vec<ChatCompletionMessageToolCall> =
219 serde_json::from_value(value.clone())?;
220 ChatCompletionRequestAssistantMessageArgs::default()
221 .tool_calls(function)
222 .content(m.content.clone())
223 .build()?
224 .into()
225 }
226 None => ChatCompletionRequestAssistantMessageArgs::default()
227 .content(m.content.clone())
228 .build()?
229 .into(),
230 }),
231 MessageType::HumanMessage => {
232 let content: ChatCompletionRequestUserMessageContent = match m.images.clone() {
233 Some(images) => {
234 let content: Result<
235 Vec<ChatCompletionRequestUserMessageContentPart>,
236 OpenAIError,
237 > = images
238 .into_iter()
239 .map(|image| {
240 Ok(ChatCompletionRequestMessageContentPartImageArgs::default()
241 .image_url(image.image_url)
242 .build()?
243 .into())
244 })
245 .collect();
246
247 content?.into()
248 }
249 None => m.content.clone().into(),
250 };
251
252 openai_messages.push(
253 ChatCompletionRequestUserMessageArgs::default()
254 .content(content)
255 .build()?
256 .into(),
257 )
258 }
259 MessageType::SystemMessage => openai_messages.push(
260 ChatCompletionRequestSystemMessageArgs::default()
261 .content(m.content.clone())
262 .build()?
263 .into(),
264 ),
265 MessageType::ToolMessage => {
266 openai_messages.push(
267 ChatCompletionRequestToolMessageArgs::default()
268 .content(m.content.clone())
269 .tool_call_id(m.id.clone().unwrap_or_default())
270 .build()?
271 .into(),
272 );
273 }
274 }
275 }
276 Ok(openai_messages)
277 }
278
279 fn generate_request(
280 &self,
281 messages: &[Message],
282 stream: bool,
283 ) -> Result<CreateChatCompletionRequest, LLMError> {
284 let messages: Vec<ChatCompletionRequestMessage> = self.to_openai_messages(messages)?;
285 let mut request_builder = CreateChatCompletionRequestArgs::default();
286 if let Some(temperature) = self.options.temperature {
287 request_builder.temperature(temperature);
288 }
289 if let Some(max_tokens) = self.options.max_tokens {
290 request_builder.max_tokens(max_tokens);
291 }
292 if stream {
293 if let Some(include_usage) = self.options.stream_usage {
294 request_builder.stream_options(ChatCompletionStreamOptions { include_usage });
295 }
296 }
297 request_builder.model(self.model.to_string());
298 if let Some(stop_words) = &self.options.stop_words {
299 request_builder.stop(stop_words);
300 }
301
302 if let Some(behavior) = &self.options.functions {
303 let mut functions = Vec::new();
304 for f in behavior.iter() {
305 let tool = FunctionObjectArgs::default()
306 .name(f.name.clone())
307 .description(f.description.clone())
308 .parameters(f.parameters.clone())
309 .build()?;
310 functions.push(
311 ChatCompletionToolArgs::default()
312 .r#type(ChatCompletionToolType::Function)
313 .function(tool)
314 .build()?,
315 )
316 }
317 request_builder.tools(functions);
318 }
319
320 if let Some(behavior) = &self.options.function_call_behavior {
321 match behavior {
322 FunctionCallBehavior::Auto => request_builder.tool_choice("auto"),
323 FunctionCallBehavior::None => request_builder.tool_choice("none"),
324 FunctionCallBehavior::Named(name) => request_builder.tool_choice(name.as_str()),
325 };
326 }
327 request_builder.messages(messages);
328 Ok(request_builder.build()?)
329 }
330}
331#[cfg(test)]
332mod tests {
333
334 use crate::schemas::FunctionDefinition;
335
336 use super::*;
337
338 use base64::prelude::*;
339 use serde_json::json;
340 use std::sync::Arc;
341 use tokio::sync::Mutex;
342 use tokio::test;
343
344 #[test]
345 #[ignore]
346 async fn test_invoke() {
347 let message_complete = Arc::new(Mutex::new(String::new()));
348
349 let streaming_func = {
352 let message_complete = message_complete.clone();
353 move |content: String| {
354 let message_complete = message_complete.clone();
355 async move {
356 let mut message_complete_lock = message_complete.lock().await;
357 println!("Content: {:?}", content);
358 message_complete_lock.push_str(&content);
359 Ok(())
360 }
361 }
362 };
363 let options = CallOptions::new().with_streaming_func(streaming_func);
364 let open_ai = OpenAI::new(OpenAIConfig::default())
366 .with_model(OpenAIModel::Gpt35.to_string()) .with_options(options);
368
369 match open_ai.invoke("hola").await {
373 Ok(result) => {
374 println!("Generate Result: {:?}", result);
376 println!("Message Complete: {:?}", message_complete.lock().await);
377 }
378 Err(e) => {
379 eprintln!("Error calling generate: {:?}", e);
381 }
382 }
383 }
384
385 #[test]
386 #[ignore]
387 async fn test_generate_function() {
388 let message_complete = Arc::new(Mutex::new(String::new()));
389
390 let streaming_func = {
393 let message_complete = message_complete.clone();
394 move |content: String| {
395 let message_complete = message_complete.clone();
396 async move {
397 let content = serde_json::from_str::<ChatChoiceStream>(&content).unwrap();
398 if content.finish_reason.is_some() {
399 return Ok(());
400 }
401 let mut message_complete_lock = message_complete.lock().await;
402 println!("Content: {:?}", content);
403 message_complete_lock.push_str(&content.delta.content.unwrap());
404 Ok(())
405 }
406 }
407 };
408 let options = CallOptions::new().with_streaming_func(streaming_func);
410 let open_ai = OpenAI::new(OpenAIConfig::default())
412 .with_model(OpenAIModel::Gpt35.to_string()) .with_options(options);
414
415 let messages = vec![Message::new_human_message("Hello, how are you?")];
417
418 match open_ai.generate(&messages).await {
420 Ok(result) => {
421 println!("Generate Result: {:?}", result);
423 println!("Message Complete: {:?}", message_complete.lock().await);
424 }
425 Err(e) => {
426 eprintln!("Error calling generate: {:?}", e);
428 }
429 }
430 }
431
432 #[test]
433 #[ignore]
434 async fn test_openai_stream() {
435 let open_ai = OpenAI::default().with_model(OpenAIModel::Gpt35.to_string());
437
438 let messages = vec![Message::new_human_message("Hello, how are you?")];
440
441 open_ai
442 .stream(&messages)
443 .await
444 .unwrap()
445 .for_each(|result| async {
446 match result {
447 Ok(stream_data) => {
448 println!("Stream Data: {:?}", stream_data.content);
449 }
450 Err(e) => {
451 eprintln!("Error calling generate: {:?}", e);
452 }
453 }
454 })
455 .await;
456 }
457
458 #[test]
459 #[ignore]
460 async fn test_function() {
461 let mut functions = Vec::new();
462 functions.push(FunctionDefinition {
463 name: "cli".to_string(),
464 description: "Use the Ubuntu command line to preform any action you wish.".to_string(),
465 parameters: json!({
466 "type": "object",
467 "properties": {
468 "command": {
469 "type": "string",
470 "description": "The raw command you want executed"
471 }
472 },
473 "required": ["command"]
474 }),
475 });
476
477 let llm = OpenAI::default()
478 .with_model(OpenAIModel::Gpt35)
479 .with_config(OpenAIConfig::new())
480 .with_options(CallOptions::new().with_functions(functions));
481 let response = llm
482 .invoke("Use the command line to create a new rust project. Execute the first command.")
483 .await
484 .unwrap();
485 println!("{}", response)
486 }
487
488 #[test]
489 #[ignore]
490 async fn test_generate_with_image_message() {
491 let open_ai =
493 OpenAI::new(OpenAIConfig::default()).with_model(OpenAIModel::Gpt4o.to_string());
494
495 let image = std::fs::read("./src/llm/test_data/example.jpg").unwrap();
497 let image_base64 = BASE64_STANDARD.encode(image);
498
499 let image_urls = vec![format!("data:image/jpeg;base64,{image_base64}")];
501 let messages = vec![
502 Message::new_human_message("Describe this image"),
503 Message::new_human_message_with_images(image_urls),
504 ];
505
506 let response = open_ai.generate(&messages).await.unwrap();
508 println!("Response: {:?}", response);
509 }
510}