use config::{Config as ConfigRs, Environment, File};
use std::{
path::{Path, PathBuf},
str::FromStr,
time::Duration,
};
use anyhow::{Context as _, Result};
use serde::{Deserialize, Serialize};
use swiftide::integrations::treesitter::SupportedLanguages;
use super::{api_key::ApiKey, tools::Tools};
use super::{
defaults::{
default_auto_push_remote, default_cache_dir, default_docker_context, default_dockerfile,
default_log_dir, default_main_branch, default_project_name,
},
mcp::McpServer,
};
use super::{CommandConfiguration, LLMConfiguration, LLMConfigurations};
#[allow(clippy::struct_excessive_bools)]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Config {
#[serde(default = "default_project_name")]
pub project_name: String,
pub language: SupportedLanguages,
pub llm: Box<LLMConfigurations>,
pub commands: CommandConfiguration,
#[serde(default = "default_cache_dir")]
pub cache_dir: PathBuf,
#[serde(default = "default_log_dir")]
pub log_dir: PathBuf,
#[serde(default)]
pub agent: SupportedAgentConfigurations,
#[serde(default)]
indexing_concurrency: Option<usize>,
#[serde(default)]
indexing_batch_size: Option<usize>,
#[serde(default = "default_indexing_garbage_collection")]
pub indexing_garbage_collection_enabled: bool,
#[serde(default)]
pub docker: DockerConfiguration,
#[serde(default)]
pub backoff: BackoffConfiguration,
pub git: GitConfiguration,
#[serde(default)]
pub tavily_api_key: Option<ApiKey>,
#[serde(default)]
pub github_api_key: Option<ApiKey>,
#[serde(default)]
pub openai_api_key: Option<ApiKey>,
#[serde(default)]
pub anthropic_api_key: Option<ApiKey>,
#[serde(default)]
pub open_router_api_key: Option<ApiKey>,
#[serde(default)]
pub azure_openai_api_key: Option<ApiKey>,
#[serde(default)]
pub tool_executor: SupportedToolExecutors,
#[serde(default)]
pub tools: Tools,
#[serde(default)]
pub endless_mode: bool,
#[serde(default)]
pub stop_on_empty_messages: bool,
#[serde(default = "default_otel_enabled")]
pub otel_enabled: bool,
#[serde(default)]
pub agent_edit_mode: AgentEditMode,
#[serde(default)]
pub agent_custom_constraints: Option<Vec<String>>,
#[serde(default)]
pub ui: UIConfig,
#[serde(default = "default_num_completions_for_summary")]
pub num_completions_for_summary: usize,
#[serde(default)]
pub mcp: Option<Vec<McpServer>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(default)]
pub struct UIConfig {
pub hide_header: bool,
}
fn default_otel_enabled() -> bool {
false
}
fn default_num_completions_for_summary() -> usize {
10
}
fn default_indexing_garbage_collection() -> bool {
true
}
#[derive(PartialEq, Debug, Clone, Serialize, Deserialize, Default)]
pub enum SupportedAgentConfigurations {
#[default]
#[serde(alias = "V1")]
Coding,
PlanAct,
}
#[derive(PartialEq, Debug, Clone, Serialize, Deserialize, Default)]
#[serde(rename_all = "kebab-case")]
pub enum SupportedToolExecutors {
#[default]
Docker,
Local,
}
#[derive(PartialEq, Debug, Clone, Serialize, Deserialize, Default, strum_macros::EnumIs)]
#[serde(rename_all = "kebab-case")]
pub enum AgentEditMode {
Whole,
Line,
#[default]
Patch,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DockerConfiguration {
#[serde(default = "default_dockerfile")]
pub dockerfile: PathBuf,
#[serde(default = "default_docker_context")]
pub context: PathBuf,
}
impl Default for DockerConfiguration {
fn default() -> Self {
Self {
dockerfile: "Dockerfile".into(),
context: ".".into(),
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct BackoffConfiguration {
pub initial_interval_sec: u64,
pub multiplier: f64,
pub randomization_factor: f64,
pub max_elapsed_time_sec: u64,
}
impl Default for BackoffConfiguration {
fn default() -> Self {
Self {
initial_interval_sec: 15,
multiplier: 2.0,
randomization_factor: 0.05,
max_elapsed_time_sec: 120,
}
}
}
impl From<BackoffConfiguration> for backoff::ExponentialBackoff {
fn from(from_backoff: BackoffConfiguration) -> Self {
backoff::ExponentialBackoffBuilder::default()
.with_initial_interval(Duration::from_secs(from_backoff.initial_interval_sec))
.with_multiplier(from_backoff.multiplier)
.with_randomization_factor(from_backoff.randomization_factor)
.with_max_elapsed_time(Some(Duration::from_secs(from_backoff.max_elapsed_time_sec)))
.build()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GitConfiguration {
pub repository: Option<String>,
pub owner: Option<String>,
#[serde(default = "default_main_branch")]
pub main_branch: String,
#[serde(default = "default_auto_push_remote")]
pub auto_push_remote: bool,
#[serde(default)]
pub auto_commit_disabled: bool,
#[serde(default = "default_agent_user_name")]
pub agent_user_name: String,
#[serde(default = "default_agent_user_email")]
pub agent_user_email: String,
#[serde(default = "default_generate_commit_message")]
pub generate_commit_message: bool,
}
fn default_agent_user_name() -> String {
"kwaak".to_string()
}
fn default_agent_user_email() -> String {
"kwaak@bosun.ai".to_string()
}
fn default_generate_commit_message() -> bool {
true
}
impl FromStr for Config {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self> {
toml::from_str(s)
.context("Failed to parse configuration")
.and_then(Config::fill_llm_api_keys)
.map(Config::add_project_name_to_paths)
}
}
impl Config {
pub fn load(path: Option<&Path>) -> Result<Self> {
let builder = match path {
Some(path) => ConfigRs::builder().add_source(File::from(path)),
None => {
if std::fs::metadata(".config/kwaak.toml").is_ok() {
ConfigRs::builder().add_source(File::with_name(".config/kwaak.toml"))
} else {
ConfigRs::builder().add_source(File::with_name("kwaak.toml"))
}
}
};
let config = builder
.add_source(File::with_name("kwaak.local").required(false))
.add_source(Environment::with_prefix("KWAAK").separator("__"))
.build()
.context("Failed to build configuration")?;
config
.try_deserialize()
.map_err(Into::into)
.and_then(Config::fill_llm_api_keys)
.map(Config::add_project_name_to_paths)
}
fn fill_llm_api_keys(mut self) -> Result<Self> {
let previous = self.clone();
let LLMConfigurations {
indexing,
embedding,
query,
} = &mut *self.llm;
for config in &mut [indexing, embedding, query] {
let maybe_root_key = previous.root_provider_api_key_for(config);
fill_llm(config, maybe_root_key)?;
}
Ok(self)
}
fn add_project_name_to_paths(mut self) -> Self {
if self.cache_dir.ends_with("kwaak") {
self.cache_dir.push(&self.project_name);
}
if self.log_dir.ends_with("kwaak/logs") {
self.log_dir.push(&self.project_name);
}
self
}
#[must_use]
fn root_provider_api_key_for(&self, provider: &LLMConfiguration) -> Option<&ApiKey> {
match provider {
LLMConfiguration::OpenAI { .. } => self.openai_api_key.as_ref(),
LLMConfiguration::OpenRouter { .. } => self.open_router_api_key.as_ref(),
LLMConfiguration::Anthropic { .. } => self.anthropic_api_key.as_ref(),
LLMConfiguration::AzureOpenAI { .. } => self.azure_openai_api_key.as_ref(),
_ => None,
}
}
#[must_use]
pub fn indexing_provider(&self) -> &LLMConfiguration {
let LLMConfigurations { indexing, .. } = &*self.llm;
indexing
}
#[must_use]
pub fn embedding_provider(&self) -> &LLMConfiguration {
let LLMConfigurations { embedding, .. } = &*self.llm;
embedding
}
#[must_use]
pub fn query_provider(&self) -> &LLMConfiguration {
let LLMConfigurations { query, .. } = &*self.llm;
query
}
#[must_use]
pub fn cache_dir(&self) -> &Path {
self.cache_dir.as_path()
}
#[must_use]
pub fn log_dir(&self) -> &Path {
self.log_dir.as_path()
}
#[must_use]
pub fn indexing_concurrency(&self) -> usize {
if let Some(concurrency) = self.indexing_concurrency {
return concurrency;
}
match self.indexing_provider() {
LLMConfiguration::OpenAI { .. } => num_cpus::get() * 4,
LLMConfiguration::AzureOpenAI { .. } => num_cpus::get() * 4,
LLMConfiguration::OpenRouter { .. } => num_cpus::get() * 4,
LLMConfiguration::Ollama { .. } => num_cpus::get(),
LLMConfiguration::FastEmbed { .. } => num_cpus::get(),
LLMConfiguration::Anthropic { .. } => num_cpus::get() * 4,
#[cfg(debug_assertions)]
LLMConfiguration::Testing => num_cpus::get(),
}
}
#[must_use]
pub fn indexing_batch_size(&self) -> usize {
if let Some(batch_size) = self.indexing_batch_size {
return batch_size;
}
match self.indexing_provider() {
LLMConfiguration::OpenAI { .. } => 12,
LLMConfiguration::AzureOpenAI { .. } => 12,
LLMConfiguration::Ollama { .. } => 256,
LLMConfiguration::OpenRouter { .. } => 12,
LLMConfiguration::FastEmbed { .. } => 256,
LLMConfiguration::Anthropic { .. } => 12,
#[cfg(debug_assertions)]
LLMConfiguration::Testing => 1,
}
}
#[must_use]
pub fn is_github_enabled(&self) -> bool {
self.github_api_key.is_some() && self.git.owner.is_some() && self.git.repository.is_some()
}
#[must_use]
pub fn enabled_tools(&self) -> Vec<&str> {
self.tools
.iter()
.filter(|(_, &enabled)| enabled)
.map(|(key, _)| key.as_str())
.collect()
}
#[must_use]
pub fn disabled_tools(&self) -> Vec<&str> {
self.tools
.iter()
.filter(|(_, &enabled)| !enabled)
.map(|(key, _)| key.as_str())
.collect()
}
}
fn fill_llm(llm: &mut LLMConfiguration, root_key: Option<&ApiKey>) -> Result<()> {
match llm {
LLMConfiguration::OpenAI { api_key, .. } => {
if api_key.is_none() {
if let Some(root) = root_key {
*api_key = Some(root.clone());
} else {
anyhow::bail!(
"OpenAI config requires an `api_key`, and none was provided or available in the root"
);
}
}
}
LLMConfiguration::AzureOpenAI { api_key, .. } => {
if api_key.is_none() {
if let Some(root) = root_key {
*api_key = Some(root.clone());
} else {
anyhow::bail!(
"AzureOpenAI config requires an `api_key`, and none was provided or available in the root"
);
}
}
}
LLMConfiguration::Anthropic { api_key, .. } => {
if api_key.is_none() {
if let Some(root) = root_key {
*api_key = Some(root.clone());
} else {
anyhow::bail!(
"Anthropic config requires an `api_key`, and none was provided or available in the root"
);
}
}
}
LLMConfiguration::OpenRouter { api_key, .. } => {
if api_key.is_none() {
if let Some(root) = root_key {
*api_key = Some(root.clone());
} else {
anyhow::bail!(
"OpenRouter config requires an `api_key`, and none was provided or available in the root"
);
}
}
}
LLMConfiguration::Ollama { .. } | LLMConfiguration::FastEmbed { .. } => {
}
#[cfg(debug_assertions)]
LLMConfiguration::Testing => {}
}
Ok(())
}
#[cfg(test)]
mod tests {
#![allow(irrefutable_let_patterns)]
use crate::config::{OpenAIEmbeddingModel, OpenAIPromptModel};
use super::*;
use swiftide::integrations::treesitter::SupportedLanguages;
#[test]
fn test_deserialize_toml_multiple() {
let toml = r#"
language = "rust"
[commands]
test = "cargo test"
coverage = "cargo tarpaulin"
[git]
owner = "bosun-ai"
repository = "kwaak"
[llm.indexing]
provider = "OpenAI"
api_key = "text:test-key"
prompt_model = "gpt-4o-mini"
[llm.query]
provider = "OpenAI"
api_key = "text:other-test-key"
prompt_model = "gpt-4o-mini"
[llm.embedding]
provider = "OpenAI"
api_key = "text:other-test-key"
embedding_model = "text-embedding-3-small"
"#;
let config: Config = Config::from_str(toml).unwrap();
assert_eq!(config.language, SupportedLanguages::Rust);
if let LLMConfigurations {
indexing,
embedding,
query,
} = &*config.llm
{
if let LLMConfiguration::OpenAI {
api_key,
prompt_model,
..
} = indexing
{
assert_eq!(api_key.as_ref().unwrap().expose_secret(), "test-key");
assert_eq!(prompt_model, &OpenAIPromptModel::GPT4OMini);
} else {
panic!("Expected OpenAI configuration for indexing");
}
if let LLMConfiguration::OpenAI {
api_key,
prompt_model,
..
} = query
{
assert_eq!(api_key.as_ref().unwrap().expose_secret(), "other-test-key");
assert_eq!(prompt_model, &OpenAIPromptModel::GPT4OMini);
} else {
panic!("Expected OpenAI configuration for query");
}
if let LLMConfiguration::OpenAI {
api_key,
embedding_model,
..
} = embedding
{
assert_eq!(api_key.as_ref().unwrap().expose_secret(), "other-test-key");
assert_eq!(embedding_model, &OpenAIEmbeddingModel::TextEmbedding3Small);
}
} else {
panic!("Expected multiple LLM configurations");
}
assert!(!config.otel_enabled);
}
#[test]
fn test_seed_openai_api_key_from_root_multiple_with_overwrite() {
let toml = r#"
language = "rust"
openai_api_key = "text:root-api-key"
[commands]
test = "cargo test"
coverage = "cargo tarpaulin"
[git]
owner = "bosun-ai"
repository = "kwaak"
[llm.indexing]
provider = "OpenAI"
prompt_model = "gpt-4o-mini"
[llm.query]
provider = "OpenAI"
api_key = "text:child-api-key"
prompt_model = "gpt-4o-mini"
[llm.embedding]
provider = "OpenAI"
embedding_model = "text-embedding-3-small"
"#;
let config: Config = Config::from_str(toml).unwrap();
let LLMConfiguration::OpenAI { api_key, .. } = config.indexing_provider() else {
panic!("Expected OpenAI configuration for indexing")
};
assert_eq!(
api_key.as_ref().unwrap().expose_secret(),
config.openai_api_key.as_ref().unwrap().expose_secret()
);
let LLMConfiguration::OpenAI { api_key, .. } = config.query_provider() else {
panic!("Expected OpenAI configuration for indexing")
};
assert_eq!(api_key.as_ref().unwrap().expose_secret(), "child-api-key");
}
#[test]
fn test_add_project_name_to_paths() {
let toml = r#"
language = "rust"
project_name = "test"
[commands]
test = "cargo test"
coverage = "cargo tarpaulin"
[git]
owner = "bosun-ai"
repository = "kwaak"
[llm.indexing]
provider = "OpenAI"
api_key = "text:test-key"
prompt_model = "gpt-4o-mini"
[llm.query]
provider = "OpenAI"
api_key = "text:other-test-key"
prompt_model = "gpt-4o-mini"
[llm.embedding]
provider = "OpenAI"
api_key = "text:other-test-key"
embedding_model = "text-embedding-3-small"
"#;
let config: Config = Config::from_str(toml).unwrap();
assert!(config.cache_dir.ends_with("kwaak/test"));
assert!(config.log_dir().ends_with("kwaak/logs/test"));
}
#[test]
fn test_deserialize_disabled_tools_list() {
let toml = r#"
language = "rust"
[tools]
git = false
shell_command = false
write_file = false
run_tests = true
[commands]
test = "cargo test"
coverage = "cargo tarpaulin"
[llm.indexing]
provider = "OpenAI"
api_key = "text:test-key"
prompt_model = "gpt-4o-mini"
[llm.query]
provider = "OpenAI"
api_key = "text:test-key"
prompt_model = "gpt-4o-mini"
[llm.embedding]
provider = "OpenAI"
api_key = "text:test-key"
embedding_model = "text-embedding-3-small"
[git]
repository = "kwaak"
owner = "bosun-ai"
"#;
let config: Config = Config::from_str(toml).unwrap();
assert_eq!(config.disabled_tools().len(), 3);
assert!(config.disabled_tools().contains(&"git"));
assert!(config.disabled_tools().contains(&"shell_command"));
assert!(config.disabled_tools().contains(&"write_file"));
assert!(config.enabled_tools().contains(&"run_tests"));
assert_eq!(config.enabled_tools().len(), 1);
}
}