use serde::Serialize;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::OnceLock;
use swarm_engine_core::types::LoraConfig;
use swarm_engine_core::util::epoch_millis;
use tokio::sync::broadcast;
#[derive(Debug, Clone, Serialize)]
pub struct LlmDebugEvent {
pub timestamp_ms: u64,
pub worker_id: Option<usize>,
pub call_type: String,
pub model_name: String,
pub endpoint: String,
pub prompt: String,
pub response: Option<String>,
pub error: Option<String>,
pub latency_ms: u64,
pub lora: Option<LoraConfig>,
}
impl LlmDebugEvent {
pub fn new(call_type: impl Into<String>, model_name: impl Into<String>) -> Self {
Self {
timestamp_ms: epoch_millis(),
worker_id: None,
call_type: call_type.into(),
model_name: model_name.into(),
endpoint: String::new(),
prompt: String::new(),
response: None,
error: None,
latency_ms: 0,
lora: None,
}
}
pub fn worker_id(mut self, id: usize) -> Self {
self.worker_id = Some(id);
self
}
pub fn endpoint(mut self, endpoint: impl Into<String>) -> Self {
self.endpoint = endpoint.into();
self
}
pub fn prompt(mut self, prompt: impl Into<String>) -> Self {
self.prompt = prompt.into();
self
}
pub fn response(mut self, response: impl Into<String>) -> Self {
self.response = Some(response.into());
self
}
pub fn error(mut self, error: impl Into<String>) -> Self {
self.error = Some(error.into());
self
}
pub fn latency_ms(mut self, latency_ms: u64) -> Self {
self.latency_ms = latency_ms;
self
}
pub fn lora(mut self, lora: LoraConfig) -> Self {
self.lora = Some(lora);
self
}
pub fn lora_opt(mut self, lora: Option<LoraConfig>) -> Self {
self.lora = lora;
self
}
}
pub struct LlmDebugChannel {
tx: broadcast::Sender<LlmDebugEvent>,
enabled: AtomicBool,
emit_on_error: AtomicBool,
}
impl LlmDebugChannel {
pub fn new(capacity: usize) -> Self {
let (tx, _) = broadcast::channel(capacity);
Self {
tx,
enabled: AtomicBool::new(false),
emit_on_error: AtomicBool::new(true),
}
}
pub fn global() -> &'static Self {
static INSTANCE: OnceLock<LlmDebugChannel> = OnceLock::new();
INSTANCE.get_or_init(|| Self::new(256))
}
pub fn enable(&self) {
self.enabled.store(true, Ordering::Relaxed);
}
pub fn disable(&self) {
self.enabled.store(false, Ordering::Relaxed);
}
pub fn is_enabled(&self) -> bool {
self.enabled.load(Ordering::Relaxed)
}
pub fn set_emit_on_error(&self, enabled: bool) {
self.emit_on_error.store(enabled, Ordering::Relaxed);
}
pub fn emit(&self, event: LlmDebugEvent) {
let should_emit = self.enabled.load(Ordering::Relaxed)
|| (self.emit_on_error.load(Ordering::Relaxed) && event.error.is_some());
if should_emit {
let _ = self.tx.send(event);
}
}
pub fn subscribe(&self) -> broadcast::Receiver<LlmDebugEvent> {
self.tx.subscribe()
}
pub fn receiver_count(&self) -> usize {
self.tx.receiver_count()
}
}
impl Default for LlmDebugChannel {
fn default() -> Self {
Self::new(256)
}
}
pub struct StderrLlmSubscriber {
rx: broadcast::Receiver<LlmDebugEvent>,
truncate_at: Option<usize>,
}
impl StderrLlmSubscriber {
pub fn new(rx: broadcast::Receiver<LlmDebugEvent>) -> Self {
Self {
rx,
truncate_at: None,
}
}
pub fn truncate_at(mut self, chars: usize) -> Self {
self.truncate_at = Some(chars);
self
}
pub async fn run(mut self) {
while let Ok(event) = self.rx.recv().await {
self.print_event(&event);
}
}
fn print_event(&self, event: &LlmDebugEvent) {
eprintln!("=== LLM Call ({}) ===", event.call_type);
eprintln!(
" Model: {} | Endpoint: {}",
event.model_name, event.endpoint
);
if let Some(wid) = event.worker_id {
eprintln!(" Worker: {}", wid);
}
if let Some(ref lora) = event.lora {
if let Some(ref name) = lora.name {
eprintln!(" LoRA: {} (id={}, scale={:.2})", name, lora.id, lora.scale);
} else {
eprintln!(" LoRA: id={}, scale={:.2}", lora.id, lora.scale);
}
}
eprintln!(" Latency: {}ms", event.latency_ms);
eprintln!("--- Prompt ---");
eprintln!("{}", self.maybe_truncate(&event.prompt));
if let Some(resp) = &event.response {
eprintln!("--- Response ---");
eprintln!("{}", self.maybe_truncate(resp));
}
if let Some(err) = &event.error {
eprintln!("--- Error ---");
eprintln!("{}", err);
}
eprintln!();
}
fn maybe_truncate(&self, s: &str) -> String {
match self.truncate_at {
Some(max) if s.len() > max => {
format!("{}... (truncated, {} chars total)", &s[..max], s.len())
}
_ => s.to_string(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_debug_channel_disabled_by_default() {
let channel = LlmDebugChannel::new(16);
assert!(!channel.is_enabled());
}
#[test]
fn test_debug_channel_enable_disable() {
let channel = LlmDebugChannel::new(16);
channel.enable();
assert!(channel.is_enabled());
channel.disable();
assert!(!channel.is_enabled());
}
#[tokio::test]
async fn test_debug_channel_emit_when_enabled() {
let channel = LlmDebugChannel::new(16);
channel.enable();
let mut rx = channel.subscribe();
channel.emit(
LlmDebugEvent::new("decide", "test-model")
.prompt("test prompt")
.response("test response"),
);
let event = rx.recv().await.unwrap();
assert_eq!(event.call_type, "decide");
assert_eq!(event.model_name, "test-model");
assert_eq!(event.prompt, "test prompt");
assert_eq!(event.response, Some("test response".to_string()));
}
#[tokio::test]
async fn test_debug_channel_emit_on_error() {
let channel = LlmDebugChannel::new(16);
let mut rx = channel.subscribe();
channel.emit(
LlmDebugEvent::new("decide", "test-model")
.prompt("test prompt")
.error("connection timeout"),
);
let event = rx.recv().await.unwrap();
assert_eq!(event.error, Some("connection timeout".to_string()));
}
#[tokio::test]
async fn test_debug_channel_no_emit_when_disabled() {
let channel = LlmDebugChannel::new(16);
channel.set_emit_on_error(false);
let mut rx = channel.subscribe();
channel.emit(
LlmDebugEvent::new("decide", "test-model")
.prompt("test prompt")
.response("test response"),
);
let result = tokio::time::timeout(std::time::Duration::from_millis(10), rx.recv()).await;
assert!(result.is_err());
}
#[test]
fn test_event_builder() {
let event = LlmDebugEvent::new("decide", "qwen2.5")
.worker_id(42)
.endpoint("http://localhost:11434")
.prompt("test prompt")
.response("test response")
.latency_ms(150);
assert_eq!(event.call_type, "decide");
assert_eq!(event.model_name, "qwen2.5");
assert_eq!(event.worker_id, Some(42));
assert_eq!(event.endpoint, "http://localhost:11434");
assert_eq!(event.prompt, "test prompt");
assert_eq!(event.response, Some("test response".to_string()));
assert_eq!(event.latency_ms, 150);
}
}