alith_client/components/cascade/
round.rs

1use super::step::{CascadeStep, StepConfig};
2use alith_interface::requests::completion::CompletionRequest;
3use std::collections::VecDeque;
4
5#[derive(Clone)]
6pub struct CascadeRound {
7    pub task: String,
8    pub unresolved_steps: VecDeque<CascadeStep>,
9    pub resolved_steps: VecDeque<CascadeStep>,
10    pub step_separator: Option<char>,
11}
12
13impl CascadeRound {
14    pub fn new<T: Into<String>>(task: T) -> CascadeRound {
15        CascadeRound {
16            task: task.into(),
17            unresolved_steps: VecDeque::new(),
18            resolved_steps: VecDeque::new(),
19            step_separator: Some(' '),
20        }
21    }
22
23    pub fn step_separator(&mut self, separator: char) -> &mut Self {
24        self.step_separator = Some(separator);
25        self
26    }
27
28    pub fn add_inference_step(&mut self, step_config: &StepConfig) -> &mut CascadeStep {
29        self.unresolved_steps
30            .push_back(CascadeStep::new_inference_step(
31                step_config.clone(),
32                self.unresolved_steps.len() + 1,
33            ));
34        self.unresolved_steps.back_mut().unwrap()
35    }
36
37    pub fn add_guidance_step<T: Into<String>>(
38        &mut self,
39        step_config: &StepConfig,
40        llm_content: T,
41    ) -> &mut CascadeStep {
42        self.unresolved_steps
43            .push_back(CascadeStep::new_guidance_step(
44                step_config.clone(),
45                self.unresolved_steps.len() + 1,
46                llm_content,
47            ));
48        self.unresolved_steps.back_mut().unwrap()
49    }
50
51    pub fn generation_prefix(&self, current_step: &CascadeStep) -> crate::Result<Option<String>> {
52        let mut generation_prefix = String::new();
53        for step in &self.resolved_steps {
54            if generation_prefix.is_empty() {
55                generation_prefix.push_str(&step.display_step_outcome()?);
56            } else {
57                if let Some(step_separator) = self.step_separator {
58                    generation_prefix.push(step_separator);
59                }
60                generation_prefix.push_str(&step.display_step_outcome()?);
61            };
62        }
63        if let Some(step_prefix) = current_step.display_step_prefix() {
64            if generation_prefix.is_empty() {
65                generation_prefix.push_str(&step_prefix);
66            } else {
67                if let Some(step_separator) = self.step_separator {
68                    generation_prefix.push(step_separator);
69                }
70                generation_prefix.push_str(&step_prefix);
71            };
72        }
73
74        if generation_prefix.is_empty() {
75            Ok(None)
76        } else {
77            Ok(Some(generation_prefix))
78        }
79    }
80
81    pub fn display_outcome(&self) -> crate::Result<String> {
82        let mut round_outcome = String::new();
83        for step in self.resolved_steps.iter() {
84            let step_outcome = step.display_step_outcome()?;
85            if round_outcome.is_empty() {
86                round_outcome.push_str(&step_outcome);
87            } else {
88                if let Some(step_separator) = self.step_separator {
89                    round_outcome.push(step_separator);
90                }
91                round_outcome.push_str(&step_outcome);
92            }
93        }
94        Ok(round_outcome)
95    }
96
97    pub async fn run_all_steps(&mut self, base_req: &mut CompletionRequest) -> crate::Result<()> {
98        base_req.prompt.add_user_message()?.set_content(&self.task);
99        while !self.unresolved_steps.is_empty() {
100            match self.run_next_step(base_req).await {
101                Ok(_) => {}
102                Err(e) => {
103                    let mut resolved = std::mem::take(&mut self.resolved_steps);
104                    resolved.append(&mut self.unresolved_steps);
105                    self.unresolved_steps = resolved;
106                    return Err(e);
107                }
108            }
109        }
110
111        let outcome = self.display_outcome()?;
112        base_req
113            .prompt
114            .add_assistant_message()?
115            .set_content(outcome);
116        Ok(())
117    }
118
119    pub async fn run_next_step(&mut self, base_req: &mut CompletionRequest) -> crate::Result<()> {
120        let mut current_step = self.unresolved_steps.pop_front().unwrap();
121        let generation_prefix = self.generation_prefix(&current_step)?;
122        match current_step
123            .run_step(generation_prefix.as_deref(), base_req)
124            .await
125        {
126            Ok(..) => {
127                self.resolved_steps.push_back(current_step);
128                Ok(())
129            }
130            Err(e) => {
131                self.unresolved_steps.push_front(current_step);
132                Err(e)
133            }
134        }
135    }
136
137    pub async fn cache_next_step(&mut self, base_req: &mut CompletionRequest) -> crate::Result<()> {
138        let mut current_step = self.unresolved_steps.pop_front().unwrap();
139        let generation_prefix = self.generation_prefix(&current_step)?;
140        match current_step
141            .set_cache_up_to_step(generation_prefix.as_deref(), base_req)
142            .await
143        {
144            Ok(..) => {
145                self.resolved_steps.push_back(current_step);
146                Ok(())
147            }
148            Err(e) => {
149                self.unresolved_steps.push_front(current_step);
150                Err(e)
151            }
152        }
153    }
154
155    pub fn primitive_result(&self) -> Option<String> {
156        if let Some(step) = self.resolved_steps.back() {
157            step.primitive_result()
158        } else {
159            None
160        }
161    }
162
163    pub fn open_round(&mut self, base_req: &mut CompletionRequest) -> crate::Result<()> {
164        base_req.prompt.add_user_message()?.set_content(&self.task);
165        Ok(())
166    }
167
168    pub fn last_step(&mut self) -> crate::Result<&mut CascadeStep> {
169        match self.resolved_steps.back_mut() {
170            Some(step) => Ok(step),
171            None => crate::bail!("No steps in round"),
172        }
173    }
174
175    pub fn drop_last_step(&mut self) -> crate::Result<()> {
176        match self.resolved_steps.pop_back() {
177            Some(..) => Ok(()),
178            None => crate::bail!("No steps in round"),
179        }
180    }
181
182    pub fn close_round(&mut self, base_req: &mut CompletionRequest) -> crate::Result<()> {
183        base_req
184            .prompt
185            .add_assistant_message()?
186            .set_content(self.display_outcome()?);
187
188        Ok(())
189    }
190}
191
192impl std::fmt::Display for CascadeRound {
193    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194        fn print_step(
195            i: usize,
196            step: &CascadeStep,
197            f: &mut std::fmt::Formatter<'_>,
198        ) -> std::fmt::Result {
199            writeln!(f)?;
200            let color = STEP_GRADIENT[i % STEP_GRADIENT.len()];
201            if let Ok(outcome) = step.display_step_outcome() {
202                writeln!(f, "\x1b[1m{color}step {}\x1b[0m: '{}'", i + 1, outcome)?;
203            } else {
204                writeln!(f, "\x1b[1m{color}step {}\x1b[0m: 'No outcome'", i + 1,)?;
205            }
206            Ok(())
207        }
208
209        writeln!(f)?;
210        writeln!(
211            f,
212            "\x1b[1m{}task\x1b[0m: '{}'",
213            STEP_GRADIENT.last().unwrap(),
214            self.task
215        )?;
216        if !self.unresolved_steps.is_empty() {
217            writeln!(f, "\x1b[1munresolved_steps\x1b[0m")?;
218            for (i, step) in self.unresolved_steps.iter().enumerate() {
219                print_step(i, step, f)?;
220            }
221            writeln!(f)?;
222            if !self.resolved_steps.is_empty() {
223                writeln!(f, "\x1b[1mresolved_steps\x1b[0m")?;
224                for (i, step) in self.resolved_steps.iter().enumerate() {
225                    print_step(i, step, f)?;
226                }
227            }
228        } else if !self.resolved_steps.is_empty() {
229            for (i, step) in self.resolved_steps.iter().enumerate() {
230                print_step(i, step, f)?;
231            }
232        }
233
234        Ok(())
235    }
236}
237
238static STEP_GRADIENT: std::sync::LazyLock<Vec<&'static str>> = std::sync::LazyLock::new(|| {
239    vec![
240        "\x1B[38;2;0;142;250m",
241        "\x1B[38;2;53;138;249m",
242        "\x1B[38;2;77;133;248m",
243        "\x1B[38;2;95;128;246m",
244        "\x1B[38;2;111;123;243m",
245        "\x1B[38;2;125;118;239m",
246        "\x1B[38;2;138;112;234m",
247        "\x1B[38;2;150;106;228m",
248        "\x1B[38;2;160;100;222m",
249        "\x1B[38;2;170;93;214m",
250        "\x1B[38;2;179;86;206m",
251        "\x1B[38;2;187;79;198m",
252        "\x1B[38;2;194;71;189m",
253        "\x1B[38;2;200;63;179m",
254        "\x1B[38;2;206;54;169m",
255        "\x1B[38;2;210;45;158m",
256        "\x1B[38;2;214;36;147m",
257        "\x1B[38;2;216;26;136m",
258        "\x1B[38;2;218;13;124m",
259        "\x1B[38;2;219;0;113m",
260    ]
261});