1use anyhow::{Context, Result};
4use std::env;
5use std::net::SocketAddr;
6use tracing::{info, warn};
7use nvml_wrapper::Nvml;
8use sysinfo::System;
9#[allow(dead_code)]
12#[derive(Debug, Clone)]
13pub struct Config {
14 pub model_path: String,
15 pub llama_bin: String,
16 pub llama_host: String,
17 pub llama_port: u16,
18 pub ctx_size: u32,
19 pub batch_size: u32,
20 pub threads: u32,
21 pub gpu_layers: u32,
22 pub health_timeout_seconds: u64,
23 pub hot_swap_grace_seconds: u64,
24 pub max_concurrent_streams: u32,
25 pub prometheus_port: u16,
26 pub api_host: String,
27 pub api_port: u16,
28 pub requests_per_second: u32,
29 pub generate_timeout_seconds: u64,
30 pub stream_timeout_seconds: u64,
31 pub health_check_timeout_seconds: u64,
32 pub queue_size: usize,
33 pub queue_timeout_seconds: u64,
34}
35
36impl Config {
37 pub fn from_env() -> Result<Self> {
38 if let Err(e) = dotenvy::dotenv() {
39 warn!("Failed to load .env file: {}. Using system environment variables.", e);
40 } else {
41 info!("Loaded environment variables from .env file");
42 }
43
44 let llama_bin = env::var("LLAMA_BIN")
46 .context("LLAMA_BIN environment variable not set. Please set it in your .env file")?;
47
48 if !std::path::Path::new(&llama_bin).exists() {
50 return Err(anyhow::anyhow!(
51 "Llama binary not found at: {}. Please check LLAMA_BIN in .env file.",
52 llama_bin
53 ));
54 }
55
56 info!("Using llama binary from .env: {}", llama_bin);
57
58 let model_path = Self::get_model_path_with_fallback()?;
60
61 let threads = if env::var("THREADS").unwrap_or_else(|_| "auto".into()) == "auto" {
63 Self::auto_detect_threads()
64 } else {
65 env::var("THREADS").unwrap_or_else(|_| "6".into()).parse().unwrap_or(6)
66 };
67
68 let gpu_layers = if env::var("GPU_LAYERS").unwrap_or_else(|_| "auto".into()) == "auto" {
70 Self::auto_detect_gpu_layers()
71 } else {
72 env::var("GPU_LAYERS").unwrap_or_else(|_| "20".into()).parse().unwrap_or(20)
73 };
74
75 let ctx_size = if env::var("CTX_SIZE").unwrap_or_else(|_| "auto".into()) == "auto" {
77 Self::auto_detect_ctx_size(&model_path)
78 } else {
79 env::var("CTX_SIZE").unwrap_or_else(|_| "8192".into()).parse().unwrap_or(8192)
80 };
81
82 let batch_size = if env::var("BATCH_SIZE").unwrap_or_else(|_| "auto".into()) == "auto" {
84 Self::auto_detect_batch_size(gpu_layers, ctx_size)
85 } else {
86 env::var("BATCH_SIZE").unwrap_or_else(|_| "256".into()).parse().unwrap_or(256)
87 };
88
89 let llama_host = env::var("LLAMA_HOST").unwrap_or_else(|_| "127.0.0.1".into());
91 let llama_port = env::var("LLAMA_PORT").unwrap_or_else(|_| "8081".into()).parse()?;
92
93 info!(
94 "Resource Configuration: {} GPU layers, {} threads, batch size: {}, context: {}",
95 gpu_layers, threads, batch_size, ctx_size
96 );
97
98 Ok(Self {
99 model_path,
100 llama_bin,
101 llama_host: llama_host.clone(),
102 llama_port,
103 ctx_size,
104 batch_size,
105 threads,
106 gpu_layers,
107 health_timeout_seconds: env::var("HEALTH_TIMEOUT_SECONDS")
108 .unwrap_or_else(|_| "60".into())
109 .parse()?,
110 hot_swap_grace_seconds: env::var("HOT_SWAP_GRACE_SECONDS")
111 .unwrap_or_else(|_| "25".into())
112 .parse()?,
113 max_concurrent_streams: env::var("MAX_CONCURRENT_STREAMS")
114 .unwrap_or_else(|_| "4".into())
115 .parse()?,
116 prometheus_port: env::var("PROMETHEUS_PORT")
117 .unwrap_or_else(|_| "9000".into())
118 .parse()?,
119 api_host: env::var("API_HOST").unwrap_or_else(|_| "127.0.0.1".into()),
120 api_port: env::var("API_PORT").unwrap_or_else(|_| "8000".into()).parse()?,
121 requests_per_second: env::var("REQUESTS_PER_SECOND")
122 .unwrap_or_else(|_| "24".into())
123 .parse()?,
124 generate_timeout_seconds: env::var("GENERATE_TIMEOUT_SECONDS")
125 .unwrap_or_else(|_| "300".into())
126 .parse()?,
127 stream_timeout_seconds: env::var("STREAM_TIMEOUT_SECONDS")
128 .unwrap_or_else(|_| "600".into())
129 .parse()?,
130 health_check_timeout_seconds: env::var("HEALTH_CHECK_TIMEOUT_SECONDS")
131 .unwrap_or_else(|_| "90".into())
132 .parse()?,
133 queue_size: env::var("QUEUE_SIZE")
134 .unwrap_or_else(|_| "100".into())
135 .parse()?,
136 queue_timeout_seconds: env::var("QUEUE_TIMEOUT_SECONDS")
137 .unwrap_or_else(|_| "30".into())
138 .parse()?,
139 })
140 }
141
142 fn get_model_path_with_fallback() -> Result<String> {
143 if let Ok(model_path) = env::var("MODEL_PATH") {
145 if std::path::Path::new(&model_path).exists() {
147 info!("Using model from MODEL_PATH: {}", model_path);
148 return Ok(model_path);
149 } else {
150 warn!("MODEL_PATH set but file doesn't exist: {}", model_path);
151 }
152 }
153
154 let exe_dir = std::env::current_exe()
156 .ok()
157 .and_then(|exe| exe.parent().map(|p| p.to_path_buf()))
158 .unwrap_or_else(|| std::env::current_dir().unwrap_or_default());
159
160 let possible_model_locations = vec![
162 exe_dir.join("resources/models/default.gguf"),
163 exe_dir.join("resources/models/model.gguf"),
164 exe_dir.join("models/default.gguf"),
165 exe_dir.join("models/model.gguf"),
166 exe_dir.join("default.gguf"),
167 ];
168
169 for model_path in possible_model_locations {
170 if model_path.exists() {
171 info!("Using embedded model: {}", model_path.display());
172 return Ok(model_path.to_string_lossy().to_string());
173 }
174 }
175
176 if let Ok(entries) = std::fs::read_dir(exe_dir.join("resources/models")) {
178 for entry in entries.flatten() {
179 if let Some(ext) = entry.path().extension() {
180 if ext == "gguf" {
181 info!("Using found model: {}", entry.path().display());
182 return Ok(entry.path().to_string_lossy().to_string());
183 }
184 }
185 }
186 }
187
188 Err(anyhow::anyhow!(
189 "No model file found. Please set MODEL_PATH environment variable or place a .gguf file in resources/models/"
190 ))
191 }
192
193 fn auto_detect_threads() -> u32 {
194 let num_cpus = num_cpus::get() as u32;
195 info!("Auto‑detected CPU cores: {}", num_cpus);
196
197 match num_cpus {
198 1..=2 => 1,
199 3..=4 => (num_cpus * 2) / 3,
200 5..=8 => (num_cpus * 3) / 5,
201 9..=16 => num_cpus / 2,
202 17..=32 => (num_cpus * 2) / 5,
203 _ => 16,
204 }
205 }
206
207 fn auto_detect_gpu_layers() -> u32 {
208 if let Ok(nvml) = Nvml::init() {
209 if let Ok(device_count) = nvml.device_count() {
210 if device_count > 0 {
211 if let Ok(first_gpu) = nvml.device_by_index(0) {
212 if let Ok(memory) = first_gpu.memory_info() {
213 let vram_gb = memory.total / 1024 / 1024 / 1024;
214 let layers = match vram_gb {
215 0..=4 => 12,
216 5..=8 => 20,
217 9..=12 => 32,
218 13..=16 => 40,
219 _ => 50,
220 };
221 info!("Auto‑detected GPU layers: {} ({} GB VRAM)", layers, vram_gb);
222 return layers;
223 }
224 }
225 }
226 }
227 }
228 warn!("Failed to detect GPU, using default 20 layers");
229 20
230 }
231
232 fn auto_detect_ctx_size(model_path: &str) -> u32 {
233 let inferred = Self::read_ctx_size_from_model_path(model_path)
234 .unwrap_or_else(|| {
235 info!("Falling back to default context size (8192)");
236 8192
237 });
238 let adjusted = Self::adjust_ctx_size_for_system(inferred);
239 info!("Final context size: {} (inferred: {})", adjusted, inferred);
240 adjusted
241 }
242
243 fn read_ctx_size_from_model_path(model_path: &str) -> Option<u32> {
244 let path_lower = model_path.to_lowercase();
246
247 if path_lower.contains("32k") {
248 Some(32768)
249 } else if path_lower.contains("16k") {
250 Some(16384)
251 } else if path_lower.contains("8k") {
252 Some(8192)
253 } else if path_lower.contains("4k") {
254 Some(4096)
255 } else if path_lower.contains("2k") {
256 Some(2048)
257 } else if path_lower.contains("7b") || path_lower.contains("8b") {
258 Some(4096)
259 } else if path_lower.contains("13b") {
260 Some(4096)
261 } else if path_lower.contains("34b") || path_lower.contains("70b") {
262 Some(8192)
263 } else {
264 Some(8192)
266 }
267 }
268
269 fn adjust_ctx_size_for_system(inferred_ctx: u32) -> u32 {
270 let mut system = System::new_all();
271 system.refresh_memory();
272
273 let available_ram_gb = system.available_memory() / 1024 / 1024 / 1024;
274 let _total_ram_gb = system.total_memory() / 1024 / 1024 / 1024;
275
276 let required_ram_gb = (inferred_ctx as f32 / 4096.0) * 1.5;
277 if available_ram_gb < required_ram_gb as u64 {
278 let adjusted = (available_ram_gb as f32 * 4096.0 / 1.5) as u32;
279 let safe_ctx = adjusted.min(inferred_ctx).max(2048);
280 warn!(
281 "Reducing context size from {} → {} due to limited RAM ({}GB available)",
282 inferred_ctx, safe_ctx, available_ram_gb
283 );
284 safe_ctx
285 } else {
286 inferred_ctx
287 }
288 }
289
290 fn auto_detect_batch_size(gpu_layers: u32, ctx_size: u32) -> u32 {
291 let mut system = System::new_all();
292 system.refresh_memory();
293
294 let available_mb = system.available_memory() / 1024;
295 let has_gpu = gpu_layers > 0;
296 let memory_per_batch = Self::estimate_memory_per_batch(ctx_size, has_gpu);
297 let safe_available_mb = (available_mb as f32 * 0.6) as u32;
298 let max_batch = (safe_available_mb as f32 / memory_per_batch).max(1.0) as u32;
299
300 let optimal = Self::apply_batch_limits(max_batch, ctx_size, has_gpu);
301 info!(
302 "Auto batch size: {} (ctx: {}, GPU: {}, est mem: {:.1}MB/batch)",
303 optimal, ctx_size, has_gpu, memory_per_batch
304 );
305 optimal
306 }
307
308 fn estimate_memory_per_batch(ctx_size: u32, has_gpu: bool) -> f32 {
309 if has_gpu {
310 (ctx_size as f32 / 1024.0) * 0.5
311 } else {
312 (ctx_size as f32 / 1024.0) * 1.2
313 }
314 }
315
316 fn apply_batch_limits(batch_size: u32, ctx_size: u32, _has_gpu: bool) -> u32 {
317 let limited = batch_size.max(16).min(1024);
318 match ctx_size {
319 0..=2048 => limited.min(512),
320 2049..=4096 => limited.min(384),
321 4097..=8192 => limited.min(256),
322 8193..=16384 => limited.min(128),
323 16385..=32768 => limited.min(64),
324 _ => limited.min(32),
325 }
326 }
327
328 pub fn print_config(&self) {
329 info!("Current Configuration:");
330 info!("- Model Path: {}", self.model_path);
331 info!("- Llama Binary: {}", self.llama_bin);
332 info!("- Context Size: {}", self.ctx_size);
333 info!("- Batch Size: {}", self.batch_size);
334 info!("- Threads: {}", self.threads);
335 info!("- GPU Layers: {}", self.gpu_layers);
336 info!("- Max Streams: {}", self.max_concurrent_streams);
337 info!("- API: {}:{}", self.api_host, self.api_port);
338 info!("- LLM Backend: {}:{}", self.llama_host, self.llama_port);
339 info!("- Queue Size: {}", self.queue_size);
340 info!("- Queue Timeout: {}s", self.queue_timeout_seconds);
341 }
342
343 pub fn api_addr(&self) -> SocketAddr {
344 format!("{}:{}", self.api_host, self.api_port).parse().unwrap()
345 }
346}