use crate::{Config, MarketDataBus, SignalBus, metrics::metrics};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::Instant;
use tokio::sync::{RwLock, watch};
pub trait LogLevelController: Send + Sync {
fn set_log_level(&self, filter_str: &str) -> Result<(), String>;
fn current_filter(&self) -> Option<String>;
}
#[async_trait::async_trait]
pub trait AffinityRecorder: Send + Sync {
async fn record_trade(
&self,
strategy: &str,
asset: &str,
pnl: f64,
is_winner: bool,
rr_ratio: Option<f64>,
);
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ServiceState {
Standby,
Running,
Stopped,
}
impl std::fmt::Display for ServiceState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ServiceState::Standby => write!(f, "standby"),
ServiceState::Running => write!(f, "running"),
ServiceState::Stopped => write!(f, "stopped"),
}
}
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct ModuleHealth {
pub name: String,
pub healthy: bool,
#[serde(skip)]
pub last_check: std::time::Instant,
pub message: Option<String>,
}
pub struct JanusState {
pub config: Config,
pub signal_bus: SignalBus,
pub market_data_bus: MarketDataBus,
start_time: Instant,
shutdown_requested: AtomicBool,
module_health: RwLock<Vec<ModuleHealth>>,
signals_generated: AtomicU64,
signals_persisted: AtomicU64,
redis_client: RwLock<Option<redis::Client>>,
service_state_tx: watch::Sender<ServiceState>,
service_state_rx: watch::Receiver<ServiceState>,
log_level_controller: RwLock<Option<Box<dyn LogLevelController>>>,
current_regime: RwLock<Option<String>>,
current_threat: RwLock<Option<f64>>,
affinity_recorder: RwLock<Option<Box<dyn AffinityRecorder>>>,
}
impl JanusState {
pub async fn new(config: Config) -> crate::Result<Self> {
let signal_bus = SignalBus::new(1000);
let market_data_bus = MarketDataBus::new(5000);
let (service_state_tx, service_state_rx) = watch::channel(ServiceState::Standby);
Ok(Self {
config,
signal_bus,
market_data_bus,
start_time: Instant::now(),
shutdown_requested: AtomicBool::new(false),
module_health: RwLock::new(Vec::new()),
signals_generated: AtomicU64::new(0),
signals_persisted: AtomicU64::new(0),
redis_client: RwLock::new(None),
service_state_tx,
service_state_rx,
log_level_controller: RwLock::new(None),
current_regime: RwLock::new(None),
current_threat: RwLock::new(None),
affinity_recorder: RwLock::new(None),
})
}
pub async fn current_regime(&self) -> Option<String> {
self.current_regime.read().await.clone()
}
pub async fn set_current_regime(&self, regime: impl Into<String>) {
let mut guard = self.current_regime.write().await;
*guard = Some(regime.into());
}
pub async fn current_threat(&self) -> Option<f64> {
*self.current_threat.read().await
}
pub async fn set_current_threat(&self, fear: f64) {
let mut guard = self.current_threat.write().await;
*guard = Some(fear);
}
pub async fn set_affinity_recorder(&self, recorder: Box<dyn AffinityRecorder>) {
let mut guard = self.affinity_recorder.write().await;
*guard = Some(recorder);
}
pub async fn has_affinity_recorder(&self) -> bool {
self.affinity_recorder.read().await.is_some()
}
pub async fn record_affinity_outcome(
&self,
strategy: &str,
asset: &str,
pnl: f64,
is_winner: bool,
rr_ratio: Option<f64>,
) -> bool {
let guard = self.affinity_recorder.read().await;
match guard.as_ref() {
Some(recorder) => {
recorder
.record_trade(strategy, asset, pnl, is_winner, rr_ratio)
.await;
true
}
None => false,
}
}
pub async fn set_log_level_controller(&self, controller: Box<dyn LogLevelController>) {
let mut guard = self.log_level_controller.write().await;
*guard = Some(controller);
tracing::debug!("log-level controller installed in JanusState");
}
pub async fn set_log_level(&self, filter_str: &str) -> Result<(), String> {
let guard = self.log_level_controller.read().await;
match guard.as_ref() {
Some(ctrl) => ctrl.set_log_level(filter_str),
None => Err("no log-level controller installed".to_string()),
}
}
pub async fn current_log_filter(&self) -> Option<String> {
let guard = self.log_level_controller.read().await;
guard.as_ref().and_then(|ctrl| ctrl.current_filter())
}
pub fn uptime_seconds(&self) -> u64 {
self.start_time.elapsed().as_secs()
}
pub fn is_shutdown_requested(&self) -> bool {
self.shutdown_requested.load(Ordering::SeqCst)
}
pub fn request_shutdown(&self) {
self.shutdown_requested.store(true, Ordering::SeqCst);
}
pub async fn shutdown(&self) -> crate::Result<()> {
tracing::info!("Initiating graceful shutdown...");
self.request_shutdown();
let _ = self.service_state_tx.send(ServiceState::Stopped);
let mut redis = self.redis_client.write().await;
if redis.is_some() {
metrics().redis_connected.set(0.0);
}
*redis = None;
tracing::info!("Shutdown complete");
Ok(())
}
pub fn start_services(&self) -> bool {
self.service_state_tx.send_if_modified(|current| {
if *current == ServiceState::Running {
false
} else {
tracing::info!("Service state: {} → running", current);
*current = ServiceState::Running;
true
}
})
}
pub fn stop_services(&self) -> bool {
self.service_state_tx.send_if_modified(|current| {
if *current == ServiceState::Stopped {
false
} else {
tracing::info!("Service state: {} → stopped", current);
*current = ServiceState::Stopped;
true
}
})
}
pub fn service_state(&self) -> ServiceState {
*self.service_state_tx.borrow()
}
pub fn are_services_active(&self) -> bool {
*self.service_state_tx.borrow() == ServiceState::Running
}
pub async fn wait_for_services_start(&self) -> bool {
let mut rx = self.service_state_rx.clone();
if *rx.borrow_and_update() == ServiceState::Running {
return true;
}
loop {
tokio::select! {
result = rx.changed() => {
match result {
Ok(()) => {
if *rx.borrow() == ServiceState::Running {
return true;
}
if self.is_shutdown_requested() {
return false;
}
}
Err(_) => {
return false;
}
}
}
_ = tokio::time::sleep(tokio::time::Duration::from_millis(250)) => {
if self.is_shutdown_requested() {
return false;
}
}
}
}
}
pub fn subscribe_service_state(&self) -> watch::Receiver<ServiceState> {
self.service_state_rx.clone()
}
pub async fn register_module_health(
&self,
name: impl Into<String>,
healthy: bool,
message: Option<String>,
) {
let mut health = self.module_health.write().await;
let name = name.into();
if let Some(existing) = health.iter_mut().find(|h| h.name == name) {
existing.healthy = healthy;
existing.last_check = Instant::now();
existing.message = message;
} else {
health.push(ModuleHealth {
name,
healthy,
last_check: Instant::now(),
message,
});
}
}
pub async fn get_module_health(&self) -> Vec<ModuleHealth> {
self.module_health.read().await.clone()
}
pub async fn all_modules_healthy(&self) -> bool {
let health = self.module_health.read().await;
health.iter().all(|h| h.healthy)
}
pub fn increment_signals_generated(&self) {
self.signals_generated.fetch_add(1, Ordering::SeqCst);
}
pub fn signals_generated(&self) -> u64 {
self.signals_generated.load(Ordering::SeqCst)
}
pub fn increment_signals_persisted(&self) {
self.signals_persisted.fetch_add(1, Ordering::SeqCst);
}
pub fn signals_persisted(&self) -> u64 {
self.signals_persisted.load(Ordering::SeqCst)
}
pub async fn redis_client(&self) -> crate::Result<redis::Client> {
let mut client = self.redis_client.write().await;
if client.is_none() {
match redis::Client::open(self.config.redis.url.as_str()) {
Ok(new_client) => {
*client = Some(new_client);
}
Err(e) => {
metrics().redis_connected.set(0.0);
return Err(e.into());
}
}
}
Ok(client.as_ref().unwrap().clone())
}
pub async fn probe_redis(&self) {
let client = match self.redis_client().await {
Ok(c) => c,
Err(e) => {
tracing::warn!("Redis probe: failed to create client — {e}");
metrics().redis_connected.set(0.0);
return;
}
};
match client.get_multiplexed_async_connection().await {
Ok(mut conn) => {
let pong: Result<String, _> = redis::cmd("PING").query_async(&mut conn).await;
match pong {
Ok(_) => {
metrics().redis_connected.set(1.0);
tracing::info!("Redis probe: connected ✓");
}
Err(e) => {
metrics().redis_connected.set(0.0);
tracing::warn!("Redis probe: PING failed — {e}");
}
}
}
Err(e) => {
metrics().redis_connected.set(0.0);
tracing::warn!("Redis probe: connection failed — {e}");
}
}
}
pub async fn health_status(&self) -> HealthStatus {
let module_health = self.get_module_health().await;
let all_healthy = module_health.iter().all(|h| h.healthy);
HealthStatus {
status: if all_healthy { "healthy" } else { "degraded" }.to_string(),
uptime_seconds: self.uptime_seconds(),
signals_generated: self.signals_generated(),
signals_persisted: self.signals_persisted(),
modules: module_health
.iter()
.map(|h| ModuleHealthSummary {
name: h.name.clone(),
healthy: h.healthy,
message: h.message.clone(),
})
.collect(),
shutdown_requested: self.is_shutdown_requested(),
service_state: self.service_state(),
}
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct HealthStatus {
pub status: String,
pub uptime_seconds: u64,
pub signals_generated: u64,
pub signals_persisted: u64,
pub modules: Vec<ModuleHealthSummary>,
pub shutdown_requested: bool,
pub service_state: ServiceState,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ModuleHealthSummary {
pub name: String,
pub healthy: bool,
pub message: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_state_creation() {
let config = Config::default();
let state = JanusState::new(config).await.unwrap();
assert!(!state.is_shutdown_requested());
assert_eq!(state.signals_generated(), 0);
assert_eq!(state.service_state(), ServiceState::Standby);
assert!(!state.are_services_active());
}
#[tokio::test]
async fn test_service_lifecycle() {
let config = Config::default();
let state = JanusState::new(config).await.unwrap();
assert_eq!(state.service_state(), ServiceState::Standby);
assert!(!state.are_services_active());
assert!(state.start_services());
assert_eq!(state.service_state(), ServiceState::Running);
assert!(state.are_services_active());
assert!(!state.start_services());
assert!(state.stop_services());
assert_eq!(state.service_state(), ServiceState::Stopped);
assert!(!state.are_services_active());
assert!(!state.stop_services());
}
#[tokio::test]
async fn test_wait_for_services_start() {
let config = Config::default();
let state = std::sync::Arc::new(JanusState::new(config).await.unwrap());
let state2 = state.clone();
let handle = tokio::spawn(async move { state2.wait_for_services_start().await });
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
state.start_services();
let result = tokio::time::timeout(tokio::time::Duration::from_secs(2), handle)
.await
.expect("timed out")
.expect("task panicked");
assert!(result);
}
#[tokio::test]
async fn test_module_health_registration() {
let config = Config::default();
let state = JanusState::new(config).await.unwrap();
state.register_module_health("forward", true, None).await;
state
.register_module_health("backward", true, Some("running".to_string()))
.await;
let health = state.get_module_health().await;
assert_eq!(health.len(), 2);
assert!(state.all_modules_healthy().await);
}
#[tokio::test]
async fn test_signal_counters() {
let config = Config::default();
let state = JanusState::new(config).await.unwrap();
state.increment_signals_generated();
state.increment_signals_generated();
state.increment_signals_persisted();
assert_eq!(state.signals_generated(), 2);
assert_eq!(state.signals_persisted(), 1);
}
type RecordedCall = (String, String, f64, bool, Option<f64>);
struct CountingRecorder {
calls: std::sync::Arc<std::sync::Mutex<Vec<RecordedCall>>>,
}
#[async_trait::async_trait]
impl AffinityRecorder for CountingRecorder {
async fn record_trade(
&self,
strategy: &str,
asset: &str,
pnl: f64,
is_winner: bool,
rr_ratio: Option<f64>,
) {
self.calls.lock().unwrap().push((
strategy.to_string(),
asset.to_string(),
pnl,
is_winner,
rr_ratio,
));
}
}
#[tokio::test]
async fn record_affinity_outcome_no_op_without_recorder() {
let state = JanusState::new(Config::default()).await.unwrap();
assert!(!state.has_affinity_recorder().await);
let recorded = state
.record_affinity_outcome("ema_cross", "BTC", 100.0, true, Some(2.0))
.await;
assert!(!recorded);
}
#[tokio::test]
async fn record_affinity_outcome_reaches_installed_recorder() {
let state = JanusState::new(Config::default()).await.unwrap();
let calls = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
state
.set_affinity_recorder(Box::new(CountingRecorder {
calls: calls.clone(),
}))
.await;
assert!(state.has_affinity_recorder().await);
let recorded = state
.record_affinity_outcome("ema_cross", "BTC", -25.0, false, None)
.await;
assert!(recorded);
let calls = calls.lock().unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(
calls[0],
("ema_cross".to_string(), "BTC".to_string(), -25.0, false, None)
);
}
}