alith_client/components/cascade/
mod.rs1pub mod round;
2pub mod step;
3
4use alith_interface::requests::{
5 completion::{CompletionFinishReason, CompletionRequest},
6 stop_sequence::StoppingSequence,
7};
8use anyhow::{Result, anyhow};
9use core::panic;
10pub use round::CascadeRound;
11use step::InferenceStep;
12
13#[derive(Clone)]
14pub struct CascadeFlow {
15 pub cascade_name: String,
16 pub duration: std::time::Duration,
17 pub result_can_be_none: bool,
18 pub rounds: Vec<CascadeRound>,
19 pub start_time: std::time::Instant,
20}
21
22impl CascadeFlow {
23 pub fn new<T: Into<String>>(cascade_name: T) -> Self {
24 Self {
25 cascade_name: cascade_name.into(),
26 start_time: std::time::Instant::now(),
27 duration: std::time::Duration::default(),
28 rounds: Vec::new(),
29 result_can_be_none: false,
30 }
31 }
32
33 pub fn new_round<T: Into<String>>(&mut self, task: T) -> &mut CascadeRound {
34 let round = CascadeRound::new(task);
35 self.rounds.push(round);
36 self.rounds.last_mut().unwrap()
37 }
38
39 pub fn add_round(&mut self, round: CascadeRound) {
40 self.rounds.push(round);
41 }
42
43 pub async fn run_all_rounds(&mut self, base_req: &mut CompletionRequest) -> Result<()> {
44 self.start_time = std::time::Instant::now();
45
46 for round in self.rounds.iter_mut() {
47 round.run_all_steps(base_req).await?;
48 }
49
50 self.duration = self.start_time.elapsed();
51 Ok(())
52 }
53
54 pub fn last_round(&mut self) -> Result<&mut CascadeRound> {
55 match self.rounds.last_mut() {
56 Some(round) => Ok(round),
57 None => Err(anyhow!("No rounds in cascade")),
58 }
59 }
60
61 pub fn drop_last_round(&mut self) -> crate::Result<()> {
62 match self.rounds.pop() {
63 Some(..) => Ok(()),
64 None => crate::bail!("No rounds in cascade"),
65 }
66 }
67
68 pub fn open_cascade(&mut self) {
69 self.start_time = std::time::Instant::now();
70 }
71
72 pub fn close_cascade(&mut self) -> Result<()> {
73 self.duration = self.start_time.elapsed();
74 Ok(())
75 }
76
77 pub fn primitive_result(&self) -> Option<String> {
78 match self.rounds.last() {
79 Some(round) => round.primitive_result(),
80 None => panic!("No rounds in cascade"),
81 }
82 }
83}
84
85pub(crate) async fn cascade_request(
86 base_req: &mut CompletionRequest,
87 step: &mut InferenceStep,
88) -> Result<()> {
89 let res = base_req.request().await?;
90 if matches!(
91 res.finish_reason,
92 CompletionFinishReason::MatchingStoppingSequence(StoppingSequence::NoResult(_))
93 ) {
94 step.llm_content = None;
95 return Ok(());
96 }
97
98 match step.step_config.grammar.validate_clean(&res.content) {
99 Ok(content) => {
100 step.llm_content = Some(content.clone());
101 }
102 Err(e) => {
103 crate::info!(?e);
104 }
105 }
106 Ok(())
107}
108
109impl std::fmt::Display for CascadeFlow {
110 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111 writeln!(f)?;
112 writeln!(f, "\x1b[1m\x1B[38;2;92;244;37m{}\x1b[0m", self.cascade_name)?;
113 writeln!(f)?;
114 for (i, round) in self.rounds.iter().enumerate() {
115 let color = ROUND_GRADIENT[i % ROUND_GRADIENT.len()];
116 writeln!(f, "\x1b[1m{color}Round {}\x1b[0m", i + 1)?;
117 writeln!(f, "{round}",)?;
118 }
119 Ok(())
120 }
121}
122static ROUND_GRADIENT: std::sync::LazyLock<Vec<&'static str>> = std::sync::LazyLock::new(|| {
123 vec![
124 "\x1B[38;2;230;175;45m",
125 "\x1B[38;2;235;158;57m",
126 "\x1B[38;2;235;142;68m",
127 "\x1B[38;2;232;127;80m",
128 "\x1B[38;2;226;114;91m",
129 "\x1B[38;2;216;103;100m",
130 "\x1B[38;2;204;94;108m",
131 "\x1B[38;2;189;88;114m",
132 "\x1B[38;2;172;83;118m",
133 "\x1B[38;2;153;79;119m",
134 "\x1B[38;2;134;76;118m",
135 "\x1B[38;2;115;73;114m",
136 "\x1B[38;2;97;69;108m",
137 "\x1B[38;2;80;65;99m",
138 "\x1B[38;2;65;60;88m",
139 ]
140});