Skip to main content

objectiveai_api/vector/completions/
pfx.rs

1//! Prefix tree for response key generation.
2//!
3//! Generates unique prefix keys (e.g., `` `A` ``, `` `B` ``) for labeling vector responses.
4//! The LLM sees these keys and responds with its choice.
5//!
6//! The tree structure is designed around logprobs for probabilistic voting. Instead of
7//! relying on the LLM's final sampled answer, we use logprobs to capture a probability
8//! distribution over responses. The leaf width matches the number of logprobs the LLM
9//! generates (e.g., 20 logprobs = 20 leaves per branch). For large response sets, nested
10//! structures (`` `A` `` `` `A` ``, `` `A` `` `` `B` ``) allow capturing preferences across
11//! more responses than a single logprobs batch allows.
12//!
13//! This enables probabilistic voting: LLMs are inherently probabilistic, and the sampler
14//! makes the final discrete choice. By using logprobs, we bypass the sampler and capture
15//! the model's full preference distribution.
16
17use indexmap::IndexMap;
18use rand::{Rng, seq::SliceRandom};
19use std::sync::Arc;
20
21/// Single-character prefix labels A-T.
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
23pub enum Pfx {
24    A,
25    B,
26    C,
27    D,
28    E,
29    F,
30    G,
31    H,
32    I,
33    J,
34    K,
35    L,
36    M,
37    N,
38    O,
39    P,
40    Q,
41    R,
42    S,
43    T,
44}
45
46impl Pfx {
47    /// Converts this prefix to its character representation.
48    pub fn to_char(&self) -> char {
49        match self {
50            Pfx::A => 'A',
51            Pfx::B => 'B',
52            Pfx::C => 'C',
53            Pfx::D => 'D',
54            Pfx::E => 'E',
55            Pfx::F => 'F',
56            Pfx::G => 'G',
57            Pfx::H => 'H',
58            Pfx::I => 'I',
59            Pfx::J => 'J',
60            Pfx::K => 'K',
61            Pfx::L => 'L',
62            Pfx::M => 'M',
63            Pfx::N => 'N',
64            Pfx::O => 'O',
65            Pfx::P => 'P',
66            Pfx::Q => 'Q',
67            Pfx::R => 'R',
68            Pfx::S => 'S',
69            Pfx::T => 'T',
70        }
71    }
72
73    /// Parses a character into a prefix, if valid.
74    pub fn from_char(c: char) -> Option<Self> {
75        match c {
76            'A' => Some(Pfx::A),
77            'B' => Some(Pfx::B),
78            'C' => Some(Pfx::C),
79            'D' => Some(Pfx::D),
80            'E' => Some(Pfx::E),
81            'F' => Some(Pfx::F),
82            'G' => Some(Pfx::G),
83            'H' => Some(Pfx::H),
84            'I' => Some(Pfx::I),
85            'J' => Some(Pfx::J),
86            'K' => Some(Pfx::K),
87            'L' => Some(Pfx::L),
88            'M' => Some(Pfx::M),
89            'N' => Some(Pfx::N),
90            'O' => Some(Pfx::O),
91            'P' => Some(Pfx::P),
92            'Q' => Some(Pfx::Q),
93            'R' => Some(Pfx::R),
94            'S' => Some(Pfx::S),
95            'T' => Some(Pfx::T),
96            _ => None,
97        }
98    }
99
100    /// Returns all prefixes in randomized order.
101    pub fn rng_vec(rng: &mut impl Rng) -> Vec<Self> {
102        let mut vec = vec![
103            Pfx::A,
104            Pfx::B,
105            Pfx::C,
106            Pfx::D,
107            Pfx::E,
108            Pfx::F,
109            Pfx::G,
110            Pfx::H,
111            Pfx::I,
112            Pfx::J,
113            Pfx::K,
114            Pfx::L,
115            Pfx::M,
116            Pfx::N,
117            Pfx::O,
118            Pfx::P,
119            Pfx::Q,
120            Pfx::R,
121            Pfx::S,
122            Pfx::T,
123        ];
124        vec.shuffle(rng);
125        vec
126    }
127}
128
129impl std::fmt::Display for Pfx {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        write!(f, "{}", self.to_char())
132    }
133}
134
135/// A tree structure for generating unique prefix keys.
136///
137/// The tree width is determined by the number of logprobs the LLM generates.
138/// For flat structures (`` `A` ``, `` `B` ``), each leaf corresponds to one logprob slot.
139/// For large response sets exceeding the logprobs limit, nested structures
140/// (`` `A` `` `` `A` ``, `` `A` `` `` `B` ``) allow capturing preferences in stages.
141#[derive(Debug, Clone)]
142pub enum PfxTree {
143    /// A branch containing child nodes.
144    Branch(Arc<IndexMap<Pfx, PfxTree>>),
145    /// A leaf containing the response index.
146    Leaf(usize),
147}
148
149impl PfxTree {
150    /// Creates a new prefix tree for the given number of responses.
151    ///
152    /// The `max_branch_len` should match the number of logprobs the LLM generates,
153    /// ensuring each branch fits within one logprobs batch for probability capture.
154    pub fn new(
155        rng: &mut impl Rng,
156        source_len: usize,
157        max_branch_len: usize,
158    ) -> Self {
159        let mut source: Vec<usize> = (0..source_len).collect();
160        source.shuffle(rng);
161        Self::new_inner(rng, &source, max_branch_len, false)
162    }
163
164    /// Internal recursive constructor.
165    pub fn new_inner(
166        rng: &mut impl Rng,
167        source: &[usize],
168        max_branch_len: usize,
169        force_sub_branch: bool,
170    ) -> Self {
171        let pfxs = Pfx::rng_vec(rng);
172        if !force_sub_branch && source.len() <= max_branch_len {
173            // return a single branch containing all leaves
174            let mut branch = IndexMap::with_capacity(source.len());
175            for (i, source_index) in source.iter().enumerate() {
176                branch.insert(pfxs[i], PfxTree::Leaf(*source_index));
177            }
178            Self::Branch(Arc::new(branch))
179        } else {
180            // split into sub-branches
181            let n = {
182                let candidate =
183                    (source.len() + max_branch_len - 1) / max_branch_len;
184                if candidate <= max_branch_len {
185                    candidate
186                } else {
187                    max_branch_len
188                }
189            };
190            let base_per = source.len() / n;
191            let extra = source.len() % n;
192            let force_sub_branch =
193                base_per + { if extra > 0 { 1 } else { 0 } } > max_branch_len;
194            let mut branch = IndexMap::with_capacity(n);
195            let mut i = 0;
196            let mut count = 0;
197            while i < n {
198                let branch_len = base_per + if i < extra { 1 } else { 0 };
199                branch.insert(
200                    pfxs[i],
201                    PfxTree::new_inner(
202                        rng,
203                        &source[count..count + branch_len],
204                        max_branch_len,
205                        force_sub_branch,
206                    ),
207                );
208                count += branch_len;
209                i += 1;
210            }
211            Self::Branch(Arc::new(branch))
212        }
213    }
214
215    /// Generates prefix-to-index mappings in randomized order.
216    ///
217    /// Returns pairs of (prefix key, response index).
218    pub fn pfx_indices(
219        &self,
220        rng: &mut impl Rng,
221        source_len: usize,
222    ) -> Vec<(String, usize)> {
223        let mut indices = Vec::with_capacity(source_len);
224        self.pfx_indices_inner(None, &mut indices);
225        indices.shuffle(rng);
226        indices
227    }
228
229    /// Internal recursive method for generating prefix indices.
230    pub fn pfx_indices_inner(
231        &self,
232        parent_pfx: Option<String>,
233        indices: &mut Vec<(String, usize)>,
234    ) {
235        match self {
236            PfxTree::Branch(branch) => {
237                for (pfx, child) in branch.as_ref() {
238                    let parent_pfx = Some(match &parent_pfx {
239                        Some(parent_pfx) => format!("{}`{}`", parent_pfx, pfx),
240                        None => format!("`{}`", pfx),
241                    });
242                    child.pfx_indices_inner(parent_pfx, indices);
243                }
244            }
245            PfxTree::Leaf(index) => {
246                indices.push((parent_pfx.unwrap(), *index));
247            }
248        }
249    }
250
251    /// Gets a child node by prefix character.
252    pub fn get(&self, pfx: Pfx) -> Option<PfxTree> {
253        match self {
254            PfxTree::Branch(branch) => branch.get(&pfx).cloned(),
255            PfxTree::Leaf(_) => None,
256        }
257    }
258
259    /// Returns the depth of the tree.
260    pub fn depth(&self) -> usize {
261        match self {
262            PfxTree::Branch(branch) => {
263                1 + branch
264                    .values()
265                    .next() // all sub-branches have the same depth
266                    .map(|v| v.depth())
267                    .unwrap_or(0)
268            }
269            PfxTree::Leaf(_) => 0,
270        }
271    }
272
273    /// Unwraps a leaf node to get its response index.
274    ///
275    /// Panics if called on a branch node.
276    pub fn unwrap_leaf(&self) -> usize {
277        match self {
278            PfxTree::Leaf(index) => *index,
279            PfxTree::Branch(_) => {
280                panic!("Called unwrap_leaf on a Branch")
281            }
282        }
283    }
284
285    /// Generates regex patterns for matching response keys.
286    ///
287    /// Returns (pattern with backticks, pattern without backticks).
288    pub fn regex_patterns(&self, keys: &[(String, usize)]) -> (String, String) {
289        let depth = self.depth();
290        let mut with_ticks = String::with_capacity(
291            (keys.len() - 1) // '|' characters
292                + (keys.len() * depth * 3) // each key
293                + keys.len() * 2, // parentheses
294        );
295        let mut without_ticks = String::with_capacity(
296            (keys.len() - 1) // for '|' characters
297                + keys.len() * (depth * 3 - 2) // each key stripped of ticks
298                + keys.len() * 2, // parentheses
299        );
300        for (key, _) in keys {
301            if with_ticks.len() > 0 {
302                with_ticks.push('|');
303                without_ticks.push('|');
304            }
305            with_ticks.push('(');
306            without_ticks.push('(');
307            with_ticks.push_str(key);
308            without_ticks.push_str(&key[1..key.len() - 1]); // strip ticks
309            with_ticks.push(')');
310            without_ticks.push(')');
311        }
312        (with_ticks, without_ticks)
313    }
314}
315
316/// Prefix data for a specific LLM, including tree and regex patterns.
317#[derive(Debug, Clone)]
318pub struct PfxData {
319    /// The prefix tree for this LLM.
320    pub pfx_tree: PfxTree,
321    /// Regex pattern matching response keys with backticks.
322    pub responses_key_pattern: String,
323    /// Regex pattern matching response keys without backticks.
324    pub responses_key_pattern_stripped: String,
325}