use super::{ReasonResult, ReasonTrait};
use crate::{
components::{InstructPromptTrait, instruct_prompt::InstructPrompt},
primitives::*,
};
use alith_interface::requests::{
completion::CompletionRequest,
req_components::{RequestConfig, RequestConfigTrait},
};
use std::collections::HashMap;
const DYNAMIC_TEMPERATURE_MIN: f32 = 0.11;
const DYNAMIC_TEMPERATURE_MAX: f32 = 1.89;
pub struct Decision<D: DecisionTrait> {
pub base_req: CompletionRequest,
pub best_of_n_votes: u8,
pub dynamic_temperature: bool,
pub reason: D,
pub result_can_be_none: bool,
}
impl<D: DecisionTrait> Decision<D> {
pub async fn return_primitive(
&mut self,
) -> crate::Result<<D::ReasonPrimitive as PrimitiveTrait>::PrimitiveResult> {
let res = self.return_result().await?;
if let Some(primitive_result) = self
.reason
.primitive()
.result_index_to_primitive(res.winner_index)?
{
Ok(primitive_result)
} else {
Err(anyhow::format_err!("No result returned."))
}
}
pub async fn return_optional_primitive(
&mut self,
) -> crate::Result<Option<<D::ReasonPrimitive as PrimitiveTrait>::PrimitiveResult>> {
let res = self.return_optional_result().await?;
self.reason
.primitive()
.result_index_to_primitive(res.winner_index)
}
pub async fn return_result(&mut self) -> crate::Result<DecisionResult> {
self.result_can_be_none = false;
self.run_decision().await
}
pub async fn return_optional_result(&mut self) -> crate::Result<DecisionResult> {
self.result_can_be_none = true;
self.run_decision().await
}
pub fn parse_decision_result(
&self,
decision_result: &DecisionResult,
) -> crate::Result<Option<<D::ReasonPrimitive as PrimitiveTrait>::PrimitiveResult>> {
if let Some(winner_index) = decision_result.winner_index {
self.reason
.primitive()
.result_index_to_primitive(Some(winner_index))
} else {
Ok(None)
}
}
async fn run_decision(&mut self) -> crate::Result<DecisionResult> {
let start = std::time::Instant::now();
let mut decision_result = DecisionResult::new();
let mut failed_attempts = 0;
let mut none_count = 0;
self.set_dynamic_temperature_on_initial(self.dynamic_temperature, self.best_of_n_votes);
while failed_attempts < self.base_req.config.retry_after_fail_n_times {
if failed_attempts >= self.base_req.config.retry_after_fail_n_times {
break;
}
*self.reason.base_req_mut() = self.base_req.clone();
let reason_result = match self
.reason
.return_reason_result(self.result_can_be_none)
.await
{
Ok(reason_result) => reason_result,
Err(_) => {
self.set_dynamic_temperature_on_fail(self.dynamic_temperature);
failed_attempts += 1;
continue;
}
};
match self.reason.primitive().parse_reason_result(&reason_result) {
Err(_) => {
self.set_dynamic_temperature_on_fail(self.dynamic_temperature);
failed_attempts += 1;
}
Ok(primitive_result) => {
decision_result.total_votes += 1;
if let Some(result_index) = reason_result.result_index {
*decision_result.votes.entry(result_index).or_insert(0) += 1;
for (choice_index, choice_votes) in &mut decision_result.votes {
if *choice_votes > decision_result.winner_votes {
decision_result.winner_votes = *choice_votes;
decision_result.winner_index = Some(*choice_index);
}
}
} else {
none_count += 1;
}
if decision_result.winner_votes
>= (self.best_of_n_votes + (self.best_of_n_votes % 2)) / 2
{
decision_result.confidence = decision_result.winner_votes as f32
/ decision_result.total_votes as f32;
decision_result.duration = start.elapsed();
tracing::info!("{}", decision_result.to_string());
decision_result.winner_primitive_result =
Some(primitive_result.unwrap().to_string());
decision_result.reason_results.push(reason_result);
return Ok(decision_result);
} else if none_count >= (self.best_of_n_votes + (self.best_of_n_votes % 2)) / 2
{
decision_result.winner_votes = none_count;
decision_result.confidence =
none_count as f32 / decision_result.total_votes as f32;
decision_result.duration = start.elapsed();
tracing::info!("{}", decision_result.to_string());
decision_result.winner_primitive_result = Some("none".to_string());
decision_result.reason_results.push(reason_result);
return Ok(decision_result);
} else {
self.set_dynamic_temperature_on_success(
self.best_of_n_votes,
&decision_result,
);
decision_result.reason_results.push(reason_result);
}
}
}
}
Err(anyhow::format_err!(
"BaseDecider: failed to get a valid response after {}",
failed_attempts
))
}
fn set_dynamic_temperature_on_initial(
&mut self,
dynamic_temperature: bool,
best_of_n_votes: u8,
) {
if dynamic_temperature && best_of_n_votes > 1 {
self.base_req.config.temperature = DYNAMIC_TEMPERATURE_MIN;
}
}
fn set_dynamic_temperature_on_success(
&mut self,
best_of_n_votes: u8,
decision_result: &DecisionResult,
) {
let votes_required_to_win = (best_of_n_votes + (best_of_n_votes % 2)) / 2;
if votes_required_to_win - decision_result.winner_votes == 1 {
self.base_req.config.temperature = DYNAMIC_TEMPERATURE_MAX;
return;
}
let minimum_votes_remaining = votes_required_to_win - decision_result.winner_votes;
let maybe_average_votes_remaining =
(votes_required_to_win + minimum_votes_remaining) as f32 / 2.0;
self.base_req.config.temperature = self.base_req.config.temperature
+ ((DYNAMIC_TEMPERATURE_MAX - self.base_req.config.temperature)
/ maybe_average_votes_remaining);
}
fn set_dynamic_temperature_on_fail(&mut self, dynamic_temperature: bool) {
if dynamic_temperature {
self.base_req.config.temperature += DYNAMIC_TEMPERATURE_MIN;
}
}
pub fn best_of_n_votes(&mut self, best_of_n_votes: u8) -> &mut Self {
if best_of_n_votes % 2 == 0 {
self.best_of_n_votes = best_of_n_votes + 1;
} else {
self.best_of_n_votes = best_of_n_votes;
}
self
}
pub fn dynamic_temperature(&mut self, dynamic_temperature: bool) -> &mut Self {
self.dynamic_temperature = dynamic_temperature;
self
}
}
#[allow(async_fn_in_trait)]
pub trait DecisionTrait: Sized + InstructPromptTrait {
type ReasonPrimitive: PrimitiveTrait + ReasonTrait;
fn base_req(&self) -> &CompletionRequest;
fn base_req_mut(&mut self) -> &mut CompletionRequest;
fn primitive(&self) -> &Self::ReasonPrimitive;
async fn return_reason_result(
&mut self,
result_can_be_none: bool,
) -> crate::Result<ReasonResult>;
fn decision(self) -> Decision<Self> {
Decision {
base_req: self.base_req().clone(),
best_of_n_votes: 3,
dynamic_temperature: true,
reason: self,
result_can_be_none: false,
}
}
}
impl<D: DecisionTrait> RequestConfigTrait for Decision<D> {
fn config(&mut self) -> &mut RequestConfig {
&mut self.base_req.config
}
fn reset_request(&mut self) {
self.reason.instruct_prompt_mut().reset_instruct_prompt();
self.base_req.reset_completion_request();
}
}
impl<D: DecisionTrait> InstructPromptTrait for Decision<D> {
fn instruct_prompt_mut(&mut self) -> &mut InstructPrompt {
self.reason.instruct_prompt_mut()
}
}
#[derive(Clone)]
pub struct DecisionResult {
pub votes: HashMap<u32, u8>,
pub confidence: f32,
pub duration: std::time::Duration,
pub winner_primitive_result: Option<String>,
pub reason_results: Vec<ReasonResult>,
pub total_votes: u8,
pub winner_votes: u8,
pub winner_index: Option<u32>,
}
impl DecisionResult {
fn new() -> Self {
Self {
votes: HashMap::new(),
confidence: 0.0,
duration: std::time::Duration::new(0, 0),
winner_primitive_result: None,
reason_results: Vec::new(),
total_votes: 0,
winner_votes: 0,
winner_index: None,
}
}
}
impl std::fmt::Display for DecisionResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f)?;
writeln!(f)?;
writeln!(f, "\x1b[38;5;45m\x1b[1mDecision\x1b[0m",)?;
for (i, res) in self.reason_results.iter().enumerate() {
writeln!(f)?;
writeln!(
f,
"\x1b[38;5;33m\x1b[1m{} {}\x1b[0m:",
res.workflow.cascade_name,
i + 1
)?;
writeln!(f)?;
if let Some(primitive_result) = &res.primitive_result {
writeln!(
f,
"\x1b[38;5;32mprimitive_result\x1b[0m: {}",
primitive_result
)?;
} else {
writeln!(f, "\x1b[38;5;32mprimitive_result\x1b[0m: None")?;
};
writeln!(f, "\x1b[38;5;31mreason duration\x1b[0m: {:?}", res.duration)?;
writeln!(
f,
"\x1b[38;5;30mreason temperature\x1b[0m: {:?}",
res.temperature
)?;
}
writeln!(f)?;
writeln!(f)?;
writeln!(f, "\x1b[38;5;45m\x1b[1mDecisionResult\x1b[0m:")?;
writeln!(f)?;
writeln!(
f,
"\x1b[38;5;44mvote results\x1b[0m: {} out of {} votes for winner.",
self.winner_votes, self.total_votes
)?;
writeln!(f, "\x1b[38;5;44mconfidence\x1b[0m: {}", self.confidence)?;
writeln!(
f,
"\x1b[38;5;43mdecision duration\x1b[0m: {:?}",
self.duration
)?;
if let Some(winner_primitive_result) = &self.winner_primitive_result {
writeln!(
f,
"\x1b[38;5;42m\x1b[1mDecision primitive result\x1b[0m: {}",
winner_primitive_result
)?;
} else {
writeln!(
f,
"\x1b[38;5;42mfs\x1b[1mdecision primitive result\x1b[0m: None"
)?;
}
writeln!(f)
}
}