llmvm_util/
lib.rs

1use std::{env::current_dir, fmt::Display, fs::create_dir_all, path::PathBuf};
2
3use directories::ProjectDirs;
4
5const PROMPTS_DIR: &str = "prompts";
6const PRESETS_DIR: &str = "presets";
7const THREADS_DIR: &str = "threads";
8const LOGS_DIR: &str = "logs";
9const CONFIG_DIR: &str = "config";
10const WEIGHTS_DIR: &str = "weights";
11
12pub const PROJECT_DIR_NAME: &str = ".llmvm";
13
14pub enum DirType {
15    Prompts,
16    Presets,
17    Threads,
18    Logs,
19    Config,
20    Weights,
21}
22
23impl Display for DirType {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        write!(
26            f,
27            "{}",
28            match self {
29                DirType::Prompts => PROMPTS_DIR,
30                DirType::Presets => PRESETS_DIR,
31                DirType::Threads => THREADS_DIR,
32                DirType::Logs => LOGS_DIR,
33                DirType::Config => CONFIG_DIR,
34                DirType::Weights => WEIGHTS_DIR,
35            }
36        )
37    }
38}
39
40pub fn get_home_dirs() -> Option<ProjectDirs> {
41    ProjectDirs::from("com", "djandries", "llmvm")
42}
43
44pub fn get_project_dir() -> Option<PathBuf> {
45    current_dir().ok().map(|p| p.join(PROJECT_DIR_NAME))
46}
47
48fn get_home_file_path(dir_type: DirType, filename: &str) -> Option<PathBuf> {
49    get_home_dirs().map(|p| {
50        let subdir = match dir_type {
51            DirType::Config => p.config_dir().into(),
52            _ => p.data_dir().join(dir_type.to_string()),
53        };
54        create_dir_all(&subdir).ok();
55        subdir.join(filename)
56    })
57}
58
59pub fn get_file_path(dir_type: DirType, filename: &str, will_create: bool) -> Option<PathBuf> {
60    // Check for project file path, if it exists or if creating new file
61    let project_dir = get_project_dir();
62    if let Some(project_dir) = project_dir {
63        if project_dir.exists() {
64            let type_dir = project_dir.join(dir_type.to_string());
65            create_dir_all(&type_dir).ok();
66            let file_dir = type_dir.join(filename);
67            if will_create || file_dir.exists() {
68                return Some(file_dir);
69            }
70        }
71    }
72    // Return user home file path
73    get_home_file_path(dir_type, filename)
74}
75
76#[cfg(feature = "logging")]
77pub mod logging {
78    use std::str::FromStr;
79
80    use std::fs::OpenOptions;
81    use tracing_subscriber::{filter::Directive, EnvFilter};
82
83    use super::{get_file_path, DirType};
84
85    pub fn setup_subscriber(directive: Option<&str>, log_filename: Option<&str>) {
86        let subscriber_builder = tracing_subscriber::fmt().with_env_filter(
87            EnvFilter::builder()
88                .with_default_directive(
89                    directive
90                        .map(|d| Directive::from_str(d).expect("logging directive should be valid"))
91                        .unwrap_or_default(),
92                )
93                .from_env()
94                .expect("should be able to read logging directive from env"),
95        );
96
97        match log_filename {
98            None => subscriber_builder.with_writer(std::io::stderr).init(),
99            Some(filename) => {
100                subscriber_builder
101                    .with_writer(
102                        OpenOptions::new()
103                            .create(true)
104                            .truncate(true)
105                            .write(true)
106                            .open(
107                                get_file_path(DirType::Logs, filename, true)
108                                    .expect("should be able to find log path"),
109                            )
110                            .expect("should be able to open log file"),
111                    )
112                    .init();
113            }
114        };
115    }
116}
117
118#[cfg(feature = "config")]
119pub mod config {
120    use config::{Config, ConfigError, Environment, File, FileFormat};
121    use multilink::ConfigExampleSnippet;
122    use serde::de::DeserializeOwned;
123    use std::{fs, io::Write};
124
125    use crate::{get_home_file_path, DirType};
126
127    use super::get_file_path;
128
129    fn maybe_save_example_config<T: ConfigExampleSnippet>(config_filename: &str) {
130        let home_config_path = get_home_file_path(DirType::Config, config_filename)
131            .expect("should be able to find home config path");
132        if home_config_path.exists() {
133            return;
134        }
135        fs::File::create(home_config_path)
136            .and_then(|mut f| f.write_all(T::config_example_snippet().as_bytes()))
137            .ok();
138    }
139
140    pub fn load_config<T: DeserializeOwned + ConfigExampleSnippet>(
141        config_filename: &str,
142    ) -> Result<T, ConfigError> {
143        // TODO: add both root and project configs as sources
144        let config_path = get_file_path(DirType::Config, config_filename, false)
145            .expect("should be able to find config path");
146
147        maybe_save_example_config::<T>(config_filename);
148
149        Config::builder()
150            .add_source(
151                File::new(
152                    config_path
153                        .to_str()
154                        .expect("config path should return to str"),
155                    FileFormat::Toml,
156                )
157                .required(false),
158            )
159            .add_source(Environment::with_prefix("LLMVM"))
160            .build()?
161            .try_deserialize()
162    }
163}