msi/internal/
package.rs

1use crate::internal::category::Category;
2use crate::internal::codepage::CodePage;
3use crate::internal::column::Column;
4use crate::internal::expr::Expr;
5use crate::internal::query::{Delete, Insert, Select, Update};
6use crate::internal::stream::{StreamReader, StreamWriter, Streams};
7use crate::internal::streamname::{
8    self, DIGITAL_SIGNATURE_STREAM_NAME, MSI_DIGITAL_SIGNATURE_EX_STREAM_NAME,
9    SUMMARY_INFO_STREAM_NAME,
10};
11use crate::internal::stringpool::{StringPool, StringPoolBuilder};
12use crate::internal::summary::SummaryInfo;
13use crate::internal::table::{Rows, Table};
14use crate::internal::value::{Value, ValueRef};
15use cfb;
16use std::borrow::Borrow;
17use std::collections::{btree_map, BTreeMap, HashMap, HashSet};
18use std::io::{self, Read, Seek, Write};
19use std::rc::Rc;
20use uuid::Uuid;
21
22// ========================================================================= //
23
24const INSTALLER_PACKAGE_CLSID: &str = "000C1084-0000-0000-C000-000000000046";
25const PATCH_PACKAGE_CLSID: &str = "000C1086-0000-0000-C000-000000000046";
26const TRANSFORM_PACKAGE_CLSID: &str = "000C1082-0000-0000-C000-000000000046";
27
28const COLUMNS_TABLE_NAME: &str = "_Columns";
29const TABLES_TABLE_NAME: &str = "_Tables";
30const VALIDATION_TABLE_NAME: &str = "_Validation";
31
32const STRING_DATA_TABLE_NAME: &str = "_StringData";
33const STRING_POOL_TABLE_NAME: &str = "_StringPool";
34
35const MAX_NUM_TABLE_COLUMNS: usize = 32;
36
37// ========================================================================= //
38
39fn make_columns_table(long_string_refs: bool) -> Rc<Table> {
40    Table::new(
41        COLUMNS_TABLE_NAME.to_string(),
42        vec![
43            Column::build("Table").primary_key().string(64),
44            Column::build("Number").primary_key().int16(),
45            Column::build("Name").string(64),
46            Column::build("Type").int16(),
47        ],
48        long_string_refs,
49    )
50}
51
52fn make_tables_table(long_string_refs: bool) -> Rc<Table> {
53    Table::new(
54        TABLES_TABLE_NAME.to_string(),
55        vec![Column::build("Name").primary_key().string(64)],
56        long_string_refs,
57    )
58}
59
60fn make_validation_columns() -> Vec<Column> {
61    let min = -0x7fff_ffff;
62    let max = 0x7fff_ffff;
63    let values: Vec<&str> =
64        Category::all().into_iter().map(Category::as_str).collect();
65    vec![
66        Column::build("Table").primary_key().id_string(32),
67        Column::build("Column").primary_key().id_string(32),
68        Column::build("Nullable").enum_values(&["Y", "N"]).string(4),
69        Column::build("MinValue").nullable().range(min, max).int32(),
70        Column::build("MaxValue").nullable().range(min, max).int32(),
71        Column::build("KeyTable").nullable().id_string(255),
72        Column::build("KeyColumn").nullable().range(1, 32).int16(),
73        Column::build("Category").nullable().enum_values(&values).string(32),
74        Column::build("Set").nullable().text_string(255),
75        Column::build("Description").nullable().text_string(255),
76    ]
77}
78
79fn make_validation_table(long_string_refs: bool) -> Rc<Table> {
80    Table::new(
81        VALIDATION_TABLE_NAME.to_string(),
82        make_validation_columns(),
83        long_string_refs,
84    )
85}
86
87fn is_reserved_table_name(table_name: &str) -> bool {
88    table_name == COLUMNS_TABLE_NAME
89        || table_name == TABLES_TABLE_NAME
90        || table_name == VALIDATION_TABLE_NAME
91}
92
93// ========================================================================= //
94
95/// The type of MSI package (e.g. installer or patch).
96#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
97pub enum PackageType {
98    /// An installer package, which installs a new application.
99    Installer,
100    /// A patch package, which provides an update to an application.
101    Patch,
102    /// A transform, which is a collection of changes applied to an
103    /// installation.
104    Transform,
105}
106
107impl PackageType {
108    fn from_clsid(clsid: &Uuid) -> Option<PackageType> {
109        if *clsid == PackageType::Installer.clsid() {
110            Some(PackageType::Installer)
111        } else if *clsid == PackageType::Patch.clsid() {
112            Some(PackageType::Patch)
113        } else if *clsid == PackageType::Transform.clsid() {
114            Some(PackageType::Transform)
115        } else {
116            None
117        }
118    }
119
120    fn clsid(self) -> Uuid {
121        match self {
122            PackageType::Installer => {
123                Uuid::parse_str(INSTALLER_PACKAGE_CLSID).unwrap()
124            }
125            PackageType::Patch => {
126                Uuid::parse_str(PATCH_PACKAGE_CLSID).unwrap()
127            }
128            PackageType::Transform => {
129                Uuid::parse_str(TRANSFORM_PACKAGE_CLSID).unwrap()
130            }
131        }
132    }
133
134    fn default_title(&self) -> &str {
135        match *self {
136            PackageType::Installer => "Installation Database",
137            PackageType::Patch => "Patch",
138            PackageType::Transform => "Transform",
139        }
140    }
141}
142
143// ========================================================================= //
144
145/// An MSI package file, backed by an underlying reader/writer (such as a
146/// [`File`](https://doc.rust-lang.org/std/fs/struct.File.html) or
147/// [`Cursor`](https://doc.rust-lang.org/std/io/struct.Cursor.html)).
148///
149/// # Examples
150///
151/// ```
152/// use msi::{Column, Expr, Insert, Package, PackageType, Select, Value};
153/// use std::io::Cursor;
154///
155/// // Create an in-memory package using a Cursor:
156/// let cursor = Cursor::new(Vec::new());
157/// let mut package = Package::create(PackageType::Installer, cursor).unwrap();
158/// // Set some summary information:
159/// package.summary_info_mut().set_author("Jane Doe".to_string());
160/// // Add a table to the package:
161/// let columns = vec![
162///     Column::build("Property").primary_key().id_string(72),
163///     Column::build("Value").nullable().formatted_string(64),
164/// ];
165/// package.create_table("CheckBox", columns).unwrap();
166/// // Add a row to the new table:
167/// let query = Insert::into("CheckBox").row(vec![
168///     Value::from("MoreMagic"),
169///     Value::from("Whether magic should be maximized"),
170/// ]);
171/// package.insert_rows(query).unwrap();
172/// // Close the package and get the cursor back out.
173/// let cursor = package.into_inner().unwrap();
174///
175/// // Now, re-open the package and make sure our data is still there.
176/// let mut package = Package::open(cursor).unwrap();
177/// assert_eq!(package.summary_info().author(), Some("Jane Doe"));
178/// let query = Select::table("CheckBox")
179///     .with(Expr::col("Property").eq(Expr::string("MoreMagic")));
180/// let mut rows = package.select_rows(query).unwrap();
181/// assert_eq!(rows.len(), 1);
182/// let row = rows.next().unwrap();
183/// assert_eq!(row["Property"], Value::Str("MoreMagic".to_string()));
184/// assert_eq!(row["Value"],
185///            Value::Str("Whether magic should be maximized".to_string()));
186/// ```
187pub struct Package<F> {
188    // The comp field is always `Some`, unless we are about to destroy the
189    // `Package` object.  The only reason for it to be an `Option` is to make
190    // it possible for the `into_inner()` method to move the `CompoundFile` out
191    // of the `Package` object, even though `Package` implements `Drop`
192    // (normally you can't move fields out an object that implements `Drop`).
193    comp: Option<cfb::CompoundFile<F>>,
194    package_type: PackageType,
195    summary_info: SummaryInfo,
196    is_summary_info_modified: bool,
197    string_pool: StringPool,
198    tables: BTreeMap<String, Rc<Table>>,
199    finisher: Option<Box<dyn Finish<F>>>,
200}
201
202impl<F> Package<F> {
203    /// Returns what type of package this is.
204    #[must_use]
205    pub fn package_type(&self) -> PackageType {
206        self.package_type
207    }
208
209    /// Returns summary information for this package.
210    #[must_use]
211    pub fn summary_info(&self) -> &SummaryInfo {
212        &self.summary_info
213    }
214
215    /// Returns the code page used for serializing strings in the database.
216    #[must_use]
217    pub fn database_codepage(&self) -> CodePage {
218        self.string_pool.codepage()
219    }
220
221    /// Returns true if the database has a table with the given name.
222    #[must_use]
223    pub fn has_table(&self, table_name: &str) -> bool {
224        self.tables.contains_key(table_name)
225    }
226
227    /// Returns the database table with the given name (if any).
228    pub fn get_table(&self, table_name: &str) -> Option<&Table> {
229        self.tables.get(table_name).map(Rc::borrow)
230    }
231
232    /// Returns an iterator over the database tables in this package.
233    #[must_use]
234    pub fn tables(&self) -> Tables {
235        Tables { iter: self.tables.values() }
236    }
237
238    /// Returns true if the package has an embedded binary stream with the
239    /// given name.
240    #[must_use]
241    pub fn has_stream(&self, stream_name: &str) -> bool {
242        self.comp().is_stream(streamname::encode(stream_name, false))
243    }
244
245    /// Returns an iterator over the embedded binary streams in this package.
246    #[must_use]
247    pub fn streams(&self) -> Streams<F> {
248        Streams::new(self.comp().read_root_storage())
249    }
250
251    /// Returns true if the package has been digitally signed.  Note that this
252    /// method only checks whether a signature is present; it does *not* verify
253    /// that the signature is actually valid.
254    #[must_use]
255    pub fn has_digital_signature(&self) -> bool {
256        self.comp().is_stream(DIGITAL_SIGNATURE_STREAM_NAME)
257    }
258
259    /// Consumes the `Package` object, returning the underlying reader/writer.
260    pub fn into_inner(mut self) -> io::Result<F> {
261        if let Some(finisher) = self.finisher.take() {
262            finisher.finish(&mut self)?;
263        }
264        Ok(self.comp.take().unwrap().into_inner())
265    }
266
267    fn comp(&self) -> &cfb::CompoundFile<F> {
268        self.comp.as_ref().unwrap()
269    }
270
271    fn comp_mut(&mut self) -> &mut cfb::CompoundFile<F> {
272        self.comp.as_mut().unwrap()
273    }
274}
275
276impl<F: Read + Seek> Package<F> {
277    /// Opens an existing MSI file, using the underlying reader.  If the
278    /// underlying reader also supports the `Write` trait, then the `Package`
279    /// object will be writable as well.
280    pub fn open(inner: F) -> io::Result<Package<F>> {
281        let mut comp = cfb::CompoundFile::open(inner)?;
282        let package_type = {
283            let root_entry = comp.root_entry();
284            let clsid = root_entry.clsid();
285            match PackageType::from_clsid(clsid) {
286                Some(ptype) => ptype,
287                None => invalid_data!(
288                    "Unrecognized package CLSID ({})",
289                    clsid.hyphenated()
290                ),
291            }
292        };
293        let summary_info =
294            SummaryInfo::read(comp.open_stream(SUMMARY_INFO_STREAM_NAME)?)?;
295        let string_pool = {
296            let builder = {
297                let name = streamname::encode(STRING_POOL_TABLE_NAME, true);
298                let stream = comp.open_stream(name)?;
299                StringPoolBuilder::read_from_pool(stream)?
300            };
301            let name = streamname::encode(STRING_DATA_TABLE_NAME, true);
302            let stream = comp.open_stream(name)?;
303            builder.build_from_data(stream)?
304        };
305        let mut all_tables = BTreeMap::<String, Rc<Table>>::new();
306        // Read in _Tables table:
307        let table_names: HashSet<String> = {
308            let table = make_tables_table(string_pool.long_string_refs());
309            let stream_name = table.stream_name();
310            let mut names = HashSet::<String>::new();
311            if comp.exists(&stream_name) {
312                let stream = comp.open_stream(&stream_name)?;
313                let rows = Rows::new(
314                    &string_pool,
315                    table.clone(),
316                    table.read_rows(stream)?,
317                );
318                for row in rows {
319                    let table_name = row[0].as_str().unwrap().to_string();
320                    if names.contains(&table_name) {
321                        invalid_data!(
322                            "Repeated key in {:?} table: {:?}",
323                            TABLES_TABLE_NAME,
324                            table_name
325                        );
326                    }
327                    names.insert(table_name);
328                }
329            }
330            all_tables.insert(table.name().to_string(), table);
331            names
332        };
333        // Read in _Columns table:
334        let mut columns_map: HashMap<String, BTreeMap<i32, (String, i32)>> =
335            table_names
336                .into_iter()
337                .map(|name| (name, BTreeMap::new()))
338                .collect();
339        {
340            let table = make_columns_table(string_pool.long_string_refs());
341            let stream_name = table.stream_name();
342            if comp.exists(&stream_name) {
343                let stream = comp.open_stream(&stream_name)?;
344                let rows = Rows::new(
345                    &string_pool,
346                    table.clone(),
347                    table.read_rows(stream)?,
348                );
349                for row in rows {
350                    let table_name = row[0].as_str().unwrap();
351                    if let Some(cols) = columns_map.get_mut(table_name) {
352                        let col_index = row[1].as_int().unwrap();
353                        if cols.contains_key(&col_index) {
354                            invalid_data!(
355                                "Repeated key in {:?} table: {:?}",
356                                COLUMNS_TABLE_NAME,
357                                (table_name, col_index)
358                            );
359                        }
360                        let col_name = row[2].as_str().unwrap().to_string();
361                        let type_bits = row[3].as_int().unwrap();
362                        cols.insert(col_index, (col_name, type_bits));
363                    } else {
364                        invalid_data!(
365                            "_Columns mentions table {:?}, which isn't in \
366                             _Tables",
367                            table_name
368                        );
369                    }
370                }
371            }
372            all_tables.insert(table.name().to_string(), table);
373        }
374        // Read in _Validation table:
375        let mut validation_map =
376            HashMap::<(String, String), Vec<ValueRef>>::new();
377        {
378            let table = make_validation_table(string_pool.long_string_refs());
379            // TODO: Ensure that columns_map["_Validation"].columns() matches
380            // the hard-coded validation table definition.
381            let stream_name = table.stream_name();
382            if comp.exists(&stream_name) {
383                let stream = comp.open_stream(&stream_name)?;
384                for value_refs in table.read_rows(stream)? {
385                    let table_name = value_refs[0]
386                        .to_value(&string_pool)
387                        .as_str()
388                        .unwrap()
389                        .to_string();
390                    let column_name = value_refs[1]
391                        .to_value(&string_pool)
392                        .as_str()
393                        .unwrap()
394                        .to_string();
395                    let key = (table_name, column_name);
396                    if validation_map.contains_key(&key) {
397                        invalid_data!(
398                            "Repeated key in {:?} table: {:?}",
399                            VALIDATION_TABLE_NAME,
400                            key
401                        );
402                    }
403                    validation_map.insert(key, value_refs);
404                }
405            }
406        }
407        // Construct Table objects from column/validation data:
408        for (table_name, column_specs) in columns_map {
409            if column_specs.is_empty() {
410                invalid_data!("No columns found for table {:?}", table_name);
411            }
412            let num_columns = column_specs.len() as i32;
413            if column_specs.keys().next() != Some(&1)
414                || column_specs.keys().next_back() != Some(&num_columns)
415            {
416                invalid_data!(
417                    "Table {:?} does not have a complete set of columns",
418                    table_name
419                );
420            }
421            let mut columns = Vec::<Column>::with_capacity(column_specs.len());
422            for (_, (column_name, bitfield)) in column_specs {
423                let mut builder = Column::build(column_name.as_str());
424                let key = (table_name.clone(), column_name);
425                if let Some(value_refs) = validation_map.get(&key) {
426                    let is_nullable = value_refs[2].to_value(&string_pool);
427                    if is_nullable.as_str().unwrap() == "Y" {
428                        builder = builder.nullable();
429                    }
430                    let min_value = value_refs[3].to_value(&string_pool);
431                    let max_value = value_refs[4].to_value(&string_pool);
432                    if !min_value.is_null() && !max_value.is_null() {
433                        let min = min_value.as_int().unwrap();
434                        let max = max_value.as_int().unwrap();
435                        builder = builder.range(min, max);
436                    }
437                    let key_table = value_refs[5].to_value(&string_pool);
438                    let key_column = value_refs[6].to_value(&string_pool);
439                    if !key_table.is_null() && !key_column.is_null() {
440                        builder = builder.foreign_key(
441                            key_table.as_str().unwrap(),
442                            key_column.as_int().unwrap(),
443                        );
444                    }
445                    let category_value = value_refs[7].to_value(&string_pool);
446                    if !category_value.is_null() {
447                        let category = category_value
448                            .as_str()
449                            .unwrap()
450                            .parse::<Category>()
451                            .ok();
452                        if let Some(category) = category {
453                            builder = builder.category(category);
454                        }
455                    }
456                    let enum_values = value_refs[8].to_value(&string_pool);
457                    if !enum_values.is_null() {
458                        let enum_values: Vec<&str> =
459                            enum_values.as_str().unwrap().split(';').collect();
460                        builder = builder.enum_values(&enum_values);
461                    }
462                }
463                columns.push(builder.with_bitfield(bitfield)?);
464            }
465            let table = Table::new(
466                table_name,
467                columns,
468                string_pool.long_string_refs(),
469            );
470            all_tables.insert(table.name().to_string(), table);
471        }
472        Ok(Package {
473            comp: Some(comp),
474            package_type,
475            summary_info,
476            is_summary_info_modified: false,
477            string_pool,
478            tables: all_tables,
479            finisher: None,
480        })
481    }
482
483    /// Attempts to execute a select query.  Returns an error if the query
484    /// fails (e.g. due to the column names being incorrect or the table(s) not
485    /// existing).
486    pub fn select_rows(&mut self, query: Select) -> io::Result<Rows> {
487        query.exec(
488            self.comp.as_mut().unwrap(),
489            &self.string_pool,
490            &self.tables,
491        )
492    }
493
494    /// Opens an existing binary stream in the package for reading.
495    pub fn read_stream(
496        &mut self,
497        stream_name: &str,
498    ) -> io::Result<StreamReader<F>> {
499        if !streamname::is_valid(stream_name, false) {
500            invalid_input!("{:?} is not a valid stream name", stream_name);
501        }
502        let encoded_name = streamname::encode(stream_name, false);
503        if !self.comp().is_stream(&encoded_name) {
504            not_found!("Stream {:?} does not exist", stream_name);
505        }
506        Ok(StreamReader::new(self.comp_mut().open_stream(&encoded_name)?))
507    }
508
509    // TODO: pub fn has_valid_digital_signature(&mut self) -> io::Result<bool>
510}
511
512impl<F: Read + Write + Seek> Package<F> {
513    /// Creates a new, empty package of the given type, using the underlying
514    /// reader/writer.  The reader/writer should be initially empty.
515    pub fn create(
516        package_type: PackageType,
517        inner: F,
518    ) -> io::Result<Package<F>> {
519        let mut comp = cfb::CompoundFile::create(inner)?;
520        comp.set_storage_clsid("/", package_type.clsid())?;
521        let mut summary_info = SummaryInfo::new();
522        summary_info.set_title(package_type.default_title().to_string());
523        let string_pool = StringPool::new(summary_info.codepage());
524        let tables = {
525            let mut tables = BTreeMap::<String, Rc<Table>>::new();
526            let table = make_tables_table(string_pool.long_string_refs());
527            tables.insert(table.name().to_string(), table);
528            let table = make_columns_table(string_pool.long_string_refs());
529            tables.insert(table.name().to_string(), table);
530            tables
531        };
532        let mut package = Package {
533            comp: Some(comp),
534            package_type,
535            summary_info,
536            is_summary_info_modified: true,
537            string_pool,
538            tables,
539            finisher: None,
540        };
541        package
542            .create_table(VALIDATION_TABLE_NAME, make_validation_columns())?;
543        package.flush()?;
544        debug_assert!(!package.is_summary_info_modified);
545        debug_assert!(!package.string_pool.is_modified());
546        Ok(package)
547    }
548
549    /// Returns a mutable reference to the summary information for this
550    /// package.  Call `flush()` or drop the `Package` object to persist any
551    /// changes made to the underlying writer.
552    pub fn summary_info_mut(&mut self) -> &mut SummaryInfo {
553        self.is_summary_info_modified = true;
554        self.set_finisher();
555        &mut self.summary_info
556    }
557
558    /// Sets the code page used for serializing strings in the database.
559    pub fn set_database_codepage(&mut self, codepage: CodePage) {
560        self.set_finisher();
561        self.string_pool.set_codepage(codepage)
562    }
563
564    /// Creates a new database table.  Returns an error without modifying the
565    /// database if the table name or columns are invalid, or if a table with
566    /// that name already exists.
567    pub fn create_table<S: Into<String>>(
568        &mut self,
569        table_name: S,
570        columns: Vec<Column>,
571    ) -> io::Result<()> {
572        self.create_table_with_name(table_name.into(), columns)
573    }
574
575    fn create_table_with_name(
576        &mut self,
577        table_name: String,
578        columns: Vec<Column>,
579    ) -> io::Result<()> {
580        if !Table::is_valid_name(&table_name) {
581            invalid_input!("{:?} is not a valid table name", table_name);
582        }
583        if columns.is_empty() {
584            invalid_input!("Cannot create a table with no columns");
585        }
586        if columns.len() > MAX_NUM_TABLE_COLUMNS {
587            invalid_input!(
588                "Cannot create a table with more than {} columns",
589                MAX_NUM_TABLE_COLUMNS
590            );
591        }
592        if !columns.iter().any(Column::is_primary_key) {
593            invalid_input!(
594                "Cannot create a table without at least one primary key column"
595            );
596        }
597        {
598            let mut column_names = HashSet::<&str>::new();
599            for column in &columns {
600                let name = column.name();
601                if !Column::is_valid_name(name) {
602                    invalid_input!("{:?} is not a valid column name", name);
603                }
604                if column_names.contains(name) {
605                    invalid_input!(
606                        "Cannot create a table with multiple columns with the \
607                         same name ({:?})",
608                        name
609                    );
610                }
611                column_names.insert(name);
612            }
613        }
614        if self.tables.contains_key(&table_name) {
615            already_exists!("Table {:?} already exists", table_name);
616        }
617        self.insert_rows(
618            Insert::into(COLUMNS_TABLE_NAME).rows(
619                columns
620                    .iter()
621                    .enumerate()
622                    .map(|(index, column)| {
623                        vec![
624                            Value::Str(table_name.clone()),
625                            Value::Int(1 + index as i32),
626                            Value::Str(column.name().to_string()),
627                            Value::Int(column.bitfield()),
628                        ]
629                    })
630                    .collect(),
631            ),
632        )?;
633        self.insert_rows(
634            Insert::into(TABLES_TABLE_NAME)
635                .row(vec![Value::Str(table_name.clone())]),
636        )?;
637        let validation_rows: Vec<Vec<Value>> = columns
638            .iter()
639            .map(|column| {
640                let (min_value, max_value) =
641                    if let Some((min, max)) = column.value_range() {
642                        (Value::Int(min), Value::Int(max))
643                    } else {
644                        (Value::Null, Value::Null)
645                    };
646                let (key_table, key_column) =
647                    if let Some((table, column)) = column.foreign_key() {
648                        (Value::Str(table.to_string()), Value::Int(column))
649                    } else {
650                        (Value::Null, Value::Null)
651                    };
652                vec![
653                    Value::Str(table_name.clone()),
654                    Value::Str(column.name().to_string()),
655                    Value::Str(if column.is_nullable() {
656                        "Y".to_string()
657                    } else {
658                        "N".to_string()
659                    }),
660                    min_value,
661                    max_value,
662                    key_table,
663                    key_column,
664                    if let Some(category) = column.category() {
665                        Value::Str(category.to_string())
666                    } else {
667                        Value::Null
668                    },
669                    if let Some(values) = column.enum_values() {
670                        Value::Str(values.join(";"))
671                    } else {
672                        Value::Null
673                    },
674                    Value::Null,
675                ]
676            })
677            .collect();
678        let long_string_refs = self.string_pool.long_string_refs();
679        let table = Table::new(table_name.clone(), columns, long_string_refs);
680        self.tables.insert(table_name, table);
681        self.insert_rows(
682            Insert::into(VALIDATION_TABLE_NAME).rows(validation_rows),
683        )?;
684        Ok(())
685    }
686
687    /// Removes an existing database table.  Returns an error without modifying
688    /// the database if the table name is invalid, or if no such table exists.
689    pub fn drop_table(&mut self, table_name: &str) -> io::Result<()> {
690        if is_reserved_table_name(table_name) {
691            invalid_input!("Cannot drop special {:?} table", table_name);
692        }
693        if !Table::is_valid_name(table_name) {
694            invalid_input!("{:?} is not a valid table name", table_name);
695        }
696        if !self.tables.contains_key(table_name) {
697            not_found!("Table {:?} does not exist", table_name);
698        }
699        let stream_name = self.tables.get(table_name).unwrap().stream_name();
700        if self.comp().exists(&stream_name) {
701            self.comp_mut().remove_stream(&stream_name)?;
702        }
703        self.delete_rows(
704            Delete::from(VALIDATION_TABLE_NAME)
705                .with(Expr::col("Table").eq(Expr::string(table_name))),
706        )?;
707        self.delete_rows(
708            Delete::from(COLUMNS_TABLE_NAME)
709                .with(Expr::col("Table").eq(Expr::string(table_name))),
710        )?;
711        self.delete_rows(
712            Delete::from(TABLES_TABLE_NAME)
713                .with(Expr::col("Name").eq(Expr::string(table_name))),
714        )?;
715        self.tables.remove(table_name);
716        Ok(())
717    }
718
719    /// Attempts to execute a delete query.  Returns an error without modifying
720    /// the database if the query fails (e.g. due to the table not existing).
721    pub fn delete_rows(&mut self, query: Delete) -> io::Result<()> {
722        self.set_finisher();
723        query.exec(
724            self.comp.as_mut().unwrap(),
725            &mut self.string_pool,
726            &self.tables,
727        )
728    }
729
730    /// Attempts to execute an insert query.  Returns an error without
731    /// modifying the database if the query fails (e.g. due to values being
732    /// invalid, or keys not being unique, or the table not existing).
733    pub fn insert_rows(&mut self, query: Insert) -> io::Result<()> {
734        self.set_finisher();
735        query.exec(
736            self.comp.as_mut().unwrap(),
737            &mut self.string_pool,
738            &self.tables,
739        )
740    }
741
742    /// Attempts to execute an update query.  Returns an error without
743    /// modifying the database if the query fails (e.g. due to values being
744    /// invalid, or column names being incorrect, or the table not existing).
745    pub fn update_rows(&mut self, query: Update) -> io::Result<()> {
746        self.set_finisher();
747        query.exec(
748            self.comp.as_mut().unwrap(),
749            &mut self.string_pool,
750            &self.tables,
751        )
752    }
753
754    /// Creates (or overwrites) a binary stream in the package.
755    pub fn write_stream(
756        &mut self,
757        stream_name: &str,
758    ) -> io::Result<StreamWriter<F>> {
759        if !streamname::is_valid(stream_name, false) {
760            invalid_input!("{:?} is not a valid stream name", stream_name);
761        }
762        let encoded_name = streamname::encode(stream_name, false);
763        Ok(StreamWriter::new(self.comp_mut().create_stream(&encoded_name)?))
764    }
765
766    /// Removes an existing binary stream from the package.
767    pub fn remove_stream(&mut self, stream_name: &str) -> io::Result<()> {
768        if !streamname::is_valid(stream_name, false) {
769            invalid_input!("{:?} is not a valid stream name", stream_name);
770        }
771        let encoded_name = streamname::encode(stream_name, false);
772        if !self.comp().is_stream(&encoded_name) {
773            not_found!("Stream {:?} does not exist", stream_name);
774        }
775        self.comp_mut().remove_stream(&encoded_name)
776    }
777
778    // TODO: pub fn add_digital_signature(&mut self, ...) -> io::Result<()>
779
780    /// Removes any existing digital signature from the package.  This can be
781    /// useful if you need to modify a signed package (which will invalidate
782    /// the signature).
783    pub fn remove_digital_signature(&mut self) -> io::Result<()> {
784        if self.comp().is_stream(DIGITAL_SIGNATURE_STREAM_NAME) {
785            self.comp_mut().remove_stream(DIGITAL_SIGNATURE_STREAM_NAME)?;
786        }
787        if self.comp().is_stream(MSI_DIGITAL_SIGNATURE_EX_STREAM_NAME) {
788            self.comp_mut()
789                .remove_stream(MSI_DIGITAL_SIGNATURE_EX_STREAM_NAME)?;
790        }
791        Ok(())
792    }
793
794    /// Flushes any buffered changes to the underlying writer.
795    pub fn flush(&mut self) -> io::Result<()> {
796        if let Some(finisher) = self.finisher.take() {
797            finisher.finish(self)?;
798        }
799        self.comp_mut().flush()
800    }
801
802    fn set_finisher(&mut self) {
803        if self.finisher.is_none() {
804            let finisher: Box<dyn Finish<F>> = Box::new(FinishImpl {});
805            self.finisher = Some(finisher);
806        }
807    }
808}
809
810impl<F> Drop for Package<F> {
811    fn drop(&mut self) {
812        if let Some(finisher) = self.finisher.take() {
813            let _ = finisher.finish(self);
814        }
815    }
816}
817
818// ========================================================================= //
819
820/// An iterator over the database tables in a package.
821///
822/// No guarantees are made about the order in which items are returned.
823#[derive(Clone)]
824pub struct Tables<'a> {
825    iter: btree_map::Values<'a, String, Rc<Table>>,
826}
827
828impl<'a> Iterator for Tables<'a> {
829    type Item = &'a Table;
830
831    fn next(&mut self) -> Option<&'a Table> {
832        self.iter.next().map(Rc::borrow)
833    }
834
835    fn size_hint(&self) -> (usize, Option<usize>) {
836        self.iter.size_hint()
837    }
838}
839
840impl<'a> ExactSizeIterator for Tables<'a> {}
841
842// ========================================================================= //
843
844trait Finish<F> {
845    fn finish(&self, package: &mut Package<F>) -> io::Result<()>;
846}
847
848struct FinishImpl {}
849
850impl<F: Read + Write + Seek> Finish<F> for FinishImpl {
851    fn finish(&self, package: &mut Package<F>) -> io::Result<()> {
852        if package.is_summary_info_modified {
853            let stream = package
854                .comp
855                .as_mut()
856                .unwrap()
857                .create_stream(SUMMARY_INFO_STREAM_NAME)?;
858            package.summary_info.write(stream)?;
859            package.is_summary_info_modified = false;
860        }
861        if package.string_pool.is_modified() {
862            {
863                let name = streamname::encode(STRING_POOL_TABLE_NAME, true);
864                let stream =
865                    package.comp.as_mut().unwrap().create_stream(name)?;
866                package.string_pool.write_pool(stream)?;
867            }
868            {
869                let name = streamname::encode(STRING_DATA_TABLE_NAME, true);
870                let stream =
871                    package.comp.as_mut().unwrap().create_stream(name)?;
872                package.string_pool.write_data(stream)?;
873            }
874            package.string_pool.mark_unmodified();
875        }
876        Ok(())
877    }
878}
879
880// ========================================================================= //
881
882#[cfg(test)]
883mod tests {
884    use super::{Package, PackageType};
885    use crate::internal::codepage::CodePage;
886    use crate::internal::column::Column;
887    use crate::internal::expr::Expr;
888    use crate::internal::query::{Insert, Select, Update};
889    use crate::internal::value::Value;
890    use std::io::Cursor;
891
892    #[test]
893    fn set_database_codepage() {
894        let cursor = Cursor::new(Vec::new());
895        let mut package =
896            Package::create(PackageType::Installer, cursor).expect("create");
897        assert_eq!(package.database_codepage(), CodePage::Utf8);
898        package.set_database_codepage(CodePage::MacintoshRoman);
899        assert_eq!(package.database_codepage(), CodePage::MacintoshRoman);
900
901        let cursor = package.into_inner().expect("into_inner");
902        let package = Package::open(cursor).expect("open");
903        assert_eq!(package.database_codepage(), CodePage::MacintoshRoman);
904    }
905
906    #[test]
907    fn insert_rows() {
908        let cursor = Cursor::new(Vec::new());
909        let mut package =
910            Package::create(PackageType::Installer, cursor).expect("create");
911        let columns = vec![
912            Column::build("Number").primary_key().int16(),
913            Column::build("Word").nullable().string(50),
914        ];
915        package.create_table("Numbers", columns).expect("create_table");
916        package
917            .insert_rows(
918                Insert::into("Numbers")
919                    .row(vec![Value::Int(2), Value::Str("Two".to_string())])
920                    .row(vec![Value::Int(4), Value::Str("Four".to_string())])
921                    .row(vec![Value::Int(1), Value::Str("One".to_string())]),
922            )
923            .expect("insert_rows");
924        assert_eq!(
925            package
926                .select_rows(Select::table("Numbers"))
927                .expect("select")
928                .len(),
929            3
930        );
931
932        let cursor = package.into_inner().expect("into_inner");
933        let mut package = Package::open(cursor).expect("open");
934        let rows = package.select_rows(Select::table("Numbers")).unwrap();
935        assert_eq!(rows.len(), 3);
936        let values: Vec<(i32, String)> = rows
937            .map(|row| {
938                (
939                    row[0].as_int().unwrap(),
940                    row[1].as_str().unwrap().to_string(),
941                )
942            })
943            .collect();
944        assert_eq!(
945            values,
946            vec![
947                (1, "One".to_string()),
948                (2, "Two".to_string()),
949                (4, "Four".to_string()),
950            ]
951        );
952    }
953
954    #[test]
955    fn update_rows() {
956        let cursor = Cursor::new(Vec::new());
957        let mut package =
958            Package::create(PackageType::Installer, cursor).expect("create");
959        let columns = vec![
960            Column::build("Key").primary_key().int16(),
961            Column::build("Value").nullable().int32(),
962        ];
963        package.create_table("Mapping", columns).expect("create_table");
964        package
965            .insert_rows(
966                Insert::into("Mapping")
967                    .row(vec![Value::Int(1), Value::Int(17)])
968                    .row(vec![Value::Int(2), Value::Int(42)])
969                    .row(vec![Value::Int(3), Value::Int(17)]),
970            )
971            .expect("insert_rows");
972        package
973            .update_rows(
974                Update::table("Mapping")
975                    .set("Value", Value::Int(-5))
976                    .with(Expr::col("Value").eq(Expr::integer(17))),
977            )
978            .unwrap();
979
980        let cursor = package.into_inner().expect("into_inner");
981        let mut package = Package::open(cursor).expect("open");
982        let rows = package.select_rows(Select::table("Mapping")).unwrap();
983        let values: Vec<(i32, i32)> = rows
984            .map(|row| (row[0].as_int().unwrap(), row[1].as_int().unwrap()))
985            .collect();
986        assert_eq!(values, vec![(1, -5), (2, 42), (3, -5)]);
987    }
988}
989
990// ========================================================================= //