use std::fs;
use std::path::PathBuf;
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
#[cfg(feature = "cuda")]
use crate::acceleration::{env_acceleration_mode_lossy, AccelerationMode};
use crate::index::paths::get_colgrep_data_dir;
const CONFIG_FILE: &str = "config.json";
pub const DEFAULT_POOL_FACTOR: usize = 2;
pub const DEFAULT_MAX_RECURSION_DEPTH: usize = 1024;
pub const DEFAULT_BATCH_SIZE_CPU: usize = 1;
pub const DEFAULT_BATCH_SIZE_GPU: usize = 64;
#[cfg(feature = "cuda")]
pub const DEFAULT_BATCH_SIZE: usize = DEFAULT_BATCH_SIZE_GPU;
#[cfg(not(feature = "cuda"))]
pub const DEFAULT_BATCH_SIZE: usize = DEFAULT_BATCH_SIZE_CPU;
#[cfg(feature = "cuda")]
pub fn get_default_batch_size() -> usize {
match env_acceleration_mode_lossy() {
AccelerationMode::ForceCpu => DEFAULT_BATCH_SIZE_CPU,
AccelerationMode::ForceGpu => DEFAULT_BATCH_SIZE_GPU,
AccelerationMode::Auto => {
if crate::onnx_runtime::is_cudnn_available() {
DEFAULT_BATCH_SIZE_GPU
} else {
DEFAULT_BATCH_SIZE_CPU
}
}
}
}
#[cfg(not(feature = "cuda"))]
pub fn get_default_batch_size() -> usize {
DEFAULT_BATCH_SIZE_CPU
}
pub fn get_default_cpu_parallel_sessions() -> usize {
let cpu_count = std::thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(16);
cpu_count.min(MAX_PARALLEL_SESSIONS_CPU)
}
#[cfg(feature = "cuda")]
pub fn get_default_parallel_sessions() -> usize {
match env_acceleration_mode_lossy() {
AccelerationMode::ForceCpu => get_default_cpu_parallel_sessions(),
AccelerationMode::ForceGpu => DEFAULT_PARALLEL_SESSIONS_GPU,
AccelerationMode::Auto => {
if crate::onnx_runtime::is_cudnn_available() {
DEFAULT_PARALLEL_SESSIONS_GPU
} else {
get_default_cpu_parallel_sessions()
}
}
}
}
#[cfg(not(feature = "cuda"))]
pub fn get_default_parallel_sessions() -> usize {
get_default_cpu_parallel_sessions()
}
pub const DEFAULT_PARALLEL_SESSIONS_GPU: usize = 1;
pub const MAX_PARALLEL_SESSIONS_CPU: usize = 16;
pub const MAX_INTRA_OP_THREADS: usize = 16;
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Config {
#[serde(skip_serializing_if = "Option::is_none")]
pub default_model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub default_k: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub default_n: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub fp32: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub pool_factor: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parallel_sessions: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub batch_size: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub verbose: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_recursion_depth: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub relative_paths: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub hybrid_search: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub hybrid_alpha: Option<f32>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub extra_ignore: Vec<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub force_include: Vec<String>,
}
impl Config {
pub fn load() -> Result<Self> {
let path = get_config_path()?;
if !path.exists() {
return Ok(Self::default());
}
let content = fs::read_to_string(&path)
.with_context(|| format!("Failed to read config from {}", path.display()))?;
let config: Config = serde_json::from_str(&content)
.with_context(|| format!("Failed to parse config from {}", path.display()))?;
Ok(config)
}
pub fn save(&self) -> Result<()> {
let path = get_config_path()?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
let content = serde_json::to_string_pretty(self)?;
fs::write(&path, content)?;
Ok(())
}
pub fn get_default_model(&self) -> Option<&str> {
self.default_model.as_deref()
}
pub fn set_default_model(&mut self, model: impl Into<String>) {
self.default_model = Some(model.into());
}
pub fn get_default_k(&self) -> Option<usize> {
self.default_k
}
pub fn set_default_k(&mut self, k: usize) {
self.default_k = Some(k);
}
pub fn clear_default_k(&mut self) {
self.default_k = None;
}
pub fn get_default_n(&self) -> Option<usize> {
self.default_n
}
pub fn set_default_n(&mut self, n: usize) {
self.default_n = Some(n);
}
pub fn clear_default_n(&mut self) {
self.default_n = None;
}
pub fn use_fp32(&self) -> bool {
#[cfg(feature = "cuda")]
{
self.fp32.unwrap_or(true)
}
#[cfg(not(feature = "cuda"))]
{
self.fp32.unwrap_or(false)
}
}
pub fn set_fp32(&mut self, fp32: bool) {
self.fp32 = Some(fp32);
}
pub fn clear_fp32(&mut self) {
self.fp32 = None;
}
pub fn get_pool_factor(&self) -> usize {
self.pool_factor.unwrap_or(DEFAULT_POOL_FACTOR)
}
pub fn set_pool_factor(&mut self, factor: usize) {
self.pool_factor = Some(factor.max(1)); }
pub fn clear_pool_factor(&mut self) {
self.pool_factor = None;
}
pub fn configured_parallel_sessions(&self) -> Option<usize> {
self.parallel_sessions.map(|sessions| sessions.max(1))
}
pub fn get_parallel_sessions(&self) -> usize {
self.configured_parallel_sessions()
.unwrap_or_else(get_default_parallel_sessions)
}
pub fn set_parallel_sessions(&mut self, sessions: usize) {
self.parallel_sessions = Some(sessions.max(1)); }
pub fn clear_parallel_sessions(&mut self) {
self.parallel_sessions = None;
}
pub fn configured_batch_size(&self) -> Option<usize> {
self.batch_size.map(|size| size.max(1))
}
pub fn get_batch_size(&self) -> usize {
self.configured_batch_size()
.unwrap_or_else(get_default_batch_size)
}
pub fn set_batch_size(&mut self, size: usize) {
self.batch_size = Some(size.max(1)); }
pub fn clear_batch_size(&mut self) {
self.batch_size = None;
}
pub fn is_verbose(&self) -> bool {
self.verbose.unwrap_or(false)
}
pub fn set_verbose(&mut self, verbose: bool) {
self.verbose = Some(verbose);
}
pub fn clear_verbose(&mut self) {
self.verbose = None;
}
pub fn use_relative_paths(&self) -> bool {
self.relative_paths.unwrap_or(true)
}
pub fn set_relative_paths(&mut self, relative: bool) {
self.relative_paths = Some(relative);
}
pub fn clear_relative_paths(&mut self) {
self.relative_paths = None;
}
pub fn get_max_recursion_depth(&self) -> usize {
self.max_recursion_depth
.unwrap_or(DEFAULT_MAX_RECURSION_DEPTH)
}
pub fn set_max_recursion_depth(&mut self, depth: usize) {
self.max_recursion_depth = Some(depth.max(1));
}
pub fn clear_max_recursion_depth(&mut self) {
self.max_recursion_depth = None;
}
pub fn use_hybrid_search(&self) -> bool {
self.hybrid_search.unwrap_or(true)
}
pub fn set_hybrid_search(&mut self, enabled: bool) {
self.hybrid_search = Some(enabled);
}
pub fn clear_hybrid_search(&mut self) {
self.hybrid_search = None;
}
pub fn get_hybrid_alpha(&self) -> f32 {
if let Ok(env_alpha) = std::env::var("COLGREP_ALPHA") {
if let Ok(v) = env_alpha.parse::<f32>() {
return v.clamp(0.0, 1.0);
}
}
self.hybrid_alpha.unwrap_or(0.60)
}
pub fn set_hybrid_alpha(&mut self, alpha: f32) {
self.hybrid_alpha = Some(alpha.clamp(0.0, 1.0));
}
pub fn clear_hybrid_alpha(&mut self) {
self.hybrid_alpha = None;
}
pub fn get_extra_ignore(&self) -> &[String] {
&self.extra_ignore
}
pub fn add_extra_ignore(&mut self, pattern: impl Into<String>) {
let p = pattern.into();
if !self.extra_ignore.contains(&p) {
self.extra_ignore.push(p);
}
}
pub fn remove_extra_ignore(&mut self, pattern: &str) -> bool {
let len = self.extra_ignore.len();
self.extra_ignore.retain(|p| p != pattern);
self.extra_ignore.len() < len
}
pub fn clear_extra_ignore(&mut self) {
self.extra_ignore.clear();
}
pub fn get_force_include(&self) -> &[String] {
&self.force_include
}
pub fn add_force_include(&mut self, pattern: impl Into<String>) {
let p = pattern.into();
if !self.force_include.contains(&p) {
self.force_include.push(p);
}
}
pub fn remove_force_include(&mut self, pattern: &str) -> bool {
let len = self.force_include.len();
self.force_include.retain(|p| p != pattern);
self.force_include.len() < len
}
pub fn clear_force_include(&mut self) {
self.force_include.clear();
}
}
pub fn get_config_path() -> Result<PathBuf> {
let data_dir = get_colgrep_data_dir()?;
let parent = data_dir
.parent()
.context("Could not determine config directory")?;
Ok(parent.join(CONFIG_FILE))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = Config::default();
assert!(config.default_model.is_none());
assert!(config.get_default_model().is_none());
assert!(config.default_k.is_none());
assert!(config.get_default_k().is_none());
assert!(config.default_n.is_none());
assert!(config.get_default_n().is_none());
}
#[test]
fn test_config_set_default_model() {
let mut config = Config::default();
config.set_default_model("test-model");
assert_eq!(config.get_default_model(), Some("test-model"));
}
#[test]
fn test_config_set_default_model_string() {
let mut config = Config::default();
config.set_default_model(String::from("another-model"));
assert_eq!(config.get_default_model(), Some("another-model"));
}
#[test]
fn test_config_serialization() {
let mut config = Config::default();
config.set_default_model("lightonai/LateOn-Code-edge");
let json = serde_json::to_string(&config).unwrap();
assert!(json.contains("lightonai/LateOn-Code-edge"));
let deserialized: Config = serde_json::from_str(&json).unwrap();
assert_eq!(
deserialized.get_default_model(),
Some("lightonai/LateOn-Code-edge")
);
}
#[test]
fn test_config_serialization_empty() {
let config = Config::default();
let json = serde_json::to_string(&config).unwrap();
assert!(!json.contains("default_model"));
let deserialized: Config = serde_json::from_str(&json).unwrap();
assert!(deserialized.get_default_model().is_none());
}
#[test]
fn test_config_deserialization_missing_field() {
let json = "{}";
let config: Config = serde_json::from_str(json).unwrap();
assert!(config.get_default_model().is_none());
}
#[test]
fn test_config_deserialization_null_field() {
let json = r#"{"default_model": null}"#;
let config: Config = serde_json::from_str(json).unwrap();
assert!(config.get_default_model().is_none());
}
#[test]
fn test_config_path_exists() {
let result = get_config_path();
assert!(result.is_ok());
let path = result.unwrap();
assert!(path.to_string_lossy().contains("config.json"));
}
#[test]
fn test_config_default_k() {
let config = Config::default();
assert!(config.get_default_k().is_none());
}
#[test]
fn test_config_set_default_k() {
let mut config = Config::default();
config.set_default_k(25);
assert_eq!(config.get_default_k(), Some(25));
}
#[test]
fn test_config_clear_default_k() {
let mut config = Config::default();
config.set_default_k(25);
assert_eq!(config.get_default_k(), Some(25));
config.clear_default_k();
assert!(config.get_default_k().is_none());
}
#[test]
fn test_config_default_n() {
let config = Config::default();
assert!(config.get_default_n().is_none());
}
#[test]
fn test_config_set_default_n() {
let mut config = Config::default();
config.set_default_n(10);
assert_eq!(config.get_default_n(), Some(10));
}
#[test]
fn test_config_clear_default_n() {
let mut config = Config::default();
config.set_default_n(10);
assert_eq!(config.get_default_n(), Some(10));
config.clear_default_n();
assert!(config.get_default_n().is_none());
}
#[test]
fn test_config_serialization_with_k_and_n() {
let mut config = Config::default();
config.set_default_k(20);
config.set_default_n(8);
let json = serde_json::to_string(&config).unwrap();
assert!(json.contains("\"default_k\":20"));
assert!(json.contains("\"default_n\":8"));
let deserialized: Config = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.get_default_k(), Some(20));
assert_eq!(deserialized.get_default_n(), Some(8));
}
#[test]
fn test_config_serialization_skips_none_k_n() {
let config = Config::default();
let json = serde_json::to_string(&config).unwrap();
assert!(!json.contains("default_k"));
assert!(!json.contains("default_n"));
}
#[test]
fn test_config_deserialization_with_k_n() {
let json = r#"{"default_k": 30, "default_n": 12}"#;
let config: Config = serde_json::from_str(json).unwrap();
assert_eq!(config.get_default_k(), Some(30));
assert_eq!(config.get_default_n(), Some(12));
}
#[test]
fn test_default_parallel_sessions_capped_at_16() {
assert_eq!(MAX_PARALLEL_SESSIONS_CPU, 16);
let sessions = get_default_parallel_sessions();
#[cfg(feature = "cuda")]
let expected = match env_acceleration_mode_lossy() {
AccelerationMode::ForceCpu => std::thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(16)
.min(MAX_PARALLEL_SESSIONS_CPU),
AccelerationMode::ForceGpu => DEFAULT_PARALLEL_SESSIONS_GPU,
AccelerationMode::Auto => {
if crate::onnx_runtime::is_cudnn_available() {
DEFAULT_PARALLEL_SESSIONS_GPU
} else {
std::thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(16)
.min(MAX_PARALLEL_SESSIONS_CPU)
}
}
};
#[cfg(not(feature = "cuda"))]
let expected = std::thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(16)
.min(MAX_PARALLEL_SESSIONS_CPU);
assert_eq!(sessions, expected);
assert!(
sessions <= MAX_PARALLEL_SESSIONS_CPU || sessions == DEFAULT_PARALLEL_SESSIONS_GPU,
"Sessions should match either the capped CPU default or the fixed GPU default"
);
}
#[test]
fn test_config_parallel_sessions_default() {
let config = Config::default();
let sessions = config.get_parallel_sessions();
assert!(sessions >= 1);
assert!(sessions <= 16);
}
#[test]
fn test_config_auto_getters_resolve_to_concrete_values() {
let config = Config::default();
assert!(config.parallel_sessions.is_none());
assert!(config.batch_size.is_none());
assert!(config.configured_parallel_sessions().is_none());
assert!(config.configured_batch_size().is_none());
assert!(config.get_parallel_sessions() >= 1);
assert!(config.get_batch_size() >= 1);
}
#[test]
fn test_config_configured_runtime_overrides_normalize_legacy_zero_values() {
let config = Config {
parallel_sessions: Some(0),
batch_size: Some(0),
..Default::default()
};
assert_eq!(config.configured_parallel_sessions(), Some(1));
assert_eq!(config.configured_batch_size(), Some(1));
assert_eq!(config.get_parallel_sessions(), 1);
assert_eq!(config.get_batch_size(), 1);
}
#[test]
fn test_extra_ignore_default_empty() {
let config = Config::default();
assert!(config.get_extra_ignore().is_empty());
}
#[test]
fn test_add_extra_ignore() {
let mut config = Config::default();
config.add_extra_ignore("generated");
config.add_extra_ignore("*.pb.go");
assert_eq!(config.get_extra_ignore(), &["generated", "*.pb.go"]);
}
#[test]
fn test_add_extra_ignore_dedup() {
let mut config = Config::default();
config.add_extra_ignore("generated");
config.add_extra_ignore("generated");
assert_eq!(config.get_extra_ignore(), &["generated"]);
}
#[test]
fn test_remove_extra_ignore() {
let mut config = Config::default();
config.add_extra_ignore("generated");
config.add_extra_ignore("migrations");
assert!(config.remove_extra_ignore("generated"));
assert_eq!(config.get_extra_ignore(), &["migrations"]);
assert!(!config.remove_extra_ignore("nonexistent"));
}
#[test]
fn test_clear_extra_ignore() {
let mut config = Config::default();
config.add_extra_ignore("a");
config.add_extra_ignore("b");
config.clear_extra_ignore();
assert!(config.get_extra_ignore().is_empty());
}
#[test]
fn test_force_include_default_empty() {
let config = Config::default();
assert!(config.get_force_include().is_empty());
}
#[test]
fn test_add_force_include() {
let mut config = Config::default();
config.add_force_include(".vscode");
config.add_force_include("vendor/internal");
assert_eq!(config.get_force_include(), &[".vscode", "vendor/internal"]);
}
#[test]
fn test_add_force_include_dedup() {
let mut config = Config::default();
config.add_force_include(".vscode");
config.add_force_include(".vscode");
assert_eq!(config.get_force_include(), &[".vscode"]);
}
#[test]
fn test_remove_force_include() {
let mut config = Config::default();
config.add_force_include(".vscode");
config.add_force_include("build");
assert!(config.remove_force_include(".vscode"));
assert_eq!(config.get_force_include(), &["build"]);
assert!(!config.remove_force_include("nonexistent"));
}
#[test]
fn test_clear_force_include() {
let mut config = Config::default();
config.add_force_include("a");
config.add_force_include("b");
config.clear_force_include();
assert!(config.get_force_include().is_empty());
}
#[test]
fn test_ignore_force_include_serialization() {
let mut config = Config::default();
config.add_extra_ignore("generated");
config.add_extra_ignore("*.pb.go");
config.add_force_include(".vscode");
let json = serde_json::to_string_pretty(&config).unwrap();
assert!(json.contains("extra_ignore"));
assert!(json.contains("generated"));
assert!(json.contains("force_include"));
assert!(json.contains(".vscode"));
let deserialized: Config = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.get_extra_ignore(), &["generated", "*.pb.go"]);
assert_eq!(deserialized.get_force_include(), &[".vscode"]);
}
#[test]
fn test_ignore_force_include_serialization_skips_empty() {
let config = Config::default();
let json = serde_json::to_string(&config).unwrap();
assert!(!json.contains("extra_ignore"));
assert!(!json.contains("force_include"));
}
#[test]
fn test_ignore_force_include_deserialization_missing() {
let json = r#"{"default_k": 10}"#;
let config: Config = serde_json::from_str(json).unwrap();
assert!(config.get_extra_ignore().is_empty());
assert!(config.get_force_include().is_empty());
}
#[test]
fn test_relative_paths_default_true() {
let config = Config::default();
assert!(config.use_relative_paths());
}
#[test]
fn test_relative_paths_set_clear() {
let mut config = Config::default();
config.set_relative_paths(false);
assert!(!config.use_relative_paths());
config.clear_relative_paths();
assert!(config.use_relative_paths());
}
#[test]
fn test_relative_paths_serialization() {
let mut config = Config::default();
let json = serde_json::to_string(&config).unwrap();
assert!(!json.contains("relative_paths"));
config.set_relative_paths(true);
let json = serde_json::to_string(&config).unwrap();
assert!(json.contains("relative_paths"));
let deserialized: Config = serde_json::from_str(&json).unwrap();
assert!(deserialized.use_relative_paths());
}
}