use super::*;
use serde::{Deserialize, Serialize};
use serde_json;
use std::collections::{HashMap as StdHashMap, VecDeque};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use tracing::{debug, error, info, warn};
#[cfg(not(target_arch = "wasm32"))]
use futures_util::{SinkExt, StreamExt};
#[cfg(not(target_arch = "wasm32"))]
use tokio::net::TcpStream;
#[cfg(not(target_arch = "wasm32"))]
use tokio_tungstenite::{
connect_async, tungstenite::Message as WsMessage, MaybeTlsStream, WebSocketStream,
};
#[cfg(target_arch = "wasm32")]
use wasm_bindgen::prelude::*;
#[cfg(target_arch = "wasm32")]
use wasm_bindgen::JsCast;
#[cfg(target_arch = "wasm32")]
use web_sys::{CloseEvent, ErrorEvent, MessageEvent, WebSocket};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TracingConfig {
pub server_url: String,
pub batch_size: usize,
pub batch_timeout: Duration,
pub enable_compression: bool,
pub retry_config: RetryConfig,
pub enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetryConfig {
pub max_retries: usize,
pub initial_delay: Duration,
pub max_delay: Duration,
pub backoff_multiplier: f64,
}
#[derive(Debug, thiserror::Error)]
pub enum TracingError {
#[error("Connection failed: {0}")]
ConnectionFailed(String),
#[error("Serialization failed: {0}")]
SerializationFailed(#[from] serde_json::Error),
#[error("WebSocket error: {0}")]
WebSocketError(String),
#[error("Client is disconnected")]
Disconnected,
#[error("Timeout occurred")]
Timeout,
}
type Result<T> = std::result::Result<T, TracingError>;
#[derive(Debug)]
struct ClientState {
connection_status: ConnectionStatus,
event_buffer: VecDeque<TraceEvent>,
#[cfg(not(target_arch = "wasm32"))]
pending_requests: StdHashMap<String, tokio::sync::oneshot::Sender<TracingResponse>>,
last_batch_time: Instant,
reconnect_attempts: usize,
current_trace_id: Option<TraceId>,
#[cfg(not(target_arch = "wasm32"))]
message_sender: Option<tokio::sync::mpsc::UnboundedSender<String>>,
#[cfg(target_arch = "wasm32")]
websocket: Option<WebSocket>,
}
#[derive(Debug, Clone, PartialEq)]
#[allow(dead_code)]
enum ConnectionStatus {
Disconnected,
Connecting,
Connected,
Reconnecting,
}
pub struct TracingClient {
config: TracingConfig,
state: Arc<Mutex<ClientState>>,
#[cfg(not(target_arch = "wasm32"))]
#[allow(dead_code)]
runtime: Option<tokio::runtime::Handle>,
}
impl Default for TracingConfig {
fn default() -> Self {
Self {
server_url: "ws://localhost:8080".to_string(),
batch_size: 1, batch_timeout: Duration::from_millis(100), enable_compression: true,
retry_config: RetryConfig {
max_retries: 5,
initial_delay: Duration::from_millis(1000),
max_delay: Duration::from_secs(30),
backoff_multiplier: 2.0,
},
enabled: true,
}
}
}
impl TracingClient {
pub fn new(config: TracingConfig) -> Self {
let state = Arc::new(Mutex::new(ClientState {
connection_status: ConnectionStatus::Disconnected,
event_buffer: VecDeque::new(),
#[cfg(not(target_arch = "wasm32"))]
pending_requests: StdHashMap::new(),
last_batch_time: Instant::now(),
reconnect_attempts: 0,
current_trace_id: None,
#[cfg(not(target_arch = "wasm32"))]
message_sender: None,
#[cfg(target_arch = "wasm32")]
websocket: None,
}));
Self {
config: config.clone(),
state,
#[cfg(not(target_arch = "wasm32"))]
runtime: if config.enabled {
tokio::runtime::Handle::try_current().ok()
} else {
None
},
}
}
pub fn with_default_config() -> Self {
Self::new(TracingConfig::default())
}
pub fn with_server_url(url: impl Into<String>) -> Self {
let config = TracingConfig {
server_url: url.into(),
..TracingConfig::default()
};
Self::new(config)
}
#[cfg(not(target_arch = "wasm32"))]
pub async fn connect(&self) -> Result<()> {
if !self.config.enabled {
debug!("Tracing is disabled, skipping connection");
return Ok(());
}
info!("Connecting to tracing server at {}", self.config.server_url);
{
let mut state = self.state.lock().unwrap();
state.connection_status = ConnectionStatus::Connecting;
}
let (ws_stream, _) = connect_async(&self.config.server_url)
.await
.map_err(|e| TracingError::ConnectionFailed(e.to_string()))?;
{
let mut state = self.state.lock().unwrap();
state.connection_status = ConnectionStatus::Connected;
state.reconnect_attempts = 0;
}
info!("Successfully connected to tracing server");
let state_clone = Arc::clone(&self.state);
let config_clone = self.config.clone();
tokio::spawn(async move {
Self::handle_connection_static(state_clone, config_clone, ws_stream).await;
});
tokio::time::sleep(Duration::from_millis(100)).await;
Ok(())
}
#[cfg(not(target_arch = "wasm32"))]
async fn handle_connection_static(
state: Arc<Mutex<ClientState>>,
config: TracingConfig,
ws_stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
) {
let (ws_sender, mut ws_receiver) = ws_stream.split();
let (message_tx, mut message_rx) = tokio::sync::mpsc::unbounded_channel();
{
let mut state_guard = state.lock().unwrap();
state_guard.message_sender = Some(message_tx);
}
let mut ws_sender = ws_sender;
let state_for_sender = Arc::clone(&state);
tokio::spawn(async move {
while let Some(message) = message_rx.recv().await {
debug!("Sending WebSocket message: {}", message);
if let Err(e) = ws_sender.send(WsMessage::Text(message.into())).await {
error!("Failed to send message: {}", e);
let mut state_guard = state_for_sender.lock().unwrap();
state_guard.connection_status = ConnectionStatus::Disconnected;
state_guard.message_sender = None;
break;
}
}
});
let state_for_receiver = Arc::clone(&state);
tokio::spawn(async move {
while let Some(msg) = ws_receiver.next().await {
match msg {
Ok(WsMessage::Text(text)) => {
debug!("Received WebSocket message: {}", text);
if let Ok(response) = serde_json::from_str::<TraceResponse>(&text) {
let mut state_guard = state_for_receiver.lock().unwrap();
let converted_response = match response {
TraceResponse::TraceStored { trace_id } => {
TracingResponse::TraceStarted { trace_id }
}
TraceResponse::TracesFound { traces } => {
TracingResponse::QueryResults {
traces,
total_count: 0,
}
}
TraceResponse::TraceData { trace } => {
TracingResponse::TraceData { trace: Some(trace) }
}
TraceResponse::Error { message } => TracingResponse::Error {
message,
code: ErrorCode::InternalError,
},
TraceResponse::Metrics { data: _ } => TracingResponse::Pong,
};
#[cfg(not(target_arch = "wasm32"))]
if let Some((_, sender)) = state_guard.pending_requests.drain().next() {
let _ = sender.send(converted_response);
};
}
}
Ok(WsMessage::Close(_)) => {
warn!("WebSocket connection closed by server");
let mut state_guard = state_for_receiver.lock().unwrap();
state_guard.connection_status = ConnectionStatus::Disconnected;
state_guard.message_sender = None;
break;
}
Err(e) => {
error!("WebSocket error: {}", e);
let mut state_guard = state_for_receiver.lock().unwrap();
state_guard.connection_status = ConnectionStatus::Disconnected;
state_guard.message_sender = None;
break;
}
_ => {}
}
}
});
let state_for_batching = Arc::clone(&state);
tokio::spawn(async move {
let mut interval = tokio::time::interval(config.batch_timeout);
loop {
interval.tick().await;
let (events_to_send, sender_ref) = {
let mut state_guard = state_for_batching.lock().unwrap();
if state_guard.connection_status != ConnectionStatus::Connected {
continue;
}
let should_send = !state_guard.event_buffer.is_empty()
&& (state_guard.event_buffer.len() >= config.batch_size
|| state_guard.last_batch_time.elapsed() >= config.batch_timeout);
if should_send {
let events = state_guard.event_buffer.drain(..).collect::<Vec<_>>();
state_guard.last_batch_time = Instant::now();
(events, state_guard.message_sender.clone())
} else {
(Vec::new(), None)
}
};
if !events_to_send.is_empty() {
info!("Sending batch of {} trace events", events_to_send.len());
if let Some(sender) = sender_ref {
for event in events_to_send {
let request = TracingRequest::RecordEvent {
trace_id: TraceId::new(),
event,
};
if let Ok(message) = serde_json::to_string(&request) {
if let Err(e) = sender.send(message) {
error!("Failed to send trace event: {}", e);
let mut state_guard = state_for_batching.lock().unwrap();
state_guard.connection_status = ConnectionStatus::Disconnected;
state_guard.message_sender = None;
break;
}
}
}
}
}
}
});
}
#[cfg(target_arch = "wasm32")]
pub async fn connect(&self) -> Result<()> {
if !self.config.enabled {
debug!("Tracing is disabled, skipping connection");
return Ok(());
}
info!("Connecting to tracing server at {}", self.config.server_url);
let ws = WebSocket::new(&self.config.server_url).map_err(|_| {
TracingError::ConnectionFailed("Failed to create WebSocket".to_string())
})?;
let state_clone = Arc::clone(&self.state);
let onopen = Closure::wrap(Box::new(move || {
let mut state = state_clone.lock().unwrap();
state.connection_status = ConnectionStatus::Connected;
state.reconnect_attempts = 0;
info!("WebSocket connection opened");
}) as Box<dyn FnMut()>);
ws.set_onopen(Some(onopen.as_ref().unchecked_ref()));
onopen.forget();
let state_clone = Arc::clone(&self.state);
let onclose = Closure::wrap(Box::new(move |_: CloseEvent| {
let mut state = state_clone.lock().unwrap();
state.connection_status = ConnectionStatus::Disconnected;
warn!("WebSocket connection closed");
}) as Box<dyn FnMut(CloseEvent)>);
ws.set_onclose(Some(onclose.as_ref().unchecked_ref()));
onclose.forget();
let onerror = Closure::wrap(Box::new(move |_: ErrorEvent| {
error!("WebSocket error occurred");
}) as Box<dyn FnMut(ErrorEvent)>);
ws.set_onerror(Some(onerror.as_ref().unchecked_ref()));
onerror.forget();
{
let mut state = self.state.lock().unwrap();
state.websocket = Some(ws);
state.connection_status = ConnectionStatus::Connected;
}
Ok(())
}
pub async fn start_trace(&self, flow_id: FlowId, version: FlowVersion) -> Result<TraceId> {
if !self.config.enabled {
return Ok(TraceId::new());
}
let trace_id = TraceId::new();
let request = TracingRequest::StartTrace { flow_id, version };
let message = serde_json::to_string(&request)?;
self.send_message(message).await?;
{
let mut state = self.state.lock().unwrap();
state.current_trace_id = Some(trace_id.clone());
}
Ok(trace_id)
}
pub async fn record_event(&self, _trace_id: TraceId, event: TraceEvent) -> Result<()> {
if !self.config.enabled {
return Ok(());
}
debug!("Recording trace event: {:?}", event.event_type);
let should_send_immediately = {
let mut state = self.state.lock().unwrap();
state.event_buffer.push_back(event.clone());
if self.config.batch_size == 1 {
state.event_buffer.pop_front().is_some()
} else if state.event_buffer.len() >= self.config.batch_size {
debug!("Event buffer full, will be sent by batching task");
false
} else {
false
}
};
if should_send_immediately {
let request = TracingRequest::RecordEvent {
trace_id: TraceId::new(),
event,
};
if let Ok(message) = serde_json::to_string(&request) {
self.send_message(message).await?;
}
}
Ok(())
}
pub async fn ping(&self) -> Result<()> {
if !self.config.enabled {
return Ok(());
}
let request = TracingRequest::Ping;
let message = serde_json::to_string(&request)?;
self.send_message(message).await?;
Ok(())
}
pub async fn query_traces(&self, query: TraceQuery) -> Result<Vec<FlowTrace>> {
if !self.config.enabled {
return Ok(Vec::new());
}
let request = TracingRequest::QueryTraces { query };
let message = serde_json::to_string(&request)?;
self.send_message(message).await?;
Ok(Vec::new())
}
pub async fn get_trace(&self, trace_id: TraceId) -> Result<Option<FlowTrace>> {
if !self.config.enabled {
return Ok(None);
}
let request = TracingRequest::GetTrace { trace_id };
let message = serde_json::to_string(&request)?;
self.send_message(message).await?;
Ok(None)
}
pub async fn subscribe(&self, filters: SubscriptionFilters) -> Result<()> {
if !self.config.enabled {
return Ok(());
}
let request = TracingRequest::Subscribe { filters };
let message = serde_json::to_string(&request)?;
self.send_message(message).await?;
Ok(())
}
pub fn is_connected(&self) -> bool {
let state = self.state.lock().unwrap();
state.connection_status == ConnectionStatus::Connected
}
pub fn connection_status(&self) -> String {
let state = self.state.lock().unwrap();
format!("{:?}", state.connection_status)
}
pub async fn shutdown(&self) -> Result<()> {
if !self.config.enabled {
return Ok(());
}
info!("Shutting down tracing client...");
let events_to_send = {
let mut state = self.state.lock().unwrap();
state.connection_status = ConnectionStatus::Disconnected;
state.event_buffer.drain(..).collect::<Vec<_>>()
};
if !events_to_send.is_empty() {
info!("Flushing {} pending trace events", events_to_send.len());
for event in events_to_send {
let request = TracingRequest::RecordEvent {
trace_id: TraceId::new(),
event,
};
if let Ok(message) = serde_json::to_string(&request) {
let _ = self.send_message(message).await;
}
}
}
info!("Tracing client shutdown complete");
Ok(())
}
pub async fn flush(&self) -> Result<()> {
if !self.config.enabled {
return Ok(());
}
let events_to_send = {
let mut state = self.state.lock().unwrap();
if state.event_buffer.is_empty() {
return Ok(());
}
let events = state.event_buffer.drain(..).collect::<Vec<_>>();
state.last_batch_time = Instant::now();
events
};
info!("Flushing {} trace events", events_to_send.len());
for event in events_to_send {
let request = TracingRequest::RecordEvent {
trace_id: TraceId::new(),
event,
};
if let Ok(message) = serde_json::to_string(&request) {
self.send_message(message).await?;
}
}
Ok(())
}
async fn send_message(&self, message: String) -> Result<()> {
debug!("Attempting to send message: {}", message);
#[cfg(not(target_arch = "wasm32"))]
{
let sender_ref = {
let state = self.state.lock().unwrap();
if state.connection_status != ConnectionStatus::Connected {
return Err(TracingError::Disconnected);
}
state.message_sender.clone()
};
if let Some(sender) = sender_ref {
sender
.send(message.clone())
.map_err(|e| TracingError::WebSocketError(e.to_string()))?;
debug!("Message queued for sending: {}", message);
} else {
return Err(TracingError::Disconnected);
}
}
#[cfg(target_arch = "wasm32")]
{
let websocket = {
let state = self.state.lock().unwrap();
if state.connection_status != ConnectionStatus::Connected {
return Err(TracingError::Disconnected);
}
state.websocket.clone()
};
if let Some(ws) = websocket {
ws.send_with_str(&message).map_err(|_| {
TracingError::WebSocketError("Failed to send message".to_string())
})?;
debug!("Message sent via WebSocket: {}", message);
} else {
return Err(TracingError::Disconnected);
}
}
Ok(())
}
}
#[derive(Clone)]
pub struct TracingIntegration {
client: Arc<TracingClient>,
current_trace_id: Arc<Mutex<Option<TraceId>>>,
}
unsafe impl Send for TracingIntegration {}
unsafe impl Sync for TracingIntegration {}
impl TracingIntegration {
pub fn new(client: TracingClient) -> Self {
Self {
client: Arc::new(client),
current_trace_id: Arc::new(Mutex::new(None)),
}
}
pub async fn start_flow_trace(&self, flow_id: impl Into<String>) -> Result<TraceId> {
let flow_id = FlowId::new(flow_id);
let version = FlowVersion {
major: 1,
minor: 0,
patch: 0,
git_hash: None,
timestamp: chrono::Utc::now(),
};
let trace_id = self.client.start_trace(flow_id, version).await?;
{
let mut current = self.current_trace_id.lock().unwrap();
*current = Some(trace_id.clone());
}
Ok(trace_id)
}
pub async fn trace_actor_created(&self, actor_id: impl Into<String>) -> Result<()> {
let event = TraceEvent::actor_created(actor_id.into());
let trace_id = self.current_trace_id.lock().unwrap().clone();
if let Some(trace_id) = trace_id {
self.client.record_event(trace_id, event).await
} else {
self.client.record_event(TraceId::new(), event).await
}
}
pub async fn trace_actor_completed(&self, actor_id: impl Into<String>) -> Result<()> {
let event = TraceEvent::actor_completed(actor_id.into());
let trace_id = self.current_trace_id.lock().unwrap().clone();
if let Some(trace_id) = trace_id {
self.client.record_event(trace_id, event).await
} else {
self.client.record_event(TraceId::new(), event).await
}
}
pub async fn trace_message_sent(
&self,
actor_id: impl Into<String>,
port: impl Into<String>,
message_type: impl Into<String>,
size_bytes: usize,
) -> Result<()> {
let event = TraceEvent::message_sent(
actor_id.into(),
port.into(),
message_type.into(),
size_bytes,
);
let trace_id = self.current_trace_id.lock().unwrap().clone();
if let Some(trace_id) = trace_id {
self.client.record_event(trace_id, event).await
} else {
self.client.record_event(TraceId::new(), event).await
}
}
pub async fn trace_actor_failed(
&self,
actor_id: impl Into<String>,
error: impl Into<String>,
) -> Result<()> {
let event = TraceEvent::actor_failed(actor_id.into(), error.into());
let trace_id = self.current_trace_id.lock().unwrap().clone();
if let Some(trace_id) = trace_id {
self.client.record_event(trace_id, event).await
} else {
self.client.record_event(TraceId::new(), event).await
}
}
pub async fn trace_data_flow(
&self,
from_actor: impl Into<String>,
from_port: impl Into<String>,
to_actor: impl Into<String>,
to_port: impl Into<String>,
message_type: impl Into<String>,
size_bytes: usize,
) -> Result<()> {
let event = TraceEvent::data_flow(
from_actor.into(),
from_port.into(),
to_actor.into(),
to_port.into(),
message_type.into(),
size_bytes,
);
let trace_id = self.current_trace_id.lock().unwrap().clone();
if let Some(trace_id) = trace_id {
self.client.record_event(trace_id, event).await
} else {
self.client.record_event(TraceId::new(), event).await
}
}
pub fn client(&self) -> Arc<TracingClient> {
Arc::clone(&self.client)
}
}
static GLOBAL_CLIENT: std::sync::OnceLock<TracingIntegration> = std::sync::OnceLock::new();
pub fn init_global_tracing(config: TracingConfig) -> Result<()> {
let client = TracingClient::new(config);
let integration = TracingIntegration::new(client);
GLOBAL_CLIENT.set(integration).map_err(|_| {
TracingError::WebSocketError("Global tracing client already initialized".to_string())
})?;
Ok(())
}
pub fn global_tracing() -> Option<&'static TracingIntegration> {
GLOBAL_CLIENT.get()
}
#[macro_export]
macro_rules! trace_actor_event {
(created, $actor_id:expr) => {
if let Some(tracing) = $crate::tracing::global_tracing() {
let _ = tracing.trace_actor_created($actor_id).await;
}
};
(message_sent, $actor_id:expr, $port:expr, $msg_type:expr, $size:expr) => {
if let Some(tracing) = $crate::tracing::global_tracing() {
let _ = tracing
.trace_message_sent($actor_id, $port, $msg_type, $size)
.await;
}
};
(failed, $actor_id:expr, $error:expr) => {
if let Some(tracing) = $crate::tracing::global_tracing() {
let _ = tracing.trace_actor_failed($actor_id, $error).await;
}
};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = TracingConfig::default();
assert_eq!(config.server_url, "ws://localhost:8080");
assert_eq!(config.batch_size, 1); assert!(config.enabled);
}
#[test]
fn test_client_creation() {
let client = TracingClient::with_default_config();
assert!(!client.is_connected());
}
#[tokio::test]
async fn test_tracing_integration() {
let config = TracingConfig {
enabled: false,
..Default::default()
};
let client = TracingClient::new(config);
let integration = TracingIntegration::new(client);
let result = integration.start_flow_trace("test_flow").await;
assert!(result.is_ok());
let actor_result = integration.trace_actor_created("test_actor").await;
assert!(actor_result.is_ok());
let message_result = integration
.trace_message_sent("test_actor", "output", "TestMessage", 128)
.await;
assert!(message_result.is_ok());
}
}