use crate::error::TorshError;
#[cfg(feature = "std")]
use std::collections::HashMap;
#[cfg(feature = "std")]
use std::sync::{Arc, Mutex, OnceLock};
#[cfg(feature = "std")]
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
#[cfg(not(feature = "std"))]
use alloc::{
collections::BTreeMap as HashMap,
string::{String, ToString},
vec::Vec,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[repr(u8)]
pub enum LogLevel {
Trace = 0,
Debug = 1,
Info = 2,
Warn = 3,
Error = 4,
Fatal = 5,
}
impl LogLevel {
pub fn as_str(&self) -> &'static str {
match self {
LogLevel::Trace => "TRACE",
LogLevel::Debug => "DEBUG",
LogLevel::Info => "INFO",
LogLevel::Warn => "WARN",
LogLevel::Error => "ERROR",
LogLevel::Fatal => "FATAL",
}
}
pub fn is_enabled(&self, min_level: LogLevel) -> bool {
*self >= min_level
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u32)]
pub enum ErrorCode {
Success = 0,
ShapeMismatch = 1000,
InvalidShape = 1001,
BroadcastError = 1002,
DimensionMismatch = 1003,
IndexOutOfBounds = 2000,
InvalidDimension = 2001,
InvalidSlice = 2002,
TypeMismatch = 3000,
UnsupportedType = 3001,
ConversionError = 3002,
DeviceMismatch = 4000,
DeviceUnavailable = 4001,
DeviceError = 4002,
AllocationFailed = 5000,
OutOfMemory = 5001,
InvalidAlignment = 5002,
ComputeError = 6000,
NumericalError = 6001,
ConvergenceError = 6002,
IoError = 7000,
SerializationError = 7001,
DeserializationError = 7002,
InvalidOperation = 8000,
NotImplemented = 8001,
InvalidState = 8002,
SynchronizationError = 8003,
Unknown = 9999,
}
impl ErrorCode {
pub fn code(&self) -> u32 {
*self as u32
}
pub fn description(&self) -> &'static str {
match self {
ErrorCode::Success => "Success",
ErrorCode::ShapeMismatch => "Shape mismatch between tensors",
ErrorCode::InvalidShape => "Invalid tensor shape",
ErrorCode::BroadcastError => "Broadcasting error",
ErrorCode::DimensionMismatch => "Dimension mismatch",
ErrorCode::IndexOutOfBounds => "Index out of bounds",
ErrorCode::InvalidDimension => "Invalid dimension",
ErrorCode::InvalidSlice => "Invalid slice",
ErrorCode::TypeMismatch => "Type mismatch",
ErrorCode::UnsupportedType => "Unsupported type",
ErrorCode::ConversionError => "Type conversion error",
ErrorCode::DeviceMismatch => "Device mismatch",
ErrorCode::DeviceUnavailable => "Device unavailable",
ErrorCode::DeviceError => "Device error",
ErrorCode::AllocationFailed => "Memory allocation failed",
ErrorCode::OutOfMemory => "Out of memory",
ErrorCode::InvalidAlignment => "Invalid memory alignment",
ErrorCode::ComputeError => "Computation error",
ErrorCode::NumericalError => "Numerical error",
ErrorCode::ConvergenceError => "Convergence error",
ErrorCode::IoError => "I/O error",
ErrorCode::SerializationError => "Serialization error",
ErrorCode::DeserializationError => "Deserialization error",
ErrorCode::InvalidOperation => "Invalid operation",
ErrorCode::NotImplemented => "Not implemented",
ErrorCode::InvalidState => "Invalid state",
ErrorCode::SynchronizationError => "Synchronization error",
ErrorCode::Unknown => "Unknown error",
}
}
pub fn from_torsh_error(error: &TorshError) -> Self {
match error {
TorshError::ShapeMismatch { .. } => ErrorCode::ShapeMismatch,
TorshError::BroadcastError { .. } => ErrorCode::BroadcastError,
TorshError::InvalidShape(_) => ErrorCode::InvalidShape,
TorshError::IndexOutOfBounds { .. } => ErrorCode::IndexOutOfBounds,
TorshError::IndexError { .. } => ErrorCode::IndexOutOfBounds,
TorshError::InvalidDimension { .. } => ErrorCode::InvalidDimension,
TorshError::InvalidArgument(_) => ErrorCode::InvalidOperation,
TorshError::IoError(_) => ErrorCode::IoError,
TorshError::DeviceMismatch => ErrorCode::DeviceMismatch,
TorshError::NotImplemented(_) => ErrorCode::NotImplemented,
TorshError::AllocationError(_) => ErrorCode::AllocationFailed,
TorshError::InvalidOperation(_) => ErrorCode::InvalidOperation,
TorshError::ConversionError(_) => ErrorCode::ConversionError,
TorshError::InvalidState(_) => ErrorCode::InvalidState,
TorshError::UnsupportedOperation { .. } => ErrorCode::NotImplemented,
TorshError::ComputeError(_) => ErrorCode::ComputeError,
TorshError::SerializationError(_) => ErrorCode::SerializationError,
_ => ErrorCode::Unknown,
}
}
}
#[derive(Debug, Clone)]
pub struct LogEvent {
pub timestamp: u64,
pub level: LogLevel,
pub message: String,
pub module_path: String,
pub file: String,
pub line: u32,
pub fields: HashMap<String, String>,
pub span_id: Option<u64>,
pub error_code: Option<ErrorCode>,
}
impl LogEvent {
pub fn new(
level: LogLevel,
message: String,
module_path: String,
file: String,
line: u32,
) -> Self {
#[cfg(feature = "std")]
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
#[cfg(not(feature = "std"))]
let timestamp = 0;
Self {
timestamp,
level,
message,
module_path,
file,
line,
fields: HashMap::new(),
span_id: None,
error_code: None,
}
}
pub fn with_field(mut self, key: String, value: String) -> Self {
self.fields.insert(key, value);
self
}
pub fn with_span_id(mut self, span_id: u64) -> Self {
self.span_id = Some(span_id);
self
}
pub fn with_error_code(mut self, code: ErrorCode) -> Self {
self.error_code = Some(code);
self
}
pub fn format_structured(&self) -> String {
#[cfg(feature = "std")]
{
let mut parts = vec![
format!("timestamp={}", self.timestamp),
format!("level={}", self.level.as_str()),
format!("message=\"{}\"", self.message),
format!("module={}", self.module_path),
format!("file={}:{}", self.file, self.line),
];
if let Some(span_id) = self.span_id {
parts.push(format!("span_id={}", span_id));
}
if let Some(error_code) = self.error_code {
parts.push(format!("error_code={}", error_code.code()));
}
for (key, value) in &self.fields {
parts.push(format!("{}=\"{}\"", key, value));
}
parts.join(" ")
}
#[cfg(not(feature = "std"))]
{
use alloc::vec;
let mut parts = vec![
format!("timestamp={}", self.timestamp),
format!("level={}", self.level.as_str()),
format!("message=\"{}\"", self.message),
format!("module={}", self.module_path),
format!("file={}:{}", self.file, self.line),
];
if let Some(span_id) = self.span_id {
parts.push(format!("span_id={}", span_id));
}
if let Some(error_code) = self.error_code {
parts.push(format!("error_code={}", error_code.code()));
}
for (key, value) in &self.fields {
parts.push(format!("{}=\"{}\"", key, value));
}
parts.join(" ")
}
}
}
#[cfg(feature = "std")]
#[derive(Debug, Clone)]
pub struct Span {
pub span_id: u64,
pub parent_id: Option<u64>,
pub name: String,
pub start_time: Instant,
pub attributes: HashMap<String, String>,
pub events: Vec<SpanEvent>,
}
#[cfg(feature = "std")]
impl Span {
pub fn new(span_id: u64, name: String, parent_id: Option<u64>) -> Self {
Self {
span_id,
parent_id,
name,
start_time: Instant::now(),
attributes: HashMap::new(),
events: Vec::new(),
}
}
pub fn add_attribute(&mut self, key: String, value: String) {
self.attributes.insert(key, value);
}
pub fn add_event(&mut self, event: SpanEvent) {
self.events.push(event);
}
pub fn duration(&self) -> Duration {
self.start_time.elapsed()
}
pub fn close(self) -> SpanMetrics {
SpanMetrics {
span_id: self.span_id,
name: self.name,
duration: self.start_time.elapsed(),
event_count: self.events.len(),
attributes: self.attributes,
}
}
}
#[cfg(feature = "std")]
#[derive(Debug, Clone)]
pub struct SpanEvent {
pub name: String,
pub timestamp: Duration,
pub attributes: HashMap<String, String>,
}
#[cfg(feature = "std")]
#[derive(Debug, Clone)]
pub struct SpanMetrics {
pub span_id: u64,
pub name: String,
pub duration: Duration,
pub event_count: usize,
pub attributes: HashMap<String, String>,
}
#[derive(Debug, Clone)]
pub struct TelemetryConfig {
pub min_log_level: LogLevel,
pub enable_tracing: bool,
pub buffer_size: usize,
pub console_output: bool,
pub structured_logging: bool,
}
impl Default for TelemetryConfig {
fn default() -> Self {
Self {
min_log_level: LogLevel::Info,
enable_tracing: false,
buffer_size: 1000,
console_output: true,
structured_logging: true,
}
}
}
#[cfg(feature = "std")]
pub struct TelemetrySystem {
config: TelemetryConfig,
event_buffer: Mutex<Vec<LogEvent>>,
active_spans: Mutex<HashMap<u64, Span>>,
next_span_id: Mutex<u64>,
closed_spans: Mutex<Vec<SpanMetrics>>,
}
#[cfg(feature = "std")]
impl TelemetrySystem {
pub fn new(config: TelemetryConfig) -> Self {
let buffer_size = config.buffer_size;
Self {
config,
event_buffer: Mutex::new(Vec::with_capacity(buffer_size)),
active_spans: Mutex::new(HashMap::new()),
next_span_id: Mutex::new(1),
closed_spans: Mutex::new(Vec::new()),
}
}
pub fn log(&self, event: LogEvent) {
if !event.level.is_enabled(self.config.min_log_level) {
return;
}
if self.config.console_output {
if self.config.structured_logging {
eprintln!("{}", event.format_structured());
} else {
eprintln!("[{}] {}", event.level.as_str(), event.message);
}
}
let mut buffer = self
.event_buffer
.lock()
.expect("lock should not be poisoned");
buffer.push(event);
if buffer.len() >= self.config.buffer_size {
self.flush_events(&mut buffer);
}
}
pub fn start_span(&self, name: String, parent_id: Option<u64>) -> u64 {
let mut next_id = self
.next_span_id
.lock()
.expect("lock should not be poisoned");
let span_id = *next_id;
*next_id += 1;
let span = Span::new(span_id, name, parent_id);
let mut spans = self
.active_spans
.lock()
.expect("lock should not be poisoned");
spans.insert(span_id, span);
span_id
}
pub fn span_add_attribute(&self, span_id: u64, key: String, value: String) {
let mut spans = self
.active_spans
.lock()
.expect("lock should not be poisoned");
if let Some(span) = spans.get_mut(&span_id) {
span.add_attribute(key, value);
}
}
pub fn end_span(&self, span_id: u64) -> Option<SpanMetrics> {
let mut spans = self
.active_spans
.lock()
.expect("lock should not be poisoned");
if let Some(span) = spans.remove(&span_id) {
let metrics = span.close();
let mut closed = self
.closed_spans
.lock()
.expect("lock should not be poisoned");
closed.push(metrics.clone());
Some(metrics)
} else {
None
}
}
fn flush_events(&self, buffer: &mut Vec<LogEvent>) {
buffer.clear();
}
pub fn get_events(&self) -> Vec<LogEvent> {
let buffer = self
.event_buffer
.lock()
.expect("lock should not be poisoned");
buffer.clone()
}
pub fn get_span_metrics(&self) -> Vec<SpanMetrics> {
let closed = self
.closed_spans
.lock()
.expect("lock should not be poisoned");
closed.clone()
}
pub fn clear(&self) {
let mut buffer = self
.event_buffer
.lock()
.expect("lock should not be poisoned");
buffer.clear();
let mut closed = self
.closed_spans
.lock()
.expect("lock should not be poisoned");
closed.clear();
}
}
#[cfg(feature = "std")]
static TELEMETRY: OnceLock<Arc<TelemetrySystem>> = OnceLock::new();
#[cfg(feature = "std")]
pub fn init_telemetry(config: TelemetryConfig) {
TELEMETRY.get_or_init(|| Arc::new(TelemetrySystem::new(config)));
}
#[cfg(feature = "std")]
pub fn telemetry() -> Arc<TelemetrySystem> {
TELEMETRY
.get_or_init(|| Arc::new(TelemetrySystem::new(TelemetryConfig::default())))
.clone()
}
#[macro_export]
macro_rules! log {
($level:expr, $msg:expr $(, $key:expr => $value:expr)*) => {{
#[cfg(feature = "std")]
{
let mut event = $crate::telemetry::LogEvent::new(
$level,
$msg.to_string(),
module_path!().to_string(),
file!().to_string(),
line!(),
);
$(
event = event.with_field($key.to_string(), $value.to_string());
)*
$crate::telemetry::telemetry().log(event);
}
}};
}
#[macro_export]
macro_rules! trace {
($msg:expr $(, $key:expr => $value:expr)*) => {
$crate::log!($crate::telemetry::LogLevel::Trace, $msg $(, $key => $value)*)
};
}
#[macro_export]
macro_rules! debug {
($msg:expr $(, $key:expr => $value:expr)*) => {
$crate::log!($crate::telemetry::LogLevel::Debug, $msg $(, $key => $value)*)
};
}
#[macro_export]
macro_rules! info {
($msg:expr $(, $key:expr => $value:expr)*) => {
$crate::log!($crate::telemetry::LogLevel::Info, $msg $(, $key => $value)*)
};
}
#[macro_export]
macro_rules! warn {
($msg:expr $(, $key:expr => $value:expr)*) => {
$crate::log!($crate::telemetry::LogLevel::Warn, $msg $(, $key => $value)*)
};
}
#[macro_export]
macro_rules! error {
($msg:expr $(, $key:expr => $value:expr)*) => {
$crate::log!($crate::telemetry::LogLevel::Error, $msg $(, $key => $value)*)
};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_log_level_ordering() {
assert!(LogLevel::Debug > LogLevel::Trace);
assert!(LogLevel::Info > LogLevel::Debug);
assert!(LogLevel::Warn > LogLevel::Info);
assert!(LogLevel::Error > LogLevel::Warn);
assert!(LogLevel::Fatal > LogLevel::Error);
}
#[test]
fn test_log_level_is_enabled() {
let min_level = LogLevel::Info;
assert!(!LogLevel::Trace.is_enabled(min_level));
assert!(!LogLevel::Debug.is_enabled(min_level));
assert!(LogLevel::Info.is_enabled(min_level));
assert!(LogLevel::Warn.is_enabled(min_level));
assert!(LogLevel::Error.is_enabled(min_level));
}
#[test]
fn test_error_code_mapping() {
assert_eq!(ErrorCode::ShapeMismatch.code(), 1000);
assert_eq!(ErrorCode::IndexOutOfBounds.code(), 2000);
assert_eq!(ErrorCode::TypeMismatch.code(), 3000);
assert_eq!(ErrorCode::DeviceMismatch.code(), 4000);
assert_eq!(ErrorCode::AllocationFailed.code(), 5000);
}
#[test]
fn test_error_code_from_torsh_error() {
let error = TorshError::InvalidShape("test".to_string());
assert_eq!(ErrorCode::from_torsh_error(&error), ErrorCode::InvalidShape);
let error = TorshError::DeviceMismatch;
assert_eq!(
ErrorCode::from_torsh_error(&error),
ErrorCode::DeviceMismatch
);
}
#[test]
fn test_log_event_creation() {
let event = LogEvent::new(
LogLevel::Info,
"test message".to_string(),
"test_module".to_string(),
"test.rs".to_string(),
42,
);
assert_eq!(event.level, LogLevel::Info);
assert_eq!(event.message, "test message");
assert_eq!(event.line, 42);
}
#[test]
fn test_log_event_with_metadata() {
let event = LogEvent::new(
LogLevel::Error,
"error occurred".to_string(),
"test_module".to_string(),
"test.rs".to_string(),
10,
)
.with_field("tensor_id".to_string(), "123".to_string())
.with_error_code(ErrorCode::ComputeError);
assert!(event.fields.contains_key("tensor_id"));
assert_eq!(event.error_code, Some(ErrorCode::ComputeError));
}
#[test]
#[cfg(feature = "std")]
fn test_telemetry_system() {
let config = TelemetryConfig {
min_log_level: LogLevel::Debug,
console_output: false,
..Default::default()
};
let telemetry = TelemetrySystem::new(config);
let event = LogEvent::new(
LogLevel::Info,
"test".to_string(),
"test".to_string(),
"test.rs".to_string(),
1,
);
telemetry.log(event.clone());
let events = telemetry.get_events();
assert_eq!(events.len(), 1);
assert_eq!(events[0].message, "test");
}
#[test]
#[cfg(feature = "std")]
fn test_span_creation() {
let config = TelemetryConfig::default();
let telemetry = TelemetrySystem::new(config);
let span_id = telemetry.start_span("test_operation".to_string(), None);
telemetry.span_add_attribute(span_id, "key".to_string(), "value".to_string());
let metrics = telemetry
.end_span(span_id)
.expect("end_span should succeed");
assert_eq!(metrics.name, "test_operation");
assert!(metrics.attributes.contains_key("key"));
}
#[test]
#[cfg(feature = "std")]
fn test_log_filtering() {
let config = TelemetryConfig {
min_log_level: LogLevel::Warn,
console_output: false,
..Default::default()
};
let telemetry = TelemetrySystem::new(config);
telemetry.log(LogEvent::new(
LogLevel::Info,
"info".to_string(),
"test".to_string(),
"test.rs".to_string(),
1,
));
telemetry.log(LogEvent::new(
LogLevel::Error,
"error".to_string(),
"test".to_string(),
"test.rs".to_string(),
2,
));
let events = telemetry.get_events();
assert_eq!(events.len(), 1);
assert_eq!(events[0].level, LogLevel::Error);
}
}