clnrm_core/services/
factory.rs

1//! Service factory for creating plugins from configuration
2//!
3//! Provides centralized plugin creation from TOML ServiceConfig,
4//! handling type-specific configuration and validation.
5
6use crate::cleanroom::ServicePlugin;
7use crate::config::ServiceConfig;
8use crate::error::{CleanroomError, Result};
9use crate::services::{
10    generic::GenericContainerPlugin,
11    ollama::{OllamaConfig, OllamaPlugin},
12    surrealdb::SurrealDbPlugin,
13    tgi::{TgiConfig, TgiPlugin},
14    vllm::{VllmConfig, VllmPlugin},
15};
16
17/// Service factory for creating plugins from configuration
18pub struct ServiceFactory;
19
20impl ServiceFactory {
21    /// Create a service plugin from configuration
22    ///
23    /// # Arguments
24    ///
25    /// * `name` - Service name identifier
26    /// * `config` - Service configuration from TOML
27    ///
28    /// # Returns
29    ///
30    /// A boxed `ServicePlugin` implementation matching the service type
31    ///
32    /// # Errors
33    ///
34    /// Returns error if:
35    /// - Service type is unknown or unsupported
36    /// - Required configuration fields are missing
37    /// - Configuration values are invalid
38    ///
39    /// # Example
40    ///
41    /// ```no_run
42    /// use clnrm_core::services::factory::ServiceFactory;
43    /// use clnrm_core::config::ServiceConfig;
44    /// use std::collections::HashMap;
45    ///
46    /// let mut config = ServiceConfig {
47    ///     r#type: "surrealdb".to_string(),
48    ///     plugin: "surrealdb".to_string(),
49    ///     image: Some("surrealdb/surrealdb:latest".to_string()),
50    ///     env: None,
51    ///     ports: None,
52    ///     volumes: None,
53    ///     health_check: None,
54    /// };
55    ///
56    /// let plugin = ServiceFactory::create_plugin("my_db", &config)?;
57    /// # Ok::<(), clnrm_core::error::CleanroomError>(())
58    /// ```
59    pub fn create_plugin(name: &str, config: &ServiceConfig) -> Result<Box<dyn ServicePlugin>> {
60        // Validate configuration before processing
61        config.validate()?;
62
63        // Determine service type from plugin field (normalized to lowercase)
64        let service_type = config.plugin.to_lowercase();
65
66        match service_type.as_str() {
67            "surrealdb" => Self::create_surrealdb_plugin(name, config),
68            "generic_container" => Self::create_generic_plugin(name, config),
69            "ollama" => Self::create_ollama_plugin(name, config),
70            "tgi" => Self::create_tgi_plugin(name, config),
71            "vllm" => Self::create_vllm_plugin(name, config),
72            _ => Err(CleanroomError::configuration_error(format!(
73                "Unknown service type: '{}'. Supported types: surrealdb, generic_container, ollama, tgi, vllm",
74                config.plugin
75            ))),
76        }
77    }
78
79    /// Create a SurrealDB plugin from configuration
80    fn create_surrealdb_plugin(
81        _name: &str,
82        config: &ServiceConfig,
83    ) -> Result<Box<dyn ServicePlugin>> {
84        // Extract credentials from environment variables or config
85        let username = Self::get_env_or_config(config, "SURREALDB_USER", "username")
86            .unwrap_or_else(|| "root".to_string());
87
88        let password = Self::get_env_or_config(config, "SURREALDB_PASS", "password")
89            .unwrap_or_else(|| "root".to_string());
90
91        // Extract strict mode flag (default: false)
92        let strict = Self::get_config_bool(config, "strict").unwrap_or(false);
93
94        // Create plugin with credentials
95        let plugin = SurrealDbPlugin::with_credentials(&username, &password).with_strict(strict);
96
97        Ok(Box::new(plugin))
98    }
99
100    /// Create a generic container plugin from configuration
101    fn create_generic_plugin(name: &str, config: &ServiceConfig) -> Result<Box<dyn ServicePlugin>> {
102        // Image is required for generic containers
103        let image = config.image.as_ref().ok_or_else(|| {
104            CleanroomError::configuration_error(
105                "Generic container requires 'image' field in configuration",
106            )
107        })?;
108
109        // Create base plugin
110        let mut plugin = GenericContainerPlugin::new(name, image);
111
112        // Add environment variables if present
113        if let Some(ref env_vars) = config.env {
114            for (key, value) in env_vars.iter() {
115                plugin = plugin.with_env(key, value);
116            }
117        }
118
119        // Add port mappings if present
120        if let Some(ref ports) = config.ports {
121            for port in ports {
122                plugin = plugin.with_port(*port);
123            }
124        }
125
126        // Add volume mounts if present
127        if let Some(ref volumes) = config.volumes {
128            for volume in volumes {
129                plugin = plugin
130                    .with_volume(
131                        &volume.host_path,
132                        &volume.container_path,
133                        volume.read_only.unwrap_or(false),
134                    )
135                    .map_err(|e| {
136                        CleanroomError::configuration_error(format!(
137                            "Invalid volume configuration: {}",
138                            e
139                        ))
140                    })?;
141            }
142        }
143
144        Ok(Box::new(plugin))
145    }
146
147    /// Create an Ollama plugin from configuration
148    fn create_ollama_plugin(name: &str, config: &ServiceConfig) -> Result<Box<dyn ServicePlugin>> {
149        // Extract endpoint (required)
150        let endpoint = Self::get_config_string(config, "endpoint").ok_or_else(|| {
151            CleanroomError::configuration_error(
152                "Ollama service requires 'endpoint' in env configuration",
153            )
154        })?;
155
156        // Extract default model (required)
157        let default_model = Self::get_config_string(config, "default_model")
158            .or_else(|| Self::get_config_string(config, "model"))
159            .ok_or_else(|| {
160                CleanroomError::configuration_error(
161                    "Ollama service requires 'default_model' or 'model' in env configuration",
162                )
163            })?;
164
165        // Extract timeout (optional, default: 60 seconds)
166        let timeout_seconds = Self::get_config_string(config, "timeout_seconds")
167            .and_then(|s| s.parse::<u64>().ok())
168            .unwrap_or(60);
169
170        let ollama_config = OllamaConfig {
171            endpoint,
172            default_model,
173            timeout_seconds,
174        };
175
176        let plugin = OllamaPlugin::new(name, ollama_config);
177        Ok(Box::new(plugin))
178    }
179
180    /// Create a TGI (Text Generation Inference) plugin from configuration
181    fn create_tgi_plugin(name: &str, config: &ServiceConfig) -> Result<Box<dyn ServicePlugin>> {
182        // Extract endpoint (required)
183        let endpoint = Self::get_config_string(config, "endpoint").ok_or_else(|| {
184            CleanroomError::configuration_error(
185                "TGI service requires 'endpoint' in env configuration",
186            )
187        })?;
188
189        // Extract model_id (required)
190        let model_id = Self::get_config_string(config, "model_id")
191            .or_else(|| Self::get_config_string(config, "model"))
192            .ok_or_else(|| {
193                CleanroomError::configuration_error(
194                    "TGI service requires 'model_id' or 'model' in env configuration",
195                )
196            })?;
197
198        // Extract optional configuration
199        let max_total_tokens =
200            Self::get_config_string(config, "max_total_tokens").and_then(|s| s.parse::<u32>().ok());
201
202        let max_input_length =
203            Self::get_config_string(config, "max_input_length").and_then(|s| s.parse::<u32>().ok());
204
205        let max_batch_prefill_tokens = Self::get_config_string(config, "max_batch_prefill_tokens")
206            .and_then(|s| s.parse::<u32>().ok());
207
208        let max_concurrent_requests = Self::get_config_string(config, "max_concurrent_requests")
209            .and_then(|s| s.parse::<u32>().ok());
210
211        let max_batch_total_tokens = Self::get_config_string(config, "max_batch_total_tokens")
212            .and_then(|s| s.parse::<u32>().ok());
213
214        let timeout_seconds = Self::get_config_string(config, "timeout_seconds")
215            .and_then(|s| s.parse::<u64>().ok())
216            .unwrap_or(60);
217
218        let tgi_config = TgiConfig {
219            endpoint,
220            model_id,
221            max_total_tokens,
222            max_input_length,
223            max_batch_prefill_tokens,
224            max_concurrent_requests,
225            max_batch_total_tokens,
226            timeout_seconds,
227        };
228
229        let plugin = TgiPlugin::new(name, tgi_config);
230        Ok(Box::new(plugin))
231    }
232
233    /// Create a vLLM plugin from configuration
234    fn create_vllm_plugin(name: &str, config: &ServiceConfig) -> Result<Box<dyn ServicePlugin>> {
235        // Extract endpoint (required)
236        let endpoint = Self::get_config_string(config, "endpoint").ok_or_else(|| {
237            CleanroomError::configuration_error(
238                "vLLM service requires 'endpoint' in env configuration",
239            )
240        })?;
241
242        // Extract model (required)
243        let model = Self::get_config_string(config, "model").ok_or_else(|| {
244            CleanroomError::configuration_error(
245                "vLLM service requires 'model' in env configuration",
246            )
247        })?;
248
249        // Extract optional configuration
250        let max_num_seqs =
251            Self::get_config_string(config, "max_num_seqs").and_then(|s| s.parse::<u32>().ok());
252
253        let max_model_len =
254            Self::get_config_string(config, "max_model_len").and_then(|s| s.parse::<u32>().ok());
255
256        let tensor_parallel_size = Self::get_config_string(config, "tensor_parallel_size")
257            .and_then(|s| s.parse::<u32>().ok());
258
259        let gpu_memory_utilization = Self::get_config_string(config, "gpu_memory_utilization")
260            .and_then(|s| s.parse::<f32>().ok());
261
262        let enable_prefix_caching = Self::get_config_bool(config, "enable_prefix_caching");
263
264        let timeout_seconds = Self::get_config_string(config, "timeout_seconds")
265            .and_then(|s| s.parse::<u64>().ok())
266            .unwrap_or(60);
267
268        let vllm_config = VllmConfig {
269            endpoint,
270            model,
271            max_num_seqs,
272            max_model_len,
273            tensor_parallel_size,
274            gpu_memory_utilization,
275            enable_prefix_caching,
276            timeout_seconds,
277        };
278
279        let plugin = VllmPlugin::new(name, vllm_config);
280        Ok(Box::new(plugin))
281    }
282
283    // Helper functions for extracting configuration values
284
285    /// Get value from environment variable or config env map
286    fn get_env_or_config(
287        config: &ServiceConfig,
288        env_var: &str,
289        config_key: &str,
290    ) -> Option<String> {
291        // First try environment variable
292        std::env::var(env_var)
293            .ok()
294            // Then try config env map
295            .or_else(|| {
296                config
297                    .env
298                    .as_ref()
299                    .and_then(|env_map| env_map.get(config_key).cloned())
300            })
301    }
302
303    /// Get string value from config env map
304    fn get_config_string(config: &ServiceConfig, key: &str) -> Option<String> {
305        config
306            .env
307            .as_ref()
308            .and_then(|env_map| env_map.get(key).cloned())
309    }
310
311    /// Get boolean value from config env map
312    fn get_config_bool(config: &ServiceConfig, key: &str) -> Option<bool> {
313        Self::get_config_string(config, key).and_then(|s| s.parse::<bool>().ok())
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320    use std::collections::HashMap;
321
322    #[test]
323    fn test_create_surrealdb_plugin() -> Result<()> {
324        let config = ServiceConfig {
325            r#type: "database".to_string(),
326            plugin: "surrealdb".to_string(),
327            image: Some("surrealdb/surrealdb:latest".to_string()),
328            env: None,
329            ports: None,
330            volumes: None,
331            health_check: None,
332            username: None,
333            password: None,
334            strict: None,
335        };
336
337        let plugin = ServiceFactory::create_plugin("test_db", &config)?;
338        assert_eq!(plugin.name(), "surrealdb");
339        Ok(())
340    }
341
342    #[test]
343    fn test_create_surrealdb_plugin_with_credentials() -> Result<()> {
344        let mut env = HashMap::new();
345        env.insert("username".to_string(), "admin".to_string());
346        env.insert("password".to_string(), "secret".to_string());
347        env.insert("strict".to_string(), "true".to_string());
348
349        let config = ServiceConfig {
350            r#type: "database".to_string(),
351            plugin: "surrealdb".to_string(),
352            image: Some("surrealdb/surrealdb:latest".to_string()),
353            env: Some(env),
354            ports: None,
355            volumes: None,
356            health_check: None,
357            username: None,
358            password: None,
359            strict: None,
360        };
361
362        let plugin = ServiceFactory::create_plugin("test_db", &config)?;
363        assert_eq!(plugin.name(), "surrealdb");
364        Ok(())
365    }
366
367    #[test]
368    fn test_create_generic_plugin() -> Result<()> {
369        let config = ServiceConfig {
370            r#type: "container".to_string(),
371            plugin: "generic_container".to_string(),
372            image: Some("alpine:latest".to_string()),
373            env: None,
374            ports: None,
375            volumes: None,
376            health_check: None,
377            username: None,
378            password: None,
379            strict: None,
380        };
381
382        let plugin = ServiceFactory::create_plugin("test_container", &config)?;
383        assert_eq!(plugin.name(), "test_container");
384        Ok(())
385    }
386
387    #[test]
388    fn test_create_generic_plugin_with_env_and_ports() -> Result<()> {
389        let mut env = HashMap::new();
390        env.insert("KEY1".to_string(), "value1".to_string());
391        env.insert("KEY2".to_string(), "value2".to_string());
392
393        let config = ServiceConfig {
394            r#type: "container".to_string(),
395            plugin: "generic_container".to_string(),
396            image: Some("nginx:latest".to_string()),
397            env: Some(env),
398            ports: Some(vec![8080, 8443]),
399            volumes: None,
400            health_check: None,
401            username: None,
402            password: None,
403            strict: None,
404        };
405
406        let plugin = ServiceFactory::create_plugin("nginx", &config)?;
407        assert_eq!(plugin.name(), "nginx");
408        Ok(())
409    }
410
411    #[test]
412    fn test_create_ollama_plugin() -> Result<()> {
413        let mut env = HashMap::new();
414        env.insert("endpoint".to_string(), "http://localhost:11434".to_string());
415        env.insert("default_model".to_string(), "llama2".to_string());
416        env.insert("timeout_seconds".to_string(), "120".to_string());
417
418        let config = ServiceConfig {
419            r#type: "ollama".to_string(), // Changed from "ai_service"
420            plugin: "ollama".to_string(),
421            image: None,
422            env: Some(env),
423            ports: None,
424            volumes: None,
425            health_check: None,
426            username: None,
427            password: None,
428            strict: None,
429        };
430
431        let plugin = ServiceFactory::create_plugin("ollama_service", &config)?;
432        assert_eq!(plugin.name(), "ollama_service");
433        Ok(())
434    }
435
436    #[test]
437    fn test_create_tgi_plugin() -> Result<()> {
438        let mut env = HashMap::new();
439        env.insert("endpoint".to_string(), "http://localhost:8080".to_string());
440        env.insert(
441            "model_id".to_string(),
442            "microsoft/DialoGPT-medium".to_string(),
443        );
444        env.insert("max_total_tokens".to_string(), "2048".to_string());
445
446        let config = ServiceConfig {
447            r#type: "network_service".to_string(), // Changed from "ai_service"
448            plugin: "tgi".to_string(),
449            image: None,
450            env: Some(env),
451            ports: None,
452            volumes: None,
453            health_check: None,
454            username: None,
455            password: None,
456            strict: None,
457        };
458
459        let plugin = ServiceFactory::create_plugin("tgi_service", &config)?;
460        assert_eq!(plugin.name(), "tgi_service");
461        Ok(())
462    }
463
464    #[test]
465    fn test_create_vllm_plugin() -> Result<()> {
466        let mut env = HashMap::new();
467        env.insert("endpoint".to_string(), "http://localhost:8000".to_string());
468        env.insert("model".to_string(), "facebook/opt-125m".to_string());
469        env.insert("max_num_seqs".to_string(), "100".to_string());
470
471        let config = ServiceConfig {
472            r#type: "network_service".to_string(), // Changed from "ai_service"
473            plugin: "vllm".to_string(),
474            image: None,
475            env: Some(env),
476            ports: None,
477            volumes: None,
478            health_check: None,
479            username: None,
480            password: None,
481            strict: None,
482        };
483
484        let plugin = ServiceFactory::create_plugin("vllm_service", &config)?;
485        assert_eq!(plugin.name(), "vllm_service");
486        Ok(())
487    }
488
489    #[test]
490    fn test_unknown_service_type_returns_error() {
491        let config = ServiceConfig {
492            r#type: "unknown".to_string(),
493            plugin: "unknown_plugin".to_string(),
494            image: Some("some:image".to_string()),
495            env: None,
496            ports: None,
497            volumes: None,
498            health_check: None,
499            username: None,
500            password: None,
501            strict: None,
502        };
503
504        let result = ServiceFactory::create_plugin("test", &config);
505        assert!(result.is_err());
506
507        if let Err(e) = result {
508            assert!(e.message.contains("Unknown service type"));
509        }
510    }
511
512    #[test]
513    fn test_generic_container_without_image_returns_error() {
514        let config = ServiceConfig {
515            r#type: "container".to_string(),
516            plugin: "generic_container".to_string(),
517            image: None, // Missing required field
518            env: None,
519            ports: None,
520            volumes: None,
521            health_check: None,
522            username: None,
523            password: None,
524            strict: None,
525        };
526
527        let result = ServiceFactory::create_plugin("test", &config);
528        assert!(result.is_err());
529    }
530
531    #[test]
532    fn test_ollama_without_endpoint_returns_error() {
533        let mut env = HashMap::new();
534        env.insert("default_model".to_string(), "llama2".to_string());
535        // Missing endpoint
536
537        let config = ServiceConfig {
538            r#type: "ollama".to_string(), // Changed from "ai_service"
539            plugin: "ollama".to_string(),
540            image: None,
541            env: Some(env),
542            ports: None,
543            volumes: None,
544            health_check: None,
545            username: None,
546            password: None,
547            strict: None,
548        };
549
550        let result = ServiceFactory::create_plugin("test", &config);
551        assert!(result.is_err());
552
553        if let Err(e) = result {
554            assert!(e.message.contains("endpoint"));
555        }
556    }
557
558    #[test]
559    fn test_case_insensitive_plugin_type() -> Result<()> {
560        let config = ServiceConfig {
561            r#type: "database".to_string(),
562            plugin: "SurrealDB".to_string(), // Mixed case
563            image: Some("surrealdb/surrealdb:latest".to_string()),
564            env: None,
565            ports: None,
566            volumes: None,
567            health_check: None,
568            username: None,
569            password: None,
570            strict: None,
571        };
572
573        let plugin = ServiceFactory::create_plugin("test_db", &config)?;
574        assert_eq!(plugin.name(), "surrealdb");
575        Ok(())
576    }
577}