alith_client/components/cascade/
step.rs1use super::cascade_request;
2use crate::components::grammar::Grammar;
3use alith_interface::requests::{completion::CompletionRequest, logit_bias::LogitBias};
4use std::cell::RefCell;
5
6#[derive(Clone)]
7pub enum CascadeStep {
8 Inference(InferenceStep),
9 Guidance(GuidanceStep),
10}
11
12impl CascadeStep {
13 pub fn new_inference_step(step_config: StepConfig, step_counter: usize) -> Self {
14 CascadeStep::Inference(InferenceStep {
15 llm_content: None,
16 dynamic_suffix: None,
17 outcome: RefCell::new(None),
18 step_config,
19 step_counter,
20 })
21 }
22
23 pub fn new_guidance_step<S: Into<String>>(
24 step_config: StepConfig,
25 step_counter: usize,
26 llm_content: S,
27 ) -> Self {
28 CascadeStep::Guidance(GuidanceStep {
29 llm_content: llm_content.into(),
30 step_counter,
31 step_config,
32 })
33 }
34
35 pub fn display_step_prefix(&self) -> Option<String> {
36 match self {
37 Self::Inference(step) => step.step_config.display_prefix(step.step_counter),
38 Self::Guidance(step) => step.step_config.display_prefix(step.step_counter),
39 }
40 }
41
42 pub async fn run_step(
43 &mut self,
44 generation_prefix: Option<&str>,
45 base_req: &mut CompletionRequest,
46 ) -> crate::Result<()> {
47 match self {
48 Self::Inference(step) => step.run(generation_prefix, base_req).await,
49 Self::Guidance(_) => self.set_cache_up_to_step(generation_prefix, base_req).await,
50 }
51 }
52
53 pub async fn set_cache_up_to_step(
54 &mut self,
55 generation_prefix: Option<&str>,
56 base_req: &mut CompletionRequest,
57 ) -> crate::Result<()> {
58 if let Some(generation_prefix) = generation_prefix {
59 base_req.prompt.set_generation_prefix(generation_prefix);
60 }
61 base_req
62 .backend
63 .set_cache(&base_req.prompt)
64 .await
65 .map_err(|e| crate::anyhow!("Failed to set cache up to step: {}", e))?;
66 Ok(())
67 }
68
69 pub fn set_dynamic_suffix<S: Into<String>>(&mut self, dynamic_suffix: S) {
70 match self {
71 Self::Inference(step) => step.dynamic_suffix = Some(dynamic_suffix.into()),
72 Self::Guidance(_) => panic!("GuidanceStep does not have dynamic_suffix."),
73 }
74 }
75
76 pub fn display_step_outcome(&self) -> crate::Result<String> {
77 match self {
78 Self::Inference(step) => step.display_outcome(),
79 Self::Guidance(step) => Ok(step.display_outcome()),
80 }
81 }
82
83 pub fn primitive_result(&self) -> Option<String> {
84 match self {
85 Self::Inference(step) => step.llm_content.clone(),
86 Self::Guidance(_) => panic!("GuidanceStep does not have primitive_result."),
87 }
88 }
89}
90
91#[derive(Clone)]
92pub struct InferenceStep {
93 pub llm_content: Option<String>, pub dynamic_suffix: Option<String>, pub outcome: RefCell<Option<String>>,
96 pub step_config: StepConfig,
97 pub step_counter: usize,
98}
99
100impl InferenceStep {
101 async fn run(
102 &mut self,
103 generation_prefix: Option<&str>,
104 base_req: &mut CompletionRequest,
105 ) -> crate::Result<()> {
106 base_req.config.requested_response_tokens = None;
108 base_req.stop_sequences.required = true;
110 base_req.set_base_req_stop_sequences(
111 &Some(self.step_config.stop_word_done.clone()),
112 &self.step_config.stop_word_no_result,
113 );
114 if let Some(stop_word_no_result) = &self.step_config.stop_word_no_result {
116 self.step_config
117 .grammar
118 .set_stop_word_no_result(stop_word_no_result);
119 }
120 self.step_config
121 .grammar
122 .set_stop_word_done(&self.step_config.stop_word_done);
123 if !matches!(self.step_config.grammar, Grammar::NoneGrammar(_)) {
124 base_req.grammar_string = Some(self.step_config.grammar.grammar_string());
125 base_req.stop_sequences.required = true;
126 } else {
127 base_req.grammar_string = None;
128 base_req.stop_sequences.required = false;
129 }
130
131 if let Some(generation_prefix) = generation_prefix {
133 base_req.prompt.set_generation_prefix(generation_prefix);
134 } else {
135 base_req.prompt.clear_generation_prefix();
136 }
137
138 base_req.logit_bias = Some(self.step_config.logit_bias.clone());
140
141 base_req.config.cache_prompt = self.step_config.cache_prompt;
142 cascade_request(base_req, self).await
143 }
144
145 fn display_outcome(&self) -> crate::Result<String> {
147 let llm_content = if let Some(llm_content) = &self.llm_content {
148 llm_content
149 } else if let Some(stop_word_no_result) = &self.step_config.stop_word_no_result {
150 stop_word_no_result
151 } else {
152 crate::bail!("llm_content not yet set and stop_word_no_result not set.")
153 };
154
155 Ok(
156 match (
157 self.step_config.display_prefix(self.step_counter),
158 &self.dynamic_suffix,
159 ) {
160 (Some(step_prefix), Some(dynamic_suffix)) => {
161 format!("{}{}{}", step_prefix, llm_content, dynamic_suffix)
162 }
163 (Some(step_prefix), None) => format!("{}{}", step_prefix, llm_content),
164 (None, Some(dynamic_suffix)) => {
165 format!("{}{}", llm_content, dynamic_suffix)
166 }
167 (None, None) => llm_content.to_owned(),
168 },
169 )
170 }
171}
172
173#[derive(Clone)]
174pub struct GuidanceStep {
175 pub llm_content: String,
176 pub step_config: StepConfig,
177 pub step_counter: usize,
178}
179
180impl GuidanceStep {
181 fn display_outcome(&self) -> String {
182 match self.step_config.display_prefix(self.step_counter) {
183 Some(step_prefix) => format!("{}{}", step_prefix, self.llm_content),
184 None => self.llm_content.to_owned(),
185 }
186 }
187}
188
189#[derive(Clone)]
190pub struct StepConfig {
191 pub step_prefix: Option<String>,
192 pub stop_word_done: String,
193 pub stop_word_no_result: Option<String>,
194 pub use_counter: bool,
195 pub cache_prompt: bool,
196 pub grammar: Grammar,
197 pub logit_bias: LogitBias,
198}
199
200impl Default for StepConfig {
201 fn default() -> Self {
202 Self {
203 step_prefix: None,
204 stop_word_done: "Done.".to_owned(),
205 stop_word_no_result: None,
206 use_counter: false,
207 cache_prompt: true,
208 grammar: Grammar::default(),
209 logit_bias: LogitBias::default(),
210 }
211 }
212}
213
214impl StepConfig {
215 pub fn step_prefix<T: Into<String>>(&mut self, step_prefix: T) -> &mut Self {
216 self.step_prefix = Some(step_prefix.into());
217 self
218 }
219
220 pub fn stop_word_done<T: Into<String>>(&mut self, stop_word_done: T) -> &mut Self {
221 self.stop_word_done = stop_word_done.into();
222 self
223 }
224
225 pub fn stop_word_no_result<T: Into<String>>(&mut self, stop_word_no_result: T) -> &mut Self {
226 self.stop_word_no_result = Some(stop_word_no_result.into());
227 self
228 }
229
230 pub fn use_counter(&mut self, use_counter: bool) -> &mut Self {
231 self.use_counter = use_counter;
232 self
233 }
234
235 pub fn cache_prompt(&mut self, cache_prompt: bool) -> &mut Self {
236 self.cache_prompt = cache_prompt;
237 self
238 }
239
240 pub fn grammar(&mut self, grammar: Grammar) -> &mut Self {
241 self.grammar = grammar;
242 self
243 }
244
245 fn display_prefix(&self, step_counter: usize) -> Option<String> {
246 match (self.use_counter, &self.step_prefix) {
247 (true, Some(step_prefix)) => Some(format!("{} {}", step_counter, step_prefix)),
248 (true, None) => Some(step_counter.to_string()),
249 (false, Some(step_prefix)) => Some(step_prefix.to_string()),
250 (false, None) => None,
251 }
252 }
253}