1use anyhow::{Context, Result};
4use std::env;
5use std::net::SocketAddr;
6use tracing::{info, warn};
7use nvml_wrapper::Nvml;
8use sysinfo::System;
9
10#[allow(dead_code)]
11#[derive(Debug, Clone)]
12pub struct Config {
13 pub model_path: String,
14 pub llama_bin: String,
15 pub llama_host: String,
16 pub llama_port: u16,
17 pub ctx_size: u32,
18 pub batch_size: u32,
19 pub threads: u32,
20 pub gpu_layers: u32,
21 pub health_timeout_seconds: u64,
22 pub hot_swap_grace_seconds: u64,
23 pub max_concurrent_streams: u32,
24 pub prometheus_port: u16,
25 pub api_host: String,
26 pub api_port: u16,
27 pub requests_per_second: u32,
28 pub generate_timeout_seconds: u64,
29 pub stream_timeout_seconds: u64,
30 pub health_check_timeout_seconds: u64,
31 pub queue_size: usize,
32 pub queue_timeout_seconds: u64,
33 pub backend_url: String,
34}
35
36impl Config {
37 pub fn from_env() -> Result<Self> {
38 if let Err(e) = dotenvy::dotenv() {
39 warn!("Failed to load .env file: {}. Using system environment variables.", e);
40 } else {
41 info!("Loaded environment variables from .env file");
42 }
43
44 let llama_bin = env::var("LLAMA_BIN")
46 .context("LLAMA_BIN environment variable not set. Please set it in your .env file")?;
47
48 if !std::path::Path::new(&llama_bin).exists() {
50 return Err(anyhow::anyhow!(
51 "Llama binary not found at: {}. Please check LLAMA_BIN in .env file.",
52 llama_bin
53 ));
54 }
55
56 info!("Using llama binary from .env: {}", llama_bin);
57
58 let model_path = Self::get_model_path_with_fallback()?;
60
61 let threads = if env::var("THREADS").unwrap_or_else(|_| "auto".into()) == "auto" {
63 Self::auto_detect_threads()
64 } else {
65 env::var("THREADS").unwrap_or_else(|_| "6".into()).parse().unwrap_or(6)
66 };
67
68 let gpu_layers = if env::var("GPU_LAYERS").unwrap_or_else(|_| "auto".into()) == "auto" {
70 Self::auto_detect_gpu_layers()
71 } else {
72 env::var("GPU_LAYERS").unwrap_or_else(|_| "20".into()).parse().unwrap_or(20)
73 };
74
75 let ctx_size = if env::var("CTX_SIZE").unwrap_or_else(|_| "auto".into()) == "auto" {
77 Self::auto_detect_ctx_size(&model_path)
78 } else {
79 env::var("CTX_SIZE").unwrap_or_else(|_| "8192".into()).parse().unwrap_or(8192)
80 };
81
82 let batch_size = if env::var("BATCH_SIZE").unwrap_or_else(|_| "auto".into()) == "auto" {
84 Self::auto_detect_batch_size(gpu_layers, ctx_size)
85 } else {
86 env::var("BATCH_SIZE").unwrap_or_else(|_| "256".into()).parse().unwrap_or(256)
87 };
88
89 let llama_host = env::var("LLAMA_HOST").unwrap_or_else(|_| "127.0.0.1".into());
91 let llama_port = env::var("LLAMA_PORT").unwrap_or_else(|_| "8081".into()).parse()?;
92 let backend_url = format!("http://{}:{}", llama_host, llama_port);
93
94 info!(
95 "Resource Configuration: {} GPU layers, {} threads, batch size: {}, context: {}",
96 gpu_layers, threads, batch_size, ctx_size
97 );
98
99 Ok(Self {
100 model_path,
101 llama_bin,
102 llama_host: llama_host.clone(),
103 llama_port,
104 ctx_size,
105 batch_size,
106 threads,
107 gpu_layers,
108 health_timeout_seconds: env::var("HEALTH_TIMEOUT_SECONDS")
109 .unwrap_or_else(|_| "60".into())
110 .parse()?,
111 hot_swap_grace_seconds: env::var("HOT_SWAP_GRACE_SECONDS")
112 .unwrap_or_else(|_| "25".into())
113 .parse()?,
114 max_concurrent_streams: env::var("MAX_CONCURRENT_STREAMS")
115 .unwrap_or_else(|_| "4".into())
116 .parse()?,
117 prometheus_port: env::var("PROMETHEUS_PORT")
118 .unwrap_or_else(|_| "9000".into())
119 .parse()?,
120 api_host: env::var("API_HOST").unwrap_or_else(|_| "127.0.0.1".into()),
121 api_port: env::var("API_PORT").unwrap_or_else(|_| "8000".into()).parse()?,
122 requests_per_second: env::var("REQUESTS_PER_SECOND")
123 .unwrap_or_else(|_| "24".into())
124 .parse()?,
125 generate_timeout_seconds: env::var("GENERATE_TIMEOUT_SECONDS")
126 .unwrap_or_else(|_| "300".into())
127 .parse()?,
128 stream_timeout_seconds: env::var("STREAM_TIMEOUT_SECONDS")
129 .unwrap_or_else(|_| "600".into())
130 .parse()?,
131 health_check_timeout_seconds: env::var("HEALTH_CHECK_TIMEOUT_SECONDS")
132 .unwrap_or_else(|_| "90".into())
133 .parse()?,
134 queue_size: env::var("QUEUE_SIZE")
135 .unwrap_or_else(|_| "100".into())
136 .parse()?,
137 queue_timeout_seconds: env::var("QUEUE_TIMEOUT_SECONDS")
138 .unwrap_or_else(|_| "30".into())
139 .parse()?,
140 backend_url,
141 })
142 }
143
144 fn get_model_path_with_fallback() -> Result<String> {
145 if let Ok(model_path) = env::var("MODEL_PATH") {
147 if std::path::Path::new(&model_path).exists() {
149 info!("Using model from MODEL_PATH: {}", model_path);
150 return Ok(model_path);
151 } else {
152 warn!("MODEL_PATH set but file doesn't exist: {}", model_path);
153 }
154 }
155
156 let exe_dir = std::env::current_exe()
158 .ok()
159 .and_then(|exe| exe.parent().map(|p| p.to_path_buf()))
160 .unwrap_or_else(|| std::env::current_dir().unwrap_or_default());
161
162 let possible_model_locations = vec![
164 exe_dir.join("resources/models/default.gguf"),
166 exe_dir.join("resources/models/model.gguf"),
167 exe_dir.join("models/default.gguf"),
168 exe_dir.join("models/model.gguf"),
169 exe_dir.join("default.gguf"),
170 exe_dir.join("resources/models/default.onnx"),
172 exe_dir.join("resources/models/model.onnx"),
173 exe_dir.join("resources/models/default.trt"),
175 exe_dir.join("resources/models/model.engine"),
176 exe_dir.join("resources/models/default.safetensors"),
178 exe_dir.join("resources/models/model.safetensors"),
179 exe_dir.join("resources/models/default.ggml"),
181 exe_dir.join("resources/models/model.bin"),
182 ];
183
184 for model_path in possible_model_locations {
185 if model_path.exists() {
186 info!("Using embedded model: {}", model_path.display());
187 return Ok(model_path.to_string_lossy().to_string());
188 }
189 }
190
191 if let Ok(entries) = std::fs::read_dir(exe_dir.join("resources/models")) {
193 for entry in entries.flatten() {
194 if let Some(ext) = entry.path().extension() {
195 let ext_str = ext.to_str().unwrap_or("").to_lowercase();
196 if matches!(ext_str.as_str(), "gguf" | "ggml" | "onnx" | "trt" | "engine" | "plan" | "safetensors" | "mlmodel") {
198 info!("Using found model: {}", entry.path().display());
199 return Ok(entry.path().to_string_lossy().to_string());
200 }
201 }
202 }
203 }
204
205 Err(anyhow::anyhow!(
206 "No model file found. Please set MODEL_PATH environment variable or place a model file (supported formats: GGUF, GGML, ONNX, TensorRT, Safetensors) in resources/models/"
207 ))
208 }
209
210 fn auto_detect_threads() -> u32 {
211 let num_cpus = num_cpus::get() as u32;
212 info!("Auto‑detected CPU cores: {}", num_cpus);
213
214 match num_cpus {
215 1..=2 => 1,
216 3..=4 => (num_cpus * 2) / 3,
217 5..=8 => (num_cpus * 3) / 5,
218 9..=16 => num_cpus / 2,
219 17..=32 => (num_cpus * 2) / 5,
220 _ => 16,
221 }
222 }
223
224 fn auto_detect_gpu_layers() -> u32 {
225 if let Ok(nvml) = Nvml::init() {
226 if let Ok(device_count) = nvml.device_count() {
227 if device_count > 0 {
228 if let Ok(first_gpu) = nvml.device_by_index(0) {
229 if let Ok(memory) = first_gpu.memory_info() {
230 let vram_gb = memory.total / 1024 / 1024 / 1024;
231 let layers = match vram_gb {
232 0..=4 => 12,
233 5..=8 => 20,
234 9..=12 => 32,
235 13..=16 => 40,
236 _ => 50,
237 };
238 info!("Auto‑detected GPU layers: {} ({} GB VRAM)", layers, vram_gb);
239 return layers;
240 }
241 }
242 }
243 }
244 }
245 warn!("Failed to detect GPU, using default 20 layers");
246 20
247 }
248
249 fn auto_detect_ctx_size(model_path: &str) -> u32 {
250 let inferred = Self::read_ctx_size_from_model_path(model_path)
251 .unwrap_or_else(|| {
252 info!("Falling back to default context size (8192)");
253 8192
254 });
255 let adjusted = Self::adjust_ctx_size_for_system(inferred);
256 info!("Final context size: {} (inferred: {})", adjusted, inferred);
257 adjusted
258 }
259
260 fn read_ctx_size_from_model_path(model_path: &str) -> Option<u32> {
261 let path_lower = model_path.to_lowercase();
263
264 if path_lower.contains("32k") {
265 Some(32768)
266 } else if path_lower.contains("16k") {
267 Some(16384)
268 } else if path_lower.contains("8k") {
269 Some(8192)
270 } else if path_lower.contains("4k") {
271 Some(4096)
272 } else if path_lower.contains("2k") {
273 Some(2048)
274 } else if path_lower.contains("7b") || path_lower.contains("8b") || path_lower.contains("13b") {
275 Some(4096)
276 } else if path_lower.contains("34b") || path_lower.contains("70b") {
277 Some(8192)
278 } else {
279 Some(8192)
281 }
282 }
283
284 fn adjust_ctx_size_for_system(inferred_ctx: u32) -> u32 {
285 let mut system = System::new_all();
286 system.refresh_memory();
287
288 let available_ram_gb = system.available_memory() / 1024 / 1024 / 1024;
289 let _total_ram_gb = system.total_memory() / 1024 / 1024 / 1024;
290
291 let required_ram_gb = (inferred_ctx as f32 / 4096.0) * 1.5;
292 if available_ram_gb < required_ram_gb as u64 {
293 let adjusted = (available_ram_gb as f32 * 4096.0 / 1.5) as u32;
294 let safe_ctx = adjusted.min(inferred_ctx).max(2048);
295 warn!(
296 "Reducing context size from {} → {} due to limited RAM ({}GB available)",
297 inferred_ctx, safe_ctx, available_ram_gb
298 );
299 safe_ctx
300 } else {
301 inferred_ctx
302 }
303 }
304
305 fn auto_detect_batch_size(gpu_layers: u32, ctx_size: u32) -> u32 {
306 let mut system = System::new_all();
307 system.refresh_memory();
308
309 let available_mb = system.available_memory() / 1024;
310 let has_gpu = gpu_layers > 0;
311 let memory_per_batch = Self::estimate_memory_per_batch(ctx_size, has_gpu);
312 let safe_available_mb = (available_mb as f32 * 0.6) as u32;
313 let max_batch = (safe_available_mb as f32 / memory_per_batch).max(1.0) as u32;
314
315 let optimal = Self::apply_batch_limits(max_batch, ctx_size, has_gpu);
316 info!(
317 "Auto batch size: {} (ctx: {}, GPU: {}, est mem: {:.1}MB/batch)",
318 optimal, ctx_size, has_gpu, memory_per_batch
319 );
320 optimal
321 }
322
323 fn estimate_memory_per_batch(ctx_size: u32, has_gpu: bool) -> f32 {
324 if has_gpu {
325 (ctx_size as f32 / 1024.0) * 0.5
326 } else {
327 (ctx_size as f32 / 1024.0) * 1.2
328 }
329 }
330
331 fn apply_batch_limits(batch_size: u32, ctx_size: u32, _has_gpu: bool) -> u32 {
332 let limited = batch_size.clamp(16, 1024);
333 match ctx_size {
334 0..=2048 => limited.min(512),
335 2049..=4096 => limited.min(384),
336 4097..=8192 => limited.min(256),
337 8193..=16384 => limited.min(128),
338 16385..=32768 => limited.min(64),
339 _ => limited.min(32),
340 }
341 }
342
343 pub fn print_config(&self) {
344 info!("Current Configuration:");
345 info!("- Model Path: {}", self.model_path);
346 info!("- Llama Binary: {}", self.llama_bin);
347 info!("- Context Size: {}", self.ctx_size);
348 info!("- Batch Size: {}", self.batch_size);
349 info!("- Threads: {}", self.threads);
350 info!("- GPU Layers: {}", self.gpu_layers);
351 info!("- Max Streams: {}", self.max_concurrent_streams);
352 info!("- API: {}:{}", self.api_host, self.api_port);
353 info!("- Backend: {}:{}", self.llama_host, self.llama_port);
354 info!("- Queue Size: {}", self.queue_size);
355 info!("- Queue Timeout: {}s", self.queue_timeout_seconds);
356 info!("- Backend URL: {}", self.backend_url);
357 }
358
359 pub fn api_addr(&self) -> SocketAddr {
360 format!("{}:{}", self.api_host, self.api_port).parse().unwrap()
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367
368
369 fn create_test_config() -> Config {
371 Config {
372 model_path: "/test/model.gguf".to_string(),
373 llama_bin: "/test/llama-server".to_string(),
374 llama_host: "127.0.0.1".to_string(),
375 llama_port: 8001,
376 ctx_size: 8192,
377 batch_size: 128,
378 threads: 6,
379 gpu_layers: 20,
380 health_timeout_seconds: 600,
381 hot_swap_grace_seconds: 25,
382 max_concurrent_streams: 2,
383 prometheus_port: 9000,
384 api_host: "127.0.0.1".to_string(),
385 api_port: 8000,
386 requests_per_second: 24,
387 generate_timeout_seconds: 300,
388 stream_timeout_seconds: 600,
389 health_check_timeout_seconds: 900,
390 queue_size: 1000,
391 queue_timeout_seconds: 300,
392 backend_url: "http://127.0.0.1:8001".to_string(),
393 }
394 }
395
396 #[test]
399 fn test_config_creation_with_default_values() {
400 let config = create_test_config();
401
402 assert_eq!(config.model_path, "/test/model.gguf");
403 assert_eq!(config.llama_bin, "/test/llama-server");
404 assert_eq!(config.api_port, 8000);
405 assert_eq!(config.llama_port, 8001);
406 }
407
408 #[test]
409 fn test_config_clone() {
410 let config1 = create_test_config();
411 let config2 = config1.clone();
412
413 assert_eq!(config1.api_host, config2.api_host);
414 assert_eq!(config1.threads, config2.threads);
415 assert_eq!(config1.gpu_layers, config2.gpu_layers);
416 }
417
418 #[test]
421 fn test_api_addr_parsing() {
422 let config = create_test_config();
423 let addr = config.api_addr();
424
425 assert_eq!(addr.ip().to_string(), "127.0.0.1");
426 assert_eq!(addr.port(), 8000);
427 }
428
429 #[test]
430 fn test_api_addr_with_different_ports() {
431 let mut config = create_test_config();
432 config.api_port = 3000;
433
434 let addr = config.api_addr();
435 assert_eq!(addr.port(), 3000);
436 }
437
438 #[test]
439 fn test_api_addr_with_zero_address() {
440 let mut config = create_test_config();
441 config.api_host = "0.0.0.0".to_string();
442 config.api_port = 5000;
443
444 let addr = config.api_addr();
445 assert_eq!(addr.port(), 5000);
446 assert_eq!(addr.ip().to_string(), "0.0.0.0");
448 }
449
450 #[test]
453 fn test_config_timeouts_are_positive() {
454 let config = create_test_config();
455
456 assert!(config.health_timeout_seconds > 0);
457 assert!(config.generate_timeout_seconds > 0);
458 assert!(config.stream_timeout_seconds > 0);
459 assert!(config.health_check_timeout_seconds > 0);
460 }
461
462 #[test]
463 fn test_health_check_timeout_greater_than_health_timeout() {
464 let config = create_test_config();
465
466 assert!(config.health_check_timeout_seconds >= config.health_timeout_seconds);
468 }
469
470 #[test]
473 fn test_max_concurrent_streams_is_positive() {
474 let config = create_test_config();
475 assert!(config.max_concurrent_streams > 0);
476 }
477
478 #[test]
479 fn test_requests_per_second_is_reasonable() {
480 let config = create_test_config();
481
482 assert!(config.requests_per_second > 0);
484 assert!(config.requests_per_second <= 1000);
485 }
486
487 #[test]
488 fn test_queue_size_is_positive() {
489 let config = create_test_config();
490 assert!(config.queue_size > 0);
491 }
492
493 #[test]
496 fn test_context_size_within_valid_range() {
497 let config = create_test_config();
498
499 assert!(config.ctx_size >= 512);
501 assert!(config.ctx_size <= 32768);
502 }
503
504 #[test]
505 fn test_batch_size_valid_range() {
506 let config = create_test_config();
507
508 assert!(config.batch_size >= 16);
510 assert!(config.batch_size <= 1024);
511 }
512
513 #[test]
514 fn test_batch_size_reasonable_vs_context() {
515 let config = create_test_config();
516
517 assert!(config.batch_size < config.ctx_size);
519 }
520
521 #[test]
524 fn test_threads_is_positive() {
525 let config = create_test_config();
526 assert!(config.threads > 0);
527 }
528
529 #[test]
530 fn test_threads_within_reasonable_range() {
531 let config = create_test_config();
532
533 assert!(config.threads <= 256);
535 }
536
537 #[test]
540 fn test_gpu_layers_non_negative() {
541 let config = create_test_config();
542 assert!(config.gpu_layers <= config.ctx_size);
543 }
544
545 #[test]
546 fn test_gpu_layers_within_range() {
547 let config = create_test_config();
548
549 assert!(config.gpu_layers <= 100);
551 }
552
553 #[test]
556 fn test_api_port_valid() {
557 let config = create_test_config();
558 assert!(config.api_port > 0);
559 assert!(config.api_port != config.llama_port);
560 }
561
562 #[test]
563 fn test_llama_port_valid() {
564 let config = create_test_config();
565 assert!(config.llama_port > 0);
566 }
567
568 #[test]
569 fn test_prometheus_port_valid() {
570 let config = create_test_config();
571 assert!(config.prometheus_port > 0);
572 }
573
574 #[test]
575 fn test_ports_are_different() {
576 let config = create_test_config();
577
578 assert_ne!(config.api_port, config.llama_port);
580 assert_ne!(config.api_port, config.prometheus_port);
581 assert_ne!(config.llama_port, config.prometheus_port);
582 }
583
584 #[test]
587 fn test_model_path_not_empty() {
588 let config = create_test_config();
589 assert!(!config.model_path.is_empty());
590 }
591
592 #[test]
593 fn test_llama_bin_not_empty() {
594 let config = create_test_config();
595 assert!(!config.llama_bin.is_empty());
596 }
597
598 #[test]
599 fn test_backend_url_not_empty() {
600 let config = create_test_config();
601 assert!(!config.backend_url.is_empty());
602 }
603
604 #[test]
605 fn test_backend_url_format() {
606 let config = create_test_config();
607
608 assert!(config.backend_url.starts_with("http://") || config.backend_url.starts_with("https://"));
610 }
611
612 #[test]
615 fn test_api_host_not_empty() {
616 let config = create_test_config();
617 assert!(!config.api_host.is_empty());
618 }
619
620 #[test]
621 fn test_llama_host_not_empty() {
622 let config = create_test_config();
623 assert!(!config.llama_host.is_empty());
624 }
625
626 #[test]
629 fn test_hot_swap_grace_positive() {
630 let config = create_test_config();
631 assert!(config.hot_swap_grace_seconds > 0);
632 }
633
634 #[test]
635 fn test_hot_swap_grace_reasonable() {
636 let config = create_test_config();
637
638 assert!(config.hot_swap_grace_seconds < 300);
640 }
641
642 #[test]
645 fn test_auto_detect_threads_returns_positive() {
646 let threads = Config::auto_detect_threads();
647 assert!(threads > 0);
648 }
649
650 #[test]
651 fn test_auto_detect_gpu_layers_non_negative() {
652 let layers = Config::auto_detect_gpu_layers();
653 assert!(layers <= 512);
654 }
655
656 #[test]
657 fn test_apply_batch_limits_small_context() {
658 let batch = Config::apply_batch_limits(1024, 1024, false);
660 assert!(batch <= 512);
661 }
662
663 #[test]
664 fn test_apply_batch_limits_medium_context() {
665 let batch = Config::apply_batch_limits(1024, 3000, false);
667 assert!(batch <= 384);
668 }
669
670 #[test]
671 fn test_apply_batch_limits_large_context() {
672 let batch = Config::apply_batch_limits(1024, 24576, false);
674 assert!(batch <= 64);
675 }
676
677 #[test]
678 fn test_apply_batch_limits_minimum() {
679 let batch = Config::apply_batch_limits(1, 8192, false);
681 assert!(batch >= 16);
682 }
683
684 #[test]
685 fn test_estimate_memory_per_batch_cpu() {
686 let memory_cpu = Config::estimate_memory_per_batch(8192, false);
687 assert!(memory_cpu > 0.0);
688 }
689
690 #[test]
691 fn test_estimate_memory_per_batch_gpu() {
692 let memory_gpu = Config::estimate_memory_per_batch(8192, true);
693 assert!(memory_gpu > 0.0);
694 }
695
696 #[test]
697 fn test_estimate_memory_gpu_less_than_cpu() {
698 let memory_cpu = Config::estimate_memory_per_batch(8192, false);
699 let memory_gpu = Config::estimate_memory_per_batch(8192, true);
700
701 assert!(memory_gpu < memory_cpu);
703 }
704
705 #[test]
708 fn test_queue_timeout_is_positive() {
709 let config = create_test_config();
710 assert!(config.queue_timeout_seconds > 0);
711 }
712
713 #[test]
714 fn test_queue_timeout_less_than_generate_timeout() {
715 let config = create_test_config();
716
717 assert!(config.queue_timeout_seconds <= config.generate_timeout_seconds);
719 }
720
721 #[test]
724 fn test_config_values_consistency() {
725 let config = create_test_config();
726
727 assert!(config.health_timeout_seconds <= 3600); assert!(config.generate_timeout_seconds <= 1800); assert!(config.stream_timeout_seconds <= 3600); assert!(config.health_check_timeout_seconds <= 3600); }
733
734 #[test]
735 fn test_config_backend_url_consistency() {
736 let config = create_test_config();
737
738 assert!(config.backend_url.contains(&config.llama_host) ||
740 config.backend_url.contains("127.0.0.1") ||
741 config.backend_url.contains("localhost"));
742 }
743
744 #[test]
745 fn test_config_all_fields_initialized() {
746 let config = create_test_config();
747
748 assert!(!config.model_path.is_empty());
750 assert!(!config.llama_bin.is_empty());
751 assert!(!config.api_host.is_empty());
752 assert!(!config.llama_host.is_empty());
753 assert!(config.threads > 0);
754 assert!(config.gpu_layers <= config.ctx_size);
755 assert!(config.api_port > 0);
756 assert!(config.llama_port > 0);
757 }
758}