llmvm_core_lib/
generation.rs

1use std::str::FromStr;
2
3use llmvm_protocol::service::{BackendRequest, BackendResponse};
4use llmvm_protocol::tower::Service;
5use llmvm_protocol::{
6    BackendGenerationRequest, BackendGenerationResponse, GenerationParameters, GenerationRequest,
7    Message, MessageRole, ModelDescription, NotificationStream, ServiceResponse,
8};
9use serde_json::Value;
10
11use tracing::{debug, info};
12
13use crate::error::CoreError;
14use crate::presets::load_preset;
15use crate::prompts::ReadyPrompt;
16use crate::threads::get_thread_messages;
17use crate::{LLMVMCore, Result};
18
19fn merge_generation_parameters(
20    preset_parameters: GenerationParameters,
21    mut request_parameters: GenerationParameters,
22) -> GenerationParameters {
23    GenerationParameters {
24        model: request_parameters.model.or(preset_parameters.model),
25        prompt_template_id: request_parameters
26            .prompt_template_id
27            .or(preset_parameters.prompt_template_id),
28        custom_prompt_template: request_parameters
29            .custom_prompt_template
30            .or(preset_parameters.custom_prompt_template),
31        max_tokens: request_parameters
32            .max_tokens
33            .or(preset_parameters.max_tokens),
34        model_parameters: preset_parameters
35            .model_parameters
36            .map(|mut parameters| {
37                parameters.extend(
38                    request_parameters
39                        .model_parameters
40                        .take()
41                        .unwrap_or_default(),
42                );
43                parameters
44            })
45            .or(request_parameters.model_parameters),
46        prompt_parameters: request_parameters
47            .prompt_parameters
48            .or(preset_parameters.prompt_parameters),
49    }
50}
51
52impl LLMVMCore {
53    pub(super) async fn send_generate_request(
54        &self,
55        request: BackendGenerationRequest,
56        model_description: &ModelDescription,
57    ) -> Result<BackendGenerationResponse> {
58        let mut clients_guard = self.clients.lock().await;
59        let client = self
60            .get_client(&mut clients_guard, model_description)
61            .await?;
62        let resp_future = client.call(BackendRequest::Generation(request));
63        drop(clients_guard);
64        let resp = resp_future
65            .await
66            .map_err(|e| CoreError::Protocol(e.into()))?;
67        match resp {
68            ServiceResponse::Single(response) => match response {
69                BackendResponse::Generation(response) => Ok(response),
70                _ => Err(CoreError::UnexpectedServiceResponse),
71            },
72            _ => Err(CoreError::UnexpectedServiceResponse),
73        }
74    }
75
76    pub(super) async fn send_generate_request_for_stream(
77        &self,
78        request: BackendGenerationRequest,
79        model_description: &ModelDescription,
80    ) -> Result<NotificationStream<BackendResponse>> {
81        let mut clients_guard = self.clients.lock().await;
82        let client = self
83            .get_client(&mut clients_guard, model_description)
84            .await?;
85        let resp_future = client.call(BackendRequest::GenerationStream(request));
86        drop(clients_guard);
87        let resp = resp_future
88            .await
89            .map_err(|e| CoreError::Protocol(e.into()))?;
90        match resp {
91            ServiceResponse::Multiple(stream) => Ok(stream),
92            _ => Err(CoreError::UnexpectedServiceResponse),
93        }
94    }
95
96    pub(super) async fn prepare_for_generate(
97        &self,
98        request: &GenerationRequest,
99    ) -> Result<(
100        BackendGenerationRequest,
101        ModelDescription,
102        Option<Vec<Message>>,
103    )> {
104        let mut parameters = match &request.preset_id {
105            Some(preset_id) => {
106                let mut parameters = load_preset(&preset_id).await?;
107                if let Some(request_parameters) = request.parameters.clone() {
108                    parameters = merge_generation_parameters(parameters, request_parameters);
109                }
110                parameters
111            }
112            None => request
113                .parameters
114                .clone()
115                .ok_or(CoreError::MissingParameters)?,
116        };
117        debug!("generation parameters: {:?}", parameters);
118
119        if parameters.max_tokens.is_none() {
120            parameters.max_tokens = Some(2048);
121        }
122
123        let model = parameters
124            .model
125            .ok_or(CoreError::MissingParameter("model"))?;
126        let model_description =
127            ModelDescription::from_str(&model).map_err(|_| CoreError::ModelDescriptionParse)?;
128        let is_chat_model = model_description.is_chat_model();
129        let prompt_parameters = parameters
130            .prompt_parameters
131            .unwrap_or(Value::Object(Default::default()));
132
133        let prompt = match parameters.custom_prompt_template {
134            Some(template) => {
135                ReadyPrompt::from_custom_template(&template, &prompt_parameters, is_chat_model)?
136            }
137            None => match parameters.prompt_template_id {
138                Some(template_id) => {
139                    ReadyPrompt::from_stored_template(
140                        &template_id,
141                        &prompt_parameters,
142                        is_chat_model,
143                    )
144                    .await?
145                }
146                None => ReadyPrompt::from_custom_prompt(
147                    request
148                        .custom_prompt
149                        .as_ref()
150                        .ok_or(CoreError::TemplateNotFound)?
151                        .clone(),
152                ),
153            },
154        };
155
156        let mut thread_messages = match request.existing_thread_id.as_ref() {
157            Some(thread_id) => Some(get_thread_messages(thread_id).await?),
158            None => None,
159        };
160        if let Some(content) = prompt.system_prompt {
161            let messages = thread_messages.get_or_insert_with(|| Vec::with_capacity(1));
162            messages.retain(|message| {
163                if let MessageRole::System = message.role {
164                    false
165                } else {
166                    true
167                }
168            });
169            messages.insert(
170                0,
171                Message {
172                    role: MessageRole::System,
173                    content,
174                },
175            );
176        }
177
178        let thread_messages_to_save = match request.save_thread {
179            true => {
180                let mut clone = thread_messages.clone().unwrap_or_default();
181                clone.push(Message {
182                    role: MessageRole::User,
183                    content: prompt.main_prompt.clone(),
184                });
185                Some(clone)
186            }
187            false => None,
188        };
189
190        let backend_request = BackendGenerationRequest {
191            model,
192            prompt: prompt.main_prompt,
193            max_tokens: parameters
194                .max_tokens
195                .ok_or(CoreError::MissingParameter("max_tokens"))?,
196            thread_messages,
197            model_parameters: parameters.model_parameters,
198        };
199
200        info!(
201            "Sending backend request with prompt: {}",
202            backend_request.prompt
203        );
204        debug!(
205            "Thread messages for requests: {:#?}",
206            backend_request.thread_messages
207        );
208        Ok((backend_request, model_description, thread_messages_to_save))
209    }
210}