from __future__ import annotations
import functools
from typing import Any, Callable, Dict, Optional
from .client import A1Client, A1Error
__all__ = ["a1_llamaindex_tool", "a1_llamaindex_guard"]
def a1_llamaindex_tool(
fn: Callable[..., Any],
*,
intent_name: str,
client: A1Client,
resolve_context: Callable[[Dict[str, Any]], Dict[str, Any]],
name: Optional[str] = None,
description: Optional[str] = None,
) -> Any:
try:
from llama_index.core.tools import FunctionTool
except ImportError as exc:
raise ImportError(
"LlamaIndex is required: pip install llama-index-core"
) from exc
tool_name = name or fn.__name__
tool_description = description or (fn.__doc__ or "").strip() or tool_name
@functools.wraps(fn)
def guarded(**kwargs: Any) -> Any:
ctx = resolve_context(kwargs)
chain = ctx.get("chain")
executor_pk = ctx.get("executor_pk_hex", "")
if chain is None:
raise A1Error(
f"resolve_context must supply 'chain' for intent '{intent_name}'",
error_code="MISSING_CHAIN",
)
client.authorize(
chain=chain,
intent_name=intent_name,
executor_pk_hex=executor_pk,
)
return fn(**kwargs)
guarded.__name__ = tool_name
guarded.__doc__ = tool_description
return FunctionTool.from_defaults(
fn=guarded,
name=tool_name,
description=tool_description,
)
def a1_llamaindex_guard(
*,
intent_name: str,
client: A1Client,
chain_kwarg: str = "signed_chain",
executor_kwarg: str = "executor_pk_hex",
) -> Callable:
import asyncio
import inspect
def decorator(fn: Callable) -> Callable:
if inspect.iscoroutinefunction(fn):
@functools.wraps(fn)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
chain = kwargs.get(chain_kwarg)
executor_pk = kwargs.get(executor_kwarg, "")
if chain is None:
raise A1Error(
f"missing required kwarg '{chain_kwarg}'",
error_code="MISSING_CHAIN",
)
await client.authorize_async(
chain=chain,
intent_name=intent_name,
executor_pk_hex=executor_pk,
)
return await fn(*args, **kwargs)
return async_wrapper
@functools.wraps(fn)
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
chain = kwargs.get(chain_kwarg)
executor_pk = kwargs.get(executor_kwarg, "")
if chain is None:
raise A1Error(
f"missing required kwarg '{chain_kwarg}'",
error_code="MISSING_CHAIN",
)
client.authorize(
chain=chain,
intent_name=intent_name,
executor_pk_hex=executor_pk,
)
return fn(*args, **kwargs)
return sync_wrapper
return decorator