alith_client/workflows/reason/
decision.rs

1use super::{ReasonResult, ReasonTrait};
2use crate::{
3    components::{InstructPromptTrait, instruct_prompt::InstructPrompt},
4    primitives::*,
5};
6use alith_interface::requests::{
7    completion::CompletionRequest,
8    req_components::{RequestConfig, RequestConfigTrait},
9};
10use std::collections::HashMap;
11
12const DYNAMIC_TEMPERATURE_MIN: f32 = 0.11;
13const DYNAMIC_TEMPERATURE_MAX: f32 = 1.89;
14
15pub struct Decision<D: DecisionTrait> {
16    pub base_req: CompletionRequest,
17    pub best_of_n_votes: u8,
18    pub dynamic_temperature: bool,
19    pub reason: D,
20    pub result_can_be_none: bool,
21}
22
23impl<D: DecisionTrait> Decision<D> {
24    pub async fn return_primitive(
25        &mut self,
26    ) -> crate::Result<<D::ReasonPrimitive as PrimitiveTrait>::PrimitiveResult> {
27        let res = self.return_result().await?;
28        if let Some(primitive_result) = self
29            .reason
30            .primitive()
31            .result_index_to_primitive(res.winner_index)?
32        {
33            Ok(primitive_result)
34        } else {
35            Err(anyhow::format_err!("No result returned."))
36        }
37    }
38    pub async fn return_optional_primitive(
39        &mut self,
40    ) -> crate::Result<Option<<D::ReasonPrimitive as PrimitiveTrait>::PrimitiveResult>> {
41        let res = self.return_optional_result().await?;
42        self.reason
43            .primitive()
44            .result_index_to_primitive(res.winner_index)
45    }
46
47    pub async fn return_result(&mut self) -> crate::Result<DecisionResult> {
48        self.result_can_be_none = false;
49        self.run_decision().await
50    }
51
52    pub async fn return_optional_result(&mut self) -> crate::Result<DecisionResult> {
53        self.result_can_be_none = true;
54        self.run_decision().await
55    }
56
57    pub fn parse_decision_result(
58        &self,
59        decision_result: &DecisionResult,
60    ) -> crate::Result<Option<<D::ReasonPrimitive as PrimitiveTrait>::PrimitiveResult>> {
61        if let Some(winner_index) = decision_result.winner_index {
62            self.reason
63                .primitive()
64                .result_index_to_primitive(Some(winner_index))
65        } else {
66            Ok(None)
67        }
68    }
69
70    async fn run_decision(&mut self) -> crate::Result<DecisionResult> {
71        let start = std::time::Instant::now();
72        let mut decision_result = DecisionResult::new();
73        let mut failed_attempts = 0;
74        let mut none_count = 0;
75
76        self.set_dynamic_temperature_on_initial(self.dynamic_temperature, self.best_of_n_votes);
77
78        while failed_attempts < self.base_req.config.retry_after_fail_n_times {
79            if failed_attempts >= self.base_req.config.retry_after_fail_n_times {
80                break;
81            }
82            *self.reason.base_req_mut() = self.base_req.clone();
83            let reason_result = match self
84                .reason
85                .return_reason_result(self.result_can_be_none)
86                .await
87            {
88                Ok(reason_result) => reason_result,
89                Err(_) => {
90                    self.set_dynamic_temperature_on_fail(self.dynamic_temperature);
91                    failed_attempts += 1;
92                    continue;
93                }
94            };
95
96            match self.reason.primitive().parse_reason_result(&reason_result) {
97                Err(_) => {
98                    self.set_dynamic_temperature_on_fail(self.dynamic_temperature);
99                    failed_attempts += 1;
100                }
101                Ok(primitive_result) => {
102                    decision_result.total_votes += 1;
103                    if let Some(result_index) = reason_result.result_index {
104                        *decision_result.votes.entry(result_index).or_insert(0) += 1;
105                        for (choice_index, choice_votes) in &mut decision_result.votes {
106                            if *choice_votes > decision_result.winner_votes {
107                                decision_result.winner_votes = *choice_votes;
108                                decision_result.winner_index = Some(*choice_index);
109                            }
110                        }
111                    } else {
112                        none_count += 1;
113                    }
114                    if decision_result.winner_votes
115                        >= (self.best_of_n_votes + (self.best_of_n_votes % 2)) / 2
116                    {
117                        decision_result.confidence = decision_result.winner_votes as f32
118                            / decision_result.total_votes as f32;
119                        decision_result.duration = start.elapsed();
120                        tracing::info!("{}", decision_result.to_string());
121
122                        decision_result.winner_primitive_result =
123                            Some(primitive_result.unwrap().to_string());
124
125                        decision_result.reason_results.push(reason_result);
126
127                        return Ok(decision_result);
128                    } else if none_count >= (self.best_of_n_votes + (self.best_of_n_votes % 2)) / 2
129                    {
130                        decision_result.winner_votes = none_count;
131                        decision_result.confidence =
132                            none_count as f32 / decision_result.total_votes as f32;
133                        decision_result.duration = start.elapsed();
134                        tracing::info!("{}", decision_result.to_string());
135
136                        decision_result.winner_primitive_result = Some("none".to_string());
137
138                        decision_result.reason_results.push(reason_result);
139
140                        return Ok(decision_result);
141                    } else {
142                        self.set_dynamic_temperature_on_success(
143                            self.best_of_n_votes,
144                            &decision_result,
145                        );
146                        decision_result.reason_results.push(reason_result);
147                    }
148                }
149            }
150        }
151        Err(anyhow::format_err!(
152            "BaseDecider: failed to get a valid response after {}",
153            failed_attempts
154        ))
155    }
156
157    fn set_dynamic_temperature_on_initial(
158        &mut self,
159        dynamic_temperature: bool,
160        best_of_n_votes: u8,
161    ) {
162        if dynamic_temperature && best_of_n_votes > 1 {
163            self.base_req.config.temperature = DYNAMIC_TEMPERATURE_MIN;
164        }
165    }
166
167    fn set_dynamic_temperature_on_success(
168        &mut self,
169        best_of_n_votes: u8,
170        decision_result: &DecisionResult,
171    ) {
172        let votes_required_to_win = (best_of_n_votes + (best_of_n_votes % 2)) / 2;
173        if votes_required_to_win - decision_result.winner_votes == 1 {
174            self.base_req.config.temperature = DYNAMIC_TEMPERATURE_MAX;
175            return;
176        }
177
178        let minimum_votes_remaining = votes_required_to_win - decision_result.winner_votes;
179
180        let maybe_average_votes_remaining =
181            (votes_required_to_win + minimum_votes_remaining) as f32 / 2.0;
182
183        self.base_req.config.temperature = self.base_req.config.temperature
184            + ((DYNAMIC_TEMPERATURE_MAX - self.base_req.config.temperature)
185                / maybe_average_votes_remaining);
186    }
187
188    fn set_dynamic_temperature_on_fail(&mut self, dynamic_temperature: bool) {
189        if dynamic_temperature {
190            self.base_req.config.temperature += DYNAMIC_TEMPERATURE_MIN;
191        }
192    }
193
194    /// Sets the number of votes to reach consensus. It is the maxium number of votes for a decision, but often the decision is reached before this number is reached.
195    /// For example, with the default of `3` votes, the first decision is made after 2 votes for a choice.
196    /// If given an even number, it will round up to the nearest odd number.
197    pub fn best_of_n_votes(&mut self, best_of_n_votes: u8) -> &mut Self {
198        if best_of_n_votes % 2 == 0 {
199            self.best_of_n_votes = best_of_n_votes + 1;
200        } else {
201            self.best_of_n_votes = best_of_n_votes;
202        }
203        self
204    }
205
206    /// Dynamically scales temperature during the voting process. Starts at a low temperature and increases towards max temperature as the number of votes increases.
207    pub fn dynamic_temperature(&mut self, dynamic_temperature: bool) -> &mut Self {
208        self.dynamic_temperature = dynamic_temperature;
209        self
210    }
211}
212
213#[allow(async_fn_in_trait)]
214pub trait DecisionTrait: Sized + InstructPromptTrait {
215    type ReasonPrimitive: PrimitiveTrait + ReasonTrait;
216    fn base_req(&self) -> &CompletionRequest;
217
218    fn base_req_mut(&mut self) -> &mut CompletionRequest;
219
220    fn primitive(&self) -> &Self::ReasonPrimitive;
221
222    async fn return_reason_result(
223        &mut self,
224        result_can_be_none: bool,
225    ) -> crate::Result<ReasonResult>;
226
227    fn decision(self) -> Decision<Self> {
228        Decision {
229            base_req: self.base_req().clone(),
230            best_of_n_votes: 3,
231            dynamic_temperature: true,
232            reason: self,
233            result_can_be_none: false,
234        }
235    }
236}
237
238impl<D: DecisionTrait> RequestConfigTrait for Decision<D> {
239    fn config(&mut self) -> &mut RequestConfig {
240        &mut self.base_req.config
241    }
242
243    fn reset_request(&mut self) {
244        self.reason.instruct_prompt_mut().reset_instruct_prompt();
245        self.base_req.reset_completion_request();
246    }
247}
248
249impl<D: DecisionTrait> InstructPromptTrait for Decision<D> {
250    fn instruct_prompt_mut(&mut self) -> &mut InstructPrompt {
251        self.reason.instruct_prompt_mut()
252    }
253}
254
255#[derive(Clone)]
256pub struct DecisionResult {
257    pub votes: HashMap<u32, u8>,
258    pub confidence: f32,
259    pub duration: std::time::Duration,
260    pub winner_primitive_result: Option<String>,
261    pub reason_results: Vec<ReasonResult>,
262    pub total_votes: u8,
263    pub winner_votes: u8,
264    pub winner_index: Option<u32>,
265}
266
267impl DecisionResult {
268    fn new() -> Self {
269        Self {
270            votes: HashMap::new(),
271            confidence: 0.0,
272            duration: std::time::Duration::new(0, 0),
273            winner_primitive_result: None,
274            reason_results: Vec::new(),
275            total_votes: 0,
276            winner_votes: 0,
277            winner_index: None,
278        }
279    }
280}
281
282impl std::fmt::Display for DecisionResult {
283    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
284        writeln!(f)?;
285        writeln!(f)?;
286        writeln!(f, "\x1b[38;5;45m\x1b[1mDecision\x1b[0m",)?;
287        for (i, res) in self.reason_results.iter().enumerate() {
288            writeln!(f)?;
289            writeln!(
290                f,
291                "\x1b[38;5;33m\x1b[1m{} {}\x1b[0m:",
292                res.workflow.cascade_name,
293                i + 1
294            )?;
295            writeln!(f)?;
296            if let Some(primitive_result) = &res.primitive_result {
297                writeln!(
298                    f,
299                    "\x1b[38;5;32mprimitive_result\x1b[0m: {}",
300                    primitive_result
301                )?;
302            } else {
303                writeln!(f, "\x1b[38;5;32mprimitive_result\x1b[0m: None")?;
304            };
305            writeln!(f, "\x1b[38;5;31mreason duration\x1b[0m: {:?}", res.duration)?;
306            writeln!(
307                f,
308                "\x1b[38;5;30mreason temperature\x1b[0m: {:?}",
309                res.temperature
310            )?;
311        }
312
313        writeln!(f)?;
314        writeln!(f)?;
315        writeln!(f, "\x1b[38;5;45m\x1b[1mDecisionResult\x1b[0m:")?;
316        writeln!(f)?;
317        writeln!(
318            f,
319            "\x1b[38;5;44mvote results\x1b[0m: {} out of {} votes for winner.",
320            self.winner_votes, self.total_votes
321        )?;
322        writeln!(f, "\x1b[38;5;44mconfidence\x1b[0m: {}", self.confidence)?;
323        writeln!(
324            f,
325            "\x1b[38;5;43mdecision duration\x1b[0m: {:?}",
326            self.duration
327        )?;
328        if let Some(winner_primitive_result) = &self.winner_primitive_result {
329            writeln!(
330                f,
331                "\x1b[38;5;42m\x1b[1mDecision primitive result\x1b[0m: {}",
332                winner_primitive_result
333            )?;
334        } else {
335            writeln!(
336                f,
337                "\x1b[38;5;42mfs\x1b[1mdecision primitive result\x1b[0m: None"
338            )?;
339        }
340        writeln!(f)
341    }
342}