use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::any::{Any, TypeId};
use std::sync::Arc;
use crate::JsonObject;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Meta {
pub request_id: Option<String>,
pub timestamp: Option<u64>,
pub source: Option<String>,
}
impl Meta {
pub fn new() -> Self {
Self::default()
}
pub fn with_request_id(mut self, id: impl Into<String>) -> Self {
self.request_id = Some(id.into());
self
}
pub fn with_timestamp(mut self, ts: u64) -> Self {
self.timestamp = Some(ts);
self
}
pub fn with_source(mut self, source: impl Into<String>) -> Self {
self.source = Some(source.into());
self
}
}
#[derive(Debug, Clone)]
pub struct ToolContext {
pub meta: Meta,
args: Option<JsonObject>,
extensions: Arc<DashMap<TypeId, Arc<dyn Any + Send + Sync>>>,
}
impl ToolContext {
pub fn new() -> Self {
Self {
meta: Meta::default(),
args: Some(JsonObject::new()),
extensions: Arc::new(DashMap::new()),
}
}
pub fn with_args(args: JsonObject) -> Self {
Self {
meta: Meta::default(),
args: Some(args),
extensions: Arc::new(DashMap::new()),
}
}
pub fn empty() -> Self {
Self {
meta: Meta::default(),
args: None,
extensions: Arc::new(DashMap::new()),
}
}
pub fn with_meta(mut self, meta: Meta) -> Self {
self.meta = meta;
self
}
pub fn take_args(&mut self) -> Option<JsonObject> {
self.args.take()
}
pub fn peek_args(&self) -> Option<&JsonObject> {
self.args.as_ref()
}
pub fn insert<T>(&self, ext: T) -> Option<Arc<T>>
where
T: Send + Sync + 'static,
{
self.extensions
.insert(TypeId::of::<T>(), Arc::new(ext))
.and_then(|arc| arc.downcast().ok())
}
pub fn get<T>(&self) -> Option<Arc<T>>
where
T: Send + Sync + 'static,
{
self.extensions
.get(&TypeId::of::<T>())
.and_then(|entry| Arc::downcast(entry.clone()).ok())
}
pub fn remove<T>(&self) -> Option<Arc<T>>
where
T: Send + Sync + 'static,
{
self.extensions
.remove(&TypeId::of::<T>())
.and_then(|(_, arc)| Arc::downcast(arc).ok())
}
pub fn contains<T>(&self) -> bool
where
T: Send + Sync + 'static,
{
self.extensions.contains_key(&TypeId::of::<T>())
}
pub fn clear(&self) {
self.extensions.clear();
}
pub fn len(&self) -> usize {
self.extensions.len()
}
pub fn is_empty(&self) -> bool {
self.extensions.is_empty()
}
}
impl Default for ToolContext {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_meta_builder() {
let meta = Meta::new()
.with_request_id("test-123")
.with_timestamp(1234567890)
.with_source("test");
assert_eq!(meta.request_id, Some("test-123".to_string()));
assert_eq!(meta.timestamp, Some(1234567890));
assert_eq!(meta.source, Some("test".to_string()));
}
#[test]
fn test_context_new() {
let ctx = ToolContext::new();
assert!(ctx.peek_args().is_some());
assert_eq!(ctx.meta.request_id, None);
}
#[test]
fn test_context_take_args() {
let mut ctx = ToolContext::with_args(
serde_json::json!({"key": "value"})
.as_object()
.unwrap()
.clone(),
);
let taken = ctx.take_args().unwrap();
assert!(taken.contains_key("key"));
assert!(ctx.take_args().is_none());
}
#[test]
fn test_context_extensions() {
let ctx = ToolContext::new();
#[derive(Debug, PartialEq)]
struct MyExt {
value: i32,
}
assert!(ctx.insert(MyExt { value: 42 }).is_none());
let ext = ctx.get::<MyExt>().unwrap();
assert_eq!(ext.value, 42);
assert!(ctx.contains::<MyExt>());
let removed = ctx.remove::<MyExt>().unwrap();
assert_eq!(removed.value, 42);
assert!(!ctx.contains::<MyExt>());
}
#[test]
fn test_context_clone() {
#[derive(Debug, PartialEq)]
struct MyExt {
value: i32,
}
let ctx = ToolContext::new();
ctx.insert(MyExt { value: 42 });
let cloned = ctx.clone();
let ext = cloned.get::<MyExt>().unwrap();
assert_eq!(ext.value, 42);
}
#[test]
fn test_context_thread_safety() {
use std::sync::Arc;
use std::thread;
let ctx = Arc::new(ToolContext::new());
let ctx_clone = ctx.clone();
#[derive(Debug)]
struct Counter;
ctx.insert(Counter);
let handle = thread::spawn(move || {
assert!(ctx_clone.contains::<Counter>());
});
handle.join().unwrap();
assert!(ctx.contains::<Counter>());
}
}