1use std::{
2 collections::{HashMap, HashSet},
3 path::{Path, PathBuf},
4 str,
5};
6
7use quote::ToTokens;
8use serde::{Deserialize, Serialize};
9use syn::{
10 parse::{Parse, ParseStream},
11 Expr, Ident, Token,
12};
13use toml::{value::Table, Value};
14
15use crate::{cpp::CppConfig, demo_gen::DemoConfig, js::JsConfig, kotlin::KotlinConfig};
16use diplomat_core::hir::LoweringConfig;
17
18#[derive(Clone, Default, Debug, Serialize, Deserialize)]
19pub struct SharedConfig {
20 pub lib_name: Option<String>,
21 pub unsafe_references_in_callbacks: Option<bool>,
24 pub custom_extra_code_location: PathBuf,
26 pub features_enabled: HashSet<String>,
28}
29
30impl SharedConfig {
31 pub fn overrides_shared(name: &str) -> bool {
33 let name: String = name.split(".").skip(1).collect();
35 matches!(
36 name.as_str(),
37 "lib_name"
38 | "unsafe_references_in_callbacks"
39 | "custom_extra_code_location"
40 | "features_enabled"
41 )
42 }
43
44 pub fn set(&mut self, key: &str, value: Value) {
45 match key {
46 "lib_name" => {
47 if value.is_str() {
48 self.lib_name = value.as_str().map(|v| v.to_string())
49 } else {
50 panic!("Config key `lib_name` must be a string");
51 }
52 }
53 "unsafe_references_in_callbacks" => {
54 if value.is_bool() {
55 self.unsafe_references_in_callbacks = value.as_bool()
56 } else {
57 panic!("Config key `unsafe_references_in_callbacks` must be a boolean");
58 }
59 }
60 "custom_extra_code_location" => {
61 if value.is_str() {
62 self.custom_extra_code_location = PathBuf::from(value.as_str().unwrap())
63 } else {
64 panic!("Config key `custom_extra_code_location` must be a string");
65 }
66 }
67 "features_enabled" => {
68 let hash_set = match &value {
69 Value::Array(arr) => {
70 let str_arr : HashSet<String> = arr.iter().map(|v| {
71 let st = v.as_str().unwrap_or_else(|| panic!("Expected features_enabled=[] to be an array of strings. Got {v:?}"));
72 st.to_string()
73 }).collect();
74 str_arr
75 }
76 Value::Table(t) if t.len() == 1 => t.keys().cloned().collect(),
77 Value::String(st) => {
78 if st.starts_with("[") && st.ends_with("]") {
80 let slice = &st[1..st.len() - 1];
81 let hash = slice
82 .split(",")
83 .map(|s| s.replace("\"", "").trim().to_string())
84 .collect();
85 hash
86 } else {
87 HashSet::from([st.clone()])
88 }
89 }
90 _ => panic!("Config key `features_enabled` must be an array or string."),
91 };
92 self.features_enabled = hash_set;
93 }
94 _ => (),
95 }
96 }
97
98 pub fn lowering_config(&self) -> LoweringConfig {
99 let mut cfg = LoweringConfig::default();
100 if let Some(refs) = self.unsafe_references_in_callbacks {
101 cfg.unsafe_references_in_callbacks = refs;
102 }
103 cfg
104 }
105}
106
107#[derive(Clone, Default, Debug, Serialize, Deserialize)]
108pub struct Config {
109 #[serde(flatten)]
110 pub shared_config: SharedConfig,
111 #[serde(rename = "kotlin")]
112 pub kotlin_config: KotlinConfig,
113 #[serde(rename = "demo_gen")]
114 pub demo_gen_config: DemoConfig,
115 #[serde(rename = "js")]
116 pub js_config: JsConfig,
117 #[serde(rename = "cpp")]
118 pub cpp_config: CppConfig,
119 #[serde(skip)]
121 pub language_overrides: HashMap<String, Value>,
122}
123
124impl Config {
125 pub fn set(&mut self, key: &str, value: Value) {
126 if key.starts_with("kotlin.") {
127 if SharedConfig::overrides_shared(key) {
128 self.language_overrides.insert(key.to_string(), value);
129 } else {
130 self.kotlin_config.set(&key.replace("kotlin.", ""), value);
131 }
132 } else if key.starts_with("demo_gen.") {
133 if SharedConfig::overrides_shared(key) {
134 self.language_overrides.insert(key.to_string(), value);
135 } else {
136 self.demo_gen_config
137 .set(&key.replace("demo_gen.", ""), value);
138 }
139 } else if key.starts_with("nanobind.") {
140 if SharedConfig::overrides_shared(key) {
141 self.language_overrides.insert(key.to_string(), value);
142 } } else if key.starts_with("js.") {
144 if SharedConfig::overrides_shared(key) {
145 self.language_overrides.insert(key.to_string(), value);
146 } else {
147 self.js_config.set(&key.replace("js.", ""), value);
148 }
149 } else if key.starts_with("cpp.") {
150 if SharedConfig::overrides_shared(key) {
151 self.language_overrides.insert(key.to_string(), value);
152 } else {
153 self.cpp_config.set(&key.replace("cpp.", ""), value);
154 }
155 } else {
156 self.shared_config.set(key, value)
157 }
158 }
159
160 pub fn get_overridden(self, target_language: &str) -> Self {
161 let mut out = self.clone();
162
163 let m = format!("{target_language}.");
165 for (k, v) in out.language_overrides.iter() {
166 if k.starts_with(&m) {
167 out.shared_config.set(&k.replace(&m, ""), v.clone());
168 }
169 }
170 out
171 }
172
173 pub fn read_file(&mut self, path: &Path) -> Result<(), String> {
175 let config_table: Table = if path.exists() {
176 let file_buf = std::fs::read(path).map_err(|e| e.to_string())?;
177 let s = str::from_utf8(&file_buf).map_err(|_| "Config file is not UTF8".to_string())?;
178 toml::from_str(s).map_err(|_| "Config file is not valid TOML".to_string())?
179 } else {
180 Table::default()
181 };
182
183 for (key, value) in config_table {
184 let key = heck::AsSnakeCase(key).to_string();
188 if let toml::Value::Table(t) = value {
189 for (subkey, subvalue) in t {
190 let subkey = heck::AsSnakeCase(subkey).to_string();
191 self.set(&format!("{key}.{subkey}"), subvalue);
192 }
193 } else {
194 self.set(&key, value);
195 }
196 }
197 Ok(())
198 }
199
200 pub fn read_cli_settings(&mut self, settings: Vec<String>) {
202 for c in settings {
203 let split = c.split_once("=");
204 if let Some((key, value)) = split {
205 self.set(key, toml_value_from_str(value));
206 } else {
207 eprintln!("Could not read {c}, expected =");
208 }
209 }
210 }
211}
212
213pub fn toml_value_from_str(string: &str) -> toml::Value {
214 let try_parse = toml::from_str::<toml::Value>(string);
215
216 if let Ok(out) = try_parse {
219 out
220 } else {
221 toml::Value::String(string.to_string())
222 }
223}
224
225#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize)]
226#[non_exhaustive]
227pub(crate) struct DiplomatBackendConfigAttr {
228 pub key_value_pairs: Vec<DiplomatBackendConfigKeyValue>,
229}
230
231impl Parse for DiplomatBackendConfigAttr {
232 fn parse(input: ParseStream) -> syn::Result<Self> {
233 let list = input.parse_terminated(DiplomatBackendConfigKeyValue::parse, Token![,])?;
234 let vec = list.into_iter().collect();
235 Ok(Self {
236 key_value_pairs: vec,
237 })
238 }
239}
240
241#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize)]
242#[non_exhaustive]
243pub(crate) struct DiplomatBackendConfigKeyValue {
244 pub key: String,
245 pub value: String,
246}
247
248impl Parse for DiplomatBackendConfigKeyValue {
249 fn parse(input: ParseStream) -> syn::Result<Self> {
250 let mut key_str: Vec<String> = Vec::new();
251
252 loop {
253 let i: Ident = input.parse()?;
254
255 key_str.push(i.to_string());
256
257 if input.peek(Token![.]) {
258 let _period: Token![.] = input.parse()?;
259 } else {
260 break;
261 }
262 }
263
264 let _equals: Token![=] = input.parse()?;
265
266 let val_expr: Expr = input.parse()?;
267
268 let value = val_expr.to_token_stream().to_string();
269
270 Ok(Self {
271 key: key_str.join("."),
272 value,
273 })
274 }
275}
276
277pub(crate) fn find_top_level_attr(module_items: Vec<syn::Item>) -> Vec<DiplomatBackendConfigAttr> {
278 let path = syn::parse_str("diplomat::config").unwrap();
279
280 let attrs = module_items
281 .iter()
282 .filter_map(|i| match i {
283 syn::Item::Struct(s) => Some(s.attrs.clone()),
284 syn::Item::Impl(i) => Some(i.attrs.clone()),
285 syn::Item::Mod(m) => Some(m.attrs.clone()),
286 _ => None,
287 })
288 .filter_map(|attrs| {
289 let attributes_vec = attrs
290 .iter()
291 .filter_map(|attribute| {
292 if attribute.path() == &path {
293 Some(
294 syn::parse2::<DiplomatBackendConfigAttr>(
295 attribute
296 .parse_args()
297 .expect("Failed to parse malformed diplomat::config"),
298 )
299 .expect("Could not parse DiplomatBackendConfig attribute."),
300 )
301 } else {
302 None
303 }
304 })
305 .collect::<Vec<_>>();
306
307 if !attributes_vec.is_empty() {
308 Some(attributes_vec)
309 } else {
310 None
311 }
312 });
313
314 let mut out_config = Vec::new();
315
316 for mut a in attrs {
317 out_config.append(&mut a);
318 }
319
320 out_config
321}
322
323#[cfg(test)]
324mod test {
325 use toml::Value;
326
327 #[test]
328 fn test_toml_parse() {
329 let t = "true";
330 assert!(toml::from_str::<Value>(t).is_err());
331 assert_eq!(super::toml_value_from_str(t), Value::String(t.to_string()));
332 }
333}