Skip to main content

llm_stack_openai/
factory.rs

1//! Factory for building `OpenAI` providers from configuration.
2
3use llm_stack::registry::{ProviderConfig, ProviderFactory};
4use llm_stack::{DynProvider, LlmError};
5
6use crate::{OpenAiConfig, OpenAiProvider};
7
8/// Factory for creating [`OpenAiProvider`] instances from configuration.
9///
10/// Register this factory with the global registry to enable config-driven
11/// provider instantiation:
12///
13/// ```rust,no_run
14/// use llm_stack::ProviderRegistry;
15/// use llm_stack_openai::OpenAiFactory;
16///
17/// ProviderRegistry::global().register(Box::new(OpenAiFactory));
18/// ```
19///
20/// # Configuration
21///
22/// | Field | Required | Description |
23/// |-------|----------|-------------|
24/// | `provider` | Yes | Must be `"openai"` |
25/// | `api_key` | Yes | OpenAI API key |
26/// | `model` | Yes | Model identifier (e.g., `"gpt-4o"`) |
27/// | `base_url` | No | Custom API endpoint |
28/// | `timeout` | No | Request timeout |
29/// | `extra.organization` | No | OpenAI organization ID |
30#[derive(Debug, Clone, Copy, Default)]
31pub struct OpenAiFactory;
32
33impl ProviderFactory for OpenAiFactory {
34    fn name(&self) -> &'static str {
35        "openai"
36    }
37
38    fn build(&self, config: &ProviderConfig) -> Result<Box<dyn DynProvider>, LlmError> {
39        let api_key = config
40            .api_key
41            .clone()
42            .ok_or_else(|| LlmError::InvalidRequest("openai provider requires api_key".into()))?;
43
44        if config.model.is_empty() {
45            return Err(LlmError::InvalidRequest(
46                "openai provider requires model".into(),
47            ));
48        }
49
50        let mut openai_config = OpenAiConfig {
51            api_key,
52            model: config.model.clone(),
53            client: config.client.clone(),
54            ..Default::default()
55        };
56
57        if let Some(base_url) = &config.base_url {
58            openai_config.base_url.clone_from(base_url);
59        }
60
61        if let Some(timeout) = config.timeout {
62            openai_config.timeout = Some(timeout);
63        }
64
65        if let Some(organization) = config.get_extra_str("organization") {
66            openai_config.organization = Some(organization.to_string());
67        }
68
69        Ok(Box::new(OpenAiProvider::new(openai_config)))
70    }
71}
72
73/// Registers the `OpenAI` factory with the global registry.
74///
75/// Call this once at application startup to enable config-driven
76/// `OpenAI` provider creation.
77pub fn register_global() {
78    llm_stack::ProviderRegistry::global().register(Box::new(OpenAiFactory));
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84    use std::time::Duration;
85
86    #[test]
87    fn test_factory_name() {
88        let factory = OpenAiFactory;
89        assert_eq!(factory.name(), "openai");
90    }
91
92    #[test]
93    fn test_factory_build_success() {
94        let factory = OpenAiFactory;
95        let config = ProviderConfig::new("openai", "gpt-4o")
96            .api_key("sk-test")
97            .timeout(Duration::from_secs(30))
98            .extra("organization", "org-123");
99
100        let provider = factory.build(&config).unwrap();
101        assert_eq!(provider.metadata().name, "openai");
102        assert_eq!(provider.metadata().model, "gpt-4o");
103    }
104
105    #[test]
106    fn test_factory_missing_api_key() {
107        let factory = OpenAiFactory;
108        let config = ProviderConfig::new("openai", "gpt-4o");
109
110        let err = factory.build(&config).err().unwrap();
111        assert!(matches!(err, LlmError::InvalidRequest(_)));
112    }
113
114    #[test]
115    fn test_factory_empty_model() {
116        let factory = OpenAiFactory;
117        let config = ProviderConfig::new("openai", "").api_key("sk-test");
118
119        let err = factory.build(&config).err().unwrap();
120        assert!(matches!(err, LlmError::InvalidRequest(_)));
121    }
122}