use std::collections::HashMap;
use std::sync::Arc;
use algocline_core::execution::{CancelInfo, ExecutionState};
use algocline_core::QueryId;
use tokio::sync::broadcast;
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use algocline_core::execution::ProgressEvent;
pub(crate) type RespTxsMap =
Arc<Mutex<HashMap<QueryId, tokio::sync::oneshot::Sender<Result<String, String>>>>>;
pub struct SessionRecord {
pub(crate) state: Arc<Mutex<ExecutionState>>,
pub(crate) bus_tx: broadcast::Sender<ProgressEvent>,
pub(crate) cancel_token: CancellationToken,
pub(crate) join_handle: Mutex<Option<JoinHandle<()>>>,
pub(crate) resp_txs: RespTxsMap,
pub(crate) first_cancel_info: Mutex<Option<CancelInfo>>,
}
impl SessionRecord {
#[cfg(test)]
pub(crate) fn new(
state: Arc<Mutex<ExecutionState>>,
bus_capacity: usize,
cancel_token: CancellationToken,
join_handle: JoinHandle<()>,
resp_txs: RespTxsMap,
) -> Self {
let (bus_tx, _) = broadcast::channel(bus_capacity);
Self {
state,
bus_tx,
cancel_token,
join_handle: Mutex::new(Some(join_handle)),
resp_txs,
first_cancel_info: Mutex::new(None),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use algocline_core::execution::ExecutionState;
use tokio::task;
#[tokio::test]
async fn record_created_with_running_state() {
let state = Arc::new(Mutex::new(ExecutionState::Running));
let cancel_token = CancellationToken::new();
let handle = task::spawn(async {});
let resp_txs: RespTxsMap = Arc::new(Mutex::new(HashMap::new()));
let record = SessionRecord::new(state.clone(), 256, cancel_token, handle, resp_txs);
let guard = record.state.lock().await;
assert!(matches!(*guard, ExecutionState::Running));
}
#[tokio::test]
async fn bus_tx_does_not_crash_caller_with_zero_observers() {
use algocline_core::execution::{ExecutionStateTag, ProgressEvent};
let state = Arc::new(Mutex::new(ExecutionState::Running));
let cancel_token = CancellationToken::new();
let handle = task::spawn(async {});
let resp_txs: RespTxsMap = Arc::new(Mutex::new(HashMap::new()));
let record = SessionRecord::new(state, 256, cancel_token, handle, resp_txs);
let event = ProgressEvent::StateTransition {
from: ExecutionStateTag::Running,
to: ExecutionStateTag::Done,
at: 0,
};
let _ = record.bus_tx.send(event);
}
}