Skip to main content

pike/commands/config/
apply.rs

1use anyhow::{bail, Context, Result};
2use derive_builder::Builder;
3use log::info;
4use serde::Deserialize;
5use std::{
6    collections::HashMap,
7    env, fs,
8    io::{BufRead, BufReader, Read, Write},
9    path::{Path, PathBuf},
10    process::{self, Command, Stdio},
11};
12use toml_edit::DocumentMut;
13
14/// Mapping of plugin service names to their properties specified in
15/// [plugin configuration](https://github.com/picodata/pike?tab=readme-ov-file#config-apply).
16///
17/// ### Example:
18///
19/// Assume plugin configuration YAML has the following content:
20///
21/// **`plugin_config.yaml`:**
22/// ```yaml
23/// service_name:
24///   http_server:
25///     url: "www.example.com"
26/// ```
27/// Mapping for such config is supposed to look like:
28///
29/// ```rust,no_run
30/// use std::collections::HashMap;
31///
32/// let plugin_config = HashMap::from([(
33///     // Name of the service
34///     "service_name".to_string(),
35///
36///     // Mapping of properties corresponding to the service
37///     HashMap::from([(
38///         "http_server".to_string(),
39///         serde_yaml::to_value(HashMap::from([(
40///             "url".to_string(),
41///             // URL is overridden for testing.
42///             "localhost:29092".to_string(),
43///         )]))
44///         .unwrap(),
45///     )]),
46/// )]);
47/// ```
48pub type ConfigMap = HashMap<String, HashMap<String, serde_yaml::Value>>;
49
50const DEFAULT_PLUGIN_CONFIG_PATH: &str = "plugin_config.yaml";
51const WISE_PIKE: &str = r"
52  ________________________________________
53/ You are trying to apply config from     \
54| custom directory, however to use this   |
55| flag, you must specify the plugin with  |
56\           --plugin-name                 /
57 ----------------------------------------
58 o
59o      ______/~/~/~/__           /((
60  o  // __            ====__    /_((
61 o  //  @))       ))))      ===/__((
62    ))           )))))))        __((
63    \\     \)     ))))    __===\ _((
64     \\_______________====      \_((
65                                 \((
66 ";
67
68fn read_config_from_path(path: &PathBuf) -> Result<ConfigMap> {
69    serde_yaml::from_str(
70        &fs::read_to_string(path)
71            .context(format!("failed to read config file at {}", path.display()))?,
72    )
73    .context(format!(
74        "failed to parse config file at {} as toml",
75        path.display()
76    ))
77}
78
79fn apply_service_config(
80    plugin_name: &str,
81    plugin_version: &str,
82    service_name: &str,
83    config: &HashMap<String, serde_yaml::Value>,
84    admin_socket: &Path,
85    picodata_path: &Path,
86) -> Result<()> {
87    let mut queries: Vec<String> = Vec::new();
88
89    for (key, value) in config {
90        let value = serde_json::to_string(&value)
91            .context(format!("failed to serialize the string with key {key}"))?;
92        queries.push(format!(
93            r#"ALTER PLUGIN "{plugin_name}" {plugin_version} SET "{service_name}"."{key}"='{value}';"#
94        ));
95    }
96
97    for query in queries {
98        log::info!("picodata admin: {query}");
99
100        let mut picodata_admin = Command::new(picodata_path)
101            .arg("admin")
102            .arg(
103                admin_socket
104                    .to_str()
105                    .context("path to picodata admin socket contains invalid characters")?,
106            )
107            .stdout(Stdio::piped())
108            .stderr(Stdio::piped())
109            .stdin(Stdio::piped())
110            .spawn()
111            .context("failed to run picodata admin")?;
112
113        {
114            let picodata_stdin = picodata_admin
115                .stdin
116                .as_mut()
117                .context("failed to get picodata stdin")?;
118            picodata_stdin
119                .write_all(query.as_bytes())
120                .context("failed to push queries into picodata admin")?;
121        }
122
123        let exit_status = picodata_admin
124            .wait()
125            .context("failed to wait for picodata admin")?
126            .code()
127            .unwrap();
128
129        let outputs: [Box<dyn Read + Send>; 2] = [
130            Box::new(picodata_admin.stdout.unwrap()),
131            Box::new(picodata_admin.stderr.unwrap()),
132        ];
133        for output in outputs {
134            let reader = BufReader::new(output);
135            for line in reader.lines() {
136                let line = line.expect("failed to read picodata admin output");
137                log::info!("picodata admin: {line}");
138            }
139        }
140
141        if exit_status == 1 {
142            bail!("failed to execute picodata query {query}");
143        }
144    }
145
146    Ok(())
147}
148
149fn apply_plugin_config(params: &Params, current_plugin_path: &str) -> Result<()> {
150    let cur_plugin_dir = env::current_dir()?
151        .join(&params.plugin_path)
152        .join(current_plugin_path);
153
154    let admin_socket = params
155        .plugin_path
156        .join(&params.data_dir)
157        .join("cluster")
158        .join("i1")
159        .join("admin.sock");
160
161    let cargo_manifest: &CargoManifest = &toml::from_str(
162        &fs::read_to_string(cur_plugin_dir.join("Cargo.toml"))
163            .context("failed to read Cargo.toml")?,
164    )
165    .context("failed to parse Cargo.toml")?;
166
167    let config: ConfigMap = match &params.config_source {
168        ConfigSource::Map(map) => map.clone(),
169        ConfigSource::Path(path) => read_config_from_path(&cur_plugin_dir.join(path))?,
170    };
171
172    for (service_name, service_config) in config {
173        apply_service_config(
174            &cargo_manifest.package.name,
175            &cargo_manifest.package.version,
176            &service_name,
177            &service_config,
178            &admin_socket,
179            &params.picodata_path,
180        )
181        .context(format!(
182            "failed to apply service config for service {service_name}"
183        ))?;
184    }
185
186    Ok(())
187}
188
189#[derive(Debug, Deserialize)]
190struct Package {
191    name: String,
192    version: String,
193}
194
195#[derive(Debug, Deserialize)]
196struct CargoManifest {
197    package: Package,
198}
199
200#[derive(Debug, Clone)]
201pub enum ConfigSource {
202    Map(ConfigMap),
203    Path(PathBuf),
204}
205
206impl Default for ConfigSource {
207    fn default() -> Self {
208        ConfigSource::Path(DEFAULT_PLUGIN_CONFIG_PATH.into())
209    }
210}
211
212#[derive(Debug, Builder)]
213pub struct Params {
214    #[builder(default, setter(custom))]
215    config_source: ConfigSource,
216    #[builder(default = "PathBuf::from(\"./tmp\")")]
217    data_dir: PathBuf,
218    #[builder(default = "PathBuf::from(\"./\")")]
219    plugin_path: PathBuf,
220    #[builder(default)]
221    plugin_name: Option<String>,
222    #[builder(default = "PathBuf::from(\"picodata\")")]
223    picodata_path: PathBuf,
224}
225
226impl ParamsBuilder {
227    pub fn config_path(&mut self, path: PathBuf) -> &mut Self {
228        self.config_source = Some(ConfigSource::Path(path));
229        self
230    }
231
232    #[allow(unused)]
233    pub fn config_map(&mut self, map: ConfigMap) -> &mut Self {
234        self.config_source = Some(ConfigSource::Map(map));
235        self
236    }
237}
238
239pub fn cmd(params: &Params) -> Result<()> {
240    // If plugin name flag was specified, apply config only for
241    // this exact plugin
242    if let Some(plugin_name) = &params.plugin_name {
243        info!("Applying plugin config for plugin {plugin_name}");
244        apply_plugin_config(params, plugin_name)?;
245        return Ok(());
246    }
247
248    let root_dir = env::current_dir()?.join(&params.plugin_path);
249
250    let cargo_toml_path = root_dir.join("Cargo.toml");
251    let cargo_toml_content = fs::read_to_string(&cargo_toml_path).context(format!(
252        "Failed to read Cargo.toml in {}",
253        &cargo_toml_path.display()
254    ))?;
255
256    let parsed_toml: DocumentMut = cargo_toml_content
257        .parse()
258        .context("Failed to parse Cargo.toml")?;
259
260    if let Some(workspace) = parsed_toml.get("workspace") {
261        if let ConfigSource::Path(config_path) = &params.config_source {
262            if config_path.to_str().unwrap() != DEFAULT_PLUGIN_CONFIG_PATH {
263                println!("{WISE_PIKE}");
264                process::exit(1);
265            }
266        }
267        info!("Applying plugin config for each plugin");
268
269        if let Some(members) = workspace.get("members") {
270            if let Some(members_array) = members.as_array() {
271                for member in members_array {
272                    let member_str = member.as_str();
273                    if member_str.is_none() {
274                        continue;
275                    }
276
277                    if !root_dir
278                        .join(member_str.unwrap())
279                        .join("manifest.yaml.template")
280                        .exists()
281                    {
282                        continue;
283                    }
284                    apply_plugin_config(params, member_str.unwrap())?;
285                }
286            }
287        }
288
289        return Ok(());
290    }
291
292    info!("Applying plugin config");
293
294    apply_plugin_config(params, "./")?;
295
296    Ok(())
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302    use std::time::{SystemTime, UNIX_EPOCH};
303
304    fn tmp_dir(prefix: &str) -> PathBuf {
305        let ts = SystemTime::now()
306            .duration_since(UNIX_EPOCH)
307            .unwrap()
308            .as_nanos();
309        let mut dir = env::temp_dir();
310        dir.push(format!("pike-config-apply-ut-{prefix}-{ts}"));
311        dir
312    }
313
314    #[test]
315    fn apply_service_config_uses_custom_picodata_path_and_fails_cleanly() {
316        let mut service_cfg: HashMap<String, serde_yaml::Value> = HashMap::new();
317        service_cfg.insert(
318            "k".to_string(),
319            serde_yaml::to_value("v").expect("serialize test value"),
320        );
321
322        let bogus_picodata = PathBuf::from("/this/does/not/exist/picodata-bogus");
323        let bogus_socket = Path::new("/tmp/nonexistent-admin.sock");
324
325        let err = apply_service_config(
326            "p",
327            "0.1.0",
328            "svc",
329            &service_cfg,
330            bogus_socket,
331            &bogus_picodata,
332        )
333        .unwrap_err();
334
335        let msg = format!("{err:#}");
336        assert!(
337            msg.contains("failed to run picodata admin"),
338            "expected process spawn error context, got: {msg}"
339        );
340    }
341
342    #[test]
343    fn params_builder_has_default_picodata_path() {
344        let params = ParamsBuilder::default().build().unwrap();
345        assert_eq!(params.picodata_path, PathBuf::from("picodata"));
346    }
347
348    #[test]
349    fn read_config_from_path_reports_read_error() {
350        let dir = tmp_dir("cfg");
351        let cfg = dir.join("no-file.yaml");
352        let res = read_config_from_path(&cfg);
353        assert!(res.is_err());
354    }
355}