klieo-core 0.6.0

Core traits + runtime for the klieo agent framework.
Documentation
//! Shared tool-dispatch helper used by both the blocking and streaming
//! runtime drivers.
//!
//! Both [`super::run_steps`] (blocking) and the streaming driver's loop
//! (`drive_streaming_loop` inside `super`) previously carried
//! byte-identical tool-dispatch arms. Extracting the shared loop here
//! removes the dedup risk: a bugfix lands in one place.

use crate::agent::AgentContext;
use crate::error::Error;
use crate::ids::ThreadId;
use crate::llm::{Message, Role, ToolCall};
use crate::memory::{Episode, ToolResult};
use tracing::debug;

/// Invoke every tool in `tool_calls` against `ctx.tools`, recording the
/// episodic [`Episode::ToolCall`] and appending a `Role::Tool` message
/// to short-term memory per call.
///
/// On a permanent (non-retryable) tool failure, returns `Err(Error::Tool)`
/// — the runtime aborts the run. Retryable failures are swallowed at
/// this layer (the error message has already been written to short-term
/// memory so the LLM observes it on the next cycle); upstream
/// backoff/retry policy belongs to wrappers (e.g. tool-level retry
/// adapters in `klieo-tools-mcp`).
///
/// `driver` is included in the trace span so log readers can tell the
/// blocking and streaming paths apart in a single combined log.
pub(crate) async fn dispatch_tool_calls(
    ctx: &AgentContext,
    thread: &ThreadId,
    tool_calls: &[ToolCall],
    driver: &'static str,
) -> Result<(), Error> {
    for call in tool_calls {
        let tool_ctx = crate::tool::ToolCtx {
            pubsub: ctx.pubsub.clone(),
            kv: ctx.kv.clone(),
            jobs: ctx.jobs.clone(),
        };
        let outcome = ctx
            .tools
            .invoke(&call.name, call.args.clone(), tool_ctx)
            .await;
        let (result_for_log, tool_msg_content) = match &outcome {
            Ok(v) => (ToolResult::Ok { value: v.clone() }, v.to_string()),
            Err(e) => (
                ToolResult::Err {
                    message: e.to_string(),
                },
                format!("error: {e}"),
            ),
        };
        ctx.episodic
            .record(
                ctx.run_id,
                Episode::ToolCall {
                    name: call.name.clone(),
                    args: call.args.clone(),
                    result: result_for_log,
                },
            )
            .await?;
        ctx.short_term
            .append(
                thread.clone(),
                Message {
                    role: Role::Tool,
                    content: tool_msg_content,
                    tool_calls: vec![],
                    tool_call_id: Some(call.id.clone()),
                },
            )
            .await?;
        if let Err(e) = outcome {
            if !e.retryable() {
                return Err(Error::Tool(e));
            }
        }
    }
    debug!(n = tool_calls.len(), driver, "dispatched tools");
    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::agent::AgentContext;
    use crate::error::ToolError;
    use crate::ids::{RunId, ThreadId};
    use crate::memory::EpisodicMemory;
    use crate::test_utils::{
        noop_bus, FakeLlmClient, FakeToolInvoker, InMemoryEpisodic, InMemoryLongTerm,
        InMemoryShortTerm,
    };
    use std::sync::Arc;
    use tokio_util::sync::CancellationToken;

    fn make_ctx(tools: Arc<FakeToolInvoker>) -> (AgentContext, Arc<InMemoryEpisodic>) {
        let (pubsub, request_reply, kv, jobs) = noop_bus();
        let episodic = Arc::new(InMemoryEpisodic::default());
        (
            AgentContext {
                llm: Arc::new(FakeLlmClient::default()),
                short_term: Arc::new(InMemoryShortTerm::default()),
                long_term: Arc::new(InMemoryLongTerm::default()),
                episodic: episodic.clone(),
                pubsub,
                kv,
                request_reply,
                jobs,
                tools,
                run_id: RunId::new(),
                cancel: CancellationToken::new(),
                agent_name: "dispatch-test".into(),
            },
            episodic,
        )
    }

    fn call(id: &str, name: &str) -> ToolCall {
        ToolCall {
            id: id.into(),
            name: name.into(),
            args: serde_json::json!({}),
        }
    }

    #[tokio::test]
    async fn records_one_episode_and_one_short_term_message_per_successful_call() {
        let tools = Arc::new(
            FakeToolInvoker::new()
                .with_tool("ok_a", "", |_| Ok(serde_json::json!("a")))
                .with_tool("ok_b", "", |_| Ok(serde_json::json!("b")))
                .with_tool("ok_c", "", |_| Ok(serde_json::json!("c"))),
        );
        let (ctx, episodic) = make_ctx(tools);
        let thread = ThreadId::new("t-1");
        let calls = vec![
            call("c-1", "ok_a"),
            call("c-2", "ok_b"),
            call("c-3", "ok_c"),
        ];

        dispatch_tool_calls(&ctx, &thread, &calls, "blocking")
            .await
            .expect("all-Ok dispatch must succeed");

        let episodes = episodic.replay(ctx.run_id).await.unwrap();
        let tool_call_count = episodes
            .iter()
            .filter(|e| matches!(e, Episode::ToolCall { .. }))
            .count();
        assert_eq!(tool_call_count, 3);
        let history = ctx.short_term.load(thread, 1024).await.unwrap();
        let tool_msgs: Vec<_> = history.iter().filter(|m| m.role == Role::Tool).collect();
        assert_eq!(tool_msgs.len(), 3);
        let ids: Vec<&str> = tool_msgs
            .iter()
            .filter_map(|m| m.tool_call_id.as_deref())
            .collect();
        assert_eq!(ids, vec!["c-1", "c-2", "c-3"]);
    }

    #[tokio::test]
    async fn non_retryable_tool_error_aborts_after_persisting_that_call() {
        let tools = Arc::new(
            FakeToolInvoker::new()
                .with_tool("ok", "", |_| Ok(serde_json::json!("ok-value")))
                .with_tool("boom", "", |_| Err(ToolError::Permanent("nope".into())))
                .with_tool("never_runs", "", |_| Ok(serde_json::json!("x"))),
        );
        let (ctx, episodic) = make_ctx(tools);
        let thread = ThreadId::new("t-abort");
        let calls = vec![
            call("c-1", "ok"),
            call("c-2", "boom"),
            call("c-3", "never_runs"),
        ];

        let err = dispatch_tool_calls(&ctx, &thread, &calls, "blocking")
            .await
            .expect_err("non-retryable must abort");
        assert!(
            matches!(&err, Error::Tool(ToolError::Permanent(m)) if m == "nope"),
            "expected Tool(Permanent), got {err:?}"
        );
        let episodes = episodic.replay(ctx.run_id).await.unwrap();
        let tool_episode_count = episodes
            .iter()
            .filter(|e| matches!(e, Episode::ToolCall { .. }))
            .count();
        assert_eq!(
            tool_episode_count, 2,
            "only ok + boom should have been recorded; never_runs aborts"
        );
        let history = ctx.short_term.load(thread, 1024).await.unwrap();
        let tool_msgs: Vec<_> = history.iter().filter(|m| m.role == Role::Tool).collect();
        assert_eq!(tool_msgs.len(), 2);
        assert!(tool_msgs[1].content.contains("error: permanent: nope"));
    }

    #[tokio::test]
    async fn retryable_tool_error_is_swallowed_and_loop_continues() {
        let tools = Arc::new(
            FakeToolInvoker::new()
                .with_tool("flaky", "", |_| {
                    Err(ToolError::Retryable {
                        message: "transient".into(),
                        retry_after_secs: 1,
                    })
                })
                .with_tool("after", "", |_| Ok(serde_json::json!("after-value"))),
        );
        let (ctx, episodic) = make_ctx(tools);
        let thread = ThreadId::new("t-retry");
        let calls = vec![call("c-1", "flaky"), call("c-2", "after")];

        dispatch_tool_calls(&ctx, &thread, &calls, "streaming")
            .await
            .expect("retryable error must not abort the loop");

        let episodes = episodic.replay(ctx.run_id).await.unwrap();
        let tool_episode_count = episodes
            .iter()
            .filter(|e| matches!(e, Episode::ToolCall { .. }))
            .count();
        assert_eq!(tool_episode_count, 2);
        let history = ctx.short_term.load(thread, 1024).await.unwrap();
        let tool_msgs: Vec<_> = history.iter().filter(|m| m.role == Role::Tool).collect();
        assert_eq!(tool_msgs.len(), 2);
        assert!(tool_msgs[0].content.contains("error: retryable"));
        assert_eq!(tool_msgs[1].content, "\"after-value\"");
    }
}