use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use log;
use super::node::{ContextStrategy, NodeConfig};
use super::transition::TransitionResult;
use crate::context::LLMContext;
use crate::services::llm::function_registry::FunctionRegistry;
use crate::services::llm::openai::TransitionHook;
pub type DharaHandlerFuture = Pin<Box<dyn Future<Output = TransitionResult> + Send>>;
pub type DharaHandlerFn = Arc<dyn Fn(String) -> DharaHandlerFuture + Send + Sync>;
struct RegisteredNode {
config: NodeConfig,
handlers: Vec<(String, DharaHandlerFn)>,
}
pub struct DharaManager {
context: Arc<Mutex<LLMContext>>,
registry: Arc<Mutex<FunctionRegistry>>,
nodes: HashMap<String, RegisteredNode>,
current_node: Option<String>,
pending_transition: Arc<Mutex<Option<String>>>,
}
impl DharaManager {
pub fn new(
context: Arc<Mutex<LLMContext>>,
registry: Arc<Mutex<FunctionRegistry>>,
) -> Self {
Self {
context,
registry,
nodes: HashMap::new(),
current_node: None,
pending_transition: Arc::new(Mutex::new(None)),
}
}
pub fn register_node(
&mut self,
name: impl Into<String>,
config: NodeConfig,
handlers: Vec<(impl Into<String>, DharaHandlerFn)>,
) {
let name = name.into();
let handlers: Vec<(String, DharaHandlerFn)> = handlers
.into_iter()
.map(|(n, h)| (n.into(), h))
.collect();
log::debug!("Dhara: registered node '{}' with {} handler(s)", name, handlers.len());
self.nodes.insert(name, RegisteredNode { config, handlers });
}
pub fn register_node_no_tools(
&mut self,
name: impl Into<String>,
config: NodeConfig,
) {
let name = name.into();
log::debug!("Dhara: registered node '{}' (no tools)", name);
self.nodes.insert(name, RegisteredNode { config, handlers: vec![] });
}
pub fn set_initial_node(&mut self, name: &str) {
let node = self.nodes.get(name).unwrap_or_else(|| {
panic!("Dhara: node '{}' not registered", name)
});
log::info!("Dhara: setting initial node '{}'", name);
self.current_node = Some(name.to_string());
Self::apply_node_to_context(&self.context, &node.config);
self.build_registry_for_node(name);
}
pub fn current_node(&self) -> Option<&str> {
self.current_node.as_deref()
}
pub fn create_transition_hook(&self) -> TransitionHook {
let pending = self.pending_transition.clone();
let registry = self.registry.clone();
let context_ref = self.context.clone();
let nodes: HashMap<String, (NodeConfig, Vec<(String, DharaHandlerFn)>)> = self
.nodes
.iter()
.map(|(name, rn)| {
(
name.clone(),
(rn.config.clone(), rn.handlers.clone()),
)
})
.collect();
let pending_for_rebuild = pending.clone();
Arc::new(move |context: &Arc<Mutex<LLMContext>>| {
let next_name = pending.lock().unwrap().take();
if let Some(name) = next_name {
if let Some((config, handlers)) = nodes.get(&name) {
log::info!("Dhara: transitioning to node '{}'", name);
Self::apply_node_to_context(context, config);
{
let mut reg = registry.lock().unwrap();
*reg = FunctionRegistry::new();
for (fn_name, handler) in handlers {
let handler = handler.clone();
let pt = pending_for_rebuild.clone();
reg.register(fn_name.clone(), move |args: String| {
let handler = handler.clone();
let pt = pt.clone();
async move {
let result = handler(args).await;
match result {
TransitionResult::Stay(r) => r,
TransitionResult::Transition { result, next_node } => {
log::info!(
"Dhara: handler requested transition to '{}'",
next_node
);
*pt.lock().unwrap() = Some(next_node);
result
}
}
}
});
}
}
log::info!("Dhara: transition to '{}' complete", name);
} else {
log::error!("Dhara: node '{}' not registered, ignoring transition", name);
}
}
})
}
fn apply_node_to_context(context: &Arc<Mutex<LLMContext>>, config: &NodeConfig) {
let mut ctx = context.lock().unwrap();
if let Some(prompt) = &config.system_prompt {
ctx.system_prompt = Some(prompt.clone());
}
match config.context_strategy {
ContextStrategy::Reset => {
ctx.messages.clear();
}
ContextStrategy::Append => {}
}
for msg in &config.task_messages {
ctx.push_message(msg.clone());
}
ctx.tools = config.tools.clone();
ctx.tool_choice = None; }
fn build_registry_for_node(&self, name: &str) {
let node = self.nodes.get(name).expect("node must be registered");
let mut reg = self.registry.lock().unwrap();
*reg = FunctionRegistry::new();
for (fn_name, handler) in &node.handlers {
let handler = handler.clone();
let pending = self.pending_transition.clone();
reg.register(fn_name.clone(), move |args: String| {
let handler = handler.clone();
let pending = pending.clone();
async move {
let result = handler(args).await;
match result {
TransitionResult::Stay(r) => r,
TransitionResult::Transition { result, next_node } => {
log::info!(
"Dhara: handler requested transition to '{}'",
next_node
);
*pending.lock().unwrap() = Some(next_node);
result
}
}
}
});
}
}
}