Skip to main content

objectiveai_api/vector/completions/
get_vote.rs

1//! Vote extraction from LLM responses.
2//!
3//! Extracts votes from LLM chat completion responses by parsing response keys
4//! and computing probability distributions from logprobs when available.
5
6use regex::Regex;
7use rust_decimal::MathematicalOps;
8
9/// Extracts a vote from an LLM choice.
10///
11/// Parses the response to find selected response keys and converts them into a
12/// probability distribution. When logprobs are available, uses them to capture
13/// the model's preference distribution (probabilistic voting). Otherwise, falls
14/// back to discrete voting based on the final sampled token.
15///
16/// Returns None if no response key is found in the content.
17pub fn get_vote(
18    mut pfx_tree: super::PfxTree,
19    with_ticks_pattern: &str,
20    without_ticks_pattern: &str,
21    responses_len: usize,
22    choice: &objectiveai::chat::completions::response::streaming::Choice,
23) -> Option<Vec<rust_decimal::Decimal>> {
24    // extract content, return None if empty
25    let content_owned = match Content::from_choice(choice) {
26        Some(c) => c,
27        None => {
28            return None;
29        }
30    };
31    let content = content_owned.as_str();
32
33    // extract response keys, return if not found
34    let with_ticks_re = Regex::new(with_ticks_pattern).unwrap();
35    let mut key_matches = with_ticks_re.find_iter(content).collect::<Vec<_>>();
36    let without_ticks_re = match key_matches.len() {
37        0 => Some(Regex::new(without_ticks_pattern).unwrap()),
38        _ => None,
39    };
40    if let Some(without_ticks_re) = without_ticks_re.as_ref() {
41        key_matches = without_ticks_re.find_iter(content).collect::<Vec<_>>();
42    }
43
44    // return None if no keys found
45    if key_matches.is_empty() {
46        return None;
47    }
48
49    // each match has an equal vote weight
50    let key_matches_len_decimal =
51        rust_decimal::Decimal::from(key_matches.len());
52
53    // reverse matches for processing
54    let keys_rev = key_matches
55        .into_iter()
56        .rev()
57        .map(|cap| cap.as_str())
58        .collect::<Vec<_>>();
59
60    // prepare vote
61    let mut vote = vec![rust_decimal::Decimal::ZERO; responses_len];
62
63    // track logprob index
64    let mut logprob_i = 0;
65
66    for key in keys_rev {
67        // get the final prefix
68        let (final_pfx_char, final_pfx) = key
69            .chars()
70            .rev()
71            .map(|c| (c, super::Pfx::from_char(c)))
72            .filter(|(_, pfx)| pfx.is_some())
73            .next()
74            .unwrap();
75        let final_pfx = final_pfx.unwrap();
76
77        // get to the lowest pfx tree branch
78        let mut i = pfx_tree.depth() - 1;
79        if i > 0 {
80            for c in key.chars() {
81                if let Some(pfx) = super::Pfx::from_char(c) {
82                    pfx_tree = pfx_tree.get(pfx).unwrap();
83                    i -= 1;
84                    if i == 0 {
85                        break;
86                    }
87                }
88            }
89        }
90        let pfx_tree = match pfx_tree.clone() {
91            super::PfxTree::Branch(branch) => branch,
92            super::PfxTree::Leaf(_) => unreachable!(),
93        };
94
95        // try to get probabilities from logprobs
96        let mut from_logprobs = false;
97        if let Some(objectiveai::chat::completions::response::Logprobs {
98            content: Some(logprobs),
99            ..
100        }) = choice.logprobs.as_ref()
101        {
102            // reverse key to check against
103            let key_rev = key.chars().rev().collect::<String>();
104
105            // slice as we go
106            let mut key_rev_slice = key_rev.as_str();
107
108            // keep the relevant logprob
109            let mut key_logprob = None;
110            let mut key_logprob_index = 0;
111
112            // find the logprob segment that matches the key
113            'outer: for logprob in logprobs.iter().rev().skip(logprob_i) {
114                logprob_i += 1;
115                let mut i = logprob.token.len();
116                for c in logprob.token.chars().rev() {
117                    i -= c.len_utf8();
118                    if key_rev_slice.starts_with(c) {
119                        // match
120                        // remove the matched char from the slice
121                        key_rev_slice = &key_rev_slice[c.len_utf8()..];
122                        // keep the logprob that contains the final pfx
123                        if key_logprob.is_none() && c == final_pfx_char {
124                            key_logprob = Some(logprob);
125                            key_logprob_index = i;
126                        }
127                        // stop when the full match is found
128                        if key_rev_slice.is_empty() {
129                            break 'outer;
130                        }
131                    } else if key_rev_slice.len() != key_rev.len() {
132                        // not match
133                        // reset
134                        key_rev_slice = key_rev.as_str();
135                        key_logprob = None;
136                        key_logprob_index = 0;
137                    } else {
138                        // unknown
139                    }
140                }
141            }
142
143            // matching logprob segment found
144            if key_rev_slice.is_empty() {
145                // collect probabilities
146                let mut probabilities =
147                    vec![rust_decimal::Decimal::ZERO; responses_len];
148                let mut probabilities_sum = rust_decimal::Decimal::ZERO;
149                for objectiveai::chat::completions::response::TopLogprob {
150                    token,
151                    logprob,
152                    ..
153                } in &key_logprob.as_ref().unwrap().top_logprobs
154                {
155                    if key_logprob_index < token.len()
156                        && let Some(logprob) = logprob
157                        && let Some((_, c)) = token
158                            .char_indices()
159                            .find(|(i, _)| *i == key_logprob_index)
160                        && let Some(pfx) = super::Pfx::from_char(c)
161                        && let Some(leaf) = pfx_tree.get(&pfx)
162                    {
163                        // logprobs sourced vote successful
164                        from_logprobs = true;
165
166                        // add to probabilities
167                        let probability = logprob.exp();
168                        probabilities[leaf.unwrap_leaf()] += probability;
169                        probabilities_sum += probability;
170                    }
171                }
172
173                // normalize and add to vote
174                if probabilities_sum > rust_decimal::Decimal::ZERO {
175                    let mut vote_i = 0;
176                    while vote_i < vote.len() {
177                        vote[vote_i] += (probabilities[vote_i]
178                            / probabilities_sum)
179                            / key_matches_len_decimal;
180                        vote_i += 1;
181                    }
182                }
183            }
184        }
185
186        // fallback, set vote indexed to selected response to 1.0
187        if !from_logprobs {
188            vote[pfx_tree.get(&final_pfx).unwrap().unwrap_leaf()] =
189                rust_decimal::Decimal::ONE / key_matches_len_decimal;
190        }
191    }
192
193    // return vote
194    Some(vote)
195}
196
197/// Helper for extracting content from choices without unnecessary allocation.
198enum Content<'s> {
199    /// Borrowed content from choice.delta.content.
200    Ref(&'s str),
201    /// Owned content when combining tool call arguments with content.
202    Owned(String),
203}
204
205impl<'s> Content<'s> {
206    /// Returns the content as a string slice.
207    fn as_str(&self) -> &str {
208        match self {
209            Content::Ref(s) => s,
210            Content::Owned(s) => s.as_str(),
211        }
212    }
213
214    /// Extracts content from a choice, combining tool call arguments if present.
215    fn from_choice(
216        choice: &'s objectiveai::chat::completions::response::streaming::Choice,
217    ) -> Option<Self> {
218        match choice.delta.tool_calls.as_ref() {
219            Some(tool_calls) => {
220                let mut len = 0;
221                for tool_call in tool_calls {
222                    if let Some(
223                    objectiveai::chat::completions::response::streaming::ToolCallFunction {
224                        arguments: Some(arguments),
225                        ..
226                    },
227                ) = tool_call.function.as_ref()
228                {
229                    len += arguments.len();
230                }
231                }
232                if let Some(content) = choice.delta.content.as_ref() {
233                    len += content.len();
234                }
235                if len == 0 {
236                    return None;
237                }
238                let mut owned = String::with_capacity(len);
239                for tool_call in tool_calls {
240                    if let Some(
241                    objectiveai::chat::completions::response::streaming::ToolCallFunction {
242                        arguments: Some(arguments),
243                        ..
244                    },
245                ) = tool_call.function.as_ref()
246                {
247                    owned.push_str(arguments);
248                }
249                }
250                if let Some(content) = choice.delta.content.as_ref() {
251                    owned.push_str(content);
252                }
253                Some(Content::Owned(owned))
254            }
255            None => {
256                if let Some(content) = choice.delta.content.as_ref()
257                    && content.len() > 0
258                {
259                    Some(Content::Ref(content.as_str()))
260                } else {
261                    None
262                }
263            }
264        }
265    }
266}