use std::{
any::{Any, TypeId},
marker::PhantomData,
path::PathBuf,
sync::Arc,
};
use serde::de::DeserializeOwned;
use crate::{NoMeta, ToolCollection, ToolError, collect_inventory_inner};
use crate::ffi::{Language, leak_string, load_language};
mod sealed {
pub trait Sealed {}
}
pub trait BuilderState: sealed::Sealed {}
pub struct Blank;
pub struct Native;
pub struct Scripted;
impl sealed::Sealed for Blank {}
impl sealed::Sealed for Native {}
impl sealed::Sealed for Scripted {}
impl BuilderState for Blank {}
impl BuilderState for Native {}
impl BuilderState for Scripted {}
struct BuilderInner {
ctx: Option<Arc<dyn Any + Send + Sync>>,
ctx_type_id: Option<TypeId>,
ctx_type_name: &'static str,
language: Option<Language>,
script_paths: Vec<PathBuf>,
}
impl BuilderInner {
fn empty() -> Self {
Self {
ctx: None,
ctx_type_id: None,
ctx_type_name: "",
language: None,
script_paths: Vec::new(),
}
}
}
pub struct ToolsBuilder<S: BuilderState = Blank, M = NoMeta> {
inner: BuilderInner,
_marker: PhantomData<fn() -> (S, M)>,
}
impl ToolsBuilder<Blank, NoMeta> {
pub fn new() -> Self {
Self {
inner: BuilderInner::empty(),
_marker: PhantomData,
}
}
}
impl Default for ToolsBuilder<Blank, NoMeta> {
fn default() -> Self {
Self::new()
}
}
impl<M> ToolsBuilder<Blank, M> {
pub fn with_context<T: Send + Sync + 'static>(
self,
ctx: Arc<T>,
) -> ToolsBuilder<Native, M> {
ToolsBuilder {
inner: BuilderInner {
ctx: Some(ctx),
ctx_type_id: Some(TypeId::of::<T>()),
ctx_type_name: std::any::type_name::<T>(),
language: None,
script_paths: Vec::new(),
},
_marker: PhantomData,
}
}
pub fn with_language(self, lang: Language) -> ToolsBuilder<Scripted, M> {
ToolsBuilder {
inner: BuilderInner {
language: Some(lang),
script_paths: Vec::new(),
ctx: None,
ctx_type_id: None,
ctx_type_name: "",
},
_marker: PhantomData,
}
}
}
impl<S: BuilderState, M> ToolsBuilder<S, M> {
pub fn with_meta<M2>(self) -> ToolsBuilder<S, M2> {
ToolsBuilder {
inner: self.inner,
_marker: PhantomData,
}
}
}
impl<M: DeserializeOwned> ToolsBuilder<Blank, M> {
pub fn collect(self) -> Result<ToolCollection<M>, ToolError> {
collect_inventory_inner(None, None, "")
}
}
impl<M: DeserializeOwned> ToolsBuilder<Native, M> {
pub fn collect(self) -> Result<ToolCollection<M>, ToolError> {
collect_inventory_inner(
self.inner.ctx,
self.inner.ctx_type_id,
self.inner.ctx_type_name,
)
}
}
impl<M> ToolsBuilder<Scripted, M> {
pub fn from_path(mut self, path: impl Into<PathBuf>) -> Self {
self.inner.script_paths.push(path.into());
self
}
}
impl<M: DeserializeOwned> ToolsBuilder<Scripted, M> {
#[cfg_attr(
not(any(feature = "python", feature = "lua", feature = "js")),
allow(unreachable_code, unused_variables)
)]
pub fn collect(self) -> Result<ToolCollection<M>, ToolError> {
let lang = self
.inner
.language
.expect("Scripted state must have a language set");
let mut collection: ToolCollection<M> = collect_inventory_inner(None, None, "")?;
for path in &self.inner.script_paths {
let defs = load_language(lang, path)?;
for def in defs {
let name = leak_string(def.name);
let desc = leak_string(def.description);
let meta: M =
serde_json::from_value(def.meta).map_err(|e| ToolError::BadMeta {
tool: name,
error: e.to_string(),
})?;
let func = def.func;
collection.register_raw(name, desc, def.parameters, move |v| func(v), meta)?;
}
}
Ok(collection)
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::Deserialize;
use serde_json::json;
#[test]
fn blank_collect_returns_collection() {
let tools: ToolCollection = ToolsBuilder::new().collect().unwrap();
let _ = tools.json().unwrap();
}
#[test]
fn with_meta_changes_type() {
#[derive(Debug, Default, Deserialize)]
#[serde(default)]
struct Policy {
_flag: bool,
}
let tools = ToolsBuilder::new()
.with_meta::<Policy>()
.collect()
.unwrap();
let _ = tools.json().unwrap();
}
#[test]
fn with_context_then_collect() {
let ctx = Arc::new(42_u32);
let tools: ToolCollection = ToolsBuilder::new()
.with_context(ctx)
.collect()
.unwrap();
let _ = tools.json().unwrap();
}
#[test]
fn register_raw_works() {
let mut tools: ToolCollection = ToolsBuilder::new().collect().unwrap();
tools
.register_raw(
"echo",
"Echoes input back",
json!({
"type": "object",
"properties": {
"msg": { "type": "string" }
},
"required": ["msg"]
}),
|v| {
Box::pin(async move {
let msg = v.get("msg").and_then(|m| m.as_str()).unwrap_or("");
Ok(serde_json::Value::String(msg.to_string()))
})
},
(),
)
.unwrap();
let decls = tools.json().unwrap();
let arr = decls.as_array().unwrap();
assert!(arr.iter().any(|d| d["name"] == "echo"));
}
#[tokio::test]
async fn register_raw_callable() {
let mut tools: ToolCollection = ToolsBuilder::new().collect().unwrap();
tools
.register_raw(
"double",
"Doubles a number",
json!({
"type": "object",
"properties": { "n": { "type": "integer" } },
"required": ["n"]
}),
|v| {
Box::pin(async move {
let n = v.get("n").and_then(|n| n.as_i64()).unwrap_or(0);
Ok(serde_json::Value::Number((n * 2).into()))
})
},
(),
)
.unwrap();
let resp = tools
.call(crate::FunctionCall::new(
"double".to_string(),
json!({ "n": 21 }),
))
.await
.unwrap();
assert_eq!(resp.result, json!(42));
}
#[cfg(feature = "python")]
#[test]
fn scripted_no_paths_collects_inventory() {
let tools: ToolCollection = ToolsBuilder::new()
.with_language(crate::Language::Python)
.collect()
.unwrap();
let _ = tools.json().unwrap();
}
#[cfg(feature = "python")]
#[test]
fn scripted_with_path_errors_not_implemented() {
let err = ToolsBuilder::new()
.with_language(crate::Language::Python)
.from_path("/some/script.py")
.collect()
.err()
.expect("should error");
assert!(
err.to_string().contains("not yet implemented"),
"expected 'not yet implemented', got: {err}"
);
}
#[cfg(feature = "python")]
#[test]
fn scripted_from_path_chainable() {
let err = ToolsBuilder::new()
.with_language(crate::Language::Python)
.from_path("/first.py")
.from_path("/second.py")
.collect()
.err()
.expect("should error");
assert!(err.to_string().contains("not yet implemented"));
}
}