Skip to main content

diplomat_tool/
config.rs

1use std::{
2    collections::{HashMap, HashSet},
3    path::{Path, PathBuf},
4    str,
5};
6
7use quote::ToTokens;
8use serde::{Deserialize, Serialize};
9use syn::{
10    parse::{Parse, ParseStream},
11    Expr, Ident, Token,
12};
13use toml::{value::Table, Value};
14
15use crate::{cpp::CppConfig, demo_gen::DemoConfig, js::JsConfig, kotlin::KotlinConfig};
16use diplomat_core::hir::LoweringConfig;
17
18#[derive(Clone, Default, Debug, Serialize, Deserialize)]
19pub struct SharedConfig {
20    pub lib_name: Option<String>,
21    /// Whether or not callbacks support references in parameters. This is unsafe: you need to be careful to not
22    /// retain these references on the foreign side.
23    pub unsafe_references_in_callbacks: Option<bool>,
24    /// The folder to pull custom bindings from. Defaults to the lib.rs folder.
25    pub custom_extra_code_location: PathBuf,
26    /// List of features to enable/disable generation for.
27    pub features_enabled: HashSet<String>,
28}
29
30impl SharedConfig {
31    // / Quick and dirty way to tell [`set_overrides`] whether or not to copy an override from a specific language over.
32    pub fn overrides_shared(name: &str) -> bool {
33        // Expect the first item in the iterator to be the name of the language, so we eliminate that:
34        let name: String = name.split(".").skip(1).collect();
35        matches!(
36            name.as_str(),
37            "lib_name"
38                | "unsafe_references_in_callbacks"
39                | "custom_extra_code_location"
40                | "features_enabled"
41        )
42    }
43
44    pub fn set(&mut self, key: &str, value: Value) {
45        match key {
46            "lib_name" => {
47                if value.is_str() {
48                    self.lib_name = value.as_str().map(|v| v.to_string())
49                } else {
50                    panic!("Config key `lib_name` must be a string");
51                }
52            }
53            "unsafe_references_in_callbacks" => {
54                if value.is_bool() {
55                    self.unsafe_references_in_callbacks = value.as_bool()
56                } else {
57                    panic!("Config key `unsafe_references_in_callbacks` must be a boolean");
58                }
59            }
60            "custom_extra_code_location" => {
61                if value.is_str() {
62                    self.custom_extra_code_location = PathBuf::from(value.as_str().unwrap())
63                } else {
64                    panic!("Config key `custom_extra_code_location` must be a string");
65                }
66            }
67            "features_enabled" => {
68                let hash_set = match &value {
69                    Value::Array(arr) => {
70                        let str_arr : HashSet<String> = arr.iter().map(|v| {
71                            let st = v.as_str().unwrap_or_else(|| panic!("Expected features_enabled=[] to be an array of strings. Got {v:?}"));
72                            st.to_string()
73                        }).collect();
74                        str_arr
75                    }
76                    Value::Table(t) if t.len() == 1 => t.keys().cloned().collect(),
77                    Value::String(st) => {
78                        // Serde Toml has screwed up reading an array:
79                        if st.starts_with("[") && st.ends_with("]") {
80                            let slice = &st[1..st.len() - 1];
81                            let hash = slice
82                                .split(",")
83                                .map(|s| s.replace("\"", "").trim().to_string())
84                                .collect();
85                            hash
86                        } else {
87                            HashSet::from([st.clone()])
88                        }
89                    }
90                    _ => panic!("Config key `features_enabled` must be an array or string."),
91                };
92                self.features_enabled = hash_set;
93            }
94            _ => (),
95        }
96    }
97
98    pub fn lowering_config(&self) -> LoweringConfig {
99        let mut cfg = LoweringConfig::default();
100        if let Some(refs) = self.unsafe_references_in_callbacks {
101            cfg.unsafe_references_in_callbacks = refs;
102        }
103        cfg
104    }
105}
106
107#[derive(Clone, Default, Debug, Serialize, Deserialize)]
108pub struct Config {
109    #[serde(flatten)]
110    pub shared_config: SharedConfig,
111    #[serde(rename = "kotlin")]
112    pub kotlin_config: KotlinConfig,
113    #[serde(rename = "demo_gen")]
114    pub demo_gen_config: DemoConfig,
115    #[serde(rename = "js")]
116    pub js_config: JsConfig,
117    #[serde(rename = "cpp")]
118    pub cpp_config: CppConfig,
119    /// Any language can override what's in [`SharedConfig`]. This is a structure that holds information about those specific overrides. [`Config`] will update [`SharedConfig`] based on the current language.
120    #[serde(skip)]
121    pub language_overrides: HashMap<String, Value>,
122}
123
124impl Config {
125    pub fn set(&mut self, key: &str, value: Value) {
126        if key.starts_with("kotlin.") {
127            if SharedConfig::overrides_shared(key) {
128                self.language_overrides.insert(key.to_string(), value);
129            } else {
130                self.kotlin_config.set(&key.replace("kotlin.", ""), value);
131            }
132        } else if key.starts_with("demo_gen.") {
133            if SharedConfig::overrides_shared(key) {
134                self.language_overrides.insert(key.to_string(), value);
135            } else {
136                self.demo_gen_config
137                    .set(&key.replace("demo_gen.", ""), value);
138            }
139        } else if key.starts_with("nanobind.") {
140            if SharedConfig::overrides_shared(key) {
141                self.language_overrides.insert(key.to_string(), value);
142            } // nanobind doesn't have any other config setting
143        } else if key.starts_with("js.") {
144            if SharedConfig::overrides_shared(key) {
145                self.language_overrides.insert(key.to_string(), value);
146            } else {
147                self.js_config.set(&key.replace("js.", ""), value);
148            }
149        } else if key.starts_with("cpp.") {
150            if SharedConfig::overrides_shared(key) {
151                self.language_overrides.insert(key.to_string(), value);
152            } else {
153                self.cpp_config.set(&key.replace("cpp.", ""), value);
154            }
155        } else {
156            self.shared_config.set(key, value)
157        }
158    }
159
160    pub fn get_overridden(self, target_language: &str) -> Self {
161        let mut out = self.clone();
162
163        // Look for a match of language_name.some_value in a potential key.
164        let m = format!("{target_language}.");
165        for (k, v) in out.language_overrides.iter() {
166            if k.starts_with(&m) {
167                out.shared_config.set(&k.replace(&m, ""), v.clone());
168            }
169        }
170        out
171    }
172
173    /// Given a filepath, read TOML formatted config settings from it (and modify the current Config struct from the read)
174    pub fn read_file(&mut self, path: &Path) -> Result<(), String> {
175        let config_table: Table = if path.exists() {
176            let file_buf = std::fs::read(path).map_err(|e| e.to_string())?;
177            let s = str::from_utf8(&file_buf).map_err(|_| "Config file is not UTF8".to_string())?;
178            toml::from_str(s).map_err(|_| "Config file is not valid TOML".to_string())?
179        } else {
180            Table::default()
181        };
182
183        for (key, value) in config_table {
184            // Quick way to take config.toml from kebab to snake case.
185            // This technically means that someone could also just as easily do CamelCase and have it translated,
186            // but I'm not sure I want to bother writing validation code for such a scenario.
187            let key = heck::AsSnakeCase(key).to_string();
188            if let toml::Value::Table(t) = value {
189                for (subkey, subvalue) in t {
190                    let subkey = heck::AsSnakeCase(subkey).to_string();
191                    self.set(&format!("{key}.{subkey}"), subvalue);
192                }
193            } else {
194                self.set(&key, value);
195            }
196        }
197        Ok(())
198    }
199
200    /// Given a vector of strings with the format `config.setting = value`, modify the `Config` struct appropriately.
201    pub fn read_cli_settings(&mut self, settings: Vec<String>) {
202        for c in settings {
203            let split = c.split_once("=");
204            if let Some((key, value)) = split {
205                self.set(key, toml_value_from_str(value));
206            } else {
207                eprintln!("Could not read {c}, expected =");
208            }
209        }
210    }
211}
212
213pub fn toml_value_from_str(string: &str) -> toml::Value {
214    let try_parse = toml::from_str::<toml::Value>(string);
215
216    // If there's an error parsing (because clap will not parse quotes, for example), we just treat what we're passed as a string:
217    // toml from_str
218    if let Ok(out) = try_parse {
219        out
220    } else {
221        toml::Value::String(string.to_string())
222    }
223}
224
225#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize)]
226#[non_exhaustive]
227pub(crate) struct DiplomatBackendConfigAttr {
228    pub key_value_pairs: Vec<DiplomatBackendConfigKeyValue>,
229}
230
231impl Parse for DiplomatBackendConfigAttr {
232    fn parse(input: ParseStream) -> syn::Result<Self> {
233        let list = input.parse_terminated(DiplomatBackendConfigKeyValue::parse, Token![,])?;
234        let vec = list.into_iter().collect();
235        Ok(Self {
236            key_value_pairs: vec,
237        })
238    }
239}
240
241#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize)]
242#[non_exhaustive]
243pub(crate) struct DiplomatBackendConfigKeyValue {
244    pub key: String,
245    pub value: String,
246}
247
248impl Parse for DiplomatBackendConfigKeyValue {
249    fn parse(input: ParseStream) -> syn::Result<Self> {
250        let mut key_str: Vec<String> = Vec::new();
251
252        loop {
253            let i: Ident = input.parse()?;
254
255            key_str.push(i.to_string());
256
257            if input.peek(Token![.]) {
258                let _period: Token![.] = input.parse()?;
259            } else {
260                break;
261            }
262        }
263
264        let _equals: Token![=] = input.parse()?;
265
266        let val_expr: Expr = input.parse()?;
267
268        let value = val_expr.to_token_stream().to_string();
269
270        Ok(Self {
271            key: key_str.join("."),
272            value,
273        })
274    }
275}
276
277pub(crate) fn find_top_level_attr(module_items: Vec<syn::Item>) -> Vec<DiplomatBackendConfigAttr> {
278    let path = syn::parse_str("diplomat::config").unwrap();
279
280    let attrs = module_items
281        .iter()
282        .filter_map(|i| match i {
283            syn::Item::Struct(s) => Some(s.attrs.clone()),
284            syn::Item::Impl(i) => Some(i.attrs.clone()),
285            syn::Item::Mod(m) => Some(m.attrs.clone()),
286            _ => None,
287        })
288        .filter_map(|attrs| {
289            let attributes_vec = attrs
290                .iter()
291                .filter_map(|attribute| {
292                    if attribute.path() == &path {
293                        Some(
294                            syn::parse2::<DiplomatBackendConfigAttr>(
295                                attribute
296                                    .parse_args()
297                                    .expect("Failed to parse malformed diplomat::config"),
298                            )
299                            .expect("Could not parse DiplomatBackendConfig attribute."),
300                        )
301                    } else {
302                        None
303                    }
304                })
305                .collect::<Vec<_>>();
306
307            if !attributes_vec.is_empty() {
308                Some(attributes_vec)
309            } else {
310                None
311            }
312        });
313
314    let mut out_config = Vec::new();
315
316    for mut a in attrs {
317        out_config.append(&mut a);
318    }
319
320    out_config
321}
322
323#[cfg(test)]
324mod test {
325    use toml::Value;
326
327    #[test]
328    fn test_toml_parse() {
329        let t = "true";
330        assert!(toml::from_str::<Value>(t).is_err());
331        assert_eq!(super::toml_value_from_str(t), Value::String(t.to_string()));
332    }
333}