Skip to main content

nu_cmd_base/
hook.rs

1use miette::Result;
2use nu_engine::{eval_block, eval_block_with_early_return, redirect_env};
3use nu_parser::parse;
4use nu_protocol::{
5    PipelineData, PositionalArg, ShellError, Span, Type, Value, VarId,
6    debugger::WithoutDebug,
7    engine::{Closure, EngineState, EnvName, Stack, StateWorkingSet},
8    report_error::{report_parse_error, report_shell_error},
9    shell_error::generic::GenericError,
10};
11use std::{collections::HashMap, sync::Arc};
12
13pub fn eval_env_change_hook(
14    env_change_hook: &HashMap<EnvName, Vec<Value>>,
15    engine_state: &mut EngineState,
16    stack: &mut Stack,
17) -> Result<(), ShellError> {
18    for (env, hooks) in env_change_hook {
19        let before = engine_state.previous_env_vars.get(env);
20        let after = stack.get_env_var(engine_state, env.as_str());
21        if before != after {
22            let before = before.cloned().unwrap_or_default();
23            let after = after.cloned().unwrap_or_default();
24
25            eval_hooks(
26                engine_state,
27                stack,
28                vec![("$before".into(), before), ("$after".into(), after.clone())],
29                hooks,
30                "env_change",
31            )?;
32
33            Arc::make_mut(&mut engine_state.previous_env_vars).insert(env.clone(), after);
34        }
35    }
36
37    Ok(())
38}
39
40pub fn eval_hooks(
41    engine_state: &mut EngineState,
42    stack: &mut Stack,
43    arguments: Vec<(String, Value)>,
44    hooks: &[Value],
45    hook_name: &str,
46) -> Result<(), ShellError> {
47    for hook in hooks {
48        eval_hook(
49            engine_state,
50            stack,
51            None,
52            arguments.clone(),
53            hook,
54            &format!("{hook_name} list, recursive"),
55        )?;
56    }
57    Ok(())
58}
59
60pub fn eval_hook(
61    engine_state: &mut EngineState,
62    stack: &mut Stack,
63    input: Option<PipelineData>,
64    arguments: Vec<(String, Value)>,
65    value: &Value,
66    hook_name: &str,
67) -> Result<PipelineData, ShellError> {
68    let mut output = PipelineData::empty();
69
70    let span = value.span();
71    match value {
72        Value::String { val, .. } => {
73            let (block, delta, vars) = {
74                let mut working_set = StateWorkingSet::new(engine_state);
75
76                let mut vars: Vec<(VarId, Value)> = vec![];
77
78                for (name, val) in arguments {
79                    let var_id = working_set.add_variable(
80                        name.as_bytes().to_vec(),
81                        val.span(),
82                        Type::Any,
83                        false,
84                    );
85                    vars.push((var_id, val));
86                }
87
88                let output = parse(
89                    &mut working_set,
90                    Some(&format!("{hook_name} hook")),
91                    val.as_bytes(),
92                    false,
93                );
94                if let Some(err) = working_set.parse_errors.first() {
95                    report_parse_error(Some(stack), &working_set, err);
96                    return Err(ShellError::Generic(GenericError::new(
97                        format!("Failed to run {hook_name} hook"),
98                        "source code has errors",
99                        span,
100                    )));
101                }
102
103                (output, working_set.render(), vars)
104            };
105
106            engine_state.merge_delta(delta)?;
107            let input = if let Some(input) = input {
108                input
109            } else {
110                PipelineData::empty()
111            };
112
113            let var_ids: Vec<VarId> = vars
114                .into_iter()
115                .map(|(var_id, val)| {
116                    stack.add_var(var_id, val);
117                    var_id
118                })
119                .collect();
120
121            match eval_block::<WithoutDebug>(engine_state, stack, &block, input).map(|p| p.body) {
122                Ok(pipeline_data) => {
123                    output = pipeline_data;
124                }
125                Err(err) => {
126                    report_shell_error(Some(stack), engine_state, &err);
127                }
128            }
129
130            for var_id in var_ids.iter() {
131                stack.remove_var(*var_id);
132            }
133        }
134        Value::List { vals, .. } => {
135            eval_hooks(engine_state, stack, arguments, vals, hook_name)?;
136        }
137        Value::Record { val, .. } => {
138            // Hooks can optionally be a record in this form:
139            // {
140            //     condition: {|before, after| ... }  # block that evaluates to true/false
141            //     code: # block or a string
142            // }
143            // The condition block will be run to check whether the main hook (in `code`) should be run.
144            // If it returns true (the default if a condition block is not specified), the hook should be run.
145            let do_run_hook = if let Some(condition) = val.get("condition") {
146                let other_span = condition.span();
147                if let Ok(closure) = condition.as_closure() {
148                    match run_hook(
149                        engine_state,
150                        stack,
151                        closure,
152                        None,
153                        arguments.clone(),
154                        other_span,
155                    ) {
156                        Ok(pipeline_data) => {
157                            if let PipelineData::Value(Value::Bool { val, .. }, ..) = pipeline_data
158                            {
159                                val
160                            } else {
161                                return Err(ShellError::RuntimeTypeMismatch {
162                                    expected: Type::Bool,
163                                    actual: pipeline_data.get_type(),
164                                    span: pipeline_data.span().unwrap_or(other_span),
165                                });
166                            }
167                        }
168                        Err(err) => {
169                            return Err(err);
170                        }
171                    }
172                } else {
173                    return Err(ShellError::RuntimeTypeMismatch {
174                        expected: Type::Closure,
175                        actual: condition.get_type(),
176                        span: other_span,
177                    });
178                }
179            } else {
180                // always run the hook
181                true
182            };
183
184            if do_run_hook {
185                let Some(follow) = val.get("code") else {
186                    return Err(ShellError::CantFindColumn {
187                        col_name: "code".into(),
188                        span: Some(span),
189                        src_span: span,
190                    });
191                };
192                let source_span = follow.span();
193                match follow {
194                    Value::String { val, .. } => {
195                        let (block, delta, vars) = {
196                            let mut working_set = StateWorkingSet::new(engine_state);
197
198                            let mut vars: Vec<(VarId, Value)> = vec![];
199
200                            for (name, val) in arguments {
201                                let var_id = working_set.add_variable(
202                                    name.as_bytes().to_vec(),
203                                    val.span(),
204                                    Type::Any,
205                                    false,
206                                );
207                                vars.push((var_id, val));
208                            }
209
210                            let output = parse(
211                                &mut working_set,
212                                Some(&format!("{hook_name} hook")),
213                                val.as_bytes(),
214                                false,
215                            );
216                            if let Some(err) = working_set.parse_errors.first() {
217                                report_parse_error(Some(stack), &working_set, err);
218                                return Err(ShellError::Generic(GenericError::new(
219                                    format!("Failed to run {hook_name} hook"),
220                                    "source code has errors",
221                                    span,
222                                )));
223                            }
224
225                            (output, working_set.render(), vars)
226                        };
227
228                        engine_state.merge_delta(delta)?;
229                        let input = PipelineData::empty();
230
231                        let var_ids: Vec<VarId> = vars
232                            .into_iter()
233                            .map(|(var_id, val)| {
234                                stack.add_var(var_id, val);
235                                var_id
236                            })
237                            .collect();
238
239                        match eval_block::<WithoutDebug>(engine_state, stack, &block, input)
240                            .map(|p| p.body)
241                        {
242                            Ok(pipeline_data) => {
243                                output = pipeline_data;
244                            }
245                            Err(err) => {
246                                report_shell_error(Some(stack), engine_state, &err);
247                            }
248                        }
249
250                        for var_id in var_ids.iter() {
251                            stack.remove_var(*var_id);
252                        }
253                    }
254                    Value::Closure { val, .. } => {
255                        run_hook(engine_state, stack, val, input, arguments, source_span)?;
256                    }
257                    other => {
258                        return Err(ShellError::RuntimeTypeMismatch {
259                            expected: Type::custom("string or closure"),
260                            actual: other.get_type(),
261                            span: source_span,
262                        });
263                    }
264                }
265            }
266        }
267        Value::Closure { val, .. } => {
268            output = run_hook(engine_state, stack, val, input, arguments, span)?;
269        }
270        other => {
271            return Err(ShellError::RuntimeTypeMismatch {
272                expected: Type::custom("string, closure, record, or list"),
273                actual: other.get_type(),
274                span: other.span(),
275            });
276        }
277    }
278
279    engine_state.merge_env(stack)?;
280
281    Ok(output)
282}
283
284fn run_hook(
285    engine_state: &EngineState,
286    stack: &mut Stack,
287    closure: &Closure,
288    optional_input: Option<PipelineData>,
289    arguments: Vec<(String, Value)>,
290    span: Span,
291) -> Result<PipelineData, ShellError> {
292    let block = engine_state.get_block(closure.block_id);
293
294    let input = optional_input.unwrap_or_else(PipelineData::empty);
295
296    let mut callee_stack = stack
297        .captures_to_stack_preserve_out_dest(closure.captures.clone())
298        .reset_pipes();
299
300    for (idx, PositionalArg { var_id, .. }) in
301        block.signature.required_positional.iter().enumerate()
302    {
303        if let Some(var_id) = var_id {
304            if let Some(arg) = arguments.get(idx) {
305                callee_stack.add_var(*var_id, arg.1.clone())
306            } else {
307                return Err(ShellError::IncompatibleParametersSingle {
308                    msg: "This hook block has too many parameters".into(),
309                    span,
310                });
311            }
312        }
313    }
314
315    let pipeline_data = eval_block_with_early_return::<WithoutDebug>(
316        engine_state,
317        &mut callee_stack,
318        block,
319        input,
320    )?
321    .body;
322
323    if let PipelineData::Value(Value::Error { error, .. }, _) = pipeline_data {
324        return Err(*error);
325    }
326
327    // If all went fine, preserve the environment of the called block
328    redirect_env(engine_state, stack, &callee_stack);
329
330    Ok(pipeline_data)
331}