use regex::Regex;
use rust_decimal::MathematicalOps;
pub fn get_vote(
mut pfx_tree: super::PfxTree,
with_ticks_pattern: &str,
responses_len: usize,
content: &str,
logprobs: Option<&objectiveai_sdk::agent::completions::response::Logprobs>,
) -> (usize, Vec<rust_decimal::Decimal>) {
let with_ticks_re = Regex::new(with_ticks_pattern).unwrap();
let key_matches = with_ticks_re.find_iter(content).collect::<Vec<_>>();
if key_matches.is_empty() {
let weight = rust_decimal::Decimal::ONE / rust_decimal::Decimal::from(responses_len);
return (0, vec![weight; responses_len]);
}
let match_count = key_matches.len();
let key_matches_len_decimal =
rust_decimal::Decimal::from(key_matches.len());
let keys_rev = key_matches
.into_iter()
.rev()
.map(|cap| cap.as_str())
.collect::<Vec<_>>();
let mut vote = vec![rust_decimal::Decimal::ZERO; responses_len];
let mut logprob_i = 0;
for key in keys_rev {
let (final_pfx_char, final_pfx) = key
.chars()
.rev()
.map(|c| (c, super::Pfx::from_char(c)))
.filter(|(_, pfx)| pfx.is_some())
.next()
.unwrap();
let final_pfx = final_pfx.unwrap();
let mut i = pfx_tree.depth() - 1;
if i > 0 {
for c in key.chars() {
if let Some(pfx) = super::Pfx::from_char(c) {
pfx_tree = pfx_tree.get(pfx).unwrap();
i -= 1;
if i == 0 {
break;
}
}
}
}
let pfx_tree = match pfx_tree.clone() {
super::PfxTree::Branch(branch) => branch,
super::PfxTree::Leaf(_) => unreachable!(),
};
let mut from_logprobs = false;
if let Some(objectiveai_sdk::agent::completions::response::Logprobs {
content: Some(logprob_content),
..
}) = logprobs
{
let key_rev = key.chars().rev().collect::<String>();
let mut key_rev_slice = key_rev.as_str();
let mut key_logprob = None;
let mut key_logprob_index = 0;
'outer: for logprob in logprob_content.iter().rev().skip(logprob_i) {
logprob_i += 1;
let mut i = logprob.token.len();
for c in logprob.token.chars().rev() {
i -= c.len_utf8();
if key_rev_slice.starts_with(c) {
key_rev_slice = &key_rev_slice[c.len_utf8()..];
if key_logprob.is_none() && c == final_pfx_char {
key_logprob = Some(logprob);
key_logprob_index = i;
}
if key_rev_slice.is_empty() {
break 'outer;
}
} else if key_rev_slice.len() != key_rev.len() {
key_rev_slice = key_rev.as_str();
key_logprob = None;
key_logprob_index = 0;
} else {
}
}
}
if key_rev_slice.is_empty() {
let mut probabilities =
vec![rust_decimal::Decimal::ZERO; responses_len];
let mut probabilities_sum = rust_decimal::Decimal::ZERO;
for objectiveai_sdk::agent::completions::response::TopLogprob {
token,
logprob,
..
} in &key_logprob.as_ref().unwrap().top_logprobs
{
if key_logprob_index < token.len()
&& let Some(logprob) = logprob
&& let Some((_, c)) = token
.char_indices()
.find(|(i, _)| *i == key_logprob_index)
&& let Some(pfx) = super::Pfx::from_char(c)
&& let Some(leaf) = pfx_tree.get(&pfx)
{
from_logprobs = true;
let probability = logprob.exp();
probabilities[leaf.unwrap_leaf()] += probability;
probabilities_sum += probability;
}
}
if probabilities_sum > rust_decimal::Decimal::ZERO {
let mut vote_i = 0;
while vote_i < vote.len() {
vote[vote_i] += (probabilities[vote_i]
/ probabilities_sum)
/ key_matches_len_decimal;
vote_i += 1;
}
}
}
}
if !from_logprobs {
vote[pfx_tree.get(&final_pfx).unwrap().unwrap_leaf()] =
rust_decimal::Decimal::ONE / key_matches_len_decimal;
}
}
(match_count, vote)
}