use crate::dynamic_prompts::{DynamicPromptExecutor, DynamicPromptOptions, SchemaRow};
use crate::error::Result;
use crate::types::*;
use std::collections::HashMap;
use std::sync::{
atomic::{AtomicU64, Ordering},
Arc, RwLock, RwLockReadGuard, RwLockWriteGuard,
};
use std::time::{SystemTime, UNIX_EPOCH};
use tracing::{debug, error, info, warn};
use uuid::Uuid;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LockRecoveryStrategy {
AlwaysRecover,
AlwaysFail,
RecoverWithLimit { max_recoveries: u64 },
RecoverWithBackoff {
max_recoveries: u64,
initial_delay_ms: u64,
},
}
impl Default for LockRecoveryStrategy {
fn default() -> Self {
LockRecoveryStrategy::AlwaysFail
}
}
#[derive(Debug, Default)]
pub struct LockPoisonMetrics {
pub total_poisoned: AtomicU64,
pub recoveries: AtomicU64,
pub failures: AtomicU64,
pub read_poisoned: AtomicU64,
pub write_poisoned: AtomicU64,
pub last_poisoned_at: Arc<RwLock<Option<u64>>>,
pub lock_poison_counts: Arc<RwLock<HashMap<String, u64>>>,
}
impl LockPoisonMetrics {
pub fn new() -> Self {
Self {
total_poisoned: AtomicU64::new(0),
recoveries: AtomicU64::new(0),
failures: AtomicU64::new(0),
read_poisoned: AtomicU64::new(0),
write_poisoned: AtomicU64::new(0),
last_poisoned_at: Arc::new(RwLock::new(None)),
lock_poison_counts: Arc::new(RwLock::new(HashMap::new())),
}
}
fn record_poisoned(&self, lock_name: &str, is_write: bool) {
self.total_poisoned.fetch_add(1, Ordering::Relaxed);
if is_write {
self.write_poisoned.fetch_add(1, Ordering::Relaxed);
} else {
self.read_poisoned.fetch_add(1, Ordering::Relaxed);
}
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
if let Ok(mut last) = self.last_poisoned_at.write() {
*last = Some(timestamp);
}
if let Ok(mut counts) = self.lock_poison_counts.write() {
*counts.entry(lock_name.to_string()).or_insert(0) += 1;
}
}
fn record_recovery(&self) {
self.recoveries.fetch_add(1, Ordering::Relaxed);
}
fn record_failure(&self) {
self.failures.fetch_add(1, Ordering::Relaxed);
}
pub fn get_summary(&self) -> LockPoisonSummary {
LockPoisonSummary {
total_poisoned: self.total_poisoned.load(Ordering::Relaxed),
recoveries: self.recoveries.load(Ordering::Relaxed),
failures: self.failures.load(Ordering::Relaxed),
read_poisoned: self.read_poisoned.load(Ordering::Relaxed),
write_poisoned: self.write_poisoned.load(Ordering::Relaxed),
last_poisoned_at: self.last_poisoned_at.read().ok().and_then(|g| *g),
lock_poison_counts: self
.lock_poison_counts
.read()
.ok()
.map(|g| g.clone())
.unwrap_or_default(),
}
}
}
#[derive(Debug, Clone)]
pub struct LockPoisonSummary {
pub total_poisoned: u64,
pub recoveries: u64,
pub failures: u64,
pub read_poisoned: u64,
pub write_poisoned: u64,
pub last_poisoned_at: Option<u64>,
pub lock_poison_counts: HashMap<String, u64>,
}
struct LockContext {
name: String,
strategy: LockRecoveryStrategy,
metrics: Arc<LockPoisonMetrics>,
recovery_count: Arc<RwLock<HashMap<String, u64>>>,
}
impl LockContext {
fn new(name: String, strategy: LockRecoveryStrategy, metrics: Arc<LockPoisonMetrics>) -> Self {
Self {
name,
strategy,
metrics,
recovery_count: Arc::new(RwLock::new(HashMap::new())),
}
}
fn should_recover(&self) -> bool {
match self.strategy {
LockRecoveryStrategy::AlwaysRecover => true,
LockRecoveryStrategy::AlwaysFail => false,
LockRecoveryStrategy::RecoverWithLimit { max_recoveries } => {
let count = self
.recovery_count
.read()
.ok()
.and_then(|g| g.get(&self.name).copied())
.unwrap_or(0);
count < max_recoveries
}
LockRecoveryStrategy::RecoverWithBackoff { max_recoveries, .. } => {
let count = self
.recovery_count
.read()
.ok()
.and_then(|g| g.get(&self.name).copied())
.unwrap_or(0);
count < max_recoveries
}
}
}
fn record_recovery_attempt(&self) {
if let Ok(mut counts) = self.recovery_count.write() {
*counts.entry(self.name.clone()).or_insert(0) += 1;
}
}
}
pub trait LockRecovery<T> {
fn read_with_context(
&self,
ctx: &LockContext,
) -> std::result::Result<RwLockReadGuard<'_, T>, crate::ZoeyError>;
fn write_with_context(
&self,
ctx: &LockContext,
) -> std::result::Result<RwLockWriteGuard<'_, T>, crate::ZoeyError>;
fn read_or_recover(&self) -> RwLockReadGuard<'_, T>;
fn write_or_recover(&self) -> RwLockWriteGuard<'_, T>;
fn read_or_fail(&self) -> std::result::Result<RwLockReadGuard<'_, T>, crate::ZoeyError>;
fn write_or_fail(&self) -> std::result::Result<RwLockWriteGuard<'_, T>, crate::ZoeyError>;
}
impl<T> LockRecovery<T> for RwLock<T> {
fn read_with_context(
&self,
ctx: &LockContext,
) -> std::result::Result<RwLockReadGuard<'_, T>, crate::ZoeyError> {
match self.read() {
Ok(guard) => Ok(guard),
Err(poisoned) => {
ctx.metrics.record_poisoned(&ctx.name, false);
if ctx.should_recover() {
ctx.record_recovery_attempt();
ctx.metrics.record_recovery();
error!(
lock_name = %ctx.name,
"CRITICAL: Recovered from poisoned read lock. This indicates a previous panic. \
State may be inconsistent. Recovery attempt #{}",
ctx.recovery_count.read()
.ok()
.and_then(|g| g.get(&ctx.name).copied())
.unwrap_or(0)
);
#[cfg(any())]
{
error!("Backtrace: {:?}", std::backtrace::Backtrace::capture());
}
if let LockRecoveryStrategy::RecoverWithBackoff {
initial_delay_ms, ..
} = ctx.strategy
{
let recovery_count = ctx
.recovery_count
.read()
.ok()
.and_then(|g| g.get(&ctx.name).copied())
.unwrap_or(0);
let delay_ms = initial_delay_ms * (1 << recovery_count.min(10)); std::thread::sleep(std::time::Duration::from_millis(delay_ms));
}
Ok(poisoned.into_inner())
} else {
ctx.metrics.record_failure();
error!(
lock_name = %ctx.name,
"CRITICAL: Attempted to acquire read lock but it is poisoned. \
A previous thread panicked while holding this lock. State may be corrupted. \
Recovery strategy: {:?}",
ctx.strategy
);
#[cfg(any())]
{
error!("Backtrace: {:?}", std::backtrace::Backtrace::capture());
}
Err(crate::ZoeyError::runtime(format!(
"Lock '{}' is poisoned - a previous thread panicked while holding this lock. \
This indicates a serious error. The runtime state may be corrupted. \
Recovery strategy: {:?}",
ctx.name, ctx.strategy
)))
}
}
}
}
fn write_with_context(
&self,
ctx: &LockContext,
) -> std::result::Result<RwLockWriteGuard<'_, T>, crate::ZoeyError> {
match self.write() {
Ok(guard) => Ok(guard),
Err(poisoned) => {
ctx.metrics.record_poisoned(&ctx.name, true);
if ctx.should_recover() {
ctx.record_recovery_attempt();
ctx.metrics.record_recovery();
error!(
lock_name = %ctx.name,
"CRITICAL: Recovered from poisoned write lock. This indicates a previous panic. \
State may be inconsistent. Recovery attempt #{}",
ctx.recovery_count.read()
.ok()
.and_then(|g| g.get(&ctx.name).copied())
.unwrap_or(0)
);
#[cfg(any())]
{
error!("Backtrace: {:?}", std::backtrace::Backtrace::capture());
}
if let LockRecoveryStrategy::RecoverWithBackoff {
initial_delay_ms, ..
} = ctx.strategy
{
let recovery_count = ctx
.recovery_count
.read()
.ok()
.and_then(|g| g.get(&ctx.name).copied())
.unwrap_or(0);
let delay_ms = initial_delay_ms * (1 << recovery_count.min(10)); std::thread::sleep(std::time::Duration::from_millis(delay_ms));
}
Ok(poisoned.into_inner())
} else {
ctx.metrics.record_failure();
error!(
lock_name = %ctx.name,
"CRITICAL: Attempted to acquire write lock but it is poisoned. \
A previous thread panicked while holding this lock. State may be corrupted. \
Recovery strategy: {:?}",
ctx.strategy
);
#[cfg(any())]
{
error!("Backtrace: {:?}", std::backtrace::Backtrace::capture());
}
Err(crate::ZoeyError::runtime(format!(
"Lock '{}' is poisoned - a previous thread panicked while holding this lock. \
This indicates a serious error. The runtime state may be corrupted. \
Recovery strategy: {:?}",
ctx.name, ctx.strategy
)))
}
}
}
}
fn read_or_recover(&self) -> RwLockReadGuard<'_, T> {
self.read().unwrap_or_else(|poisoned| {
error!(
"CRITICAL: Recovered from poisoned read lock. This indicates a previous panic. \
State may be inconsistent. Consider using read_or_fail() for safer behavior."
);
#[cfg(any())]
{
error!("Backtrace: {:?}", std::backtrace::Backtrace::capture());
}
poisoned.into_inner()
})
}
fn write_or_recover(&self) -> RwLockWriteGuard<'_, T> {
self.write().unwrap_or_else(|poisoned| {
error!(
"CRITICAL: Recovered from poisoned write lock. This indicates a previous panic. \
State may be inconsistent. Consider using write_or_fail() for safer behavior."
);
#[cfg(any())]
{
error!("Backtrace: {:?}", std::backtrace::Backtrace::capture());
}
poisoned.into_inner()
})
}
fn read_or_fail(&self) -> std::result::Result<RwLockReadGuard<'_, T>, crate::ZoeyError> {
self.read().map_err(|_poisoned| {
error!(
"CRITICAL: Attempted to acquire read lock but it is poisoned. \
A previous thread panicked while holding this lock. State may be corrupted."
);
#[cfg(any())]
{
error!("Backtrace: {:?}", std::backtrace::Backtrace::capture());
}
crate::ZoeyError::runtime(
"Lock is poisoned - a previous thread panicked while holding this lock. \
This indicates a serious error. The runtime state may be corrupted.",
)
})
}
fn write_or_fail(&self) -> std::result::Result<RwLockWriteGuard<'_, T>, crate::ZoeyError> {
self.write().map_err(|_poisoned| {
error!(
"CRITICAL: Attempted to acquire write lock but it is poisoned. \
A previous thread panicked while holding this lock. State may be corrupted."
);
#[cfg(any())]
{
error!("Backtrace: {:?}", std::backtrace::Backtrace::capture());
}
crate::ZoeyError::runtime(
"Lock is poisoned - a previous thread panicked while holding this lock. \
This indicates a serious error. The runtime state may be corrupted.",
)
})
}
}
pub struct AgentRuntime {
pub agent_id: Uuid,
pub character: Character,
pub(crate) adapter: Arc<RwLock<Option<Arc<dyn IDatabaseAdapter + Send + Sync>>>>,
pub(crate) actions: Arc<RwLock<Vec<Arc<dyn Action>>>>,
pub(crate) evaluators: Arc<RwLock<Vec<Arc<dyn Evaluator>>>>,
pub(crate) providers: Arc<RwLock<Vec<Arc<dyn Provider>>>>,
pub(crate) services: Arc<RwLock<HashMap<ServiceTypeName, Vec<Arc<dyn Service>>>>>,
pub(crate) typed_services: Arc<RwLock<HashMap<String, Arc<dyn Service>>>>,
pub(crate) models: Arc<RwLock<HashMap<String, Vec<ModelProvider>>>>,
plugins: Arc<RwLock<Vec<Arc<dyn Plugin>>>>,
pub(crate) events: Arc<RwLock<HashMap<String, Vec<EventHandler>>>>,
pub(crate) state_cache: Arc<RwLock<HashMap<String, State>>>,
logger: Arc<RwLock<tracing::Span>>,
pub(crate) settings: Arc<RwLock<HashMap<String, serde_json::Value>>>,
routes: Arc<RwLock<Vec<Route>>>,
task_workers: Arc<RwLock<HashMap<String, Arc<dyn TaskWorker>>>>,
send_handlers: Arc<RwLock<HashMap<String, SendHandlerFunction>>>,
message_service: Arc<RwLock<Option<Arc<dyn IMessageService>>>>,
pub(crate) conversation_length: usize,
pub(crate) current_run_id: Arc<RwLock<Option<Uuid>>>,
pub(crate) action_results: Arc<RwLock<HashMap<Uuid, Vec<ActionResult>>>>,
zoey_os: Arc<RwLock<Option<Arc<dyn std::any::Any + Send + Sync>>>>,
dynamic_prompt_executor: Arc<DynamicPromptExecutor>,
pub observability: Arc<RwLock<Option<Arc<crate::observability::Observability>>>>,
lock_recovery_strategy: LockRecoveryStrategy,
lock_poison_metrics: Arc<LockPoisonMetrics>,
training_collector: Option<Arc<crate::training::TrainingCollector>>,
}
#[derive(Default)]
pub struct RuntimeOpts {
pub agent_id: Option<Uuid>,
pub character: Option<Character>,
pub plugins: Vec<Arc<dyn Plugin>>,
pub adapter: Option<Arc<dyn IDatabaseAdapter + Send + Sync>>,
pub settings: Option<HashMap<String, serde_json::Value>>,
pub conversation_length: Option<usize>,
pub all_available_plugins: Option<Vec<Arc<dyn Plugin>>>,
pub lock_recovery_strategy: Option<LockRecoveryStrategy>,
pub test_mode: Option<bool>,
}
impl RuntimeOpts {
pub fn new() -> Self {
Self::default()
}
pub fn with_character(mut self, character: Character) -> Self {
self.character = Some(character);
self
}
pub fn with_agent_id(mut self, agent_id: Uuid) -> Self {
self.agent_id = Some(agent_id);
self
}
pub fn with_adapter(mut self, adapter: Arc<dyn IDatabaseAdapter + Send + Sync>) -> Self {
self.adapter = Some(adapter);
self
}
pub fn with_plugins(mut self, plugins: Vec<Arc<dyn Plugin>>) -> Self {
self.plugins = plugins;
self
}
pub fn with_plugin(mut self, plugin: Arc<dyn Plugin>) -> Self {
self.plugins.push(plugin);
self
}
pub fn with_conversation_length(mut self, length: usize) -> Self {
self.conversation_length = Some(length);
self
}
pub fn with_settings(mut self, settings: HashMap<String, serde_json::Value>) -> Self {
self.settings = Some(settings);
self
}
pub fn with_lock_recovery_strategy(mut self, strategy: LockRecoveryStrategy) -> Self {
self.lock_recovery_strategy = Some(strategy);
self
}
}
impl AgentRuntime {
pub async fn new(opts: RuntimeOpts) -> Result<Arc<RwLock<Self>>> {
let character = opts.character.unwrap_or_default();
let agent_id = opts.agent_id.unwrap_or_else(|| {
character
.id
.unwrap_or_else(|| crate::utils::string_to_uuid(&character.name))
});
#[cfg(feature = "otel")]
{
crate::infrastructure::otel::init_otel();
}
let logger_span = tracing::span!(tracing::Level::INFO, "agent", name = %character.name);
let max_cache_entries = std::env::var("DYNAMIC_PROMPT_MAX_ENTRIES")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(1000);
let lock_recovery_strategy = opts.lock_recovery_strategy.unwrap_or_default();
let lock_poison_metrics = Arc::new(LockPoisonMetrics::new());
let mut initial_settings = opts.settings.unwrap_or_default();
initial_settings
.entry("ui:streaming".to_string())
.or_insert(serde_json::json!(true));
initial_settings
.entry("ui:provider_racing".to_string())
.or_insert(serde_json::json!(true));
initial_settings
.entry("ui:fast_mode".to_string())
.or_insert(serde_json::json!(true));
initial_settings
.entry("ui:verbosity".to_string())
.or_insert(serde_json::json!("short"));
initial_settings
.entry("ui:avoid_cutoff".to_string())
.or_insert(serde_json::json!(false));
let test_mode = opts.test_mode.unwrap_or_else(|| {
std::env::var("ZOEY_TEST_MODE")
.ok()
.and_then(|v| v.parse::<bool>().ok())
.unwrap_or(false)
});
if test_mode {
initial_settings.insert("ui:streaming".to_string(), serde_json::json!(false));
initial_settings.insert("ui:provider_racing".to_string(), serde_json::json!(false));
initial_settings.insert("ui:fast_mode".to_string(), serde_json::json!(true));
}
let training_config = crate::training::TrainingConfig {
enabled: true,
min_quality_score: 0.5,
max_samples: 10000,
auto_save_interval: 300, output_dir: std::path::PathBuf::from("./training_data"),
default_format: crate::training::TrainingFormat::Jsonl,
include_thoughts: true,
include_negative_examples: true,
negative_example_ratio: 0.1,
enable_rlhf: true,
auto_label: true,
};
let training_collector = Some(Arc::new(crate::training::TrainingCollector::new(training_config)));
let runtime = Self {
agent_id,
character,
adapter: Arc::new(RwLock::new(opts.adapter)),
actions: Arc::new(RwLock::new(Vec::new())),
evaluators: Arc::new(RwLock::new(Vec::new())),
providers: Arc::new(RwLock::new(Vec::new())),
services: Arc::new(RwLock::new(HashMap::new())),
typed_services: Arc::new(RwLock::new(HashMap::new())),
models: Arc::new(RwLock::new(HashMap::new())),
plugins: Arc::new(RwLock::new(Vec::new())),
events: Arc::new(RwLock::new(HashMap::new())),
state_cache: Arc::new(RwLock::new(HashMap::new())),
logger: Arc::new(RwLock::new(logger_span)),
settings: Arc::new(RwLock::new(initial_settings)),
routes: Arc::new(RwLock::new(Vec::new())),
task_workers: Arc::new(RwLock::new(HashMap::new())),
send_handlers: Arc::new(RwLock::new(HashMap::new())),
message_service: Arc::new(RwLock::new(None)),
conversation_length: opts.conversation_length.unwrap_or(32),
current_run_id: Arc::new(RwLock::new(None)),
action_results: Arc::new(RwLock::new(HashMap::new())),
zoey_os: Arc::new(RwLock::new(None)),
dynamic_prompt_executor: Arc::new(DynamicPromptExecutor::new(Some(max_cache_entries))),
observability: Arc::new(RwLock::new(None)),
lock_recovery_strategy,
lock_poison_metrics,
training_collector,
};
let runtime_arc = Arc::new(RwLock::new(runtime));
if !test_mode {
for plugin in opts.plugins {
let mut rt = runtime_arc.write().unwrap();
rt.register_plugin_internal(plugin).await?;
}
}
debug!("runtime_new:post_plugins");
if !test_mode {
{
let mut rt = runtime_arc.write().unwrap();
rt.register_task_worker(
"embedding_generation".to_string(),
Arc::new(crate::workers::embedding_worker::EmbeddingWorker::new()),
);
}
}
debug!("runtime_new:return");
Ok(runtime_arc)
}
async fn register_plugin_internal(&mut self, plugin: Arc<dyn Plugin>) -> Result<()> {
info!("Registering plugin: {}", plugin.name());
debug!("plugin_register:start name={}", plugin.name());
for action in plugin.actions() {
self.actions.write_or_recover().push(action);
}
debug!("plugin_register:actions name={}", plugin.name());
for provider in plugin.providers() {
self.providers.write_or_recover().push(provider);
}
debug!("plugin_register:providers name={}", plugin.name());
for evaluator in plugin.evaluators() {
self.evaluators.write_or_recover().push(evaluator);
}
debug!("plugin_register:evaluators name={}", plugin.name());
for service in plugin.services() {
let service_type = service.service_type().to_string();
self.services
.write_or_recover()
.entry(service_type.clone())
.or_insert_with(Vec::new)
.push(service.clone());
self.typed_services
.write_or_recover()
.entry(service_type)
.or_insert_with(|| service.clone());
}
debug!("plugin_register:services name={}", plugin.name());
for (model_type, handler) in plugin.models() {
let provider_info = ModelProvider {
name: plugin.name().to_string(),
handler,
priority: plugin.priority(),
};
self.models
.write_or_recover()
.entry(model_type)
.or_insert_with(Vec::new)
.push(provider_info);
}
for handlers in self.models.write_or_recover().values_mut() {
handlers.sort_by(|a, b| b.priority.cmp(&a.priority));
}
debug!("plugin_register:models name={}", plugin.name());
for (event_type, handlers) in plugin.events() {
self.events
.write_or_recover()
.entry(event_type)
.or_insert_with(Vec::new)
.extend(handlers);
}
debug!("plugin_register:events name={}", plugin.name());
self.routes.write_or_recover().extend(plugin.routes());
debug!("plugin_register:routes name={}", plugin.name());
self.plugins.write_or_recover().push(plugin.clone());
debug!("plugin_register:done name={}", plugin.name());
Ok(())
}
pub async fn initialize(&mut self, options: InitializeOptions) -> Result<()> {
info!("Initializing runtime for agent: {}", self.character.name);
if let Some(adapter) = self.adapter.read_or_recover().clone() {
debug!("Initializing database connection...");
match adapter.is_ready().await {
Ok(true) => {
info!("✓ Database connection is ready");
}
Ok(false) => {
warn!("Database is not ready, attempting initialization...");
return Err(crate::ZoeyError::database(
"Database is not ready. Please initialize the adapter before passing to runtime."
));
}
Err(e) => {
warn!("Failed to check database readiness: {}", e);
return Err(crate::ZoeyError::database(format!(
"Database readiness check failed: {}",
e
)));
}
}
if !options.skip_migrations {
debug!("Checking for plugin migrations...");
let plugins = self.plugins.read_or_recover();
if !plugins.is_empty() {
info!("Running migrations for {} plugin(s)...", plugins.len());
let mut plugin_migrations = Vec::new();
for plugin in plugins.iter() {
let schema = plugin.schema();
plugin_migrations.push(PluginMigration {
name: plugin.name().to_string(),
schema,
});
}
if !plugin_migrations.is_empty() {
match adapter
.run_plugin_migrations(
plugin_migrations,
MigrationOptions {
verbose: false,
force: false,
dry_run: false,
},
)
.await
{
Ok(_) => info!("✓ Plugin migrations completed successfully"),
Err(e) => {
warn!("Plugin migration failed: {}", e);
return Err(crate::ZoeyError::database(format!(
"Plugin migration failed: {}",
e
)));
}
}
} else {
debug!("No plugin migrations required");
}
} else {
debug!("No plugins registered, skipping migrations");
}
} else {
debug!("Skipping migrations (skip_migrations=true)");
}
match adapter.get_agent(self.agent_id).await {
Ok(Some(_)) => debug!("Agent already exists in database"),
Ok(None) | Err(_) => {
info!("Registering agent '{}' in database...", self.character.name);
let agent = crate::types::Agent {
id: self.agent_id,
name: self.character.name.clone(),
character: serde_json::to_value(&self.character).unwrap_or_default(),
created_at: Some(chrono::Utc::now().timestamp()),
updated_at: None,
};
match adapter.create_agent(&agent).await {
Ok(true) => info!("✓ Agent registered successfully"),
Ok(false) => warn!("Agent may already exist (create returned false)"),
Err(e) => warn!("Failed to register agent: {} - continuing anyway", e),
}
}
}
} else {
warn!("No database adapter configured - running without persistence");
}
let service_map = self.services.read_or_recover();
if !service_map.is_empty() {
info!("Initializing {} service type(s)...", service_map.len());
for (service_type, services) in service_map.iter() {
for _service in services {
debug!("Initializing service: {}", service_type);
}
}
}
info!(
"✓ Runtime initialization complete for agent '{}'",
self.character.name
);
Ok(())
}
pub async fn compose_state(
&self,
message: &Memory,
include_list: Option<Vec<String>>,
only_include: bool,
skip_cache: bool,
) -> Result<State> {
crate::runtime::RuntimeState::new()
.compose_state_impl(self, message, include_list, only_include, skip_cache)
.await
}
pub async fn add_embedding_to_memory(&self, memory: &Memory) -> Result<Memory> {
if let Some(adapter) = crate::runtime::RuntimeState::get_adapter(self) {
let mut updated = memory.clone();
let _ = adapter.update_memory(&updated).await?;
Ok(updated)
} else {
Err(crate::ZoeyError::runtime("No adapter configured"))
}
}
pub async fn queue_embedding_generation(
&self,
memory: &Memory,
_priority: crate::types::EmbeddingPriority,
) -> Result<()> {
let provider_opt = {
let models = self.models.read_or_recover();
models
.get(&crate::types::ModelType::TextEmbedding.to_string())
.and_then(|v| v.first().cloned())
};
let adapter_opt = crate::runtime::RuntimeState::get_adapter(self);
if let (Some(provider), Some(adapter)) = (provider_opt, adapter_opt) {
let params = crate::types::GenerateTextParams {
prompt: memory.content.text.clone(),
max_tokens: None,
temperature: None,
top_p: None,
stop: None,
model: None,
frequency_penalty: None,
presence_penalty: None,
};
let mh_params = crate::types::ModelHandlerParams {
runtime: Arc::new(()),
params,
};
let raw = (provider.handler)(mh_params).await?;
if let Ok(vec) = serde_json::from_str::<Vec<f32>>(&raw) {
let mut updated = memory.clone();
updated.embedding = Some(vec);
let _ = adapter.update_memory(&updated).await?;
}
Ok(())
} else {
Err(crate::ZoeyError::runtime(
"TEXT_EMBEDDING provider or adapter not available",
))
}
}
pub async fn get_all_memories(&self) -> Result<Vec<Memory>> {
if let Some(adapter) = crate::runtime::RuntimeState::get_adapter(self) {
adapter
.get_memories(crate::types::MemoryQuery {
table_name: "messages".to_string(),
..Default::default()
})
.await
} else {
Err(crate::ZoeyError::runtime("No adapter configured"))
}
}
pub fn create_run_id(&self) -> Uuid {
crate::runtime::lifecycle::create_run_id(self)
}
pub fn start_run(&mut self) -> Uuid {
crate::runtime::lifecycle::start_run(self)
}
pub fn end_run(&mut self) {
crate::runtime::lifecycle::end_run(self)
}
pub fn get_current_run_id(&self) -> Option<Uuid> {
crate::runtime::lifecycle::get_current_run_id(self)
}
pub fn get_action_results(&self, message_id: Uuid) -> Vec<ActionResult> {
self.action_results
.read_or_recover()
.get(&message_id)
.cloned()
.unwrap_or_default()
}
pub fn set_action_results(&self, message_id: Uuid, results: Vec<ActionResult>) {
self.action_results
.write_or_recover()
.insert(message_id, results);
}
}
impl AgentRuntime {
pub fn get_training_collector(&self) -> Option<Arc<crate::training::TrainingCollector>> {
self.training_collector.clone()
}
pub fn get_actions(&self) -> Vec<Arc<dyn Action>> {
crate::plugin_system::registry::get_actions(self)
}
pub fn get_providers(&self) -> Vec<Arc<dyn Provider>> {
crate::plugin_system::registry::get_providers(self)
}
pub fn get_evaluators(&self) -> Vec<Arc<dyn Evaluator>> {
crate::plugin_system::registry::get_evaluators(self)
}
pub fn get_service(&self, service_type: &str) -> Option<Arc<dyn Service>> {
crate::runtime::RuntimeState::get_service(self, service_type)
}
pub fn get_services_count(&self) -> usize {
crate::runtime::RuntimeState::get_services_count(self)
}
pub fn get_all_services(&self) -> HashMap<ServiceTypeName, Vec<Arc<dyn Service>>> {
crate::runtime::RuntimeState::get_all_services(self)
}
pub fn get_service_by_name(&self, name: &str) -> Option<Arc<dyn Service>> {
self.typed_services.read_or_recover().get(name).cloned()
}
pub fn register_provider(&mut self, provider: Arc<dyn Provider>) {
crate::plugin_system::registry::register_provider(self, provider)
}
pub fn register_action(&mut self, action: Arc<dyn Action>) {
crate::plugin_system::registry::register_action(self, action)
}
pub fn register_evaluator(&mut self, evaluator: Arc<dyn Evaluator>) {
crate::plugin_system::registry::register_evaluator(self, evaluator)
}
pub fn set_setting(&mut self, key: &str, value: serde_json::Value, _secret: bool) {
crate::runtime::RuntimeState::set_setting(self, key, value);
}
pub fn get_setting(&self, key: &str) -> Option<serde_json::Value> {
crate::runtime::RuntimeState::get_setting(self, key)
}
pub fn get_setting_string(&self, key: &str) -> Option<String> {
crate::runtime::RuntimeState::get_setting_string(self, key)
}
pub fn get_settings_with_prefix(&self, prefix: &str) -> Vec<(String, String)> {
crate::runtime::RuntimeState::get_settings_with_prefix(self, prefix)
}
pub fn logger(&self) -> Arc<RwLock<tracing::Span>> {
Arc::clone(&self.logger)
}
pub fn get_conversation_length(&self) -> usize {
crate::runtime::RuntimeState::get_conversation_length(self)
}
pub fn message_service(&self) -> Option<Arc<dyn IMessageService>> {
crate::runtime::RuntimeState::message_service(self)
}
pub fn register_send_handler(&mut self, source: String, handler: SendHandlerFunction) {
self.send_handlers
.write_or_recover()
.insert(source, handler);
}
pub fn get_send_handler(&self, source: &str) -> Option<SendHandlerFunction> {
self.send_handlers.read_or_recover().get(source).cloned()
}
pub fn register_task_worker(&mut self, name: String, worker: Arc<dyn TaskWorker>) {
self.task_workers.write_or_recover().insert(name, worker);
}
pub fn get_task_worker(&self, name: &str) -> Option<Arc<dyn TaskWorker>> {
self.task_workers.read_or_recover().get(name).cloned()
}
pub fn get_task_workers(&self) -> HashMap<String, Arc<dyn TaskWorker>> {
self.task_workers.read_or_recover().clone()
}
pub fn zoey_os(&self) -> Option<Arc<dyn std::any::Any + Send + Sync>> {
self.zoey_os.read_or_recover().clone()
}
pub fn set_zoey_os(&mut self, instance: Arc<dyn std::any::Any + Send + Sync>) {
*self.zoey_os.write_or_recover() = Some(instance);
}
pub fn dynamic_prompt_executor(&self) -> Arc<DynamicPromptExecutor> {
Arc::clone(&self.dynamic_prompt_executor)
}
pub async fn dynamic_prompt_exec_from_state(
&self,
state: &State,
schema: Vec<SchemaRow>,
prompt_template: &str,
options: DynamicPromptOptions,
) -> Result<HashMap<String, serde_json::Value>> {
let executor = &self.dynamic_prompt_executor;
let models = self.models.read_or_recover();
let model_identifier =
options
.model
.clone()
.unwrap_or_else(|| match options.model_size.as_deref() {
Some("small") => "TEXT_SMALL".to_string(),
_ => "TEXT_LARGE".to_string(),
});
let _handlers = models.get(&model_identifier).cloned();
let model_fn = |_prompt: String, _opts: DynamicPromptOptions| async move {
Ok(format!(
"<response><error>Model not connected</error></response>"
))
};
executor
.execute_from_state(state, schema, prompt_template, options, model_fn)
.await
}
pub fn get_dynamic_prompt_metrics(&self) -> crate::dynamic_prompts::MetricsSummary {
self.dynamic_prompt_executor.get_metrics_summary()
}
pub fn clear_dynamic_prompt_metrics(&self) {
self.dynamic_prompt_executor.clear_metrics();
}
pub fn get_adapter(&self) -> Option<Arc<dyn IDatabaseAdapter + Send + Sync>> {
self.adapter.read_or_recover().clone()
}
pub fn get_models(&self) -> HashMap<String, Vec<ModelProvider>> {
crate::plugin_system::registry::get_models(self)
}
fn get_lock_context(&self, lock_name: &str) -> LockContext {
LockContext::new(
lock_name.to_string(),
self.lock_recovery_strategy,
Arc::clone(&self.lock_poison_metrics),
)
}
pub fn get_lock_recovery_strategy(&self) -> LockRecoveryStrategy {
self.lock_recovery_strategy
}
pub fn set_lock_recovery_strategy(&mut self, strategy: LockRecoveryStrategy) {
self.lock_recovery_strategy = strategy;
info!("Lock recovery strategy changed to: {:?}", strategy);
}
pub fn get_lock_poison_metrics(&self) -> LockPoisonSummary {
self.lock_poison_metrics.get_summary()
}
pub fn reset_lock_poison_metrics(&self) {
self.lock_poison_metrics
.total_poisoned
.store(0, Ordering::Relaxed);
self.lock_poison_metrics
.recoveries
.store(0, Ordering::Relaxed);
self.lock_poison_metrics
.failures
.store(0, Ordering::Relaxed);
self.lock_poison_metrics
.read_poisoned
.store(0, Ordering::Relaxed);
self.lock_poison_metrics
.write_poisoned
.store(0, Ordering::Relaxed);
if let Ok(mut last) = self.lock_poison_metrics.last_poisoned_at.write() {
*last = None;
}
if let Ok(mut counts) = self.lock_poison_metrics.lock_poison_counts.write() {
counts.clear();
}
}
pub fn has_poisoned_locks(&self) -> bool {
self.lock_poison_metrics
.total_poisoned
.load(Ordering::Relaxed)
> 0
}
pub fn get_lock_health_status(&self) -> LockHealthStatus {
let summary = self.lock_poison_metrics.get_summary();
let mut counts: Vec<(String, u64)> = summary.lock_poison_counts.into_iter().collect();
counts.sort_by(|a, b| b.1.cmp(&a.1));
let most_poisoned_locks = counts.into_iter().take(5).collect();
let is_healthy = summary.total_poisoned == 0 && summary.failures == 0;
LockHealthStatus {
is_healthy,
total_poisoned: summary.total_poisoned,
recoveries: summary.recoveries,
failures: summary.failures,
most_poisoned_locks,
}
}
}
use super::lifecycle::LockHealthStatus;
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_runtime_creation() {
let opts = RuntimeOpts {
character: Some(Character {
name: "TestAgent".to_string(),
..Default::default()
}),
..Default::default()
};
let runtime = AgentRuntime::new(opts).await.unwrap();
let rt = runtime.read().unwrap();
assert_eq!(rt.character.name, "TestAgent");
let expected_id = crate::utils::string_to_uuid("TestAgent");
assert_eq!(rt.agent_id, expected_id);
assert_eq!(rt.get_conversation_length(), 32);
}
#[tokio::test]
async fn test_runtime_with_custom_agent_id() {
let custom_id = Uuid::new_v4();
let opts = RuntimeOpts {
agent_id: Some(custom_id),
character: Some(Character {
name: "TestAgent".to_string(),
..Default::default()
}),
..Default::default()
};
let runtime = AgentRuntime::new(opts).await.unwrap();
let rt = runtime.read().unwrap();
assert_eq!(rt.agent_id, custom_id);
}
#[tokio::test]
async fn test_state_composition_empty_providers() {
let opts = RuntimeOpts {
character: Some(Character {
name: "TestAgent".to_string(),
..Default::default()
}),
..Default::default()
};
let runtime = AgentRuntime::new(opts).await.unwrap();
let rt = runtime.read().unwrap();
let message = Memory {
id: Uuid::new_v4(),
entity_id: Uuid::new_v4(),
agent_id: rt.agent_id,
room_id: Uuid::new_v4(),
content: Content::default(),
embedding: None,
metadata: None,
created_at: chrono::Utc::now().timestamp(),
unique: None,
similarity: None,
};
let state = rt
.compose_state(&message, None, false, false)
.await
.unwrap();
assert!(
state.values.is_empty(),
"State should be empty with no providers registered"
);
assert!(
state.data.is_empty(),
"Data should be empty with no providers registered"
);
}
#[tokio::test]
async fn test_settings_management() {
let opts = RuntimeOpts {
character: Some(Character {
name: "TestAgent".to_string(),
..Default::default()
}),
..Default::default()
};
let runtime = AgentRuntime::new(opts).await.unwrap();
{
let mut rt = runtime.write().unwrap();
rt.set_setting("test_key", serde_json::json!("test_value"), false);
}
{
let rt = runtime.read().unwrap();
let value = rt.get_setting("test_key");
assert!(value.is_some(), "Setting should exist");
assert_eq!(value.unwrap(), serde_json::json!("test_value"));
let missing = rt.get_setting("nonexistent");
assert!(missing.is_none(), "Non-existent setting should return None");
}
}
#[tokio::test]
async fn test_run_id_management() {
let opts = RuntimeOpts {
character: Some(Character {
name: "TestAgent".to_string(),
..Default::default()
}),
..Default::default()
};
let runtime = AgentRuntime::new(opts).await.unwrap();
{
let rt = runtime.read().unwrap();
assert!(rt.get_current_run_id().is_none(), "Initially no run ID");
}
let run_id = {
let mut rt = runtime.write().unwrap();
rt.start_run()
};
{
let rt = runtime.read().unwrap();
assert_eq!(
rt.get_current_run_id(),
Some(run_id),
"Run ID should be set"
);
}
{
let mut rt = runtime.write().unwrap();
rt.end_run();
}
{
let rt = runtime.read().unwrap();
assert!(
rt.get_current_run_id().is_none(),
"Run ID should be cleared"
);
}
}
}