biscuit_auth/token/authorizer/
snapshot.rs

1use prost::Message;
2use std::{collections::HashMap, time::Duration};
3
4use crate::{
5    builder::{BlockBuilder, Convert, Policy},
6    datalog::{Origin, RunLimits, TrustedOrigins},
7    error,
8    format::{
9        convert::{
10            proto_snapshot_block_to_token_block, token_block_to_proto_snapshot_block,
11            v2::{
12                policy_to_proto_policy, proto_fact_to_token_fact, proto_policy_to_policy,
13                token_fact_to_proto_fact,
14            },
15        },
16        schema::{self, GeneratedFacts},
17    },
18    token::{default_symbol_table, MAX_SCHEMA_VERSION, MIN_SCHEMA_VERSION},
19    PublicKey,
20};
21
22impl super::Authorizer {
23    pub fn from_snapshot(input: schema::AuthorizerSnapshot) -> Result<Self, error::Token> {
24        let schema::AuthorizerSnapshot {
25            limits,
26            execution_time,
27            world,
28        } = input;
29
30        let limits = RunLimits {
31            max_facts: limits.max_facts,
32            max_iterations: limits.max_iterations,
33            max_time: Duration::from_nanos(limits.max_time),
34        };
35
36        let execution_time = Duration::from_nanos(execution_time);
37
38        let version = world.version.unwrap_or(0);
39        if !(MIN_SCHEMA_VERSION..=MAX_SCHEMA_VERSION).contains(&version) {
40            return Err(error::Format::Version {
41                minimum: crate::token::MIN_SCHEMA_VERSION,
42                maximum: crate::token::MAX_SCHEMA_VERSION,
43                actual: version,
44            }
45            .into());
46        }
47
48        let mut symbols = default_symbol_table();
49        for symbol in world.symbols {
50            symbols.insert(&symbol);
51        }
52        for public_key in world.public_keys {
53            symbols
54                .public_keys
55                .insert(&PublicKey::from_proto(&public_key)?);
56        }
57
58        let authorizer_block = proto_snapshot_block_to_token_block(&world.authorizer_block)?;
59
60        let authorizer_block_builder = BlockBuilder::convert_from(&authorizer_block, &symbols)?;
61        let policies = world
62            .authorizer_policies
63            .iter()
64            .map(|policy| proto_policy_to_policy(policy, &symbols, version))
65            .collect::<Result<Vec<Policy>, error::Format>>()?;
66
67        let mut authorizer = super::Authorizer::new();
68        authorizer.symbols = symbols;
69        authorizer.authorizer_block_builder = authorizer_block_builder;
70        authorizer.policies = policies;
71        authorizer.limits = limits;
72        authorizer.execution_time = execution_time;
73
74        let mut public_key_to_block_id: HashMap<usize, Vec<usize>> = HashMap::new();
75        let mut blocks = Vec::new();
76        for (i, block) in world.blocks.iter().enumerate() {
77            let token_symbols = if block.external_key.is_none() {
78                authorizer.symbols.clone()
79            } else {
80                let mut token_symbols = authorizer.symbols.clone();
81                token_symbols.public_keys = authorizer.symbols.public_keys.clone();
82                token_symbols
83            };
84
85            let mut block = proto_snapshot_block_to_token_block(block)?;
86
87            if let Some(key) = block.external_key.as_ref() {
88                public_key_to_block_id
89                    .entry(authorizer.symbols.public_keys.insert(key) as usize)
90                    .or_default()
91                    .push(i);
92            }
93
94            authorizer.load_and_translate_block(&mut block, i, &token_symbols)?;
95            blocks.push(block);
96        }
97
98        authorizer.public_key_to_block_id = public_key_to_block_id;
99
100        if !blocks.is_empty() {
101            authorizer.token_origins = TrustedOrigins::from_scopes(
102                &[crate::token::Scope::Previous],
103                &TrustedOrigins::default(),
104                blocks.len(),
105                &authorizer.public_key_to_block_id,
106            );
107            authorizer.blocks = Some(blocks);
108        }
109
110        for GeneratedFacts { origins, facts } in world.generated_facts {
111            let origin = proto_origin_to_authorizer_origin(&origins)?;
112
113            for fact in &facts {
114                let fact = proto_fact_to_token_fact(fact)?;
115                //let fact = Fact::convert_from(&fact, &symbols)?.convert(&mut authorizer.symbols);
116                authorizer.world.facts.insert(&origin, fact);
117            }
118        }
119
120        authorizer.world.iterations = world.iterations;
121
122        Ok(authorizer)
123    }
124
125    pub fn from_raw_snapshot(input: &[u8]) -> Result<Self, error::Token> {
126        let snapshot = schema::AuthorizerSnapshot::decode(input).map_err(|e| {
127            error::Format::DeserializationError(format!("deserialization error: {:?}", e))
128        })?;
129        Self::from_snapshot(snapshot)
130    }
131
132    pub fn from_base64_snapshot(input: &str) -> Result<Self, error::Token> {
133        let bytes = base64::decode_config(input, base64::URL_SAFE)?;
134        Self::from_raw_snapshot(&bytes)
135    }
136
137    pub fn snapshot(&self) -> Result<schema::AuthorizerSnapshot, error::Format> {
138        let mut symbols = default_symbol_table();
139
140        let authorizer_policies = self
141            .policies
142            .iter()
143            .map(|policy| policy_to_proto_policy(policy, &mut symbols))
144            .collect();
145
146        let authorizer_block = self.authorizer_block_builder.clone().build(symbols.clone());
147        symbols.extend(&authorizer_block.symbols)?;
148        symbols.public_keys.extend(&authorizer_block.public_keys)?;
149
150        let authorizer_block = token_block_to_proto_snapshot_block(&authorizer_block);
151
152        let blocks = match self.blocks.as_ref() {
153            None => Vec::new(),
154            Some(blocks) => blocks
155                .iter()
156                .map(|block| {
157                    block
158                        .translate(&self.symbols, &mut symbols)
159                        .map(|block| token_block_to_proto_snapshot_block(&block))
160                })
161                .collect::<Result<Vec<_>, error::Format>>()?,
162        };
163
164        let generated_facts = self
165            .world
166            .facts
167            .inner
168            .iter()
169            .map(|(origin, facts)| {
170                Ok(GeneratedFacts {
171                    origins: authorizer_origin_to_proto_origin(origin),
172                    facts: facts
173                        .iter()
174                        .map(|fact| {
175                            Ok(token_fact_to_proto_fact(
176                                &crate::builder::Fact::convert_from(fact, &self.symbols)?
177                                    .convert(&mut symbols),
178                            ))
179                        })
180                        .collect::<Result<Vec<_>, error::Format>>()?,
181                })
182            })
183            .collect::<Result<Vec<GeneratedFacts>, error::Format>>()?;
184
185        let world = schema::AuthorizerWorld {
186            version: Some(MAX_SCHEMA_VERSION),
187            symbols: symbols.strings(),
188            public_keys: symbols
189                .public_keys
190                .into_inner()
191                .into_iter()
192                .map(|key| key.to_proto())
193                .collect(),
194            blocks,
195            authorizer_block,
196            authorizer_policies,
197            generated_facts,
198            iterations: self.world.iterations,
199        };
200
201        Ok(schema::AuthorizerSnapshot {
202            world,
203            execution_time: self.execution_time.as_nanos() as u64,
204            limits: schema::RunLimits {
205                max_facts: self.limits.max_facts,
206                max_iterations: self.limits.max_iterations,
207                max_time: self.limits.max_time.as_nanos() as u64,
208            },
209        })
210    }
211
212    pub fn to_raw_snapshot(&self) -> Result<Vec<u8>, error::Format> {
213        let snapshot = self.snapshot()?;
214        let mut bytes = Vec::new();
215        snapshot.encode(&mut bytes).map_err(|e| {
216            error::Format::SerializationError(format!("serialization error: {:?}", e))
217        })?;
218        Ok(bytes)
219    }
220
221    pub fn to_base64_snapshot(&self) -> Result<String, error::Format> {
222        let snapshot_bytes = self.to_raw_snapshot()?;
223        Ok(base64::encode_config(snapshot_bytes, base64::URL_SAFE))
224    }
225}
226
227fn authorizer_origin_to_proto_origin(origin: &Origin) -> Vec<schema::Origin> {
228    origin
229        .inner
230        .iter()
231        .map(|o| {
232            if *o == usize::MAX {
233                schema::Origin {
234                    content: Some(schema::origin::Content::Authorizer(schema::Empty {})),
235                }
236            } else {
237                schema::Origin {
238                    content: Some(schema::origin::Content::Origin(*o as u32)),
239                }
240            }
241        })
242        .collect()
243}
244
245fn proto_origin_to_authorizer_origin(origins: &[schema::Origin]) -> Result<Origin, error::Format> {
246    let mut new_origin = Origin::default();
247
248    for origin in origins {
249        match origin.content {
250            Some(schema::origin::Content::Authorizer(schema::Empty {})) => {
251                new_origin.insert(usize::MAX)
252            }
253            Some(schema::origin::Content::Origin(o)) => new_origin.insert(o as usize),
254            _ => {
255                return Err(error::Format::DeserializationError(
256                    "invalid origin".to_string(),
257                ))
258            }
259        }
260    }
261
262    Ok(new_origin)
263}