use super::callable::{Callable, ExecutionContext, Value};
use crate::error::{DbxError, DbxResult};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
pub struct ExecutionEngine {
callables: RwLock<HashMap<String, Arc<dyn Callable>>>,
metrics: Arc<RwLock<ExecutionMetrics>>,
}
impl ExecutionEngine {
pub fn new() -> Self {
Self {
callables: RwLock::new(HashMap::new()),
metrics: Arc::new(RwLock::new(ExecutionMetrics::new())),
}
}
pub fn register(&self, callable: Arc<dyn Callable>) -> DbxResult<()> {
let name = callable.name().to_string();
let mut callables = self.callables.write().map_err(|_| DbxError::LockPoisoned)?;
if callables.contains_key(&name) {
return Err(DbxError::DuplicateCallable(name));
}
callables.insert(name, callable);
Ok(())
}
pub fn unregister(&self, name: &str) -> DbxResult<()> {
let mut callables = self.callables.write().map_err(|_| DbxError::LockPoisoned)?;
callables
.remove(name)
.ok_or_else(|| DbxError::CallableNotFound(name.to_string()))?;
Ok(())
}
pub fn execute(&self, name: &str, ctx: &ExecutionContext, args: &[Value]) -> DbxResult<Value> {
let callables = self.callables.read().map_err(|_| DbxError::LockPoisoned)?;
let callable = callables
.get(name)
.ok_or_else(|| DbxError::CallableNotFound(name.to_string()))?
.clone();
drop(callables);
let start = Instant::now();
let result = callable.call(ctx, args);
let elapsed = start.elapsed();
let success = result.is_ok();
if let Ok(mut metrics) = self.metrics.write() {
metrics.record(name, elapsed, success);
}
result
}
pub fn list(&self) -> DbxResult<Vec<String>> {
let callables = self.callables.read().map_err(|_| DbxError::LockPoisoned)?;
Ok(callables.keys().cloned().collect())
}
pub fn metrics(&self) -> DbxResult<ExecutionMetrics> {
let metrics = self.metrics.read().map_err(|_| DbxError::LockPoisoned)?;
Ok(metrics.clone())
}
}
impl Default for ExecutionEngine {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ExecutionMetrics {
pub call_counts: HashMap<String, u64>,
pub total_durations: HashMap<String, Duration>,
pub error_counts: HashMap<String, u64>,
}
impl ExecutionMetrics {
pub fn new() -> Self {
Self {
call_counts: HashMap::new(),
total_durations: HashMap::new(),
error_counts: HashMap::new(),
}
}
pub fn record(&mut self, name: &str, duration: Duration, success: bool) {
*self.call_counts.entry(name.to_string()).or_insert(0) += 1;
*self
.total_durations
.entry(name.to_string())
.or_insert(Duration::ZERO) += duration;
if !success {
*self.error_counts.entry(name.to_string()).or_insert(0) += 1;
}
}
pub fn avg_duration(&self, name: &str) -> Option<Duration> {
let total = self.total_durations.get(name)?;
let count = self.call_counts.get(name)?;
if *count == 0 {
return None;
}
Some(*total / (*count as u32))
}
pub fn success_rate(&self, name: &str) -> Option<f64> {
let total = *self.call_counts.get(name)?;
let errors = self.error_counts.get(name).copied().unwrap_or(0);
if total == 0 {
return None;
}
Some((total - errors) as f64 / total as f64)
}
}
impl Default for ExecutionMetrics {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::automation::callable::{DataType, Signature};
struct TestCallable {
name: String,
signature: Signature,
}
impl TestCallable {
fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
signature: Signature {
params: vec![DataType::Int],
return_type: DataType::Int,
is_variadic: false,
},
}
}
}
impl Callable for TestCallable {
fn call(&self, _ctx: &ExecutionContext, args: &[Value]) -> DbxResult<Value> {
Ok(args.first().cloned().unwrap_or(Value::Null))
}
fn name(&self) -> &str {
&self.name
}
fn signature(&self) -> &Signature {
&self.signature
}
}
#[test]
fn test_register_and_execute() {
let engine = ExecutionEngine::new();
let callable = Arc::new(TestCallable::new("test_func"));
engine.register(callable).unwrap();
let ctx =
ExecutionContext::new(Arc::new(crate::engine::Database::open_in_memory().unwrap()));
let result = engine
.execute("test_func", &ctx, &[Value::Int(42)])
.unwrap();
assert_eq!(result.as_i64().unwrap(), 42);
}
#[test]
fn test_duplicate_registration() {
let engine = ExecutionEngine::new();
let callable1 = Arc::new(TestCallable::new("test_func"));
let callable2 = Arc::new(TestCallable::new("test_func"));
engine.register(callable1).unwrap();
let result = engine.register(callable2);
assert!(result.is_err());
}
#[test]
fn test_metrics() {
let engine = ExecutionEngine::new();
let callable = Arc::new(TestCallable::new("test_func"));
engine.register(callable).unwrap();
let ctx =
ExecutionContext::new(Arc::new(crate::engine::Database::open_in_memory().unwrap()));
for _ in 0..10 {
let _ = engine.execute("test_func", &ctx, &[Value::Int(42)]);
}
let metrics = engine.metrics().unwrap();
assert_eq!(metrics.call_counts.get("test_func"), Some(&10));
}
}