use lash_core::{ToolDefinition, ToolFailure, ToolFailureClass, ToolResult};
use schemars::JsonSchema;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Deserializer, Serialize};
use std::future::Future;
use std::path::{Component, Path, PathBuf};
mod static_provider;
#[cfg(feature = "lashlang")]
pub use lash_lashlang_runtime::LashlangToolBinding;
pub use static_provider::{StaticToolExecute, StaticToolProvider};
#[cfg(not(feature = "lashlang"))]
#[derive(Clone, Debug, Default)]
pub struct LashlangToolBinding;
#[cfg(not(feature = "lashlang"))]
impl LashlangToolBinding {
pub fn new(
module_path: impl IntoIterator<Item = impl Into<String>>,
operation: impl Into<String>,
) -> Self {
let _ = module_path
.into_iter()
.map(Into::into)
.collect::<Vec<String>>();
let _ = operation.into();
Self
}
pub fn with_authority_type(self, authority_type: impl Into<String>) -> Self {
let _ = authority_type.into();
self
}
pub fn with_aliases(self, aliases: impl IntoIterator<Item = impl Into<String>>) -> Self {
let _ = aliases.into_iter().map(Into::into).collect::<Vec<String>>();
self
}
}
pub trait ToolDefinitionLashlangExt {
fn with_lashlang_binding(self, lashlang_binding: LashlangToolBinding) -> Self;
}
#[cfg(feature = "lashlang")]
impl ToolDefinitionLashlangExt for ToolDefinition {
fn with_lashlang_binding(self, lashlang_binding: LashlangToolBinding) -> Self {
lash_lashlang_runtime::ToolDefinitionLashlangExt::with_lashlang_binding(
self,
lashlang_binding,
)
}
}
#[cfg(not(feature = "lashlang"))]
impl ToolDefinitionLashlangExt for ToolDefinition {
fn with_lashlang_binding(self, _lashlang_binding: LashlangToolBinding) -> Self {
self
}
}
pub fn resolve_under(base: &Path, path: &Path) -> PathBuf {
let joined = if path.is_absolute() {
path.to_path_buf()
} else {
base.join(path)
};
normalize_lexical(&joined)
}
pub fn normalize_lexical(path: &Path) -> PathBuf {
let mut normalized = PathBuf::new();
for component in path.components() {
match component {
Component::CurDir => {}
Component::ParentDir => {
if !normalized.pop() {
normalized.push(component.as_os_str());
}
}
Component::Prefix(_) | Component::RootDir | Component::Normal(_) => {
normalized.push(component.as_os_str());
}
}
}
normalized
}
pub fn canonicalize_under(base: &Path, path: &Path) -> std::io::Result<PathBuf> {
std::fs::canonicalize(resolve_under(base, path))
}
pub fn display_relative(base: &Path, path: &Path) -> String {
let display = path
.strip_prefix(base)
.unwrap_or(path)
.display()
.to_string();
let display = if display.is_empty() {
path.file_name()
.and_then(|name| name.to_str())
.unwrap_or(".")
.to_string()
} else {
display
};
display.replace('\\', "/")
}
pub const FS_DEFAULTS_PREAMBLE: &str = "By default this excludes hidden entries, `.git`, and `node_modules`, and respects ignore files.";
#[derive(Clone, Debug, Serialize, JsonSchema)]
pub struct TruncationMeta {
pub shown: usize,
pub total: usize,
pub omitted: usize,
}
pub fn invalid_tool_args(message: impl Into<String>) -> ToolResult {
ToolResult::failure(ToolFailure::tool(
ToolFailureClass::InvalidRequest,
"invalid_tool_args",
message.into(),
))
}
pub fn typed_tool_args<Args>(args: &serde_json::Value) -> Result<Args, ToolResult>
where
Args: DeserializeOwned + JsonSchema,
{
serde_json::from_value(args.clone())
.map_err(|err| invalid_tool_args(format!("Invalid tool arguments: {err}")))
}
pub fn typed_tool_ok<Output>(output: Output) -> ToolResult
where
Output: Serialize + JsonSchema,
{
match serde_json::to_value(output) {
Ok(value) => ToolResult::ok(value),
Err(err) => ToolResult::err_fmt(format_args!("Failed to serialize tool result: {err}")),
}
}
pub async fn execute_typed_tool<Args, Output, F, Fut>(
args: &serde_json::Value,
execute: F,
) -> ToolResult
where
Args: DeserializeOwned + JsonSchema,
Output: Serialize + JsonSchema,
F: FnOnce(Args) -> Fut,
Fut: Future<Output = Result<Output, ToolResult>>,
{
let args = match typed_tool_args::<Args>(args) {
Ok(args) => args,
Err(err) => return err,
};
match execute(args).await {
Ok(output) => typed_tool_ok(output),
Err(err) => err,
}
}
pub async fn execute_typed_tool_result<Args, F, Fut>(
args: &serde_json::Value,
execute: F,
) -> ToolResult
where
Args: DeserializeOwned + JsonSchema,
F: FnOnce(Args) -> Fut,
Fut: Future<Output = ToolResult>,
{
let args = match typed_tool_args::<Args>(args) {
Ok(args) => args,
Err(err) => return err,
};
execute(args).await
}
pub fn non_empty_string(value: &str, key: &str) -> Result<(), ToolResult> {
if value.is_empty() {
Err(invalid_tool_args(format!(
"Missing required parameter: {key}"
)))
} else {
Ok(())
}
}
pub fn default_path_dot() -> String {
".".to_string()
}
#[derive(Clone, Debug, Deserialize, JsonSchema)]
#[serde(untagged)]
pub enum OptionalUsizeArg {
Value(usize),
NoneString(String),
Null(()),
}
impl OptionalUsizeArg {
pub fn into_option(self, key: &str, min: usize) -> Result<Option<usize>, ToolResult> {
match self {
Self::Value(value) if value >= min => Ok(Some(value)),
Self::Value(_) => Err(invalid_tool_args(format!(
"Invalid {key}: must be >= {min}, or use null/\"none\" for no cap"
))),
Self::NoneString(value) if value.eq_ignore_ascii_case("none") => Ok(None),
Self::NoneString(_) => Err(invalid_tool_args(format!(
"Invalid {key}: expected int, null, or \"none\""
))),
Self::Null(()) => Ok(None),
}
}
}
pub fn deserialize_optional_usize_none<'de, D>(deserializer: D) -> Result<Option<usize>, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(untagged)]
enum OptionalUsize {
Int(usize),
String(String),
Null,
}
match Option::<OptionalUsize>::deserialize(deserializer)? {
None | Some(OptionalUsize::Null) => Ok(None),
Some(OptionalUsize::Int(value)) => Ok(Some(value)),
Some(OptionalUsize::String(value)) if value.eq_ignore_ascii_case("none") => Ok(None),
Some(OptionalUsize::String(_)) => Err(serde::de::Error::custom(
"expected integer, null, or \"none\"",
)),
}
}
pub fn default_glob_limit() -> OptionalUsizeArg {
OptionalUsizeArg::Value(100)
}
pub fn require_str<'a>(args: &'a serde_json::Value, key: &str) -> Result<&'a str, ToolResult> {
args.get(key)
.and_then(|v| v.as_str())
.filter(|s| !s.is_empty())
.ok_or_else(|| ToolResult::err_fmt(format_args!("Missing required parameter: {key}")))
}
pub fn parse_optional_bool(
args: &serde_json::Value,
key: &str,
default: bool,
) -> Result<bool, ToolResult> {
match args.get(key) {
None => Ok(default),
Some(v) if v.is_null() => Ok(default),
Some(v) => match v.as_bool() {
Some(b) => Ok(b),
None => Err(ToolResult::err_fmt(format_args!(
"Invalid {key}: expected bool"
))),
},
}
}
pub fn parse_optional_usize_arg(
args: &serde_json::Value,
key: &str,
default: Option<usize>,
allow_none: bool,
min: usize,
) -> Result<Option<usize>, ToolResult> {
match args.get(key) {
None => Ok(default),
Some(v) if v.is_null() => {
if allow_none {
Ok(None)
} else {
Err(ToolResult::err_fmt(format_args!(
"Invalid {key}: expected int >= {min}"
)))
}
}
Some(v) => {
if let Some(s) = v.as_str() {
if allow_none && s.eq_ignore_ascii_case("none") {
return Ok(None);
}
return Err(ToolResult::err_fmt(format_args!(
"Invalid {key}: expected int{}",
if allow_none {
", null, or \"none\""
} else {
""
}
)));
}
let n = v.as_u64().ok_or_else(|| {
ToolResult::err_fmt(format_args!(
"Invalid {key}: expected int{}",
if allow_none {
", null, or \"none\""
} else {
""
}
))
})? as usize;
if n < min {
return Err(ToolResult::err_fmt(format_args!(
"Invalid {key}: must be >= {min}{}",
if allow_none {
", or use null/\"none\" for no cap"
} else {
""
}
)));
}
Ok(Some(n))
}
}
}
pub fn object_schema(properties: serde_json::Value, required: &[&str]) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": properties,
"required": required,
"additionalProperties": false,
})
}
pub fn lashlang_binding(
module_path: impl IntoIterator<Item = impl Into<String>>,
operation: impl Into<String>,
aliases: &[&str],
) -> LashlangToolBinding {
LashlangToolBinding::new(module_path, operation).with_aliases(aliases.iter().copied())
}
pub async fn run_blocking<F>(f: F) -> ToolResult
where
F: FnOnce() -> ToolResult + Send + 'static,
{
match tokio::task::spawn_blocking(f).await {
Ok(result) => result,
Err(e) => ToolResult::err_fmt(format_args!("blocking task failed: {e}")),
}
}
pub async fn run_blocking_value<F, T>(f: F) -> Result<T, String>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
tokio::task::spawn_blocking(f)
.await
.map_err(|err| format!("blocking task failed: {err}"))
}
pub fn rg_file_list(
base: &Path,
show_hidden_entries: bool,
respect_ignore_files: bool,
max_depth: Option<usize>,
globs: &[String],
) -> Result<Vec<PathBuf>, ToolResult> {
if is_default_excluded_entry(base) {
return Ok(Vec::new());
}
let mut builder = ignore::WalkBuilder::new(base);
builder
.hidden(!show_hidden_entries)
.max_depth(max_depth)
.filter_entry(|entry| !is_default_excluded_entry(entry.path()));
if respect_ignore_files {
builder.git_ignore(true).git_exclude(true).git_global(true);
builder.require_git(true);
} else {
builder
.git_ignore(false)
.git_exclude(false)
.git_global(false)
.ignore(false)
.parents(false)
.require_git(false);
}
if !globs.is_empty() {
let mut override_builder = ignore::overrides::OverrideBuilder::new(base);
for glob in globs {
override_builder.add(glob).map_err(|err| {
ToolResult::err_fmt(format_args!(
"invalid ignore glob for {}: {err}",
base.display()
))
})?;
}
let overrides = override_builder.build().map_err(|err| {
ToolResult::err_fmt(format_args!(
"failed to build ignore globs for {}: {err}",
base.display()
))
})?;
builder.overrides(overrides);
}
let files = builder
.build()
.filter_map(Result::ok)
.filter(|entry| entry.path() != base)
.filter(|entry| !is_default_excluded_entry(entry.path()))
.map(ignore::DirEntry::into_path)
.collect();
Ok(files)
}
fn is_default_excluded_entry(path: &Path) -> bool {
path.file_name().is_some_and(|name| {
let name = name.to_string_lossy();
matches!(name.as_ref(), ".git" | "node_modules")
})
}
pub fn compact_diff(old: &str, new: &str, path: &str, max_lines: usize) -> String {
let diff = similar::TextDiff::from_lines(old, new);
let unified = diff
.unified_diff()
.header(&format!("a/{path}"), &format!("b/{path}"))
.to_string();
if unified.is_empty() {
return String::new();
}
let lines: Vec<&str> = unified.lines().collect();
if lines.len() <= max_lines {
unified
} else {
let mut truncated: String = lines[..max_lines].join("\n");
truncated.push_str(&format!("\n... ({} more lines)", lines.len() - max_lines));
truncated
}
}