1use async_trait::async_trait;
7use chrono::Utc;
8use futures::{Stream, StreamExt};
9use paladin_core::platform::container::content::{ContentItem, ContentType};
10use paladin_core::platform::container::prompt::{PromptItem, PromptRole, PromptType};
11use paladin_ports::output::llm_port::{
12 FinishReason, LlmError, LlmPort, LlmRequest, LlmResponse, ProviderCapabilities,
13 StreamingResponse, TokenUsage,
14};
15use rand::Rng;
16use reqwest::Client;
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19use std::env;
20use std::pin::Pin;
21use std::time::Duration;
22use uuid::Uuid;
23
24#[derive(Debug, Clone)]
26pub struct OpenAIConfig {
27 pub api_key: String,
29 pub base_url: String,
31 pub organization: Option<String>,
33 pub timeout_seconds: u64,
35 pub max_retries: u32,
37}
38
39impl OpenAIConfig {
40 pub fn from_env() -> Result<Self, String> {
51 let api_key = env::var("OPENAI_API_KEY")
52 .map_err(|_| "OPENAI_API_KEY environment variable not set")?;
53
54 let base_url =
55 env::var("OPENAI_BASE_URL").unwrap_or_else(|_| "https://api.openai.com/v1".to_string());
56
57 let organization = env::var("OPENAI_ORGANIZATION").ok();
58
59 let timeout_seconds = env::var("OPENAI_TIMEOUT_SECONDS")
60 .unwrap_or_else(|_| "300".to_string())
61 .parse()
62 .map_err(|_| "Invalid OPENAI_TIMEOUT_SECONDS value")?;
63
64 let max_retries = env::var("OPENAI_MAX_RETRIES")
65 .unwrap_or_else(|_| "3".to_string())
66 .parse()
67 .map_err(|_| "Invalid OPENAI_MAX_RETRIES value")?;
68
69 Ok(Self {
70 api_key,
71 base_url,
72 organization,
73 timeout_seconds,
74 max_retries,
75 })
76 }
77
78 pub fn new(api_key: String) -> Self {
80 Self {
81 api_key,
82 base_url: "https://api.openai.com/v1".to_string(),
83 organization: None,
84 timeout_seconds: 300,
85 max_retries: 3,
86 }
87 }
88
89 pub fn validate(&self) -> Result<(), String> {
91 if self.api_key.is_empty() {
92 return Err("API key cannot be empty".to_string());
93 }
94 if self.base_url.is_empty() {
95 return Err("Base URL cannot be empty".to_string());
96 }
97 if !self.base_url.starts_with("http") {
98 return Err("Base URL must start with http or https".to_string());
99 }
100 Ok(())
101 }
102}
103
104#[derive(Debug, Serialize)]
109struct OpenAIRequest {
110 model: String,
111 messages: Vec<OpenAIMessage>,
112 #[serde(skip_serializing_if = "Option::is_none")]
113 temperature: Option<f32>,
114 #[serde(skip_serializing_if = "Option::is_none")]
115 max_tokens: Option<u32>,
116 #[serde(skip_serializing_if = "Option::is_none")]
117 top_p: Option<f32>,
118 stream: bool,
119}
120
121#[derive(Debug, Serialize, Deserialize)]
122struct OpenAIMessage {
123 role: String,
124 content: String,
125}
126
127#[derive(Debug, Deserialize)]
128struct OpenAIResponse {
129 #[allow(dead_code)]
130 id: String,
131 model: String,
132 choices: Vec<OpenAIChoice>,
133 usage: OpenAIUsage,
134}
135
136#[derive(Debug, Deserialize)]
137struct OpenAIChoice {
138 #[allow(dead_code)]
139 index: u32,
140 message: OpenAIMessage,
141 finish_reason: Option<String>,
142}
143
144#[derive(Debug, Deserialize)]
145struct OpenAIUsage {
146 prompt_tokens: u32,
147 completion_tokens: u32,
148 total_tokens: u32,
149}
150
151#[derive(Debug, Deserialize)]
152struct OpenAIStreamChunk {
153 #[allow(dead_code)]
154 id: String,
155 choices: Vec<OpenAIStreamChoice>,
156}
157
158#[derive(Debug, Deserialize)]
159struct OpenAIStreamChoice {
160 #[allow(dead_code)]
161 index: u32,
162 delta: OpenAIStreamDelta,
163 finish_reason: Option<String>,
164}
165
166#[derive(Debug, Deserialize)]
167struct OpenAIStreamDelta {
168 #[allow(dead_code)]
169 role: Option<String>,
170 content: Option<String>,
171}
172
173pub struct OpenAIAdapter {
179 pub(crate) config: OpenAIConfig,
180 pub(crate) client: Client,
181}
182
183impl OpenAIAdapter {
184 pub fn new(config: OpenAIConfig) -> Result<Self, String> {
186 config.validate()?;
187 let client = Client::builder()
188 .timeout(Duration::from_secs(config.timeout_seconds))
189 .build()
190 .map_err(|e| format!("Failed to create HTTP client: {}", e))?;
191 Ok(Self { config, client })
192 }
193
194 pub fn from_env() -> Result<Self, String> {
196 Self::new(OpenAIConfig::from_env()?)
197 }
198
199 fn convert_to_messages(
201 &self,
202 prompt: &PromptItem,
203 attachments: &[ContentItem],
204 ) -> Result<Vec<OpenAIMessage>, LlmError> {
205 let mut messages = Vec::new();
206
207 match prompt.prompt_type() {
208 PromptType::System(system_prompt) => {
209 let mut content = system_prompt.instructions.clone();
210 if let Some(constraints) = &system_prompt.constraints
211 && !constraints.is_empty()
212 {
213 content.push_str("\n\nConstraints:\n");
214 for constraint in constraints {
215 content.push_str(&format!("- {}\n", constraint));
216 }
217 }
218 messages.push(OpenAIMessage {
219 role: "system".to_string(),
220 content,
221 });
222 }
223 PromptType::User(user_prompt) => {
224 messages.push(OpenAIMessage {
225 role: "user".to_string(),
226 content: user_prompt.context.clone().unwrap_or_default(),
227 });
228 }
229 PromptType::Assistant(assistant_prompt) => {
230 let mut content = assistant_prompt.response.clone();
231 if let Some(reasoning) = &assistant_prompt.reasoning {
232 content.push_str(&format!("\n\nReasoning: {}", reasoning));
233 }
234 messages.push(OpenAIMessage {
235 role: "assistant".to_string(),
236 content,
237 });
238 }
239 PromptType::Text(text_prompt) => {
240 let role = match text_prompt.role {
241 PromptRole::System => "system",
242 PromptRole::User => "user",
243 PromptRole::Assistant => "assistant",
244 PromptRole::Function => "function",
245 };
246 messages.push(OpenAIMessage {
247 role: role.to_string(),
248 content: text_prompt.content.clone(),
249 });
250 }
251 PromptType::Function(function_prompt) => {
252 messages.push(OpenAIMessage {
253 role: "function".to_string(),
254 content: function_prompt.function_name.clone(),
255 });
256 }
257 }
258
259 for content in attachments {
260 if let Ok(content_text) = self.convert_content_to_text(content)
261 && !content_text.is_empty()
262 {
263 messages.push(OpenAIMessage {
264 role: "user".to_string(),
265 content: format!("Content to analyze:\n{}", content_text),
266 });
267 }
268 }
269
270 Ok(messages)
271 }
272
273 fn convert_content_to_text(&self, content: &ContentItem) -> Result<String, LlmError> {
274 match content.content() {
275 ContentType::Text(text_content) => {
276 Ok(text_content.content.as_deref().unwrap_or("").to_string())
277 }
278 ContentType::Video(video_content) => Ok(format!(
279 "Video: {} (Duration: {}s)",
280 content.title().unwrap_or(&"Untitled".to_string()),
281 video_content.duration
282 )),
283 ContentType::Audio(audio_content) => Ok(format!(
284 "Audio: {} (Duration: {}s)",
285 content.title().unwrap_or(&"Untitled".to_string()),
286 audio_content.duration
287 )),
288 ContentType::Image(image_content) => Ok(format!(
289 "Image: {} ({}x{})",
290 content.title().unwrap_or(&"Untitled".to_string()),
291 image_content.resolution.0,
292 image_content.resolution.1
293 )),
294 }
295 }
296
297 fn convert_finish_reason(&self, reason: Option<String>) -> FinishReason {
298 match reason.as_deref() {
299 Some("stop") => FinishReason::Stop,
300 Some("length") => FinishReason::Length,
301 Some("content_filter") => FinishReason::ContentFilter,
302 Some("function_call") => FinishReason::FunctionCall,
303 Some(other) => FinishReason::Error(format!("Unknown: {}", other)),
304 None => FinishReason::Stop,
305 }
306 }
307
308 async fn make_request_with_retries(
309 &self,
310 request: &OpenAIRequest,
311 ) -> Result<OpenAIResponse, LlmError> {
312 let mut last_error = None;
313
314 for attempt in 0..=self.config.max_retries {
315 match self.make_single_request(request).await {
316 Ok(response) => return Ok(response),
317 Err(e) => {
318 last_error = Some(e.clone());
319
320 if matches!(e, LlmError::AuthenticationError(_)) {
321 return Err(e);
322 }
323
324 if attempt < self.config.max_retries {
325 let base_delay = Duration::from_secs(1);
326 let exponential_delay = base_delay * 2_u32.pow(attempt);
327 let max_delay = Duration::from_secs(10);
328 let delay = exponential_delay.min(max_delay);
329
330 let jitter_ms = {
331 let mut rng = rand::thread_rng();
332 rng.gen_range(0..=(delay.as_millis() / 5)) as u64
333 };
334 let total_delay = delay + Duration::from_millis(jitter_ms);
335
336 tokio::time::sleep(total_delay).await;
337 }
338 }
339 }
340 }
341
342 Err(last_error
343 .unwrap_or_else(|| LlmError::ProcessingError("Maximum retries exceeded".to_string())))
344 }
345
346 async fn make_single_request(
347 &self,
348 request: &OpenAIRequest,
349 ) -> Result<OpenAIResponse, LlmError> {
350 let url = format!("{}/chat/completions", self.config.base_url);
351
352 let mut req = self
353 .client
354 .post(&url)
355 .header("Authorization", format!("Bearer {}", self.config.api_key))
356 .header("Content-Type", "application/json");
357
358 if let Some(org) = &self.config.organization {
359 req = req.header("OpenAI-Organization", org);
360 }
361
362 let response = req
363 .json(request)
364 .send()
365 .await
366 .map_err(|e| LlmError::NetworkError(format!("Request failed: {}", e)))?;
367
368 let status = response.status();
369 let response_text = response
370 .text()
371 .await
372 .map_err(|e| LlmError::ProcessingError(format!("Failed to read response: {}", e)))?;
373
374 if !status.is_success() {
375 return match status.as_u16() {
376 401 => Err(LlmError::AuthenticationError(
377 "Invalid OpenAI API key".to_string(),
378 )),
379 429 => Err(LlmError::RateLimitExceeded),
380 400 => {
381 if response_text.contains("maximum context length") {
382 Err(LlmError::TokenLimitExceeded)
383 } else {
384 Err(LlmError::InvalidPrompt(response_text))
385 }
386 }
387 500..=599 => Err(LlmError::ProcessingError(format!(
388 "OpenAI server error: {}",
389 response_text
390 ))),
391 _ => Err(LlmError::ProcessingError(format!(
392 "HTTP {}: {}",
393 status, response_text
394 ))),
395 };
396 }
397
398 serde_json::from_str::<OpenAIResponse>(&response_text)
399 .map_err(|e| LlmError::ProcessingError(format!("Failed to parse response: {}", e)))
400 }
401
402 async fn make_streaming_request(
403 &self,
404 request: &OpenAIRequest,
405 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamingResponse, LlmError>> + Send>>, LlmError>
406 {
407 let url = format!("{}/chat/completions", self.config.base_url);
408
409 let mut req = self
410 .client
411 .post(&url)
412 .header("Authorization", format!("Bearer {}", self.config.api_key))
413 .header("Content-Type", "application/json");
414
415 if let Some(org) = &self.config.organization {
416 req = req.header("OpenAI-Organization", org);
417 }
418
419 let response = req
420 .json(request)
421 .send()
422 .await
423 .map_err(|e| LlmError::NetworkError(format!("Request failed: {}", e)))?;
424
425 if !response.status().is_success() {
426 let status = response.status();
427 let error_text = response.text().await.unwrap_or_default();
428 return Err(match status.as_u16() {
429 401 => LlmError::AuthenticationError("Invalid OpenAI API key".to_string()),
430 429 => LlmError::RateLimitExceeded,
431 400 => LlmError::InvalidPrompt(error_text),
432 _ => LlmError::ProcessingError(format!("HTTP {}: {}", status, error_text)),
433 });
434 }
435
436 let stream = response.bytes_stream().map(|chunk_result| {
437 chunk_result
438 .map_err(|e| LlmError::NetworkError(format!("Stream error: {}", e)))
439 .and_then(|chunk| {
440 let chunk_str = String::from_utf8_lossy(&chunk);
441
442 for line in chunk_str.lines() {
443 if let Some(data) = line.strip_prefix("data: ") {
444 if data == "[DONE]" {
445 return Ok(StreamingResponse {
446 id: Uuid::new_v4(),
447 delta: String::new(),
448 finish_reason: Some(FinishReason::Stop),
449 });
450 }
451
452 match serde_json::from_str::<OpenAIStreamChunk>(data) {
453 Ok(chunk) => {
454 if let Some(choice) = chunk.choices.first() {
455 let delta =
456 choice.delta.content.clone().unwrap_or_default();
457 let finish_reason =
458 choice.finish_reason.as_ref().map(|r| {
459 match r.as_str() {
460 "stop" => FinishReason::Stop,
461 "length" => FinishReason::Length,
462 "content_filter" => FinishReason::ContentFilter,
463 "function_call" => FinishReason::FunctionCall,
464 other => FinishReason::Error(format!(
465 "Unknown: {}",
466 other
467 )),
468 }
469 });
470
471 return Ok(StreamingResponse {
472 id: Uuid::new_v4(),
473 delta,
474 finish_reason,
475 });
476 }
477 }
478 Err(e) => {
479 return Err(LlmError::ProcessingError(format!(
480 "Failed to parse stream chunk: {}",
481 e
482 )));
483 }
484 }
485 }
486 }
487
488 Ok(StreamingResponse {
489 id: Uuid::new_v4(),
490 delta: String::new(),
491 finish_reason: None,
492 })
493 })
494 });
495
496 Ok(Box::pin(stream))
497 }
498}
499
500#[async_trait]
501impl LlmPort for OpenAIAdapter {
502 async fn generate(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
503 let messages = self.convert_to_messages(&request.prompt, &request.attachments)?;
504
505 let temperature = request
506 .prompt
507 .node
508 .node
509 .parameters
510 .temperature
511 .unwrap_or(0.7);
512 let max_tokens = request
513 .prompt
514 .node
515 .node
516 .parameters
517 .max_tokens
518 .unwrap_or(4096);
519
520 let openai_request = OpenAIRequest {
521 model: request.model.clone(),
522 messages,
523 temperature: Some(temperature),
524 max_tokens: Some(max_tokens),
525 top_p: Some(1.0),
526 stream: false,
527 };
528
529 let response = self.make_request_with_retries(&openai_request).await?;
530
531 if response.choices.is_empty() {
532 return Err(LlmError::ProcessingError(
533 "No choices in response".to_string(),
534 ));
535 }
536
537 let choice = &response.choices[0];
538 let finish_reason = self.convert_finish_reason(choice.finish_reason.clone());
539
540 Ok(LlmResponse {
541 id: Uuid::new_v4(),
542 request_id: request.id,
543 model: response.model,
544 content: choice.message.content.clone(),
545 finish_reason,
546 usage: TokenUsage {
547 prompt_tokens: response.usage.prompt_tokens,
548 completion_tokens: response.usage.completion_tokens,
549 total_tokens: response.usage.total_tokens,
550 },
551 created_at: Utc::now(),
552 metadata: HashMap::new(),
553 function_call: None,
554 })
555 }
556
557 async fn generate_stream(
558 &self,
559 request: LlmRequest,
560 ) -> Result<Box<dyn Stream<Item = Result<StreamingResponse, LlmError>> + Send>, LlmError> {
561 let messages = self.convert_to_messages(&request.prompt, &request.attachments)?;
562
563 let temperature = request
564 .prompt
565 .node
566 .node
567 .parameters
568 .temperature
569 .unwrap_or(0.7);
570 let max_tokens = request
571 .prompt
572 .node
573 .node
574 .parameters
575 .max_tokens
576 .unwrap_or(4096);
577
578 let openai_request = OpenAIRequest {
579 model: request.model.clone(),
580 messages,
581 temperature: Some(temperature),
582 max_tokens: Some(max_tokens),
583 top_p: Some(1.0),
584 stream: true,
585 };
586
587 let stream = self.make_streaming_request(&openai_request).await?;
588 Ok(Box::new(stream))
589 }
590
591 async fn validate_model(&self, model: &str) -> Result<bool, LlmError> {
592 let available_models = self.get_available_models().await?;
593 Ok(available_models.contains(&model.to_string()))
594 }
595
596 async fn get_available_models(&self) -> Result<Vec<String>, LlmError> {
597 let url = format!("{}/models", self.config.base_url);
598
599 let mut req = self
600 .client
601 .get(&url)
602 .header("Authorization", format!("Bearer {}", self.config.api_key));
603
604 if let Some(org) = &self.config.organization {
605 req = req.header("OpenAI-Organization", org);
606 }
607
608 let response = req
609 .send()
610 .await
611 .map_err(|e| LlmError::NetworkError(format!("Failed to fetch models: {}", e)))?;
612
613 if !response.status().is_success() {
614 return Err(LlmError::ProcessingError(format!(
615 "HTTP {}",
616 response.status()
617 )));
618 }
619
620 let response_text = response
621 .text()
622 .await
623 .map_err(|e| LlmError::ProcessingError(format!("Failed to read response: {}", e)))?;
624
625 let models_response: serde_json::Value = serde_json::from_str(&response_text)
626 .map_err(|e| LlmError::ProcessingError(format!("Failed to parse response: {}", e)))?;
627
628 let models = models_response["data"]
629 .as_array()
630 .ok_or_else(|| LlmError::ProcessingError("Invalid models response format".to_string()))?
631 .iter()
632 .filter_map(|model| model["id"].as_str().map(String::from))
633 .collect();
634
635 Ok(models)
636 }
637
638 fn get_provider_name(&self) -> &'static str {
639 "openai"
640 }
641
642 fn get_capabilities(&self) -> ProviderCapabilities {
643 ProviderCapabilities {
644 supports_streaming: true,
645 supports_tool_calling: true,
646 supports_function_calling: true,
647 supports_vision: true,
648 max_context_tokens: Some(128000),
649 supports_embeddings: true,
650 supports_system_messages: true,
651 }
652 }
653}
654
655#[cfg(test)]
656mod tests {
657 use super::*;
658
659 #[test]
660 fn test_config_creation() {
661 let config = OpenAIConfig::new("test-key".to_string());
662 assert_eq!(config.api_key, "test-key");
663 assert_eq!(config.base_url, "https://api.openai.com/v1");
664 assert_eq!(config.timeout_seconds, 300);
665 assert_eq!(config.max_retries, 3);
666 }
667
668 #[test]
669 fn test_config_validation() {
670 let valid_config = OpenAIConfig::new("test-key".to_string());
671 assert!(valid_config.validate().is_ok());
672
673 let invalid_config = OpenAIConfig {
674 api_key: String::new(),
675 base_url: "https://api.openai.com/v1".to_string(),
676 organization: None,
677 timeout_seconds: 300,
678 max_retries: 3,
679 };
680 assert!(invalid_config.validate().is_err());
681 }
682
683 #[test]
684 fn test_adapter_creation() {
685 let config = OpenAIConfig::new("test-key".to_string());
686 let adapter = OpenAIAdapter::new(config);
687 assert!(adapter.is_ok());
688 }
689
690 #[test]
691 fn test_get_provider_name() {
692 let config = OpenAIConfig::new("test-key".to_string());
693 let adapter = OpenAIAdapter::new(config).unwrap();
694 assert_eq!(adapter.get_provider_name(), "openai");
695 }
696
697 #[test]
698 fn test_get_capabilities() {
699 let config = OpenAIConfig::new("test-key".to_string());
700 let adapter = OpenAIAdapter::new(config).unwrap();
701 let caps = adapter.get_capabilities();
702 assert!(caps.supports_streaming);
703 assert!(caps.supports_tool_calling);
704 assert!(caps.supports_vision);
705 assert_eq!(caps.max_context_tokens, Some(128000));
706 }
707
708 #[test]
709 fn test_config_with_organization() {
710 let mut config = OpenAIConfig::new("test-key".to_string());
711 config.organization = Some("org-123".to_string());
712 assert_eq!(config.organization, Some("org-123".to_string()));
713 }
714
715 #[test]
716 fn test_config_validation_empty_base_url() {
717 let config = OpenAIConfig {
718 api_key: "test-key".to_string(),
719 base_url: String::new(),
720 organization: None,
721 timeout_seconds: 300,
722 max_retries: 3,
723 };
724 assert!(config.validate().is_err());
725 }
726}