llguidance/lark/
compiler.rs

1use crate::{
2    grammar_builder::{GrammarResult, RegexId},
3    substring::substring,
4    HashMap, HashSet,
5};
6use anyhow::{anyhow, bail, ensure, Result};
7use derivre::RegexAst;
8
9use crate::{
10    api::{GenGrammarOptions, GenOptions, GrammarId, LLGuidanceOptions, NodeProps, RegexExt},
11    json::json_merge,
12    substring::{chunk_into_chars, chunk_into_words},
13    GrammarBuilder, JsonCompileOptions, NodeRef,
14};
15
16use super::{
17    ast::*,
18    common::lookup_common_regex,
19    lexer::Location,
20    parser::{parse_lark, ParsedLark},
21};
22
23#[derive(Debug)]
24struct Grammar {
25    rules: HashMap<String, Rule>,
26    tokens: HashMap<String, TokenDef>,
27    ignore: Vec<Expansions>,
28    llguidance_options: serde_json::Value,
29}
30
31impl Default for Grammar {
32    fn default() -> Self {
33        Self {
34            rules: HashMap::default(),
35            tokens: HashMap::default(),
36            ignore: vec![],
37            llguidance_options: serde_json::Value::Object(serde_json::Map::new()),
38        }
39    }
40}
41
42enum PendingGrammar {
43    Json(serde_json::Value),
44    Lark(Vec<Item>),
45}
46
47struct Compiler {
48    builder: GrammarBuilder,
49    parsed: ParsedLark,
50    grammar: Grammar,
51    node_ids: HashMap<String, NodeRef>,
52    regex_ids: HashMap<String, RegexId>,
53    in_progress: HashSet<String>,
54    pending_grammars: Vec<(NodeRef, Location, PendingGrammar)>,
55}
56
57fn compile_lark(builder: GrammarBuilder, parsed: ParsedLark) -> Result<GrammarResult> {
58    let c = Compiler {
59        builder,
60        parsed,
61        grammar: Grammar::default(),
62        node_ids: HashMap::default(),
63        regex_ids: HashMap::default(),
64        in_progress: HashSet::default(),
65        pending_grammars: vec![],
66    };
67    c.execute()
68}
69
70pub fn lark_to_llguidance(mut builder: GrammarBuilder, lark: &str) -> Result<GrammarResult> {
71    let parsed = parse_lark(lark)?;
72
73    let n = std::cmp::min(lark.len() / 8, 1_000_000);
74    builder.regex.spec.regex_builder.reserve(n);
75
76    compile_lark(builder, parsed)
77}
78
79impl Compiler {
80    fn do_token(&mut self, name: &str) -> Result<RegexId> {
81        if let Some(id) = self.regex_ids.get(name) {
82            return Ok(*id);
83        }
84        if self.in_progress.contains(name) {
85            bail!("circular reference in token {:?} definition", name);
86        }
87        self.in_progress.insert(name.to_string());
88        let token = self
89            .grammar
90            .tokens
91            .remove(name)
92            .ok_or_else(|| anyhow!("unknown name: {:?}", name))?;
93        let id = self.do_token_expansions(token.expansions)?;
94        self.regex_ids.insert(name.to_string(), id);
95        self.in_progress.remove(name);
96        Ok(id)
97    }
98
99    fn mk_regex(&mut self, info: &str, rx: String) -> Result<RegexId> {
100        self.builder
101            .regex
102            .regex(&rx)
103            .map_err(|e| anyhow!("invalid regex {rx:?} (in {info}): {e}"))
104    }
105
106    fn do_token_atom(&mut self, atom: Atom) -> Result<RegexId> {
107        self.builder.check_limits()?;
108        match atom {
109            Atom::Group(expansions) => self.do_token_expansions(expansions),
110            Atom::Maybe(expansions) => {
111                let id = self.do_token_expansions(expansions)?;
112                Ok(self.builder.regex.optional(id))
113            }
114            Atom::Value(value) => match value {
115                Value::LiteralRange(a, b) => {
116                    ensure!(
117                        a.chars().count() == 1,
118                        "range start must be a single character"
119                    );
120                    ensure!(
121                        b.chars().count() == 1,
122                        "range end must be a single character"
123                    );
124                    let a = a.chars().next().unwrap();
125                    let b = b.chars().next().unwrap();
126                    if a <= b {
127                        self.mk_regex(
128                            "range",
129                            format!(
130                                "[{}-{}]",
131                                regex_syntax::escape(&a.to_string()),
132                                regex_syntax::escape(&b.to_string())
133                            ),
134                        )
135                    } else {
136                        bail!("invalid range order: {:?}..{:?}", a, b);
137                    }
138                }
139                Value::Name(n) => self.do_token(&n),
140                Value::LiteralString(val, flags) => {
141                    if flags.contains("i") {
142                        self.mk_regex(
143                            "string with i-flag",
144                            format!("(?i){}", regex_syntax::escape(&val)),
145                        )
146                    } else {
147                        Ok(self.builder.regex.literal(val))
148                    }
149                }
150                Value::LiteralRegex(val, flags) => {
151                    ensure!(!flags.contains("l"), "l-flag is not supported in regexes");
152                    let rx = if flags.is_empty() {
153                        val
154                    } else {
155                        format!("(?{}){}", flags, val)
156                    };
157                    self.mk_regex("regex", rx)
158                }
159                Value::RegexExt(s) => compile_lark_regex(&mut self.builder, s),
160                Value::SpecialToken(s) => {
161                    bail!("special tokens (like {:?}) cannot be used in terminals", s);
162                }
163                Value::Json(_) => {
164                    bail!("%json literals cannot be used in terminals");
165                }
166                Value::GrammarRef(g) => {
167                    bail!(
168                        "grammar references (like {:?}) cannot be used in terminals",
169                        g
170                    );
171                }
172                Value::NestedLark(_) => {
173                    bail!("nested %lark {{ ... }} cannot be used in terminals");
174                }
175                Value::TemplateUsage { .. } => bail!("template usage not supported yet"),
176            },
177        }
178    }
179
180    fn do_token_expr(&mut self, expr: Expr) -> Result<RegexId> {
181        let atom = self.do_token_atom(expr.atom)?;
182        if let Some(range) = &expr.range {
183            ensure!(expr.op.is_none(), "ranges not supported with operators");
184            ensure!(range.0 >= 0, "range start must be >= 0, got {:?}", range);
185            ensure!(
186                range.1 >= range.0,
187                "range end must be >= start, got {:?}",
188                range
189            );
190            Ok(self.builder.regex.repeat(
191                atom,
192                range.0 as u32,
193                if range.1 == i32::MAX {
194                    None
195                } else {
196                    Some(range.1 as u32)
197                },
198            ))
199        } else {
200            match &expr.op {
201                Some(op) => match op.0.as_str() {
202                    "*" => Ok(self.builder.regex.zero_or_more(atom)),
203                    "+" => Ok(self.builder.regex.one_or_more(atom)),
204                    "?" => Ok(self.builder.regex.optional(atom)),
205                    _ => {
206                        bail!("unsupported operator: {:?}", op.0);
207                    }
208                },
209                None => Ok(atom),
210            }
211        }
212    }
213
214    fn do_token_expansions(&mut self, expansions: Expansions) -> Result<RegexId> {
215        self.builder.check_limits()?;
216        let options = expansions
217            .1
218            .into_iter()
219            .map(|alias| {
220                let args = alias
221                    .expansion
222                    .0
223                    .into_iter()
224                    .map(|e| self.do_token_expr(e))
225                    .collect::<Result<Vec<_>>>()?;
226                Ok(self.builder.regex.concat(args))
227            })
228            .collect::<Result<Vec<_>>>()
229            .map_err(|e| expansions.0.augment(e))?;
230        Ok(self.builder.regex.select(options))
231    }
232
233    fn lift_regex(&mut self, rx_id: RegexId) -> Result<NodeRef> {
234        Ok(self.builder.lexeme(rx_id))
235    }
236
237    fn do_nested(
238        &mut self,
239        loc: &Location,
240        v: Value,
241        temperature: Option<f32>,
242        props: NodeProps,
243    ) -> Result<NodeRef> {
244        let inner = match v {
245            Value::NestedLark(items) => PendingGrammar::Lark(items),
246            Value::Json(json) => PendingGrammar::Json(json),
247            _ => bail!("expected %lark or %json, got {:?}", v),
248        };
249        let name = format!("%nested---{}", self.builder.num_nodes());
250        let gg = self.builder.gen_grammar(
251            GenGrammarOptions {
252                grammar: GrammarId::Name(name),
253                temperature,
254            },
255            props,
256        );
257        self.pending_grammars.push((gg, loc.clone(), inner));
258        Ok(gg)
259    }
260
261    fn do_atom(&mut self, loc: &Location, expr: Atom) -> Result<NodeRef> {
262        match expr {
263            Atom::Group(expansions) => self.do_expansions(expansions),
264            Atom::Maybe(expansions) => {
265                let id = self.do_expansions(expansions)?;
266                Ok(self.builder.optional(id))
267            }
268            Atom::Value(value) => {
269                match &value {
270                    Value::Name(n) => {
271                        if self.is_rule(n) {
272                            return self.do_rule(n);
273                        } else {
274                            // OK -> treat as token
275                        }
276                    }
277                    Value::SpecialToken(s) => {
278                        if s.starts_with("<[") && s.ends_with("]>") {
279                            let s = &s[2..s.len() - 2];
280                            let mut ranges = vec![];
281                            for range in s.split(",") {
282                                let ends: Vec<&str> = range.split('-').map(|s| s.trim()).collect();
283                                ensure!(
284                                    ends.len() == 1 || ends.len() == 2,
285                                    "invalid token range: {:?}",
286                                    range
287                                );
288                                if ends.len() == 1 && ends[0].is_empty() {
289                                    continue;
290                                }
291                                let start = ends[0].parse::<u32>()?;
292                                let end = if ends.len() == 2 {
293                                    ends[1].parse::<u32>()?
294                                } else {
295                                    start
296                                };
297                                ensure!(start <= end, "invalid token range: {:?}", range);
298                                ranges.push(start..=end);
299                            }
300                            ensure!(!ranges.is_empty(), "empty token range");
301                            return self.builder.token_ranges(ranges);
302                        }
303                        return self.builder.special_token(s);
304                    }
305                    Value::GrammarRef(g) => {
306                        return self.gen_grammar(g, None, NodeProps::default());
307                    }
308                    Value::NestedLark(_) | Value::Json(_) => {
309                        return self.do_nested(loc, value, None, NodeProps::default());
310                    }
311                    // special case "" literal, so it doesn't pollute grammar with epsilon regex
312                    Value::LiteralString(s, _) if s.is_empty() => return Ok(self.builder.empty()),
313                    Value::RegexExt(_)
314                    | Value::LiteralRange(_, _)
315                    | Value::LiteralString(_, _)
316                    | Value::LiteralRegex(_, _) => {
317                        // treat as token
318                    }
319                    Value::TemplateUsage { .. } => {
320                        bail!("template usage not supported yet");
321                    }
322                };
323                let rx = self.do_token_atom(Atom::Value(value))?;
324                Ok(self.lift_regex(rx)?)
325            }
326        }
327    }
328
329    fn do_expr(&mut self, loc: &Location, expr: Expr) -> Result<NodeRef> {
330        let atom = self.do_atom(loc, expr.atom)?;
331
332        if let Some((a, b)) = expr.range {
333            ensure!(expr.op.is_none(), "ranges not supported with operators");
334            ensure!(a <= b, "range end must be >= start, got {:?}", (a, b));
335            ensure!(a >= 0, "range start must be >= 0, got {:?}", a);
336            Ok(self.builder.repeat(
337                atom,
338                a as usize,
339                if b == i32::MAX {
340                    None
341                } else {
342                    Some(b as usize)
343                },
344            ))
345        } else {
346            match &expr.op {
347                Some(op) => match op.0.as_str() {
348                    "*" => Ok(self.builder.zero_or_more(atom)),
349                    "+" => Ok(self.builder.one_or_more(atom)),
350                    "?" => Ok(self.builder.optional(atom)),
351                    _ => {
352                        bail!("unsupported operator: {}", op.0);
353                    }
354                },
355                None => Ok(atom),
356            }
357        }
358    }
359
360    fn do_expansions(&mut self, expansions: Expansions) -> Result<NodeRef> {
361        self.builder.check_limits()?;
362        let loc = expansions.0;
363        let options = expansions
364            .1
365            .into_iter()
366            .map(|alias| {
367                let args = alias
368                    .expansion
369                    .0
370                    .into_iter()
371                    .map(|e| self.do_expr(&loc, e))
372                    .collect::<Result<Vec<_>>>()?;
373                Ok(self.builder.join(&args))
374            })
375            .collect::<Result<Vec<_>>>()
376            .map_err(|e| loc.augment(e))?;
377        Ok(self.builder.select(&options))
378    }
379
380    fn is_rule(&self, name: &str) -> bool {
381        self.node_ids.contains_key(name)
382            || self.in_progress.contains(name)
383            || self.grammar.rules.contains_key(name)
384    }
385
386    fn do_rule(&mut self, name: &str) -> Result<NodeRef> {
387        if let Some(id) = self.node_ids.get(name) {
388            return Ok(*id);
389        }
390        if self.in_progress.contains(name) {
391            let id = self.builder.new_node(name);
392            self.node_ids.insert(name.to_string(), id);
393            return Ok(id);
394        }
395        self.in_progress.insert(name.to_string());
396
397        let id = self.do_rule_core(name)?;
398
399        if let Some(placeholder) = self.node_ids.get(name) {
400            self.builder.set_placeholder(*placeholder, id);
401        }
402        self.node_ids.insert(name.to_string(), id);
403        self.in_progress.remove(name);
404        Ok(id)
405    }
406
407    fn gen_grammar(
408        &mut self,
409        name: &str,
410        temperature: Option<f32>,
411        props: NodeProps,
412    ) -> Result<NodeRef> {
413        assert!(name.starts_with("@"));
414        // see if name[1..] is an integer
415        let name = if name[1..].parse::<usize>().is_ok() {
416            bail!("numeric grammar references no longer supported");
417        } else {
418            name[1..].to_string()
419        };
420        let id = self.builder.gen_grammar(
421            GenGrammarOptions {
422                grammar: GrammarId::Name(name.clone()),
423                temperature,
424            },
425            props,
426        );
427        Ok(id)
428    }
429
430    fn do_rule_core(&mut self, name: &str) -> Result<NodeRef> {
431        let mut rule = self
432            .grammar
433            .rules
434            .remove(name)
435            .ok_or_else(|| anyhow!("rule {:?} not found", name))?;
436
437        let props = NodeProps {
438            max_tokens: rule.max_tokens,
439            capture_name: rule.capture_name.clone(),
440            ..Default::default()
441        };
442
443        if rule.stop.is_some() && rule.suffix.is_some() {
444            bail!("stop= and suffix= cannot be used together");
445        }
446
447        let id = if let Some(stop) = rule.stop_like() {
448            let is_suffix = rule.suffix.is_some();
449            let is_empty = matches!(stop, Value::LiteralString(s, _) if s.is_empty());
450            let lazy = rule.is_lazy();
451            let stop_val = Atom::Value(rule.take_stop_like().unwrap());
452            let rx_id = self.do_token_expansions(rule.expansions)?;
453            let stop_id = self.do_token_atom(stop_val)?;
454
455            self.builder.gen(
456                GenOptions {
457                    body_rx: RegexAst::ExprRef(rx_id),
458                    stop_rx: if is_empty {
459                        RegexAst::EmptyString
460                    } else {
461                        RegexAst::ExprRef(stop_id)
462                    },
463                    stop_capture_name: rule.stop_capture_name.clone(),
464                    lazy: Some(lazy),
465                    temperature: rule.temperature,
466                    is_suffix: Some(is_suffix),
467                },
468                props,
469            )?
470        } else {
471            ensure!(
472                rule.stop_capture_name.is_none(),
473                "stop_capture_name requires stop= or suffix="
474            );
475            if rule.temperature.is_some() || rule.max_tokens.is_some() {
476                match rule.expansions.single_atom() {
477                    Some(Atom::Value(Value::GrammarRef(g))) => {
478                        return self.gen_grammar(g, rule.temperature, props);
479                    }
480                    Some(Atom::Value(Value::Json(_) | Value::NestedLark(_))) => {
481                        if let Atom::Value(x) = rule.expansions.1[0].expansion.0.pop().unwrap().atom
482                        {
483                            return self.do_nested(&rule.expansions.0, x, rule.temperature, props);
484                        } else {
485                            unreachable!();
486                        }
487                    }
488                    _ => {
489                        // try as terminal
490                        let rx_id = self.do_token_expansions(rule.expansions).map_err(|e| {
491                            anyhow::anyhow!(
492                                "{}; temperature= and max_tokens= only \
493                                supported on TERMINALS and @subgrammars",
494                                e
495                            )
496                        })?;
497                        return Ok(self.builder.lexeme_ext(rx_id, rule.temperature, props));
498                    }
499                }
500            }
501
502            let inner = self.do_expansions(rule.expansions)?;
503            #[allow(clippy::assertions_on_constants)]
504            if let Some(max_tokens) = rule.max_tokens {
505                assert!(false, "max_tokens handled above for now");
506                self.builder.join_props(
507                    &[inner],
508                    NodeProps {
509                        max_tokens: Some(max_tokens),
510                        // assume the user also wants capture
511                        capture_name: Some(name.to_string()),
512                        ..Default::default()
513                    },
514                )
515            } else if rule.capture_name.is_some() {
516                self.builder.join_props(&[inner], props)
517            } else {
518                inner
519            }
520        };
521        Ok(id)
522    }
523
524    fn execute(mut self) -> Result<GrammarResult> {
525        let mut grm = Grammar::default();
526        for item in std::mem::take(&mut self.parsed.items) {
527            let loc = item.location().clone();
528            grm.process_item(item).map_err(|e| loc.augment(e))?;
529        }
530        let start_name = "start";
531        ensure!(
532            grm.rules.contains_key(start_name),
533            "no {} rule found",
534            start_name
535        );
536        let ignore = std::mem::take(&mut grm.ignore);
537        self.grammar = grm;
538
539        let opts: LLGuidanceOptions =
540            serde_json::from_value(self.grammar.llguidance_options.clone())
541                .map_err(|e| anyhow!("failed to parse %llguidance declaration: {}", e))?;
542
543        let ignore = ignore
544            .into_iter()
545            .map(|exp| Ok(RegexAst::ExprRef(self.do_token_expansions(exp)?)))
546            .collect::<Result<Vec<_>>>()?;
547        let id = self.builder.add_grammar(opts, RegexAst::Or(ignore))?;
548
549        let start = self.do_rule(start_name)?;
550        self.builder.set_start_node(start);
551
552        let mut builder = self.builder;
553        for (gg, loc, grm) in self.pending_grammars {
554            let res = match grm {
555                PendingGrammar::Json(json_schema) => JsonCompileOptions::default()
556                    .json_to_llg_with_overrides(builder, json_schema)
557                    .map_err(|e| loc.augment(anyhow!("failed to compile JSON schema: {}", e)))?,
558                PendingGrammar::Lark(items) => compile_lark(builder, ParsedLark { items })?,
559            };
560            builder = res.builder;
561            builder.link_gen_grammar(gg, res.start_node)?;
562        }
563
564        Ok(builder.finalize(id))
565    }
566}
567
568impl Grammar {
569    fn add_token_def(&mut self, loc: &Location, local_name: String, regex: &str) -> Result<()> {
570        ensure!(
571            !self.tokens.contains_key(&local_name),
572            "duplicate token (in import): {:?}",
573            local_name
574        );
575
576        let t = TokenDef {
577            name: local_name,
578            params: None,
579            priority: None,
580            expansions: Expansions(
581                loc.clone(),
582                vec![Alias {
583                    expansion: Expansion(vec![Expr {
584                        atom: Atom::Value(Value::LiteralRegex(regex.to_string(), "".to_string())),
585                        op: None,
586                        range: None,
587                    }]),
588                    alias: None,
589                }],
590            ),
591        };
592        self.tokens.insert(t.name.clone(), t);
593        Ok(())
594    }
595
596    fn do_statement(&mut self, loc: &Location, statement: Statement) -> Result<()> {
597        match statement {
598            Statement::Ignore(exp) => {
599                self.ignore.push(exp);
600            }
601            Statement::Import { path, alias } => {
602                let regex = lookup_common_regex(&path)?;
603                let local_name =
604                    alias.unwrap_or_else(|| path.split('.').next_back().unwrap().to_string());
605                self.add_token_def(loc, local_name, regex)?;
606            }
607            Statement::MultiImport { path, names } => {
608                for n in names {
609                    let qname = format!("{}.{}", path, n);
610                    let regex = lookup_common_regex(&qname)?;
611                    self.add_token_def(loc, n.to_string(), regex)?;
612                }
613            }
614            Statement::LLGuidance(json_value) => {
615                // merge-in at the JSON level
616                json_merge(&mut self.llguidance_options, &json_value);
617                // but also check if it's valid format and all the right types
618                let _v: LLGuidanceOptions = serde_json::from_value(json_value)
619                    .map_err(|e| anyhow!("failed to parse %llguidance declaration: {}", e))?;
620            }
621            Statement::OverrideRule(_) => {
622                bail!("override statement not supported yet");
623            }
624            Statement::Declare(_) => {
625                bail!("declare statement not supported yet");
626            }
627        }
628        Ok(())
629    }
630
631    fn process_item(&mut self, item: Item) -> Result<()> {
632        match item {
633            Item::Rule(rule) => {
634                ensure!(rule.params.is_none(), "params not supported yet");
635                ensure!(rule.priority.is_none(), "priority not supported yet");
636                ensure!(
637                    !self.rules.contains_key(&rule.name),
638                    "duplicate rule: {:?}",
639                    rule.name
640                );
641                self.rules.insert(rule.name.clone(), rule);
642            }
643            Item::Token(token_def) => {
644                ensure!(token_def.params.is_none(), "params not supported yet");
645                ensure!(token_def.priority.is_none(), "priority not supported yet");
646                ensure!(
647                    !self.tokens.contains_key(&token_def.name),
648                    "duplicate token: {:?}",
649                    token_def.name
650                );
651                self.tokens.insert(token_def.name.clone(), token_def);
652            }
653            Item::Statement(loc, statement) => {
654                self.do_statement(&loc, statement)?;
655            }
656        }
657        Ok(())
658    }
659}
660
661fn compile_lark_regex(builder: &mut GrammarBuilder, l: RegexExt) -> Result<RegexId> {
662    let mut fields_set = vec![];
663    if l.substring_chunks.is_some() {
664        fields_set.push("substring_chunks");
665    }
666    if l.substring_words.is_some() {
667        fields_set.push("substring_words");
668    }
669    if l.substring_chars.is_some() {
670        fields_set.push("substring_chars");
671    }
672    if fields_set.is_empty() {
673        bail!("no fields set on %regex");
674    }
675    if fields_set.len() > 1 {
676        bail!("only one field can be set on %regex; got {:?}", fields_set);
677    }
678
679    let bld = &mut builder.regex.spec.regex_builder;
680
681    let eref = if let Some(s) = l.substring_words {
682        substring(bld, chunk_into_words(&s))?
683    } else if let Some(s) = l.substring_chars {
684        substring(bld, chunk_into_chars(&s))?
685    } else if let Some(s) = l.substring_chunks {
686        substring(bld, s.iter().map(|s| s.as_str()).collect())?
687    } else {
688        unreachable!()
689    };
690
691    Ok(eref)
692}