use std::collections::HashMap;
use std::fmt;
use serde::{Deserialize, Serialize};
use super::LlmMessage;
pub trait CustomMessage: Send + Sync + fmt::Debug + std::any::Any {
fn as_any(&self) -> &dyn std::any::Any;
fn type_name(&self) -> Option<&str> {
None
}
fn to_json(&self) -> Option<serde_json::Value> {
None
}
fn clone_box(&self) -> Option<Box<dyn CustomMessage>> {
None
}
}
pub type CustomMessageDeserializer =
Box<dyn Fn(serde_json::Value) -> Result<Box<dyn CustomMessage>, String> + Send + Sync>;
pub struct CustomMessageRegistry {
deserializers: HashMap<String, CustomMessageDeserializer>,
}
impl CustomMessageRegistry {
#[must_use]
pub fn new() -> Self {
Self {
deserializers: HashMap::new(),
}
}
pub fn register(
&mut self,
type_name: impl Into<String>,
deserializer: CustomMessageDeserializer,
) {
self.deserializers.insert(type_name.into(), deserializer);
}
pub fn register_type<T>(&mut self, type_name: impl Into<String>)
where
T: CustomMessage + serde::de::DeserializeOwned + 'static,
{
self.deserializers.insert(
type_name.into(),
Box::new(|value| {
serde_json::from_value::<T>(value)
.map(|v| Box::new(v) as Box<dyn CustomMessage>)
.map_err(|e| e.to_string())
}),
);
}
pub fn deserialize(
&self,
type_name: &str,
value: serde_json::Value,
) -> Result<Box<dyn CustomMessage>, String> {
let deser = self.deserializers.get(type_name).ok_or_else(|| {
format!("no deserializer registered for custom message type: {type_name}")
})?;
deser(value)
}
#[must_use]
pub fn has_type_name(&self, type_name: &str) -> bool {
self.deserializers.contains_key(type_name)
}
}
impl Default for CustomMessageRegistry {
fn default() -> Self {
Self::new()
}
}
impl fmt::Debug for CustomMessageRegistry {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CustomMessageRegistry")
.field(
"registered_types",
&self.deserializers.keys().collect::<Vec<_>>(),
)
.finish()
}
}
#[must_use]
pub fn serialize_custom_message(msg: &dyn CustomMessage) -> Option<serde_json::Value> {
let type_name = msg.type_name()?;
let payload = msg.to_json()?;
Some(serde_json::json!({
"type": type_name,
"data": payload,
}))
}
pub fn deserialize_custom_message(
registry: &CustomMessageRegistry,
envelope: &serde_json::Value,
) -> Result<Box<dyn CustomMessage>, String> {
let type_name = envelope
.get("type")
.and_then(|v| v.as_str())
.ok_or_else(|| "missing 'type' field in custom message envelope".to_string())?;
let data = envelope
.get("data")
.cloned()
.ok_or_else(|| "missing 'data' field in custom message envelope".to_string())?;
registry.deserialize(type_name, data)
}
#[allow(clippy::large_enum_variant)]
pub enum AgentMessage {
Llm(LlmMessage),
Custom(Box<dyn CustomMessage>),
}
impl AgentMessage {
pub const fn cache_hint(&self) -> Option<&crate::context_cache::CacheHint> {
match self {
Self::Llm(msg) => match msg {
LlmMessage::User(m) => m.cache_hint.as_ref(),
LlmMessage::Assistant(m) => m.cache_hint.as_ref(),
LlmMessage::ToolResult(m) => m.cache_hint.as_ref(),
},
Self::Custom(_) => None,
}
}
pub const fn set_cache_hint(&mut self, hint: crate::context_cache::CacheHint) {
match self {
Self::Llm(msg) => match msg {
LlmMessage::User(m) => m.cache_hint = Some(hint),
LlmMessage::Assistant(m) => m.cache_hint = Some(hint),
LlmMessage::ToolResult(m) => m.cache_hint = Some(hint),
},
Self::Custom(_) => {}
}
}
pub const fn clear_cache_hint(&mut self) {
match self {
Self::Llm(msg) => match msg {
LlmMessage::User(m) => m.cache_hint = None,
LlmMessage::Assistant(m) => m.cache_hint = None,
LlmMessage::ToolResult(m) => m.cache_hint = None,
},
Self::Custom(_) => {}
}
}
pub fn downcast_ref<T: 'static>(&self) -> Result<&T, crate::error::DowncastError> {
match self {
Self::Custom(msg) => {
msg.as_any()
.downcast_ref::<T>()
.ok_or_else(|| crate::error::DowncastError {
expected: std::any::type_name::<T>(),
actual: msg
.type_name()
.map_or_else(|| format!("{msg:?}"), ToString::to_string),
})
}
Self::Llm(_) => Err(crate::error::DowncastError {
expected: std::any::type_name::<T>(),
actual: "LlmMessage".to_string(),
}),
}
}
}
impl fmt::Debug for AgentMessage {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Llm(msg) => f.debug_tuple("Llm").field(msg).finish(),
Self::Custom(msg) => f.debug_tuple("Custom").field(msg).finish(),
}
}
}
impl Serialize for AgentMessage {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
match self {
Self::Llm(msg) => {
use serde::ser::SerializeMap;
let mut map = serializer.serialize_map(Some(2))?;
map.serialize_entry("kind", "llm")?;
map.serialize_entry("message", msg)?;
map.end()
}
Self::Custom(msg) => {
use serde::ser::SerializeMap;
let mut map = serializer.serialize_map(Some(2))?;
map.serialize_entry("kind", "custom")?;
let envelope = serialize_custom_message(msg.as_ref());
map.serialize_entry("message", &envelope)?;
map.end()
}
}
}
}
impl<'de> Deserialize<'de> for AgentMessage {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
#[derive(Deserialize)]
struct Tagged {
kind: String,
message: serde_json::Value,
}
let tagged = Tagged::deserialize(deserializer)?;
match tagged.kind.as_str() {
"llm" => {
let msg: LlmMessage =
serde_json::from_value(tagged.message).map_err(serde::de::Error::custom)?;
Ok(Self::Llm(msg))
}
"custom" => Err(serde::de::Error::custom(
"cannot deserialize AgentMessage::Custom (requires runtime type info)",
)),
other => Err(serde::de::Error::unknown_variant(other, &["llm", "custom"])),
}
}
}