llmvm_core_lib/
generation.rs1use std::str::FromStr;
2
3use llmvm_protocol::service::{BackendRequest, BackendResponse};
4use llmvm_protocol::tower::Service;
5use llmvm_protocol::{
6 BackendGenerationRequest, BackendGenerationResponse, GenerationParameters, GenerationRequest,
7 Message, MessageRole, ModelDescription, NotificationStream, ServiceResponse,
8};
9use serde_json::Value;
10
11use tracing::{debug, info};
12
13use crate::error::CoreError;
14use crate::presets::load_preset;
15use crate::prompts::ReadyPrompt;
16use crate::threads::get_thread_messages;
17use crate::{LLMVMCore, Result};
18
19fn merge_generation_parameters(
20 preset_parameters: GenerationParameters,
21 mut request_parameters: GenerationParameters,
22) -> GenerationParameters {
23 GenerationParameters {
24 model: request_parameters.model.or(preset_parameters.model),
25 prompt_template_id: request_parameters
26 .prompt_template_id
27 .or(preset_parameters.prompt_template_id),
28 custom_prompt_template: request_parameters
29 .custom_prompt_template
30 .or(preset_parameters.custom_prompt_template),
31 max_tokens: request_parameters
32 .max_tokens
33 .or(preset_parameters.max_tokens),
34 model_parameters: preset_parameters
35 .model_parameters
36 .map(|mut parameters| {
37 parameters.extend(
38 request_parameters
39 .model_parameters
40 .take()
41 .unwrap_or_default(),
42 );
43 parameters
44 })
45 .or(request_parameters.model_parameters),
46 prompt_parameters: request_parameters
47 .prompt_parameters
48 .or(preset_parameters.prompt_parameters),
49 }
50}
51
52impl LLMVMCore {
53 pub(super) async fn send_generate_request(
54 &self,
55 request: BackendGenerationRequest,
56 model_description: &ModelDescription,
57 ) -> Result<BackendGenerationResponse> {
58 let mut clients_guard = self.clients.lock().await;
59 let client = self
60 .get_client(&mut clients_guard, model_description)
61 .await?;
62 let resp_future = client.call(BackendRequest::Generation(request));
63 drop(clients_guard);
64 let resp = resp_future
65 .await
66 .map_err(|e| CoreError::Protocol(e.into()))?;
67 match resp {
68 ServiceResponse::Single(response) => match response {
69 BackendResponse::Generation(response) => Ok(response),
70 _ => Err(CoreError::UnexpectedServiceResponse),
71 },
72 _ => Err(CoreError::UnexpectedServiceResponse),
73 }
74 }
75
76 pub(super) async fn send_generate_request_for_stream(
77 &self,
78 request: BackendGenerationRequest,
79 model_description: &ModelDescription,
80 ) -> Result<NotificationStream<BackendResponse>> {
81 let mut clients_guard = self.clients.lock().await;
82 let client = self
83 .get_client(&mut clients_guard, model_description)
84 .await?;
85 let resp_future = client.call(BackendRequest::GenerationStream(request));
86 drop(clients_guard);
87 let resp = resp_future
88 .await
89 .map_err(|e| CoreError::Protocol(e.into()))?;
90 match resp {
91 ServiceResponse::Multiple(stream) => Ok(stream),
92 _ => Err(CoreError::UnexpectedServiceResponse),
93 }
94 }
95
96 pub(super) async fn prepare_for_generate(
97 &self,
98 request: &GenerationRequest,
99 ) -> Result<(
100 BackendGenerationRequest,
101 ModelDescription,
102 Option<Vec<Message>>,
103 )> {
104 let mut parameters = match &request.preset_id {
105 Some(preset_id) => {
106 let mut parameters = load_preset(&preset_id).await?;
107 if let Some(request_parameters) = request.parameters.clone() {
108 parameters = merge_generation_parameters(parameters, request_parameters);
109 }
110 parameters
111 }
112 None => request
113 .parameters
114 .clone()
115 .ok_or(CoreError::MissingParameters)?,
116 };
117 debug!("generation parameters: {:?}", parameters);
118
119 if parameters.max_tokens.is_none() {
120 parameters.max_tokens = Some(2048);
121 }
122
123 let model = parameters
124 .model
125 .ok_or(CoreError::MissingParameter("model"))?;
126 let model_description =
127 ModelDescription::from_str(&model).map_err(|_| CoreError::ModelDescriptionParse)?;
128 let is_chat_model = model_description.is_chat_model();
129 let prompt_parameters = parameters
130 .prompt_parameters
131 .unwrap_or(Value::Object(Default::default()));
132
133 let prompt = match parameters.custom_prompt_template {
134 Some(template) => {
135 ReadyPrompt::from_custom_template(&template, &prompt_parameters, is_chat_model)?
136 }
137 None => match parameters.prompt_template_id {
138 Some(template_id) => {
139 ReadyPrompt::from_stored_template(
140 &template_id,
141 &prompt_parameters,
142 is_chat_model,
143 )
144 .await?
145 }
146 None => ReadyPrompt::from_custom_prompt(
147 request
148 .custom_prompt
149 .as_ref()
150 .ok_or(CoreError::TemplateNotFound)?
151 .clone(),
152 ),
153 },
154 };
155
156 let mut thread_messages = match request.existing_thread_id.as_ref() {
157 Some(thread_id) => Some(get_thread_messages(thread_id).await?),
158 None => None,
159 };
160 if let Some(content) = prompt.system_prompt {
161 let messages = thread_messages.get_or_insert_with(|| Vec::with_capacity(1));
162 messages.retain(|message| {
163 if let MessageRole::System = message.role {
164 false
165 } else {
166 true
167 }
168 });
169 messages.insert(
170 0,
171 Message {
172 role: MessageRole::System,
173 content,
174 },
175 );
176 }
177
178 let thread_messages_to_save = match request.save_thread {
179 true => {
180 let mut clone = thread_messages.clone().unwrap_or_default();
181 clone.push(Message {
182 role: MessageRole::User,
183 content: prompt.main_prompt.clone(),
184 });
185 Some(clone)
186 }
187 false => None,
188 };
189
190 let backend_request = BackendGenerationRequest {
191 model,
192 prompt: prompt.main_prompt,
193 max_tokens: parameters
194 .max_tokens
195 .ok_or(CoreError::MissingParameter("max_tokens"))?,
196 thread_messages,
197 model_parameters: parameters.model_parameters,
198 };
199
200 info!(
201 "Sending backend request with prompt: {}",
202 backend_request.prompt
203 );
204 debug!(
205 "Thread messages for requests: {:#?}",
206 backend_request.thread_messages
207 );
208 Ok((backend_request, model_description, thread_messages_to_save))
209 }
210}