#[allow(unused_imports)]
use crate::sync_util::LockExt;
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
use super::hooks::GetSteeringMessagesFn;
use super::message::{LoopMessage, UserMessage};
use super::types::QueueMode;
pub const MID_TURN_STEER_WRAPPER: &str = "[Mid-turn steer queued by the user. Do not treat this as a new task; use it only as additional guidance for the current task after completing the current step.]";
pub fn format_steer_user_message(content: &str) -> String {
[MID_TURN_STEER_WRAPPER, content].join("\n")
}
pub fn steering_from_queue(
queue: Arc<Mutex<VecDeque<String>>>,
mode: QueueMode,
) -> GetSteeringMessagesFn {
Arc::new(move || {
let queue = queue.clone();
Box::pin(async move {
let drained: Vec<String> = {
let mut q = queue.lock_ignore_poison();
match mode {
QueueMode::All => q.drain(..).collect(),
QueueMode::OneAtATime => q.pop_front().into_iter().collect(),
}
};
drained
.into_iter()
.map(|content| {
LoopMessage::User(UserMessage {
content: format_steer_user_message(&content),
})
})
.collect()
})
})
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn empty_queue_returns_empty() {
let queue = Arc::new(Mutex::new(VecDeque::<String>::new()));
let hook = steering_from_queue(queue, QueueMode::All);
let messages = hook().await;
assert!(messages.is_empty());
}
#[tokio::test]
async fn all_mode_drains_fifo() {
let queue = Arc::new(Mutex::new(VecDeque::<String>::from(vec![
"first".to_string(),
"second".to_string(),
"third".to_string(),
])));
let hook = steering_from_queue(queue.clone(), QueueMode::All);
let messages = hook().await;
assert_eq!(messages.len(), 3);
let contents: Vec<_> = messages
.iter()
.map(|m| match m {
LoopMessage::User(u) => u.content.clone(),
_ => panic!("expected User"),
})
.collect();
assert!(contents[0].starts_with(MID_TURN_STEER_WRAPPER));
assert!(contents[0].ends_with("first"));
assert!(contents[1].ends_with("second"));
assert!(contents[2].ends_with("third"));
assert!(queue.lock().unwrap().is_empty());
}
#[tokio::test]
async fn one_at_a_time_drains_oldest_only() {
let queue = Arc::new(Mutex::new(VecDeque::<String>::from(vec![
"first".to_string(),
"second".to_string(),
])));
let hook = steering_from_queue(queue.clone(), QueueMode::OneAtATime);
let m1 = hook().await;
assert_eq!(m1.len(), 1);
assert!(matches!(
&m1[0],
LoopMessage::User(u) if u.content.starts_with(MID_TURN_STEER_WRAPPER) && u.content.ends_with("first")
));
assert_eq!(queue.lock().unwrap().len(), 1);
let m2 = hook().await;
assert_eq!(m2.len(), 1);
assert!(matches!(
&m2[0],
LoopMessage::User(u) if u.content.contains("second")
));
let m3 = hook().await;
assert!(m3.is_empty());
}
#[tokio::test]
async fn producer_enqueue_visible_on_next_poll() {
let queue = Arc::new(Mutex::new(VecDeque::<String>::new()));
let hook = steering_from_queue(queue.clone(), QueueMode::All);
assert!(hook().await.is_empty());
let pushed = queue.clone();
tokio::spawn(async move {
pushed.lock().unwrap().push_back("mid-run".to_string());
})
.await
.unwrap();
let messages = hook().await;
assert_eq!(messages.len(), 1);
assert!(matches!(
&messages[0],
LoopMessage::User(u) if u.content.starts_with(MID_TURN_STEER_WRAPPER) && u.content.ends_with("mid-run")
));
}
#[tokio::test]
async fn concurrent_polls_serialize() {
let queue = Arc::new(Mutex::new(VecDeque::<String>::from(vec![
"a".to_string(),
"b".to_string(),
"c".to_string(),
])));
let hook = steering_from_queue(queue.clone(), QueueMode::All);
let h1 = hook.clone();
let h2 = hook.clone();
let (r1, r2) = tokio::join!(h1(), h2());
let lens = [r1.len(), r2.len()];
let mut sorted = lens;
sorted.sort();
assert_eq!(sorted, [0, 3]);
}
#[tokio::test]
async fn fn_is_send_sync() {
fn assert_send_sync<T: Send + Sync>(_: &T) {}
let queue = Arc::new(Mutex::new(VecDeque::<String>::new()));
let hook = steering_from_queue(queue, QueueMode::All);
assert_send_sync(&hook);
}
#[tokio::test]
async fn integration_steering_queue_injects_between_turns() {
use crate::agent::agent_loop::message::{
AssistantMessage, ContentBlock, StopReason, StreamEvent,
};
use crate::agent::agent_loop::result::LoopToolResult;
use crate::agent::agent_loop::run::run_agent_loop;
use crate::agent::agent_loop::stream::StreamFn;
use crate::agent::agent_loop::tool::{AbortSignal, LoopTool, LoopToolUpdate};
use crate::agent::agent_loop::tools::extract_tool_calls;
use crate::agent::agent_loop::types::{Context, LoopConfig, ToolExecutionMode};
use serde_json::Value;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug)]
struct EchoTool;
impl LoopTool for EchoTool {
fn name(&self) -> &str {
"echo"
}
fn description(&self) -> &str {
"Echo"
}
fn label(&self) -> &str {
"Echo"
}
fn parameters(&self) -> &Value {
static EMPTY: std::sync::OnceLock<Value> = std::sync::OnceLock::new();
EMPTY.get_or_init(|| serde_json::json!({"type": "object"}))
}
fn execute<'a>(
&'a self,
_id: &'a str,
_args: Value,
_signal: AbortSignal,
_on_update: LoopToolUpdate,
) -> Pin<Box<dyn Future<Output = Result<LoopToolResult, String>> + Send + 'a>>
{
Box::pin(async move {
Ok(LoopToolResult {
content: vec![serde_json::json!({"type": "text", "text": "ok"})],
details: Value::Null,
terminate: None,
})
})
}
}
let queue = Arc::new(Mutex::new(VecDeque::<String>::new()));
let saw_interrupt = Arc::new(Mutex::new(false));
let saw_clone = saw_interrupt.clone();
let call_counter = Arc::new(AtomicUsize::new(0));
let queue_writer = queue.clone();
let factory: StreamFn = Arc::new(move |llm_ctx, _opts| {
let n = call_counter.fetch_add(1, Ordering::SeqCst);
if n == 1 {
let found = llm_ctx.messages.iter().any(|m| {
m.get("role").and_then(|r| r.as_str()) == Some("user")
&& m.get("content")
.and_then(|c| c.as_str())
.map(|s| s.contains("interrupt"))
== Some(true)
});
*saw_clone.lock().unwrap() = found;
} else if n == 0 {
queue_writer
.lock()
.unwrap()
.push_back("interrupt".to_string());
}
let msg = if n == 0 {
AssistantMessage::new(
vec![ContentBlock::ToolCall {
id: "call-1".to_string(),
name: "echo".to_string(),
arguments: serde_json::json!({}),
}],
StopReason::ToolUse,
)
} else {
AssistantMessage::new(
vec![ContentBlock::Text {
text: "done".to_string(),
}],
StopReason::Stop,
)
};
let reason = msg.stop_reason;
Box::pin(futures::stream::iter(vec![StreamEvent::Done {
reason,
message: msg,
usage: None,
}]))
});
let mut config = LoopConfig {
convert_to_llm: Arc::new(|messages: &[Value]| {
messages
.iter()
.filter(|m| {
let role = m.get("role").and_then(|r| r.as_str()).unwrap_or("");
matches!(role, "user" | "assistant" | "tool" | "toolResult")
})
.cloned()
.collect()
}),
transform_context: None,
compaction_hooks: None,
get_api_key: None,
api_key: None,
tool_execution: ToolExecutionMode::Sequential,
before_tool_call: None,
after_tool_call: None,
prepare_next_turn: None,
should_stop_after_turn: None,
get_steering_messages: None,
get_followup_messages: None,
reasoning: None,
thinking_budgets: None,
headers: std::collections::HashMap::new(),
metadata: std::collections::HashMap::new(),
request_timeout: None,
provider_name: None,
model_name: None,
compact_model: None,
storm_mutating_tools: None,
storm_exempt_tools: None,
repair_stats: std::sync::Arc::new(
crate::agent::agent_loop::tool_input_repair::RepairStats::new(),
),
truncation_notes: std::sync::Arc::new(std::sync::Mutex::new(
std::collections::HashMap::new(),
)),
tool_def_filter: None,
dynamic_tool_search: false,
escalation_stream_fn: None,
escalation_provider_name: None,
escalation_pending: std::sync::Arc::new(std::sync::Mutex::new(None)),
escalation_max_per_session: 3,
escalation_remaining: std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(3)),
file_touch_tracker: None,
verifier: None,
critic_fn: None,
goal: None,
max_turns: None,
};
config.get_steering_messages = Some(steering_from_queue(queue.clone(), QueueMode::All));
let mut ctx = Context::default();
ctx.tools.push(Arc::new(EchoTool));
let (tx, _rx) = tokio::sync::mpsc::channel(64);
let messages = run_agent_loop(
vec![LoopMessage::User(UserMessage {
content: "start".to_string(),
})],
ctx,
config,
AbortSignal::new(),
&tx,
&factory,
None,
None, )
.await;
assert!(
*saw_interrupt.lock().unwrap(),
"second LLM call should see the injected interrupt"
);
let user_contents: Vec<String> = messages
.iter()
.filter_map(|m| match m {
LoopMessage::User(u) => Some(u.content.clone()),
_ => None,
})
.collect();
assert_eq!(user_contents[0], "start");
assert!(
user_contents[1].starts_with(MID_TURN_STEER_WRAPPER),
"steering message should be wrapped with preamble"
);
assert!(user_contents[1].ends_with("interrupt"));
}
}