use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct HookResponse {
pub ok: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub reason: Option<String>,
}
impl HookResponse {
pub fn ok() -> Self {
Self {
ok: true,
reason: None,
}
}
pub fn ok_with_reason(reason: impl Into<String>) -> Self {
Self {
ok: true,
reason: Some(reason.into()),
}
}
pub fn not_ok(reason: impl Into<String>) -> Self {
Self {
ok: false,
reason: Some(reason.into()),
}
}
}
pub fn hook_response_json_schema() -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"ok": {
"type": "boolean",
"description": "Whether the condition was met"
},
"reason": {
"type": "string",
"description": "Reason, if the condition was not met"
}
},
"required": ["ok"],
"additionalProperties": false
})
}
pub fn parse_arguments(args: &str) -> Vec<String> {
if args.trim().is_empty() {
return vec![];
}
args.split_whitespace().map(|s| s.to_string()).collect()
}
pub fn parse_argument_names(argument_names: Option<Vec<String>>) -> Vec<String> {
match argument_names {
Some(names) => names
.into_iter()
.filter(|name| !name.trim().is_empty() && !name.chars().all(|c| c.is_ascii_digit()))
.collect(),
None => vec![],
}
}
pub fn substitute_arguments(
content: &str,
args: Option<&str>,
append_if_no_placeholder: bool,
argument_names: Vec<String>,
) -> String {
let args = match args {
Some(a) => a,
None => return content.to_string(),
};
let parsed_args = parse_arguments(args);
let original_content = content.to_string();
let mut content = original_content.clone();
for (i, name) in argument_names.iter().enumerate() {
if name.is_empty() {
continue;
}
let needle = format!("${}", name);
let mut search_start = 0;
while let Some(dollar_pos) = content[search_start..].find(&needle) {
let actual_pos = search_start + dollar_pos;
let after_name_pos = actual_pos + needle.len();
let after = content
.get(after_name_pos..after_name_pos + 1)
.unwrap_or("");
if after.is_empty()
|| (!after
.chars()
.next()
.map(|c| c.is_alphanumeric() || c == '_' || c == '[')
.unwrap_or(true))
{
let replacement = parsed_args.get(i).map(|s| s.as_str()).unwrap_or("");
content = format!(
"{}{}{}",
&content[..actual_pos],
replacement,
&content[after_name_pos..]
);
search_start = actual_pos + replacement.len();
} else {
search_start = after_name_pos;
}
}
}
if let Ok(re) = regex::Regex::new(r"\$ARGUMENTS\[(\d+)\]") {
content = re
.replace_all(&content, |caps: ®ex::Captures| {
let index: usize = caps[1].parse().unwrap_or(0);
parsed_args.get(index).map(|s| s.as_str()).unwrap_or("")
})
.to_string();
}
let mut result = content.to_string();
let mut search_start = 0;
while let Some(dollar_pos) = result[search_start..].find('$') {
let actual_pos = search_start + dollar_pos;
if let Some(ch) = result.chars().nth(actual_pos + 1) {
if ch.is_ascii_digit() {
let rest = &result[actual_pos + 1..];
let num_chars: String = rest.chars().take_while(|c| c.is_ascii_digit()).collect();
if let Ok(index) = num_chars.parse::<usize>() {
let after_num_pos = actual_pos + 1 + num_chars.len();
let after = result.get(after_num_pos..after_num_pos + 1).unwrap_or("");
if after.is_empty()
|| (!after
.chars()
.next()
.map(|c| c.is_alphanumeric())
.unwrap_or(true))
{
let replacement = parsed_args.get(index).map(|s| s.as_str()).unwrap_or("");
result = format!(
"{}{}{}",
&result[..actual_pos],
replacement,
&result[after_num_pos..]
);
search_start = actual_pos + replacement.len();
continue;
}
}
}
}
search_start = actual_pos + 1;
}
content = result;
content = content.replace("$ARGUMENTS", args);
if content == original_content && append_if_no_placeholder && !args.is_empty() {
content = format!("{}\n\nARGUMENTS: {}", content, args);
}
content
}
pub fn add_arguments_to_prompt(prompt: &str, json_input: &str) -> String {
substitute_arguments(prompt, Some(json_input), true, vec![])
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hook_response_ok() {
let resp = HookResponse::ok();
assert!(resp.ok);
assert!(resp.reason.is_none());
}
#[test]
fn test_hook_response_ok_with_reason() {
let resp = HookResponse::ok_with_reason("success");
assert!(resp.ok);
assert_eq!(resp.reason, Some("success".to_string()));
}
#[test]
fn test_hook_response_not_ok() {
let resp = HookResponse::not_ok("failed");
assert!(!resp.ok);
assert_eq!(resp.reason, Some("failed".to_string()));
}
#[test]
fn test_parse_arguments() {
assert_eq!(parse_arguments("foo bar baz"), vec!["foo", "bar", "baz"]);
assert_eq!(parse_arguments(""), Vec::<&str>::new());
assert_eq!(parse_arguments(" "), Vec::<&str>::new());
}
#[test]
fn test_parse_argument_names() {
assert_eq!(
parse_argument_names(Some(vec!["foo".to_string(), "bar".to_string()])),
vec!["foo", "bar"]
);
assert_eq!(parse_argument_names(None), Vec::<&str>::new());
assert_eq!(
parse_argument_names(Some(vec!["0".to_string(), "foo".to_string()])),
vec!["foo"]
);
}
#[test]
fn test_substitute_arguments_basic() {
let result = substitute_arguments("Hello $ARGUMENTS", Some("world"), true, vec![]);
assert_eq!(result, "Hello world");
}
#[test]
fn test_substitute_arguments_indexed() {
let result =
substitute_arguments("$ARGUMENTS[0] $ARGUMENTS[1]", Some("foo bar"), true, vec![]);
assert_eq!(result, "foo bar");
}
#[test]
fn test_substitute_arguments_shorthand() {
let result = substitute_arguments("$0 $1", Some("foo bar"), true, vec![]);
assert_eq!(result, "foo bar");
}
#[test]
fn test_substitute_arguments_named() {
let result = substitute_arguments(
"$greeting $name",
Some("hello world"),
true,
vec!["greeting".to_string(), "name".to_string()],
);
assert_eq!(result, "hello world");
}
#[test]
fn test_substitute_arguments_no_args() {
let result = substitute_arguments("Hello $ARGUMENTS", None, true, vec![]);
assert_eq!(result, "Hello $ARGUMENTS");
}
#[test]
fn test_substitute_arguments_append() {
let result = substitute_arguments("Hello", Some("world"), true, vec![]);
assert!(result.contains("ARGUMENTS: world"));
}
#[test]
fn test_add_arguments_to_prompt() {
let result = add_arguments_to_prompt("Run: $ARGUMENTS", "ls -la");
assert!(result.contains("ls -la"));
}
}