code_mesh_core/llm/
openai.rs

1use async_trait::async_trait;
2use futures::Stream;
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::HashMap;
7
8use super::{
9    FinishReason, GenerateOptions, GenerateResult, LanguageModel, Message, MessageContent,
10    MessagePart, MessageRole, StreamChunk, StreamOptions, ToolCall, ToolDefinition, Usage,
11};
12use crate::auth::{Auth, AuthCredentials};
13
14/// OpenAI provider implementation
15pub struct OpenAIProvider {
16    auth: Box<dyn Auth>,
17    client: Client,
18    models: HashMap<String, OpenAIModel>,
19    api_base: String,
20}
21
22#[derive(Debug, Clone)]
23pub struct OpenAIModel {
24    pub id: String,
25    pub name: String,
26    pub max_tokens: u32,
27    pub supports_tools: bool,
28    pub supports_vision: bool,
29    pub supports_caching: bool,
30}
31
32#[derive(Debug, Serialize)]
33struct OpenAIRequest {
34    model: String,
35    messages: Vec<OpenAIMessage>,
36    max_tokens: u32,
37    #[serde(skip_serializing_if = "Option::is_none")]
38    temperature: Option<f32>,
39    #[serde(skip_serializing_if = "Vec::is_empty")]
40    tools: Vec<OpenAITool>,
41    #[serde(skip_serializing_if = "Vec::is_empty")]
42    stop: Vec<String>,
43    #[serde(skip_serializing_if = "Option::is_none")]
44    stream: Option<bool>,
45}
46
47#[derive(Debug, Serialize, Deserialize)]
48struct OpenAIMessage {
49    role: String,
50    content: OpenAIContent,
51    #[serde(skip_serializing_if = "Option::is_none")]
52    name: Option<String>,
53    #[serde(skip_serializing_if = "Option::is_none")]
54    tool_calls: Option<Vec<OpenAIToolCall>>,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    tool_call_id: Option<String>,
57}
58
59#[derive(Debug, Serialize, Deserialize)]
60#[serde(untagged)]
61enum OpenAIContent {
62    Text(String),
63    Parts(Vec<OpenAIContentPart>),
64}
65
66#[derive(Debug, Serialize, Deserialize)]
67#[serde(tag = "type")]
68enum OpenAIContentPart {
69    #[serde(rename = "text")]
70    Text { text: String },
71    #[serde(rename = "image_url")]
72    ImageUrl { image_url: OpenAIImageUrl },
73}
74
75#[derive(Debug, Serialize, Deserialize)]
76struct OpenAIImageUrl {
77    url: String,
78    detail: Option<String>,
79}
80
81#[derive(Debug, Serialize)]
82struct OpenAITool {
83    #[serde(rename = "type")]
84    tool_type: String,
85    function: OpenAIFunction,
86}
87
88#[derive(Debug, Serialize)]
89struct OpenAIFunction {
90    name: String,
91    description: String,
92    parameters: Value,
93}
94
95#[derive(Debug, Serialize, Deserialize)]
96struct OpenAIToolCall {
97    id: String,
98    #[serde(rename = "type")]
99    tool_type: String,
100    function: OpenAIFunctionCall,
101}
102
103#[derive(Debug, Serialize, Deserialize)]
104struct OpenAIFunctionCall {
105    name: String,
106    arguments: String,
107}
108
109#[derive(Debug, Deserialize)]
110struct OpenAIResponse {
111    choices: Vec<OpenAIChoice>,
112    usage: OpenAIUsage,
113}
114
115#[derive(Debug, Deserialize)]
116struct OpenAIChoice {
117    message: OpenAIMessage,
118    finish_reason: Option<String>,
119}
120
121#[derive(Debug, Deserialize)]
122struct OpenAIUsage {
123    prompt_tokens: u32,
124    completion_tokens: u32,
125    total_tokens: u32,
126}
127
128impl OpenAIProvider {
129    const API_BASE: &'static str = "https://api.openai.com";
130    
131    pub fn new(auth: Box<dyn Auth>) -> Self {
132        Self::with_api_base(auth, Self::API_BASE.to_string())
133    }
134    
135    pub fn with_api_base(auth: Box<dyn Auth>, api_base: String) -> Self {
136        let client = Client::new();
137        let models = Self::default_models();
138        
139        Self {
140            auth,
141            client,
142            models,
143            api_base,
144        }
145    }
146    
147    fn default_models() -> HashMap<String, OpenAIModel> {
148        let mut models = HashMap::new();
149        
150        models.insert(
151            "gpt-4o".to_string(),
152            OpenAIModel {
153                id: "gpt-4o".to_string(),
154                name: "GPT-4o".to_string(),
155                max_tokens: 4096,
156                supports_tools: true,
157                supports_vision: true,
158                supports_caching: false,
159            },
160        );
161        
162        models.insert(
163            "gpt-4o-mini".to_string(),
164            OpenAIModel {
165                id: "gpt-4o-mini".to_string(),
166                name: "GPT-4o Mini".to_string(),
167                max_tokens: 4096,
168                supports_tools: true,
169                supports_vision: true,
170                supports_caching: false,
171            },
172        );
173        
174        models.insert(
175            "gpt-4-turbo".to_string(),
176            OpenAIModel {
177                id: "gpt-4-turbo".to_string(),
178                name: "GPT-4 Turbo".to_string(),
179                max_tokens: 4096,
180                supports_tools: true,
181                supports_vision: true,
182                supports_caching: false,
183            },
184        );
185        
186        models.insert(
187            "gpt-3.5-turbo".to_string(),
188            OpenAIModel {
189                id: "gpt-3.5-turbo".to_string(),
190                name: "GPT-3.5 Turbo".to_string(),
191                max_tokens: 4096,
192                supports_tools: true,
193                supports_vision: false,
194                supports_caching: false,
195            },
196        );
197        
198        models.insert(
199            "o1-preview".to_string(),
200            OpenAIModel {
201                id: "o1-preview".to_string(),
202                name: "OpenAI o1 Preview".to_string(),
203                max_tokens: 32768,
204                supports_tools: false,
205                supports_vision: false,
206                supports_caching: false,
207            },
208        );
209        
210        models.insert(
211            "o1-mini".to_string(),
212            OpenAIModel {
213                id: "o1-mini".to_string(),
214                name: "OpenAI o1 Mini".to_string(),
215                max_tokens: 65536,
216                supports_tools: false,
217                supports_vision: false,
218                supports_caching: false,
219            },
220        );
221        
222        models
223    }
224    
225    async fn get_auth_header(&self) -> crate::Result<String> {
226        let credentials = self.auth.get_credentials().await?;
227        
228        match credentials {
229            AuthCredentials::ApiKey { key } => Ok(format!("Bearer {}", key)),
230            _ => Err(crate::Error::Other(anyhow::anyhow!(
231                "Invalid credentials for OpenAI (API key required)"
232            ))),
233        }
234    }
235    
236    fn convert_messages(&self, messages: Vec<Message>) -> Vec<OpenAIMessage> {
237        messages
238            .into_iter()
239            .map(|msg| self.convert_message(msg))
240            .collect()
241    }
242    
243    fn convert_message(&self, message: Message) -> OpenAIMessage {
244        let role = match message.role {
245            MessageRole::System => "system",
246            MessageRole::User => "user",
247            MessageRole::Assistant => "assistant",
248            MessageRole::Tool => "tool",
249        }
250        .to_string();
251        
252        let content = match message.content {
253            MessageContent::Text(text) => OpenAIContent::Text(text),
254            MessageContent::Parts(parts) => {
255                let openai_parts: Vec<OpenAIContentPart> = parts
256                    .into_iter()
257                    .filter_map(|part| match part {
258                        MessagePart::Text { text } => Some(OpenAIContentPart::Text { text }),
259                        MessagePart::Image { image } => {
260                            if let Some(url) = image.url {
261                                Some(OpenAIContentPart::ImageUrl {
262                                    image_url: OpenAIImageUrl {
263                                        url,
264                                        detail: Some("auto".to_string()),
265                                    },
266                                })
267                            } else if let Some(base64) = image.base64 {
268                                Some(OpenAIContentPart::ImageUrl {
269                                    image_url: OpenAIImageUrl {
270                                        url: format!("data:{};base64,{}", image.mime_type, base64),
271                                        detail: Some("auto".to_string()),
272                                    },
273                                })
274                            } else {
275                                None
276                            }
277                        }
278                    })
279                    .collect();
280                OpenAIContent::Parts(openai_parts)
281            }
282        };
283        
284        let tool_calls = message.tool_calls.map(|calls| {
285            calls
286                .into_iter()
287                .map(|call| OpenAIToolCall {
288                    id: call.id,
289                    tool_type: "function".to_string(),
290                    function: OpenAIFunctionCall {
291                        name: call.name,
292                        arguments: call.arguments.to_string(),
293                    },
294                })
295                .collect()
296        });
297        
298        OpenAIMessage {
299            role,
300            content,
301            name: message.name,
302            tool_calls,
303            tool_call_id: message.tool_call_id,
304        }
305    }
306    
307    fn convert_tools(&self, tools: Vec<ToolDefinition>) -> Vec<OpenAITool> {
308        tools
309            .into_iter()
310            .map(|tool| OpenAITool {
311                tool_type: "function".to_string(),
312                function: OpenAIFunction {
313                    name: tool.name,
314                    description: tool.description,
315                    parameters: tool.parameters,
316                },
317            })
318            .collect()
319    }
320    
321    fn parse_finish_reason(&self, reason: Option<String>) -> FinishReason {
322        match reason.as_deref() {
323            Some("stop") => FinishReason::Stop,
324            Some("length") => FinishReason::Length,
325            Some("tool_calls") => FinishReason::ToolCalls,
326            Some("content_filter") => FinishReason::ContentFilter,
327            _ => FinishReason::Stop,
328        }
329    }
330}
331
332pub struct OpenAIModelWithProvider {
333    model: OpenAIModel,
334    provider: OpenAIProvider,
335}
336
337impl OpenAIModelWithProvider {
338    pub fn new(model: OpenAIModel, provider: OpenAIProvider) -> Self {
339        Self { model, provider }
340    }
341}
342
343#[async_trait]
344impl LanguageModel for OpenAIModelWithProvider {
345    async fn generate(
346        &self,
347        messages: Vec<Message>,
348        options: GenerateOptions,
349    ) -> crate::Result<GenerateResult> {
350        let auth_header = self.provider.get_auth_header().await?;
351        let openai_messages = self.provider.convert_messages(messages);
352        let tools = self.provider.convert_tools(options.tools);
353        
354        let request = OpenAIRequest {
355            model: self.model.id.clone(),
356            messages: openai_messages,
357            max_tokens: options.max_tokens.unwrap_or(self.model.max_tokens),
358            temperature: options.temperature,
359            tools,
360            stop: options.stop_sequences,
361            stream: Some(false),
362        };
363        
364        let response = self
365            .provider
366            .client
367            .post(&format!("{}/v1/chat/completions", self.provider.api_base))
368            .header("Authorization", auth_header)
369            .header("Content-Type", "application/json")
370            .json(&request)
371            .send()
372            .await
373            .map_err(|e| crate::Error::Other(anyhow::anyhow!("Request failed: {}", e)))?;
374            
375        if !response.status().is_success() {
376            let status = response.status();
377            let body = response.text().await.unwrap_or_default();
378            return Err(crate::Error::Other(anyhow::anyhow!(
379                "API request failed with status {}: {}",
380                status,
381                body
382            )));
383        }
384        
385        let openai_response: OpenAIResponse = response
386            .json()
387            .await
388            .map_err(|e| crate::Error::Other(anyhow::anyhow!("Failed to parse response: {}", e)))?;
389            
390        let choice = openai_response
391            .choices
392            .into_iter()
393            .next()
394            .ok_or_else(|| crate::Error::Other(anyhow::anyhow!("No choices in response")))?;
395            
396        let content = match choice.message.content {
397            OpenAIContent::Text(text) => text,
398            OpenAIContent::Parts(parts) => {
399                parts
400                    .into_iter()
401                    .filter_map(|part| match part {
402                        OpenAIContentPart::Text { text } => Some(text),
403                        _ => None,
404                    })
405                    .collect::<Vec<_>>()
406                    .join("")
407            }
408        };
409        
410        let tool_calls = choice
411            .message
412            .tool_calls
413            .unwrap_or_default()
414            .into_iter()
415            .map(|call| ToolCall {
416                id: call.id,
417                name: call.function.name,
418                arguments: serde_json::from_str(&call.function.arguments)
419                    .unwrap_or(serde_json::Value::Object(serde_json::Map::new())),
420            })
421            .collect();
422            
423        Ok(GenerateResult {
424            content,
425            tool_calls,
426            usage: Usage {
427                prompt_tokens: openai_response.usage.prompt_tokens,
428                completion_tokens: openai_response.usage.completion_tokens,
429                total_tokens: openai_response.usage.total_tokens,
430            },
431            finish_reason: self.provider.parse_finish_reason(choice.finish_reason),
432        })
433    }
434    
435    async fn stream(
436        &self,
437        messages: Vec<Message>,
438        options: StreamOptions,
439    ) -> crate::Result<Box<dyn Stream<Item = crate::Result<StreamChunk>> + Send + Unpin>> {
440        // Similar to generate but with stream: true
441        // Implementation would handle SSE stream parsing
442        Err(crate::Error::Other(anyhow::anyhow!(
443            "Streaming not yet implemented for OpenAI"
444        )))
445    }
446    
447    fn supports_tools(&self) -> bool {
448        self.model.supports_tools
449    }
450    
451    fn supports_vision(&self) -> bool {
452        self.model.supports_vision
453    }
454    
455    fn supports_caching(&self) -> bool {
456        self.model.supports_caching
457    }
458}
459
460/// Azure OpenAI provider implementation
461pub struct AzureOpenAIProvider {
462    base_provider: OpenAIProvider,
463    deployment_name: String,
464    api_version: String,
465}
466
467impl AzureOpenAIProvider {
468    pub fn new(
469        auth: Box<dyn Auth>,
470        endpoint: String,
471        deployment_name: String,
472        api_version: String,
473    ) -> Self {
474        let base_provider = OpenAIProvider::with_api_base(auth, endpoint);
475        
476        Self {
477            base_provider,
478            deployment_name,
479            api_version,
480        }
481    }
482    
483    pub fn default_api_version() -> String {
484        "2024-02-15-preview".to_string()
485    }
486    
487    fn get_endpoint(&self) -> String {
488        format!(
489            "{}/openai/deployments/{}/chat/completions?api-version={}",
490            self.base_provider.api_base, self.deployment_name, self.api_version
491        )
492    }
493}
494
495pub struct AzureOpenAIModelWithProvider {
496    model: OpenAIModel,
497    provider: AzureOpenAIProvider,
498}
499
500impl AzureOpenAIModelWithProvider {
501    pub fn new(model: OpenAIModel, provider: AzureOpenAIProvider) -> Self {
502        Self { model, provider }
503    }
504}
505
506#[async_trait]
507impl LanguageModel for AzureOpenAIModelWithProvider {
508    async fn generate(
509        &self,
510        messages: Vec<Message>,
511        options: GenerateOptions,
512    ) -> crate::Result<GenerateResult> {
513        let auth_header = self.provider.base_provider.get_auth_header().await?;
514        let openai_messages = self.provider.base_provider.convert_messages(messages);
515        let tools = self.provider.base_provider.convert_tools(options.tools);
516        
517        // Azure uses deployment name instead of model in URL
518        let request = OpenAIRequest {
519            model: self.model.id.clone(), // Still include model in body for compatibility
520            messages: openai_messages,
521            max_tokens: options.max_tokens.unwrap_or(self.model.max_tokens),
522            temperature: options.temperature,
523            tools,
524            stop: options.stop_sequences,
525            stream: Some(false),
526        };
527        
528        let response = self
529            .provider
530            .base_provider
531            .client
532            .post(&self.provider.get_endpoint())
533            .header("Authorization", auth_header)
534            .header("Content-Type", "application/json")
535            .json(&request)
536            .send()
537            .await
538            .map_err(|e| crate::Error::Other(anyhow::anyhow!("Request failed: {}", e)))?;
539            
540        if !response.status().is_success() {
541            let status = response.status();
542            let body = response.text().await.unwrap_or_default();
543            return Err(crate::Error::Other(anyhow::anyhow!(
544                "API request failed with status {}: {}",
545                status,
546                body
547            )));
548        }
549        
550        let openai_response: OpenAIResponse = response
551            .json()
552            .await
553            .map_err(|e| crate::Error::Other(anyhow::anyhow!("Failed to parse response: {}", e)))?;
554            
555        let choice = openai_response
556            .choices
557            .into_iter()
558            .next()
559            .ok_or_else(|| crate::Error::Other(anyhow::anyhow!("No choices in response")))?;
560            
561        let content = match choice.message.content {
562            OpenAIContent::Text(text) => text,
563            OpenAIContent::Parts(parts) => {
564                parts
565                    .into_iter()
566                    .filter_map(|part| match part {
567                        OpenAIContentPart::Text { text } => Some(text),
568                        _ => None,
569                    })
570                    .collect::<Vec<_>>()
571                    .join("")
572            }
573        };
574        
575        let tool_calls = choice
576            .message
577            .tool_calls
578            .unwrap_or_default()
579            .into_iter()
580            .map(|call| ToolCall {
581                id: call.id,
582                name: call.function.name,
583                arguments: serde_json::from_str(&call.function.arguments)
584                    .unwrap_or(serde_json::Value::Object(serde_json::Map::new())),
585            })
586            .collect();
587            
588        Ok(GenerateResult {
589            content,
590            tool_calls,
591            usage: Usage {
592                prompt_tokens: openai_response.usage.prompt_tokens,
593                completion_tokens: openai_response.usage.completion_tokens,
594                total_tokens: openai_response.usage.total_tokens,
595            },
596            finish_reason: self.provider.base_provider.parse_finish_reason(choice.finish_reason),
597        })
598    }
599    
600    async fn stream(
601        &self,
602        _messages: Vec<Message>,
603        _options: StreamOptions,
604    ) -> crate::Result<Box<dyn Stream<Item = crate::Result<StreamChunk>> + Send + Unpin>> {
605        Err(crate::Error::Other(anyhow::anyhow!(
606            "Streaming not yet implemented for Azure OpenAI"
607        )))
608    }
609    
610    fn supports_tools(&self) -> bool {
611        self.model.supports_tools
612    }
613    
614    fn supports_vision(&self) -> bool {
615        self.model.supports_vision
616    }
617    
618    fn supports_caching(&self) -> bool {
619        self.model.supports_caching
620    }
621}
622
623#[cfg(test)]
624mod tests {
625    use super::*;
626    
627    #[test]
628    fn test_default_models() {
629        let models = OpenAIProvider::default_models();
630        assert!(!models.is_empty());
631        assert!(models.contains_key("gpt-4o"));
632        assert!(models.contains_key("gpt-4o-mini"));
633        assert!(models.contains_key("o1-preview"));
634    }
635    
636    #[test]
637    fn test_model_capabilities() {
638        let models = OpenAIProvider::default_models();
639        let gpt4o = models.get("gpt-4o").unwrap();
640        assert!(gpt4o.supports_tools);
641        assert!(gpt4o.supports_vision);
642        
643        let o1 = models.get("o1-preview").unwrap();
644        assert!(!o1.supports_tools);
645        assert!(!o1.supports_vision);
646        
647        let gpt35 = models.get("gpt-3.5-turbo").unwrap();
648        assert!(gpt35.supports_tools);
649        assert!(!gpt35.supports_vision);
650    }
651    
652    #[test]
653    fn test_azure_endpoint() {
654        use crate::auth::FileAuthStorage;
655        use tempfile::tempdir;
656        
657        let temp_dir = tempdir().unwrap();
658        let auth_path = temp_dir.path().join("auth.json");
659        let storage = FileAuthStorage::new(auth_path);
660        let auth = Box::new(crate::auth::AnthropicAuth::new(Box::new(storage))); // Dummy auth
661        
662        let provider = AzureOpenAIProvider::new(
663            auth,
664            "https://test.openai.azure.com".to_string(),
665            "gpt-4".to_string(),
666            "2024-02-15-preview".to_string(),
667        );
668        
669        let endpoint = provider.get_endpoint();
670        assert!(endpoint.contains("openai/deployments/gpt-4"));
671        assert!(endpoint.contains("api-version=2024-02-15-preview"));
672    }
673}