use crate::AgentMessage;
use crate::dora_adapter::error::{DoraError, DoraResult};
use crate::interrupt::AgentInterrupt;
use crate::message::{AgentEvent, TaskRequest};
use dora_node_api::Event;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{RwLock, mpsc};
use tracing::{debug, info, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DoraNodeConfig {
pub node_id: String,
pub name: String,
pub inputs: Vec<String>,
pub outputs: Vec<String>,
pub event_buffer_size: usize,
pub default_timeout: Duration,
pub custom_config: HashMap<String, String>,
}
impl Default for DoraNodeConfig {
fn default() -> Self {
Self {
node_id: uuid::Uuid::now_v7().to_string(),
name: "default_node".to_string(),
inputs: vec![],
outputs: vec![],
event_buffer_size: 1024,
default_timeout: Duration::from_secs(30),
custom_config: HashMap::new(),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum NodeState {
Created,
Initializing,
Running,
Paused,
Stopping,
Stopped,
Error(String),
}
pub struct DoraAgentNode {
config: DoraNodeConfig,
state: Arc<RwLock<NodeState>>,
interrupt: AgentInterrupt,
event_tx: mpsc::Sender<AgentEvent>,
event_rx: Arc<RwLock<mpsc::Receiver<AgentEvent>>>,
output_channels: Arc<RwLock<HashMap<String, mpsc::Sender<Vec<u8>>>>>,
}
impl DoraAgentNode {
pub fn new(config: DoraNodeConfig) -> Self {
let (event_tx, event_rx) = mpsc::channel(config.event_buffer_size);
Self {
config,
state: Arc::new(RwLock::new(NodeState::Created)),
interrupt: AgentInterrupt::new(),
event_tx,
event_rx: Arc::new(RwLock::new(event_rx)),
output_channels: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn config(&self) -> &DoraNodeConfig {
&self.config
}
pub async fn state(&self) -> NodeState {
self.state.read().await.clone()
}
pub fn interrupt(&self) -> &AgentInterrupt {
&self.interrupt
}
pub async fn init(&self) -> DoraResult<()> {
let mut state = self.state.write().await;
if *state != NodeState::Created {
return Err(DoraError::NodeInitError(
"Node already initialized".to_string(),
));
}
*state = NodeState::Initializing;
let mut output_channels = self.output_channels.write().await;
for output in &self.config.outputs {
let (tx, _rx) = mpsc::channel(self.config.event_buffer_size);
output_channels.insert(output.clone(), tx);
}
*state = NodeState::Running;
info!("DoraAgentNode {} initialized", self.config.node_id);
Ok(())
}
pub async fn send_output(&self, output_id: &str, data: Vec<u8>) -> DoraResult<()> {
let state = self.state.read().await;
if *state != NodeState::Running {
return Err(DoraError::NodeNotRunning);
}
let output_channels = self.output_channels.read().await;
if let Some(tx) = output_channels.get(output_id) {
tx.send(data)
.await
.map_err(|e| DoraError::ChannelError(e.to_string()))?;
debug!("Sent data to output: {}", output_id);
} else {
warn!("Output channel {} not found", output_id);
}
Ok(())
}
pub async fn send_message(&self, output_id: &str, message: &AgentMessage) -> DoraResult<()> {
let data = bincode::serialize(message)?;
self.send_output(output_id, data).await
}
pub async fn inject_event(&self, event: AgentEvent) -> DoraResult<()> {
self.event_tx
.send(event)
.await
.map_err(|e| DoraError::ChannelError(e.to_string()))?;
Ok(())
}
pub async fn pause(&self) -> DoraResult<()> {
let mut state = self.state.write().await;
if *state == NodeState::Running {
*state = NodeState::Paused;
info!("DoraAgentNode {} paused", self.config.node_id);
}
Ok(())
}
pub async fn resume(&self) -> DoraResult<()> {
let mut state = self.state.write().await;
if *state == NodeState::Paused {
*state = NodeState::Running;
info!("DoraAgentNode {} resumed", self.config.node_id);
}
Ok(())
}
pub async fn stop(&self) -> DoraResult<()> {
let mut state = self.state.write().await;
*state = NodeState::Stopping;
self.interrupt.trigger();
*state = NodeState::Stopped;
info!("DoraAgentNode {} stopped", self.config.node_id);
Ok(())
}
pub fn create_event_loop(&self) -> NodeEventLoop {
NodeEventLoop {
event_rx: self.event_rx.clone(),
interrupt: self.interrupt.clone(),
state: self.state.clone(),
}
}
}
pub struct NodeEventLoop {
event_rx: Arc<RwLock<mpsc::Receiver<AgentEvent>>>,
interrupt: AgentInterrupt,
state: Arc<RwLock<NodeState>>,
}
impl NodeEventLoop {
pub async fn next_event(&self) -> Option<AgentEvent> {
if self.interrupt.check() {
return Some(AgentEvent::Shutdown);
}
let state = self.state.read().await;
if *state == NodeState::Stopped || *state == NodeState::Stopping {
return Some(AgentEvent::Shutdown);
}
drop(state);
let mut event_rx = self.event_rx.write().await;
tokio::select! {
event = event_rx.recv() => event,
_ = self.interrupt.notify.notified() => Some(AgentEvent::Shutdown),
}
}
pub async fn try_next_event(&self) -> Option<AgentEvent> {
if self.interrupt.check() {
return Some(AgentEvent::Shutdown);
}
let mut event_rx = self.event_rx.write().await;
event_rx.try_recv().ok()
}
pub fn should_interrupt(&self) -> bool {
self.interrupt.check()
}
pub fn interrupt(&self) -> &AgentInterrupt {
&self.interrupt
}
}
fn extract_bytes_from_arrow_data(data: &dora_node_api::ArrowData) -> Vec<u8> {
Vec::<u8>::try_from(data).unwrap_or_default()
}
pub fn convert_dora_event(dora_event: &Event) -> Option<AgentEvent> {
match dora_event {
Event::Stop(_cause) => Some(AgentEvent::Shutdown),
Event::Input {
id,
metadata: _,
data,
} => {
let bytes = extract_bytes_from_arrow_data(data);
if let Ok(task) = bincode::deserialize::<TaskRequest>(&bytes) {
Some(AgentEvent::TaskReceived(task))
} else if let Ok(msg) = bincode::deserialize::<AgentMessage>(&bytes) {
match msg {
AgentMessage::Event(event) => Some(event),
AgentMessage::TaskRequest { task_id, content } => {
Some(AgentEvent::TaskReceived(TaskRequest {
task_id,
content,
priority: crate::message::TaskPriority::Medium,
deadline: None,
metadata: HashMap::new(),
}))
}
_ => Some(AgentEvent::Custom(id.to_string(), bytes)),
}
} else {
Some(AgentEvent::Custom(id.to_string(), bytes))
}
}
Event::InputClosed { id } => {
debug!("Input {} closed", id);
None
}
_ => None,
}
}
#[cfg(test)]
mod tests {
#[tokio::test]
async fn test_node_lifecycle() {
let config = DoraNodeConfig {
node_id: "test_node".to_string(),
name: "Test Node".to_string(),
outputs: vec!["output1".to_string()],
..Default::default()
};
let node = DoraAgentNode::new(config);
assert_eq!(node.state().await, NodeState::Created);
node.init().await.unwrap();
assert_eq!(node.state().await, NodeState::Running);
node.pause().await.unwrap();
assert_eq!(node.state().await, NodeState::Paused);
node.resume().await.unwrap();
assert_eq!(node.state().await, NodeState::Running);
node.stop().await.unwrap();
assert_eq!(node.state().await, NodeState::Stopped);
}
}