use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use trusty_common::LocalModelConfig;
pub const DEFAULT_OPENROUTER_MODEL: &str = "anthropic/claude-3-5-sonnet";
pub const DEFAULT_MAX_CONTEXT_TOKENS: usize = 4096;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct OpenRouterConfig {
#[serde(default)]
pub api_key: String,
#[serde(default = "default_model")]
pub model: String,
#[serde(default = "default_max_context_tokens")]
pub max_context_tokens: usize,
#[serde(default)]
pub system_prompt: String,
}
fn default_model() -> String {
DEFAULT_OPENROUTER_MODEL.to_string()
}
fn default_max_context_tokens() -> usize {
DEFAULT_MAX_CONTEXT_TOKENS
}
impl Default for OpenRouterConfig {
fn default() -> Self {
Self {
api_key: String::new(),
model: default_model(),
max_context_tokens: default_max_context_tokens(),
system_prompt: String::new(),
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct UserConfig {
#[serde(default)]
pub openrouter: OpenRouterConfig,
#[serde(default)]
pub local_model: LocalModelConfig,
}
impl PartialEq for UserConfig {
fn eq(&self, other: &Self) -> bool {
self.openrouter == other.openrouter
&& self.local_model.enabled == other.local_model.enabled
&& self.local_model.base_url == other.local_model.base_url
&& self.local_model.model == other.local_model.model
}
}
impl Eq for UserConfig {}
impl UserConfig {
pub fn load() -> Result<Self> {
Self::load_from(&default_config_path()?)
}
pub fn load_from(path: &Path) -> Result<Self> {
if !path.exists() {
return Ok(Self::default());
}
let raw = std::fs::read_to_string(path)
.with_context(|| format!("read config file {}", path.display()))?;
let cfg: UserConfig = toml::from_str(&raw)
.with_context(|| format!("parse config file {}", path.display()))?;
Ok(cfg)
}
pub fn save(&self) -> Result<()> {
self.save_to(&default_config_path()?)
}
pub fn save_to(&self, path: &Path) -> Result<()> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)
.with_context(|| format!("create config dir {}", parent.display()))?;
}
let toml_str = toml::to_string_pretty(self).context("serialize user config to TOML")?;
std::fs::write(path, toml_str)
.with_context(|| format!("write config file {}", path.display()))?;
Ok(())
}
pub fn set_dotted(&mut self, key: &str, value: &str) -> Result<()> {
match key {
"openrouter.api_key" => self.openrouter.api_key = value.to_string(),
"openrouter.model" => self.openrouter.model = value.to_string(),
"openrouter.max_context_tokens" => {
self.openrouter.max_context_tokens = value
.parse()
.with_context(|| format!("parse usize for {key}"))?;
}
"openrouter.system_prompt" => self.openrouter.system_prompt = value.to_string(),
"local_model.enabled" => {
self.local_model.enabled = value
.parse()
.with_context(|| format!("parse bool for {key}"))?;
}
"local_model.base_url" => self.local_model.base_url = value.to_string(),
"local_model.model" => self.local_model.model = value.to_string(),
other => anyhow::bail!("unknown config key: {other}"),
}
Ok(())
}
}
pub fn default_config_path() -> Result<PathBuf> {
let home = dirs::home_dir().context("could not resolve home directory")?;
Ok(home.join(".trusty-memory").join("config.toml"))
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn user_config_default_when_missing() {
let dir = tempdir().unwrap();
let path = dir.path().join("nope/config.toml");
let cfg = UserConfig::load_from(&path).expect("missing file should be ok");
assert_eq!(cfg, UserConfig::default());
assert_eq!(cfg.openrouter.model, DEFAULT_OPENROUTER_MODEL);
assert_eq!(
cfg.openrouter.max_context_tokens,
DEFAULT_MAX_CONTEXT_TOKENS
);
assert!(cfg.openrouter.api_key.is_empty());
}
#[test]
fn user_config_roundtrip() {
let dir = tempdir().unwrap();
let path = dir.path().join("config.toml");
let mut cfg = UserConfig::default();
cfg.openrouter.api_key = "sk-or-test-key".to_string(); cfg.openrouter.model = "anthropic/claude-3-opus".to_string();
cfg.openrouter.max_context_tokens = 2048;
cfg.save_to(&path).unwrap();
let loaded = UserConfig::load_from(&path).unwrap();
assert_eq!(loaded, cfg);
assert_eq!(loaded.openrouter.api_key, "sk-or-test-key"); assert_eq!(loaded.openrouter.model, "anthropic/claude-3-opus");
assert_eq!(loaded.openrouter.max_context_tokens, 2048);
}
#[test]
fn set_dotted_known_keys() {
let mut cfg = UserConfig::default();
cfg.set_dotted("openrouter.api_key", "sk-or-x").unwrap();
cfg.set_dotted("openrouter.model", "anthropic/foo").unwrap();
cfg.set_dotted("openrouter.max_context_tokens", "1234")
.unwrap();
assert_eq!(cfg.openrouter.api_key, "sk-or-x");
assert_eq!(cfg.openrouter.model, "anthropic/foo");
assert_eq!(cfg.openrouter.max_context_tokens, 1234);
}
#[test]
fn set_dotted_unknown_key_errors() {
let mut cfg = UserConfig::default();
let err = cfg.set_dotted("openrouter.unknown", "x").unwrap_err();
assert!(err.to_string().contains("unknown config key"));
}
}