use crate::runtime::AgentRuntime;
use crate::runtime_ref::downcast_runtime_ref;
use crate::templates::TemplateEngine;
use crate::types::{
Entity, GenerateTextParams, Memory, MemoryQuery, ModelHandlerParams, ModelType, Relationship,
Role, Room, State, UUID,
};
use crate::utils::string_to_uuid;
use crate::ZoeyError;
use crate::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use tracing::{debug, error, info, instrument, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct EntityResolution {
entity_id: Option<UUID>,
#[serde(rename = "type")]
match_type: MatchType,
matches: Vec<EntityMatch>,
#[serde(default)]
confidence: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct EntityMatch {
name: String,
reason: String,
#[serde(skip_serializing_if = "Option::is_none")]
entity_id: Option<UUID>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub(crate) enum MatchType {
ExactMatch,
UsernameMatch,
NameMatch,
RelationshipMatch,
Ambiguous,
Unknown,
}
#[derive(Debug, Clone)]
pub struct EntityResolutionConfig {
pub use_llm: bool,
pub model_type: ModelType,
pub cache_ttl: u64,
pub max_entities: usize,
pub context_message_count: usize,
pub min_confidence: f32,
pub enable_metrics: bool,
}
impl Default for EntityResolutionConfig {
fn default() -> Self {
Self {
use_llm: true,
model_type: ModelType::TextSmall,
cache_ttl: 300,
max_entities: 50,
context_message_count: 20,
min_confidence: 0.5,
enable_metrics: true,
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct CacheEntry {
pub(crate) entity: Option<Entity>,
pub(crate) timestamp: Instant,
pub(crate) confidence: f32,
}
type EntityCache = Arc<RwLock<HashMap<String, CacheEntry>>>;
static ENTITY_CACHE: once_cell::sync::Lazy<EntityCache> =
once_cell::sync::Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
const ENTITY_RESOLUTION_TEMPLATE: &str = r#"# Task: Resolve Entity Name
Message Sender: {{senderName}} (ID: {{senderId}})
Agent: {{agentName}} (ID: {{agentId}})
# Entities in Room:
{{#if entitiesInRoom}}
{{entitiesInRoom}}
{{/if}}
{{recentMessages}}
# Instructions:
1. Analyze the context to identify which entity is being referenced
2. Consider special references like "me" (the message sender) or "you" (agent the message is directed to)
3. Look for usernames/handles in standard formats (e.g. @username, user#1234)
4. Consider context from recent messages for pronouns and references
5. If multiple matches exist, use context to disambiguate
6. Consider recent interactions and relationship strength when resolving ambiguity
Do NOT include any thinking, reasoning, or <think> sections in your response.
Go directly to the XML response format without any preamble or explanation.
Return an XML response with:
<response>
<entityId>exact-id-if-known-otherwise-null</entityId>
<type>EXACT_MATCH | USERNAME_MATCH | NAME_MATCH | RELATIONSHIP_MATCH | AMBIGUOUS | UNKNOWN</type>
<matches>
<match>
<name>matched-name</name>
<reason>why this entity matches</reason>
</match>
</matches>
</response>
IMPORTANT: Your response must ONLY contain the <response></response> XML block above. Do not include any text, thinking, or reasoning before or after this XML block. Start your response immediately with <response> and end with </response>."#;
#[instrument(skip(xml), level = "debug")]
pub(crate) fn parse_entity_resolution_xml(xml: &str) -> Result<EntityResolution> {
debug!("Parsing entity resolution XML");
let entity_id = extract_xml_tag(xml, "entityId");
let match_type_str = extract_xml_tag(xml, "type").unwrap_or_else(|| "UNKNOWN".to_string());
let match_type = match match_type_str.to_uppercase().as_str() {
"EXACT_MATCH" => MatchType::ExactMatch,
"USERNAME_MATCH" => MatchType::UsernameMatch,
"NAME_MATCH" => MatchType::NameMatch,
"RELATIONSHIP_MATCH" => MatchType::RelationshipMatch,
"AMBIGUOUS" => MatchType::Ambiguous,
_ => MatchType::Unknown,
};
let confidence = extract_xml_tag(xml, "confidence")
.and_then(|c| c.parse::<f32>().ok())
.unwrap_or(0.5);
let mut matches = Vec::new();
if let Some(matches_section) = extract_xml_section(xml, "matches") {
let match_blocks = extract_xml_sections(&matches_section, "match");
for block in match_blocks {
if let (Some(name), Some(reason)) = (
extract_xml_tag(&block, "name"),
extract_xml_tag(&block, "reason"),
) {
let entity_id = extract_xml_tag(&block, "entityId")
.and_then(|id| uuid::Uuid::parse_str(&id).ok());
matches.push(EntityMatch {
name,
reason,
entity_id,
});
}
}
}
let entity_id = entity_id.and_then(|id| {
if id == "null" || id.is_empty() {
None
} else {
match uuid::Uuid::parse_str(&id) {
Ok(uuid) => Some(uuid),
Err(e) => {
warn!("Failed to parse entity ID '{}': {}", id, e);
None
}
}
}
});
debug!(
"Parsed resolution: match_type={:?}, confidence={}, matches={}",
match_type,
confidence,
matches.len()
);
Ok(EntityResolution {
entity_id,
match_type,
matches,
confidence,
})
}
pub(crate) fn extract_xml_tag(xml: &str, tag: &str) -> Option<String> {
let start_tag = format!("<{}>", tag);
let end_tag = format!("</{}>", tag);
if let Some(start_pos) = xml.find(&start_tag) {
let content_start = start_pos + start_tag.len();
if let Some(end_pos) = xml[content_start..].find(&end_tag) {
return Some(
xml[content_start..content_start + end_pos]
.trim()
.to_string(),
);
}
}
None
}
pub(crate) fn extract_xml_section(xml: &str, tag: &str) -> Option<String> {
let start_tag = format!("<{}>", tag);
let end_tag = format!("</{}>", tag);
if let Some(start_pos) = xml.find(&start_tag) {
let content_start = start_pos + start_tag.len();
if let Some(end_pos) = xml[content_start..].find(&end_tag) {
return Some(xml[content_start..content_start + end_pos].to_string());
}
}
None
}
pub(crate) fn extract_xml_sections(xml: &str, tag: &str) -> Vec<String> {
let mut sections = Vec::new();
let start_tag = format!("<{}>", tag);
let end_tag = format!("</{}>", tag);
let mut search_pos = 0;
while let Some(start_pos) = xml[search_pos..].find(&start_tag) {
let actual_start = search_pos + start_pos;
let content_start = actual_start + start_tag.len();
if let Some(end_pos) = xml[content_start..].find(&end_tag) {
sections.push(xml[content_start..content_start + end_pos].to_string());
search_pos = content_start + end_pos + end_tag.len();
} else {
break;
}
}
sections
}
pub(crate) fn generate_cache_key(message: &Memory, state: &State) -> String {
let entity_name = state
.values
.get("entityName")
.map(|s| s.as_str())
.unwrap_or("");
format!(
"{}:{}:{}:{}",
message.room_id,
message.entity_id,
message.content.text.chars().take(50).collect::<String>(),
entity_name
)
}
pub(crate) fn clean_cache(cache: &EntityCache, ttl_seconds: u64) {
let mut cache_lock = cache.write().unwrap();
let now = Instant::now();
let ttl = Duration::from_secs(ttl_seconds);
cache_lock.retain(|_, entry| now.duration_since(entry.timestamp) < ttl);
}
pub(crate) fn get_cached_entity(
cache_key: &str,
config: &EntityResolutionConfig,
) -> Option<Option<Entity>> {
if config.cache_ttl == 0 {
return None;
}
let cache = ENTITY_CACHE.read().unwrap();
if let Some(entry) = cache.get(cache_key) {
if entry.timestamp.elapsed().as_secs() < config.cache_ttl {
if entry.confidence >= config.min_confidence {
debug!(
"Cache hit for entity resolution (confidence: {})",
entry.confidence
);
return Some(entry.entity.clone());
} else {
debug!("Cache entry found but below confidence threshold");
}
} else {
debug!("Cache entry expired");
}
}
None
}
pub(crate) fn cache_entity(
cache_key: String,
entity: Option<Entity>,
confidence: f32,
config: &EntityResolutionConfig,
) {
if config.cache_ttl == 0 {
return;
}
let mut cache = ENTITY_CACHE.write().unwrap();
cache.insert(
cache_key,
CacheEntry {
entity,
timestamp: Instant::now(),
confidence,
},
);
debug!("Cached entity resolution (confidence: {})", confidence);
}
async fn call_llm_for_entity_resolution(
agent_runtime: &AgentRuntime,
prompt: &str,
model_type: ModelType,
) -> Result<String> {
let models = agent_runtime.models.read().unwrap();
let model_type_str = match model_type {
ModelType::TextSmall => "TEXT_SMALL",
ModelType::TextMedium => "TEXT_MEDIUM",
ModelType::TextLarge => "TEXT_LARGE",
_ => "TEXT_SMALL", };
let handlers = models.get(model_type_str);
if let Some(handlers) = handlers {
if handlers.is_empty() {
warn!("No model handlers registered for {}", model_type_str);
return Err(ZoeyError::Model(format!(
"No model handlers for {}",
model_type_str
)));
}
let provider = &handlers[0];
info!(
"Using LLM provider for entity resolution: {} (priority: {})",
provider.name, provider.priority
);
let (model_name, temperature, max_tokens) = {
let model = if provider.name.to_lowercase().contains("openai") {
agent_runtime
.get_setting("OPENAI_MODEL")
.and_then(|v| v.as_str().map(|s| s.to_string()))
} else if provider.name.to_lowercase().contains("anthropic")
|| provider.name.to_lowercase().contains("claude")
{
agent_runtime
.get_setting("ANTHROPIC_MODEL")
.and_then(|v| v.as_str().map(|s| s.to_string()))
} else {
agent_runtime
.get_setting("LOCAL_LLM_MODEL")
.and_then(|v| v.as_str().map(|s| s.to_string()))
};
let temp = agent_runtime
.get_setting("temperature")
.and_then(|v| v.as_f64().map(|f| f as f32))
.unwrap_or(0.3);
let tokens = agent_runtime
.get_setting("max_tokens")
.and_then(|v| v.as_u64().map(|u| u as usize))
.unwrap_or(300);
(model, temp, tokens)
};
let params = GenerateTextParams {
prompt: prompt.to_string(),
max_tokens: Some(max_tokens),
temperature: Some(temperature),
top_p: None,
stop: Some(vec!["</response>".to_string()]),
model: model_name,
frequency_penalty: None,
presence_penalty: None,
};
let model_params = ModelHandlerParams {
runtime: Arc::new(()) as Arc<dyn std::any::Any + Send + Sync>,
params,
};
debug!(
"Calling LLM for entity resolution (temp: {}, max_tokens: {})",
temperature, max_tokens
);
match (provider.handler)(model_params).await {
Ok(response) => {
info!(
"✓ LLM entity resolution response received ({} chars)",
response.len()
);
Ok(response)
}
Err(e) => {
error!("LLM model handler failed: {}", e);
Err(e)
}
}
} else {
warn!("No model handlers found for {}", model_type_str);
Err(ZoeyError::Model(format!(
"No model handlers for {}",
model_type_str
)))
}
}
pub async fn find_entity_by_name(
runtime: Arc<dyn std::any::Any + Send + Sync>,
message: &Memory,
state: &State,
) -> Result<Option<Entity>> {
find_entity_by_name_with_config(runtime, message, state, &EntityResolutionConfig::default())
.await
}
#[instrument(skip(runtime, message, state, config), fields(
message_id = %message.id,
room_id = %message.room_id,
entity_id = %message.entity_id
), level = "info")]
pub async fn find_entity_by_name_with_config(
runtime: Arc<dyn std::any::Any + Send + Sync>,
message: &Memory,
state: &State,
config: &EntityResolutionConfig,
) -> Result<Option<Entity>> {
let start_time = Instant::now();
info!("Starting entity resolution");
clean_cache(&ENTITY_CACHE, config.cache_ttl);
let cache_key = generate_cache_key(message, state);
if let Some(cached) = get_cached_entity(&cache_key, config) {
info!(
"Entity resolution cache hit ({}ms)",
start_time.elapsed().as_millis()
);
return Ok(cached);
}
debug!("Cache miss, proceeding with resolution");
let runtime_arc = if let Some(runtime_ref) = downcast_runtime_ref(&runtime) {
runtime_ref.try_upgrade().ok_or_else(|| {
error!("Runtime has been dropped");
ZoeyError::Runtime("Runtime has been dropped".to_string())
})?
} else {
error!("Runtime must be passed as Arc<RuntimeRef>");
return Err(ZoeyError::Runtime(
"Runtime must be passed as Arc<RuntimeRef>. Use RuntimeRef::new() to wrap the runtime."
.to_string(),
));
};
let agent_runtime = runtime_arc.read().map_err(|e| {
error!("Failed to lock runtime: {}", e);
ZoeyError::Runtime(format!("Failed to lock runtime: {}", e))
})?;
let adapter_lock = agent_runtime.adapter.read().map_err(|e| {
error!("Failed to lock adapter: {}", e);
ZoeyError::Runtime(format!("Failed to lock adapter: {}", e))
})?;
let adapter = adapter_lock.as_ref().ok_or_else(|| {
warn!("No database adapter configured");
ZoeyError::Database("No database adapter configured".to_string())
})?;
let agent_id = agent_runtime.agent_id;
let agent_name = agent_runtime.character.name.clone();
debug!("Resolving entity for agent: {} ({})", agent_name, agent_id);
let room = if let Some(room_value) = state.data.get("room") {
match serde_json::from_value::<Room>(room_value.clone()) {
Ok(r) => {
debug!("Using room from state");
Some(r)
}
Err(e) => {
warn!("Failed to deserialize room from state: {}", e);
None
}
}
} else {
None
};
let room = if let Some(r) = room {
r
} else {
debug!("Fetching room from database: {}", message.room_id);
adapter.get_room(message.room_id).await?.ok_or_else(|| {
error!("Room not found: {}", message.room_id);
ZoeyError::NotFound(format!("Room {} not found", message.room_id))
})?
};
debug!("Room: {} (world: {})", room.name, room.world_id);
let world = match adapter.get_world(room.world_id).await {
Ok(Some(w)) => {
debug!("Loaded world: {}", w.name);
Some(w)
}
Ok(None) => {
debug!("World not found: {}", room.world_id);
None
}
Err(e) => {
warn!("Failed to load world: {}", e);
None
}
};
let entities_in_room = match adapter.get_entities_for_room(room.id, true).await {
Ok(entities) => {
debug!("Found {} entities in room", entities.len());
if entities.len() > config.max_entities {
warn!(
"Too many entities ({}), limiting to {}",
entities.len(),
config.max_entities
);
entities.into_iter().take(config.max_entities).collect()
} else {
entities
}
}
Err(e) => {
error!("Failed to get entities for room: {}", e);
return Err(e);
}
};
if let Some(ref world) = world {
let _world_roles: HashMap<UUID, Role> =
if let Some(roles_value) = world.metadata.get("roles") {
if let Some(roles_obj) = roles_value.as_object() {
roles_obj
.iter()
.filter_map(|(k, v)| {
let uuid = uuid::Uuid::parse_str(k).ok()?;
let role_str = v.as_str()?;
let role = match role_str.to_uppercase().as_str() {
"OWNER" => Role::Owner,
"ADMIN" => Role::Admin,
"MODERATOR" | "MOD" => Role::Moderator,
"MEMBER" => Role::Member,
_ => Role::None,
};
Some((uuid, role))
})
.collect()
} else {
HashMap::new()
}
} else {
HashMap::new()
};
debug!(
"Loaded {} entities with permission filtering",
entities_in_room.len()
);
}
let relationships: Vec<Relationship> = vec![];
let engine = TemplateEngine::new();
let mut template_data: HashMap<String, serde_json::Value> = HashMap::new();
let sender_entity = entities_in_room.iter().find(|e| e.id == message.entity_id);
let sender_name = sender_entity
.and_then(|e| e.name.clone())
.or_else(|| sender_entity.and_then(|e| e.username.clone()))
.unwrap_or_else(|| "Unknown".to_string());
template_data.insert("senderName".to_string(), serde_json::json!(sender_name));
template_data.insert(
"senderId".to_string(),
serde_json::json!(message.entity_id.to_string()),
);
template_data.insert("agentName".to_string(), serde_json::json!(agent_name));
template_data.insert(
"agentId".to_string(),
serde_json::json!(agent_id.to_string()),
);
let entities_str = format_entities(&entities_in_room);
template_data.insert(
"entitiesInRoom".to_string(),
serde_json::json!(entities_str),
);
let recent_messages = adapter
.get_memories(MemoryQuery {
room_id: Some(message.room_id),
agent_id: Some(agent_id),
count: Some(20),
unique: Some(false),
..Default::default()
})
.await
.unwrap_or_default();
let messages_str = recent_messages
.iter()
.map(|m| {
let entity_name = entities_in_room
.iter()
.find(|e| e.id == m.entity_id)
.and_then(|e| e.name.clone())
.or_else(|| {
entities_in_room
.iter()
.find(|e| e.id == m.entity_id)
.and_then(|e| e.username.clone())
})
.unwrap_or_else(|| "Unknown".to_string());
format!("{}: {}", entity_name, m.content.text)
})
.collect::<Vec<_>>()
.join("\n");
template_data.insert(
"recentMessages".to_string(),
serde_json::json!(messages_str),
);
let prompt = engine.render(ENTITY_RESOLUTION_TEMPLATE, &template_data)?;
debug!(
"Generated entity resolution prompt ({} chars)",
prompt.len()
);
let (resolved_entity, confidence): (Option<Entity>, f32) = if config.use_llm {
debug!("Attempting LLM-based entity resolution");
match call_llm_for_entity_resolution(&agent_runtime, &prompt, config.model_type).await {
Ok(llm_response) => {
debug!("LLM response received ({} chars)", llm_response.len());
match parse_entity_resolution_xml(&llm_response) {
Ok(resolution) => {
debug!(
"LLM resolution parsed: match_type={:?}, confidence={}",
resolution.match_type, resolution.confidence
);
if let Some(entity_id) = resolution.entity_id {
if let Some(entity) =
entities_in_room.iter().find(|e| e.id == entity_id)
{
(Some(entity.clone()), resolution.confidence)
} else {
debug!("Entity ID from LLM not found in room");
(None, resolution.confidence)
}
} else if !resolution.matches.is_empty() {
let match_name = resolution.matches[0].name.to_lowercase();
if let Some(entity) = entities_in_room.iter().find(|e| {
e.name
.as_ref()
.map(|n| n.to_lowercase() == match_name)
.unwrap_or(false)
|| e.username
.as_ref()
.map(|u| u.to_lowercase() == match_name)
.unwrap_or(false)
}) {
(Some(entity.clone()), resolution.confidence)
} else {
(None, resolution.confidence)
}
} else {
(None, resolution.confidence)
}
}
Err(e) => {
warn!("Failed to parse LLM entity resolution: {}", e);
(None, 0.0)
}
}
}
Err(e) => {
warn!("LLM call failed for entity resolution: {}", e);
(None, 0.0)
}
}
} else {
debug!("LLM resolution disabled, using fallback strategies only");
(None, 0.0)
};
if let Some(entity) = resolved_entity {
info!(
"Entity resolved via LLM (confidence: {}, {}ms)",
confidence,
start_time.elapsed().as_millis()
);
cache_entity(cache_key, Some(entity.clone()), confidence, config);
return Ok(Some(entity));
}
debug!("Using fallback resolution strategies");
if let Some(entity_name) = state.values.get("entityName") {
debug!("Trying state-based resolution with hint: {}", entity_name);
let query = entity_name.to_lowercase();
for entity in &entities_in_room {
if let Some(name) = &entity.name {
if name.to_lowercase().contains(&query) {
info!(
"Entity resolved via state hint ({}ms)",
start_time.elapsed().as_millis()
);
let result = Some(entity.clone());
cache_entity(cache_key, result.clone(), 0.8, config);
return Ok(result);
}
}
if let Some(username) = &entity.username {
if username.to_lowercase().contains(&query) {
info!(
"Entity resolved via state hint username ({}ms)",
start_time.elapsed().as_millis()
);
let result = Some(entity.clone());
cache_entity(cache_key, result.clone(), 0.8, config);
return Ok(result);
}
}
}
}
let text = message.content.text.to_lowercase();
debug!("Trying mention-based resolution");
for entity in &entities_in_room {
if let Some(username) = &entity.username {
let mention = format!("@{}", username.to_lowercase());
if text.contains(&mention) {
info!(
"Entity resolved via @mention ({}ms)",
start_time.elapsed().as_millis()
);
let result = Some(entity.clone());
cache_entity(cache_key, result.clone(), 0.9, config);
return Ok(result);
}
}
}
debug!("Trying name-based resolution");
for entity in &entities_in_room {
if let Some(name) = &entity.name {
if name.len() > 2 && text.contains(&name.to_lowercase()) {
info!(
"Entity resolved via name mention ({}ms)",
start_time.elapsed().as_millis()
);
let result = Some(entity.clone());
cache_entity(cache_key, result.clone(), 0.7, config);
return Ok(result);
}
}
}
debug!("Trying pronoun-based resolution");
if text.contains("you") || text.contains("your") {
for entity in &entities_in_room {
if entity.id == agent_id {
info!(
"Entity resolved via pronoun 'you' -> agent ({}ms)",
start_time.elapsed().as_millis()
);
let result = Some(entity.clone());
cache_entity(cache_key, result.clone(), 0.6, config);
return Ok(result);
}
}
}
if text.contains("me") || text.contains("my") || text.contains("i ") {
for entity in &entities_in_room {
if entity.id == message.entity_id {
info!(
"Entity resolved via pronoun 'me' -> sender ({}ms)",
start_time.elapsed().as_millis()
);
let result = Some(entity.clone());
cache_entity(cache_key, result.clone(), 0.6, config);
return Ok(result);
}
}
}
if !relationships.is_empty() {
debug!("Trying relationship-based resolution");
let interaction_data = get_recent_interactions(
message.entity_id,
&entities_in_room,
room.id,
&recent_messages,
&relationships,
);
if let Some((entity, _, score)) = interaction_data.first() {
if *score > 0 {
info!(
"Entity resolved via relationships (score: {}, {}ms)",
score,
start_time.elapsed().as_millis()
);
let result = Some(entity.clone());
cache_entity(cache_key, result.clone(), 0.5, config);
return Ok(result);
}
}
}
info!(
"No entity resolved ({}ms)",
start_time.elapsed().as_millis()
);
cache_entity(cache_key, None, 0.0, config);
Ok(None)
}
pub fn create_unique_uuid_for_entity(agent_id: UUID, base_user_id: &str) -> UUID {
if base_user_id == agent_id.to_string() {
return agent_id;
}
let combined_string = format!("{}:{}", base_user_id, agent_id);
string_to_uuid(&combined_string)
}
pub fn get_entity_details(
_room: &Room,
entities: &[Entity],
) -> Vec<HashMap<String, serde_json::Value>> {
let mut unique_entities: HashMap<UUID, HashMap<String, serde_json::Value>> = HashMap::new();
for entity in entities {
if unique_entities.contains_key(&entity.id) {
continue;
}
let name = entity.name.clone().unwrap_or_else(|| {
entity
.username
.clone()
.unwrap_or_else(|| "Unknown".to_string())
});
let mut entity_detail = HashMap::new();
entity_detail.insert("id".to_string(), serde_json::json!(entity.id.to_string()));
entity_detail.insert("name".to_string(), serde_json::json!(name));
let metadata_json = serde_json::to_value(&entity.metadata).unwrap_or(serde_json::json!({}));
entity_detail.insert("data".to_string(), metadata_json);
unique_entities.insert(entity.id, entity_detail);
}
unique_entities.into_values().collect()
}
pub fn format_entities(entities: &[Entity]) -> String {
entities
.iter()
.map(|entity| {
let name = entity.name.clone().unwrap_or_else(|| {
entity
.username
.clone()
.unwrap_or_else(|| "Unknown".to_string())
});
let mut header = format!("\"{}\"\nID: {}", name, entity.id);
if !entity.metadata.is_empty() {
if let Ok(metadata_str) = serde_json::to_string(&entity.metadata) {
header.push_str(&format!("\nData: {}\n", metadata_str));
} else {
header.push('\n');
}
} else {
header.push('\n');
}
header
})
.collect::<Vec<_>>()
.join("\n")
}
pub fn get_recent_interactions(
source_entity_id: UUID,
candidate_entities: &[Entity],
_room_id: UUID,
recent_messages: &[Memory],
relationships: &[Relationship],
) -> Vec<(Entity, Vec<Memory>, usize)> {
let mut results: Vec<(Entity, Vec<Memory>, usize)> = Vec::new();
for entity in candidate_entities {
let mut interactions: Vec<Memory> = Vec::new();
let mut interaction_score = 0;
let direct_replies: Vec<Memory> = recent_messages
.iter()
.filter(|msg| {
msg.entity_id == source_entity_id || msg.entity_id == entity.id
})
.cloned()
.collect();
interactions.extend(direct_replies.clone());
let relationship = relationships.iter().find(|rel| {
(rel.entity_id_a == source_entity_id && rel.entity_id_b == entity.id)
|| (rel.entity_id_b == source_entity_id && rel.entity_id_a == entity.id)
});
if let Some(rel) = relationship {
if let Some(interactions_count) = rel.metadata.get("interactions") {
if let Some(count) = interactions_count.as_u64() {
interaction_score = count as usize;
}
}
}
interaction_score += direct_replies.len();
let unique_interactions: Vec<Memory> = interactions.into_iter().rev().take(5).collect();
results.push((entity.clone(), unique_interactions, interaction_score));
}
results.sort_by(|a, b| b.2.cmp(&a.2));
results
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::Metadata;
use uuid::Uuid;
#[test]
fn test_create_unique_uuid_for_entity() {
let agent_id = Uuid::new_v4();
let user_id = "user123";
let uuid1 = create_unique_uuid_for_entity(agent_id, user_id);
let uuid2 = create_unique_uuid_for_entity(agent_id, user_id);
assert_eq!(uuid1, uuid2);
assert_ne!(uuid1, agent_id);
let agent_id_result = create_unique_uuid_for_entity(agent_id, &agent_id.to_string());
assert_eq!(agent_id_result, agent_id);
}
#[test]
fn test_format_entities() {
let entity = Entity {
id: Uuid::new_v4(),
agent_id: Uuid::new_v4(),
name: Some("Test User".to_string()),
username: Some("testuser".to_string()),
email: None,
avatar_url: None,
metadata: Metadata::new(),
created_at: Some(12345),
};
let formatted = format_entities(&[entity.clone()]);
assert!(formatted.contains("Test User"));
assert!(formatted.contains(&entity.id.to_string()));
}
#[test]
fn test_extract_xml_tag() {
let xml = "<response><entityId>12345</entityId><type>EXACT_MATCH</type></response>";
let entity_id = extract_xml_tag(xml, "entityId");
assert_eq!(entity_id, Some("12345".to_string()));
let match_type = extract_xml_tag(xml, "type");
assert_eq!(match_type, Some("EXACT_MATCH".to_string()));
let missing = extract_xml_tag(xml, "missing");
assert_eq!(missing, None);
}
#[test]
fn test_extract_xml_section() {
let xml = r#"<response>
<matches>
<match><name>John</name></match>
<match><name>Jane</name></match>
</matches>
</response>"#;
let matches_section = extract_xml_section(xml, "matches");
assert!(matches_section.is_some());
let section = matches_section.unwrap();
assert!(section.contains("<match>"));
assert!(section.contains("John"));
assert!(section.contains("Jane"));
}
#[test]
fn test_extract_xml_sections() {
let xml = r#"<matches>
<match><name>John</name><reason>First match</reason></match>
<match><name>Jane</name><reason>Second match</reason></match>
</matches>"#;
let sections = extract_xml_sections(xml, "match");
assert_eq!(sections.len(), 2);
assert!(sections[0].contains("John"));
assert!(sections[1].contains("Jane"));
}
#[test]
fn test_parse_entity_resolution_xml() {
let xml = r#"<response>
<entityId>550e8400-e29b-41d4-a716-446655440000</entityId>
<type>EXACT_MATCH</type>
<matches>
<match>
<name>John Doe</name>
<reason>Exact ID match</reason>
</match>
</matches>
</response>"#;
let result = parse_entity_resolution_xml(xml);
assert!(result.is_ok());
let resolution = result.unwrap();
assert!(resolution.entity_id.is_some());
assert_eq!(resolution.match_type, MatchType::ExactMatch);
assert_eq!(resolution.matches.len(), 1);
assert_eq!(resolution.matches[0].name, "John Doe");
assert_eq!(resolution.matches[0].reason, "Exact ID match");
}
#[test]
fn test_parse_entity_resolution_xml_no_id() {
let xml = r#"<response>
<entityId>null</entityId>
<type>UNKNOWN</type>
<matches></matches>
</response>"#;
let result = parse_entity_resolution_xml(xml);
assert!(result.is_ok());
let resolution = result.unwrap();
assert!(resolution.entity_id.is_none());
assert_eq!(resolution.match_type, MatchType::Unknown);
assert_eq!(resolution.matches.len(), 0);
}
#[test]
fn test_get_entity_details() {
let room = Room {
id: Uuid::new_v4(),
agent_id: Some(Uuid::new_v4()),
name: "Test Room".to_string(),
source: "test".to_string(),
channel_type: crate::types::ChannelType::GuildText,
channel_id: None,
server_id: None,
world_id: Uuid::new_v4(),
metadata: Metadata::new(),
created_at: Some(12345),
};
let entity1 = Entity {
id: Uuid::new_v4(),
agent_id: room.agent_id.unwrap(),
name: Some("Alice".to_string()),
username: Some("alice".to_string()),
email: None,
avatar_url: None,
metadata: Metadata::new(),
created_at: Some(12345),
};
let entity2 = Entity {
id: Uuid::new_v4(),
agent_id: room.agent_id.unwrap(),
name: Some("Bob".to_string()),
username: Some("bob".to_string()),
email: None,
avatar_url: None,
metadata: Metadata::new(),
created_at: Some(12345),
};
let details = get_entity_details(&room, &[entity1.clone(), entity2.clone()]);
assert_eq!(details.len(), 2);
let names: Vec<String> = details
.iter()
.filter_map(|d| {
d.get("name")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
})
.collect();
assert!(names.contains(&"Alice".to_string()));
assert!(names.contains(&"Bob".to_string()));
}
#[test]
fn test_get_recent_interactions() {
let source_entity_id = Uuid::new_v4();
let target_entity_id = Uuid::new_v4();
let room_id = Uuid::new_v4();
let agent_id = Uuid::new_v4();
let _source_entity = Entity {
id: source_entity_id,
agent_id,
name: Some("Source".to_string()),
username: None,
email: None,
avatar_url: None,
metadata: Metadata::new(),
created_at: Some(12345),
};
let target_entity = Entity {
id: target_entity_id,
agent_id,
name: Some("Target".to_string()),
username: None,
email: None,
avatar_url: None,
metadata: Metadata::new(),
created_at: Some(12345),
};
let messages = vec![];
let relationships = vec![];
let interactions = get_recent_interactions(
source_entity_id,
&[target_entity.clone()],
room_id,
&messages,
&relationships,
);
assert_eq!(interactions.len(), 1);
assert_eq!(interactions[0].0.id, target_entity_id);
assert_eq!(interactions[0].2, 0); }
#[test]
fn test_get_recent_interactions_with_relationship() {
let source_entity_id = Uuid::new_v4();
let target_entity_id = Uuid::new_v4();
let room_id = Uuid::new_v4();
let agent_id = Uuid::new_v4();
let target_entity = Entity {
id: target_entity_id,
agent_id,
name: Some("Target".to_string()),
username: None,
email: None,
avatar_url: None,
metadata: Metadata::new(),
created_at: Some(12345),
};
let mut metadata = Metadata::new();
metadata.insert("interactions".to_string(), serde_json::json!(5));
let relationship = Relationship {
entity_id_a: source_entity_id,
entity_id_b: target_entity_id,
relationship_type: "friend".to_string(),
agent_id,
metadata,
created_at: Some(12345),
};
let messages = vec![];
let relationships = vec![relationship];
let interactions = get_recent_interactions(
source_entity_id,
&[target_entity.clone()],
room_id,
&messages,
&relationships,
);
assert_eq!(interactions.len(), 1);
assert_eq!(interactions[0].0.id, target_entity_id);
assert_eq!(interactions[0].2, 5); }
#[test]
fn test_entity_resolution_config() {
let config = EntityResolutionConfig::default();
assert!(config.use_llm);
assert_eq!(config.cache_ttl, 300);
assert_eq!(config.max_entities, 50);
assert_eq!(config.min_confidence, 0.5);
let custom_config = EntityResolutionConfig {
use_llm: false,
cache_ttl: 600,
max_entities: 100,
context_message_count: 50,
min_confidence: 0.7,
..Default::default()
};
assert!(!custom_config.use_llm);
assert_eq!(custom_config.cache_ttl, 600);
}
#[test]
fn test_generate_cache_key() {
let message = Memory {
id: Uuid::new_v4(),
entity_id: Uuid::new_v4(),
agent_id: Uuid::new_v4(),
room_id: Uuid::new_v4(),
content: crate::types::Content {
text: "Hello world".to_string(),
..Default::default()
},
embedding: None,
metadata: None,
created_at: 12345,
unique: None,
similarity: None,
};
let state = State::new();
let key1 = generate_cache_key(&message, &state);
let key2 = generate_cache_key(&message, &state);
assert_eq!(key1, key2);
let mut message2 = message.clone();
message2.entity_id = Uuid::new_v4();
let key3 = generate_cache_key(&message2, &state);
assert_ne!(key1, key3);
}
#[test]
fn test_match_type() {
assert_eq!(MatchType::ExactMatch, MatchType::ExactMatch);
assert_ne!(MatchType::ExactMatch, MatchType::Unknown);
}
#[test]
fn test_entity_resolution_parsing() {
let xml = r#"<response>
<entityId>550e8400-e29b-41d4-a716-446655440000</entityId>
<type>EXACT_MATCH</type>
<confidence>0.95</confidence>
<matches>
<match>
<name>John Doe</name>
<reason>Exact ID match</reason>
<entityId>550e8400-e29b-41d4-a716-446655440000</entityId>
</match>
</matches>
</response>"#;
let result = parse_entity_resolution_xml(xml);
assert!(result.is_ok());
let resolution = result.unwrap();
assert!(resolution.entity_id.is_some());
assert_eq!(resolution.match_type, MatchType::ExactMatch);
assert_eq!(resolution.confidence, 0.95);
assert_eq!(resolution.matches.len(), 1);
assert_eq!(resolution.matches[0].name, "John Doe");
}
}