use crate::repl;
use crate::tool_definition;
use crate::utils;
use rig::completion::ToolDefinition;
use rig::tool::Tool;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Deserialize)]
pub struct LuaReplArgs {
pub source_code: String,
}
#[derive(Serialize)]
pub struct LuaReplOutput {
pub output: String,
pub result: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct LuaReplOptions {
pub max_output_len: usize,
}
#[derive(Clone)]
pub struct LuaRepl {
repl: Arc<repl::Repl>,
options: LuaReplOptions,
}
impl LuaRepl {
const DEFAULT_OPTIONS: LuaReplOptions = LuaReplOptions {
max_output_len: 50_000,
};
pub fn new(repl: repl::Repl) -> Self {
Self::new_with(repl, Self::DEFAULT_OPTIONS)
}
pub fn new_with(repl: repl::Repl, options: LuaReplOptions) -> Self {
Self {
repl: Arc::new(repl),
options,
}
}
}
impl Tool for LuaRepl {
const NAME: &'static str = tool_definition::NAME;
type Error = std::convert::Infallible;
type Args = LuaReplArgs;
type Output = LuaReplOutput;
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: tool_definition::NAME.to_string(),
description: tool_definition::DESCRIPTION.to_string(),
parameters: tool_definition::json_schema(),
}
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let eval_outcome = match self.repl.eval(&args.source_code) {
Ok(outcome) => outcome,
Err(err) => {
return Ok(LuaReplOutput {
output: String::new(),
result: format!("error: REPL evaluation failed: {}", err),
});
}
};
let truncated_output =
utils::truncate_output(&eval_outcome.output.join("\n"), self.options.max_output_len);
let full_result = match eval_outcome.result {
Ok(values) => values.join("\n"),
Err(err) => format!("error: {}", err),
};
let truncated_result = utils::truncate_output(&full_result, self.options.max_output_len);
Ok(LuaReplOutput {
output: truncated_output,
result: truncated_result,
})
}
}