behaviortree 0.7.4

A #![no_std] compatible behavior tree library similar to 'BehaviorTree.CPP'.
Documentation
// Copyright © 2025 Stephan Kunz

//! [`Groot2Connector`] implementation.

extern crate std;

use core::time::Duration;

use alloc::collections::vec_deque::VecDeque;
// region:      --- modules
use crate::{
	ConstString, Mutex, XmlCreator,
	behavior::{BehaviorState, behavior_data::BehaviorData},
	tree::{
		observer::groot2_protocol::{Groot2ReplyHeader, Groot2RequestHeader, Groot2RequestType, Groot2TransitionInfo},
		tree::{BehaviorTree, BehaviorTreeMessage},
	},
};
use alloc::string::{String, ToString};
use alloc::sync::Arc;
use bytes::{Bytes, BytesMut};
use thingbuf::mpsc;
use tokio::{task::JoinHandle, time::Instant};
use zeromq::{Socket, SocketRecv, SocketSend, ZmqMessage};
// endregion:   --- modules

/// Predefined size of the behavior state transition buffer.
/// This amount will be buffered between sends.
/// If there are more state transitions happening, the eldest will be dropped.
const TRANSITION_SIZE: u32 = 100;

/// constants
pub const GROOT_STATE: &str = "groot_state";

// region:      --- GrootCallback
/// Attach the Groot2 communication callbacks to a [`BehaviorTree`].
/// # Panics
/// - if an unknown message from Groot2 arrives
pub fn attach_groot_callback(tree: &mut BehaviorTree, shared: Arc<Mutex<Groot2ConnectorData>>) {
	let id: ConstString = GROOT_STATE.into();
	// add a callback for each tree element
	let shared = shared;
	for element in tree.iter_mut() {
		let shared_clone = shared.clone();
		// the callback
		let callback = move |behavior: &BehaviorData, new_state: &mut BehaviorState| {
			if behavior.state() != *new_state {
				// Groot does not need a state for root
				if behavior.uid() != 0 {
					let state = if *new_state == BehaviorState::Idle {
						behavior.state() as u8 + 10
					} else {
						*new_state as u8
					};
					let mut shared_guard = shared_clone.lock();
					let uid = behavior.uid().to_le_bytes();
					let index = 3 * ((behavior.uid() - 1) as usize);
					shared_guard.state_buffer[index] = uid[0];
					shared_guard.state_buffer[index + 1] = uid[1];
					shared_guard.state_buffer[index + 2] = state;

					if shared_guard.recording {
						#[allow(clippy::cast_possible_truncation)]
						#[allow(clippy::expect_used)]
						let timestamp = std::time::SystemTime::now()
							.duration_since(std::time::UNIX_EPOCH)
							.expect("Time went backwards")
							.as_micros() as u64;
						let info = Groot2TransitionInfo::new(timestamp, behavior.uid(), *new_state);
						if shared_guard.transitions_buffer.is_empty() {
							shared_guard.transitions = 0;
						} else if shared_guard.transitions >= TRANSITION_SIZE {
							shared_guard.transitions_buffer.pop_front();
						} else {
							shared_guard.transitions += 1;
						}
						shared_guard.transitions_buffer.push_back(info);
					}
					drop(shared_guard);
				}
			}
		};
		element.add_pre_state_change_callback(id.clone(), callback);
	}
}
// endregion:   --- GrootCallback

// region:      --- Groot2Connector
/// The [`Groot2Connector`] is used to create an interface between Groot2
/// and the tree executor.
///
/// The connection is via TCP and has to be established by Groot2.
/// So the connector on tree side only needs to know the port it shall listen on.
#[allow(dead_code)]
pub struct Groot2Connector {
	/// The sender to send messages to tree
	tx: mpsc::Sender<BehaviorTreeMessage>,
	/// Shared data across multiple tasks (callbacks)
	shared: Arc<Mutex<Groot2ConnectorData>>,
	/// Response server
	server_handle: JoinHandle<Result<(), zeromq::ZmqError>>,
	/// watchdog for connection
	watchdog_handle: JoinHandle<()>,
}

/// The shared data among multiple [`BehaviorTreeElement`]s.
pub struct Groot2ConnectorData {
	/// Connection indicator
	connected: bool,
	/// Flag for recording transitions, accessible from multiple tasks
	recording: bool,
	/// Current size of the transition buffer
	transitions: u32,
	/// The state buffer for Groot communication
	state_buffer: BytesMut,
	/// The transitions buffer for Groot communication
	transitions_buffer: VecDeque<Groot2TransitionInfo>,
	/// Timestamp of the last communication
	last_communication: Instant,
}

impl Groot2Connector {
	/// Construct a new [`Groot2Connector`].
	/// # Panics
	#[must_use]
	#[allow(clippy::too_many_lines)]
	pub fn new(tree: &mut BehaviorTree, port: u16) -> Self {
		// an empty transitions buffer
		let transitions_buffer = VecDeque::new();
		// a state buffer
		let tree_size = tree.size() - 1; // without root
		let mut state_buffer = BytesMut::zeroed((3 * tree_size) as usize);
		// initialize state buffer
		for i in 0..tree_size {
			let index = (3 * i) as usize;
			let bytes = (i + 1).to_be_bytes();
			state_buffer[index] = bytes[0];
			state_buffer[index] = bytes[1];
		}

		let shared = Arc::new(Mutex::new(Groot2ConnectorData {
			connected: false,
			recording: false,
			transitions: 0,
			state_buffer,
			transitions_buffer,
			last_communication: Instant::now(),
		}));

		let shared_clone = shared.clone();
		let sender = tree.sender();

		let watchdog_handle = tokio::spawn(async move {
			loop {
				// std::dbg!("watchdog");
				if let Some(mut data) = shared_clone.try_lock() {
					// std::dbg!("checking connection");
					#[allow(clippy::expect_used)]
					if data.connected
						&& Instant::now()
							.checked_duration_since(data.last_communication)
							.expect("time went backwards")
							> Duration::from_secs(5)
					{
						// std::dbg!("removing connection");
						let _ = sender
							.send(BehaviorTreeMessage::RemoveAllGrootHooks)
							.await;
						data.connected = false;
					}
				}

				tokio::time::sleep(Duration::from_secs(1)).await;
			}
		});

		// @TODO: proper error handling
		let shared_clone = shared.clone();
		let tree_id = tree.uuid();
		#[allow(clippy::expect_used)]
		let xml = XmlCreator::groot_write_tree(tree).expect("usually this should not happen");
		let sender = tree.sender();

		let server_handle = tokio::spawn(async move {
			// @TODO: replace zeromq with something #![no_std] compatible
			let server_address = String::from("tcp://0.0.0.0:") + &port.to_string();
			let mut server_socket = zeromq::RepSocket::new();
			server_socket.bind(&server_address).await?;

			loop {
				// std::dbg!("server");
				let request = server_socket.recv().await?;
				shared_clone.lock().last_communication = Instant::now();
				// std::dbg!(&request);
				if let Some(bytes) = request.get(0) {
					// std::dbg!(bytes);
					if let Ok(header) = Groot2RequestHeader::try_from(bytes) {
						let rq_type = header.rq_type();
						let reply_header = Groot2ReplyHeader::new(header, tree_id);
						let mut reply = ZmqMessage::from(Bytes::from(&reply_header));
						match rq_type {
							// most requests will be "State"
							Groot2RequestType::State => {
								// std::println!("{:?}", buffer.lock());
								reply.push_back(shared_clone.lock().state_buffer.clone().into());
							}
							Groot2RequestType::FullTree => {
								shared_clone.lock().connected = true;
								let _ = sender
									.send(BehaviorTreeMessage::AddGrootCallback(shared_clone.clone()))
									.await;
								reply.push_back(xml.clone());
							}
							Groot2RequestType::BlackBoard => {
								std::dbg!(&request);
								todo!()
							}
							Groot2RequestType::HookInsert => {
								std::dbg!(&request);
								todo!()
							}
							Groot2RequestType::HookRemove => {
								std::dbg!(&request);
								todo!()
							}
							Groot2RequestType::HooksDump => {
								std::dbg!(&request);
								todo!()
							}
							Groot2RequestType::RemoveAllHooks => {
								shared_clone.lock().connected = false;
								let _ = sender
									.send(BehaviorTreeMessage::RemoveAllGrootHooks)
									.await;
							}
							Groot2RequestType::DisableAllHooks => {
								std::dbg!(&request);
								todo!()
							}
							Groot2RequestType::BreakpointReached => {
								std::dbg!(&request);
								todo!()
							}
							Groot2RequestType::BreakpointUnlock => {
								std::dbg!(&request);
								todo!()
							}
							Groot2RequestType::ToggleRecording => {
								if let Some(command) = request.get(1) {
									let cmd = command.to_vec();
									match &cmd[..] {
										b"start" => {
											// activate transition recording
											let mut shared_guard = shared_clone.lock();
											shared_guard.recording = true;
											// clear transition buffer
											shared_guard.transitions_buffer.clear();
											// ensure that we can store at least TRANSITION_SIZE elements
											shared_guard
												.transitions_buffer
												.reserve(TRANSITION_SIZE as usize);
											drop(shared_guard);
											// return the microseconds since 01.01.1970
											#[allow(clippy::cast_possible_truncation)]
											#[allow(clippy::expect_used)]
											let timestamp = std::time::SystemTime::now()
												.duration_since(std::time::UNIX_EPOCH)
												.expect("Time went backwards")
												.as_micros() as u64;
											reply.push_back(Bytes::from(timestamp.to_string()));
										}
										b"stop" => {
											// de-activate transition recording
											shared_clone.lock().recording = false;
										}
										_ => {
											// this will only happen if there is some new Groot feature
											std::dbg!(&command);
											todo!()
										}
									}
								} else {
									todo!()
								}
							}
							Groot2RequestType::GetTransitions => {
								// send transition buffer
								let mut bytes = BytesMut::with_capacity((TRANSITION_SIZE * 9) as usize);
								let mut shared_guard = shared_clone.lock();
								for info in &shared_guard.transitions_buffer {
									bytes.extend(Bytes::from(info));
								}
								// std::println!("{:?}", &bytes);
								reply.push_back(Bytes::from(bytes));
								shared_guard.transitions_buffer.clear();
							}
							Groot2RequestType::Undefined => {
								std::dbg!(&request);
								todo!()
							}
						}

						// std::dbg!(&reply);
						server_socket.send(reply).await?;
					} else {
						std::dbg!(&request);
						todo!()
					}
				} else {
					todo!()
				}
			}
		});
		Self {
			tx: tree.sender(),
			shared,
			server_handle,
			watchdog_handle,
		}
	}
}
// endregion:   --- Groot2Connector