1use crate::cleanroom::ServicePlugin;
7use crate::config::ServiceConfig;
8use crate::error::{CleanroomError, Result};
9use crate::services::{
10 generic::GenericContainerPlugin,
11 ollama::{OllamaConfig, OllamaPlugin},
12 surrealdb::SurrealDbPlugin,
13 tgi::{TgiConfig, TgiPlugin},
14 vllm::{VllmConfig, VllmPlugin},
15};
16
17pub struct ServiceFactory;
19
20impl ServiceFactory {
21 pub fn create_plugin(name: &str, config: &ServiceConfig) -> Result<Box<dyn ServicePlugin>> {
60 config.validate()?;
62
63 let service_type = config.plugin.to_lowercase();
65
66 match service_type.as_str() {
67 "surrealdb" => Self::create_surrealdb_plugin(name, config),
68 "generic_container" => Self::create_generic_plugin(name, config),
69 "ollama" => Self::create_ollama_plugin(name, config),
70 "tgi" => Self::create_tgi_plugin(name, config),
71 "vllm" => Self::create_vllm_plugin(name, config),
72 _ => Err(CleanroomError::configuration_error(format!(
73 "Unknown service type: '{}'. Supported types: surrealdb, generic_container, ollama, tgi, vllm",
74 config.plugin
75 ))),
76 }
77 }
78
79 fn create_surrealdb_plugin(
81 _name: &str,
82 config: &ServiceConfig,
83 ) -> Result<Box<dyn ServicePlugin>> {
84 let username = Self::get_env_or_config(config, "SURREALDB_USER", "username")
86 .unwrap_or_else(|| "root".to_string());
87
88 let password = Self::get_env_or_config(config, "SURREALDB_PASS", "password")
89 .unwrap_or_else(|| "root".to_string());
90
91 let strict = Self::get_config_bool(config, "strict").unwrap_or(false);
93
94 let plugin = SurrealDbPlugin::with_credentials(&username, &password).with_strict(strict);
96
97 Ok(Box::new(plugin))
98 }
99
100 fn create_generic_plugin(name: &str, config: &ServiceConfig) -> Result<Box<dyn ServicePlugin>> {
102 let image = config.image.as_ref().ok_or_else(|| {
104 CleanroomError::configuration_error(
105 "Generic container requires 'image' field in configuration",
106 )
107 })?;
108
109 let mut plugin = GenericContainerPlugin::new(name, image);
111
112 if let Some(ref env_vars) = config.env {
114 for (key, value) in env_vars.iter() {
115 plugin = plugin.with_env(key, value);
116 }
117 }
118
119 if let Some(ref ports) = config.ports {
121 for port in ports {
122 plugin = plugin.with_port(*port);
123 }
124 }
125
126 if let Some(ref volumes) = config.volumes {
128 for volume in volumes {
129 plugin = plugin
130 .with_volume(
131 &volume.host_path,
132 &volume.container_path,
133 volume.read_only.unwrap_or(false),
134 )
135 .map_err(|e| {
136 CleanroomError::configuration_error(format!(
137 "Invalid volume configuration: {}",
138 e
139 ))
140 })?;
141 }
142 }
143
144 Ok(Box::new(plugin))
145 }
146
147 fn create_ollama_plugin(name: &str, config: &ServiceConfig) -> Result<Box<dyn ServicePlugin>> {
149 let endpoint = Self::get_config_string(config, "endpoint").ok_or_else(|| {
151 CleanroomError::configuration_error(
152 "Ollama service requires 'endpoint' in env configuration",
153 )
154 })?;
155
156 let default_model = Self::get_config_string(config, "default_model")
158 .or_else(|| Self::get_config_string(config, "model"))
159 .ok_or_else(|| {
160 CleanroomError::configuration_error(
161 "Ollama service requires 'default_model' or 'model' in env configuration",
162 )
163 })?;
164
165 let timeout_seconds = Self::get_config_string(config, "timeout_seconds")
167 .and_then(|s| s.parse::<u64>().ok())
168 .unwrap_or(60);
169
170 let ollama_config = OllamaConfig {
171 endpoint,
172 default_model,
173 timeout_seconds,
174 };
175
176 let plugin = OllamaPlugin::new(name, ollama_config);
177 Ok(Box::new(plugin))
178 }
179
180 fn create_tgi_plugin(name: &str, config: &ServiceConfig) -> Result<Box<dyn ServicePlugin>> {
182 let endpoint = Self::get_config_string(config, "endpoint").ok_or_else(|| {
184 CleanroomError::configuration_error(
185 "TGI service requires 'endpoint' in env configuration",
186 )
187 })?;
188
189 let model_id = Self::get_config_string(config, "model_id")
191 .or_else(|| Self::get_config_string(config, "model"))
192 .ok_or_else(|| {
193 CleanroomError::configuration_error(
194 "TGI service requires 'model_id' or 'model' in env configuration",
195 )
196 })?;
197
198 let max_total_tokens =
200 Self::get_config_string(config, "max_total_tokens").and_then(|s| s.parse::<u32>().ok());
201
202 let max_input_length =
203 Self::get_config_string(config, "max_input_length").and_then(|s| s.parse::<u32>().ok());
204
205 let max_batch_prefill_tokens = Self::get_config_string(config, "max_batch_prefill_tokens")
206 .and_then(|s| s.parse::<u32>().ok());
207
208 let max_concurrent_requests = Self::get_config_string(config, "max_concurrent_requests")
209 .and_then(|s| s.parse::<u32>().ok());
210
211 let max_batch_total_tokens = Self::get_config_string(config, "max_batch_total_tokens")
212 .and_then(|s| s.parse::<u32>().ok());
213
214 let timeout_seconds = Self::get_config_string(config, "timeout_seconds")
215 .and_then(|s| s.parse::<u64>().ok())
216 .unwrap_or(60);
217
218 let tgi_config = TgiConfig {
219 endpoint,
220 model_id,
221 max_total_tokens,
222 max_input_length,
223 max_batch_prefill_tokens,
224 max_concurrent_requests,
225 max_batch_total_tokens,
226 timeout_seconds,
227 };
228
229 let plugin = TgiPlugin::new(name, tgi_config);
230 Ok(Box::new(plugin))
231 }
232
233 fn create_vllm_plugin(name: &str, config: &ServiceConfig) -> Result<Box<dyn ServicePlugin>> {
235 let endpoint = Self::get_config_string(config, "endpoint").ok_or_else(|| {
237 CleanroomError::configuration_error(
238 "vLLM service requires 'endpoint' in env configuration",
239 )
240 })?;
241
242 let model = Self::get_config_string(config, "model").ok_or_else(|| {
244 CleanroomError::configuration_error(
245 "vLLM service requires 'model' in env configuration",
246 )
247 })?;
248
249 let max_num_seqs =
251 Self::get_config_string(config, "max_num_seqs").and_then(|s| s.parse::<u32>().ok());
252
253 let max_model_len =
254 Self::get_config_string(config, "max_model_len").and_then(|s| s.parse::<u32>().ok());
255
256 let tensor_parallel_size = Self::get_config_string(config, "tensor_parallel_size")
257 .and_then(|s| s.parse::<u32>().ok());
258
259 let gpu_memory_utilization = Self::get_config_string(config, "gpu_memory_utilization")
260 .and_then(|s| s.parse::<f32>().ok());
261
262 let enable_prefix_caching = Self::get_config_bool(config, "enable_prefix_caching");
263
264 let timeout_seconds = Self::get_config_string(config, "timeout_seconds")
265 .and_then(|s| s.parse::<u64>().ok())
266 .unwrap_or(60);
267
268 let vllm_config = VllmConfig {
269 endpoint,
270 model,
271 max_num_seqs,
272 max_model_len,
273 tensor_parallel_size,
274 gpu_memory_utilization,
275 enable_prefix_caching,
276 timeout_seconds,
277 };
278
279 let plugin = VllmPlugin::new(name, vllm_config);
280 Ok(Box::new(plugin))
281 }
282
283 fn get_env_or_config(
287 config: &ServiceConfig,
288 env_var: &str,
289 config_key: &str,
290 ) -> Option<String> {
291 std::env::var(env_var)
293 .ok()
294 .or_else(|| {
296 config
297 .env
298 .as_ref()
299 .and_then(|env_map| env_map.get(config_key).cloned())
300 })
301 }
302
303 fn get_config_string(config: &ServiceConfig, key: &str) -> Option<String> {
305 config
306 .env
307 .as_ref()
308 .and_then(|env_map| env_map.get(key).cloned())
309 }
310
311 fn get_config_bool(config: &ServiceConfig, key: &str) -> Option<bool> {
313 Self::get_config_string(config, key).and_then(|s| s.parse::<bool>().ok())
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320 use std::collections::HashMap;
321
322 #[test]
323 fn test_create_surrealdb_plugin() -> Result<()> {
324 let config = ServiceConfig {
325 r#type: "database".to_string(),
326 plugin: "surrealdb".to_string(),
327 image: Some("surrealdb/surrealdb:latest".to_string()),
328 env: None,
329 ports: None,
330 volumes: None,
331 health_check: None,
332 username: None,
333 password: None,
334 strict: None,
335 };
336
337 let plugin = ServiceFactory::create_plugin("test_db", &config)?;
338 assert_eq!(plugin.name(), "surrealdb");
339 Ok(())
340 }
341
342 #[test]
343 fn test_create_surrealdb_plugin_with_credentials() -> Result<()> {
344 let mut env = HashMap::new();
345 env.insert("username".to_string(), "admin".to_string());
346 env.insert("password".to_string(), "secret".to_string());
347 env.insert("strict".to_string(), "true".to_string());
348
349 let config = ServiceConfig {
350 r#type: "database".to_string(),
351 plugin: "surrealdb".to_string(),
352 image: Some("surrealdb/surrealdb:latest".to_string()),
353 env: Some(env),
354 ports: None,
355 volumes: None,
356 health_check: None,
357 username: None,
358 password: None,
359 strict: None,
360 };
361
362 let plugin = ServiceFactory::create_plugin("test_db", &config)?;
363 assert_eq!(plugin.name(), "surrealdb");
364 Ok(())
365 }
366
367 #[test]
368 fn test_create_generic_plugin() -> Result<()> {
369 let config = ServiceConfig {
370 r#type: "container".to_string(),
371 plugin: "generic_container".to_string(),
372 image: Some("alpine:latest".to_string()),
373 env: None,
374 ports: None,
375 volumes: None,
376 health_check: None,
377 username: None,
378 password: None,
379 strict: None,
380 };
381
382 let plugin = ServiceFactory::create_plugin("test_container", &config)?;
383 assert_eq!(plugin.name(), "test_container");
384 Ok(())
385 }
386
387 #[test]
388 fn test_create_generic_plugin_with_env_and_ports() -> Result<()> {
389 let mut env = HashMap::new();
390 env.insert("KEY1".to_string(), "value1".to_string());
391 env.insert("KEY2".to_string(), "value2".to_string());
392
393 let config = ServiceConfig {
394 r#type: "container".to_string(),
395 plugin: "generic_container".to_string(),
396 image: Some("nginx:latest".to_string()),
397 env: Some(env),
398 ports: Some(vec![8080, 8443]),
399 volumes: None,
400 health_check: None,
401 username: None,
402 password: None,
403 strict: None,
404 };
405
406 let plugin = ServiceFactory::create_plugin("nginx", &config)?;
407 assert_eq!(plugin.name(), "nginx");
408 Ok(())
409 }
410
411 #[test]
412 fn test_create_ollama_plugin() -> Result<()> {
413 let mut env = HashMap::new();
414 env.insert("endpoint".to_string(), "http://localhost:11434".to_string());
415 env.insert("default_model".to_string(), "llama2".to_string());
416 env.insert("timeout_seconds".to_string(), "120".to_string());
417
418 let config = ServiceConfig {
419 r#type: "ollama".to_string(), plugin: "ollama".to_string(),
421 image: None,
422 env: Some(env),
423 ports: None,
424 volumes: None,
425 health_check: None,
426 username: None,
427 password: None,
428 strict: None,
429 };
430
431 let plugin = ServiceFactory::create_plugin("ollama_service", &config)?;
432 assert_eq!(plugin.name(), "ollama_service");
433 Ok(())
434 }
435
436 #[test]
437 fn test_create_tgi_plugin() -> Result<()> {
438 let mut env = HashMap::new();
439 env.insert("endpoint".to_string(), "http://localhost:8080".to_string());
440 env.insert(
441 "model_id".to_string(),
442 "microsoft/DialoGPT-medium".to_string(),
443 );
444 env.insert("max_total_tokens".to_string(), "2048".to_string());
445
446 let config = ServiceConfig {
447 r#type: "network_service".to_string(), plugin: "tgi".to_string(),
449 image: None,
450 env: Some(env),
451 ports: None,
452 volumes: None,
453 health_check: None,
454 username: None,
455 password: None,
456 strict: None,
457 };
458
459 let plugin = ServiceFactory::create_plugin("tgi_service", &config)?;
460 assert_eq!(plugin.name(), "tgi_service");
461 Ok(())
462 }
463
464 #[test]
465 fn test_create_vllm_plugin() -> Result<()> {
466 let mut env = HashMap::new();
467 env.insert("endpoint".to_string(), "http://localhost:8000".to_string());
468 env.insert("model".to_string(), "facebook/opt-125m".to_string());
469 env.insert("max_num_seqs".to_string(), "100".to_string());
470
471 let config = ServiceConfig {
472 r#type: "network_service".to_string(), plugin: "vllm".to_string(),
474 image: None,
475 env: Some(env),
476 ports: None,
477 volumes: None,
478 health_check: None,
479 username: None,
480 password: None,
481 strict: None,
482 };
483
484 let plugin = ServiceFactory::create_plugin("vllm_service", &config)?;
485 assert_eq!(plugin.name(), "vllm_service");
486 Ok(())
487 }
488
489 #[test]
490 fn test_unknown_service_type_returns_error() {
491 let config = ServiceConfig {
492 r#type: "unknown".to_string(),
493 plugin: "unknown_plugin".to_string(),
494 image: Some("some:image".to_string()),
495 env: None,
496 ports: None,
497 volumes: None,
498 health_check: None,
499 username: None,
500 password: None,
501 strict: None,
502 };
503
504 let result = ServiceFactory::create_plugin("test", &config);
505 assert!(result.is_err());
506
507 if let Err(e) = result {
508 assert!(e.message.contains("Unknown service type"));
509 }
510 }
511
512 #[test]
513 fn test_generic_container_without_image_returns_error() {
514 let config = ServiceConfig {
515 r#type: "container".to_string(),
516 plugin: "generic_container".to_string(),
517 image: None, env: None,
519 ports: None,
520 volumes: None,
521 health_check: None,
522 username: None,
523 password: None,
524 strict: None,
525 };
526
527 let result = ServiceFactory::create_plugin("test", &config);
528 assert!(result.is_err());
529 }
530
531 #[test]
532 fn test_ollama_without_endpoint_returns_error() {
533 let mut env = HashMap::new();
534 env.insert("default_model".to_string(), "llama2".to_string());
535 let config = ServiceConfig {
538 r#type: "ollama".to_string(), plugin: "ollama".to_string(),
540 image: None,
541 env: Some(env),
542 ports: None,
543 volumes: None,
544 health_check: None,
545 username: None,
546 password: None,
547 strict: None,
548 };
549
550 let result = ServiceFactory::create_plugin("test", &config);
551 assert!(result.is_err());
552
553 if let Err(e) = result {
554 assert!(e.message.contains("endpoint"));
555 }
556 }
557
558 #[test]
559 fn test_case_insensitive_plugin_type() -> Result<()> {
560 let config = ServiceConfig {
561 r#type: "database".to_string(),
562 plugin: "SurrealDB".to_string(), image: Some("surrealdb/surrealdb:latest".to_string()),
564 env: None,
565 ports: None,
566 volumes: None,
567 health_check: None,
568 username: None,
569 password: None,
570 strict: None,
571 };
572
573 let plugin = ServiceFactory::create_plugin("test_db", &config)?;
574 assert_eq!(plugin.name(), "surrealdb");
575 Ok(())
576 }
577}