Skip to main content

kapsl_backends/
factory.rs

1use crate::onnx::{ExecutionProvider, OnnxBackend, OnnxBackendBuilder};
2use kapsl_core::loader::Manifest;
3use kapsl_core::HardwareRequirements;
4use kapsl_engine_api::Engine;
5use kapsl_hal::device::DeviceInfo;
6use kapsl_llm::llm_backend::LLMBackend;
7use kapsl_llm::GgufBackend;
8#[cfg(target_os = "windows")]
9use ort::execution_providers::DirectMLExecutionProvider;
10use ort::execution_providers::ExecutionProvider as _;
11use ort::execution_providers::{
12    CUDAExecutionProvider, CoreMLExecutionProvider, OpenVINOExecutionProvider,
13    ROCmExecutionProvider, TensorRTExecutionProvider,
14};
15use ort::session::builder::GraphOptimizationLevel;
16
17pub struct BackendFactory;
18
19#[derive(Debug, Clone, Default)]
20pub struct OnnxRuntimeTuning {
21    pub memory_pattern: Option<bool>,
22    pub disable_cpu_mem_arena: Option<bool>,
23    pub session_buckets: Option<usize>,
24    pub bucket_dim_granularity: Option<usize>,
25    pub bucket_max_dims: Option<usize>,
26    pub peak_concurrency_hint: Option<u32>,
27}
28
29pub fn parse_optimization_level(level: Option<&String>) -> Result<GraphOptimizationLevel, String> {
30    match level.as_ref().map(|s| s.as_str()) {
31        Some("disable") | Some("0") => Ok(GraphOptimizationLevel::Disable),
32        Some("basic") | Some("1") => Ok(GraphOptimizationLevel::Level1),
33        Some("extended") | Some("2") => Ok(GraphOptimizationLevel::Level2),
34        Some("all") | Some("3") | Some("99") | None => Ok(GraphOptimizationLevel::Level3),
35        _ => Err("Unknown optimization level".to_string()),
36    }
37}
38
39impl BackendFactory {
40    fn apply_onnx_tuning(
41        mut builder: OnnxBackendBuilder,
42        tuning: &OnnxRuntimeTuning,
43    ) -> OnnxBackendBuilder {
44        if let Some(v) = tuning.memory_pattern {
45            builder = builder.with_memory_pattern(v);
46        }
47        if let Some(v) = tuning.disable_cpu_mem_arena {
48            builder = builder.with_disable_cpu_mem_arena(v);
49        }
50        if let Some(v) = tuning.session_buckets {
51            builder = builder.with_max_bucket_sessions(v);
52        }
53        if let Some(v) = tuning.bucket_dim_granularity {
54            builder = builder.with_bucket_dim_granularity(v);
55        }
56        if let Some(v) = tuning.bucket_max_dims {
57            builder = builder.with_bucket_max_dims(v);
58        }
59        if let Some(v) = tuning.peak_concurrency_hint {
60            builder = builder.with_peak_concurrency_hint(v);
61        }
62        builder
63    }
64
65    fn build_onnx_backend(
66        provider: ExecutionProvider,
67        opt_level: GraphOptimizationLevel,
68        device_id: i32,
69        tuning: &OnnxRuntimeTuning,
70    ) -> Result<Box<dyn Engine>, String> {
71        let mut builder = OnnxBackend::builder()
72            .with_provider(provider)
73            .with_optimization_level(opt_level);
74        if !matches!(provider, ExecutionProvider::CPU) {
75            builder = builder.with_device_id(device_id)?;
76        }
77        builder = Self::apply_onnx_tuning(builder, tuning);
78        Ok(Box::new(builder.build()))
79    }
80
81    fn push_unique_provider(providers: &mut Vec<String>, provider: &str) {
82        if providers
83            .iter()
84            .any(|candidate| candidate.eq_ignore_ascii_case(provider))
85        {
86            return;
87        }
88        providers.push(provider.to_string());
89    }
90
91    fn provider_policy() -> String {
92        std::env::var("KAPSL_PROVIDER_POLICY")
93            .or_else(|_| std::env::var("KAPSL_PROVIDER_POLICY"))
94            .unwrap_or_else(|_| "fastest".to_string())
95            .trim()
96            .to_ascii_lowercase()
97    }
98
99    fn should_append_fastest_candidates(providers: &[String]) -> bool {
100        if Self::provider_policy() == "manifest" {
101            return false;
102        }
103
104        providers.is_empty()
105            || providers
106                .iter()
107                .all(|provider| matches!(provider.trim().to_ascii_lowercase().as_str(), "" | "cpu"))
108    }
109
110    fn append_fastest_candidates(device_info: &DeviceInfo, providers: &mut Vec<String>) {
111        if device_info.has_cuda {
112            Self::push_unique_provider(providers, "tensorrt");
113            Self::push_unique_provider(providers, "cuda");
114        }
115        if device_info.has_metal {
116            Self::push_unique_provider(providers, "coreml");
117        }
118        if device_info.has_rocm {
119            Self::push_unique_provider(providers, "rocm");
120        }
121        if device_info.has_directml {
122            Self::push_unique_provider(providers, "directml");
123        }
124        Self::push_unique_provider(providers, "cpu");
125    }
126
127    /// Create the best available backend based on manifest requirements and available hardware
128    pub fn create_best_backend(
129        manifest: &Manifest,
130        device_info: &DeviceInfo,
131    ) -> Result<Box<dyn Engine>, String> {
132        Self::create_best_backend_with_tuning(manifest, device_info, &OnnxRuntimeTuning::default())
133    }
134
135    pub fn create_best_backend_with_tuning(
136        manifest: &Manifest,
137        device_info: &DeviceInfo,
138        tuning: &OnnxRuntimeTuning,
139    ) -> Result<Box<dyn Engine>, String> {
140        // GGUF: route to the llama.cpp-backed GgufBackend
141        if manifest.framework == "gguf" {
142            log::info!("✓ Using GgufBackend (llama.cpp)");
143            return Ok(Box::new(GgufBackend::new()));
144        }
145
146        // Check for LLM framework
147        if manifest.framework == "llm" {
148            let requirements = &manifest.hardware_requirements;
149            if Self::provider_policy() == "manifest" {
150                if let Some(provider) = requirements.preferred_provider.clone() {
151                    let device_id = requirements.device_id.unwrap_or(0);
152                    log::info!(
153                        "✓ Using LLMBackend with manifest provider override: {}",
154                        provider
155                    );
156                    return Ok(Box::new(LLMBackend::with_device(provider, device_id)));
157                }
158            }
159            log::info!("✓ Using LLMBackend with runtime fastest-provider selection");
160            return Ok(Box::new(LLMBackend::new()));
161        }
162
163        let requirements = &manifest.hardware_requirements;
164
165        log::info!("🔍 Selecting backend based on requirements:");
166        log::info!("   Preferred: {:?}", requirements.preferred_provider);
167        log::info!("   Fallbacks: {:?}", requirements.fallback_providers);
168        log::info!(
169            "   Graph Optimization: {:?}",
170            requirements.graph_optimization_level
171        );
172
173        // Parse and validate optimization level early (fail-fast)
174        let opt_level = parse_optimization_level(requirements.graph_optimization_level.as_ref())
175            .map_err(|e| format!("Invalid graph optimization level in manifest: {}", e))?;
176
177        log::info!("   Graph Optimization: {:?}", opt_level);
178
179        let mut providers_to_try = Vec::new();
180        if let Some(preferred) = &requirements.preferred_provider {
181            Self::push_unique_provider(&mut providers_to_try, preferred);
182        }
183        for provider in &requirements.fallback_providers {
184            Self::push_unique_provider(&mut providers_to_try, provider);
185        }
186
187        if Self::should_append_fastest_candidates(&providers_to_try) {
188            log::info!("⚡ Provider policy `fastest`: appending hardware-accelerated providers");
189            Self::append_fastest_candidates(device_info, &mut providers_to_try);
190        }
191
192        for provider in &providers_to_try {
193            let device_id = requirements.device_id.unwrap_or(0);
194            match Self::try_create_provider(provider, device_info, opt_level, device_id, tuning) {
195                Ok(backend) => {
196                    log::info!("✓ Using provider: {}", provider);
197                    return Ok(backend);
198                }
199                Err(err) => {
200                    log::warn!("⚠ Provider '{}' not available: {}", provider, err);
201                }
202            }
203        }
204
205        // Last resort: CPU
206        log::info!("⚠ Using last-resort CPU backend");
207        let opt_cpu = parse_optimization_level(requirements.graph_optimization_level.as_ref())
208            .unwrap_or(GraphOptimizationLevel::Level3);
209        Self::build_onnx_backend(ExecutionProvider::CPU, opt_cpu, 0, tuning)
210    }
211
212    /// Create a backend for a specific device
213    pub fn create_backend_for_device(
214        manifest: &Manifest,
215        provider: &str,
216        device_id: usize,
217        device_info: &DeviceInfo,
218    ) -> Result<Box<dyn Engine>, String> {
219        Self::create_backend_for_device_with_tuning(
220            manifest,
221            provider,
222            device_id,
223            device_info,
224            &OnnxRuntimeTuning::default(),
225        )
226    }
227
228    pub fn create_backend_for_device_with_tuning(
229        manifest: &Manifest,
230        provider: &str,
231        device_id: usize,
232        device_info: &DeviceInfo,
233        tuning: &OnnxRuntimeTuning,
234    ) -> Result<Box<dyn Engine>, String> {
235        // GGUF: route to the llama.cpp-backed GgufBackend
236        if manifest.framework == "gguf" {
237            log::info!("✓ Using GgufBackend (llama.cpp)");
238            return Ok(Box::new(GgufBackend::new()));
239        }
240
241        // Check for LLM framework
242        if manifest.framework == "llm" {
243            if Self::provider_policy() == "manifest" {
244                log::info!(
245                    "✓ Using LLMBackend with manifest provider override: {}",
246                    provider
247                );
248                return Ok(Box::new(LLMBackend::with_device(
249                    provider.to_string(),
250                    device_id as i32,
251                )));
252            }
253
254            log::info!(
255                "✓ Using LLMBackend with device pinning and runtime provider auto-selection"
256            );
257            return Ok(Box::new(LLMBackend::with_device_id(device_id as i32)));
258        }
259
260        let requirements = &manifest.hardware_requirements;
261        let opt_level = parse_optimization_level(requirements.graph_optimization_level.as_ref())
262            .map_err(|e| format!("Invalid graph optimization level in manifest: {}", e))?;
263
264        Self::try_create_provider(provider, device_info, opt_level, device_id as i32, tuning)
265    }
266
267    fn try_create_provider(
268        provider: &str,
269        device_info: &DeviceInfo,
270        opt_level: GraphOptimizationLevel,
271        device_id: i32,
272        tuning: &OnnxRuntimeTuning,
273    ) -> Result<Box<dyn Engine>, String> {
274        let provider_lower = provider.to_lowercase();
275
276        match provider_lower.as_str() {
277            "cuda" => {
278                if !device_info.has_cuda {
279                    return Err("CUDA not available on this system".to_string());
280                }
281                if !CUDAExecutionProvider::default()
282                    .is_available()
283                    .unwrap_or(false)
284                {
285                    return Err(
286                        "CUDA execution provider is not available in ONNX Runtime".to_string()
287                    );
288                }
289                let cuda_version = device_info
290                    .devices
291                    .iter()
292                    .find(|d| matches!(d.backend, kapsl_hal::device::DeviceBackend::Cuda))
293                    .and_then(|d| d.cuda_version.as_ref())
294                    .map(|s| s.as_str())
295                    .unwrap_or("unknown");
296                log::info!("   CUDA available: version {:?}", cuda_version);
297                Self::build_onnx_backend(ExecutionProvider::CUDA, opt_level, device_id, tuning)
298            }
299
300            "tensorrt" => {
301                if !device_info.has_cuda {
302                    return Err("TensorRT requires CUDA-capable GPU".to_string());
303                }
304                if !TensorRTExecutionProvider::default()
305                    .is_available()
306                    .unwrap_or(false)
307                {
308                    return Err(
309                        "TensorRT execution provider is not available in ONNX Runtime".to_string(),
310                    );
311                }
312                log::info!("   TensorRT requested (requires CUDA)");
313                Self::build_onnx_backend(ExecutionProvider::TensorRT, opt_level, device_id, tuning)
314            }
315
316            "metal" | "coreml" => {
317                if !device_info.has_metal {
318                    return Err(format!(
319                        "{} not available on this system",
320                        if provider_lower == "metal" {
321                            "Metal"
322                        } else {
323                            "CoreML"
324                        }
325                    ));
326                }
327                if !CoreMLExecutionProvider::default()
328                    .is_available()
329                    .unwrap_or(false)
330                {
331                    return Err("CoreML execution provider is not available".to_string());
332                }
333                if provider_lower == "metal" {
334                    log::info!("   Metal available on macOS");
335                    log::info!("   Using CoreML execution provider for Metal");
336                } else {
337                    log::info!("   CoreML available on macOS");
338                }
339                // CoreML performs best with basic optimization; aggressive levels
340                // can cause layout issues and runtime errors on Apple Silicon.
341                let coreml_opt_level = match opt_level {
342                    GraphOptimizationLevel::Level2 | GraphOptimizationLevel::Level3 => {
343                        log::info!("   Capping optimization level to Level1 for CoreML backend");
344                        GraphOptimizationLevel::Level1
345                    }
346                    other => other,
347                };
348                Self::build_onnx_backend(
349                    ExecutionProvider::CoreML,
350                    coreml_opt_level,
351                    device_id,
352                    tuning,
353                )
354            }
355            "rocm" => {
356                if !device_info.has_rocm {
357                    return Err("ROCm not available on this system".to_string());
358                }
359                if !ROCmExecutionProvider::default()
360                    .is_available()
361                    .unwrap_or(false)
362                {
363                    return Err("ROCm execution provider is not available".to_string());
364                }
365                log::info!("   ROCm available");
366                Self::build_onnx_backend(ExecutionProvider::ROCm, opt_level, device_id, tuning)
367            }
368            "directml" => {
369                #[cfg(target_os = "windows")]
370                {
371                    if !device_info.has_directml {
372                        return Err("DirectML not available on this system".to_string());
373                    }
374                    if !DirectMLExecutionProvider::default()
375                        .is_available()
376                        .unwrap_or(false)
377                    {
378                        return Err("DirectML execution provider is not available".to_string());
379                    }
380                    log::info!("   DirectML available");
381                    Self::build_onnx_backend(
382                        ExecutionProvider::DirectML,
383                        opt_level,
384                        device_id,
385                        tuning,
386                    )
387                }
388                #[cfg(not(target_os = "windows"))]
389                {
390                    Err("DirectML is only supported on Windows".to_string())
391                }
392            }
393            "openvino" => {
394                if !OpenVINOExecutionProvider::default()
395                    .is_available()
396                    .unwrap_or(false)
397                {
398                    return Err("OpenVINO execution provider is not available".to_string());
399                }
400                log::info!("   OpenVINO available");
401                Self::build_onnx_backend(ExecutionProvider::OpenVINO, opt_level, device_id, tuning)
402            }
403
404            "cpu" => {
405                log::info!("   Using CPU execution");
406                Self::build_onnx_backend(ExecutionProvider::CPU, opt_level, 0, tuning)
407            }
408
409            _ => Err(format!("Unknown provider: {}", provider)),
410        }
411    }
412
413    /// Validate that hardware meets minimum requirements
414    pub fn validate_requirements(
415        requirements: &HardwareRequirements,
416        device_info: &DeviceInfo,
417    ) -> Result<(), String> {
418        // Validation logic for CPU memory
419        if let Some(min_mem_mb) = requirements.min_memory_mb {
420            let available_mb = device_info.total_memory / (1024 * 1024);
421            if available_mb < min_mem_mb {
422                return Err(format!(
423                    "Insufficient memory: need {}MB, have {}MB",
424                    min_mem_mb, available_mb
425                ));
426            }
427        }
428
429        // Collect all providers to check (preferred + fallbacks)
430        let mut providers_to_check = Vec::new();
431        if let Some(preferred) = &requirements.preferred_provider {
432            providers_to_check.push(preferred.clone());
433        }
434        providers_to_check.extend(requirements.fallback_providers.clone());
435
436        // We only fail if NONE of the providers are valid/present
437        let mut reasons = Vec::new();
438        let mut has_valid_provider = false;
439
440        let strategy = requirements
441            .strategy
442            .as_deref()
443            .unwrap_or("")
444            .to_ascii_lowercase();
445        let allow_multi = matches!(
446            strategy.as_str(),
447            "pool"
448                | "round-robin"
449                | "data-parallel"
450                | "pipeline"
451                | "pipeline-parallel"
452                | "tensor-parallel"
453                | "auto"
454        );
455
456        for provider in &providers_to_check {
457            let provider_lower = provider.to_lowercase();
458            let backend_key = match provider_lower.as_str() {
459                "tensorrt" => "cuda",
460                "coreml" => "metal",
461                other => other,
462            };
463            let is_cpu = backend_key == "cpu";
464
465            if is_cpu {
466                // CPU is always valid if memory check passed (which is global above, though strictly
467                // memory check should maybe be per-provider if requirements differed, but here it's global)
468                has_valid_provider = true;
469                break;
470            }
471
472            // GPU checks
473            if !device_info.has_provider(backend_key) {
474                reasons.push(format!("Provider {} not available", provider));
475                continue;
476            }
477
478            let device_meets = |device: &kapsl_hal::device::Device| -> bool {
479                if backend_key != "cpu" {
480                    if let Some(min_vram) = requirements.min_vram_mb {
481                        if device.memory_mb < min_vram {
482                            return false;
483                        }
484                    }
485                    if backend_key == "cuda" {
486                        if let Some(min_ver) = &requirements.min_cuda_version {
487                            if let Some(dev_ver) = &device.cuda_version {
488                                if dev_ver < min_ver {
489                                    return false;
490                                }
491                            } else {
492                                return false;
493                            }
494                        }
495                    }
496                }
497                true
498            };
499
500            if allow_multi {
501                let mut candidates = device_info
502                    .devices
503                    .iter()
504                    .filter(|d| d.backend.to_string().to_lowercase() == backend_key);
505
506                if candidates.any(device_meets) {
507                    has_valid_provider = true;
508                    break;
509                }
510
511                reasons.push(format!(
512                    "No devices meet requirements for provider {}",
513                    provider
514                ));
515                continue;
516            }
517
518            // Find the device
519            let dev_id = requirements.device_id.unwrap_or(0) as usize;
520            // Note: device_id 0 is usually the first GPU if provider is GPU.
521            if let Some(device) = device_info
522                .devices
523                .iter()
524                .find(|d| d.id == dev_id && d.backend.to_string().to_lowercase() == backend_key)
525            {
526                // Check VRAM
527                if let Some(min_vram) = requirements.min_vram_mb {
528                    if device.memory_mb < min_vram {
529                        reasons.push(format!(
530                            "Provider {} (Device {}) has insufficient VRAM: {}MB < required {}MB",
531                            provider, dev_id, device.memory_mb, min_vram
532                        ));
533                        continue;
534                    }
535                }
536
537                // Check CUDA version
538                if backend_key == "cuda" {
539                    if let Some(min_ver) = &requirements.min_cuda_version {
540                        if let Some(dev_ver) = &device.cuda_version {
541                            if dev_ver < min_ver {
542                                reasons.push(format!(
543                                    "CUDA version too old: {} < required {}",
544                                    dev_ver, min_ver
545                                ));
546                                continue;
547                            }
548                        } else {
549                            reasons.push("Unknown CUDA version on device".to_string());
550                            continue;
551                        }
552                    }
553                }
554
555                has_valid_provider = true;
556                break;
557            } else {
558                reasons.push(format!(
559                    "Device ID {} not found for provider {}",
560                    dev_id, provider
561                ));
562            }
563        }
564
565        if !has_valid_provider {
566            if providers_to_check.is_empty() {
567                // No requirements?
568                return Ok(());
569            }
570            // If we have CPU in list and it wasn't caught above, it means something weird happened.
571            // But usually CPU works.
572            return Err(format!(
573                "No compatible provider found. Reasons: {:?}",
574                reasons
575            ));
576        }
577
578        Ok(())
579    }
580}
581
582#[path = "factory_tests.rs"]
583mod factory_tests;