Skip to main content

paladin_llm/openai/
adapter.rs

1//! OpenAI GPT adapter.
2//!
3//! Implements [`LlmPort`] for the OpenAI Chat Completions API.
4//! Supports GPT-3.5-Turbo, GPT-4, GPT-4o, and other compatible models.
5
6use async_trait::async_trait;
7use chrono::Utc;
8use futures::{Stream, StreamExt};
9use paladin_core::platform::container::content::{ContentItem, ContentType};
10use paladin_core::platform::container::prompt::{PromptItem, PromptRole, PromptType};
11use paladin_ports::output::llm_port::{
12    FinishReason, LlmError, LlmPort, LlmRequest, LlmResponse, ProviderCapabilities,
13    StreamingResponse, TokenUsage,
14};
15use rand::Rng;
16use reqwest::Client;
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19use std::env;
20use std::pin::Pin;
21use std::time::Duration;
22use uuid::Uuid;
23
24/// Configuration for the OpenAI adapter.
25#[derive(Debug, Clone)]
26pub struct OpenAIConfig {
27    /// OpenAI API key.
28    pub api_key: String,
29    /// Base URL for the API (default: `https://api.openai.com/v1`).
30    pub base_url: String,
31    /// Optional organisation ID.
32    pub organization: Option<String>,
33    /// Request timeout in seconds (default: 300).
34    pub timeout_seconds: u64,
35    /// Maximum retry attempts (default: 3).
36    pub max_retries: u32,
37}
38
39impl OpenAIConfig {
40    /// Load configuration from environment variables.
41    ///
42    /// Required:
43    /// - `OPENAI_API_KEY`
44    ///
45    /// Optional:
46    /// - `OPENAI_BASE_URL` (default: `https://api.openai.com/v1`)
47    /// - `OPENAI_ORGANIZATION`
48    /// - `OPENAI_TIMEOUT_SECONDS` (default: 300)
49    /// - `OPENAI_MAX_RETRIES` (default: 3)
50    pub fn from_env() -> Result<Self, String> {
51        let api_key = env::var("OPENAI_API_KEY")
52            .map_err(|_| "OPENAI_API_KEY environment variable not set")?;
53
54        let base_url =
55            env::var("OPENAI_BASE_URL").unwrap_or_else(|_| "https://api.openai.com/v1".to_string());
56
57        let organization = env::var("OPENAI_ORGANIZATION").ok();
58
59        let timeout_seconds = env::var("OPENAI_TIMEOUT_SECONDS")
60            .unwrap_or_else(|_| "300".to_string())
61            .parse()
62            .map_err(|_| "Invalid OPENAI_TIMEOUT_SECONDS value")?;
63
64        let max_retries = env::var("OPENAI_MAX_RETRIES")
65            .unwrap_or_else(|_| "3".to_string())
66            .parse()
67            .map_err(|_| "Invalid OPENAI_MAX_RETRIES value")?;
68
69        Ok(Self {
70            api_key,
71            base_url,
72            organization,
73            timeout_seconds,
74            max_retries,
75        })
76    }
77
78    /// Create a configuration with the given API key and sensible defaults.
79    pub fn new(api_key: String) -> Self {
80        Self {
81            api_key,
82            base_url: "https://api.openai.com/v1".to_string(),
83            organization: None,
84            timeout_seconds: 300,
85            max_retries: 3,
86        }
87    }
88
89    /// Validate the configuration fields.
90    pub fn validate(&self) -> Result<(), String> {
91        if self.api_key.is_empty() {
92            return Err("API key cannot be empty".to_string());
93        }
94        if self.base_url.is_empty() {
95            return Err("Base URL cannot be empty".to_string());
96        }
97        if !self.base_url.starts_with("http") {
98            return Err("Base URL must start with http or https".to_string());
99        }
100        Ok(())
101    }
102}
103
104// ---------------------------------------------------------------------------
105// Internal API structures
106// ---------------------------------------------------------------------------
107
108#[derive(Debug, Serialize)]
109struct OpenAIRequest {
110    model: String,
111    messages: Vec<OpenAIMessage>,
112    #[serde(skip_serializing_if = "Option::is_none")]
113    temperature: Option<f32>,
114    #[serde(skip_serializing_if = "Option::is_none")]
115    max_tokens: Option<u32>,
116    #[serde(skip_serializing_if = "Option::is_none")]
117    top_p: Option<f32>,
118    stream: bool,
119}
120
121#[derive(Debug, Serialize, Deserialize)]
122struct OpenAIMessage {
123    role: String,
124    content: String,
125}
126
127#[derive(Debug, Deserialize)]
128struct OpenAIResponse {
129    #[allow(dead_code)]
130    id: String,
131    model: String,
132    choices: Vec<OpenAIChoice>,
133    usage: OpenAIUsage,
134}
135
136#[derive(Debug, Deserialize)]
137struct OpenAIChoice {
138    #[allow(dead_code)]
139    index: u32,
140    message: OpenAIMessage,
141    finish_reason: Option<String>,
142}
143
144#[derive(Debug, Deserialize)]
145struct OpenAIUsage {
146    prompt_tokens: u32,
147    completion_tokens: u32,
148    total_tokens: u32,
149}
150
151#[derive(Debug, Deserialize)]
152struct OpenAIStreamChunk {
153    #[allow(dead_code)]
154    id: String,
155    choices: Vec<OpenAIStreamChoice>,
156}
157
158#[derive(Debug, Deserialize)]
159struct OpenAIStreamChoice {
160    #[allow(dead_code)]
161    index: u32,
162    delta: OpenAIStreamDelta,
163    finish_reason: Option<String>,
164}
165
166#[derive(Debug, Deserialize)]
167struct OpenAIStreamDelta {
168    #[allow(dead_code)]
169    role: Option<String>,
170    content: Option<String>,
171}
172
173// ---------------------------------------------------------------------------
174// Adapter
175// ---------------------------------------------------------------------------
176
177/// OpenAI LLM adapter implementing [`LlmPort`].
178pub struct OpenAIAdapter {
179    pub(crate) config: OpenAIConfig,
180    pub(crate) client: Client,
181}
182
183impl OpenAIAdapter {
184    /// Create a new adapter from explicit configuration.
185    pub fn new(config: OpenAIConfig) -> Result<Self, String> {
186        config.validate()?;
187        let client = Client::builder()
188            .timeout(Duration::from_secs(config.timeout_seconds))
189            .build()
190            .map_err(|e| format!("Failed to create HTTP client: {}", e))?;
191        Ok(Self { config, client })
192    }
193
194    /// Create an adapter by loading configuration from environment variables.
195    pub fn from_env() -> Result<Self, String> {
196        Self::new(OpenAIConfig::from_env()?)
197    }
198
199    /// Convert a [`PromptItem`] and optional attachments into OpenAI messages.
200    fn convert_to_messages(
201        &self,
202        prompt: &PromptItem,
203        attachments: &[ContentItem],
204    ) -> Result<Vec<OpenAIMessage>, LlmError> {
205        let mut messages = Vec::new();
206
207        match prompt.prompt_type() {
208            PromptType::System(system_prompt) => {
209                let mut content = system_prompt.instructions.clone();
210                if let Some(constraints) = &system_prompt.constraints
211                    && !constraints.is_empty()
212                {
213                    content.push_str("\n\nConstraints:\n");
214                    for constraint in constraints {
215                        content.push_str(&format!("- {}\n", constraint));
216                    }
217                }
218                messages.push(OpenAIMessage {
219                    role: "system".to_string(),
220                    content,
221                });
222            }
223            PromptType::User(user_prompt) => {
224                messages.push(OpenAIMessage {
225                    role: "user".to_string(),
226                    content: user_prompt.context.clone().unwrap_or_default(),
227                });
228            }
229            PromptType::Assistant(assistant_prompt) => {
230                let mut content = assistant_prompt.response.clone();
231                if let Some(reasoning) = &assistant_prompt.reasoning {
232                    content.push_str(&format!("\n\nReasoning: {}", reasoning));
233                }
234                messages.push(OpenAIMessage {
235                    role: "assistant".to_string(),
236                    content,
237                });
238            }
239            PromptType::Text(text_prompt) => {
240                let role = match text_prompt.role {
241                    PromptRole::System => "system",
242                    PromptRole::User => "user",
243                    PromptRole::Assistant => "assistant",
244                    PromptRole::Function => "function",
245                };
246                messages.push(OpenAIMessage {
247                    role: role.to_string(),
248                    content: text_prompt.content.clone(),
249                });
250            }
251            PromptType::Function(function_prompt) => {
252                messages.push(OpenAIMessage {
253                    role: "function".to_string(),
254                    content: function_prompt.function_name.clone(),
255                });
256            }
257        }
258
259        for content in attachments {
260            if let Ok(content_text) = self.convert_content_to_text(content)
261                && !content_text.is_empty()
262            {
263                messages.push(OpenAIMessage {
264                    role: "user".to_string(),
265                    content: format!("Content to analyze:\n{}", content_text),
266                });
267            }
268        }
269
270        Ok(messages)
271    }
272
273    fn convert_content_to_text(&self, content: &ContentItem) -> Result<String, LlmError> {
274        match content.content() {
275            ContentType::Text(text_content) => {
276                Ok(text_content.content.as_deref().unwrap_or("").to_string())
277            }
278            ContentType::Video(video_content) => Ok(format!(
279                "Video: {} (Duration: {}s)",
280                content.title().unwrap_or(&"Untitled".to_string()),
281                video_content.duration
282            )),
283            ContentType::Audio(audio_content) => Ok(format!(
284                "Audio: {} (Duration: {}s)",
285                content.title().unwrap_or(&"Untitled".to_string()),
286                audio_content.duration
287            )),
288            ContentType::Image(image_content) => Ok(format!(
289                "Image: {} ({}x{})",
290                content.title().unwrap_or(&"Untitled".to_string()),
291                image_content.resolution.0,
292                image_content.resolution.1
293            )),
294        }
295    }
296
297    fn convert_finish_reason(&self, reason: Option<String>) -> FinishReason {
298        match reason.as_deref() {
299            Some("stop") => FinishReason::Stop,
300            Some("length") => FinishReason::Length,
301            Some("content_filter") => FinishReason::ContentFilter,
302            Some("function_call") => FinishReason::FunctionCall,
303            Some(other) => FinishReason::Error(format!("Unknown: {}", other)),
304            None => FinishReason::Stop,
305        }
306    }
307
308    async fn make_request_with_retries(
309        &self,
310        request: &OpenAIRequest,
311    ) -> Result<OpenAIResponse, LlmError> {
312        let mut last_error = None;
313
314        for attempt in 0..=self.config.max_retries {
315            match self.make_single_request(request).await {
316                Ok(response) => return Ok(response),
317                Err(e) => {
318                    last_error = Some(e.clone());
319
320                    if matches!(e, LlmError::AuthenticationError(_)) {
321                        return Err(e);
322                    }
323
324                    if attempt < self.config.max_retries {
325                        let base_delay = Duration::from_secs(1);
326                        let exponential_delay = base_delay * 2_u32.pow(attempt);
327                        let max_delay = Duration::from_secs(10);
328                        let delay = exponential_delay.min(max_delay);
329
330                        let jitter_ms = {
331                            let mut rng = rand::thread_rng();
332                            rng.gen_range(0..=(delay.as_millis() / 5)) as u64
333                        };
334                        let total_delay = delay + Duration::from_millis(jitter_ms);
335
336                        tokio::time::sleep(total_delay).await;
337                    }
338                }
339            }
340        }
341
342        Err(last_error
343            .unwrap_or_else(|| LlmError::ProcessingError("Maximum retries exceeded".to_string())))
344    }
345
346    async fn make_single_request(
347        &self,
348        request: &OpenAIRequest,
349    ) -> Result<OpenAIResponse, LlmError> {
350        let url = format!("{}/chat/completions", self.config.base_url);
351
352        let mut req = self
353            .client
354            .post(&url)
355            .header("Authorization", format!("Bearer {}", self.config.api_key))
356            .header("Content-Type", "application/json");
357
358        if let Some(org) = &self.config.organization {
359            req = req.header("OpenAI-Organization", org);
360        }
361
362        let response = req
363            .json(request)
364            .send()
365            .await
366            .map_err(|e| LlmError::NetworkError(format!("Request failed: {}", e)))?;
367
368        let status = response.status();
369        let response_text = response
370            .text()
371            .await
372            .map_err(|e| LlmError::ProcessingError(format!("Failed to read response: {}", e)))?;
373
374        if !status.is_success() {
375            return match status.as_u16() {
376                401 => Err(LlmError::AuthenticationError(
377                    "Invalid OpenAI API key".to_string(),
378                )),
379                429 => Err(LlmError::RateLimitExceeded),
380                400 => {
381                    if response_text.contains("maximum context length") {
382                        Err(LlmError::TokenLimitExceeded)
383                    } else {
384                        Err(LlmError::InvalidPrompt(response_text))
385                    }
386                }
387                500..=599 => Err(LlmError::ProcessingError(format!(
388                    "OpenAI server error: {}",
389                    response_text
390                ))),
391                _ => Err(LlmError::ProcessingError(format!(
392                    "HTTP {}: {}",
393                    status, response_text
394                ))),
395            };
396        }
397
398        serde_json::from_str::<OpenAIResponse>(&response_text)
399            .map_err(|e| LlmError::ProcessingError(format!("Failed to parse response: {}", e)))
400    }
401
402    async fn make_streaming_request(
403        &self,
404        request: &OpenAIRequest,
405    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamingResponse, LlmError>> + Send>>, LlmError>
406    {
407        let url = format!("{}/chat/completions", self.config.base_url);
408
409        let mut req = self
410            .client
411            .post(&url)
412            .header("Authorization", format!("Bearer {}", self.config.api_key))
413            .header("Content-Type", "application/json");
414
415        if let Some(org) = &self.config.organization {
416            req = req.header("OpenAI-Organization", org);
417        }
418
419        let response = req
420            .json(request)
421            .send()
422            .await
423            .map_err(|e| LlmError::NetworkError(format!("Request failed: {}", e)))?;
424
425        if !response.status().is_success() {
426            let status = response.status();
427            let error_text = response.text().await.unwrap_or_default();
428            return Err(match status.as_u16() {
429                401 => LlmError::AuthenticationError("Invalid OpenAI API key".to_string()),
430                429 => LlmError::RateLimitExceeded,
431                400 => LlmError::InvalidPrompt(error_text),
432                _ => LlmError::ProcessingError(format!("HTTP {}: {}", status, error_text)),
433            });
434        }
435
436        let stream = response.bytes_stream().map(|chunk_result| {
437            chunk_result
438                .map_err(|e| LlmError::NetworkError(format!("Stream error: {}", e)))
439                .and_then(|chunk| {
440                    let chunk_str = String::from_utf8_lossy(&chunk);
441
442                    for line in chunk_str.lines() {
443                        if let Some(data) = line.strip_prefix("data: ") {
444                            if data == "[DONE]" {
445                                return Ok(StreamingResponse {
446                                    id: Uuid::new_v4(),
447                                    delta: String::new(),
448                                    finish_reason: Some(FinishReason::Stop),
449                                });
450                            }
451
452                            match serde_json::from_str::<OpenAIStreamChunk>(data) {
453                                Ok(chunk) => {
454                                    if let Some(choice) = chunk.choices.first() {
455                                        let delta =
456                                            choice.delta.content.clone().unwrap_or_default();
457                                        let finish_reason =
458                                            choice.finish_reason.as_ref().map(|r| {
459                                                match r.as_str() {
460                                                    "stop" => FinishReason::Stop,
461                                                    "length" => FinishReason::Length,
462                                                    "content_filter" => FinishReason::ContentFilter,
463                                                    "function_call" => FinishReason::FunctionCall,
464                                                    other => FinishReason::Error(format!(
465                                                        "Unknown: {}",
466                                                        other
467                                                    )),
468                                                }
469                                            });
470
471                                        return Ok(StreamingResponse {
472                                            id: Uuid::new_v4(),
473                                            delta,
474                                            finish_reason,
475                                        });
476                                    }
477                                }
478                                Err(e) => {
479                                    return Err(LlmError::ProcessingError(format!(
480                                        "Failed to parse stream chunk: {}",
481                                        e
482                                    )));
483                                }
484                            }
485                        }
486                    }
487
488                    Ok(StreamingResponse {
489                        id: Uuid::new_v4(),
490                        delta: String::new(),
491                        finish_reason: None,
492                    })
493                })
494        });
495
496        Ok(Box::pin(stream))
497    }
498}
499
500#[async_trait]
501impl LlmPort for OpenAIAdapter {
502    async fn generate(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
503        let messages = self.convert_to_messages(&request.prompt, &request.attachments)?;
504
505        let temperature = request
506            .prompt
507            .node
508            .node
509            .parameters
510            .temperature
511            .unwrap_or(0.7);
512        let max_tokens = request
513            .prompt
514            .node
515            .node
516            .parameters
517            .max_tokens
518            .unwrap_or(4096);
519
520        let openai_request = OpenAIRequest {
521            model: request.model.clone(),
522            messages,
523            temperature: Some(temperature),
524            max_tokens: Some(max_tokens),
525            top_p: Some(1.0),
526            stream: false,
527        };
528
529        let response = self.make_request_with_retries(&openai_request).await?;
530
531        if response.choices.is_empty() {
532            return Err(LlmError::ProcessingError(
533                "No choices in response".to_string(),
534            ));
535        }
536
537        let choice = &response.choices[0];
538        let finish_reason = self.convert_finish_reason(choice.finish_reason.clone());
539
540        Ok(LlmResponse {
541            id: Uuid::new_v4(),
542            request_id: request.id,
543            model: response.model,
544            content: choice.message.content.clone(),
545            finish_reason,
546            usage: TokenUsage {
547                prompt_tokens: response.usage.prompt_tokens,
548                completion_tokens: response.usage.completion_tokens,
549                total_tokens: response.usage.total_tokens,
550            },
551            created_at: Utc::now(),
552            metadata: HashMap::new(),
553            function_call: None,
554        })
555    }
556
557    async fn generate_stream(
558        &self,
559        request: LlmRequest,
560    ) -> Result<Box<dyn Stream<Item = Result<StreamingResponse, LlmError>> + Send>, LlmError> {
561        let messages = self.convert_to_messages(&request.prompt, &request.attachments)?;
562
563        let temperature = request
564            .prompt
565            .node
566            .node
567            .parameters
568            .temperature
569            .unwrap_or(0.7);
570        let max_tokens = request
571            .prompt
572            .node
573            .node
574            .parameters
575            .max_tokens
576            .unwrap_or(4096);
577
578        let openai_request = OpenAIRequest {
579            model: request.model.clone(),
580            messages,
581            temperature: Some(temperature),
582            max_tokens: Some(max_tokens),
583            top_p: Some(1.0),
584            stream: true,
585        };
586
587        let stream = self.make_streaming_request(&openai_request).await?;
588        Ok(Box::new(stream))
589    }
590
591    async fn validate_model(&self, model: &str) -> Result<bool, LlmError> {
592        let available_models = self.get_available_models().await?;
593        Ok(available_models.contains(&model.to_string()))
594    }
595
596    async fn get_available_models(&self) -> Result<Vec<String>, LlmError> {
597        let url = format!("{}/models", self.config.base_url);
598
599        let mut req = self
600            .client
601            .get(&url)
602            .header("Authorization", format!("Bearer {}", self.config.api_key));
603
604        if let Some(org) = &self.config.organization {
605            req = req.header("OpenAI-Organization", org);
606        }
607
608        let response = req
609            .send()
610            .await
611            .map_err(|e| LlmError::NetworkError(format!("Failed to fetch models: {}", e)))?;
612
613        if !response.status().is_success() {
614            return Err(LlmError::ProcessingError(format!(
615                "HTTP {}",
616                response.status()
617            )));
618        }
619
620        let response_text = response
621            .text()
622            .await
623            .map_err(|e| LlmError::ProcessingError(format!("Failed to read response: {}", e)))?;
624
625        let models_response: serde_json::Value = serde_json::from_str(&response_text)
626            .map_err(|e| LlmError::ProcessingError(format!("Failed to parse response: {}", e)))?;
627
628        let models = models_response["data"]
629            .as_array()
630            .ok_or_else(|| LlmError::ProcessingError("Invalid models response format".to_string()))?
631            .iter()
632            .filter_map(|model| model["id"].as_str().map(String::from))
633            .collect();
634
635        Ok(models)
636    }
637
638    fn get_provider_name(&self) -> &'static str {
639        "openai"
640    }
641
642    fn get_capabilities(&self) -> ProviderCapabilities {
643        ProviderCapabilities {
644            supports_streaming: true,
645            supports_tool_calling: true,
646            supports_function_calling: true,
647            supports_vision: true,
648            max_context_tokens: Some(128000),
649            supports_embeddings: true,
650            supports_system_messages: true,
651        }
652    }
653}
654
655#[cfg(test)]
656mod tests {
657    use super::*;
658
659    #[test]
660    fn test_config_creation() {
661        let config = OpenAIConfig::new("test-key".to_string());
662        assert_eq!(config.api_key, "test-key");
663        assert_eq!(config.base_url, "https://api.openai.com/v1");
664        assert_eq!(config.timeout_seconds, 300);
665        assert_eq!(config.max_retries, 3);
666    }
667
668    #[test]
669    fn test_config_validation() {
670        let valid_config = OpenAIConfig::new("test-key".to_string());
671        assert!(valid_config.validate().is_ok());
672
673        let invalid_config = OpenAIConfig {
674            api_key: String::new(),
675            base_url: "https://api.openai.com/v1".to_string(),
676            organization: None,
677            timeout_seconds: 300,
678            max_retries: 3,
679        };
680        assert!(invalid_config.validate().is_err());
681    }
682
683    #[test]
684    fn test_adapter_creation() {
685        let config = OpenAIConfig::new("test-key".to_string());
686        let adapter = OpenAIAdapter::new(config);
687        assert!(adapter.is_ok());
688    }
689
690    #[test]
691    fn test_get_provider_name() {
692        let config = OpenAIConfig::new("test-key".to_string());
693        let adapter = OpenAIAdapter::new(config).unwrap();
694        assert_eq!(adapter.get_provider_name(), "openai");
695    }
696
697    #[test]
698    fn test_get_capabilities() {
699        let config = OpenAIConfig::new("test-key".to_string());
700        let adapter = OpenAIAdapter::new(config).unwrap();
701        let caps = adapter.get_capabilities();
702        assert!(caps.supports_streaming);
703        assert!(caps.supports_tool_calling);
704        assert!(caps.supports_vision);
705        assert_eq!(caps.max_context_tokens, Some(128000));
706    }
707
708    #[test]
709    fn test_config_with_organization() {
710        let mut config = OpenAIConfig::new("test-key".to_string());
711        config.organization = Some("org-123".to_string());
712        assert_eq!(config.organization, Some("org-123".to_string()));
713    }
714
715    #[test]
716    fn test_config_validation_empty_base_url() {
717        let config = OpenAIConfig {
718            api_key: "test-key".to_string(),
719            base_url: String::new(),
720            organization: None,
721            timeout_seconds: 300,
722            max_retries: 3,
723        };
724        assert!(config.validate().is_err());
725    }
726}