alith_client/components/cascade/
mod.rs

1pub mod round;
2pub mod step;
3
4use alith_interface::requests::{
5    completion::{CompletionFinishReason, CompletionRequest},
6    stop_sequence::StoppingSequence,
7};
8use anyhow::{Result, anyhow};
9use core::panic;
10pub use round::CascadeRound;
11use step::InferenceStep;
12
13#[derive(Clone)]
14pub struct CascadeFlow {
15    pub cascade_name: String,
16    pub duration: std::time::Duration,
17    pub result_can_be_none: bool,
18    pub rounds: Vec<CascadeRound>,
19    pub start_time: std::time::Instant,
20}
21
22impl CascadeFlow {
23    pub fn new<T: Into<String>>(cascade_name: T) -> Self {
24        Self {
25            cascade_name: cascade_name.into(),
26            start_time: std::time::Instant::now(),
27            duration: std::time::Duration::default(),
28            rounds: Vec::new(),
29            result_can_be_none: false,
30        }
31    }
32
33    pub fn new_round<T: Into<String>>(&mut self, task: T) -> &mut CascadeRound {
34        let round = CascadeRound::new(task);
35        self.rounds.push(round);
36        self.rounds.last_mut().unwrap()
37    }
38
39    pub fn add_round(&mut self, round: CascadeRound) {
40        self.rounds.push(round);
41    }
42
43    pub async fn run_all_rounds(&mut self, base_req: &mut CompletionRequest) -> Result<()> {
44        self.start_time = std::time::Instant::now();
45
46        for round in self.rounds.iter_mut() {
47            round.run_all_steps(base_req).await?;
48        }
49
50        self.duration = self.start_time.elapsed();
51        Ok(())
52    }
53
54    pub fn last_round(&mut self) -> Result<&mut CascadeRound> {
55        match self.rounds.last_mut() {
56            Some(round) => Ok(round),
57            None => Err(anyhow!("No rounds in cascade")),
58        }
59    }
60
61    pub fn drop_last_round(&mut self) -> crate::Result<()> {
62        match self.rounds.pop() {
63            Some(..) => Ok(()),
64            None => crate::bail!("No rounds in cascade"),
65        }
66    }
67
68    pub fn open_cascade(&mut self) {
69        self.start_time = std::time::Instant::now();
70    }
71
72    pub fn close_cascade(&mut self) -> Result<()> {
73        self.duration = self.start_time.elapsed();
74        Ok(())
75    }
76
77    pub fn primitive_result(&self) -> Option<String> {
78        match self.rounds.last() {
79            Some(round) => round.primitive_result(),
80            None => panic!("No rounds in cascade"),
81        }
82    }
83}
84
85pub(crate) async fn cascade_request(
86    base_req: &mut CompletionRequest,
87    step: &mut InferenceStep,
88) -> Result<()> {
89    let res = base_req.request().await?;
90    if matches!(
91        res.finish_reason,
92        CompletionFinishReason::MatchingStoppingSequence(StoppingSequence::NoResult(_))
93    ) {
94        step.llm_content = None;
95        return Ok(());
96    }
97
98    match step.step_config.grammar.validate_clean(&res.content) {
99        Ok(content) => {
100            step.llm_content = Some(content.clone());
101        }
102        Err(e) => {
103            crate::info!(?e);
104        }
105    }
106    Ok(())
107}
108
109impl std::fmt::Display for CascadeFlow {
110    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111        writeln!(f)?;
112        writeln!(f, "\x1b[1m\x1B[38;2;92;244;37m{}\x1b[0m", self.cascade_name)?;
113        writeln!(f)?;
114        for (i, round) in self.rounds.iter().enumerate() {
115            let color = ROUND_GRADIENT[i % ROUND_GRADIENT.len()];
116            writeln!(f, "\x1b[1m{color}Round {}\x1b[0m", i + 1)?;
117            writeln!(f, "{round}",)?;
118        }
119        Ok(())
120    }
121}
122static ROUND_GRADIENT: std::sync::LazyLock<Vec<&'static str>> = std::sync::LazyLock::new(|| {
123    vec![
124        "\x1B[38;2;230;175;45m",
125        "\x1B[38;2;235;158;57m",
126        "\x1B[38;2;235;142;68m",
127        "\x1B[38;2;232;127;80m",
128        "\x1B[38;2;226;114;91m",
129        "\x1B[38;2;216;103;100m",
130        "\x1B[38;2;204;94;108m",
131        "\x1B[38;2;189;88;114m",
132        "\x1B[38;2;172;83;118m",
133        "\x1B[38;2;153;79;119m",
134        "\x1B[38;2;134;76;118m",
135        "\x1B[38;2;115;73;114m",
136        "\x1B[38;2;97;69;108m",
137        "\x1B[38;2;80;65;99m",
138        "\x1B[38;2;65;60;88m",
139    ]
140});