omni_llm_kit/models/openai_provider/
openai_model.rs

1use anyhow::anyhow;
2use futures_core::Stream;
3use futures_core::future::BoxFuture;
4use futures_core::stream::BoxStream;
5use schemars::JsonSchema;
6use std::collections::HashMap;
7use std::pin::Pin;
8use std::sync::Arc;
9// use futures_core::{future::BoxFuture, stream::{BoxStream};
10
11use crate::OpenAiSettings;
12use crate::http_client::HttpClient;
13use crate::model::{
14    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
15    LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest,
16    LanguageModelToolChoice, LanguageModelToolResultContent, MessageContent, Role,
17};
18use crate::models::openai_provider::event_mapper::OpenAiEventMapper;
19use crate::openai::{self, ImageUrl, ResponseStreamEvent};
20use futures_util::{FutureExt, StreamExt};
21use log::info;
22use serde::{Deserialize, Serialize};
23use strum::EnumIter;
24
25pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
26pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
27    LanguageModelProviderName::new("OpenAI");
28
29pub struct OpenAiLanguageModel {
30    pub(crate) id: LanguageModelId,
31    pub(crate) model: openai::Model,
32    // pub(crate) state: State,
33    pub(crate) http_client: Arc<dyn HttpClient>,
34}
35
36impl OpenAiLanguageModel {
37    async fn stream_completion(
38        &self,
39        request: openai::Request,
40    ) -> anyhow::Result<BoxStream<'static, anyhow::Result<ResponseStreamEvent>>> {
41        let http_client = self.http_client.clone();
42        let openai_settings =
43            global_registry::get!(OpenAiSettings).expect("OpenAiSettings not found");
44        let api_key = openai_settings.api_key.clone();
45        let base_url = openai_settings.api_url.clone();
46
47        let response =
48            openai::stream_completion(http_client.as_ref(), &base_url, &api_key, request).await?;
49        Ok(response.boxed())
50    }
51}
52#[async_trait::async_trait]
53impl LanguageModel for OpenAiLanguageModel {
54    fn id(&self) -> LanguageModelId {
55        self.id.clone()
56    }
57
58    fn name(&self) -> LanguageModelName {
59        LanguageModelName::from(self.model.display_name().to_string())
60    }
61
62    fn provider_id(&self) -> LanguageModelProviderId {
63        OPEN_AI_PROVIDER_ID
64    }
65
66    fn provider_name(&self) -> LanguageModelProviderName {
67        OPEN_AI_PROVIDER_NAME
68    }
69
70    fn max_token_count(&self) -> u64 {
71        self.model.max_token_count()
72    }
73    fn max_output_tokens(&self) -> Option<u64> {
74        self.model.max_output_tokens()
75    }
76    fn supports_tools(&self) -> bool {
77        return true;
78    }
79    fn supports_burn_mode(&self) -> bool {
80        return false;
81    }
82    async fn stream_completion(
83        &self,
84        request: LanguageModelRequest,
85    ) -> Result<
86        BoxStream<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
87        LanguageModelCompletionError,
88    > {
89        let request = into_open_ai(
90            request,
91            self.model.id(),
92            self.model.supports_parallel_tool_calls(),
93            self.max_output_tokens(),
94        );
95        let completion = self.stream_completion(request).await?.boxed();
96        let mapper = OpenAiEventMapper::new();
97        Ok(mapper.map_stream(completion).boxed())
98    }
99}
100fn add_message_content_part(
101    new_part: openai::MessagePart,
102    role: Role,
103    messages: &mut Vec<openai::RequestMessage>,
104) {
105    match (role, messages.last_mut()) {
106        (Role::User, Some(openai::RequestMessage::User { content }))
107        | (
108            Role::Assistant,
109            Some(openai::RequestMessage::Assistant {
110                content: Some(content),
111                ..
112            }),
113        )
114        | (Role::System, Some(openai::RequestMessage::System { content, .. })) => {
115            content.push_part(new_part);
116        }
117        _ => {
118            messages.push(match role {
119                Role::User => openai::RequestMessage::User {
120                    content: openai::MessageContent::from(vec![new_part]),
121                },
122                Role::Assistant => openai::RequestMessage::Assistant {
123                    content: Some(openai::MessageContent::from(vec![new_part])),
124                    tool_calls: Vec::new(),
125                },
126                Role::System => openai::RequestMessage::System {
127                    content: openai::MessageContent::from(vec![new_part]),
128                },
129            });
130        }
131    }
132}
133pub fn into_open_ai(
134    request: LanguageModelRequest,
135    model_id: &str,
136    supports_parallel_tool_calls: bool,
137    max_output_tokens: Option<u64>,
138) -> openai::Request {
139    let stream = !model_id.starts_with("o1-");
140
141    let mut messages = Vec::new();
142    for message in request.messages {
143        for content in message.content {
144            match content {
145                MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
146                    add_message_content_part(
147                        openai::MessagePart::Text { text: text },
148                        message.role,
149                        &mut messages,
150                    )
151                }
152                MessageContent::RedactedThinking(_) => {}
153                MessageContent::Image(image) => {
154                    add_message_content_part(
155                        openai::MessagePart::Image {
156                            image_url: ImageUrl {
157                                url: image.to_base64_url(),
158                                detail: None,
159                            },
160                        },
161                        message.role,
162                        &mut messages,
163                    );
164                }
165                MessageContent::ToolUse(tool_use) => {
166                    let tool_call = openai::ToolCall {
167                        id: tool_use.id.to_string(),
168                        content: openai::ToolCallContent::Function {
169                            function: openai::FunctionContent {
170                                name: tool_use.name.to_string(),
171                                arguments: serde_json::to_string(&tool_use.input)
172                                    .unwrap_or_default(),
173                            },
174                        },
175                    };
176
177                    if let Some(openai::RequestMessage::Assistant { tool_calls, .. }) =
178                        messages.last_mut()
179                    {
180                        tool_calls.push(tool_call);
181                    } else {
182                        messages.push(openai::RequestMessage::Assistant {
183                            content: None,
184                            tool_calls: vec![tool_call],
185                        });
186                    }
187                }
188                MessageContent::ToolResult(tool_result) => {
189                    let content = match &tool_result.content {
190                        LanguageModelToolResultContent::Text(text) => {
191                            vec![openai::MessagePart::Text {
192                                text: text.to_string(),
193                            }]
194                        } // LanguageModelToolResultContent::Image(image) => {
195                          //     vec![openai::MessagePart::Image {
196                          //         image_url: ImageUrl {
197                          //             url: image.to_base64_url(),
198                          //             detail: None,
199                          //         },
200                          //     }]
201                          // }
202                    };
203
204                    messages.push(openai::RequestMessage::Tool {
205                        content: content.into(),
206                        tool_call_id: tool_result.tool_use_id.to_string(),
207                    });
208                }
209            }
210        }
211    }
212
213    openai::Request {
214        model: model_id.into(),
215        messages,
216        stream,
217        stop: request.stop,
218        temperature: request.temperature.unwrap_or(1.0),
219        max_completion_tokens: max_output_tokens,
220        parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() {
221            // Disable parallel tool calls, as the Agent currently expects a maximum of one per turn.
222            Some(false)
223        } else {
224            None
225        },
226        tools: request
227            .tools
228            .into_iter()
229            .map(|tool| openai::ToolDefinition::Function {
230                function: openai::FunctionDefinition {
231                    name: tool.name,
232                    description: Some(tool.description),
233                    parameters: Some(tool.input_schema),
234                },
235            })
236            .collect(),
237        tool_choice: request.tool_choice.map(|choice| match choice {
238            LanguageModelToolChoice::Auto => openai::ToolChoice::Auto,
239            LanguageModelToolChoice::Any => openai::ToolChoice::Required,
240            LanguageModelToolChoice::None => openai::ToolChoice::None,
241        }),
242    }
243}
244
245#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
246pub struct AvailableModel {
247    pub name: String,
248    pub display_name: Option<String>,
249    pub max_tokens: u64,
250    pub max_output_tokens: Option<u64>,
251    pub max_completion_tokens: Option<u64>,
252}