use crate::agent::AgentContext;
use crate::error::{Error, LlmError};
use crate::ids::ThreadId;
use crate::llm::{ChatChunk, ChatRequest, ChatResponse, ChunkStream, Usage};
use crate::memory::Episode;
use std::time::Instant;
use tracing::{error, instrument};
use super::build_request;
use super::guardrails::{check_post_llm, check_pre_llm};
use super::options::RunOptions;
use super::retry::stream_with_retry;
use super::step::{record_and_dispatch, StepDisposition, StepOutcome};
use super::streaming_forward::forward_chunks;
#[instrument(level = "debug", skip(ctx, system_prompt), fields(run_id = %ctx.run_id))]
pub async fn run_steps_streaming(
ctx: &AgentContext,
system_prompt: &str,
thread: ThreadId,
opts: RunOptions,
) -> Result<ChunkStream, Error> {
ctx.episodic
.record(
ctx.run_id,
Episode::Started {
agent: ctx.agent_name.clone(),
},
)
.await?;
if ctx.cancel.is_cancelled() {
return Err(Error::Cancelled);
}
if opts.max_steps == 0 {
return Err(Error::MaxStepsExceeded {
steps: opts.max_steps,
});
}
let req = build_request(ctx, system_prompt, &thread, opts.max_history_tokens).await?;
check_pre_llm(&opts.guardrails, &req).await?;
let started = Instant::now();
let provider_stream = stream_with_retry(ctx.llm.as_ref(), &ctx.cancel, req.clone()).await?;
let (tx, rx) = tokio::sync::mpsc::channel::<Result<ChatChunk, LlmError>>(16);
let ctx_owned = ctx.clone();
let system_prompt = system_prompt.to_owned();
tokio::spawn(async move {
let outcome = drive_streaming_loop(
&ctx_owned,
&system_prompt,
thread,
opts,
tx.clone(),
provider_stream,
started,
req,
)
.await;
if let Err(e) = outcome {
let consumer_gone = tx.is_closed();
let cancel_observed = ctx_owned.cancel.is_cancelled();
let (terminal, episode_text) = terminal_chunk_for(e, cancel_observed);
if !consumer_gone {
let _ = tx.send(Err(terminal)).await;
}
drop(tx);
let _ = ctx_owned
.episodic
.record(
ctx_owned.run_id,
Episode::Failed {
error: episode_text,
},
)
.await;
}
});
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
Ok(Box::pin(stream) as ChunkStream)
}
#[allow(clippy::too_many_arguments)]
async fn drive_streaming_loop(
ctx: &AgentContext,
system_prompt: &str,
thread: ThreadId,
opts: RunOptions,
tx: tokio::sync::mpsc::Sender<Result<ChatChunk, LlmError>>,
first_stream: ChunkStream,
first_started: Instant,
first_req: ChatRequest,
) -> Result<(), Error> {
let mut first_stream = Some(first_stream);
let mut first_started = Some(first_started);
let mut first_req = Some(first_req);
let mut step = 0u32;
loop {
if ctx.cancel.is_cancelled() {
error!("cancelled");
return Err(Error::Cancelled);
}
if step >= opts.max_steps {
return Err(Error::MaxStepsExceeded {
steps: opts.max_steps,
});
}
step += 1;
let (provider_stream, started, current_req) =
match (first_stream.take(), first_started.take(), first_req.take()) {
(Some(s), Some(t), Some(r)) => (s, t, r),
_ => {
let req =
build_request(ctx, system_prompt, &thread, opts.max_history_tokens).await?;
check_pre_llm(&opts.guardrails, &req).await?;
let started = Instant::now();
let s = stream_with_retry(ctx.llm.as_ref(), &ctx.cancel, req.clone()).await?;
(s, started, req)
}
};
let (assistant_msg, finish_reason) =
forward_chunks(provider_stream, &tx, &ctx.cancel, opts.max_response_bytes).await?;
let latency_ms = started.elapsed().as_millis() as u32;
if !opts.guardrails.is_empty() {
let synth_resp = ChatResponse {
message: assistant_msg.clone(),
usage: Usage::default(),
finish_reason,
};
check_post_llm(&opts.guardrails, ¤t_req, &synth_resp).await?;
}
let outcome = StepOutcome {
message: assistant_msg,
finish_reason,
usage: Usage::default(),
latency_ms,
};
match record_and_dispatch(ctx, &thread, step, outcome, "streaming").await? {
StepDisposition::Done(_) => return Ok(()),
StepDisposition::Continue => continue,
}
}
}
fn terminal_chunk_for(error: Error, cancel_observed: bool) -> (LlmError, String) {
match error {
Error::Cancelled if cancel_observed => (LlmError::Cancelled, "cancelled".to_string()),
Error::Cancelled => (LlmError::Cancelled, "consumer-dropped".to_string()),
Error::MaxStepsExceeded { steps } => {
let message = format!("max steps exceeded: {steps}");
(LlmError::Server(message.clone()), message)
}
Error::Tool(tool_error) => {
let message = format!("tool: {tool_error}");
(LlmError::Server(message.clone()), message)
}
Error::Llm(llm_error) => {
let message = llm_error.to_string();
(llm_error, message)
}
Error::Refused { reason } => {
let message = format!("refused: {reason}");
(LlmError::Server(message.clone()), message)
}
Error::Handoff { agent, reason } => {
let message = format!("handoff to {agent}: {reason}");
(LlmError::Server(message.clone()), message)
}
other => {
let message = other.to_string();
(LlmError::Server(message.clone()), message)
}
}
}
#[cfg(test)]
#[path = "streaming_tests.rs"]
mod tests;