alith_client/components/cascade/
round.rs1use super::step::{CascadeStep, StepConfig};
2use alith_interface::requests::completion::CompletionRequest;
3use std::collections::VecDeque;
4
5#[derive(Clone)]
6pub struct CascadeRound {
7 pub task: String,
8 pub unresolved_steps: VecDeque<CascadeStep>,
9 pub resolved_steps: VecDeque<CascadeStep>,
10 pub step_separator: Option<char>,
11}
12
13impl CascadeRound {
14 pub fn new<T: Into<String>>(task: T) -> CascadeRound {
15 CascadeRound {
16 task: task.into(),
17 unresolved_steps: VecDeque::new(),
18 resolved_steps: VecDeque::new(),
19 step_separator: Some(' '),
20 }
21 }
22
23 pub fn step_separator(&mut self, separator: char) -> &mut Self {
24 self.step_separator = Some(separator);
25 self
26 }
27
28 pub fn add_inference_step(&mut self, step_config: &StepConfig) -> &mut CascadeStep {
29 self.unresolved_steps
30 .push_back(CascadeStep::new_inference_step(
31 step_config.clone(),
32 self.unresolved_steps.len() + 1,
33 ));
34 self.unresolved_steps.back_mut().unwrap()
35 }
36
37 pub fn add_guidance_step<T: Into<String>>(
38 &mut self,
39 step_config: &StepConfig,
40 llm_content: T,
41 ) -> &mut CascadeStep {
42 self.unresolved_steps
43 .push_back(CascadeStep::new_guidance_step(
44 step_config.clone(),
45 self.unresolved_steps.len() + 1,
46 llm_content,
47 ));
48 self.unresolved_steps.back_mut().unwrap()
49 }
50
51 pub fn generation_prefix(&self, current_step: &CascadeStep) -> crate::Result<Option<String>> {
52 let mut generation_prefix = String::new();
53 for step in &self.resolved_steps {
54 if generation_prefix.is_empty() {
55 generation_prefix.push_str(&step.display_step_outcome()?);
56 } else {
57 if let Some(step_separator) = self.step_separator {
58 generation_prefix.push(step_separator);
59 }
60 generation_prefix.push_str(&step.display_step_outcome()?);
61 };
62 }
63 if let Some(step_prefix) = current_step.display_step_prefix() {
64 if generation_prefix.is_empty() {
65 generation_prefix.push_str(&step_prefix);
66 } else {
67 if let Some(step_separator) = self.step_separator {
68 generation_prefix.push(step_separator);
69 }
70 generation_prefix.push_str(&step_prefix);
71 };
72 }
73
74 if generation_prefix.is_empty() {
75 Ok(None)
76 } else {
77 Ok(Some(generation_prefix))
78 }
79 }
80
81 pub fn display_outcome(&self) -> crate::Result<String> {
82 let mut round_outcome = String::new();
83 for step in self.resolved_steps.iter() {
84 let step_outcome = step.display_step_outcome()?;
85 if round_outcome.is_empty() {
86 round_outcome.push_str(&step_outcome);
87 } else {
88 if let Some(step_separator) = self.step_separator {
89 round_outcome.push(step_separator);
90 }
91 round_outcome.push_str(&step_outcome);
92 }
93 }
94 Ok(round_outcome)
95 }
96
97 pub async fn run_all_steps(&mut self, base_req: &mut CompletionRequest) -> crate::Result<()> {
98 base_req.prompt.add_user_message()?.set_content(&self.task);
99 while !self.unresolved_steps.is_empty() {
100 match self.run_next_step(base_req).await {
101 Ok(_) => {}
102 Err(e) => {
103 let mut resolved = std::mem::take(&mut self.resolved_steps);
104 resolved.append(&mut self.unresolved_steps);
105 self.unresolved_steps = resolved;
106 return Err(e);
107 }
108 }
109 }
110
111 let outcome = self.display_outcome()?;
112 base_req
113 .prompt
114 .add_assistant_message()?
115 .set_content(outcome);
116 Ok(())
117 }
118
119 pub async fn run_next_step(&mut self, base_req: &mut CompletionRequest) -> crate::Result<()> {
120 let mut current_step = self.unresolved_steps.pop_front().unwrap();
121 let generation_prefix = self.generation_prefix(¤t_step)?;
122 match current_step
123 .run_step(generation_prefix.as_deref(), base_req)
124 .await
125 {
126 Ok(..) => {
127 self.resolved_steps.push_back(current_step);
128 Ok(())
129 }
130 Err(e) => {
131 self.unresolved_steps.push_front(current_step);
132 Err(e)
133 }
134 }
135 }
136
137 pub async fn cache_next_step(&mut self, base_req: &mut CompletionRequest) -> crate::Result<()> {
138 let mut current_step = self.unresolved_steps.pop_front().unwrap();
139 let generation_prefix = self.generation_prefix(¤t_step)?;
140 match current_step
141 .set_cache_up_to_step(generation_prefix.as_deref(), base_req)
142 .await
143 {
144 Ok(..) => {
145 self.resolved_steps.push_back(current_step);
146 Ok(())
147 }
148 Err(e) => {
149 self.unresolved_steps.push_front(current_step);
150 Err(e)
151 }
152 }
153 }
154
155 pub fn primitive_result(&self) -> Option<String> {
156 if let Some(step) = self.resolved_steps.back() {
157 step.primitive_result()
158 } else {
159 None
160 }
161 }
162
163 pub fn open_round(&mut self, base_req: &mut CompletionRequest) -> crate::Result<()> {
164 base_req.prompt.add_user_message()?.set_content(&self.task);
165 Ok(())
166 }
167
168 pub fn last_step(&mut self) -> crate::Result<&mut CascadeStep> {
169 match self.resolved_steps.back_mut() {
170 Some(step) => Ok(step),
171 None => crate::bail!("No steps in round"),
172 }
173 }
174
175 pub fn drop_last_step(&mut self) -> crate::Result<()> {
176 match self.resolved_steps.pop_back() {
177 Some(..) => Ok(()),
178 None => crate::bail!("No steps in round"),
179 }
180 }
181
182 pub fn close_round(&mut self, base_req: &mut CompletionRequest) -> crate::Result<()> {
183 base_req
184 .prompt
185 .add_assistant_message()?
186 .set_content(self.display_outcome()?);
187
188 Ok(())
189 }
190}
191
192impl std::fmt::Display for CascadeRound {
193 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194 fn print_step(
195 i: usize,
196 step: &CascadeStep,
197 f: &mut std::fmt::Formatter<'_>,
198 ) -> std::fmt::Result {
199 writeln!(f)?;
200 let color = STEP_GRADIENT[i % STEP_GRADIENT.len()];
201 if let Ok(outcome) = step.display_step_outcome() {
202 writeln!(f, "\x1b[1m{color}step {}\x1b[0m: '{}'", i + 1, outcome)?;
203 } else {
204 writeln!(f, "\x1b[1m{color}step {}\x1b[0m: 'No outcome'", i + 1,)?;
205 }
206 Ok(())
207 }
208
209 writeln!(f)?;
210 writeln!(
211 f,
212 "\x1b[1m{}task\x1b[0m: '{}'",
213 STEP_GRADIENT.last().unwrap(),
214 self.task
215 )?;
216 if !self.unresolved_steps.is_empty() {
217 writeln!(f, "\x1b[1munresolved_steps\x1b[0m")?;
218 for (i, step) in self.unresolved_steps.iter().enumerate() {
219 print_step(i, step, f)?;
220 }
221 writeln!(f)?;
222 if !self.resolved_steps.is_empty() {
223 writeln!(f, "\x1b[1mresolved_steps\x1b[0m")?;
224 for (i, step) in self.resolved_steps.iter().enumerate() {
225 print_step(i, step, f)?;
226 }
227 }
228 } else if !self.resolved_steps.is_empty() {
229 for (i, step) in self.resolved_steps.iter().enumerate() {
230 print_step(i, step, f)?;
231 }
232 }
233
234 Ok(())
235 }
236}
237
238static STEP_GRADIENT: std::sync::LazyLock<Vec<&'static str>> = std::sync::LazyLock::new(|| {
239 vec![
240 "\x1B[38;2;0;142;250m",
241 "\x1B[38;2;53;138;249m",
242 "\x1B[38;2;77;133;248m",
243 "\x1B[38;2;95;128;246m",
244 "\x1B[38;2;111;123;243m",
245 "\x1B[38;2;125;118;239m",
246 "\x1B[38;2;138;112;234m",
247 "\x1B[38;2;150;106;228m",
248 "\x1B[38;2;160;100;222m",
249 "\x1B[38;2;170;93;214m",
250 "\x1B[38;2;179;86;206m",
251 "\x1B[38;2;187;79;198m",
252 "\x1B[38;2;194;71;189m",
253 "\x1B[38;2;200;63;179m",
254 "\x1B[38;2;206;54;169m",
255 "\x1B[38;2;210;45;158m",
256 "\x1B[38;2;214;36;147m",
257 "\x1B[38;2;216;26;136m",
258 "\x1B[38;2;218;13;124m",
259 "\x1B[38;2;219;0;113m",
260 ]
261});