1use async_trait::async_trait;
5use bytes::Bytes;
6
7use fern_protocol_postgresql::codec::backend;
8use fern_proxy_interfaces::{SQLMessage, SQLMessageHandler};
9
10use crate::strategies::MaskingStrategy;
11
12pub use fern_proxy_interfaces::SQLHandlerConfig;
14
15#[derive(Debug)]
23pub struct DataMaskingHandler {
24 state: QueryState,
25
26 strategy: Box<dyn MaskingStrategy>,
29
30 columns_excluded: Vec<Bytes>,
32
33 columns_forced: Vec<Bytes>,
36}
37
38#[derive(Debug)]
40enum QueryState {
41 Description,
43
44 Data(Vec<usize>),
46}
47
48#[async_trait]
50impl SQLMessageHandler<backend::Message> for DataMaskingHandler {
51 fn new(config: &SQLHandlerConfig) -> Self {
52 let strategy: Box<dyn MaskingStrategy> =
54 if let Ok(strategy) = config.get::<String>("masking.strategy") {
55 match strategy.as_str() {
56 "caviar" => Box::new(strategies::CaviarMask::new(6)),
57 "caviar-preserve-shape" => Box::new(strategies::CaviarShapeMask::new()),
58 _ => Box::new(strategies::CaviarMask::new(6)),
59 }
60 } else {
61 Box::new(strategies::CaviarMask::new(6))
63 };
64
65 let mut columns_excluded = vec![];
66 if let Ok(columns) = config.get::<Vec<String>>("masking.exclude.columns") {
67 for column_name in columns.iter() {
68 columns_excluded.push(Bytes::from(column_name.clone()));
69 }
70 }
71
72 let mut columns_forced = vec![];
73 if let Ok(columns) = config.get::<Vec<String>>("masking.force.columns") {
74 for column_name in columns.iter() {
75 columns_forced.push(Bytes::from(column_name.clone()));
76 }
77 }
78
79 Self {
80 state: QueryState::Description,
81 strategy,
82 columns_excluded,
83 columns_forced,
84 }
85 }
86
87 async fn process(&mut self, msg: backend::Message) -> backend::Message {
88 match msg {
89 backend::Message::RowDescription(descriptions) => {
90 let mut no_mask = vec![];
92 if self.columns_excluded.len() == 1 && self.columns_excluded[0] == "*" {
93 for (idx, description) in descriptions.iter().enumerate() {
96 if !self.columns_forced.contains(&description.name) {
97 no_mask.push(idx);
98 }
99 }
100 } else {
101 for (idx, description) in descriptions.iter().enumerate() {
104 if self.columns_excluded.contains(&description.name)
105 && !self.columns_forced.contains(&description.name)
106 {
107 no_mask.push(idx);
108 }
109 }
110 }
111
112 self.state = QueryState::Data(no_mask);
114 log::debug!("new masking exclusion state: {:?}", self.state);
115 backend::Message::RowDescription(descriptions)
116 }
117 backend::Message::CommandComplete(command) => {
118 self.state = QueryState::Description;
120 log::debug!("resetting masking state, awaiting next query");
121 backend::Message::CommandComplete(command)
122 }
123 backend::Message::DataRow(fields) => {
124 log::trace!("processing fields: {:?}", fields);
125 let mask = if let QueryState::Data(mask) = &self.state {
126 mask
127 } else {
128 panic!("unexpected state for `QueryState`");
129 };
130
131 let mut replaced_fields = vec![];
132 for (idx, field) in fields.iter().enumerate() {
133 if !mask.contains(&idx) {
134 log::debug!("applying masking to field #{}", idx);
135 let rewritten = self.strategy.mask(field);
136 replaced_fields.push(rewritten);
137 } else {
138 replaced_fields.push(field.clone());
139 }
140 }
141 backend::Message::DataRow(replaced_fields)
142 }
143 _ => msg,
144 }
145 }
146}
147
148#[derive(Debug)]
151pub struct PassthroughHandler<M> {
152 _phantom: std::marker::PhantomData<M>,
153}
154
155#[async_trait]
156impl<M> SQLMessageHandler<M> for PassthroughHandler<M>
157where
158 M: SQLMessage,
159{
160 fn new(_config: &SQLHandlerConfig) -> Self {
161 Self {
162 _phantom: std::marker::PhantomData,
163 }
164 }
165}
166
167mod strategies;
168
169#[cfg(test)]
170mod tests {
171 #[test]
172 fn it_works() {}
173}