1use crate::onnx::{ExecutionProvider, OnnxBackend, OnnxBackendBuilder};
2use kapsl_core::loader::Manifest;
3use kapsl_core::HardwareRequirements;
4use kapsl_engine_api::Engine;
5use kapsl_hal::device::DeviceInfo;
6use kapsl_llm::llm_backend::LLMBackend;
7use kapsl_llm::GgufBackend;
8#[cfg(target_os = "windows")]
9use ort::execution_providers::DirectMLExecutionProvider;
10use ort::execution_providers::ExecutionProvider as _;
11use ort::execution_providers::{
12 CUDAExecutionProvider, CoreMLExecutionProvider, OpenVINOExecutionProvider,
13 ROCmExecutionProvider, TensorRTExecutionProvider,
14};
15use ort::session::builder::GraphOptimizationLevel;
16
17pub struct BackendFactory;
18
19#[derive(Debug, Clone, Default)]
20pub struct OnnxRuntimeTuning {
21 pub memory_pattern: Option<bool>,
22 pub disable_cpu_mem_arena: Option<bool>,
23 pub session_buckets: Option<usize>,
24 pub bucket_dim_granularity: Option<usize>,
25 pub bucket_max_dims: Option<usize>,
26 pub peak_concurrency_hint: Option<u32>,
27}
28
29pub fn parse_optimization_level(level: Option<&String>) -> Result<GraphOptimizationLevel, String> {
30 match level.as_ref().map(|s| s.as_str()) {
31 Some("disable") | Some("0") => Ok(GraphOptimizationLevel::Disable),
32 Some("basic") | Some("1") => Ok(GraphOptimizationLevel::Level1),
33 Some("extended") | Some("2") => Ok(GraphOptimizationLevel::Level2),
34 Some("all") | Some("3") | Some("99") | None => Ok(GraphOptimizationLevel::Level3),
35 _ => Err("Unknown optimization level".to_string()),
36 }
37}
38
39impl BackendFactory {
40 fn apply_onnx_tuning(
41 mut builder: OnnxBackendBuilder,
42 tuning: &OnnxRuntimeTuning,
43 ) -> OnnxBackendBuilder {
44 if let Some(v) = tuning.memory_pattern {
45 builder = builder.with_memory_pattern(v);
46 }
47 if let Some(v) = tuning.disable_cpu_mem_arena {
48 builder = builder.with_disable_cpu_mem_arena(v);
49 }
50 if let Some(v) = tuning.session_buckets {
51 builder = builder.with_max_bucket_sessions(v);
52 }
53 if let Some(v) = tuning.bucket_dim_granularity {
54 builder = builder.with_bucket_dim_granularity(v);
55 }
56 if let Some(v) = tuning.bucket_max_dims {
57 builder = builder.with_bucket_max_dims(v);
58 }
59 if let Some(v) = tuning.peak_concurrency_hint {
60 builder = builder.with_peak_concurrency_hint(v);
61 }
62 builder
63 }
64
65 fn build_onnx_backend(
66 provider: ExecutionProvider,
67 opt_level: GraphOptimizationLevel,
68 device_id: i32,
69 tuning: &OnnxRuntimeTuning,
70 ) -> Result<Box<dyn Engine>, String> {
71 let mut builder = OnnxBackend::builder()
72 .with_provider(provider)
73 .with_optimization_level(opt_level);
74 if !matches!(provider, ExecutionProvider::CPU) {
75 builder = builder.with_device_id(device_id)?;
76 }
77 builder = Self::apply_onnx_tuning(builder, tuning);
78 Ok(Box::new(builder.build()))
79 }
80
81 fn push_unique_provider(providers: &mut Vec<String>, provider: &str) {
82 if providers
83 .iter()
84 .any(|candidate| candidate.eq_ignore_ascii_case(provider))
85 {
86 return;
87 }
88 providers.push(provider.to_string());
89 }
90
91 fn provider_policy() -> String {
92 std::env::var("KAPSL_PROVIDER_POLICY")
93 .or_else(|_| std::env::var("KAPSL_PROVIDER_POLICY"))
94 .unwrap_or_else(|_| "fastest".to_string())
95 .trim()
96 .to_ascii_lowercase()
97 }
98
99 fn should_append_fastest_candidates(providers: &[String]) -> bool {
100 if Self::provider_policy() == "manifest" {
101 return false;
102 }
103
104 providers.is_empty()
105 || providers
106 .iter()
107 .all(|provider| matches!(provider.trim().to_ascii_lowercase().as_str(), "" | "cpu"))
108 }
109
110 fn append_fastest_candidates(device_info: &DeviceInfo, providers: &mut Vec<String>) {
111 if device_info.has_cuda {
112 Self::push_unique_provider(providers, "tensorrt");
113 Self::push_unique_provider(providers, "cuda");
114 }
115 if device_info.has_metal {
116 Self::push_unique_provider(providers, "coreml");
117 }
118 if device_info.has_rocm {
119 Self::push_unique_provider(providers, "rocm");
120 }
121 if device_info.has_directml {
122 Self::push_unique_provider(providers, "directml");
123 }
124 Self::push_unique_provider(providers, "cpu");
125 }
126
127 pub fn create_best_backend(
129 manifest: &Manifest,
130 device_info: &DeviceInfo,
131 ) -> Result<Box<dyn Engine>, String> {
132 Self::create_best_backend_with_tuning(manifest, device_info, &OnnxRuntimeTuning::default())
133 }
134
135 pub fn create_best_backend_with_tuning(
136 manifest: &Manifest,
137 device_info: &DeviceInfo,
138 tuning: &OnnxRuntimeTuning,
139 ) -> Result<Box<dyn Engine>, String> {
140 if manifest.framework == "gguf" {
142 log::info!("✓ Using GgufBackend (llama.cpp)");
143 return Ok(Box::new(GgufBackend::new()));
144 }
145
146 if manifest.framework == "llm" {
148 let requirements = &manifest.hardware_requirements;
149 if Self::provider_policy() == "manifest" {
150 if let Some(provider) = requirements.preferred_provider.clone() {
151 let device_id = requirements.device_id.unwrap_or(0);
152 log::info!(
153 "✓ Using LLMBackend with manifest provider override: {}",
154 provider
155 );
156 return Ok(Box::new(LLMBackend::with_device(provider, device_id)));
157 }
158 }
159 log::info!("✓ Using LLMBackend with runtime fastest-provider selection");
160 return Ok(Box::new(LLMBackend::new()));
161 }
162
163 let requirements = &manifest.hardware_requirements;
164
165 log::info!("🔍 Selecting backend based on requirements:");
166 log::info!(" Preferred: {:?}", requirements.preferred_provider);
167 log::info!(" Fallbacks: {:?}", requirements.fallback_providers);
168 log::info!(
169 " Graph Optimization: {:?}",
170 requirements.graph_optimization_level
171 );
172
173 let opt_level = parse_optimization_level(requirements.graph_optimization_level.as_ref())
175 .map_err(|e| format!("Invalid graph optimization level in manifest: {}", e))?;
176
177 log::info!(" Graph Optimization: {:?}", opt_level);
178
179 let mut providers_to_try = Vec::new();
180 if let Some(preferred) = &requirements.preferred_provider {
181 Self::push_unique_provider(&mut providers_to_try, preferred);
182 }
183 for provider in &requirements.fallback_providers {
184 Self::push_unique_provider(&mut providers_to_try, provider);
185 }
186
187 if Self::should_append_fastest_candidates(&providers_to_try) {
188 log::info!("⚡ Provider policy `fastest`: appending hardware-accelerated providers");
189 Self::append_fastest_candidates(device_info, &mut providers_to_try);
190 }
191
192 for provider in &providers_to_try {
193 let device_id = requirements.device_id.unwrap_or(0);
194 match Self::try_create_provider(provider, device_info, opt_level, device_id, tuning) {
195 Ok(backend) => {
196 log::info!("✓ Using provider: {}", provider);
197 return Ok(backend);
198 }
199 Err(err) => {
200 log::warn!("⚠ Provider '{}' not available: {}", provider, err);
201 }
202 }
203 }
204
205 log::info!("⚠ Using last-resort CPU backend");
207 let opt_cpu = parse_optimization_level(requirements.graph_optimization_level.as_ref())
208 .unwrap_or(GraphOptimizationLevel::Level3);
209 Self::build_onnx_backend(ExecutionProvider::CPU, opt_cpu, 0, tuning)
210 }
211
212 pub fn create_backend_for_device(
214 manifest: &Manifest,
215 provider: &str,
216 device_id: usize,
217 device_info: &DeviceInfo,
218 ) -> Result<Box<dyn Engine>, String> {
219 Self::create_backend_for_device_with_tuning(
220 manifest,
221 provider,
222 device_id,
223 device_info,
224 &OnnxRuntimeTuning::default(),
225 )
226 }
227
228 pub fn create_backend_for_device_with_tuning(
229 manifest: &Manifest,
230 provider: &str,
231 device_id: usize,
232 device_info: &DeviceInfo,
233 tuning: &OnnxRuntimeTuning,
234 ) -> Result<Box<dyn Engine>, String> {
235 if manifest.framework == "gguf" {
237 log::info!("✓ Using GgufBackend (llama.cpp)");
238 return Ok(Box::new(GgufBackend::new()));
239 }
240
241 if manifest.framework == "llm" {
243 if Self::provider_policy() == "manifest" {
244 log::info!(
245 "✓ Using LLMBackend with manifest provider override: {}",
246 provider
247 );
248 return Ok(Box::new(LLMBackend::with_device(
249 provider.to_string(),
250 device_id as i32,
251 )));
252 }
253
254 log::info!(
255 "✓ Using LLMBackend with device pinning and runtime provider auto-selection"
256 );
257 return Ok(Box::new(LLMBackend::with_device_id(device_id as i32)));
258 }
259
260 let requirements = &manifest.hardware_requirements;
261 let opt_level = parse_optimization_level(requirements.graph_optimization_level.as_ref())
262 .map_err(|e| format!("Invalid graph optimization level in manifest: {}", e))?;
263
264 Self::try_create_provider(provider, device_info, opt_level, device_id as i32, tuning)
265 }
266
267 fn try_create_provider(
268 provider: &str,
269 device_info: &DeviceInfo,
270 opt_level: GraphOptimizationLevel,
271 device_id: i32,
272 tuning: &OnnxRuntimeTuning,
273 ) -> Result<Box<dyn Engine>, String> {
274 let provider_lower = provider.to_lowercase();
275
276 match provider_lower.as_str() {
277 "cuda" => {
278 if !device_info.has_cuda {
279 return Err("CUDA not available on this system".to_string());
280 }
281 if !CUDAExecutionProvider::default()
282 .is_available()
283 .unwrap_or(false)
284 {
285 return Err(
286 "CUDA execution provider is not available in ONNX Runtime".to_string()
287 );
288 }
289 let cuda_version = device_info
290 .devices
291 .iter()
292 .find(|d| matches!(d.backend, kapsl_hal::device::DeviceBackend::Cuda))
293 .and_then(|d| d.cuda_version.as_ref())
294 .map(|s| s.as_str())
295 .unwrap_or("unknown");
296 log::info!(" CUDA available: version {:?}", cuda_version);
297 Self::build_onnx_backend(ExecutionProvider::CUDA, opt_level, device_id, tuning)
298 }
299
300 "tensorrt" => {
301 if !device_info.has_cuda {
302 return Err("TensorRT requires CUDA-capable GPU".to_string());
303 }
304 if !TensorRTExecutionProvider::default()
305 .is_available()
306 .unwrap_or(false)
307 {
308 return Err(
309 "TensorRT execution provider is not available in ONNX Runtime".to_string(),
310 );
311 }
312 log::info!(" TensorRT requested (requires CUDA)");
313 Self::build_onnx_backend(ExecutionProvider::TensorRT, opt_level, device_id, tuning)
314 }
315
316 "metal" | "coreml" => {
317 if !device_info.has_metal {
318 return Err(format!(
319 "{} not available on this system",
320 if provider_lower == "metal" {
321 "Metal"
322 } else {
323 "CoreML"
324 }
325 ));
326 }
327 if !CoreMLExecutionProvider::default()
328 .is_available()
329 .unwrap_or(false)
330 {
331 return Err("CoreML execution provider is not available".to_string());
332 }
333 if provider_lower == "metal" {
334 log::info!(" Metal available on macOS");
335 log::info!(" Using CoreML execution provider for Metal");
336 } else {
337 log::info!(" CoreML available on macOS");
338 }
339 let coreml_opt_level = match opt_level {
342 GraphOptimizationLevel::Level2 | GraphOptimizationLevel::Level3 => {
343 log::info!(" Capping optimization level to Level1 for CoreML backend");
344 GraphOptimizationLevel::Level1
345 }
346 other => other,
347 };
348 Self::build_onnx_backend(
349 ExecutionProvider::CoreML,
350 coreml_opt_level,
351 device_id,
352 tuning,
353 )
354 }
355 "rocm" => {
356 if !device_info.has_rocm {
357 return Err("ROCm not available on this system".to_string());
358 }
359 if !ROCmExecutionProvider::default()
360 .is_available()
361 .unwrap_or(false)
362 {
363 return Err("ROCm execution provider is not available".to_string());
364 }
365 log::info!(" ROCm available");
366 Self::build_onnx_backend(ExecutionProvider::ROCm, opt_level, device_id, tuning)
367 }
368 "directml" => {
369 #[cfg(target_os = "windows")]
370 {
371 if !device_info.has_directml {
372 return Err("DirectML not available on this system".to_string());
373 }
374 if !DirectMLExecutionProvider::default()
375 .is_available()
376 .unwrap_or(false)
377 {
378 return Err("DirectML execution provider is not available".to_string());
379 }
380 log::info!(" DirectML available");
381 Self::build_onnx_backend(
382 ExecutionProvider::DirectML,
383 opt_level,
384 device_id,
385 tuning,
386 )
387 }
388 #[cfg(not(target_os = "windows"))]
389 {
390 Err("DirectML is only supported on Windows".to_string())
391 }
392 }
393 "openvino" => {
394 if !OpenVINOExecutionProvider::default()
395 .is_available()
396 .unwrap_or(false)
397 {
398 return Err("OpenVINO execution provider is not available".to_string());
399 }
400 log::info!(" OpenVINO available");
401 Self::build_onnx_backend(ExecutionProvider::OpenVINO, opt_level, device_id, tuning)
402 }
403
404 "cpu" => {
405 log::info!(" Using CPU execution");
406 Self::build_onnx_backend(ExecutionProvider::CPU, opt_level, 0, tuning)
407 }
408
409 _ => Err(format!("Unknown provider: {}", provider)),
410 }
411 }
412
413 pub fn validate_requirements(
415 requirements: &HardwareRequirements,
416 device_info: &DeviceInfo,
417 ) -> Result<(), String> {
418 if let Some(min_mem_mb) = requirements.min_memory_mb {
420 let available_mb = device_info.total_memory / (1024 * 1024);
421 if available_mb < min_mem_mb {
422 return Err(format!(
423 "Insufficient memory: need {}MB, have {}MB",
424 min_mem_mb, available_mb
425 ));
426 }
427 }
428
429 let mut providers_to_check = Vec::new();
431 if let Some(preferred) = &requirements.preferred_provider {
432 providers_to_check.push(preferred.clone());
433 }
434 providers_to_check.extend(requirements.fallback_providers.clone());
435
436 let mut reasons = Vec::new();
438 let mut has_valid_provider = false;
439
440 let strategy = requirements
441 .strategy
442 .as_deref()
443 .unwrap_or("")
444 .to_ascii_lowercase();
445 let allow_multi = matches!(
446 strategy.as_str(),
447 "pool"
448 | "round-robin"
449 | "data-parallel"
450 | "pipeline"
451 | "pipeline-parallel"
452 | "tensor-parallel"
453 | "auto"
454 );
455
456 for provider in &providers_to_check {
457 let provider_lower = provider.to_lowercase();
458 let backend_key = match provider_lower.as_str() {
459 "tensorrt" => "cuda",
460 "coreml" => "metal",
461 other => other,
462 };
463 let is_cpu = backend_key == "cpu";
464
465 if is_cpu {
466 has_valid_provider = true;
469 break;
470 }
471
472 if !device_info.has_provider(backend_key) {
474 reasons.push(format!("Provider {} not available", provider));
475 continue;
476 }
477
478 let device_meets = |device: &kapsl_hal::device::Device| -> bool {
479 if backend_key != "cpu" {
480 if let Some(min_vram) = requirements.min_vram_mb {
481 if device.memory_mb < min_vram {
482 return false;
483 }
484 }
485 if backend_key == "cuda" {
486 if let Some(min_ver) = &requirements.min_cuda_version {
487 if let Some(dev_ver) = &device.cuda_version {
488 if dev_ver < min_ver {
489 return false;
490 }
491 } else {
492 return false;
493 }
494 }
495 }
496 }
497 true
498 };
499
500 if allow_multi {
501 let mut candidates = device_info
502 .devices
503 .iter()
504 .filter(|d| d.backend.to_string().to_lowercase() == backend_key);
505
506 if candidates.any(device_meets) {
507 has_valid_provider = true;
508 break;
509 }
510
511 reasons.push(format!(
512 "No devices meet requirements for provider {}",
513 provider
514 ));
515 continue;
516 }
517
518 let dev_id = requirements.device_id.unwrap_or(0) as usize;
520 if let Some(device) = device_info
522 .devices
523 .iter()
524 .find(|d| d.id == dev_id && d.backend.to_string().to_lowercase() == backend_key)
525 {
526 if let Some(min_vram) = requirements.min_vram_mb {
528 if device.memory_mb < min_vram {
529 reasons.push(format!(
530 "Provider {} (Device {}) has insufficient VRAM: {}MB < required {}MB",
531 provider, dev_id, device.memory_mb, min_vram
532 ));
533 continue;
534 }
535 }
536
537 if backend_key == "cuda" {
539 if let Some(min_ver) = &requirements.min_cuda_version {
540 if let Some(dev_ver) = &device.cuda_version {
541 if dev_ver < min_ver {
542 reasons.push(format!(
543 "CUDA version too old: {} < required {}",
544 dev_ver, min_ver
545 ));
546 continue;
547 }
548 } else {
549 reasons.push("Unknown CUDA version on device".to_string());
550 continue;
551 }
552 }
553 }
554
555 has_valid_provider = true;
556 break;
557 } else {
558 reasons.push(format!(
559 "Device ID {} not found for provider {}",
560 dev_id, provider
561 ));
562 }
563 }
564
565 if !has_valid_provider {
566 if providers_to_check.is_empty() {
567 return Ok(());
569 }
570 return Err(format!(
573 "No compatible provider found. Reasons: {:?}",
574 reasons
575 ));
576 }
577
578 Ok(())
579 }
580}
581
582#[path = "factory_tests.rs"]
583mod factory_tests;