use std::sync::atomic::{AtomicI64, Ordering};
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use algocline_core::execution::{
CancelCode, CancelInfo, CancelReason, ExecutionResult, ExecutionState, ExecutionStateTag,
FailureInfo, FailureKind, PauseInfo, PauseKind, PausePrompt, ProgressEvent,
};
use algocline_core::{ExecutionMetrics, ExecutionObserver, LlmQuery};
use mlua_isle::AsyncTask;
use tokio::sync::broadcast;
use tokio::sync::mpsc;
use tokio::sync::Mutex;
use tokio_util::sync::CancellationToken;
use crate::llm_bridge::LlmRequest;
pub(crate) fn now_ms() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as i64
}
pub(crate) async fn transition_state(
state: &Arc<Mutex<ExecutionState>>,
bus_tx: &broadcast::Sender<ProgressEvent>,
new_state: ExecutionState,
) {
let from_tag = {
let guard = state.lock().await;
guard.tag()
};
let to_tag = new_state.tag();
if is_terminal_tag(from_tag) && is_terminal_tag(to_tag) {
return;
}
{
let mut guard = state.lock().await;
*guard = new_state;
}
let _ = bus_tx.send(ProgressEvent::StateTransition {
from: from_tag,
to: to_tag,
at: now_ms(),
});
}
fn is_terminal_tag(tag: ExecutionStateTag) -> bool {
matches!(
tag,
ExecutionStateTag::Done | ExecutionStateTag::Failed | ExecutionStateTag::Cancelled
)
}
pub(crate) async fn build_cancel_info(
state: &Arc<Mutex<ExecutionState>>,
reason: CancelReason,
) -> CancelInfo {
let state_before = {
let guard = state.lock().await;
match &*guard {
ExecutionState::Cancelled(prior) => (*prior.state_before).clone(),
other => other.clone(),
}
};
CancelInfo {
reason,
observed_at: now_ms(),
state_before: Box::new(state_before),
}
}
pub(crate) struct DriverContext {
pub state: Arc<Mutex<ExecutionState>>,
pub bus_tx: broadcast::Sender<ProgressEvent>,
pub cancel_token: CancellationToken,
pub resp_txs: super::record::RespTxsMap,
pub last_active: Arc<AtomicI64>,
pub metrics: Arc<ExecutionMetrics>,
}
pub(crate) async fn driver_loop(
ctx: DriverContext,
mut exec_task: AsyncTask,
mut llm_rx: mpsc::Receiver<LlmRequest>,
) {
ctx.last_active.store(now_ms(), Ordering::Relaxed);
if ctx.cancel_token.is_cancelled() {
let reason = CancelReason {
code: CancelCode::User,
detail: Some("cancelled before execution started (checkpoint A)".into()),
requested_at: now_ms(),
};
let info = build_cancel_info(&ctx.state, reason).await;
transition_state(&ctx.state, &ctx.bus_tx, ExecutionState::Cancelled(info)).await;
return;
}
let _ = ctx.bus_tx.send(ProgressEvent::Tick {
phase: "running".into(),
at: now_ms(),
});
loop {
tokio::select! {
biased;
_ = ctx.cancel_token.cancelled() => {
let reason = CancelReason {
code: CancelCode::User,
detail: Some("cancelled at select! checkpoint D".into()),
requested_at: now_ms(),
};
let info = build_cancel_info(&ctx.state, reason).await;
transition_state(&ctx.state, &ctx.bus_tx, ExecutionState::Cancelled(info)).await;
break;
}
result = &mut exec_task => {
match result {
Ok(json_str) => {
match serde_json::from_str::<serde_json::Value>(&json_str) {
Ok(v) => {
let done = ExecutionState::Done(ExecutionResult {
value: v,
usage: ctx.metrics.usage_aggregate(),
finished_at: now_ms(),
});
transition_state(&ctx.state, &ctx.bus_tx, done).await;
}
Err(e) => {
tracing::warn!(
"driver_loop: JSON parse error on exec_task result: {e}"
);
let failed = ExecutionState::Failed(FailureInfo {
message: format!("JSON parse error: {e}"),
kind: FailureKind::EngineError,
occurred_at: now_ms(),
});
transition_state(&ctx.state, &ctx.bus_tx, failed).await;
}
}
}
Err(e) => {
tracing::warn!("driver_loop: exec_task error: {e}");
let failed = ExecutionState::Failed(FailureInfo {
message: e.to_string(),
kind: FailureKind::LuaError,
occurred_at: now_ms(),
});
transition_state(&ctx.state, &ctx.bus_tx, failed).await;
}
}
break;
}
Some(req) = llm_rx.recv() => {
if ctx.cancel_token.is_cancelled() {
let reason = CancelReason {
code: CancelCode::User,
detail: Some("cancelled before pause publish (checkpoint B)".into()),
requested_at: now_ms(),
};
let info = build_cancel_info(&ctx.state, reason).await;
transition_state(&ctx.state, &ctx.bus_tx, ExecutionState::Cancelled(info)).await;
for qr in req.queries {
if let Err(_e) = qr.resp_tx.send(Err("cancelled".into())) {
tracing::debug!(
"driver_loop checkpoint B: failed to send cancel to coroutine \
(receiver already dropped)"
);
}
}
break;
}
let kind = if req.queries.len() == 1 {
PauseKind::Single
} else {
PauseKind::Batch
};
let prompts: Vec<PausePrompt> = req.queries.iter().map(|qr| PausePrompt {
query_id: qr.id.as_str().to_owned(),
prompt: qr.prompt.clone(),
}).collect();
let queries_for_observer: Vec<LlmQuery> = req.queries.iter()
.map(|qr| LlmQuery {
id: qr.id.clone(),
prompt: qr.prompt.clone(),
system: qr.system.clone(),
max_tokens: qr.max_tokens,
grounded: qr.grounded,
underspecified: qr.underspecified,
})
.collect();
ctx.metrics.create_observer().on_paused(&queries_for_observer);
let pause_info = PauseInfo {
kind,
prompts,
paused_at: now_ms(),
};
{
let mut txs = ctx.resp_txs.lock().await;
for qr in req.queries {
txs.insert(qr.id, qr.resp_tx);
}
}
let pause_event = ProgressEvent::PauseRequested {
info: pause_info.clone(),
at: now_ms(),
};
transition_state(&ctx.state, &ctx.bus_tx, ExecutionState::Paused(pause_info)).await;
let _ = ctx.bus_tx.send(pause_event);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use algocline_core::execution::ExecutionState;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{broadcast, Mutex};
use tokio_util::sync::CancellationToken;
fn tmp_dirs() -> (
Arc<crate::state::JsonFileStore>,
Arc<crate::card::FileCardStore>,
std::path::PathBuf,
) {
let tmp = tempfile::tempdir().expect("test tempdir");
let root = tmp.path().to_path_buf();
std::mem::forget(tmp);
(
Arc::new(crate::state::JsonFileStore::new(root.join("state"))),
Arc::new(crate::card::FileCardStore::new(root.join("cards"))),
root.join("scenarios"),
)
}
fn make_state_and_bus() -> (
Arc<Mutex<ExecutionState>>,
broadcast::Sender<ProgressEvent>,
broadcast::Receiver<ProgressEvent>,
) {
let state = Arc::new(Mutex::new(ExecutionState::Running));
let (tx, rx) = broadcast::channel(256);
(state, tx, rx)
}
#[tokio::test]
async fn cancel_at_checkpoint_a_before_first_lua_chunk() {
let executor = crate::executor::Executor::new(vec![]).await.unwrap();
let (state_store, card_store, scenarios_dir) = tmp_dirs();
let code = "while true do end".to_string();
let session = executor
.start_session(
code,
serde_json::json!({}),
vec![],
vec![],
state_store,
card_store,
scenarios_dir,
)
.await
.unwrap();
let (exec_task, llm_rx, _vm_driver, _metrics) = session.into_driver_parts();
let (state, bus_tx, _rx) = make_state_and_bus();
let cancel_token = CancellationToken::new();
cancel_token.cancel();
let resp_txs = Arc::new(Mutex::new(HashMap::new()));
let ctx = DriverContext {
state: state.clone(),
bus_tx,
cancel_token,
resp_txs,
last_active: Arc::new(std::sync::atomic::AtomicI64::new(0)),
metrics: Arc::new(algocline_core::ExecutionMetrics::new()),
};
driver_loop(ctx, exec_task, llm_rx).await;
let guard = state.lock().await;
assert!(
matches!(*guard, ExecutionState::Cancelled(_)),
"expected Cancelled after pre-cancel checkpoint A, got: {:?}",
guard.tag()
);
}
#[tokio::test]
async fn cancel_at_checkpoint_b_before_pause_publish() {
let executor = crate::executor::Executor::new(vec![]).await.unwrap();
let (state_store, card_store, scenarios_dir) = tmp_dirs();
let code = r#"return alc.llm("question")"#.to_string();
let session = executor
.start_session(
code,
serde_json::json!({}),
vec![],
vec![],
state_store,
card_store,
scenarios_dir,
)
.await
.unwrap();
let (exec_task, llm_rx, _vm_driver, _metrics) = session.into_driver_parts();
let (state, bus_tx, _rx) = make_state_and_bus();
let cancel_token = CancellationToken::new();
cancel_token.cancel();
let resp_txs = Arc::new(Mutex::new(HashMap::new()));
let ctx = DriverContext {
state: state.clone(),
bus_tx,
cancel_token,
resp_txs,
last_active: Arc::new(std::sync::atomic::AtomicI64::new(0)),
metrics: Arc::new(algocline_core::ExecutionMetrics::new()),
};
driver_loop(ctx, exec_task, llm_rx).await;
let guard = state.lock().await;
assert!(
matches!(*guard, ExecutionState::Cancelled(_)),
"expected Cancelled when token set before pause publish, got: {:?}",
guard.tag()
);
}
#[test]
fn cancel_idempotent() {
let cancel_token = CancellationToken::new();
cancel_token.cancel();
assert!(cancel_token.is_cancelled());
cancel_token.cancel();
assert!(cancel_token.is_cancelled());
}
#[tokio::test]
async fn transition_state_terminal_is_idempotent() {
let (state, bus_tx, _rx) = make_state_and_bus();
let info1 = CancelInfo {
reason: CancelReason {
code: CancelCode::User,
detail: Some("first".into()),
requested_at: 100,
},
observed_at: 110,
state_before: Box::new(ExecutionState::Running),
};
transition_state(&state, &bus_tx, ExecutionState::Cancelled(info1.clone())).await;
let info2 = CancelInfo {
reason: CancelReason {
code: CancelCode::Internal,
detail: Some("second".into()),
requested_at: 200,
},
observed_at: 210,
state_before: Box::new(ExecutionState::Paused(PauseInfo {
kind: PauseKind::Single,
prompts: vec![],
paused_at: 150,
})),
};
transition_state(&state, &bus_tx, ExecutionState::Cancelled(info2)).await;
let guard = state.lock().await;
match &*guard {
ExecutionState::Cancelled(seen) => {
assert_eq!(
seen.reason.detail.as_deref(),
Some("first"),
"second transition must not overwrite the first CancelInfo"
);
assert!(
matches!(*seen.state_before, ExecutionState::Running),
"state_before must remain the original pre-cancel state, got: {:?}",
seen.state_before
);
}
other => panic!("expected Cancelled, got {other:?}"),
}
}
#[tokio::test]
async fn build_cancel_info_does_not_nest_on_already_cancelled_state() {
let original_pause = ExecutionState::Paused(PauseInfo {
kind: PauseKind::Single,
prompts: vec![],
paused_at: 100,
});
let outer = ExecutionState::Cancelled(CancelInfo {
reason: CancelReason {
code: CancelCode::User,
detail: Some("first".into()),
requested_at: 200,
},
observed_at: 210,
state_before: Box::new(original_pause.clone()),
});
let state = Arc::new(Mutex::new(outer));
let second_reason = CancelReason {
code: CancelCode::Internal,
detail: Some("driver-checkpoint".into()),
requested_at: 300,
};
let info = build_cancel_info(&state, second_reason).await;
match *info.state_before {
ExecutionState::Paused(ref p) => {
assert_eq!(p.paused_at, 100, "expected original Paused snapshot");
}
ref other => {
panic!("state_before must inherit inner pre-cancel state, got nested: {other:?}")
}
}
}
#[tokio::test]
async fn driver_loop_completes_with_done() {
let executor = crate::executor::Executor::new(vec![]).await.unwrap();
let (state_store, card_store, scenarios_dir) = tmp_dirs();
let session = executor
.start_session(
"return 42".to_string(),
serde_json::json!({}),
vec![],
vec![],
state_store,
card_store,
scenarios_dir,
)
.await
.unwrap();
let (exec_task, llm_rx, _vm_driver, _metrics) = session.into_driver_parts();
let (state, bus_tx, _rx) = make_state_and_bus();
let cancel_token = CancellationToken::new();
let resp_txs = Arc::new(Mutex::new(HashMap::new()));
let ctx = DriverContext {
state: state.clone(),
bus_tx,
cancel_token,
resp_txs,
last_active: Arc::new(std::sync::atomic::AtomicI64::new(0)),
metrics: Arc::new(algocline_core::ExecutionMetrics::new()),
};
driver_loop(ctx, exec_task, llm_rx).await;
let guard = state.lock().await;
assert!(
matches!(*guard, ExecutionState::Done(_)),
"expected Done for trivial Lua, got: {:?}",
guard.tag()
);
}
#[test]
fn checkpoint_markers_exist_in_driver() {
let source = include_str!("driver.rs");
for marker in &["checkpoint A", "checkpoint B", "checkpoint D"] {
assert!(
source.contains(marker),
"driver.rs must contain comment '{marker}'"
);
}
}
#[test]
fn checkpoint_c_exists_in_registry() {
let registry_source = include_str!("registry.rs");
assert!(
registry_source.contains("checkpoint C"),
"registry.rs must contain comment 'checkpoint C'"
);
}
}