alith_client/components/cascade/
step.rs

1use super::cascade_request;
2use crate::components::grammar::Grammar;
3use alith_interface::requests::{completion::CompletionRequest, logit_bias::LogitBias};
4use std::cell::RefCell;
5
6#[derive(Clone)]
7pub enum CascadeStep {
8    Inference(InferenceStep),
9    Guidance(GuidanceStep),
10}
11
12impl CascadeStep {
13    pub fn new_inference_step(step_config: StepConfig, step_counter: usize) -> Self {
14        CascadeStep::Inference(InferenceStep {
15            llm_content: None,
16            dynamic_suffix: None,
17            outcome: RefCell::new(None),
18            step_config,
19            step_counter,
20        })
21    }
22
23    pub fn new_guidance_step<S: Into<String>>(
24        step_config: StepConfig,
25        step_counter: usize,
26        llm_content: S,
27    ) -> Self {
28        CascadeStep::Guidance(GuidanceStep {
29            llm_content: llm_content.into(),
30            step_counter,
31            step_config,
32        })
33    }
34
35    pub fn display_step_prefix(&self) -> Option<String> {
36        match self {
37            Self::Inference(step) => step.step_config.display_prefix(step.step_counter),
38            Self::Guidance(step) => step.step_config.display_prefix(step.step_counter),
39        }
40    }
41
42    pub async fn run_step(
43        &mut self,
44        generation_prefix: Option<&str>,
45        base_req: &mut CompletionRequest,
46    ) -> crate::Result<()> {
47        match self {
48            Self::Inference(step) => step.run(generation_prefix, base_req).await,
49            Self::Guidance(_) => self.set_cache_up_to_step(generation_prefix, base_req).await,
50        }
51    }
52
53    pub async fn set_cache_up_to_step(
54        &mut self,
55        generation_prefix: Option<&str>,
56        base_req: &mut CompletionRequest,
57    ) -> crate::Result<()> {
58        if let Some(generation_prefix) = generation_prefix {
59            base_req.prompt.set_generation_prefix(generation_prefix);
60        }
61        base_req
62            .backend
63            .set_cache(&base_req.prompt)
64            .await
65            .map_err(|e| crate::anyhow!("Failed to set cache up to step: {}", e))?;
66        Ok(())
67    }
68
69    pub fn set_dynamic_suffix<S: Into<String>>(&mut self, dynamic_suffix: S) {
70        match self {
71            Self::Inference(step) => step.dynamic_suffix = Some(dynamic_suffix.into()),
72            Self::Guidance(_) => panic!("GuidanceStep does not have dynamic_suffix."),
73        }
74    }
75
76    pub fn display_step_outcome(&self) -> crate::Result<String> {
77        match self {
78            Self::Inference(step) => step.display_outcome(),
79            Self::Guidance(step) => Ok(step.display_outcome()),
80        }
81    }
82
83    pub fn primitive_result(&self) -> Option<String> {
84        match self {
85            Self::Inference(step) => step.llm_content.clone(),
86            Self::Guidance(_) => panic!("GuidanceStep does not have primitive_result."),
87        }
88    }
89}
90
91#[derive(Clone)]
92pub struct InferenceStep {
93    pub llm_content: Option<String>, // raw, unformatted result from llm.
94    pub dynamic_suffix: Option<String>, // suffix to be added to the result.
95    pub outcome: RefCell<Option<String>>,
96    pub step_config: StepConfig,
97    pub step_counter: usize,
98}
99
100impl InferenceStep {
101    async fn run(
102        &mut self,
103        generation_prefix: Option<&str>,
104        base_req: &mut CompletionRequest,
105    ) -> crate::Result<()> {
106        // Request tokens
107        base_req.config.requested_response_tokens = None;
108        // Request stop words
109        base_req.stop_sequences.required = true;
110        base_req.set_base_req_stop_sequences(
111            &Some(self.step_config.stop_word_done.clone()),
112            &self.step_config.stop_word_no_result,
113        );
114        // Request grammar
115        if let Some(stop_word_no_result) = &self.step_config.stop_word_no_result {
116            self.step_config
117                .grammar
118                .set_stop_word_no_result(stop_word_no_result);
119        }
120        self.step_config
121            .grammar
122            .set_stop_word_done(&self.step_config.stop_word_done);
123        if !matches!(self.step_config.grammar, Grammar::NoneGrammar(_)) {
124            base_req.grammar_string = Some(self.step_config.grammar.grammar_string());
125            base_req.stop_sequences.required = true;
126        } else {
127            base_req.grammar_string = None;
128            base_req.stop_sequences.required = false;
129        }
130
131        // Request prompt
132        if let Some(generation_prefix) = generation_prefix {
133            base_req.prompt.set_generation_prefix(generation_prefix);
134        } else {
135            base_req.prompt.clear_generation_prefix();
136        }
137
138        // Request logit bias
139        base_req.logit_bias = Some(self.step_config.logit_bias.clone());
140
141        base_req.config.cache_prompt = self.step_config.cache_prompt;
142        cascade_request(base_req, self).await
143    }
144
145    // step_counter + step_prefix + prefix_delimiter + (llm_content | stop_word_no_result) + dynamic_suffix
146    fn display_outcome(&self) -> crate::Result<String> {
147        let llm_content = if let Some(llm_content) = &self.llm_content {
148            llm_content
149        } else if let Some(stop_word_no_result) = &self.step_config.stop_word_no_result {
150            stop_word_no_result
151        } else {
152            crate::bail!("llm_content not yet set and stop_word_no_result not set.")
153        };
154
155        Ok(
156            match (
157                self.step_config.display_prefix(self.step_counter),
158                &self.dynamic_suffix,
159            ) {
160                (Some(step_prefix), Some(dynamic_suffix)) => {
161                    format!("{}{}{}", step_prefix, llm_content, dynamic_suffix)
162                }
163                (Some(step_prefix), None) => format!("{}{}", step_prefix, llm_content),
164                (None, Some(dynamic_suffix)) => {
165                    format!("{}{}", llm_content, dynamic_suffix)
166                }
167                (None, None) => llm_content.to_owned(),
168            },
169        )
170    }
171}
172
173#[derive(Clone)]
174pub struct GuidanceStep {
175    pub llm_content: String,
176    pub step_config: StepConfig,
177    pub step_counter: usize,
178}
179
180impl GuidanceStep {
181    fn display_outcome(&self) -> String {
182        match self.step_config.display_prefix(self.step_counter) {
183            Some(step_prefix) => format!("{}{}", step_prefix, self.llm_content),
184            None => self.llm_content.to_owned(),
185        }
186    }
187}
188
189#[derive(Clone)]
190pub struct StepConfig {
191    pub step_prefix: Option<String>,
192    pub stop_word_done: String,
193    pub stop_word_no_result: Option<String>,
194    pub use_counter: bool,
195    pub cache_prompt: bool,
196    pub grammar: Grammar,
197    pub logit_bias: LogitBias,
198}
199
200impl Default for StepConfig {
201    fn default() -> Self {
202        Self {
203            step_prefix: None,
204            stop_word_done: "Done.".to_owned(),
205            stop_word_no_result: None,
206            use_counter: false,
207            cache_prompt: true,
208            grammar: Grammar::default(),
209            logit_bias: LogitBias::default(),
210        }
211    }
212}
213
214impl StepConfig {
215    pub fn step_prefix<T: Into<String>>(&mut self, step_prefix: T) -> &mut Self {
216        self.step_prefix = Some(step_prefix.into());
217        self
218    }
219
220    pub fn stop_word_done<T: Into<String>>(&mut self, stop_word_done: T) -> &mut Self {
221        self.stop_word_done = stop_word_done.into();
222        self
223    }
224
225    pub fn stop_word_no_result<T: Into<String>>(&mut self, stop_word_no_result: T) -> &mut Self {
226        self.stop_word_no_result = Some(stop_word_no_result.into());
227        self
228    }
229
230    pub fn use_counter(&mut self, use_counter: bool) -> &mut Self {
231        self.use_counter = use_counter;
232        self
233    }
234
235    pub fn cache_prompt(&mut self, cache_prompt: bool) -> &mut Self {
236        self.cache_prompt = cache_prompt;
237        self
238    }
239
240    pub fn grammar(&mut self, grammar: Grammar) -> &mut Self {
241        self.grammar = grammar;
242        self
243    }
244
245    fn display_prefix(&self, step_counter: usize) -> Option<String> {
246        match (self.use_counter, &self.step_prefix) {
247            (true, Some(step_prefix)) => Some(format!("{} {}", step_counter, step_prefix)),
248            (true, None) => Some(step_counter.to_string()),
249            (false, Some(step_prefix)) => Some(step_prefix.to_string()),
250            (false, None) => None,
251        }
252    }
253}