fern_masking/
lib.rs

1// SPDX-FileCopyrightText:  Copyright © 2022 The Fern Authors <team@fernproxy.io>
2// SPDX-License-Identifier: Apache-2.0
3
4use async_trait::async_trait;
5use bytes::Bytes;
6
7use fern_protocol_postgresql::codec::backend;
8use fern_proxy_interfaces::{SQLMessage, SQLMessageHandler};
9
10use crate::strategies::MaskingStrategy;
11
12// Re-export.
13pub use fern_proxy_interfaces::SQLHandlerConfig;
14
15/// An `SQLMessageHandler` applying a data masking strategy.
16///
17/// The [`MaskingStrategy`] to use is set at struct instantiation,
18/// depending on settings in provided `SQLHandlerConfig`.
19///
20/// Should no settings be defined for data masking, by default a
21/// fixed-length caviar strategy will mask all `DataRow` fields.
22#[derive(Debug)]
23pub struct DataMaskingHandler {
24    state: QueryState,
25
26    /// Masking strategy applied by this Handler.
27    //TODO(ppiotr3k): investigate if `Box`-ing can be avoided
28    strategy: Box<dyn MaskingStrategy>,
29
30    /// Column names where masking will not be applied, unless forced.
31    columns_excluded: Vec<Bytes>,
32
33    /// Column names where masking will be applied, in any case.
34    /// This allows using a wildcard in exclusions, and progressively mask.
35    columns_forced: Vec<Bytes>,
36}
37
38///TODO(ppiotr3k): write description
39#[derive(Debug)]
40enum QueryState {
41    /// Awaiting for a `RowDescription` Message.
42    Description,
43
44    /// Processing `DataRow` Messages.
45    Data(Vec<usize>),
46}
47
48//TODO(ppiotr3k): this crate should only process abstracted types
49#[async_trait]
50impl SQLMessageHandler<backend::Message> for DataMaskingHandler {
51    fn new(config: &SQLHandlerConfig) -> Self {
52        //TODO(ppiotr3k): make length configurable
53        let strategy: Box<dyn MaskingStrategy> =
54            if let Ok(strategy) = config.get::<String>("masking.strategy") {
55                match strategy.as_str() {
56                    "caviar" => Box::new(strategies::CaviarMask::new(6)),
57                    "caviar-preserve-shape" => Box::new(strategies::CaviarShapeMask::new()),
58                    _ => Box::new(strategies::CaviarMask::new(6)),
59                }
60            } else {
61                // Default strategy, if nothing is defined in `config`.
62                Box::new(strategies::CaviarMask::new(6))
63            };
64
65        let mut columns_excluded = vec![];
66        if let Ok(columns) = config.get::<Vec<String>>("masking.exclude.columns") {
67            for column_name in columns.iter() {
68                columns_excluded.push(Bytes::from(column_name.clone()));
69            }
70        }
71
72        let mut columns_forced = vec![];
73        if let Ok(columns) = config.get::<Vec<String>>("masking.force.columns") {
74            for column_name in columns.iter() {
75                columns_forced.push(Bytes::from(column_name.clone()));
76            }
77        }
78
79        Self {
80            state: QueryState::Description,
81            strategy,
82            columns_excluded,
83            columns_forced,
84        }
85    }
86
87    async fn process(&mut self, msg: backend::Message) -> backend::Message {
88        match msg {
89            backend::Message::RowDescription(descriptions) => {
90                // Define indexes of columns to exclude from masking.
91                let mut no_mask = vec![];
92                if self.columns_excluded.len() == 1 && self.columns_excluded[0] == "*" {
93                    // Wildcard `*` translates to all indexes.
94                    // Note: `forced` columns prevail on exclusions anyway.
95                    for (idx, description) in descriptions.iter().enumerate() {
96                        if !self.columns_forced.contains(&description.name) {
97                            no_mask.push(idx);
98                        }
99                    }
100                } else {
101                    // If no wildcard, capture columns to exclude from masking.
102                    // Note: `forced` columns prevail on exclusions anyway.
103                    for (idx, description) in descriptions.iter().enumerate() {
104                        if self.columns_excluded.contains(&description.name)
105                            && !self.columns_forced.contains(&description.name)
106                        {
107                            no_mask.push(idx);
108                        }
109                    }
110                }
111
112                // Store indexes to exclude from masking in upcoming `DataRow`s.
113                self.state = QueryState::Data(no_mask);
114                log::debug!("new masking exclusion state: {:?}", self.state);
115                backend::Message::RowDescription(descriptions)
116            }
117            backend::Message::CommandComplete(command) => {
118                // No more `DataRow`s to process, reset state.
119                self.state = QueryState::Description;
120                log::debug!("resetting masking state, awaiting next query");
121                backend::Message::CommandComplete(command)
122            }
123            backend::Message::DataRow(fields) => {
124                log::trace!("processing fields: {:?}", fields);
125                let mask = if let QueryState::Data(mask) = &self.state {
126                    mask
127                } else {
128                    panic!("unexpected state for `QueryState`");
129                };
130
131                let mut replaced_fields = vec![];
132                for (idx, field) in fields.iter().enumerate() {
133                    if !mask.contains(&idx) {
134                        log::debug!("applying masking to field #{}", idx);
135                        let rewritten = self.strategy.mask(field);
136                        replaced_fields.push(rewritten);
137                    } else {
138                        replaced_fields.push(field.clone());
139                    }
140                }
141                backend::Message::DataRow(replaced_fields)
142            }
143            _ => msg,
144        }
145    }
146}
147
148/// Handler used currently for PostgreSQL frontend Messages.
149/// Does nothing but passthrough.
150#[derive(Debug)]
151pub struct PassthroughHandler<M> {
152    _phantom: std::marker::PhantomData<M>,
153}
154
155#[async_trait]
156impl<M> SQLMessageHandler<M> for PassthroughHandler<M>
157where
158    M: SQLMessage,
159{
160    fn new(_config: &SQLHandlerConfig) -> Self {
161        Self {
162            _phantom: std::marker::PhantomData,
163        }
164    }
165}
166
167mod strategies;
168
169#[cfg(test)]
170mod tests {
171    #[test]
172    fn it_works() {}
173}