objectiveai_api/vector/completions/
get_vote.rs1use regex::Regex;
7use rust_decimal::MathematicalOps;
8
9pub fn get_vote(
18 mut pfx_tree: super::PfxTree,
19 with_ticks_pattern: &str,
20 without_ticks_pattern: &str,
21 responses_len: usize,
22 choice: &objectiveai::chat::completions::response::streaming::Choice,
23) -> Option<Vec<rust_decimal::Decimal>> {
24 let content_owned = match Content::from_choice(choice) {
26 Some(c) => c,
27 None => {
28 return None;
29 }
30 };
31 let content = content_owned.as_str();
32
33 let with_ticks_re = Regex::new(with_ticks_pattern).unwrap();
35 let mut key_matches = with_ticks_re.find_iter(content).collect::<Vec<_>>();
36 let without_ticks_re = match key_matches.len() {
37 0 => Some(Regex::new(without_ticks_pattern).unwrap()),
38 _ => None,
39 };
40 if let Some(without_ticks_re) = without_ticks_re.as_ref() {
41 key_matches = without_ticks_re.find_iter(content).collect::<Vec<_>>();
42 }
43
44 if key_matches.is_empty() {
46 return None;
47 }
48
49 let key_matches_len_decimal =
51 rust_decimal::Decimal::from(key_matches.len());
52
53 let keys_rev = key_matches
55 .into_iter()
56 .rev()
57 .map(|cap| cap.as_str())
58 .collect::<Vec<_>>();
59
60 let mut vote = vec![rust_decimal::Decimal::ZERO; responses_len];
62
63 let mut logprob_i = 0;
65
66 for key in keys_rev {
67 let (final_pfx_char, final_pfx) = key
69 .chars()
70 .rev()
71 .map(|c| (c, super::Pfx::from_char(c)))
72 .filter(|(_, pfx)| pfx.is_some())
73 .next()
74 .unwrap();
75 let final_pfx = final_pfx.unwrap();
76
77 let mut i = pfx_tree.depth() - 1;
79 if i > 0 {
80 for c in key.chars() {
81 if let Some(pfx) = super::Pfx::from_char(c) {
82 pfx_tree = pfx_tree.get(pfx).unwrap();
83 i -= 1;
84 if i == 0 {
85 break;
86 }
87 }
88 }
89 }
90 let pfx_tree = match pfx_tree.clone() {
91 super::PfxTree::Branch(branch) => branch,
92 super::PfxTree::Leaf(_) => unreachable!(),
93 };
94
95 let mut from_logprobs = false;
97 if let Some(objectiveai::chat::completions::response::Logprobs {
98 content: Some(logprobs),
99 ..
100 }) = choice.logprobs.as_ref()
101 {
102 let key_rev = key.chars().rev().collect::<String>();
104
105 let mut key_rev_slice = key_rev.as_str();
107
108 let mut key_logprob = None;
110 let mut key_logprob_index = 0;
111
112 'outer: for logprob in logprobs.iter().rev().skip(logprob_i) {
114 logprob_i += 1;
115 let mut i = logprob.token.len();
116 for c in logprob.token.chars().rev() {
117 i -= c.len_utf8();
118 if key_rev_slice.starts_with(c) {
119 key_rev_slice = &key_rev_slice[c.len_utf8()..];
122 if key_logprob.is_none() && c == final_pfx_char {
124 key_logprob = Some(logprob);
125 key_logprob_index = i;
126 }
127 if key_rev_slice.is_empty() {
129 break 'outer;
130 }
131 } else if key_rev_slice.len() != key_rev.len() {
132 key_rev_slice = key_rev.as_str();
135 key_logprob = None;
136 key_logprob_index = 0;
137 } else {
138 }
140 }
141 }
142
143 if key_rev_slice.is_empty() {
145 let mut probabilities =
147 vec![rust_decimal::Decimal::ZERO; responses_len];
148 let mut probabilities_sum = rust_decimal::Decimal::ZERO;
149 for objectiveai::chat::completions::response::TopLogprob {
150 token,
151 logprob,
152 ..
153 } in &key_logprob.as_ref().unwrap().top_logprobs
154 {
155 if key_logprob_index < token.len()
156 && let Some(logprob) = logprob
157 && let Some((_, c)) = token
158 .char_indices()
159 .find(|(i, _)| *i == key_logprob_index)
160 && let Some(pfx) = super::Pfx::from_char(c)
161 && let Some(leaf) = pfx_tree.get(&pfx)
162 {
163 from_logprobs = true;
165
166 let probability = logprob.exp();
168 probabilities[leaf.unwrap_leaf()] += probability;
169 probabilities_sum += probability;
170 }
171 }
172
173 if probabilities_sum > rust_decimal::Decimal::ZERO {
175 let mut vote_i = 0;
176 while vote_i < vote.len() {
177 vote[vote_i] += (probabilities[vote_i]
178 / probabilities_sum)
179 / key_matches_len_decimal;
180 vote_i += 1;
181 }
182 }
183 }
184 }
185
186 if !from_logprobs {
188 vote[pfx_tree.get(&final_pfx).unwrap().unwrap_leaf()] =
189 rust_decimal::Decimal::ONE / key_matches_len_decimal;
190 }
191 }
192
193 Some(vote)
195}
196
197enum Content<'s> {
199 Ref(&'s str),
201 Owned(String),
203}
204
205impl<'s> Content<'s> {
206 fn as_str(&self) -> &str {
208 match self {
209 Content::Ref(s) => s,
210 Content::Owned(s) => s.as_str(),
211 }
212 }
213
214 fn from_choice(
216 choice: &'s objectiveai::chat::completions::response::streaming::Choice,
217 ) -> Option<Self> {
218 match choice.delta.tool_calls.as_ref() {
219 Some(tool_calls) => {
220 let mut len = 0;
221 for tool_call in tool_calls {
222 if let Some(
223 objectiveai::chat::completions::response::streaming::ToolCallFunction {
224 arguments: Some(arguments),
225 ..
226 },
227 ) = tool_call.function.as_ref()
228 {
229 len += arguments.len();
230 }
231 }
232 if let Some(content) = choice.delta.content.as_ref() {
233 len += content.len();
234 }
235 if len == 0 {
236 return None;
237 }
238 let mut owned = String::with_capacity(len);
239 for tool_call in tool_calls {
240 if let Some(
241 objectiveai::chat::completions::response::streaming::ToolCallFunction {
242 arguments: Some(arguments),
243 ..
244 },
245 ) = tool_call.function.as_ref()
246 {
247 owned.push_str(arguments);
248 }
249 }
250 if let Some(content) = choice.delta.content.as_ref() {
251 owned.push_str(content);
252 }
253 Some(Content::Owned(owned))
254 }
255 None => {
256 if let Some(content) = choice.delta.content.as_ref()
257 && content.len() > 0
258 {
259 Some(Content::Ref(content.as_str()))
260 } else {
261 None
262 }
263 }
264 }
265 }
266}