Skip to main content

onepass_seed/expr/
generator.rs

1use core::fmt;
2use std::{io, sync::Arc};
3
4use crypto_bigint::{NonZero, U256, Word as _Word};
5use onepass_base::dict::Dict;
6use secrecy::{ExposeSecret, ExposeSecretMut, SecretBox};
7
8use super::{
9    EvalContext,
10    context::Context,
11    repr::write_literal,
12    util::{u256_saturating_pow, u256_to_word},
13};
14use crate::dict::EFF_WORDLIST;
15
16pub trait GeneratorFunc: Send + Sync {
17    fn name(&self) -> &'static str;
18
19    // TODO(soon): return Result from size so we can report dict lookup failure
20    fn size(&self, context: &Context<'_>, args: &[&str]) -> NonZero<U256>;
21
22    fn write_to(
23        &self,
24        context: &Context<'_>,
25        w: &mut dyn io::Write,
26        index: &mut dyn ExposeSecretMut<U256>,
27        args: &[&str],
28    ) -> io::Result<()>;
29
30    /// `GeneratorFunc`s know how to format themselves, which they may use to e.g. inject
31    /// dictionary hashes for canonical serialization.
32    // TODO(someday): standardize `write_sep_arg`, and instead have an optional trait method that
33    // yields each argument.
34    fn write_repr(&self, _: &Context<'_>, w: &mut dyn fmt::Write, args: &[&str]) -> fmt::Result {
35        write!(w, "{}", self.name())?;
36        for &arg in args {
37            write_sep_arg(w, arg)?;
38        }
39        Ok(())
40    }
41}
42
43#[derive(Clone, Debug, PartialEq)]
44pub struct Generator(Box<str>);
45
46pub struct Word;
47
48pub struct Words;
49
50fn write_sep_arg<W>(w: &mut W, arg: &str) -> fmt::Result
51where
52    W: fmt::Write + ?Sized,
53{
54    w.write_char('|')?;
55    write_literal(w, arg)?;
56    Ok(())
57}
58
59impl EvalContext for Generator {
60    type Context<'a> = Context<'a>;
61
62    fn size(&self, context: &Context) -> NonZero<U256> {
63        context
64            .get_generator(self.name())
65            .unwrap()
66            .size(context, &self.args())
67    }
68
69    fn write_to(
70        &self,
71        context: &Context,
72        w: &mut dyn io::Write,
73        index: &mut dyn ExposeSecretMut<U256>,
74    ) -> io::Result<()> {
75        context
76            .get_generator(self.name())
77            .unwrap()
78            .write_to(context, w, index, &self.args())
79    }
80}
81
82impl Generator {
83    pub fn from(s: impl Into<Box<str>>) -> Self {
84        Generator(s.into())
85    }
86
87    pub fn new(s: &str) -> Self {
88        Generator(s.into())
89    }
90
91    pub fn name(&self) -> &str {
92        let n = self
93            .0
94            .find(|c: char| !c.is_ascii_lowercase())
95            .unwrap_or(self.0.len());
96        &self.0[..n]
97    }
98
99    pub fn args(&self) -> Box<[&str]> {
100        let Some(sep) = self.0.chars().find(|&c| !c.is_ascii_lowercase()) else {
101            return [].into();
102        };
103        self.0.split(sep).skip(1).collect()
104    }
105}
106
107impl<'a> Context<'a> {
108    // TODO(soon): remove
109    pub fn with_dict(dict: Arc<dyn Dict + 'a>) -> Self {
110        Context::default().with_default_dict(dict)
111    }
112}
113
114impl Default for Context<'_> {
115    fn default() -> Self {
116        let generators: Vec<Arc<dyn GeneratorFunc>> = vec![Arc::new(Word), Arc::new(Words)];
117        Context::new(generators, [], Arc::new(EFF_WORDLIST))
118    }
119}
120
121fn fmt_with_hash<W>(w: &mut W, hash: &[u8; 32], args: &[&str]) -> fmt::Result
122where
123    W: fmt::Write + ?Sized,
124{
125    if !args.iter().copied().any(|arg| {
126        let mut out = vec![0u8; 32];
127        let Ok(()) = hex::decode_to_slice(arg, &mut out) else {
128            return false;
129        };
130        out == hash
131    }) {
132        let mut out = vec![0u8; 64];
133        hex::encode_to_slice(hash, &mut out).unwrap();
134        let out = String::from_utf8(out).unwrap();
135        write_sep_arg(w, &out)?;
136    };
137    for &arg in args {
138        write_sep_arg(w, arg)?;
139    }
140    Ok(())
141}
142
143impl GeneratorFunc for Word {
144    fn name(&self) -> &'static str {
145        "word"
146    }
147
148    fn size(&self, context: &Context<'_>, args: &[&str]) -> NonZero<U256> {
149        let dict = context.get_dict(&Context::dict_hash(args)).unwrap();
150        NonZero::new(_Word::try_from(dict.words().len()).unwrap().into()).unwrap()
151    }
152
153    fn write_to(
154        &self,
155        context: &Context<'_>,
156        w: &mut dyn io::Write,
157        index: &mut dyn ExposeSecretMut<U256>,
158        args: &[&str],
159    ) -> io::Result<()> {
160        let dict = context.get_dict(&Context::dict_hash(args)).unwrap();
161        let upper = args.iter().copied().any(|s| s == "U");
162        let word = dict.words()[u256_to_word(index.expose_secret_mut()) as usize];
163        if !upper {
164            write!(w, "{word}")?;
165            return Ok(());
166        }
167        let mut iter = word.chars();
168        let first = iter.next().unwrap();
169        write!(w, "{}", first.to_uppercase())?;
170        for c in iter {
171            write!(w, "{c}")?;
172        }
173        Ok(())
174    }
175
176    fn write_repr(
177        &self,
178        context: &Context<'_>,
179        w: &mut dyn fmt::Write,
180        args: &[&str],
181    ) -> fmt::Result {
182        // TODO(soon): clean up
183        let hash = Context::dict_hash(args).unwrap_or_else(|| *context.default_dict.hash());
184        write!(w, "{}", self.name())?;
185        fmt_with_hash(w, &hash, args)
186    }
187}
188
189impl Words {
190    pub fn parse_args<'a>(args: &'_ [&'a str]) -> (u32, &'a str, bool) {
191        let mut count = 5;
192        let mut sep = " ";
193        let mut upper = false;
194        for &arg in args {
195            if let Some(c) = arg.chars().next() {
196                if c.is_ascii_digit()
197                    && let Ok(n) = arg.parse()
198                {
199                    count = n;
200                } else if arg.len() == 1 {
201                    if c.is_ascii_punctuation() {
202                        sep = arg;
203                    } else if c == 'U' {
204                        upper = true;
205                    }
206                }
207            } else {
208                sep = "";
209            }
210        }
211        assert!(count > 0);
212        (count, sep, upper)
213    }
214}
215
216impl GeneratorFunc for Words {
217    fn name(&self) -> &'static str {
218        "words"
219    }
220
221    fn size(&self, context: &Context<'_>, args: &[&str]) -> NonZero<U256> {
222        let (count, _, upper) = Self::parse_args(args);
223        let base = Word.size(context, args);
224        let mut n = U256::ZERO;
225        u256_saturating_pow(&base, count.into(), &mut n);
226        if upper {
227            n = n.saturating_mul(&U256::from_u32(count));
228        }
229        NonZero::new(n).unwrap()
230    }
231
232    fn write_to(
233        &self,
234        context: &Context<'_>,
235        w: &mut dyn io::Write,
236        index: &mut dyn ExposeSecretMut<U256>,
237        args: &[&str],
238    ) -> io::Result<()> {
239        let (count, sep, upper) = Self::parse_args(args);
240        // TODO(soon): better Words -> Word arg mapping
241        let base = Word.size(context, args);
242        let j = if !upper {
243            0
244        } else {
245            let index = index.expose_secret_mut();
246            let j_uint = SecretBox::init_with_mut(|j| {
247                (*index, *j) = index.div_rem(&NonZero::new(U256::from_u32(count)).unwrap());
248            });
249            u32::try_from(u256_to_word(j_uint.expose_secret())).unwrap()
250        };
251        for i in 0..count {
252            if i != 0 {
253                write!(w, "{sep}")?;
254            }
255            let index = index.expose_secret_mut();
256            let mut word_index = SecretBox::init_with_mut(|word_index| {
257                (*index, *word_index) = index.div_rem(&base);
258            });
259            let args: &[&str] = if upper && i == j { &["U"] } else { &[] };
260            Word.write_to(context, w, &mut word_index, args)?;
261        }
262        assert!(bool::from(index.expose_secret_mut().is_zero()));
263        Ok(())
264    }
265
266    fn write_repr(
267        &self,
268        context: &Context<'_>,
269        w: &mut dyn fmt::Write,
270        args: &[&str],
271    ) -> fmt::Result {
272        let hash = Context::dict_hash(args).unwrap_or_else(|| *context.default_dict.hash());
273        write!(w, "{}", self.name())?;
274        fmt_with_hash(w, &hash, args)
275    }
276}
277
278impl<'a> fmt::Debug for dyn GeneratorFunc + 'a {
279    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
280        // TODO(soon): represent args, context
281        write!(f, "GeneratorFunc({:?})", self.name())?;
282        Ok(())
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::{
289        super::{Expr, Node, util::*},
290        *,
291    };
292    use crate::dict::BoxDict;
293
294    #[test]
295    fn test_generators() {
296        let ctx = Context::default();
297        let tests: [(&str, u64, &[(&str, u64)]); _] = [
298            ("word", 7776, &[("abacus", 0), ("zoom", 7775)]),
299            (
300                "words:4:-",
301                0xCFD41B9100000,
302                &[
303                    ("abacus-abacus-abacus-abacus", 0),
304                    ("abdomen-abacus-abacus-abacus", 1),
305                    ("abacus-abdomen-abacus-abacus", 7776),
306                    ("zoology-zoom-zoom-zoom", 0xCFD41B90FFFFE),
307                    ("zoom-zoom-zoom-zoom", 0xCFD41B90FFFFF),
308                ],
309            ),
310            (
311                "words:2:U",
312                0x7354800,
313                &[
314                    ("Abacus abacus", 0),
315                    ("abacus Abacus", 1),
316                    ("Abdomen abacus", 2),
317                    ("abdomen Abacus", 3),
318                    ("Zoom zoom", 0x73547fe),
319                    ("zoom Zoom", 0x73547ff),
320                ],
321            ),
322        ];
323        for (g, sz, tt) in tests {
324            let g = Generator::new(g);
325            assert_eq!(U256::from_u64(sz), *g.size(&ctx));
326            for (s, i) in tt {
327                assert_eq!(s, &format_at_ctx(&g, &ctx, U256::from_u64(*i)));
328            }
329        }
330    }
331
332    #[test]
333    fn test_hashes() {
334        let mut ctx = Context::default();
335        let dict_a = Arc::new(BoxDict::from_lines("a\nb"));
336        let dict_b = Arc::new(BoxDict::from_lines("c\nd"));
337        ctx.extend([dict_a as Arc<dyn Dict>, dict_b]);
338        let ctx = ctx;
339        let a =
340            Generator::new("word|e622f861cfb90d7fc2773ebf739fd5331515e652d2d3bad8d5a24ec90bf505fd");
341        let b =
342            Generator::new("word|ca492d04b5ed9cb47f4405591bb0ca14f5cdf0e45ea86a1d38466e8965e9abb2");
343        assert_eq!("a", &format_at_ctx(&a, &ctx, U256::ZERO));
344        assert_eq!("c", &format_at_ctx(&b, &ctx, U256::ZERO));
345    }
346
347    #[test]
348    fn test_case() {
349        let ctx = Context::default();
350        let g = Generator::new("word:U");
351        assert_eq!("Abacus", &format_at_ctx(&g, &ctx, U256::ZERO));
352        let g = Generator::new("words:U:3:");
353        assert_eq!("Abacusabacusabacus", &format_at_ctx(&g, &ctx, U256::ZERO));
354    }
355
356    #[test]
357    fn test_lifetimes() {
358        let s = "bob\ndole".to_string();
359        let dict = Arc::new(BoxDict::from_lines(&s));
360        let ctx = Context::with_dict(dict);
361        let g = Generator::new("word");
362        assert_eq!(U256::from_u32(2), *g.size(&ctx));
363        assert_eq!("bob", &format_at_ctx(&g, &ctx, U256::from_u32(0)));
364        assert_eq!("dole", &format_at_ctx(&g, &ctx, U256::from_u32(1)));
365    }
366
367    #[test]
368    fn test_fmt() {
369        let expr = Expr::new(Node::Generator(Generator::new("word")));
370        assert_eq!(
371            "{word|323606b363ebdedff9f562cb84c50df1a21cbd4b597ff4566df92bb9f2cefdfd}",
372            &format!("{expr}"),
373        );
374        let expr = Expr::new(Node::Generator(Generator::new("word:up|:too")));
375        assert_eq!(
376            "{word|323606b363ebdedff9f562cb84c50df1a21cbd4b597ff4566df92bb9f2cefdfd|up\\||too}",
377            &format!("{expr}"),
378        );
379        let expr = Expr::new(Node::Generator(Generator::new("word|up:|too")));
380        assert_eq!(
381            "{word|323606b363ebdedff9f562cb84c50df1a21cbd4b597ff4566df92bb9f2cefdfd|up:|too}",
382            &format!("{expr}"),
383        );
384    }
385}