1use std::{env::current_dir, fmt::Display, fs::create_dir_all, path::PathBuf};
2
3use directories::ProjectDirs;
4
5const PROMPTS_DIR: &str = "prompts";
6const PRESETS_DIR: &str = "presets";
7const THREADS_DIR: &str = "threads";
8const LOGS_DIR: &str = "logs";
9const CONFIG_DIR: &str = "config";
10const WEIGHTS_DIR: &str = "weights";
11
12pub const PROJECT_DIR_NAME: &str = ".llmvm";
13
14pub enum DirType {
15 Prompts,
16 Presets,
17 Threads,
18 Logs,
19 Config,
20 Weights,
21}
22
23impl Display for DirType {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 write!(
26 f,
27 "{}",
28 match self {
29 DirType::Prompts => PROMPTS_DIR,
30 DirType::Presets => PRESETS_DIR,
31 DirType::Threads => THREADS_DIR,
32 DirType::Logs => LOGS_DIR,
33 DirType::Config => CONFIG_DIR,
34 DirType::Weights => WEIGHTS_DIR,
35 }
36 )
37 }
38}
39
40pub fn get_home_dirs() -> Option<ProjectDirs> {
41 ProjectDirs::from("com", "djandries", "llmvm")
42}
43
44pub fn get_project_dir() -> Option<PathBuf> {
45 current_dir().ok().map(|p| p.join(PROJECT_DIR_NAME))
46}
47
48fn get_home_file_path(dir_type: DirType, filename: &str) -> Option<PathBuf> {
49 get_home_dirs().map(|p| {
50 let subdir = match dir_type {
51 DirType::Config => p.config_dir().into(),
52 _ => p.data_dir().join(dir_type.to_string()),
53 };
54 create_dir_all(&subdir).ok();
55 subdir.join(filename)
56 })
57}
58
59pub fn get_file_path(dir_type: DirType, filename: &str, will_create: bool) -> Option<PathBuf> {
60 let project_dir = get_project_dir();
62 if let Some(project_dir) = project_dir {
63 if project_dir.exists() {
64 let type_dir = project_dir.join(dir_type.to_string());
65 create_dir_all(&type_dir).ok();
66 let file_dir = type_dir.join(filename);
67 if will_create || file_dir.exists() {
68 return Some(file_dir);
69 }
70 }
71 }
72 get_home_file_path(dir_type, filename)
74}
75
76#[cfg(feature = "logging")]
77pub mod logging {
78 use std::str::FromStr;
79
80 use std::fs::OpenOptions;
81 use tracing_subscriber::{filter::Directive, EnvFilter};
82
83 use super::{get_file_path, DirType};
84
85 pub fn setup_subscriber(directive: Option<&str>, log_filename: Option<&str>) {
86 let subscriber_builder = tracing_subscriber::fmt().with_env_filter(
87 EnvFilter::builder()
88 .with_default_directive(
89 directive
90 .map(|d| Directive::from_str(d).expect("logging directive should be valid"))
91 .unwrap_or_default(),
92 )
93 .from_env()
94 .expect("should be able to read logging directive from env"),
95 );
96
97 match log_filename {
98 None => subscriber_builder.with_writer(std::io::stderr).init(),
99 Some(filename) => {
100 subscriber_builder
101 .with_writer(
102 OpenOptions::new()
103 .create(true)
104 .truncate(true)
105 .write(true)
106 .open(
107 get_file_path(DirType::Logs, filename, true)
108 .expect("should be able to find log path"),
109 )
110 .expect("should be able to open log file"),
111 )
112 .init();
113 }
114 };
115 }
116}
117
118#[cfg(feature = "config")]
119pub mod config {
120 use config::{Config, ConfigError, Environment, File, FileFormat};
121 use multilink::ConfigExampleSnippet;
122 use serde::de::DeserializeOwned;
123 use std::{fs, io::Write};
124
125 use crate::{get_home_file_path, DirType};
126
127 use super::get_file_path;
128
129 fn maybe_save_example_config<T: ConfigExampleSnippet>(config_filename: &str) {
130 let home_config_path = get_home_file_path(DirType::Config, config_filename)
131 .expect("should be able to find home config path");
132 if home_config_path.exists() {
133 return;
134 }
135 fs::File::create(home_config_path)
136 .and_then(|mut f| f.write_all(T::config_example_snippet().as_bytes()))
137 .ok();
138 }
139
140 pub fn load_config<T: DeserializeOwned + ConfigExampleSnippet>(
141 config_filename: &str,
142 ) -> Result<T, ConfigError> {
143 let config_path = get_file_path(DirType::Config, config_filename, false)
145 .expect("should be able to find config path");
146
147 maybe_save_example_config::<T>(config_filename);
148
149 Config::builder()
150 .add_source(
151 File::new(
152 config_path
153 .to_str()
154 .expect("config path should return to str"),
155 FileFormat::Toml,
156 )
157 .required(false),
158 )
159 .add_source(Environment::with_prefix("LLMVM"))
160 .build()?
161 .try_deserialize()
162 }
163}