use glob::glob;
use log::debug;
use std::fs;
use crate::{
config::{Message, Prompt, PLACEHOLDER_TOKEN},
PromptParams,
};
pub fn customize_prompt(mut prompt: Prompt, prompt_params: &PromptParams) -> Prompt {
debug!("pre-customization prompt {:?}", prompt);
if let Some(api) = prompt_params.api.clone() {
prompt.api = api.to_owned();
}
if prompt_params.model.is_some() {
prompt.model = prompt_params.model.to_owned();
}
let mut first_user_message_index = prompt
.messages
.iter()
.position(|m| m.role == "user")
.unwrap_or(0);
let mut maybe_insert_message = |content: Option<String>, prefix: Option<String>| {
if let Some(mut content) = content {
if let Some(mut pre) = prefix {
pre.push_str(&content);
content = pre;
}
let system_message = Message::system(&content);
prompt
.messages
.insert(first_user_message_index, system_message);
first_user_message_index += 1;
}
};
maybe_insert_message(prompt_params.system_message.clone(), None);
let context = prompt_params.context.clone().and_then(|glob_pattern| {
let files_content = glob(&glob_pattern)
.expect("Failed to read glob pattern")
.filter_map(Result::ok)
.filter_map(|path| {
fs::read_to_string(&path)
.ok()
.map(|content| format!("{}:\n```\n{}\n```\n", path.display(), content))
})
.collect::<String>();
(!files_content.is_empty()).then_some(files_content)
});
maybe_insert_message(context, Some("files content for context:\n\n".to_owned()));
if let Some(command_text) = prompt_params.custom_prompt.clone() {
let mut prompt_message = command_text;
if !prompt_message.contains(PLACEHOLDER_TOKEN) {
prompt_message.push_str(PLACEHOLDER_TOKEN);
}
for message in prompt.messages.iter_mut() {
message.content = message.content.replace(PLACEHOLDER_TOKEN, "");
}
prompt.messages.push(Message::user(&prompt_message));
}
let mut last_message =
if prompt.messages.is_empty() | prompt.messages.last().is_some_and(|m| m.role != "user") {
Message::user(PLACEHOLDER_TOKEN)
} else {
prompt.messages.pop().unwrap()
};
if !last_message.content.contains(PLACEHOLDER_TOKEN) {
last_message.content.push_str(PLACEHOLDER_TOKEN);
}
if let Some(after_input_text) = prompt_params.after_input.clone() {
last_message.content.push_str(&after_input_text);
}
if prompt_params.temperature.is_some() {
prompt.temperature = prompt_params.temperature;
}
prompt.messages.push(last_message);
debug!("post-customization prompt {:?}", prompt);
prompt
}
#[cfg(test)]
mod tests {
use crate::config::Api;
use std::io::Write;
use super::*;
#[test]
fn test_customize_prompt_empty_no_overrides() {
let prompt = Prompt::default();
let prompt_params = PromptParams::default();
let customized = customize_prompt(prompt, &prompt_params);
let default_prompt = Prompt::empty();
assert_eq!(customized.api, default_prompt.api);
assert_eq!(customized.model, default_prompt.model);
assert_eq!(customized.temperature, default_prompt.temperature);
assert_eq!(
customized.messages,
vec![
Prompt::default().messages.first().unwrap().to_owned(),
Message::user(PLACEHOLDER_TOKEN)
]
);
assert!(
customized
.messages
.last()
.unwrap()
.content
.contains(PLACEHOLDER_TOKEN),
"The last message should contain the placeholder."
);
}
#[test]
fn test_customize_prompt_api_override() {
let prompt = Prompt::empty();
let prompt_params = PromptParams {
api: Some(Api::AnotherApiForTests),
..PromptParams::default()
};
let customized = customize_prompt(prompt, &prompt_params);
let default_prompt = Prompt::empty();
assert_eq!(customized.api, Api::AnotherApiForTests);
assert_eq!(customized.model, default_prompt.model);
}
#[test]
fn test_customize_prompt_model_override() {
let prompt = Prompt::empty();
let prompt_params = PromptParams {
model: Some("test_model".to_owned()),
..PromptParams::default()
};
let customized = customize_prompt(prompt, &prompt_params);
let default_prompt = Prompt::empty();
assert_eq!(customized.model, prompt_params.model);
assert_eq!(customized.api, default_prompt.api);
}
#[test]
fn test_customize_prompt_command_insertion() {
let prompt = Prompt::empty();
let prompt_params = PromptParams {
custom_prompt: Some("test_command".to_owned()),
..PromptParams::default()
};
let customized = customize_prompt(prompt, &prompt_params);
assert!(customized
.messages
.iter()
.any(|m| m.content.contains("test_command")));
}
#[test]
fn test_customize_prompt_system_message_insertion() {
let prompt = Prompt::empty();
let prompt_params = PromptParams {
system_message: Some("system message".to_owned()),
..PromptParams::default()
};
let customized = customize_prompt(prompt, &prompt_params);
assert_eq!(
customized.messages[0].content,
prompt_params.system_message.unwrap(),
"{:?}",
customized.messages
);
assert_eq!(
customized.messages[0].role,
Message::system("").role,
"{:?}",
customized.messages
);
}
#[test]
fn test_customize_prompt_system_message_insertion_with_user_message() {
let mut prompt = Prompt::empty();
prompt.messages.push(Message::user("user message"));
let prompt_params = PromptParams {
system_message: Some("system message".to_owned()),
..PromptParams::default()
};
let customized = customize_prompt(prompt, &prompt_params);
assert_eq!(
customized.messages[0].content,
prompt_params.system_message.unwrap(),
"{:?}",
customized.messages
);
assert_eq!(
customized.messages[0].role,
Message::system("").role,
"{:?}",
customized.messages
);
}
#[test]
fn test_customize_prompt_with_context_file() {
let prompt = Prompt::empty();
let context_content = "hello there".to_owned();
let mut context_file = tempfile::NamedTempFile::new().unwrap();
context_file.write_all(context_content.as_bytes()).unwrap();
let prompt_params = PromptParams {
context: Some(context_file.path().to_str().unwrap().to_owned()),
..PromptParams::default()
};
let customized = customize_prompt(prompt, &prompt_params);
assert_eq!(
customized.messages[0].content,
format!(
"files content for context:\n\n{}:\n```\n{}\n```\n",
context_file.path().display(),
context_content
)
);
assert_eq!(customized.messages[0].role, "system");
}
#[test]
fn test_customize_prompt_temperature_override() {
let prompt = Prompt::empty();
let prompt_params = PromptParams {
temperature: Some(42.),
..PromptParams::default()
};
let customized = customize_prompt(prompt, &prompt_params);
assert_eq!(customized.temperature, Some(42.));
}
#[test]
fn test_customize_prompt_after_input_insertion() {
let mut prompt = Prompt::empty();
prompt
.messages
.push(Message::user(&format!("command {}", PLACEHOLDER_TOKEN)));
let prompt_params = PromptParams {
after_input: Some("-- after input".to_owned()),
..PromptParams::default()
};
let customized = customize_prompt(prompt, &prompt_params);
let last_message_content = &customized.messages.last().unwrap().content;
assert!(
last_message_content.ends_with("-- after input"),
"The last message should end with the after input text. Got {}",
&last_message_content
)
}
#[test]
fn test_customize_prompt_with_all_overrides() {
let prompt = Prompt::empty();
let context_content = "hello there".to_owned();
let mut context_file = tempfile::NamedTempFile::new().unwrap();
context_file.write_all(context_content.as_bytes()).unwrap();
let promt_params = PromptParams {
api: Some(Api::AnotherApiForTests),
model: Some("test_model_override".to_owned()),
custom_prompt: Some("test_command_override".to_owned()),
context: Some(context_file.path().to_str().unwrap().to_owned()),
after_input: Some(" test_after_input_override".to_owned()),
system_message: Some("system message override".to_owned()),
temperature: Some(42.),
};
let customized = customize_prompt(prompt, &promt_params);
assert_eq!(customized.api, promt_params.api.unwrap());
assert!(customized.messages.iter().any(|m| m
.content
.contains(promt_params.custom_prompt.as_ref().unwrap())));
assert_eq!(
customized.messages[0].content,
promt_params.system_message.unwrap()
);
assert_eq!(customized.messages[0].role, "system");
assert_eq!(customized.model, promt_params.model);
assert_eq!(customized.temperature, promt_params.temperature);
assert_eq!(
customized.messages[1].content,
format!(
"files content for context:\n\n{}:\n```\n{}\n```\n",
context_file.path().display(),
context_content
)
);
assert_eq!(customized.messages[1].role, "system");
assert!(
customized
.messages
.last()
.unwrap()
.content
.ends_with(&promt_params.after_input.unwrap()),
"The last message should end with the after input text."
);
}
}