klieo-core 0.6.0

Core traits + runtime for the klieo agent framework.
Documentation
//! Streaming agent loop: [`run_steps_streaming`] + helpers.

use crate::agent::AgentContext;
use crate::error::{Error, LlmError};
use crate::ids::ThreadId;
use crate::llm::{ChatChunk, ChatRequest, ChatResponse, ChunkStream, Usage};
use crate::memory::Episode;
use std::time::Instant;
use tracing::{error, instrument};

use super::build_request;
use super::guardrails::{check_post_llm, check_pre_llm};
use super::options::RunOptions;
use super::retry::stream_with_retry;
use super::step::{record_and_dispatch, StepDisposition, StepOutcome};
use super::streaming_forward::forward_chunks;

/// Drive the agent's LLM/tool loop the same way [`super::run_steps`]
/// does, but expose the underlying provider chunks as they arrive.
///
/// The returned [`ChunkStream`] yields every chunk produced by the
/// underlying provider stream(s), in order, across one or more LLM
/// cycles (a tool-call cycle starts a fresh provider stream after
/// dispatching the requested tools).
///
/// **Terminal-error contract.** On any non-`Stop` terminal outcome the
/// output stream emits exactly one final `Err(_)` item before closing,
/// so consumers that iterate until `None` always observe the cause:
/// - Cancellation (`ctx.cancel` fires mid-stream) → `LlmError::Cancelled`.
/// - Consumer-dropped (caller stopped reading the stream) → no error
///   item is observed by definition; the spawned driver records
///   `Episode::Failed { error: "consumer-dropped" }` and exits.
/// - Max-steps exceeded → `LlmError::Server("max steps exceeded: N")`.
/// - Tool failure (non-retryable) → `LlmError::Server("tool: ...")`.
/// - Mid-stream provider error → the original `LlmError` variant is
///   preserved end-to-end (single terminal item, no double-emit).
/// - Response exceeds `RunOptions.max_response_bytes` → terminal
///   `LlmError::Server("response exceeded max_response_bytes cap")`.
/// - Cycle-1 init error (non-retryable) → propagated synchronously as
///   `Err` from this function before any task is spawned, so callers
///   writing `let s = run_steps_streaming(...).await?;` keep working.
/// - Cycle-2+ init error → arrives in-band as a terminal `Err` chunk.
///
/// **Episode logging.** Mirrors [`super::run_steps`] for the
/// steady-state path with these caveats: (1) `Episode::LlmCall.tokens`
/// is recorded as `0` until streaming token-usage plumbing lands
/// (deferred follow-up — providers emit usage in the terminal chunk's
/// `finish_reason` carrier in some implementations, others not).
/// (2) Mid-stream cancellation records
/// `Episode::Failed { error: "cancelled" }`; consumer-dropped records
/// `Episode::Failed { error: "consumer-dropped" }`; max-steps records
/// `Episode::Failed { error: "max steps exceeded" }`.
///
/// **Back-pressure contract.** Slow consumers cause the spawned driver
/// to await on the bounded channel (capacity 16). This propagates
/// back-pressure to the provider stream — typically desirable, but
/// consumers should drain promptly. The provider-side timeout is
/// governed by `ChatRequest.timeout`.
///
/// **Stream-initiation retry.** Mirrors `complete_with_retry`: a
/// retryable [`LlmError`] from `llm.stream()` triggers exponential
/// backoff up to `MAX_LLM_RETRIES` times. The first cycle's init
/// outcome is awaited synchronously so non-retryable errors
/// (e.g. `Unauthorized`) propagate via this function's `Err` return
/// before any background task is spawned. Once a provider stream is
/// open, mid-stream errors are forwarded to the caller unchanged —
/// replay/retry mid-stream is out of scope here.
#[instrument(level = "debug", skip(ctx, system_prompt), fields(run_id = %ctx.run_id))]
pub async fn run_steps_streaming(
    ctx: &AgentContext,
    system_prompt: &str,
    thread: ThreadId,
    opts: RunOptions,
) -> Result<ChunkStream, Error> {
    ctx.episodic
        .record(
            ctx.run_id,
            Episode::Started {
                agent: ctx.agent_name.clone(),
            },
        )
        .await?;

    if ctx.cancel.is_cancelled() {
        return Err(Error::Cancelled);
    }
    if opts.max_steps == 0 {
        return Err(Error::MaxStepsExceeded {
            steps: opts.max_steps,
        });
    }

    let req = build_request(ctx, system_prompt, &thread, opts.max_history_tokens).await?;
    check_pre_llm(&opts.guardrails, &req).await?;
    let started = Instant::now();
    let provider_stream = stream_with_retry(ctx.llm.as_ref(), &ctx.cancel, req.clone()).await?;

    let (tx, rx) = tokio::sync::mpsc::channel::<Result<ChatChunk, LlmError>>(16);

    let ctx_owned = ctx.clone();
    let system_prompt = system_prompt.to_owned();

    tokio::spawn(async move {
        let outcome = drive_streaming_loop(
            &ctx_owned,
            &system_prompt,
            thread,
            opts,
            tx.clone(),
            provider_stream,
            started,
            req,
        )
        .await;
        if let Err(e) = outcome {
            let consumer_gone = tx.is_closed();
            let cancel_observed = ctx_owned.cancel.is_cancelled();
            let (terminal, episode_text) = terminal_chunk_for(e, cancel_observed);
            if !consumer_gone {
                let _ = tx.send(Err(terminal)).await;
            }
            drop(tx);
            let _ = ctx_owned
                .episodic
                .record(
                    ctx_owned.run_id,
                    Episode::Failed {
                        error: episode_text,
                    },
                )
                .await;
        }
    });

    let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
    Ok(Box::pin(stream) as ChunkStream)
}

/// Body of the streaming agent loop. Lives in its own function so
/// [`run_steps_streaming`] stays small and the loop's `?` early-exits
/// can map cleanly onto the spawned task's outcome.
#[allow(clippy::too_many_arguments)]
async fn drive_streaming_loop(
    ctx: &AgentContext,
    system_prompt: &str,
    thread: ThreadId,
    opts: RunOptions,
    tx: tokio::sync::mpsc::Sender<Result<ChatChunk, LlmError>>,
    first_stream: ChunkStream,
    first_started: Instant,
    first_req: ChatRequest,
) -> Result<(), Error> {
    let mut first_stream = Some(first_stream);
    let mut first_started = Some(first_started);
    let mut first_req = Some(first_req);
    let mut step = 0u32;
    loop {
        if ctx.cancel.is_cancelled() {
            error!("cancelled");
            return Err(Error::Cancelled);
        }
        if step >= opts.max_steps {
            return Err(Error::MaxStepsExceeded {
                steps: opts.max_steps,
            });
        }
        step += 1;

        let (provider_stream, started, current_req) =
            match (first_stream.take(), first_started.take(), first_req.take()) {
                (Some(s), Some(t), Some(r)) => (s, t, r),
                _ => {
                    let req =
                        build_request(ctx, system_prompt, &thread, opts.max_history_tokens).await?;
                    check_pre_llm(&opts.guardrails, &req).await?;
                    let started = Instant::now();
                    let s = stream_with_retry(ctx.llm.as_ref(), &ctx.cancel, req.clone()).await?;
                    (s, started, req)
                }
            };

        let (assistant_msg, finish_reason) =
            forward_chunks(provider_stream, &tx, &ctx.cancel, opts.max_response_bytes).await?;
        let latency_ms = started.elapsed().as_millis() as u32;

        if !opts.guardrails.is_empty() {
            let synth_resp = ChatResponse {
                message: assistant_msg.clone(),
                usage: Usage::default(),
                finish_reason,
            };
            check_post_llm(&opts.guardrails, &current_req, &synth_resp).await?;
        }

        let outcome = StepOutcome {
            message: assistant_msg,
            finish_reason,
            usage: Usage::default(),
            latency_ms,
        };
        match record_and_dispatch(ctx, &thread, step, outcome, "streaming").await? {
            StepDisposition::Done(_) => return Ok(()),
            StepDisposition::Continue => continue,
        }
    }
}

/// Translate a driver-loop terminal [`Error`] into the
/// `(LlmError chunk, Episode::Failed message)` pair the spawn wrapper
/// forwards. `cancel_observed` distinguishes a `ctx.cancel`-driven
/// abort (`"cancelled"`) from a consumer-dropped channel
/// (`"consumer-dropped"`); both look identical at the [`Error`] level.
fn terminal_chunk_for(error: Error, cancel_observed: bool) -> (LlmError, String) {
    match error {
        Error::Cancelled if cancel_observed => (LlmError::Cancelled, "cancelled".to_string()),
        Error::Cancelled => (LlmError::Cancelled, "consumer-dropped".to_string()),
        Error::MaxStepsExceeded { steps } => {
            let message = format!("max steps exceeded: {steps}");
            (LlmError::Server(message.clone()), message)
        }
        Error::Tool(tool_error) => {
            let message = format!("tool: {tool_error}");
            (LlmError::Server(message.clone()), message)
        }
        Error::Llm(llm_error) => {
            let message = llm_error.to_string();
            (llm_error, message)
        }
        Error::Refused { reason } => {
            let message = format!("refused: {reason}");
            (LlmError::Server(message.clone()), message)
        }
        Error::Handoff { agent, reason } => {
            let message = format!("handoff to {agent}: {reason}");
            (LlmError::Server(message.clone()), message)
        }
        other => {
            let message = other.to_string();
            (LlmError::Server(message.clone()), message)
        }
    }
}

#[cfg(test)]
#[path = "streaming_tests.rs"]
mod tests;