use crate::repl;
use crate::tool_definition;
use crate::utils;
use serde_json::json;
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct LuaReplOptions {
pub max_output_len: usize,
}
#[derive(Clone)]
pub struct LuaRepl {
repl: Arc<repl::Repl>,
options: LuaReplOptions,
}
impl LuaRepl {
const DEFAULT_OPTIONS: LuaReplOptions = LuaReplOptions {
max_output_len: 50_000,
};
pub fn new(repl: repl::Repl) -> Self {
Self::new_with(repl, Self::DEFAULT_OPTIONS)
}
pub fn new_with(repl: repl::Repl, options: LuaReplOptions) -> Self {
Self {
repl: Arc::new(repl),
options,
}
}
}
impl LuaRepl {
pub fn definition(&self) -> mistralrs::Tool {
let schema = tool_definition::json_schema();
let parameters: HashMap<String, serde_json::Value> =
serde_json::from_value(schema).expect("Failed to parse schema");
mistralrs::Tool {
tp: mistralrs::ToolType::Function,
function: mistralrs::Function {
name: tool_definition::NAME.to_string(),
description: Some(tool_definition::DESCRIPTION.to_string()),
parameters: Some(parameters),
},
}
}
pub fn call(&self, tool_call: &mistralrs::ToolCallResponse) -> String {
if tool_call.function.name != tool_definition::NAME {
return json!({
"error": format!(
"Unknown tool: {}. Expected: {}",
tool_call.function.name,
tool_definition::NAME
)
})
.to_string();
}
let arguments: serde_json::Value = match serde_json::from_str(&tool_call.function.arguments)
{
Ok(args) => args,
Err(err) => {
return json!({
"error": format!("Failed to parse arguments: {}", err)
})
.to_string();
}
};
let source_code = match arguments.get(tool_definition::PARAM_SOURCE_CODE) {
Some(serde_json::Value::String(s)) => s,
_ => {
return json!({
"error": format!(
"Missing or invalid parameter: {}",
tool_definition::PARAM_SOURCE_CODE
)
})
.to_string();
}
};
let eval_outcome = match self.repl.eval(source_code) {
Ok(outcome) => outcome,
Err(err) => {
return json!({
"error": format!("REPL evaluation failed: {}", err)
})
.to_string();
}
};
let truncated_output =
utils::truncate_output(&eval_outcome.output.join("\n"), self.options.max_output_len);
let full_result = match eval_outcome.result {
Ok(values) => values.join("\n"),
Err(err) => format!("error: {}", err),
};
let truncated_result = utils::truncate_output(&full_result, self.options.max_output_len);
json!({
"output": truncated_output,
"result": truncated_result
})
.to_string()
}
}