use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use crate::provider::ToolSpec;
use crate::tool::{Tool, ToolRegistry};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ScopeId(pub(crate) usize);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ScopeLevel {
Global = 0,
Agent = 1,
Run = 2,
Turn = 3,
}
#[derive(Clone)]
struct ScopeLayer {
tools: HashMap<String, Arc<dyn Tool>>,
}
struct ScopeState {
layers: Vec<(ScopeId, ScopeLayer)>,
next_id: usize,
}
impl Clone for ScopeState {
fn clone(&self) -> Self {
Self {
layers: self.layers.clone(),
next_id: self.next_id,
}
}
}
type RegisterArcResult = Result<Option<Arc<dyn Tool>>, (String, Arc<dyn Tool>)>;
pub struct ScopedToolRegistry {
base: ToolRegistry,
state: Mutex<ScopeState>,
}
impl Clone for ScopedToolRegistry {
fn clone(&self) -> Self {
let state = self.lock_state_poison_safe();
Self {
base: self.base.clone(),
state: Mutex::new(state.clone()),
}
}
}
impl ScopedToolRegistry {
#[must_use]
pub fn new(base: ToolRegistry) -> Self {
Self {
base,
state: Mutex::new(ScopeState {
layers: Vec::new(),
next_id: 0,
}),
}
}
#[must_use]
pub fn push_scope(&self) -> ScopeId {
let mut state = self.lock_state_poison_safe();
let id = ScopeId(state.next_id);
state.next_id += 1;
state.layers.push((
id,
ScopeLayer {
tools: HashMap::new(),
},
));
id
}
pub fn pop_scope(&self, id: ScopeId) -> bool {
let mut state = self.lock_state_poison_safe();
if let Some(&(top_id, _)) = state.layers.last() {
if top_id == id {
state.layers.pop();
return true;
}
}
false
}
pub fn register_in_scope<T: Tool + 'static>(
&self,
scope: ScopeId,
tool: T,
) -> Result<Option<Arc<dyn Tool>>, T> {
let name = tool.name().to_owned();
let mut state = self.lock_state_poison_safe();
for (id, layer) in &mut state.layers {
if *id == scope {
let arc: Arc<dyn Tool> = Arc::new(tool);
return Ok(layer.tools.insert(name, arc));
}
}
Err(tool)
}
pub fn register_arc_in_scope(
&self,
scope: ScopeId,
name: String,
tool: Arc<dyn Tool>,
) -> RegisterArcResult {
let mut state = self.lock_state_poison_safe();
for (id, layer) in &mut state.layers {
if *id == scope {
return Ok(layer.tools.insert(name, tool));
}
}
Err((name, tool))
}
#[must_use]
pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
let state = self.lock_state_poison_safe();
for (_, layer) in state.layers.iter().rev() {
if let Some(tool) = layer.tools.get(name) {
return Some(Arc::clone(tool));
}
}
self.base.get(name)
}
#[must_use]
pub fn specs(&self) -> Vec<ToolSpec> {
let state = self.lock_state_poison_safe();
let mut merged: HashMap<String, Arc<dyn Tool>> = HashMap::new();
for spec in self.base.specs() {
if let Some(tool) = self.base.get(&spec.name) {
merged.insert(spec.name, tool);
}
}
for (_, layer) in &state.layers {
for (name, tool) in &layer.tools {
merged.insert(name.clone(), Arc::clone(tool));
}
}
let mut specs: Vec<ToolSpec> = merged.values().map(|t| t.to_spec()).collect();
specs.sort_by(|a, b| a.name.cmp(&b.name));
specs
}
#[must_use]
pub fn scope_depth(&self) -> usize {
let state = self.lock_state_poison_safe();
state.layers.len()
}
#[must_use]
pub fn base(&self) -> &ToolRegistry {
&self.base
}
pub fn base_mut(&mut self) -> &mut ToolRegistry {
&mut self.base
}
pub fn unregister_from_base(&self, name: &str) -> Option<Arc<dyn Tool>> {
self.base.unregister(name)
}
#[must_use]
pub fn len(&self) -> usize {
self.specs().len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.specs().is_empty()
}
pub async fn execute(
&self,
call: &crate::provider::ToolCall,
) -> crate::tool::ToolResult<crate::tool::ToolOutput> {
let tool = self
.get(&call.name)
.ok_or_else(|| crate::error::ToolError::NotFound {
name: call.name.clone(),
})?;
tool.execute(call.arguments.clone()).await
}
#[must_use]
pub fn push_scope_guarded(self: &Arc<Self>) -> ScopeGuard {
let id = self.push_scope();
ScopeGuard {
scope: Arc::clone(self),
id,
}
}
#[allow(clippy::expect_used)]
fn lock_state_poison_safe(&self) -> std::sync::MutexGuard<'_, ScopeState> {
self.state.lock().expect("scope lock poisoned")
}
}
pub struct ScopeGuard {
scope: Arc<ScopedToolRegistry>,
id: ScopeId,
}
impl ScopeGuard {
#[must_use]
pub fn id(&self) -> ScopeId {
self.id
}
#[must_use]
pub fn into_inner(self) -> Arc<ScopedToolRegistry> {
let scope = Arc::clone(&self.scope);
std::mem::forget(self);
scope
}
}
impl Drop for ScopeGuard {
fn drop(&mut self) {
self.scope.pop_scope(self.id);
}
}
impl std::fmt::Debug for ScopeGuard {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ScopeGuard")
.field("id", &self.id)
.finish_non_exhaustive()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::type_complexity, clippy::expect_used)]
mod tests {
use super::*;
use crate::tool::FunctionTool;
use serde_json::{Value, json};
fn make_tool(
name: &'static str,
) -> FunctionTool<
impl Fn(
Value,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = crate::tool::ToolResult<Value>> + Send>,
> + Send
+ Sync
+ 'static,
> {
FunctionTool::new(
name,
format!("{name} desc"),
json!({"type": "object"}),
|args| -> std::pin::Pin<
Box<dyn std::future::Future<Output = crate::tool::ToolResult<Value>> + Send>,
> { Box::pin(async move { Ok(args) }) },
)
}
#[test]
fn new_should_have_base_tools() {
let base = ToolRegistry::new();
base.register(make_tool("base_tool"));
let scoped = ScopedToolRegistry::new(base);
assert!(scoped.get("base_tool").is_some());
assert!(scoped.get("missing").is_none());
assert_eq!(scoped.scope_depth(), 0);
}
#[test]
fn push_scope_should_create_empty_scope() {
let base = ToolRegistry::new();
let scoped = ScopedToolRegistry::new(base);
let scope = scoped.push_scope();
assert_eq!(scoped.scope_depth(), 1);
assert!(scoped.get("anything").is_none());
let _ = scope;
}
#[test]
fn register_in_scope_should_shadow_base() {
let base = ToolRegistry::new();
base.register(make_tool("shared"));
let scoped = ScopedToolRegistry::new(base);
let scope = scoped.push_scope();
scoped
.register_in_scope(scope, make_tool("shared"))
.ok()
.unwrap();
assert!(scoped.get("shared").is_some());
assert_eq!(scoped.specs().len(), 1);
}
#[test]
fn pop_scope_should_unshadow_base_tool() {
let base = ToolRegistry::new();
base.register(make_tool("shared"));
let scoped = ScopedToolRegistry::new(base);
let scope = scoped.push_scope();
scoped
.register_in_scope(scope, make_tool("shared"))
.ok()
.unwrap();
assert!(scoped.get("shared").is_some());
scoped.pop_scope(scope);
assert!(scoped.get("shared").is_some());
assert_eq!(scoped.scope_depth(), 0);
}
#[test]
fn pop_scope_should_remove_scope_only_tools() {
let base = ToolRegistry::new();
let scoped = ScopedToolRegistry::new(base);
let scope = scoped.push_scope();
scoped
.register_in_scope(scope, make_tool("scope_only"))
.ok()
.unwrap();
assert!(scoped.get("scope_only").is_some());
scoped.pop_scope(scope);
assert!(scoped.get("scope_only").is_none());
}
#[test]
fn pop_scope_non_top_should_return_false() {
let base = ToolRegistry::new();
let registry = ScopedToolRegistry::new(base);
let scope1 = registry.push_scope();
let scope2 = registry.push_scope();
assert!(!registry.pop_scope(scope1));
assert_eq!(registry.scope_depth(), 2);
assert!(registry.pop_scope(scope2));
assert_eq!(registry.scope_depth(), 1);
}
#[test]
fn pop_nonexistent_scope_should_return_false() {
let base = ToolRegistry::new();
let scoped = ScopedToolRegistry::new(base);
assert!(!scoped.pop_scope(ScopeId(999)));
}
#[test]
fn multiple_scopes_shadow_correctly() {
let base = ToolRegistry::new();
let registry = ScopedToolRegistry::new(base);
let scope1 = registry.push_scope();
registry
.register_in_scope(scope1, make_tool("tool"))
.ok()
.unwrap();
let scope2 = registry.push_scope();
registry
.register_in_scope(scope2, make_tool("tool"))
.ok()
.unwrap();
assert_eq!(registry.specs().len(), 1);
registry.pop_scope(scope2);
assert!(registry.get("tool").is_some());
assert_eq!(registry.specs().len(), 1);
registry.pop_scope(scope1);
assert!(registry.get("tool").is_none());
}
#[test]
fn register_in_nonexistent_scope_returns_error() {
let base = ToolRegistry::new();
let scoped = ScopedToolRegistry::new(base);
let result = scoped.register_in_scope(ScopeId(999), make_tool("nope"));
assert!(result.is_err());
}
#[test]
fn register_replaces_within_same_scope() {
let base = ToolRegistry::new();
let scoped = ScopedToolRegistry::new(base);
let scope = scoped.push_scope();
let old = scoped
.register_in_scope(scope, make_tool("tool"))
.ok()
.unwrap();
assert!(old.is_none());
let old = scoped
.register_in_scope(scope, make_tool("tool"))
.ok()
.unwrap();
assert!(old.is_some());
assert_eq!(scoped.specs().len(), 1);
}
#[tokio::test]
async fn execute_should_resolve_from_scoped_registry() {
let base = ToolRegistry::new();
let scoped = ScopedToolRegistry::new(base);
let scope = scoped.push_scope();
scoped
.register_in_scope(scope, make_tool("tool"))
.ok()
.unwrap();
let call = crate::provider::ToolCall::new("c1", "tool", json!({"x": 1}));
let output = scoped.execute(&call).await.unwrap();
assert_eq!(output.value, json!({"x": 1}));
}
#[tokio::test]
async fn execute_unknown_tool_returns_not_found() {
let base = ToolRegistry::new();
let scoped = ScopedToolRegistry::new(base);
let call = crate::provider::ToolCall::new("c1", "missing", json!({}));
let result = scoped.execute(&call).await;
assert!(result.is_err());
}
#[test]
fn scope_id_increments_monotonically() {
let base = ToolRegistry::new();
let scoped = ScopedToolRegistry::new(base);
let s1 = scoped.push_scope();
let s2 = scoped.push_scope();
assert!(s2.0 > s1.0);
}
#[test]
fn base_mut_allows_modifying_base() {
let base = ToolRegistry::new();
let mut scoped = ScopedToolRegistry::new(base);
scoped.base_mut().register(make_tool("new_base"));
assert!(scoped.get("new_base").is_some());
}
#[test]
fn clone_captures_scope_snapshot() {
let base = ToolRegistry::new();
let scoped = ScopedToolRegistry::new(base);
let scope = scoped.push_scope();
scoped
.register_in_scope(scope, make_tool("tool"))
.ok()
.unwrap();
let cloned = scoped.clone();
assert!(cloned.get("tool").is_some());
scoped.pop_scope(scope);
assert!(scoped.get("tool").is_none());
assert!(cloned.get("tool").is_some());
}
#[test]
fn scope_guard_pops_on_drop() {
let base = ToolRegistry::new();
let scoped = Arc::new(ScopedToolRegistry::new(base));
{
let _guard = scoped.push_scope_guarded();
assert_eq!(scoped.scope_depth(), 1);
}
assert_eq!(scoped.scope_depth(), 0);
}
#[test]
fn scope_guard_id_matches_pushed_scope() {
let base = ToolRegistry::new();
let scoped = Arc::new(ScopedToolRegistry::new(base));
let guard = scoped.push_scope_guarded();
let id = guard.id();
assert_eq!(scoped.scope_depth(), 1);
assert!(scoped.pop_scope(id));
}
#[test]
fn scope_guard_into_inner_prevents_drop_cleanup() {
let base = ToolRegistry::new();
let scoped = Arc::new(ScopedToolRegistry::new(base));
let guard = scoped.push_scope_guarded();
let id = guard.id();
let _ = guard.into_inner();
assert_eq!(scoped.scope_depth(), 1);
scoped.pop_scope(id);
}
#[test]
fn scope_guard_tools_within_scope_are_visible() {
let base = ToolRegistry::new();
let scoped = Arc::new(ScopedToolRegistry::new(base));
let guard = scoped.push_scope_guarded();
scoped
.register_in_scope(guard.id(), make_tool("scoped_tool"))
.ok()
.unwrap();
assert!(scoped.get("scoped_tool").is_some());
assert_eq!(scoped.specs().len(), 1);
drop(guard);
assert!(scoped.get("scoped_tool").is_none());
assert_eq!(scoped.specs().len(), 0);
}
#[test]
fn push_scope_is_thread_safe() {
use std::thread;
let base = ToolRegistry::new();
let scoped = Arc::new(ScopedToolRegistry::new(base));
let handles: Vec<_> = (0..4)
.map(|_| {
let s = Arc::clone(&scoped);
thread::spawn(move || {
let _scope = s.push_scope();
})
})
.collect();
for h in handles {
h.join().unwrap();
}
assert_eq!(scoped.scope_depth(), 4);
}
#[test]
fn specs_should_return_sorted_by_name() {
let base = ToolRegistry::new();
base.register(make_tool("zebra"));
base.register(make_tool("alpha"));
let scoped = ScopedToolRegistry::new(base);
let scope = scoped.push_scope();
scoped
.register_in_scope(scope, make_tool("mike"))
.ok()
.unwrap();
let specs = scoped.specs();
assert_eq!(specs.len(), 3);
assert_eq!(specs[0].name, "alpha");
assert_eq!(specs[1].name, "mike");
assert_eq!(specs[2].name, "zebra");
}
}