#[allow(unused_imports)]
use crate::sync_util::LockExt;
use std::pin::Pin;
use std::sync::Arc;
use futures::Stream;
use futures::stream::StreamExt;
use tokio::sync::mpsc;
use super::message::{
AssistantMessage, LoopEvent, LoopMessage, StopReason, StreamEvent, assistant_to_value,
};
use super::tool::AbortSignal;
use super::types::{Context, LoopConfig};
#[derive(Debug, Clone)]
pub struct LlmContext {
pub system_prompt: String,
pub messages: Vec<serde_json::Value>,
}
#[derive(Clone)]
pub struct StreamOptions {
#[allow(dead_code)]
pub api_key: Option<String>,
pub reasoning: Option<super::types::ThinkingLevel>,
pub thinking_budgets: Option<super::types::ThinkingBudgets>,
pub headers: std::collections::HashMap<String, String>,
pub metadata: std::collections::HashMap<String, serde_json::Value>,
#[allow(dead_code)]
pub request_timeout: Option<std::time::Duration>,
pub signal: AbortSignal,
}
impl StreamOptions {
#[cfg(test)]
pub fn from_signal(signal: AbortSignal) -> Self {
Self {
api_key: None,
reasoning: None,
thinking_budgets: None,
headers: std::collections::HashMap::new(),
metadata: std::collections::HashMap::new(),
request_timeout: None,
signal,
}
}
}
pub type StreamFn = Arc<
dyn Fn(LlmContext, StreamOptions) -> Pin<Box<dyn Stream<Item = StreamEvent> + Send>>
+ Send
+ Sync,
>;
pub async fn stream_assistant_response(
context: &mut Context,
config: &LoopConfig,
signal: AbortSignal,
emit: &mpsc::Sender<LoopEvent>,
stream_fn: &StreamFn,
) -> (AssistantMessage, Option<super::message::TokenUsage>) {
let messages: Vec<serde_json::Value> = if let Some(transform) = &config.transform_context {
transform(context.messages.clone()).await
} else {
context.messages.clone()
};
let llm_messages = (config.convert_to_llm)(&messages);
let resolved_api_key: Option<String> = if let Some(get_key) = &config.get_api_key {
let provider = config.provider_name.as_deref().unwrap_or("");
match get_key(provider).await {
Some(k) => Some(k),
None => config.api_key.clone(),
}
} else {
config.api_key.clone()
};
let llm_ctx = LlmContext {
system_prompt: context.system_prompt.clone(),
messages: llm_messages,
};
let stream_options = StreamOptions {
api_key: resolved_api_key,
reasoning: config.reasoning,
thinking_budgets: config.thinking_budgets.clone(),
headers: config.headers.clone(),
metadata: config.metadata.clone(),
request_timeout: config.request_timeout,
signal,
};
let pending_reason: Option<super::message::EscalationReason> = {
let mut pending = config.escalation_pending.lock_ignore_poison();
pending.take()
};
let use_escalation = pending_reason.is_some() && config.escalation_stream_fn.is_some();
if let Some(reason) = pending_reason
&& use_escalation
{
let provider = config
.escalation_provider_name
.clone()
.unwrap_or_else(|| "escalation".to_string());
let _ = emit
.send(LoopEvent::EscalationActivated { provider, reason })
.await;
}
let active_stream_fn: &StreamFn = if use_escalation {
config
.escalation_stream_fn
.as_ref()
.expect("checked Some above")
} else {
stream_fn
};
let mut stream = active_stream_fn(llm_ctx, stream_options);
let mut added_partial = false;
let mut final_message: Option<(AssistantMessage, Option<super::message::TokenUsage>)> = None;
while let Some(event) = stream.next().await {
match event {
StreamEvent::Start { partial } => {
context.messages.push(assistant_to_value(&partial));
added_partial = true;
let _ = emit
.send(LoopEvent::MessageStart {
message: LoopMessage::Assistant(partial),
})
.await;
}
StreamEvent::Delta { partial, phase } => {
if added_partial {
if let Some(last) = context.messages.last_mut() {
*last = assistant_to_value(&partial);
}
}
let _ = emit
.send(LoopEvent::MessageUpdate {
message: partial,
phase,
})
.await;
}
StreamEvent::Done {
reason,
message,
usage,
} => {
let mut finalised = message;
finalised.stop_reason = reason;
finalize(context, &finalised, added_partial, emit).await;
if let Some(u) = usage {
let _ = emit.send(LoopEvent::Usage { usage: u }).await;
}
final_message = Some((finalised, usage));
break;
}
StreamEvent::Error { error } => {
let finalised = AssistantMessage {
content: Vec::new(),
stop_reason: StopReason::Error,
error_message: Some(error),
};
finalize(context, &finalised, added_partial, emit).await;
final_message = Some((finalised, None));
break;
}
StreamEvent::Retry {
attempt,
delay_ms,
error,
} => {
let _ = emit
.send(LoopEvent::RetryNotice {
attempt,
delay_ms,
error,
})
.await;
if added_partial
&& let Some(last) = context.messages.last()
&& last.get("role").and_then(|r| r.as_str()) == Some("assistant")
{
context.messages.pop();
}
added_partial = false;
}
}
}
match final_message {
Some((m, usage)) => (m, usage),
None => {
let empty = AssistantMessage::new(Vec::new(), StopReason::Stop);
finalize(context, &empty, added_partial, emit).await;
(empty, None)
}
}
}
async fn finalize(
context: &mut Context,
final_msg: &AssistantMessage,
added_partial: bool,
emit: &mpsc::Sender<LoopEvent>,
) {
if added_partial {
if let Some(last) = context.messages.last_mut() {
*last = assistant_to_value(final_msg);
}
} else {
context.messages.push(assistant_to_value(final_msg));
let _ = emit
.send(LoopEvent::MessageStart {
message: LoopMessage::Assistant(final_msg.clone()),
})
.await;
}
let _ = emit
.send(LoopEvent::MessageEnd {
message: LoopMessage::Assistant(final_msg.clone()),
})
.await;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent::agent_loop::message::ContentBlock;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
fn identity_converter()
-> Arc<dyn Fn(&[serde_json::Value]) -> Vec<serde_json::Value> + Send + Sync> {
Arc::new(|messages: &[serde_json::Value]| {
messages
.iter()
.filter(|m| {
let role = m.get("role").and_then(|r| r.as_str()).unwrap_or("");
matches!(role, "user" | "assistant" | "tool" | "toolResult")
})
.cloned()
.collect()
})
}
fn canned_done_stream(content_text: &str) -> StreamFn {
let text = content_text.to_string();
Arc::new(move |_ctx, _opts| {
let message = AssistantMessage::new(
vec![ContentBlock::Text { text: text.clone() }],
StopReason::Stop,
);
Box::pin(futures::stream::iter(vec![StreamEvent::Done {
reason: StopReason::Stop,
message,
usage: None,
}]))
})
}
fn build_config(
convert: Arc<dyn Fn(&[serde_json::Value]) -> Vec<serde_json::Value> + Send + Sync>,
) -> LoopConfig {
LoopConfig {
convert_to_llm: convert,
transform_context: None,
compaction_hooks: None,
get_api_key: None,
api_key: None,
tool_execution: crate::agent::agent_loop::ToolExecutionMode::Parallel,
before_tool_call: None,
after_tool_call: None,
prepare_next_turn: None,
should_stop_after_turn: None,
get_steering_messages: None,
get_followup_messages: None,
reasoning: None,
thinking_budgets: None,
headers: std::collections::HashMap::new(),
metadata: std::collections::HashMap::new(),
request_timeout: None,
provider_name: None,
model_name: None,
compact_model: None,
storm_mutating_tools: None,
storm_exempt_tools: None,
repair_stats: std::sync::Arc::new(
crate::agent::agent_loop::tool_input_repair::RepairStats::new(),
),
truncation_notes: std::sync::Arc::new(std::sync::Mutex::new(
std::collections::HashMap::new(),
)),
tool_def_filter: None,
dynamic_tool_search: false,
escalation_stream_fn: None,
escalation_provider_name: None,
escalation_pending: std::sync::Arc::new(std::sync::Mutex::new(None)),
escalation_max_per_session: 3,
escalation_remaining: std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(3)),
file_touch_tracker: None,
verifier: None,
critic_fn: None,
goal: None,
max_turns: None,
}
}
#[tokio::test]
async fn test_stream_options_threaded_from_loop_config() {
use crate::agent::agent_loop::types::{ThinkingBudgets, ThinkingLevel};
use std::sync::Mutex;
let observed: Arc<Mutex<Option<StreamOptions>>> = Arc::new(Mutex::new(None));
let observed_clone = observed.clone();
let stream_fn: StreamFn = Arc::new(move |_ctx, opts: StreamOptions| {
*observed_clone.lock().unwrap() = Some(opts);
let message = AssistantMessage::new(
vec![ContentBlock::Text {
text: "ok".to_string(),
}],
StopReason::Stop,
);
Box::pin(futures::stream::iter(vec![StreamEvent::Done {
reason: StopReason::Stop,
message,
usage: None,
}]))
});
let mut config = build_config(identity_converter());
config.api_key = Some("static-key".to_string());
config.reasoning = Some(ThinkingLevel::High);
config.thinking_budgets = Some(ThinkingBudgets {
high: Some(8192),
..Default::default()
});
config
.headers
.insert("X-Test".to_string(), "yes".to_string());
config
.metadata
.insert("user_id".to_string(), serde_json::json!("u42"));
config.request_timeout = Some(std::time::Duration::from_secs(120));
let mut ctx = Context {
system_prompt: String::new(),
messages: vec![serde_json::json!({"role": "user", "content": "hi"})],
tools: Vec::new(),
};
let (tx, _rx) = mpsc::channel::<LoopEvent>(8);
let _ =
stream_assistant_response(&mut ctx, &config, AbortSignal::new(), &tx, &stream_fn).await;
let opts = observed.lock().unwrap().clone().expect("opts captured");
assert_eq!(opts.api_key.as_deref(), Some("static-key"));
assert_eq!(opts.reasoning, Some(ThinkingLevel::High));
assert_eq!(
opts.thinking_budgets.as_ref().and_then(|b| b.high),
Some(8192)
);
assert_eq!(opts.headers.get("X-Test").map(String::as_str), Some("yes"));
assert_eq!(
opts.metadata.get("user_id"),
Some(&serde_json::json!("u42")),
);
assert_eq!(
opts.request_timeout,
Some(std::time::Duration::from_secs(120))
);
}
#[tokio::test]
async fn test_emits_message_start_and_end() {
let mut ctx = Context {
system_prompt: "You are helpful.".to_string(),
messages: vec![serde_json::json!({"role": "user", "content": "Hello"})],
tools: Vec::new(),
};
let config = build_config(identity_converter());
let signal = AbortSignal::new();
let (tx, mut rx) = mpsc::channel::<LoopEvent>(32);
let (final_msg, _) = stream_assistant_response(
&mut ctx,
&config,
signal,
&tx,
&canned_done_stream("Hi there!"),
)
.await;
drop(tx);
assert_eq!(final_msg.stop_reason, StopReason::Stop);
assert_eq!(final_msg.content.len(), 1);
let mut kinds = Vec::new();
while let Some(e) = rx.recv().await {
kinds.push(e.kind().to_string());
}
assert_eq!(kinds, vec!["message_start", "message_end"]);
assert_eq!(ctx.messages.len(), 2);
assert_eq!(
ctx.messages[0].get("role").and_then(|r| r.as_str()),
Some("user")
);
assert_eq!(
ctx.messages[1].get("role").and_then(|r| r.as_str()),
Some("assistant")
);
}
#[tokio::test]
async fn test_get_api_key_receives_provider_name() {
use std::sync::Mutex;
let observed: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
let observed_clone = observed.clone();
let mut config = build_config(identity_converter());
config.provider_name = Some("anthropic".to_string());
config.get_api_key = Some(Arc::new(move |provider| {
let observed = observed_clone.clone();
let p = provider.to_string();
Box::pin(async move {
*observed.lock().unwrap() = Some(p);
Some("hook-resolved-key".to_string())
})
}));
let mut ctx = Context {
system_prompt: String::new(),
messages: vec![serde_json::json!({"role": "user", "content": "hi"})],
tools: Vec::new(),
};
let (tx, _rx) = mpsc::channel::<LoopEvent>(8);
let _ = stream_assistant_response(
&mut ctx,
&config,
AbortSignal::new(),
&tx,
&canned_done_stream("ok"),
)
.await;
assert_eq!(
observed.lock().unwrap().as_deref(),
Some("anthropic"),
"get_api_key hook should have received 'anthropic'"
);
}
#[tokio::test]
async fn test_convert_to_llm_filters_custom_messages() {
let mut ctx = Context {
system_prompt: "You are helpful.".to_string(),
messages: vec![
serde_json::json!({"role": "notification", "text": "noisy"}),
serde_json::json!({"role": "user", "content": "Hello"}),
],
tools: Vec::new(),
};
let received = Arc::new(std::sync::Mutex::new(Vec::<serde_json::Value>::new()));
let received_clone = received.clone();
let convert: Arc<dyn Fn(&[serde_json::Value]) -> Vec<serde_json::Value> + Send + Sync> =
Arc::new(move |messages| {
let mut slot = received_clone.lock().unwrap();
*slot = messages.to_vec();
messages
.iter()
.filter(|m| m.get("role").and_then(|r| r.as_str()) != Some("notification"))
.cloned()
.collect()
});
let config = build_config(convert);
let signal = AbortSignal::new();
let (tx, mut rx) = mpsc::channel::<LoopEvent>(32);
let _ = stream_assistant_response(
&mut ctx,
&config,
signal,
&tx,
&canned_done_stream("Response"),
)
.await;
drop(tx);
while rx.recv().await.is_some() {}
let received = received.lock().unwrap();
assert_eq!(received.len(), 2);
let roles: Vec<_> = received
.iter()
.map(|m| m.get("role").and_then(|r| r.as_str()).unwrap_or(""))
.collect();
assert_eq!(roles, vec!["notification", "user"]);
}
#[tokio::test]
async fn test_transform_context_runs_before_convert_to_llm() {
let mut ctx = Context {
system_prompt: "You are helpful.".to_string(),
messages: vec![
serde_json::json!({"role": "user", "content": "old 1"}),
serde_json::json!({"role": "assistant", "content": "resp 1"}),
serde_json::json!({"role": "user", "content": "old 2"}),
serde_json::json!({"role": "assistant", "content": "resp 2"}),
serde_json::json!({"role": "user", "content": "new"}),
],
tools: Vec::new(),
};
let counter = Arc::new(AtomicUsize::new(0));
let transform_order = counter.clone();
let transform: Arc<
dyn Fn(
Vec<serde_json::Value>,
)
-> Pin<Box<dyn std::future::Future<Output = Vec<serde_json::Value>> + Send>>
+ Send
+ Sync,
> = Arc::new(move |messages| {
let order = transform_order.clone();
Box::pin(async move {
let n = order.fetch_add(1, Ordering::SeqCst);
assert_eq!(n, 0, "transform_context must fire before convert_to_llm");
let len = messages.len();
if len <= 2 {
messages
} else {
messages[len - 2..].to_vec()
}
})
});
let convert_order = counter.clone();
let received_convert = Arc::new(std::sync::Mutex::new(Vec::<serde_json::Value>::new()));
let received_clone = received_convert.clone();
let convert: Arc<dyn Fn(&[serde_json::Value]) -> Vec<serde_json::Value> + Send + Sync> =
Arc::new(move |messages| {
let n = convert_order.fetch_add(1, Ordering::SeqCst);
assert_eq!(n, 1, "convert_to_llm must run after transform_context");
*received_clone.lock().unwrap() = messages.to_vec();
messages.to_vec()
});
let config = LoopConfig {
convert_to_llm: convert,
transform_context: Some(transform),
compaction_hooks: None,
get_api_key: None,
api_key: None,
tool_execution: crate::agent::agent_loop::ToolExecutionMode::Parallel,
before_tool_call: None,
after_tool_call: None,
prepare_next_turn: None,
should_stop_after_turn: None,
get_steering_messages: None,
get_followup_messages: None,
reasoning: None,
thinking_budgets: None,
headers: std::collections::HashMap::new(),
metadata: std::collections::HashMap::new(),
request_timeout: None,
provider_name: None,
model_name: None,
compact_model: None,
storm_mutating_tools: None,
storm_exempt_tools: None,
repair_stats: std::sync::Arc::new(
crate::agent::agent_loop::tool_input_repair::RepairStats::new(),
),
truncation_notes: std::sync::Arc::new(std::sync::Mutex::new(
std::collections::HashMap::new(),
)),
tool_def_filter: None,
dynamic_tool_search: false,
escalation_stream_fn: None,
escalation_provider_name: None,
escalation_pending: std::sync::Arc::new(std::sync::Mutex::new(None)),
escalation_max_per_session: 3,
escalation_remaining: std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(3)),
file_touch_tracker: None,
verifier: None,
critic_fn: None,
goal: None,
max_turns: None,
};
let signal = AbortSignal::new();
let (tx, mut rx) = mpsc::channel::<LoopEvent>(32);
let _ = stream_assistant_response(
&mut ctx,
&config,
signal,
&tx,
&canned_done_stream("Response"),
)
.await;
drop(tx);
while rx.recv().await.is_some() {}
let received = received_convert.lock().unwrap();
assert_eq!(received.len(), 2, "convert_to_llm should see pruned list");
assert_eq!(counter.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn test_stream_closed_without_terminal_event() {
let mut ctx = Context {
system_prompt: String::new(),
messages: vec![serde_json::json!({"role": "user", "content": "hi"})],
tools: Vec::new(),
};
let config = build_config(identity_converter());
let signal = AbortSignal::new();
let (tx, mut rx) = mpsc::channel::<LoopEvent>(32);
let empty_stream: StreamFn =
Arc::new(|_ctx, _opts| Box::pin(futures::stream::iter::<Vec<StreamEvent>>(vec![])));
let (final_msg, _) =
stream_assistant_response(&mut ctx, &config, signal, &tx, &empty_stream).await;
drop(tx);
let mut events = Vec::new();
while let Some(e) = rx.recv().await {
events.push(e);
}
assert_eq!(final_msg.stop_reason, StopReason::Stop);
assert_eq!(ctx.messages.len(), 2);
let kinds: Vec<_> = events.iter().map(|e| e.kind()).collect();
assert_eq!(
kinds,
vec!["message_start", "message_end"],
"fallback must emit message_start + message_end (pi 363-366)",
);
}
fn labelled_stream(
label: &'static str,
observed: Arc<std::sync::Mutex<Vec<&'static str>>>,
) -> StreamFn {
Arc::new(move |_ctx, _opts| {
observed.lock().unwrap().push(label);
let msg = AssistantMessage::new(
vec![ContentBlock::Text {
text: format!("{label}-response"),
}],
StopReason::Stop,
);
Box::pin(futures::stream::iter(vec![StreamEvent::Done {
reason: StopReason::Stop,
message: msg,
usage: None,
}]))
})
}
#[tokio::test]
async fn escalation_arm_then_swap_uses_alternate_stream_fn() {
use crate::agent::agent_loop::message::EscalationReason;
let observed = Arc::new(std::sync::Mutex::new(Vec::new()));
let default_fn = labelled_stream("default", observed.clone());
let escalation_fn = labelled_stream("escalation", observed.clone());
let mut config = build_config(identity_converter());
config.escalation_stream_fn = Some(escalation_fn);
config.escalation_provider_name = Some("alt-provider".to_string());
*config.escalation_pending.lock().unwrap() = Some(EscalationReason::RepairExhausted {
tool: "write".to_string(),
});
let mut ctx = Context::default();
let (tx, _rx) = mpsc::channel::<LoopEvent>(32);
let _ = stream_assistant_response(&mut ctx, &config, AbortSignal::new(), &tx, &default_fn)
.await;
assert_eq!(observed.lock().unwrap().as_slice(), &["escalation"]);
}
#[tokio::test]
async fn escalation_flag_cleared_after_one_call() {
use crate::agent::agent_loop::message::EscalationReason;
let observed = Arc::new(std::sync::Mutex::new(Vec::new()));
let default_fn = labelled_stream("default", observed.clone());
let escalation_fn = labelled_stream("escalation", observed.clone());
let mut config = build_config(identity_converter());
config.escalation_stream_fn = Some(escalation_fn);
config.escalation_provider_name = Some("alt-provider".to_string());
*config.escalation_pending.lock().unwrap() = Some(EscalationReason::SyntacticFailure {
tool: "edit".to_string(),
path: "src/foo.rs".to_string(),
});
let mut ctx = Context::default();
let (tx, _rx) = mpsc::channel::<LoopEvent>(32);
let _ = stream_assistant_response(&mut ctx, &config, AbortSignal::new(), &tx, &default_fn)
.await;
let _ = stream_assistant_response(&mut ctx, &config, AbortSignal::new(), &tx, &default_fn)
.await;
assert_eq!(
observed.lock().unwrap().as_slice(),
&["escalation", "default"]
);
assert!(config.escalation_pending.lock().unwrap().is_none());
}
#[tokio::test]
async fn escalation_no_op_when_alternate_is_none() {
use crate::agent::agent_loop::message::EscalationReason;
let observed = Arc::new(std::sync::Mutex::new(Vec::new()));
let default_fn = labelled_stream("default", observed.clone());
let config = build_config(identity_converter());
*config.escalation_pending.lock().unwrap() = Some(EscalationReason::RepairExhausted {
tool: "write".to_string(),
});
let mut ctx = Context::default();
let (tx, _rx) = mpsc::channel::<LoopEvent>(32);
let _ = stream_assistant_response(&mut ctx, &config, AbortSignal::new(), &tx, &default_fn)
.await;
assert_eq!(observed.lock().unwrap().as_slice(), &["default"]);
assert!(config.escalation_pending.lock().unwrap().is_none());
}
#[tokio::test]
async fn escalation_max_per_session_caps_arming() {
use crate::agent::agent_loop::message::EscalationReason;
use crate::agent::agent_loop::tools::try_arm_escalation;
use std::sync::atomic::Ordering;
let mut config = build_config(identity_converter());
config.escalation_max_per_session = 2;
config.escalation_remaining.store(2, Ordering::SeqCst);
for _ in 0..5 {
try_arm_escalation(
&config,
EscalationReason::RepairExhausted {
tool: "write".to_string(),
},
);
*config.escalation_pending.lock().unwrap() = None;
}
assert_eq!(
config.escalation_remaining.load(Ordering::SeqCst),
0,
"budget exhausted exactly twice"
);
}
#[tokio::test]
async fn escalation_event_emitted() {
use crate::agent::agent_loop::message::EscalationReason;
let observed = Arc::new(std::sync::Mutex::new(Vec::new()));
let default_fn = labelled_stream("default", observed.clone());
let escalation_fn = labelled_stream("escalation", observed.clone());
let mut config = build_config(identity_converter());
config.escalation_stream_fn = Some(escalation_fn);
config.escalation_provider_name = Some("anthropic-pro".to_string());
*config.escalation_pending.lock().unwrap() = Some(EscalationReason::SyntacticFailure {
tool: "write".to_string(),
path: "lib.rs".to_string(),
});
let mut ctx = Context::default();
let (tx, mut rx) = mpsc::channel::<LoopEvent>(64);
let _ = stream_assistant_response(&mut ctx, &config, AbortSignal::new(), &tx, &default_fn)
.await;
drop(tx);
let mut saw_escalation = false;
while let Some(evt) = rx.recv().await {
if let LoopEvent::EscalationActivated { provider, reason } = &evt {
assert_eq!(provider, "anthropic-pro");
assert!(matches!(reason, EscalationReason::SyntacticFailure { .. }));
saw_escalation = true;
}
}
assert!(saw_escalation, "expected EscalationActivated event");
}
}