1use prost::Message;
6use std::{collections::HashMap, time::Duration};
7
8use crate::{
9 builder::{load_and_translate_block, BlockBuilder, Convert, Policy},
10 datalog::{Origin, RunLimits, TrustedOrigins},
11 error,
12 format::{
13 convert::{
14 policy_to_proto_policy, proto_fact_to_token_fact, proto_policy_to_policy,
15 proto_snapshot_block_to_token_block, token_block_to_proto_snapshot_block,
16 token_fact_to_proto_fact,
17 },
18 schema::{self, GeneratedFacts},
19 },
20 token::{default_symbol_table, MAX_SCHEMA_VERSION, MIN_SCHEMA_VERSION},
21 PublicKey,
22};
23
24impl super::Authorizer {
25 pub fn from_snapshot(input: schema::AuthorizerSnapshot) -> Result<Self, error::Token> {
26 let schema::AuthorizerSnapshot {
27 limits,
28 execution_time,
29 world,
30 } = input;
31
32 let limits = RunLimits {
33 max_facts: limits.max_facts,
34 max_iterations: limits.max_iterations,
35 max_time: Duration::from_nanos(limits.max_time),
36 };
37
38 let execution_time = Duration::from_nanos(execution_time);
39
40 let version = world.version.unwrap_or(0);
41 if !(MIN_SCHEMA_VERSION..=MAX_SCHEMA_VERSION).contains(&version) {
42 return Err(error::Format::Version {
43 minimum: crate::token::MIN_SCHEMA_VERSION,
44 maximum: crate::token::MAX_SCHEMA_VERSION,
45 actual: version,
46 }
47 .into());
48 }
49
50 let mut symbols = default_symbol_table();
51 for symbol in world.symbols {
52 symbols.insert(&symbol);
53 }
54 for public_key in world.public_keys {
55 symbols
56 .public_keys
57 .insert(&PublicKey::from_proto(&public_key)?);
58 }
59
60 let authorizer_block = proto_snapshot_block_to_token_block(&world.authorizer_block)?;
61
62 let authorizer_block_builder = BlockBuilder::convert_from(&authorizer_block, &symbols)?;
63 let policies = world
64 .authorizer_policies
65 .iter()
66 .map(|policy| proto_policy_to_policy(policy, &symbols, version))
67 .collect::<Result<Vec<Policy>, error::Format>>()?;
68
69 let mut authorizer = super::Authorizer::new();
70 authorizer.symbols = symbols;
71 authorizer.authorizer_block_builder = authorizer_block_builder;
72 authorizer.policies = policies;
73 authorizer.limits = limits;
74 authorizer.execution_time =
75 Some(execution_time).filter(|_| execution_time > Duration::default());
76
77 let mut public_key_to_block_id: HashMap<usize, Vec<usize>> = HashMap::new();
78 let mut blocks = Vec::new();
79 for (i, block) in world.blocks.iter().enumerate() {
80 let token_symbols = if block.external_key.is_none() {
81 authorizer.symbols.clone()
82 } else {
83 let mut token_symbols = authorizer.symbols.clone();
84 token_symbols.public_keys = authorizer.symbols.public_keys.clone();
85 token_symbols
86 };
87
88 let mut block = proto_snapshot_block_to_token_block(block)?;
89
90 if let Some(key) = block.external_key.as_ref() {
91 public_key_to_block_id
92 .entry(authorizer.symbols.public_keys.insert(key) as usize)
93 .or_default()
94 .push(i);
95 }
96
97 load_and_translate_block(
98 &mut block,
99 i,
100 &token_symbols,
101 &mut authorizer.symbols,
102 &mut public_key_to_block_id,
103 &mut authorizer.world,
104 )?;
105 blocks.push(block);
106 }
107
108 authorizer.public_key_to_block_id = public_key_to_block_id;
109
110 if !blocks.is_empty() {
111 authorizer.token_origins = TrustedOrigins::from_scopes(
112 &[crate::token::Scope::Previous],
113 &TrustedOrigins::default(),
114 blocks.len(),
115 &authorizer.public_key_to_block_id,
116 );
117 authorizer.blocks = Some(blocks);
118 }
119
120 let mut authorizer_origin = Origin::default();
121 authorizer_origin.insert(usize::MAX);
122
123 let authorizer_scopes: Vec<crate::token::Scope> = authorizer
124 .authorizer_block_builder
125 .scopes
126 .clone()
127 .iter()
128 .map(|s| s.convert(&mut authorizer.symbols))
129 .collect();
130
131 let authorizer_trusted_origins = TrustedOrigins::from_scopes(
132 &authorizer_scopes,
133 &TrustedOrigins::default(),
134 usize::MAX,
135 &authorizer.public_key_to_block_id,
136 );
137 for fact in &authorizer.authorizer_block_builder.facts {
138 authorizer
139 .world
140 .facts
141 .insert(&authorizer_origin, fact.convert(&mut authorizer.symbols));
142 }
143
144 for rule in &authorizer.authorizer_block_builder.rules {
145 let rule = rule.convert(&mut authorizer.symbols);
146
147 let rule_trusted_origins = TrustedOrigins::from_scopes(
148 &rule.scopes,
149 &authorizer_trusted_origins,
150 usize::MAX,
151 &authorizer.public_key_to_block_id,
152 );
153
154 authorizer
155 .world
156 .rules
157 .insert(usize::MAX, &rule_trusted_origins, rule);
158 }
159
160 for GeneratedFacts { origins, facts } in world.generated_facts {
161 let origin = proto_origin_to_authorizer_origin(&origins)?;
162
163 for fact in &facts {
164 let fact = proto_fact_to_token_fact(fact)?;
165 authorizer.world.facts.insert(&origin, fact);
167 }
168 }
169
170 authorizer.world.iterations = world.iterations;
171
172 Ok(authorizer)
173 }
174
175 pub fn from_raw_snapshot(input: &[u8]) -> Result<Self, error::Token> {
176 let snapshot = schema::AuthorizerSnapshot::decode(input).map_err(|e| {
177 error::Format::DeserializationError(format!("deserialization error: {:?}", e))
178 })?;
179 Self::from_snapshot(snapshot)
180 }
181
182 pub fn from_base64_snapshot(input: &str) -> Result<Self, error::Token> {
183 let bytes = base64::decode_config(input, base64::URL_SAFE)?;
184 Self::from_raw_snapshot(&bytes)
185 }
186
187 pub fn snapshot(&self) -> Result<schema::AuthorizerSnapshot, error::Format> {
188 let mut symbols = default_symbol_table();
189
190 let authorizer_policies = self
191 .policies
192 .iter()
193 .map(|policy| policy_to_proto_policy(policy, &mut symbols))
194 .collect();
195
196 let authorizer_block = self.authorizer_block_builder.clone().build(symbols.clone());
197 symbols.extend(&authorizer_block.symbols)?;
198 symbols.public_keys.extend(&authorizer_block.public_keys)?;
199
200 let authorizer_block = token_block_to_proto_snapshot_block(&authorizer_block);
201
202 let blocks = match self.blocks.as_ref() {
203 None => Vec::new(),
204 Some(blocks) => blocks
205 .iter()
206 .map(|block| {
207 block
208 .translate(&self.symbols, &mut symbols)
209 .map(|block| token_block_to_proto_snapshot_block(&block))
210 })
211 .collect::<Result<Vec<_>, error::Format>>()?,
212 };
213
214 let generated_facts = self
215 .world
216 .facts
217 .inner
218 .iter()
219 .map(|(origin, facts)| {
220 Ok(GeneratedFacts {
221 origins: authorizer_origin_to_proto_origin(origin),
222 facts: facts
223 .iter()
224 .map(|fact| {
225 Ok(token_fact_to_proto_fact(
226 &crate::builder::Fact::convert_from(fact, &self.symbols)?
227 .convert(&mut symbols),
228 ))
229 })
230 .collect::<Result<Vec<_>, error::Format>>()?,
231 })
232 })
233 .collect::<Result<Vec<GeneratedFacts>, error::Format>>()?;
234
235 let world = schema::AuthorizerWorld {
236 version: Some(MAX_SCHEMA_VERSION),
237 symbols: symbols.strings(),
238 public_keys: symbols
239 .public_keys
240 .into_inner()
241 .into_iter()
242 .map(|key| key.to_proto())
243 .collect(),
244 blocks,
245 authorizer_block,
246 authorizer_policies,
247 generated_facts,
248 iterations: self.world.iterations,
249 };
250
251 Ok(schema::AuthorizerSnapshot {
252 world,
253 execution_time: self.execution_time.unwrap_or_default().as_nanos() as u64,
254 limits: schema::RunLimits {
255 max_facts: self.limits.max_facts,
256 max_iterations: self.limits.max_iterations,
257 max_time: self.limits.max_time.as_nanos() as u64,
258 },
259 })
260 }
261
262 pub fn to_raw_snapshot(&self) -> Result<Vec<u8>, error::Format> {
263 let snapshot = self.snapshot()?;
264 let mut bytes = Vec::new();
265 snapshot.encode(&mut bytes).map_err(|e| {
266 error::Format::SerializationError(format!("serialization error: {:?}", e))
267 })?;
268 Ok(bytes)
269 }
270
271 pub fn to_base64_snapshot(&self) -> Result<String, error::Format> {
272 let snapshot_bytes = self.to_raw_snapshot()?;
273 Ok(base64::encode_config(snapshot_bytes, base64::URL_SAFE))
274 }
275}
276
277pub(crate) fn authorizer_origin_to_proto_origin(origin: &Origin) -> Vec<schema::Origin> {
278 origin
279 .inner
280 .iter()
281 .map(|o| {
282 if *o == usize::MAX {
283 schema::Origin {
284 content: Some(schema::origin::Content::Authorizer(schema::Empty {})),
285 }
286 } else {
287 schema::Origin {
288 content: Some(schema::origin::Content::Origin(*o as u32)),
289 }
290 }
291 })
292 .collect()
293}
294
295pub(crate) fn proto_origin_to_authorizer_origin(
296 origins: &[schema::Origin],
297) -> Result<Origin, error::Format> {
298 let mut new_origin = Origin::default();
299
300 for origin in origins {
301 match origin.content {
302 Some(schema::origin::Content::Authorizer(schema::Empty {})) => {
303 new_origin.insert(usize::MAX)
304 }
305 Some(schema::origin::Content::Origin(o)) => new_origin.insert(o as usize),
306 _ => {
307 return Err(error::Format::DeserializationError(
308 "invalid origin".to_string(),
309 ))
310 }
311 }
312 }
313
314 Ok(new_origin)
315}
316
317#[cfg(test)]
318mod tests {
319 use std::collections::HashMap;
320 use std::time::Duration;
321
322 use crate::{datalog::RunLimits, Algorithm, AuthorizerBuilder};
323 use crate::{Authorizer, BiscuitBuilder, KeyPair};
324
325 #[test]
326 fn roundtrip_builder() {
327 let secp_pubkey = KeyPair::new_with_algorithm(Algorithm::Secp256r1).public();
328 let ed_pubkey = KeyPair::new_with_algorithm(Algorithm::Ed25519).public();
329 let builder = AuthorizerBuilder::new()
330 .set_limits(RunLimits {
331 max_facts: 42,
332 max_iterations: 42,
333 max_time: Duration::from_secs(1),
334 })
335 .code_with_params(
336 r#"
337 fact(true);
338 head($a) <- fact($a);
339 check if head(true) trusting authority, {ed_pubkey}, {secp_pubkey};
340 allow if head(true);
341 deny if head(false);
342 "#,
343 HashMap::default(),
344 HashMap::from([
345 ("ed_pubkey".to_string(), ed_pubkey),
346 ("secp_pubkey".to_string(), secp_pubkey),
347 ]),
348 )
349 .unwrap();
350 let snapshot = builder.snapshot().unwrap();
351
352 let parsed = AuthorizerBuilder::from_snapshot(snapshot).unwrap();
353 assert_eq!(parsed.dump_code(), builder.dump_code());
354 assert_eq!(parsed.limits, builder.limits);
355 }
356
357 #[test]
358 fn roundtrip_with_token() {
359 let secp_pubkey = KeyPair::new_with_algorithm(Algorithm::Secp256r1).public();
360 let ed_pubkey = KeyPair::new_with_algorithm(Algorithm::Ed25519).public();
361 let builder = AuthorizerBuilder::new()
362 .set_limits(RunLimits {
363 max_facts: 42,
364 max_iterations: 42,
365 max_time: Duration::from_secs(1),
366 })
367 .code_with_params(
368 r#"
369 fact(true);
370 head($a) <- fact($a);
371 check if head(true) trusting authority, {ed_pubkey}, {secp_pubkey};
372 allow if head(true);
373 deny if head(false);
374 "#,
375 HashMap::default(),
376 HashMap::from([
377 ("ed_pubkey".to_string(), ed_pubkey),
378 ("secp_pubkey".to_string(), secp_pubkey),
379 ]),
380 )
381 .unwrap();
382 let biscuit = BiscuitBuilder::new()
383 .code_with_params(
384 r#"
385 bfact(true);
386 bhead($a) <- fact($a);
387 check if bhead(true) trusting authority, {ed_pubkey}, {secp_pubkey};
388 "#,
389 HashMap::default(),
390 HashMap::from([
391 ("ed_pubkey".to_string(), ed_pubkey),
392 ("secp_pubkey".to_string(), secp_pubkey),
393 ]),
394 )
395 .unwrap()
396 .build(&KeyPair::new())
397 .unwrap();
398
399 let authorizer_pre_run = builder.build(&biscuit).unwrap();
400
401 let snapshot = authorizer_pre_run.snapshot().unwrap();
402
403 let parsed = Authorizer::from_snapshot(snapshot).unwrap();
404 assert_eq!(parsed.dump_code(), authorizer_pre_run.dump_code());
405 assert_eq!(parsed.limits(), authorizer_pre_run.limits());
406
407 let mut authorizer_post_run = authorizer_pre_run.clone();
408 let _ = authorizer_post_run.run();
409
410 let snapshot = authorizer_post_run.snapshot().unwrap();
411
412 let parsed = Authorizer::from_snapshot(snapshot).unwrap();
413 assert_eq!(parsed.dump_code(), authorizer_post_run.dump_code());
414 assert_eq!(parsed.limits(), authorizer_post_run.limits());
415 }
416
417 #[test]
418 fn roundtrip_without_token() {
419 let builder = AuthorizerBuilder::new()
420 .set_limits(RunLimits {
421 max_facts: 42,
422 max_iterations: 42,
423 max_time: Duration::from_secs(1),
424 })
425 .code(
426 r#"
427 fact(true);
428 head($a) <- fact($a);
429 check if head(true);
430 allow if head(true);
431 deny if head(false);
432 "#,
433 )
434 .unwrap();
435 let authorizer = builder.build_unauthenticated().unwrap();
436 let snapshot = authorizer.snapshot().unwrap();
437
438 let parsed = Authorizer::from_snapshot(snapshot).unwrap();
439 assert_eq!(parsed.dump_code(), authorizer.dump_code());
440 assert_eq!(parsed.limits(), authorizer.limits());
441
442 let mut authorizer_post_run = authorizer.clone();
443 let _ = authorizer_post_run.run();
444 let snapshot = authorizer_post_run.snapshot().unwrap();
445
446 let parsed = Authorizer::from_snapshot(snapshot).unwrap();
447 assert_eq!(parsed.dump_code(), authorizer_post_run.dump_code());
448 assert_eq!(parsed.limits(), authorizer_post_run.limits());
449 }
450
451 #[test]
452 fn roundtrip_with_eval_error() {
453 let builder = AuthorizerBuilder::new()
454 .set_limits(RunLimits {
455 max_facts: 42,
456 max_iterations: 42,
457 max_time: Duration::from_secs(1),
458 })
459 .code(
460 r#"
461 fact(true);
462 head($a) <- fact($a), $a.length();
463 allow if head(true);
464 deny if head(false);
465 "#,
466 )
467 .unwrap();
468 let authorizer = builder.build_unauthenticated().unwrap();
469 let snapshot = authorizer.snapshot().unwrap();
470
471 let parsed = Authorizer::from_snapshot(snapshot).unwrap();
472 assert_eq!(parsed.dump_code(), authorizer.dump_code());
473 assert_eq!(parsed.limits(), authorizer.limits());
474
475 let mut authorizer_post_run = authorizer.clone();
476 let _ = authorizer_post_run.run();
477 let snapshot = authorizer_post_run.snapshot().unwrap();
478
479 let parsed = Authorizer::from_snapshot(snapshot).unwrap();
480 assert_eq!(parsed.dump_code(), authorizer_post_run.dump_code());
481 assert_eq!(parsed.limits(), authorizer_post_run.limits());
482 }
483}