llm_stack_openai/
factory.rs1use llm_stack::registry::{ProviderConfig, ProviderFactory};
4use llm_stack::{DynProvider, LlmError};
5
6use crate::{OpenAiConfig, OpenAiProvider};
7
8#[derive(Debug, Clone, Copy, Default)]
31pub struct OpenAiFactory;
32
33impl ProviderFactory for OpenAiFactory {
34 fn name(&self) -> &'static str {
35 "openai"
36 }
37
38 fn build(&self, config: &ProviderConfig) -> Result<Box<dyn DynProvider>, LlmError> {
39 let api_key = config
40 .api_key
41 .clone()
42 .ok_or_else(|| LlmError::InvalidRequest("openai provider requires api_key".into()))?;
43
44 if config.model.is_empty() {
45 return Err(LlmError::InvalidRequest(
46 "openai provider requires model".into(),
47 ));
48 }
49
50 let mut openai_config = OpenAiConfig {
51 api_key,
52 model: config.model.clone(),
53 client: config.client.clone(),
54 ..Default::default()
55 };
56
57 if let Some(base_url) = &config.base_url {
58 openai_config.base_url.clone_from(base_url);
59 }
60
61 if let Some(timeout) = config.timeout {
62 openai_config.timeout = Some(timeout);
63 }
64
65 if let Some(organization) = config.get_extra_str("organization") {
66 openai_config.organization = Some(organization.to_string());
67 }
68
69 Ok(Box::new(OpenAiProvider::new(openai_config)))
70 }
71}
72
73pub fn register_global() {
78 llm_stack::ProviderRegistry::global().register(Box::new(OpenAiFactory));
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84 use std::time::Duration;
85
86 #[test]
87 fn test_factory_name() {
88 let factory = OpenAiFactory;
89 assert_eq!(factory.name(), "openai");
90 }
91
92 #[test]
93 fn test_factory_build_success() {
94 let factory = OpenAiFactory;
95 let config = ProviderConfig::new("openai", "gpt-4o")
96 .api_key("sk-test")
97 .timeout(Duration::from_secs(30))
98 .extra("organization", "org-123");
99
100 let provider = factory.build(&config).unwrap();
101 assert_eq!(provider.metadata().name, "openai");
102 assert_eq!(provider.metadata().model, "gpt-4o");
103 }
104
105 #[test]
106 fn test_factory_missing_api_key() {
107 let factory = OpenAiFactory;
108 let config = ProviderConfig::new("openai", "gpt-4o");
109
110 let err = factory.build(&config).err().unwrap();
111 assert!(matches!(err, LlmError::InvalidRequest(_)));
112 }
113
114 #[test]
115 fn test_factory_empty_model() {
116 let factory = OpenAiFactory;
117 let config = ProviderConfig::new("openai", "").api_key("sk-test");
118
119 let err = factory.build(&config).err().unwrap();
120 assert!(matches!(err, LlmError::InvalidRequest(_)));
121 }
122}