alith_interface/requests/completion/
request.rs1use super::{ToolChoice, ToolDefinition, error::CompletionError, response::CompletionResponse};
2use crate::{
3 llms::LLMBackend,
4 requests::{
5 completion::response::CompletionFinishReason, logit_bias::LogitBias,
6 req_components::RequestConfig, stop_sequence::StopSequences,
7 },
8};
9use alith_prompt::LLMPrompt;
10use std::sync::Arc;
11
12pub struct CompletionRequest {
13 pub start_time: std::time::Instant,
14 pub stop_sequences: StopSequences,
15 pub grammar_string: Option<String>,
16 pub logit_bias: Option<LogitBias>,
17 pub prompt: LLMPrompt,
18 pub config: RequestConfig,
19 pub backend: Arc<LLMBackend>,
20 pub llm_interface_errors: Vec<CompletionError>,
21 pub tools: Vec<ToolDefinition>,
22 pub tool_choice: ToolChoice,
23}
24
25impl Clone for CompletionRequest {
26 fn clone(&self) -> Self {
27 Self {
28 start_time: self.start_time,
29 stop_sequences: self.stop_sequences.clone(),
30 grammar_string: self.grammar_string.clone(),
31 logit_bias: self.logit_bias.clone(),
32 prompt: self.prompt.clone(),
33 config: self.config.clone(),
34 backend: Arc::clone(&self.backend),
35 llm_interface_errors: Vec::new(),
36 tools: Vec::new(),
37 tool_choice: ToolChoice::Auto,
38 }
39 }
40}
41
42impl CompletionRequest {
43 pub fn new(backend: Arc<LLMBackend>) -> CompletionRequest {
44 CompletionRequest {
45 start_time: std::time::Instant::now(),
46 stop_sequences: Default::default(),
47 logit_bias: None,
48 config: RequestConfig::new(backend.model_ctx_size(), backend.inference_ctx_size()),
49 prompt: backend.new_prompt(),
50 grammar_string: None,
51 backend: Arc::clone(&backend),
52 llm_interface_errors: Vec::new(),
53 tools: Vec::new(),
54 tool_choice: ToolChoice::default(),
55 }
56 }
57
58 pub fn reset_completion_request(&mut self) {
59 self.prompt.reset_prompt();
60 self.stop_sequences.sequences.clear();
61 self.grammar_string = None;
62 self.logit_bias = None;
63 }
64
65 pub async fn request(&mut self) -> crate::Result<CompletionResponse, CompletionError> {
66 self.llm_interface_errors.clear();
67 self.start_time = std::time::Instant::now();
68 self.backend
69 .build_logit_bias(&mut self.logit_bias)
70 .map_err(|e| CompletionError::RequestBuilderError(e.to_string()))?;
71
72 let total_prompt_tokens = self
73 .backend
74 .get_total_prompt_tokens(&self.prompt)
75 .map_err(|e| CompletionError::RequestBuilderError(e.to_string()))?;
76
77 self.config
78 .set_max_tokens_for_request(total_prompt_tokens)
79 .map_err(CompletionError::RequestTokenLimitError)?;
80
81 let mut retry_count: u8 = 0;
82
83 loop {
84 if retry_count >= self.config.retry_after_fail_n_times {
85 let llm_interface_error = CompletionError::ExceededRetryCount {
86 message: format!("Request failed after {retry_count} attempts."),
87 errors: std::mem::take(&mut self.llm_interface_errors),
88 };
89 tracing::error!(?llm_interface_error);
90 eprintln!("{}", llm_interface_error);
91 return Err(llm_interface_error);
92 }
93 tracing::info!("{}", self);
94 match self.backend.completion_request(self).await {
95 Err(e) => {
96 tracing::warn!(?e);
97 retry_count += 1;
98 match e {
99 CompletionError::RequestBuilderError { .. }
100 | CompletionError::StopReasonUnsupported { .. }
101 | CompletionError::ClientError { .. } => {
102 return Err(e);
103 }
104
105 _ => (),
106 }
107 self.llm_interface_errors.push(e);
108 continue;
109 }
110 Ok(res) => {
111 tracing::info!("{}", res);
112 if self.stop_sequences.required {
113 if matches!(
114 res.finish_reason,
115 CompletionFinishReason::MatchingStoppingSequence(_)
116 ) {
117 return Ok(res);
118 } else {
119 let llm_interface_error = match res.finish_reason {
120 CompletionFinishReason::NonMatchingStoppingSequence(s) => {
121 if let Some(s) = s {
122 CompletionError::NonMatchingStopSequence(s.clone())
123 } else {
124 CompletionError::NoRequiredStopSequence
125 }
126 }
127 _ => CompletionError::NoRequiredStopSequence,
128 };
129 tracing::warn!(?llm_interface_error);
130 self.llm_interface_errors.push(llm_interface_error);
131 retry_count += 1;
132 if self.config.increase_limit_on_fail {
133 self.config
134 .increase_token_limit(total_prompt_tokens, None)?;
135 }
136 continue;
137 };
138 };
139 match res.finish_reason {
140 CompletionFinishReason::NonMatchingStoppingSequence(_)
141 | CompletionFinishReason::MatchingStoppingSequence(_) => return Ok(res),
142 CompletionFinishReason::StopLimit => {
143 if self.config.increase_limit_on_fail {
144 let llm_interface_error = CompletionError::StopLimitRetry;
145 tracing::warn!(?llm_interface_error);
146 self.llm_interface_errors.push(llm_interface_error);
147 self.config
148 .increase_token_limit(total_prompt_tokens, None)?;
149 retry_count += 1;
150 continue;
151 }
152 return Ok(res);
153 }
154 CompletionFinishReason::Eos | CompletionFinishReason::ToolsCall => {
155 return Ok(res);
156 }
157 }
158 }
159 };
160 }
161 }
162
163 pub fn set_base_req_stop_sequences(
164 &mut self,
165 stop_word_done: &Option<String>,
166 stop_word_no_result: &Option<String>,
167 ) {
168 if stop_word_done.is_some() || stop_word_no_result.is_some()
169 {
171 self.stop_sequences.required = true;
172 self.stop_sequences.sequences.clear();
173 }
174 if let Some(stop_word_done) = &stop_word_done {
175 self.stop_sequences.set_stop_word_done(stop_word_done);
176 }
177
178 if let Some(no_result_stop_word) = &stop_word_no_result {
179 self.stop_sequences
180 .set_stop_word_no_result(no_result_stop_word);
181 }
182 }
183}
184
185impl std::fmt::Display for CompletionRequest {
186 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187 writeln!(f)?;
188 writeln!(f, "CompletionRequest:")?;
189
190 writeln!(f, " prompt: {}", self.prompt)?;
191 writeln!(f, " stop_sequences: {:?}", self.stop_sequences.to_vec())?;
192 writeln!(f, " grammar_string: {:?}", self.grammar_string)?;
193 write!(f, " config: {}", self.config)?;
194 write!(f, " tools: {:?}", self.tools)
195 }
196}