1use async_trait::async_trait;
2use futures::Stream;
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::HashMap;
7
8use super::{
9 FinishReason, GenerateOptions, GenerateResult, LanguageModel, Message, MessageContent,
10 MessagePart, MessageRole, StreamChunk, StreamOptions, ToolCall, ToolDefinition, Usage,
11};
12use crate::auth::{Auth, AuthCredentials};
13
14pub struct OpenAIProvider {
16 auth: Box<dyn Auth>,
17 client: Client,
18 models: HashMap<String, OpenAIModel>,
19 api_base: String,
20}
21
22#[derive(Debug, Clone)]
23pub struct OpenAIModel {
24 pub id: String,
25 pub name: String,
26 pub max_tokens: u32,
27 pub supports_tools: bool,
28 pub supports_vision: bool,
29 pub supports_caching: bool,
30}
31
32#[derive(Debug, Serialize)]
33struct OpenAIRequest {
34 model: String,
35 messages: Vec<OpenAIMessage>,
36 max_tokens: u32,
37 #[serde(skip_serializing_if = "Option::is_none")]
38 temperature: Option<f32>,
39 #[serde(skip_serializing_if = "Vec::is_empty")]
40 tools: Vec<OpenAITool>,
41 #[serde(skip_serializing_if = "Vec::is_empty")]
42 stop: Vec<String>,
43 #[serde(skip_serializing_if = "Option::is_none")]
44 stream: Option<bool>,
45}
46
47#[derive(Debug, Serialize, Deserialize)]
48struct OpenAIMessage {
49 role: String,
50 content: OpenAIContent,
51 #[serde(skip_serializing_if = "Option::is_none")]
52 name: Option<String>,
53 #[serde(skip_serializing_if = "Option::is_none")]
54 tool_calls: Option<Vec<OpenAIToolCall>>,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 tool_call_id: Option<String>,
57}
58
59#[derive(Debug, Serialize, Deserialize)]
60#[serde(untagged)]
61enum OpenAIContent {
62 Text(String),
63 Parts(Vec<OpenAIContentPart>),
64}
65
66#[derive(Debug, Serialize, Deserialize)]
67#[serde(tag = "type")]
68enum OpenAIContentPart {
69 #[serde(rename = "text")]
70 Text { text: String },
71 #[serde(rename = "image_url")]
72 ImageUrl { image_url: OpenAIImageUrl },
73}
74
75#[derive(Debug, Serialize, Deserialize)]
76struct OpenAIImageUrl {
77 url: String,
78 detail: Option<String>,
79}
80
81#[derive(Debug, Serialize)]
82struct OpenAITool {
83 #[serde(rename = "type")]
84 tool_type: String,
85 function: OpenAIFunction,
86}
87
88#[derive(Debug, Serialize)]
89struct OpenAIFunction {
90 name: String,
91 description: String,
92 parameters: Value,
93}
94
95#[derive(Debug, Serialize, Deserialize)]
96struct OpenAIToolCall {
97 id: String,
98 #[serde(rename = "type")]
99 tool_type: String,
100 function: OpenAIFunctionCall,
101}
102
103#[derive(Debug, Serialize, Deserialize)]
104struct OpenAIFunctionCall {
105 name: String,
106 arguments: String,
107}
108
109#[derive(Debug, Deserialize)]
110struct OpenAIResponse {
111 choices: Vec<OpenAIChoice>,
112 usage: OpenAIUsage,
113}
114
115#[derive(Debug, Deserialize)]
116struct OpenAIChoice {
117 message: OpenAIMessage,
118 finish_reason: Option<String>,
119}
120
121#[derive(Debug, Deserialize)]
122struct OpenAIUsage {
123 prompt_tokens: u32,
124 completion_tokens: u32,
125 total_tokens: u32,
126}
127
128impl OpenAIProvider {
129 const API_BASE: &'static str = "https://api.openai.com";
130
131 pub fn new(auth: Box<dyn Auth>) -> Self {
132 Self::with_api_base(auth, Self::API_BASE.to_string())
133 }
134
135 pub fn with_api_base(auth: Box<dyn Auth>, api_base: String) -> Self {
136 let client = Client::new();
137 let models = Self::default_models();
138
139 Self {
140 auth,
141 client,
142 models,
143 api_base,
144 }
145 }
146
147 fn default_models() -> HashMap<String, OpenAIModel> {
148 let mut models = HashMap::new();
149
150 models.insert(
151 "gpt-4o".to_string(),
152 OpenAIModel {
153 id: "gpt-4o".to_string(),
154 name: "GPT-4o".to_string(),
155 max_tokens: 4096,
156 supports_tools: true,
157 supports_vision: true,
158 supports_caching: false,
159 },
160 );
161
162 models.insert(
163 "gpt-4o-mini".to_string(),
164 OpenAIModel {
165 id: "gpt-4o-mini".to_string(),
166 name: "GPT-4o Mini".to_string(),
167 max_tokens: 4096,
168 supports_tools: true,
169 supports_vision: true,
170 supports_caching: false,
171 },
172 );
173
174 models.insert(
175 "gpt-4-turbo".to_string(),
176 OpenAIModel {
177 id: "gpt-4-turbo".to_string(),
178 name: "GPT-4 Turbo".to_string(),
179 max_tokens: 4096,
180 supports_tools: true,
181 supports_vision: true,
182 supports_caching: false,
183 },
184 );
185
186 models.insert(
187 "gpt-3.5-turbo".to_string(),
188 OpenAIModel {
189 id: "gpt-3.5-turbo".to_string(),
190 name: "GPT-3.5 Turbo".to_string(),
191 max_tokens: 4096,
192 supports_tools: true,
193 supports_vision: false,
194 supports_caching: false,
195 },
196 );
197
198 models.insert(
199 "o1-preview".to_string(),
200 OpenAIModel {
201 id: "o1-preview".to_string(),
202 name: "OpenAI o1 Preview".to_string(),
203 max_tokens: 32768,
204 supports_tools: false,
205 supports_vision: false,
206 supports_caching: false,
207 },
208 );
209
210 models.insert(
211 "o1-mini".to_string(),
212 OpenAIModel {
213 id: "o1-mini".to_string(),
214 name: "OpenAI o1 Mini".to_string(),
215 max_tokens: 65536,
216 supports_tools: false,
217 supports_vision: false,
218 supports_caching: false,
219 },
220 );
221
222 models
223 }
224
225 async fn get_auth_header(&self) -> crate::Result<String> {
226 let credentials = self.auth.get_credentials().await?;
227
228 match credentials {
229 AuthCredentials::ApiKey { key } => Ok(format!("Bearer {}", key)),
230 _ => Err(crate::Error::Other(anyhow::anyhow!(
231 "Invalid credentials for OpenAI (API key required)"
232 ))),
233 }
234 }
235
236 fn convert_messages(&self, messages: Vec<Message>) -> Vec<OpenAIMessage> {
237 messages
238 .into_iter()
239 .map(|msg| self.convert_message(msg))
240 .collect()
241 }
242
243 fn convert_message(&self, message: Message) -> OpenAIMessage {
244 let role = match message.role {
245 MessageRole::System => "system",
246 MessageRole::User => "user",
247 MessageRole::Assistant => "assistant",
248 MessageRole::Tool => "tool",
249 }
250 .to_string();
251
252 let content = match message.content {
253 MessageContent::Text(text) => OpenAIContent::Text(text),
254 MessageContent::Parts(parts) => {
255 let openai_parts: Vec<OpenAIContentPart> = parts
256 .into_iter()
257 .filter_map(|part| match part {
258 MessagePart::Text { text } => Some(OpenAIContentPart::Text { text }),
259 MessagePart::Image { image } => {
260 if let Some(url) = image.url {
261 Some(OpenAIContentPart::ImageUrl {
262 image_url: OpenAIImageUrl {
263 url,
264 detail: Some("auto".to_string()),
265 },
266 })
267 } else if let Some(base64) = image.base64 {
268 Some(OpenAIContentPart::ImageUrl {
269 image_url: OpenAIImageUrl {
270 url: format!("data:{};base64,{}", image.mime_type, base64),
271 detail: Some("auto".to_string()),
272 },
273 })
274 } else {
275 None
276 }
277 }
278 })
279 .collect();
280 OpenAIContent::Parts(openai_parts)
281 }
282 };
283
284 let tool_calls = message.tool_calls.map(|calls| {
285 calls
286 .into_iter()
287 .map(|call| OpenAIToolCall {
288 id: call.id,
289 tool_type: "function".to_string(),
290 function: OpenAIFunctionCall {
291 name: call.name,
292 arguments: call.arguments.to_string(),
293 },
294 })
295 .collect()
296 });
297
298 OpenAIMessage {
299 role,
300 content,
301 name: message.name,
302 tool_calls,
303 tool_call_id: message.tool_call_id,
304 }
305 }
306
307 fn convert_tools(&self, tools: Vec<ToolDefinition>) -> Vec<OpenAITool> {
308 tools
309 .into_iter()
310 .map(|tool| OpenAITool {
311 tool_type: "function".to_string(),
312 function: OpenAIFunction {
313 name: tool.name,
314 description: tool.description,
315 parameters: tool.parameters,
316 },
317 })
318 .collect()
319 }
320
321 fn parse_finish_reason(&self, reason: Option<String>) -> FinishReason {
322 match reason.as_deref() {
323 Some("stop") => FinishReason::Stop,
324 Some("length") => FinishReason::Length,
325 Some("tool_calls") => FinishReason::ToolCalls,
326 Some("content_filter") => FinishReason::ContentFilter,
327 _ => FinishReason::Stop,
328 }
329 }
330}
331
332pub struct OpenAIModelWithProvider {
333 model: OpenAIModel,
334 provider: OpenAIProvider,
335}
336
337impl OpenAIModelWithProvider {
338 pub fn new(model: OpenAIModel, provider: OpenAIProvider) -> Self {
339 Self { model, provider }
340 }
341}
342
343#[async_trait]
344impl LanguageModel for OpenAIModelWithProvider {
345 async fn generate(
346 &self,
347 messages: Vec<Message>,
348 options: GenerateOptions,
349 ) -> crate::Result<GenerateResult> {
350 let auth_header = self.provider.get_auth_header().await?;
351 let openai_messages = self.provider.convert_messages(messages);
352 let tools = self.provider.convert_tools(options.tools);
353
354 let request = OpenAIRequest {
355 model: self.model.id.clone(),
356 messages: openai_messages,
357 max_tokens: options.max_tokens.unwrap_or(self.model.max_tokens),
358 temperature: options.temperature,
359 tools,
360 stop: options.stop_sequences,
361 stream: Some(false),
362 };
363
364 let response = self
365 .provider
366 .client
367 .post(&format!("{}/v1/chat/completions", self.provider.api_base))
368 .header("Authorization", auth_header)
369 .header("Content-Type", "application/json")
370 .json(&request)
371 .send()
372 .await
373 .map_err(|e| crate::Error::Other(anyhow::anyhow!("Request failed: {}", e)))?;
374
375 if !response.status().is_success() {
376 let status = response.status();
377 let body = response.text().await.unwrap_or_default();
378 return Err(crate::Error::Other(anyhow::anyhow!(
379 "API request failed with status {}: {}",
380 status,
381 body
382 )));
383 }
384
385 let openai_response: OpenAIResponse = response
386 .json()
387 .await
388 .map_err(|e| crate::Error::Other(anyhow::anyhow!("Failed to parse response: {}", e)))?;
389
390 let choice = openai_response
391 .choices
392 .into_iter()
393 .next()
394 .ok_or_else(|| crate::Error::Other(anyhow::anyhow!("No choices in response")))?;
395
396 let content = match choice.message.content {
397 OpenAIContent::Text(text) => text,
398 OpenAIContent::Parts(parts) => {
399 parts
400 .into_iter()
401 .filter_map(|part| match part {
402 OpenAIContentPart::Text { text } => Some(text),
403 _ => None,
404 })
405 .collect::<Vec<_>>()
406 .join("")
407 }
408 };
409
410 let tool_calls = choice
411 .message
412 .tool_calls
413 .unwrap_or_default()
414 .into_iter()
415 .map(|call| ToolCall {
416 id: call.id,
417 name: call.function.name,
418 arguments: serde_json::from_str(&call.function.arguments)
419 .unwrap_or(serde_json::Value::Object(serde_json::Map::new())),
420 })
421 .collect();
422
423 Ok(GenerateResult {
424 content,
425 tool_calls,
426 usage: Usage {
427 prompt_tokens: openai_response.usage.prompt_tokens,
428 completion_tokens: openai_response.usage.completion_tokens,
429 total_tokens: openai_response.usage.total_tokens,
430 },
431 finish_reason: self.provider.parse_finish_reason(choice.finish_reason),
432 })
433 }
434
435 async fn stream(
436 &self,
437 messages: Vec<Message>,
438 options: StreamOptions,
439 ) -> crate::Result<Box<dyn Stream<Item = crate::Result<StreamChunk>> + Send + Unpin>> {
440 Err(crate::Error::Other(anyhow::anyhow!(
443 "Streaming not yet implemented for OpenAI"
444 )))
445 }
446
447 fn supports_tools(&self) -> bool {
448 self.model.supports_tools
449 }
450
451 fn supports_vision(&self) -> bool {
452 self.model.supports_vision
453 }
454
455 fn supports_caching(&self) -> bool {
456 self.model.supports_caching
457 }
458}
459
460pub struct AzureOpenAIProvider {
462 base_provider: OpenAIProvider,
463 deployment_name: String,
464 api_version: String,
465}
466
467impl AzureOpenAIProvider {
468 pub fn new(
469 auth: Box<dyn Auth>,
470 endpoint: String,
471 deployment_name: String,
472 api_version: String,
473 ) -> Self {
474 let base_provider = OpenAIProvider::with_api_base(auth, endpoint);
475
476 Self {
477 base_provider,
478 deployment_name,
479 api_version,
480 }
481 }
482
483 pub fn default_api_version() -> String {
484 "2024-02-15-preview".to_string()
485 }
486
487 fn get_endpoint(&self) -> String {
488 format!(
489 "{}/openai/deployments/{}/chat/completions?api-version={}",
490 self.base_provider.api_base, self.deployment_name, self.api_version
491 )
492 }
493}
494
495pub struct AzureOpenAIModelWithProvider {
496 model: OpenAIModel,
497 provider: AzureOpenAIProvider,
498}
499
500impl AzureOpenAIModelWithProvider {
501 pub fn new(model: OpenAIModel, provider: AzureOpenAIProvider) -> Self {
502 Self { model, provider }
503 }
504}
505
506#[async_trait]
507impl LanguageModel for AzureOpenAIModelWithProvider {
508 async fn generate(
509 &self,
510 messages: Vec<Message>,
511 options: GenerateOptions,
512 ) -> crate::Result<GenerateResult> {
513 let auth_header = self.provider.base_provider.get_auth_header().await?;
514 let openai_messages = self.provider.base_provider.convert_messages(messages);
515 let tools = self.provider.base_provider.convert_tools(options.tools);
516
517 let request = OpenAIRequest {
519 model: self.model.id.clone(), messages: openai_messages,
521 max_tokens: options.max_tokens.unwrap_or(self.model.max_tokens),
522 temperature: options.temperature,
523 tools,
524 stop: options.stop_sequences,
525 stream: Some(false),
526 };
527
528 let response = self
529 .provider
530 .base_provider
531 .client
532 .post(&self.provider.get_endpoint())
533 .header("Authorization", auth_header)
534 .header("Content-Type", "application/json")
535 .json(&request)
536 .send()
537 .await
538 .map_err(|e| crate::Error::Other(anyhow::anyhow!("Request failed: {}", e)))?;
539
540 if !response.status().is_success() {
541 let status = response.status();
542 let body = response.text().await.unwrap_or_default();
543 return Err(crate::Error::Other(anyhow::anyhow!(
544 "API request failed with status {}: {}",
545 status,
546 body
547 )));
548 }
549
550 let openai_response: OpenAIResponse = response
551 .json()
552 .await
553 .map_err(|e| crate::Error::Other(anyhow::anyhow!("Failed to parse response: {}", e)))?;
554
555 let choice = openai_response
556 .choices
557 .into_iter()
558 .next()
559 .ok_or_else(|| crate::Error::Other(anyhow::anyhow!("No choices in response")))?;
560
561 let content = match choice.message.content {
562 OpenAIContent::Text(text) => text,
563 OpenAIContent::Parts(parts) => {
564 parts
565 .into_iter()
566 .filter_map(|part| match part {
567 OpenAIContentPart::Text { text } => Some(text),
568 _ => None,
569 })
570 .collect::<Vec<_>>()
571 .join("")
572 }
573 };
574
575 let tool_calls = choice
576 .message
577 .tool_calls
578 .unwrap_or_default()
579 .into_iter()
580 .map(|call| ToolCall {
581 id: call.id,
582 name: call.function.name,
583 arguments: serde_json::from_str(&call.function.arguments)
584 .unwrap_or(serde_json::Value::Object(serde_json::Map::new())),
585 })
586 .collect();
587
588 Ok(GenerateResult {
589 content,
590 tool_calls,
591 usage: Usage {
592 prompt_tokens: openai_response.usage.prompt_tokens,
593 completion_tokens: openai_response.usage.completion_tokens,
594 total_tokens: openai_response.usage.total_tokens,
595 },
596 finish_reason: self.provider.base_provider.parse_finish_reason(choice.finish_reason),
597 })
598 }
599
600 async fn stream(
601 &self,
602 _messages: Vec<Message>,
603 _options: StreamOptions,
604 ) -> crate::Result<Box<dyn Stream<Item = crate::Result<StreamChunk>> + Send + Unpin>> {
605 Err(crate::Error::Other(anyhow::anyhow!(
606 "Streaming not yet implemented for Azure OpenAI"
607 )))
608 }
609
610 fn supports_tools(&self) -> bool {
611 self.model.supports_tools
612 }
613
614 fn supports_vision(&self) -> bool {
615 self.model.supports_vision
616 }
617
618 fn supports_caching(&self) -> bool {
619 self.model.supports_caching
620 }
621}
622
623#[cfg(test)]
624mod tests {
625 use super::*;
626
627 #[test]
628 fn test_default_models() {
629 let models = OpenAIProvider::default_models();
630 assert!(!models.is_empty());
631 assert!(models.contains_key("gpt-4o"));
632 assert!(models.contains_key("gpt-4o-mini"));
633 assert!(models.contains_key("o1-preview"));
634 }
635
636 #[test]
637 fn test_model_capabilities() {
638 let models = OpenAIProvider::default_models();
639 let gpt4o = models.get("gpt-4o").unwrap();
640 assert!(gpt4o.supports_tools);
641 assert!(gpt4o.supports_vision);
642
643 let o1 = models.get("o1-preview").unwrap();
644 assert!(!o1.supports_tools);
645 assert!(!o1.supports_vision);
646
647 let gpt35 = models.get("gpt-3.5-turbo").unwrap();
648 assert!(gpt35.supports_tools);
649 assert!(!gpt35.supports_vision);
650 }
651
652 #[test]
653 fn test_azure_endpoint() {
654 use crate::auth::FileAuthStorage;
655 use tempfile::tempdir;
656
657 let temp_dir = tempdir().unwrap();
658 let auth_path = temp_dir.path().join("auth.json");
659 let storage = FileAuthStorage::new(auth_path);
660 let auth = Box::new(crate::auth::AnthropicAuth::new(Box::new(storage))); let provider = AzureOpenAIProvider::new(
663 auth,
664 "https://test.openai.azure.com".to_string(),
665 "gpt-4".to_string(),
666 "2024-02-15-preview".to_string(),
667 );
668
669 let endpoint = provider.get_endpoint();
670 assert!(endpoint.contains("openai/deployments/gpt-4"));
671 assert!(endpoint.contains("api-version=2024-02-15-preview"));
672 }
673}