alith_interface/requests/completion/
request.rs

1use super::{ToolChoice, ToolDefinition, error::CompletionError, response::CompletionResponse};
2use crate::{
3    llms::LLMBackend,
4    requests::{
5        completion::response::CompletionFinishReason, logit_bias::LogitBias,
6        req_components::RequestConfig, stop_sequence::StopSequences,
7    },
8};
9use alith_prompt::LLMPrompt;
10use std::sync::Arc;
11
12pub struct CompletionRequest {
13    pub start_time: std::time::Instant,
14    pub stop_sequences: StopSequences,
15    pub grammar_string: Option<String>,
16    pub logit_bias: Option<LogitBias>,
17    pub prompt: LLMPrompt,
18    pub config: RequestConfig,
19    pub backend: Arc<LLMBackend>,
20    pub llm_interface_errors: Vec<CompletionError>,
21    pub tools: Vec<ToolDefinition>,
22    pub tool_choice: ToolChoice,
23}
24
25impl Clone for CompletionRequest {
26    fn clone(&self) -> Self {
27        Self {
28            start_time: self.start_time,
29            stop_sequences: self.stop_sequences.clone(),
30            grammar_string: self.grammar_string.clone(),
31            logit_bias: self.logit_bias.clone(),
32            prompt: self.prompt.clone(),
33            config: self.config.clone(),
34            backend: Arc::clone(&self.backend),
35            llm_interface_errors: Vec::new(),
36            tools: Vec::new(),
37            tool_choice: ToolChoice::Auto,
38        }
39    }
40}
41
42impl CompletionRequest {
43    pub fn new(backend: Arc<LLMBackend>) -> CompletionRequest {
44        CompletionRequest {
45            start_time: std::time::Instant::now(),
46            stop_sequences: Default::default(),
47            logit_bias: None,
48            config: RequestConfig::new(backend.model_ctx_size(), backend.inference_ctx_size()),
49            prompt: backend.new_prompt(),
50            grammar_string: None,
51            backend: Arc::clone(&backend),
52            llm_interface_errors: Vec::new(),
53            tools: Vec::new(),
54            tool_choice: ToolChoice::default(),
55        }
56    }
57
58    pub fn reset_completion_request(&mut self) {
59        self.prompt.reset_prompt();
60        self.stop_sequences.sequences.clear();
61        self.grammar_string = None;
62        self.logit_bias = None;
63    }
64
65    pub async fn request(&mut self) -> crate::Result<CompletionResponse, CompletionError> {
66        self.llm_interface_errors.clear();
67        self.start_time = std::time::Instant::now();
68        self.backend
69            .build_logit_bias(&mut self.logit_bias)
70            .map_err(|e| CompletionError::RequestBuilderError(e.to_string()))?;
71
72        let total_prompt_tokens = self
73            .backend
74            .get_total_prompt_tokens(&self.prompt)
75            .map_err(|e| CompletionError::RequestBuilderError(e.to_string()))?;
76
77        self.config
78            .set_max_tokens_for_request(total_prompt_tokens)
79            .map_err(CompletionError::RequestTokenLimitError)?;
80
81        let mut retry_count: u8 = 0;
82
83        loop {
84            if retry_count >= self.config.retry_after_fail_n_times {
85                let llm_interface_error = CompletionError::ExceededRetryCount {
86                    message: format!("Request failed after {retry_count} attempts."),
87                    errors: std::mem::take(&mut self.llm_interface_errors),
88                };
89                tracing::error!(?llm_interface_error);
90                eprintln!("{}", llm_interface_error);
91                return Err(llm_interface_error);
92            }
93            tracing::info!("{}", self);
94            match self.backend.completion_request(self).await {
95                Err(e) => {
96                    tracing::warn!(?e);
97                    retry_count += 1;
98                    match e {
99                        CompletionError::RequestBuilderError { .. }
100                        | CompletionError::StopReasonUnsupported { .. }
101                        | CompletionError::ClientError { .. } => {
102                            return Err(e);
103                        }
104
105                        _ => (),
106                    }
107                    self.llm_interface_errors.push(e);
108                    continue;
109                }
110                Ok(res) => {
111                    tracing::info!("{}", res);
112                    if self.stop_sequences.required {
113                        if matches!(
114                            res.finish_reason,
115                            CompletionFinishReason::MatchingStoppingSequence(_)
116                        ) {
117                            return Ok(res);
118                        } else {
119                            let llm_interface_error = match res.finish_reason {
120                                CompletionFinishReason::NonMatchingStoppingSequence(s) => {
121                                    if let Some(s) = s {
122                                        CompletionError::NonMatchingStopSequence(s.clone())
123                                    } else {
124                                        CompletionError::NoRequiredStopSequence
125                                    }
126                                }
127                                _ => CompletionError::NoRequiredStopSequence,
128                            };
129                            tracing::warn!(?llm_interface_error);
130                            self.llm_interface_errors.push(llm_interface_error);
131                            retry_count += 1;
132                            if self.config.increase_limit_on_fail {
133                                self.config
134                                    .increase_token_limit(total_prompt_tokens, None)?;
135                            }
136                            continue;
137                        };
138                    };
139                    match res.finish_reason {
140                        CompletionFinishReason::NonMatchingStoppingSequence(_)
141                        | CompletionFinishReason::MatchingStoppingSequence(_) => return Ok(res),
142                        CompletionFinishReason::StopLimit => {
143                            if self.config.increase_limit_on_fail {
144                                let llm_interface_error = CompletionError::StopLimitRetry;
145                                tracing::warn!(?llm_interface_error);
146                                self.llm_interface_errors.push(llm_interface_error);
147                                self.config
148                                    .increase_token_limit(total_prompt_tokens, None)?;
149                                retry_count += 1;
150                                continue;
151                            }
152                            return Ok(res);
153                        }
154                        CompletionFinishReason::Eos | CompletionFinishReason::ToolsCall => {
155                            return Ok(res);
156                        }
157                    }
158                }
159            };
160        }
161    }
162
163    pub fn set_base_req_stop_sequences(
164        &mut self,
165        stop_word_done: &Option<String>,
166        stop_word_no_result: &Option<String>,
167    ) {
168        if stop_word_done.is_some() || stop_word_no_result.is_some()
169        // || step.stop_word_steps_done.is_some()
170        {
171            self.stop_sequences.required = true;
172            self.stop_sequences.sequences.clear();
173        }
174        if let Some(stop_word_done) = &stop_word_done {
175            self.stop_sequences.set_stop_word_done(stop_word_done);
176        }
177
178        if let Some(no_result_stop_word) = &stop_word_no_result {
179            self.stop_sequences
180                .set_stop_word_no_result(no_result_stop_word);
181        }
182    }
183}
184
185impl std::fmt::Display for CompletionRequest {
186    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187        writeln!(f)?;
188        writeln!(f, "CompletionRequest:")?;
189
190        writeln!(f, "  prompt: {}", self.prompt)?;
191        writeln!(f, "  stop_sequences: {:?}", self.stop_sequences.to_vec())?;
192        writeln!(f, "  grammar_string: {:?}", self.grammar_string)?;
193        write!(f, "  config: {}", self.config)?;
194        write!(f, "  tools: {:?}", self.tools)
195    }
196}