alith_interface/requests/
req_components.rs

1use alith_prompt::{MaxTokenState, RequestTokenLimitError, check_and_get_max_tokens};
2
3#[derive(Clone)]
4pub struct RequestConfig {
5    /// Total token limit for input and output combined.
6    ///
7    /// This value represents the maximum number of tokens that can be used for both
8    /// the input prompt and the model's output combined. It's set once when the
9    /// RequestConfig is created and is used to calculate the available token budget
10    /// for each request.
11    ///
12    /// This limit applies to all LLM types, including both local and API-based models.
13    pub(crate) model_ctx_size: u64,
14    /// Maximum token limit for model output.
15    ///
16    /// This value represents the maximum number of tokens the model can generate
17    /// as output. It's set once when the RequestConfig is created and is used to
18    /// ensure the model's response doesn't exceed this limit.
19    ///
20    /// Note: This limit is primarily used by API-based LLMs. For local LLMs,
21    /// [RequestConfig::inference_ctx_size] should use the same value as '[RequestConfig::model_ctx_size].
22    pub(crate) inference_ctx_size: u64,
23    /// Requested maximum number of tokens for the model's output.
24    ///
25    /// This value specifies the upper limit of tokens the model should generate in its response.
26    ///
27    /// The system uses this value, along with the input prompt length, to ensure the entire
28    /// request (input + output) stays within the model's token limits.
29    ///
30    /// - For OpenAI API-compatible LLMs, this corresponds to the 'max_tokens' parameter.
31    /// - For local LLMs, this is equivalent to the 'n_predict' parameter.
32    ///
33    /// If `None`, the system will use a default or calculated value based on [RequestConfig::model_ctx_size] or [RequestConfig::inference_ctx_size].
34    pub requested_response_tokens: Option<u64>,
35    /// A small safety margin to prevent exceeding model limits.
36    ///
37    /// This is a count of tokens subtracted from the total available tokens to help ensure
38    /// that the model doesn't unexpectedly exceed its token limit.
39    /// This prevents issues that might arise from slight discrepancies in token counting or unexpected model behavior.
40    ///
41    /// Defaults to 10 tokens.
42    pub safety_tokens: u64,
43    /// Final adjusted token count for model output.
44    ///
45    /// This value represents the actual number of tokens requested for the model's output
46    /// after all adjustments and calculations have been made. It's derived from
47    /// [RequestConfig::requested_response_tokens] but may be different to ensure the request stays
48    /// within the model's limits.
49    pub(crate) actual_request_tokens: Option<u64>,
50    /// Controls the randomness of the model's output.
51    ///
52    /// The temperature parameter adjusts the randomness in token selection for the model's
53    /// response. It accepts values between 0.0 and 2.0:
54    /// - Higher values (e.g., 0.8) increase randomness, leading to more diverse and creative outputs.
55    /// - Lower values (e.g., 0.2) decrease randomness, resulting in more focused and deterministic responses.
56    ///
57    /// Note: It's generally recommended to adjust either this parameter or `top_p`, but not both simultaneously.
58    ///
59    /// Special considerations:
60    /// - For Anthropic models: This value is automatically scaled to the range 0.0 to 1.0 to match
61    ///   the requirements of [crate::llms::api::anthropic::completion::AnthropicCompletionRequest::temperature].
62    ///
63    /// Supported by all LLM backends.
64    ///
65    /// Defaults to `1.0`.
66    pub temperature: f32,
67    /// Adjusts token selection based on their frequency in the generated text.
68    ///
69    /// The frequency penalty influences how the model selects tokens based on their existing
70    /// frequency in the output. It accepts values between -2.0 and 2.0:
71    /// - Positive values decrease the likelihood of repeating tokens, reducing verbatim repetition.
72    /// - Negative values increase the likelihood of repeating tokens, potentially leading to more repetitive text.
73    /// - A value of 0.0 (or `None`) applies no frequency-based adjustments.
74    ///
75    /// This can be particularly useful for:
76    /// - Encouraging more diverse vocabulary usage (with positive values)
77    /// - Maintaining consistent terminology (with negative values)
78    ///
79    /// Supported LLMs: openai
80    ///
81    /// Defaults to `None` (no frequency penalty applied).
82    pub frequency_penalty: Option<f32>,
83    /// Adjusts token selection based on their presence in the generated text.
84    ///
85    /// The presence penalty influences how the model selects tokens based on whether they've
86    /// appeared at all in the output, regardless of frequency. It accepts values between -2.0 and 2.0:
87    /// - Positive values decrease the likelihood of using tokens that have appeared at all,
88    ///   encouraging the model to introduce new concepts and topics.
89    /// - Negative values increase the likelihood of reusing tokens that have appeared,
90    ///   potentially leading to more focused or repetitive text.
91    /// - A value of 0.0 applies no presence-based adjustments.
92    ///
93    /// This differs from `frequency_penalty` in that it considers only whether a token has
94    /// appeared, not how often.
95    ///
96    /// Use cases:
97    /// - Encouraging the model to cover more topics (with positive values)
98    /// - Maintaining focus on specific themes (with negative values)
99    ///
100    /// Supported LLMs: openai
101    ///
102    /// Defaults to `0.0` (no presence penalty applied).
103    pub presence_penalty: f32,
104    /// Controls diversity via nucleus sampling.
105    ///
106    /// Top-p sampling (also called nucleus sampling) is an alternative to temperature-based sampling.
107    /// It selects from the smallest possible set of tokens whose cumulative probability exceeds
108    /// the probability `p`. The value should be between 0.0 and 1.0:
109    /// - A value of 0.1 means only the tokens comprising the top 10% probability mass are considered.
110    /// - Lower values lead to more focused and deterministic outputs.
111    /// - Higher values allow for more diverse outputs.
112    ///
113    /// Key points:
114    /// - It's generally recommended to adjust either this or `temperature`, but not both simultaneously.
115    /// - This method is considered more advanced than `temperature` and is recommended for
116    ///   users who need fine-grained control over output diversity.
117    ///
118    /// Supported LLMs: All
119    ///
120    /// Defaults to `None` (not used, falling back to temperature-based sampling).
121    pub top_p: Option<f32>,
122    /// Maximum number of retry attempts after a request failure.
123    ///
124    /// Specifies how many times the system should attempt to retry a failed request before giving up.
125    /// This can help handle transient errors or temporary service unavailability.
126    ///
127    /// Supported LLMs: All
128    ///
129    /// Defaults to `3`.
130    pub retry_after_fail_n_times: u8,
131    /// Automatically increase token limit on request failure.
132    ///
133    /// When set to `true`, if a request fails due to token limit constraints or other errors,
134    /// the system will attempt to increase the token limit using [`RequestConfig::increase_token_limit`]
135    /// before retrying the request.
136    ///
137    /// Supported LLMs: All
138    ///
139    /// Defaults to `false`.
140    pub increase_limit_on_fail: bool,
141    /// Enable prompt caching for subsequent requests.
142    ///
143    /// When set to `true`, the system will cache the prompt and reuse it for the next request.
144    /// This can potentially improve performance for repeated or similar queries.
145    ///
146    /// Supported LLMs
147    ///
148    /// Defaults to `false`.
149    pub cache_prompt: bool,
150}
151
152impl RequestConfig {
153    pub fn new(model_ctx_size: u64, inference_ctx_size: u64) -> Self {
154        Self {
155            model_ctx_size,
156            inference_ctx_size,
157            requested_response_tokens: None,
158            actual_request_tokens: None,
159            frequency_penalty: None,
160            presence_penalty: 0.0,
161            temperature: 1.0,
162            top_p: None,
163            safety_tokens: 10,
164            retry_after_fail_n_times: 3,
165            increase_limit_on_fail: false,
166            cache_prompt: false,
167        }
168    }
169
170    pub fn set_max_tokens_for_request(
171        &mut self,
172        total_prompt_tokens: u64,
173    ) -> crate::Result<(), RequestTokenLimitError> {
174        let actual_request_tokens = check_and_get_max_tokens(
175            self.model_ctx_size,
176            Some(self.inference_ctx_size),
177            total_prompt_tokens,
178            Some(self.safety_tokens),
179            self.requested_response_tokens,
180        )?;
181        let tokens = if actual_request_tokens == 0 {
182            total_prompt_tokens
183        } else {
184            actual_request_tokens
185        };
186        self.actual_request_tokens = Some(tokens);
187        if self.requested_response_tokens.is_none() {
188            self.requested_response_tokens = Some(tokens);
189        }
190        Ok(())
191    }
192
193    pub const DEFAULT_INCREASE_FACTOR: f32 = 1.33;
194    pub fn increase_token_limit(
195        &mut self,
196        total_prompt_tokens: u64,
197        token_increase_factor: Option<f32>,
198    ) -> crate::Result<(), RequestTokenLimitError> {
199        let token_increase_factor = token_increase_factor.unwrap_or(Self::DEFAULT_INCREASE_FACTOR);
200        crate::info!(
201            "Attempting to increase requested_response_tokens by {token_increase_factor} before retrying."
202        );
203
204        if self.actual_request_tokens.is_none() || self.requested_response_tokens.is_none() {
205            self.set_max_tokens_for_request(total_prompt_tokens)?; // To ensure both token sets are set
206        }
207
208        let initial_state = MaxTokenState {
209            actual_request: self
210                .actual_request_tokens
211                .expect("requested_response_tokens"),
212            requested_response: self
213                .requested_response_tokens
214                .expect("requested_response_tokens"),
215        };
216
217        self.requested_response_tokens =
218            Some((initial_state.requested_response as f32 * token_increase_factor) as u64);
219
220        let new_state = MaxTokenState {
221            actual_request: self
222                .actual_request_tokens
223                .expect("requested_response_tokens"),
224            requested_response: self
225                .requested_response_tokens
226                .expect("requested_response_tokens"),
227        };
228
229        crate::info!(
230            "Token counts changed: actual_request ({} -> {}), requested_response ({} -> {})",
231            initial_state.actual_request,
232            new_state.actual_request,
233            initial_state.requested_response,
234            new_state.requested_response
235        );
236
237        self.set_max_tokens_for_request(total_prompt_tokens)?;
238
239        if new_state.actual_request <= initial_state.actual_request {
240            crate::error!("Increase limit failed.");
241            Err(RequestTokenLimitError::TokenLimitIncreaseError {
242                initial_state,
243                new_state,
244            })
245        } else {
246            crate::info!("Increase limit succeeded. Retrying.");
247            Ok(())
248        }
249    }
250}
251
252pub trait RequestConfigTrait {
253    fn config(&mut self) -> &mut RequestConfig;
254
255    fn reset_request(&mut self);
256
257    /// Sets the value of [RequestConfig::requested_response_tokens].
258    fn max_tokens(&mut self, max_tokens: u64) -> &mut Self {
259        self.config().requested_response_tokens = Some(max_tokens);
260        self
261    }
262
263    /// Sets the value of [RequestConfig::frequency_penalty].
264    fn frequency_penalty(&mut self, frequency_penalty: f32) -> &mut Self {
265        self.config().frequency_penalty = Some(frequency_penalty);
266        self
267    }
268
269    /// Sets the value of [RequestConfig::presence_penalty].
270    fn presence_penalty(&mut self, presence_penalty: f32) -> &mut Self {
271        match presence_penalty {
272            value if (-2.0..=2.0).contains(&value) => self.config().presence_penalty = value,
273            _ => self.config().presence_penalty = 0.0,
274        };
275        self
276    }
277
278    /// Sets the value of [RequestConfig::temperature].
279    fn temperature(&mut self, temperature: f32) -> &mut Self {
280        match temperature {
281            value if (0.0..=2.0).contains(&value) => self.config().temperature = value,
282            _ => self.config().temperature = 1.0,
283        };
284        self
285    }
286
287    /// Sets the value of [RequestConfig::top_p].
288    fn top_p(&mut self, top_p: f32) -> &mut Self {
289        self.config().top_p = Some(top_p);
290        self
291    }
292
293    /// Sets the value of [RequestConfig::retry_after_fail_n_times].
294    fn retry_after_fail_n_times(&mut self, retry_after_fail_n_times: u8) -> &mut Self {
295        self.config().retry_after_fail_n_times = retry_after_fail_n_times;
296        self
297    }
298
299    /// Sets the value of [RequestConfig::increase_limit_on_fail].
300    fn increase_limit_on_fail(&mut self, increase_limit_on_fail: bool) -> &mut Self {
301        self.config().increase_limit_on_fail = increase_limit_on_fail;
302        self
303    }
304
305    /// Sets the value of [RequestConfig::cache_prompt].
306    fn cache_prompt(&mut self, cache_prompt: bool) -> &mut Self {
307        self.config().cache_prompt = cache_prompt;
308        self
309    }
310}
311
312impl std::fmt::Display for RequestConfig {
313    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
314        writeln!(f)?;
315        writeln!(f, "    model_ctx_size: {}", self.model_ctx_size)?;
316        writeln!(f, "    inference_ctx_size: {}", self.inference_ctx_size)?;
317        writeln!(
318            f,
319            "    requested_response_tokens: {:?}",
320            self.requested_response_tokens
321        )?;
322        writeln!(
323            f,
324            "    actual_request_tokens: {:?}",
325            self.actual_request_tokens
326        )?;
327        writeln!(f, "    frequency_penalty: {:?}", self.frequency_penalty)?;
328        writeln!(f, "    presence_penalty: {:?}", self.presence_penalty)?;
329        writeln!(f, "    temperature: {:?}", self.temperature)?;
330        writeln!(f, "    top_p: {:?}", self.top_p)?;
331        writeln!(
332            f,
333            "    retry_after_fail_n_times: {:?}",
334            self.retry_after_fail_n_times
335        )?;
336        writeln!(
337            f,
338            "    increase_limit_on_fail: {:?}",
339            self.increase_limit_on_fail
340        )?;
341        writeln!(f, "    cache_prompt: {:?}", self.cache_prompt)
342    }
343}