1use crate::cleanroom::ServicePlugin;
7use crate::config::ServiceConfig;
8use crate::error::{CleanroomError, Result};
9use crate::services::{
10 generic::GenericContainerPlugin,
11 ollama::{OllamaConfig, OllamaPlugin},
12 surrealdb::SurrealDbPlugin,
13 tgi::{TgiConfig, TgiPlugin},
14 vllm::{VllmConfig, VllmPlugin},
15};
16
17pub struct ServiceFactory;
19
20impl ServiceFactory {
21 pub fn create_plugin(name: &str, config: &ServiceConfig) -> Result<Box<dyn ServicePlugin>> {
60 config.validate()?;
62
63 let service_type = config.plugin.to_lowercase();
65
66 match service_type.as_str() {
67 "surrealdb" => Self::create_surrealdb_plugin(name, config),
68 "generic_container" => Self::create_generic_plugin(name, config),
69 "ollama" => Self::create_ollama_plugin(name, config),
70 "tgi" => Self::create_tgi_plugin(name, config),
71 "vllm" => Self::create_vllm_plugin(name, config),
72 _ => Err(CleanroomError::configuration_error(format!(
73 "Unknown service type: '{}'. Supported types: surrealdb, generic_container, ollama, tgi, vllm",
74 config.plugin
75 ))),
76 }
77 }
78
79 fn create_surrealdb_plugin(
81 _name: &str,
82 config: &ServiceConfig,
83 ) -> Result<Box<dyn ServicePlugin>> {
84 let username = Self::get_env_or_config(config, "SURREALDB_USER", "username")
86 .unwrap_or_else(|| "root".to_string());
87
88 let password = Self::get_env_or_config(config, "SURREALDB_PASS", "password")
89 .unwrap_or_else(|| "root".to_string());
90
91 let strict = Self::get_config_bool(config, "strict").unwrap_or(false);
93
94 let plugin = SurrealDbPlugin::with_credentials(&username, &password).with_strict(strict);
96
97 Ok(Box::new(plugin))
98 }
99
100 fn create_generic_plugin(name: &str, config: &ServiceConfig) -> Result<Box<dyn ServicePlugin>> {
102 let image = config.image.as_ref().ok_or_else(|| {
104 CleanroomError::configuration_error(
105 "Generic container requires 'image' field in configuration",
106 )
107 })?;
108
109 let mut plugin = GenericContainerPlugin::new(name, image);
111
112 if let Some(ref env_vars) = config.env {
114 for (key, value) in env_vars.iter() {
115 plugin = plugin.with_env(key, value);
116 }
117 }
118
119 if let Some(ref ports) = config.ports {
121 for port in ports {
122 plugin = plugin.with_port(*port);
123 }
124 }
125
126 if let Some(ref volumes) = config.volumes {
128 for volume in volumes {
129 plugin = plugin
130 .with_volume(
131 &volume.host_path,
132 &volume.container_path,
133 volume.read_only.unwrap_or(false),
134 )
135 .map_err(|e| {
136 CleanroomError::configuration_error(format!(
137 "Invalid volume configuration: {}",
138 e
139 ))
140 })?;
141 }
142 }
143
144 Ok(Box::new(plugin))
145 }
146
147 fn create_ollama_plugin(name: &str, config: &ServiceConfig) -> Result<Box<dyn ServicePlugin>> {
149 let endpoint = Self::get_config_string(config, "endpoint").ok_or_else(|| {
151 CleanroomError::configuration_error(
152 "Ollama service requires 'endpoint' in env configuration",
153 )
154 })?;
155
156 let default_model = Self::get_config_string(config, "default_model")
158 .or_else(|| Self::get_config_string(config, "model"))
159 .ok_or_else(|| {
160 CleanroomError::configuration_error(
161 "Ollama service requires 'default_model' or 'model' in env configuration",
162 )
163 })?;
164
165 let timeout_seconds = Self::get_config_string(config, "timeout_seconds")
167 .and_then(|s| s.parse::<u64>().ok())
168 .unwrap_or(60);
169
170 let ollama_config = OllamaConfig {
171 endpoint,
172 default_model,
173 timeout_seconds,
174 };
175
176 let plugin = OllamaPlugin::new(name, ollama_config);
177 Ok(Box::new(plugin))
178 }
179
180 fn create_tgi_plugin(name: &str, config: &ServiceConfig) -> Result<Box<dyn ServicePlugin>> {
182 let endpoint = Self::get_config_string(config, "endpoint").ok_or_else(|| {
184 CleanroomError::configuration_error(
185 "TGI service requires 'endpoint' in env configuration",
186 )
187 })?;
188
189 let model_id = Self::get_config_string(config, "model_id")
191 .or_else(|| Self::get_config_string(config, "model"))
192 .ok_or_else(|| {
193 CleanroomError::configuration_error(
194 "TGI service requires 'model_id' or 'model' in env configuration",
195 )
196 })?;
197
198 let max_total_tokens =
200 Self::get_config_string(config, "max_total_tokens").and_then(|s| s.parse::<u32>().ok());
201
202 let max_input_length =
203 Self::get_config_string(config, "max_input_length").and_then(|s| s.parse::<u32>().ok());
204
205 let max_batch_prefill_tokens = Self::get_config_string(config, "max_batch_prefill_tokens")
206 .and_then(|s| s.parse::<u32>().ok());
207
208 let max_concurrent_requests = Self::get_config_string(config, "max_concurrent_requests")
209 .and_then(|s| s.parse::<u32>().ok());
210
211 let max_batch_total_tokens = Self::get_config_string(config, "max_batch_total_tokens")
212 .and_then(|s| s.parse::<u32>().ok());
213
214 let timeout_seconds = Self::get_config_string(config, "timeout_seconds")
215 .and_then(|s| s.parse::<u64>().ok())
216 .unwrap_or(60);
217
218 let tgi_config = TgiConfig {
219 endpoint,
220 model_id,
221 max_total_tokens,
222 max_input_length,
223 max_batch_prefill_tokens,
224 max_concurrent_requests,
225 max_batch_total_tokens,
226 timeout_seconds,
227 };
228
229 let plugin = TgiPlugin::new(name, tgi_config);
230 Ok(Box::new(plugin))
231 }
232
233 fn create_vllm_plugin(name: &str, config: &ServiceConfig) -> Result<Box<dyn ServicePlugin>> {
235 let endpoint = Self::get_config_string(config, "endpoint").ok_or_else(|| {
237 CleanroomError::configuration_error(
238 "vLLM service requires 'endpoint' in env configuration",
239 )
240 })?;
241
242 let model = Self::get_config_string(config, "model").ok_or_else(|| {
244 CleanroomError::configuration_error(
245 "vLLM service requires 'model' in env configuration",
246 )
247 })?;
248
249 let max_num_seqs =
251 Self::get_config_string(config, "max_num_seqs").and_then(|s| s.parse::<u32>().ok());
252
253 let max_model_len =
254 Self::get_config_string(config, "max_model_len").and_then(|s| s.parse::<u32>().ok());
255
256 let tensor_parallel_size = Self::get_config_string(config, "tensor_parallel_size")
257 .and_then(|s| s.parse::<u32>().ok());
258
259 let gpu_memory_utilization = Self::get_config_string(config, "gpu_memory_utilization")
260 .and_then(|s| s.parse::<f32>().ok());
261
262 let enable_prefix_caching = Self::get_config_bool(config, "enable_prefix_caching");
263
264 let timeout_seconds = Self::get_config_string(config, "timeout_seconds")
265 .and_then(|s| s.parse::<u64>().ok())
266 .unwrap_or(60);
267
268 let vllm_config = VllmConfig {
269 endpoint,
270 model,
271 max_num_seqs,
272 max_model_len,
273 tensor_parallel_size,
274 gpu_memory_utilization,
275 enable_prefix_caching,
276 timeout_seconds,
277 };
278
279 let plugin = VllmPlugin::new(name, vllm_config);
280 Ok(Box::new(plugin))
281 }
282
283 fn get_env_or_config(
287 config: &ServiceConfig,
288 env_var: &str,
289 config_key: &str,
290 ) -> Option<String> {
291 std::env::var(env_var)
293 .ok()
294 .or_else(|| {
296 config
297 .env
298 .as_ref()
299 .and_then(|env_map| env_map.get(config_key).cloned())
300 })
301 }
302
303 fn get_config_string(config: &ServiceConfig, key: &str) -> Option<String> {
305 config
306 .env
307 .as_ref()
308 .and_then(|env_map| env_map.get(key).cloned())
309 }
310
311 fn get_config_bool(config: &ServiceConfig, key: &str) -> Option<bool> {
313 Self::get_config_string(config, key).and_then(|s| s.parse::<bool>().ok())
314 }
315}