alith_client/workflows/reason/
decision.rs1use super::{ReasonResult, ReasonTrait};
2use crate::{
3 components::{InstructPromptTrait, instruct_prompt::InstructPrompt},
4 primitives::*,
5};
6use alith_interface::requests::{
7 completion::CompletionRequest,
8 req_components::{RequestConfig, RequestConfigTrait},
9};
10use std::collections::HashMap;
11
12const DYNAMIC_TEMPERATURE_MIN: f32 = 0.11;
13const DYNAMIC_TEMPERATURE_MAX: f32 = 1.89;
14
15pub struct Decision<D: DecisionTrait> {
16 pub base_req: CompletionRequest,
17 pub best_of_n_votes: u8,
18 pub dynamic_temperature: bool,
19 pub reason: D,
20 pub result_can_be_none: bool,
21}
22
23impl<D: DecisionTrait> Decision<D> {
24 pub async fn return_primitive(
25 &mut self,
26 ) -> crate::Result<<D::ReasonPrimitive as PrimitiveTrait>::PrimitiveResult> {
27 let res = self.return_result().await?;
28 if let Some(primitive_result) = self
29 .reason
30 .primitive()
31 .result_index_to_primitive(res.winner_index)?
32 {
33 Ok(primitive_result)
34 } else {
35 Err(anyhow::format_err!("No result returned."))
36 }
37 }
38 pub async fn return_optional_primitive(
39 &mut self,
40 ) -> crate::Result<Option<<D::ReasonPrimitive as PrimitiveTrait>::PrimitiveResult>> {
41 let res = self.return_optional_result().await?;
42 self.reason
43 .primitive()
44 .result_index_to_primitive(res.winner_index)
45 }
46
47 pub async fn return_result(&mut self) -> crate::Result<DecisionResult> {
48 self.result_can_be_none = false;
49 self.run_decision().await
50 }
51
52 pub async fn return_optional_result(&mut self) -> crate::Result<DecisionResult> {
53 self.result_can_be_none = true;
54 self.run_decision().await
55 }
56
57 pub fn parse_decision_result(
58 &self,
59 decision_result: &DecisionResult,
60 ) -> crate::Result<Option<<D::ReasonPrimitive as PrimitiveTrait>::PrimitiveResult>> {
61 if let Some(winner_index) = decision_result.winner_index {
62 self.reason
63 .primitive()
64 .result_index_to_primitive(Some(winner_index))
65 } else {
66 Ok(None)
67 }
68 }
69
70 async fn run_decision(&mut self) -> crate::Result<DecisionResult> {
71 let start = std::time::Instant::now();
72 let mut decision_result = DecisionResult::new();
73 let mut failed_attempts = 0;
74 let mut none_count = 0;
75
76 self.set_dynamic_temperature_on_initial(self.dynamic_temperature, self.best_of_n_votes);
77
78 while failed_attempts < self.base_req.config.retry_after_fail_n_times {
79 if failed_attempts >= self.base_req.config.retry_after_fail_n_times {
80 break;
81 }
82 *self.reason.base_req_mut() = self.base_req.clone();
83 let reason_result = match self
84 .reason
85 .return_reason_result(self.result_can_be_none)
86 .await
87 {
88 Ok(reason_result) => reason_result,
89 Err(_) => {
90 self.set_dynamic_temperature_on_fail(self.dynamic_temperature);
91 failed_attempts += 1;
92 continue;
93 }
94 };
95
96 match self.reason.primitive().parse_reason_result(&reason_result) {
97 Err(_) => {
98 self.set_dynamic_temperature_on_fail(self.dynamic_temperature);
99 failed_attempts += 1;
100 }
101 Ok(primitive_result) => {
102 decision_result.total_votes += 1;
103 if let Some(result_index) = reason_result.result_index {
104 *decision_result.votes.entry(result_index).or_insert(0) += 1;
105 for (choice_index, choice_votes) in &mut decision_result.votes {
106 if *choice_votes > decision_result.winner_votes {
107 decision_result.winner_votes = *choice_votes;
108 decision_result.winner_index = Some(*choice_index);
109 }
110 }
111 } else {
112 none_count += 1;
113 }
114 if decision_result.winner_votes
115 >= (self.best_of_n_votes + (self.best_of_n_votes % 2)) / 2
116 {
117 decision_result.confidence = decision_result.winner_votes as f32
118 / decision_result.total_votes as f32;
119 decision_result.duration = start.elapsed();
120 tracing::info!("{}", decision_result.to_string());
121
122 decision_result.winner_primitive_result =
123 Some(primitive_result.unwrap().to_string());
124
125 decision_result.reason_results.push(reason_result);
126
127 return Ok(decision_result);
128 } else if none_count >= (self.best_of_n_votes + (self.best_of_n_votes % 2)) / 2
129 {
130 decision_result.winner_votes = none_count;
131 decision_result.confidence =
132 none_count as f32 / decision_result.total_votes as f32;
133 decision_result.duration = start.elapsed();
134 tracing::info!("{}", decision_result.to_string());
135
136 decision_result.winner_primitive_result = Some("none".to_string());
137
138 decision_result.reason_results.push(reason_result);
139
140 return Ok(decision_result);
141 } else {
142 self.set_dynamic_temperature_on_success(
143 self.best_of_n_votes,
144 &decision_result,
145 );
146 decision_result.reason_results.push(reason_result);
147 }
148 }
149 }
150 }
151 Err(anyhow::format_err!(
152 "BaseDecider: failed to get a valid response after {}",
153 failed_attempts
154 ))
155 }
156
157 fn set_dynamic_temperature_on_initial(
158 &mut self,
159 dynamic_temperature: bool,
160 best_of_n_votes: u8,
161 ) {
162 if dynamic_temperature && best_of_n_votes > 1 {
163 self.base_req.config.temperature = DYNAMIC_TEMPERATURE_MIN;
164 }
165 }
166
167 fn set_dynamic_temperature_on_success(
168 &mut self,
169 best_of_n_votes: u8,
170 decision_result: &DecisionResult,
171 ) {
172 let votes_required_to_win = (best_of_n_votes + (best_of_n_votes % 2)) / 2;
173 if votes_required_to_win - decision_result.winner_votes == 1 {
174 self.base_req.config.temperature = DYNAMIC_TEMPERATURE_MAX;
175 return;
176 }
177
178 let minimum_votes_remaining = votes_required_to_win - decision_result.winner_votes;
179
180 let maybe_average_votes_remaining =
181 (votes_required_to_win + minimum_votes_remaining) as f32 / 2.0;
182
183 self.base_req.config.temperature = self.base_req.config.temperature
184 + ((DYNAMIC_TEMPERATURE_MAX - self.base_req.config.temperature)
185 / maybe_average_votes_remaining);
186 }
187
188 fn set_dynamic_temperature_on_fail(&mut self, dynamic_temperature: bool) {
189 if dynamic_temperature {
190 self.base_req.config.temperature += DYNAMIC_TEMPERATURE_MIN;
191 }
192 }
193
194 pub fn best_of_n_votes(&mut self, best_of_n_votes: u8) -> &mut Self {
198 if best_of_n_votes % 2 == 0 {
199 self.best_of_n_votes = best_of_n_votes + 1;
200 } else {
201 self.best_of_n_votes = best_of_n_votes;
202 }
203 self
204 }
205
206 pub fn dynamic_temperature(&mut self, dynamic_temperature: bool) -> &mut Self {
208 self.dynamic_temperature = dynamic_temperature;
209 self
210 }
211}
212
213#[allow(async_fn_in_trait)]
214pub trait DecisionTrait: Sized + InstructPromptTrait {
215 type ReasonPrimitive: PrimitiveTrait + ReasonTrait;
216 fn base_req(&self) -> &CompletionRequest;
217
218 fn base_req_mut(&mut self) -> &mut CompletionRequest;
219
220 fn primitive(&self) -> &Self::ReasonPrimitive;
221
222 async fn return_reason_result(
223 &mut self,
224 result_can_be_none: bool,
225 ) -> crate::Result<ReasonResult>;
226
227 fn decision(self) -> Decision<Self> {
228 Decision {
229 base_req: self.base_req().clone(),
230 best_of_n_votes: 3,
231 dynamic_temperature: true,
232 reason: self,
233 result_can_be_none: false,
234 }
235 }
236}
237
238impl<D: DecisionTrait> RequestConfigTrait for Decision<D> {
239 fn config(&mut self) -> &mut RequestConfig {
240 &mut self.base_req.config
241 }
242
243 fn reset_request(&mut self) {
244 self.reason.instruct_prompt_mut().reset_instruct_prompt();
245 self.base_req.reset_completion_request();
246 }
247}
248
249impl<D: DecisionTrait> InstructPromptTrait for Decision<D> {
250 fn instruct_prompt_mut(&mut self) -> &mut InstructPrompt {
251 self.reason.instruct_prompt_mut()
252 }
253}
254
255#[derive(Clone)]
256pub struct DecisionResult {
257 pub votes: HashMap<u32, u8>,
258 pub confidence: f32,
259 pub duration: std::time::Duration,
260 pub winner_primitive_result: Option<String>,
261 pub reason_results: Vec<ReasonResult>,
262 pub total_votes: u8,
263 pub winner_votes: u8,
264 pub winner_index: Option<u32>,
265}
266
267impl DecisionResult {
268 fn new() -> Self {
269 Self {
270 votes: HashMap::new(),
271 confidence: 0.0,
272 duration: std::time::Duration::new(0, 0),
273 winner_primitive_result: None,
274 reason_results: Vec::new(),
275 total_votes: 0,
276 winner_votes: 0,
277 winner_index: None,
278 }
279 }
280}
281
282impl std::fmt::Display for DecisionResult {
283 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
284 writeln!(f)?;
285 writeln!(f)?;
286 writeln!(f, "\x1b[38;5;45m\x1b[1mDecision\x1b[0m",)?;
287 for (i, res) in self.reason_results.iter().enumerate() {
288 writeln!(f)?;
289 writeln!(
290 f,
291 "\x1b[38;5;33m\x1b[1m{} {}\x1b[0m:",
292 res.workflow.cascade_name,
293 i + 1
294 )?;
295 writeln!(f)?;
296 if let Some(primitive_result) = &res.primitive_result {
297 writeln!(
298 f,
299 "\x1b[38;5;32mprimitive_result\x1b[0m: {}",
300 primitive_result
301 )?;
302 } else {
303 writeln!(f, "\x1b[38;5;32mprimitive_result\x1b[0m: None")?;
304 };
305 writeln!(f, "\x1b[38;5;31mreason duration\x1b[0m: {:?}", res.duration)?;
306 writeln!(
307 f,
308 "\x1b[38;5;30mreason temperature\x1b[0m: {:?}",
309 res.temperature
310 )?;
311 }
312
313 writeln!(f)?;
314 writeln!(f)?;
315 writeln!(f, "\x1b[38;5;45m\x1b[1mDecisionResult\x1b[0m:")?;
316 writeln!(f)?;
317 writeln!(
318 f,
319 "\x1b[38;5;44mvote results\x1b[0m: {} out of {} votes for winner.",
320 self.winner_votes, self.total_votes
321 )?;
322 writeln!(f, "\x1b[38;5;44mconfidence\x1b[0m: {}", self.confidence)?;
323 writeln!(
324 f,
325 "\x1b[38;5;43mdecision duration\x1b[0m: {:?}",
326 self.duration
327 )?;
328 if let Some(winner_primitive_result) = &self.winner_primitive_result {
329 writeln!(
330 f,
331 "\x1b[38;5;42m\x1b[1mDecision primitive result\x1b[0m: {}",
332 winner_primitive_result
333 )?;
334 } else {
335 writeln!(
336 f,
337 "\x1b[38;5;42mfs\x1b[1mdecision primitive result\x1b[0m: None"
338 )?;
339 }
340 writeln!(f)
341 }
342}