Skip to main content

nu_command/stor/
update.rs

1use crate::database::{MEMORY_DB, SQLiteDatabase, values_to_sql};
2use nu_engine::command_prelude::*;
3use nu_protocol::Signals;
4use nu_protocol::shell_error::generic::GenericError;
5use rusqlite::params_from_iter;
6use std::fmt::Write;
7
8#[derive(Clone)]
9pub struct StorUpdate;
10
11impl Command for StorUpdate {
12    fn name(&self) -> &str {
13        "stor update"
14    }
15
16    fn signature(&self) -> Signature {
17        Signature::build("stor update")
18            .input_output_types(vec![
19                (Type::Nothing, Type::table()),
20                (Type::record(), Type::table()),
21                // FIXME Type::Any input added to disable pipeline input type checking, as run-time checks can raise undesirable type errors
22                // which aren't caught by the parser. see https://github.com/nushell/nushell/pull/14922 for more details
23                (Type::Any, Type::table()),
24            ])
25            .required_named(
26                "table-name",
27                SyntaxShape::String,
28                "Name of the table you want to insert into.",
29                Some('t'),
30            )
31            .named(
32                "update-record",
33                SyntaxShape::Record(vec![]),
34                "A record of column names and column values to update in the specified table.",
35                Some('u'),
36            )
37            .named(
38                "where-clause",
39                SyntaxShape::String,
40                "A sql string to use as a where clause without the WHERE keyword.",
41                Some('w'),
42            )
43            .allow_variants_without_examples(true)
44            .category(Category::Database)
45    }
46
47    fn description(&self) -> &str {
48        "Update information in a specified table in the in-memory sqlite database."
49    }
50
51    fn search_terms(&self) -> Vec<&str> {
52        vec!["sqlite", "storing", "table", "saving", "changing"]
53    }
54
55    fn examples(&self) -> Vec<Example<'_>> {
56        vec![
57            Example {
58                description: "Update the in-memory sqlite database",
59                example: "stor update --table-name nudb --update-record {str1: nushell datetime1: 2020-04-17}",
60                result: None,
61            },
62            Example {
63                description: "Update the in-memory sqlite database with a where clause",
64                example: "stor update --table-name nudb --update-record {str1: nushell datetime1: 2020-04-17} --where-clause \"bool1 = 1\"",
65                result: None,
66            },
67            Example {
68                description: "Update the in-memory sqlite database through pipeline input",
69                example: "{str1: nushell datetime1: 2020-04-17} | stor update --table-name nudb",
70                result: None,
71            },
72        ]
73    }
74
75    fn run(
76        &self,
77        engine_state: &EngineState,
78        stack: &mut Stack,
79        call: &Call,
80        input: PipelineData,
81    ) -> Result<PipelineData, ShellError> {
82        let span = call.head;
83        let table_name: Option<String> = call.get_flag(engine_state, stack, "table-name")?;
84        let update_record: Option<Record> = call.get_flag(engine_state, stack, "update-record")?;
85        let where_clause_opt: Option<Spanned<String>> =
86            call.get_flag(engine_state, stack, "where-clause")?;
87
88        // Open the in-mem database
89        let db = Box::new(SQLiteDatabase::new(
90            std::path::Path::new(MEMORY_DB),
91            Signals::empty(),
92        ));
93
94        // Check if the record is being passed as input or using the update record parameter
95        let columns = handle(span, update_record, input)?;
96
97        process(
98            engine_state,
99            table_name,
100            span,
101            &db,
102            columns,
103            where_clause_opt,
104        )?;
105
106        Ok(Value::custom(db, span).into_pipeline_data())
107    }
108}
109
110fn handle(
111    span: Span,
112    update_record: Option<Record>,
113    input: PipelineData,
114) -> Result<Record, ShellError> {
115    match input {
116        PipelineData::Empty => update_record.ok_or_else(|| ShellError::MissingParameter {
117            param_name: "requires a record".into(),
118            span,
119        }),
120        PipelineData::Value(value, ..) => {
121            // Since input is being used, check if the data record parameter is used too
122            if update_record.is_some() {
123                return Err(ShellError::Generic(GenericError::new(
124                    "Pipeline and Flag both being used",
125                    "Use either pipeline input or '--update-record' parameter",
126                    span,
127                )));
128            }
129            match value {
130                Value::Record { val, .. } => Ok(val.into_owned()),
131                val => Err(ShellError::OnlySupportsThisInputType {
132                    exp_input_type: "record".into(),
133                    wrong_type: val.get_type().to_string(),
134                    dst_span: span,
135                    src_span: val.span(),
136                }),
137            }
138        }
139        _ => {
140            if update_record.is_some() {
141                return Err(ShellError::Generic(GenericError::new(
142                    "Pipeline and Flag both being used",
143                    "Use either pipeline input or '--update-record' parameter",
144                    span,
145                )));
146            }
147            Err(ShellError::OnlySupportsThisInputType {
148                exp_input_type: "record".into(),
149                wrong_type: "".into(),
150                dst_span: span,
151                src_span: span,
152            })
153        }
154    }
155}
156
157fn process(
158    engine_state: &EngineState,
159    table_name: Option<String>,
160    span: Span,
161    db: &SQLiteDatabase,
162    record: Record,
163    where_clause_opt: Option<Spanned<String>>,
164) -> Result<(), ShellError> {
165    if table_name.is_none() {
166        return Err(ShellError::MissingParameter {
167            param_name: "requires at table name".into(),
168            span,
169        });
170    }
171    let new_table_name = table_name.unwrap_or("table".into());
172    if let Ok(conn) = db.open_connection() {
173        let mut update_stmt = format!("UPDATE {new_table_name} ");
174
175        update_stmt.push_str("SET ");
176        let mut placeholders: Vec<String> = Vec::new();
177
178        for (index, (key, _)) in record.iter().enumerate() {
179            placeholders.push(format!("{} = ?{}", key, index + 1));
180        }
181        update_stmt.push_str(&placeholders.join(", "));
182
183        // Yup, this is a bit janky, but I'm not sure a better way to do this without having
184        // --and and --or flags as well as supporting ==, !=, <>, is null, is not null, etc.
185        // and other sql syntax. So, for now, just type a sql where clause as a string.
186        if let Some(where_clause) = where_clause_opt {
187            write!(update_stmt, " WHERE {}", where_clause.item)
188                .expect("writing to a String is infallible");
189        }
190        // dbg!(&update_stmt);
191
192        // Get the params from the passed values
193        let params = values_to_sql(engine_state, record.values().cloned(), span)?;
194
195        conn.execute(&update_stmt, params_from_iter(params))
196            .map_err(|err| {
197                ShellError::Generic(GenericError::new_internal(
198                    "Failed to open SQLite connection in memory from update",
199                    err.to_string(),
200                ))
201            })?;
202    }
203    // dbg!(db.clone());
204    Ok(())
205}
206
207#[cfg(test)]
208mod test {
209    use super::*;
210
211    #[test]
212    fn test_examples() -> nu_test_support::Result {
213        nu_test_support::test().examples(StorUpdate)
214    }
215}