omni_llm_kit/models/openai_provider/
openai_model.rs1use anyhow::anyhow;
2use futures_core::Stream;
3use futures_core::future::BoxFuture;
4use futures_core::stream::BoxStream;
5use schemars::JsonSchema;
6use std::collections::HashMap;
7use std::pin::Pin;
8use std::sync::Arc;
9use crate::OpenAiSettings;
12use crate::http_client::HttpClient;
13use crate::model::{
14 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
15 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest,
16 LanguageModelToolChoice, LanguageModelToolResultContent, MessageContent, Role,
17};
18use crate::models::openai_provider::event_mapper::OpenAiEventMapper;
19use crate::openai::{self, ImageUrl, ResponseStreamEvent};
20use futures_util::{FutureExt, StreamExt};
21use log::info;
22use serde::{Deserialize, Serialize};
23use strum::EnumIter;
24
25pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
26pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
27 LanguageModelProviderName::new("OpenAI");
28
29pub struct OpenAiLanguageModel {
30 pub(crate) id: LanguageModelId,
31 pub(crate) model: openai::Model,
32 pub(crate) http_client: Arc<dyn HttpClient>,
34}
35
36impl OpenAiLanguageModel {
37 async fn stream_completion(
38 &self,
39 request: openai::Request,
40 ) -> anyhow::Result<BoxStream<'static, anyhow::Result<ResponseStreamEvent>>> {
41 let http_client = self.http_client.clone();
42 let openai_settings =
43 global_registry::get!(OpenAiSettings).expect("OpenAiSettings not found");
44 let api_key = openai_settings.api_key.clone();
45 let base_url = openai_settings.api_url.clone();
46
47 let response =
48 openai::stream_completion(http_client.as_ref(), &base_url, &api_key, request).await?;
49 Ok(response.boxed())
50 }
51}
52#[async_trait::async_trait]
53impl LanguageModel for OpenAiLanguageModel {
54 fn id(&self) -> LanguageModelId {
55 self.id.clone()
56 }
57
58 fn name(&self) -> LanguageModelName {
59 LanguageModelName::from(self.model.display_name().to_string())
60 }
61
62 fn provider_id(&self) -> LanguageModelProviderId {
63 OPEN_AI_PROVIDER_ID
64 }
65
66 fn provider_name(&self) -> LanguageModelProviderName {
67 OPEN_AI_PROVIDER_NAME
68 }
69
70 fn max_token_count(&self) -> u64 {
71 self.model.max_token_count()
72 }
73 fn max_output_tokens(&self) -> Option<u64> {
74 self.model.max_output_tokens()
75 }
76 fn supports_tools(&self) -> bool {
77 return true;
78 }
79 fn supports_burn_mode(&self) -> bool {
80 return false;
81 }
82 async fn stream_completion(
83 &self,
84 request: LanguageModelRequest,
85 ) -> Result<
86 BoxStream<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
87 LanguageModelCompletionError,
88 > {
89 let request = into_open_ai(
90 request,
91 self.model.id(),
92 self.model.supports_parallel_tool_calls(),
93 self.max_output_tokens(),
94 );
95 let completion = self.stream_completion(request).await?.boxed();
96 let mapper = OpenAiEventMapper::new();
97 Ok(mapper.map_stream(completion).boxed())
98 }
99}
100fn add_message_content_part(
101 new_part: openai::MessagePart,
102 role: Role,
103 messages: &mut Vec<openai::RequestMessage>,
104) {
105 match (role, messages.last_mut()) {
106 (Role::User, Some(openai::RequestMessage::User { content }))
107 | (
108 Role::Assistant,
109 Some(openai::RequestMessage::Assistant {
110 content: Some(content),
111 ..
112 }),
113 )
114 | (Role::System, Some(openai::RequestMessage::System { content, .. })) => {
115 content.push_part(new_part);
116 }
117 _ => {
118 messages.push(match role {
119 Role::User => openai::RequestMessage::User {
120 content: openai::MessageContent::from(vec![new_part]),
121 },
122 Role::Assistant => openai::RequestMessage::Assistant {
123 content: Some(openai::MessageContent::from(vec![new_part])),
124 tool_calls: Vec::new(),
125 },
126 Role::System => openai::RequestMessage::System {
127 content: openai::MessageContent::from(vec![new_part]),
128 },
129 });
130 }
131 }
132}
133pub fn into_open_ai(
134 request: LanguageModelRequest,
135 model_id: &str,
136 supports_parallel_tool_calls: bool,
137 max_output_tokens: Option<u64>,
138) -> openai::Request {
139 let stream = !model_id.starts_with("o1-");
140
141 let mut messages = Vec::new();
142 for message in request.messages {
143 for content in message.content {
144 match content {
145 MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
146 add_message_content_part(
147 openai::MessagePart::Text { text: text },
148 message.role,
149 &mut messages,
150 )
151 }
152 MessageContent::RedactedThinking(_) => {}
153 MessageContent::Image(image) => {
154 add_message_content_part(
155 openai::MessagePart::Image {
156 image_url: ImageUrl {
157 url: image.to_base64_url(),
158 detail: None,
159 },
160 },
161 message.role,
162 &mut messages,
163 );
164 }
165 MessageContent::ToolUse(tool_use) => {
166 let tool_call = openai::ToolCall {
167 id: tool_use.id.to_string(),
168 content: openai::ToolCallContent::Function {
169 function: openai::FunctionContent {
170 name: tool_use.name.to_string(),
171 arguments: serde_json::to_string(&tool_use.input)
172 .unwrap_or_default(),
173 },
174 },
175 };
176
177 if let Some(openai::RequestMessage::Assistant { tool_calls, .. }) =
178 messages.last_mut()
179 {
180 tool_calls.push(tool_call);
181 } else {
182 messages.push(openai::RequestMessage::Assistant {
183 content: None,
184 tool_calls: vec![tool_call],
185 });
186 }
187 }
188 MessageContent::ToolResult(tool_result) => {
189 let content = match &tool_result.content {
190 LanguageModelToolResultContent::Text(text) => {
191 vec![openai::MessagePart::Text {
192 text: text.to_string(),
193 }]
194 } };
203
204 messages.push(openai::RequestMessage::Tool {
205 content: content.into(),
206 tool_call_id: tool_result.tool_use_id.to_string(),
207 });
208 }
209 }
210 }
211 }
212
213 openai::Request {
214 model: model_id.into(),
215 messages,
216 stream,
217 stop: request.stop,
218 temperature: request.temperature.unwrap_or(1.0),
219 max_completion_tokens: max_output_tokens,
220 parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() {
221 Some(false)
223 } else {
224 None
225 },
226 tools: request
227 .tools
228 .into_iter()
229 .map(|tool| openai::ToolDefinition::Function {
230 function: openai::FunctionDefinition {
231 name: tool.name,
232 description: Some(tool.description),
233 parameters: Some(tool.input_schema),
234 },
235 })
236 .collect(),
237 tool_choice: request.tool_choice.map(|choice| match choice {
238 LanguageModelToolChoice::Auto => openai::ToolChoice::Auto,
239 LanguageModelToolChoice::Any => openai::ToolChoice::Required,
240 LanguageModelToolChoice::None => openai::ToolChoice::None,
241 }),
242 }
243}
244
245#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
246pub struct AvailableModel {
247 pub name: String,
248 pub display_name: Option<String>,
249 pub max_tokens: u64,
250 pub max_output_tokens: Option<u64>,
251 pub max_completion_tokens: Option<u64>,
252}