use super::spec::{AnyStateAction, StateScope};
use std::any::TypeId;
use std::collections::HashMap;
use tirea_state::StateSpec;
#[derive(Debug, Clone, Default)]
pub struct StateScopeRegistry {
by_type_id: HashMap<TypeId, (&'static str, StateScope, &'static str)>,
}
impl StateScopeRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register<S: StateSpec>(&mut self, scope: StateScope) {
self.by_type_id.insert(
TypeId::of::<S>(),
(std::any::type_name::<S>(), scope, S::PATH),
);
}
pub fn scope_for_type_id(&self, type_id: TypeId) -> Option<StateScope> {
self.by_type_id.get(&type_id).map(|(_, scope, _)| *scope)
}
pub fn run_scoped_paths(&self) -> Vec<&'static str> {
self.by_type_id
.values()
.filter(|(_, scope, _)| *scope == StateScope::Run)
.map(|(_, _, path)| *path)
.collect()
}
pub fn resolve(&self, action: &AnyStateAction) -> StateScope {
if let Some(scope) = self.scope_for_type_id(action.state_type_id()) {
return scope;
}
action.scope()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tirea_state::{DocCell, PatchSink, Path, State, TireaResult};
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
struct RunScoped {
value: i64,
}
struct RunScopedRef;
impl State for RunScoped {
type Ref<'a> = RunScopedRef;
const PATH: &'static str = "run_scoped";
fn state_ref<'a>(_: &'a DocCell, _: Path, _: PatchSink<'a>) -> Self::Ref<'a> {
RunScopedRef
}
fn from_value(value: &Value) -> TireaResult<Self> {
if value.is_null() {
return Ok(Self::default());
}
serde_json::from_value(value.clone()).map_err(tirea_state::TireaError::Serialization)
}
fn to_value(&self) -> TireaResult<Value> {
serde_json::to_value(self).map_err(tirea_state::TireaError::Serialization)
}
}
impl StateSpec for RunScoped {
type Action = ();
fn reduce(&mut self, _: ()) {}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
struct ToolScoped {
value: i64,
}
struct ToolScopedRef;
impl State for ToolScoped {
type Ref<'a> = ToolScopedRef;
const PATH: &'static str = "tool_scoped";
fn state_ref<'a>(_: &'a DocCell, _: Path, _: PatchSink<'a>) -> Self::Ref<'a> {
ToolScopedRef
}
fn from_value(value: &Value) -> TireaResult<Self> {
if value.is_null() {
return Ok(Self::default());
}
serde_json::from_value(value.clone()).map_err(tirea_state::TireaError::Serialization)
}
fn to_value(&self) -> TireaResult<Value> {
serde_json::to_value(self).map_err(tirea_state::TireaError::Serialization)
}
}
impl StateSpec for ToolScoped {
type Action = ();
const SCOPE: StateScope = StateScope::ToolCall;
fn reduce(&mut self, _: ()) {}
}
#[test]
fn register_and_lookup() {
let mut reg = StateScopeRegistry::new();
reg.register::<RunScoped>(StateScope::Run);
reg.register::<ToolScoped>(StateScope::ToolCall);
assert_eq!(
reg.scope_for_type_id(TypeId::of::<RunScoped>()),
Some(StateScope::Run)
);
assert_eq!(
reg.scope_for_type_id(TypeId::of::<ToolScoped>()),
Some(StateScope::ToolCall)
);
}
#[test]
fn unregistered_type_returns_none() {
let reg = StateScopeRegistry::new();
assert_eq!(reg.scope_for_type_id(TypeId::of::<RunScoped>()), None);
}
#[test]
fn resolve_falls_back_to_action_scope() {
let reg = StateScopeRegistry::new();
let action = AnyStateAction::new::<RunScoped>(());
assert_eq!(reg.resolve(&action), StateScope::Thread);
}
#[test]
fn resolve_uses_registered_scope() {
let mut reg = StateScopeRegistry::new();
reg.register::<ToolScoped>(StateScope::ToolCall);
assert_eq!(
reg.scope_for_type_id(TypeId::of::<ToolScoped>()),
Some(StateScope::ToolCall)
);
}
#[test]
fn run_scoped_paths_returns_run_types() {
let mut reg = StateScopeRegistry::new();
reg.register::<RunScoped>(StateScope::Run);
reg.register::<ToolScoped>(StateScope::ToolCall);
let paths = reg.run_scoped_paths();
assert_eq!(paths.len(), 1);
assert_eq!(paths[0], "run_scoped");
}
#[test]
fn run_scoped_paths_empty_when_none_registered() {
let mut reg = StateScopeRegistry::new();
reg.register::<ToolScoped>(StateScope::ToolCall);
assert!(reg.run_scoped_paths().is_empty());
}
}