polyc-llm 0.1.3

Provider-agnostic LLM trait + wire types for polychrome.
Documentation
//! Turn helpers: a [`StubProvider`] for wiring/tests and [`collect_turn`],
//! which folds a provider's [`Chunk`] stream into a single [`TurnOutput`].
//!
//! `collect_turn` is the output half of the bridge between this crate's
//! streaming vocabulary and the message-granular wire types: the harness drains
//! a provider stream into a `TurnOutput`, then maps that to wire `Message`s.

use async_trait::async_trait;
use futures::{Stream, StreamExt, stream};

use crate::{
    Chunk, CompletionRequest, LlmProvider, StopReason, Usage, error::DummyError, request::ToolCall,
};

/// An incremental event observed while folding a turn, for live streaming.
///
/// Surfaces like Slack `chat.appendStream` or a streaming CLI consume these;
/// the buffered [`TurnOutput`] is still returned in full — this is a side
/// channel, not a replacement.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TurnStreamEvent {
    /// A freshly-generated piece of answer text (concatenate to reconstruct).
    TextDelta(String),
    /// The model has begun a tool call (`id` + `name` known up front).
    ToolStarted {
        /// Provider-assigned call id.
        id: String,
        /// Name of the tool being called.
        name: String,
    },
}

/// The fully-assembled result of one turn, folded from a [`Chunk`] stream.
#[derive(Debug, Default, Clone)]
pub struct TurnOutput {
    /// Concatenated text deltas.
    pub text: String,
    /// Completed tool calls, in arrival order.
    pub tool_calls: Vec<ToolCall>,
    /// Final token accounting (last [`Chunk::Usage`] seen).
    pub usage: Usage,
    /// Why the turn ended, if the stream reported it.
    pub stop: Option<StopReason>,
}

/// Drain a provider stream into a [`TurnOutput`].
///
/// Text deltas concatenate; a tool call accretes from its
/// `ToolCallStart`/`ToolCallArgsDelta`/`ToolCallEnd` run (matched by `id`);
/// usage and stop reason are taken from their chunks.
///
/// # Errors
///
/// Propagates the first `Err` item from the stream.
pub async fn collect_turn<S, E>(stream: S) -> Result<TurnOutput, E>
where
    S: Stream<Item = Result<Chunk, E>> + Unpin,
{
    collect_turn_observed(stream, |_| {}).await
}

/// Like [`collect_turn`], but observes each streamable event as it arrives.
///
/// Invokes `on_event` for each text delta / tool start while still folding and
/// returning the complete [`TurnOutput`]. `on_event` is synchronous and must
/// not block (e.g. an unbounded-channel `send`).
///
/// # Errors
///
/// Propagates the first `Err` item from the stream.
pub async fn collect_turn_observed<S, E, F>(mut stream: S, mut on_event: F) -> Result<TurnOutput, E>
where
    S: Stream<Item = Result<Chunk, E>> + Unpin,
    F: FnMut(TurnStreamEvent),
{
    let mut out = TurnOutput::default();
    // In-progress tool calls, kept in start order and matched by id. A provider
    // may interleave several calls (OpenAI's `parallel_tool_calls` defaults to
    // true) and/or defer all their `ToolCallEnd`s to the end of the stream, so a
    // single `Option` would let a second `ToolCallStart` clobber the first and
    // an `ToolCallEnd` close the wrong call. Matching by id throughout keeps
    // every parallel call intact regardless of emission order.
    let mut pending: Vec<ToolCall> = Vec::new();
    while let Some(item) = stream.next().await {
        match item? {
            Chunk::TextDelta(s) => {
                on_event(TurnStreamEvent::TextDelta(s.clone()));
                out.text.push_str(&s);
            }
            Chunk::ToolCallStart {
                id,
                name,
                signature,
            } => {
                on_event(TurnStreamEvent::ToolStarted {
                    id: id.clone(),
                    name: name.clone(),
                });
                pending.push(ToolCall {
                    id,
                    name,
                    args_json: String::new(),
                    signature,
                });
            }
            Chunk::ToolCallArgsDelta {
                id,
                args_json_delta,
            } => {
                if let Some(tc) = pending.iter_mut().find(|tc| tc.id == id) {
                    tc.args_json.push_str(&args_json_delta);
                }
            }
            Chunk::ToolCallEnd { id } => {
                // Move the matching call to the output in completion order. An
                // unmatched id is ignored (defensive); calls still open at EOF
                // are flushed after the loop so none are silently dropped.
                if let Some(pos) = pending.iter().position(|tc| tc.id == id) {
                    out.tool_calls.push(pending.remove(pos));
                }
            }
            Chunk::Usage(u) => out.usage = u,
            // A `ToolUse` stop is sticky against a *later* `EndTurn`. Some
            // providers stream the tool call in one event and then a separate
            // trailing terminator event carrying an end-of-turn finish reason;
            // letting that later `EndTurn` overwrite the `ToolUse` stop would
            // make the agent loop skip executing the tool and end the turn with
            // no output.
            //
            // A *hard* stop (MaxTokens / Refusal / StopSequence) is the
            // opposite: it means the turn was truncated or refused, so it must
            // win over an earlier `ToolUse` — the tool call may be incomplete
            // and must not be executed.
            Chunk::Stop(r) => {
                let keep_tool_use =
                    out.stop == Some(StopReason::ToolUse) && matches!(r, StopReason::EndTurn);
                if !keep_tool_use {
                    out.stop = Some(r);
                }
            }
        }
    }
    // Flush any call that started (and may have accreted args) but whose
    // `ToolCallEnd` never arrived — a provider that omits the terminator must
    // not lose the call.
    out.tool_calls.append(&mut pending);
    Ok(out)
}

/// Env var: emit a synthetic tool call for `<name>` on the stub provider.
///
/// First `complete()` of a turn emits a synthetic tool call for the named
/// tool, subsequent calls (once a `tool_result` has landed in the
/// transcript) fall back to canned `EndTurn` text. Empty / unset keeps the
/// canned-text behaviour. Used by the HITL resume loopback verification to
/// drive the data path without a real provider backend.
pub const STUB_TOOL_CALL_ENV: &str = "POLYCHROME_STUB_TOOL_CALL";

fn stub_tool_name() -> Option<String> {
    std::env::var(STUB_TOOL_CALL_ENV)
        .ok()
        .filter(|s| !s.is_empty())
}

/// A canned [`LlmProvider`] for wiring and tests.
///
/// Emits two text deltas, a usage tally, and an end-of-turn stop. No
/// network, no credentials.
///
/// When [`STUB_TOOL_CALL_ENV`] is set, the first `complete()` of a turn
/// emits a synthetic tool call (id `stub-call-1`) for that tool name and
/// the caller's function-calling loop drives the rest. Subsequent calls
/// in the same turn fall back to the `EndTurn` text path. Used by the
/// HITL resume loopback verification.
#[derive(Clone, Copy, Default)]
pub struct StubProvider;

#[async_trait]
impl LlmProvider for StubProvider {
    type Error = DummyError;

    async fn complete(
        &self,
        req: CompletionRequest,
    ) -> Result<futures::stream::BoxStream<'static, Result<Chunk, Self::Error>>, Self::Error> {
        // If POLYCHROME_STUB_TOOL_CALL is set and we haven't yet seen a
        // matching tool_result in the transcript, emit the synthetic tool
        // call. Otherwise fall through to canned text.
        if let Some(tool_name) = stub_tool_name() {
            let saw_result = req.messages.iter().any(|m| {
                m.content
                    .iter()
                    .any(|c| matches!(c, crate::Content::ToolResult(_)))
            });
            if !saw_result {
                let chunks = vec![
                    Ok(Chunk::tool_call_start("stub-call-1", &tool_name)),
                    Ok(Chunk::tool_call_args_delta("stub-call-1", "{}")),
                    Ok(Chunk::tool_call_end("stub-call-1")),
                    Ok(Chunk::Stop(StopReason::ToolUse)),
                ];
                return Ok(stream::iter(chunks).boxed());
            }
        }
        let chunks = vec![
            Ok(Chunk::text_delta("Hello from the ")),
            Ok(Chunk::text_delta("stub provider.")),
            Ok(Chunk::Usage(Usage {
                input_tokens: 5,
                output_tokens: 4,
            })),
            Ok(Chunk::Stop(StopReason::EndTurn)),
        ];
        Ok(stream::iter(chunks).boxed())
    }
}

#[cfg(test)]
mod tests {
    #![allow(clippy::pedantic, clippy::nursery, missing_docs)]

    use super::*;

    #[tokio::test]
    async fn stub_provider_collects_into_text() {
        let stream = StubProvider
            .complete(CompletionRequest::new("stub"))
            .await
            .expect("stream opens");
        let out = collect_turn(stream).await.expect("collect");
        assert_eq!(out.text, "Hello from the stub provider.");
        assert!(out.tool_calls.is_empty());
        assert_eq!(out.usage.output_tokens, 4);
        assert_eq!(out.stop, Some(StopReason::EndTurn));
    }

    #[tokio::test]
    async fn collect_assembles_tool_call_from_deltas() {
        let chunks: Vec<Result<Chunk, DummyError>> = vec![
            Ok(Chunk::text_delta("calling ")),
            Ok(Chunk::tool_call_start("c1", "search")),
            Ok(Chunk::tool_call_args_delta("c1", r#"{"q":"#)),
            Ok(Chunk::tool_call_args_delta("c1", r#""rust"}"#)),
            Ok(Chunk::tool_call_end("c1")),
            Ok(Chunk::Stop(StopReason::ToolUse)),
        ];
        let out = collect_turn(stream::iter(chunks)).await.expect("collect");
        assert_eq!(out.text, "calling ");
        assert_eq!(out.tool_calls.len(), 1);
        assert_eq!(out.tool_calls[0].name, "search");
        assert_eq!(out.tool_calls[0].args_json, r#"{"q":"rust"}"#);
        assert_eq!(out.stop, Some(StopReason::ToolUse));
    }

    #[tokio::test]
    async fn collect_keeps_parallel_tool_calls_with_deferred_ends() {
        // Two interleaved calls whose `ToolCallEnd`s are both deferred to the
        // end of the stream (the OpenAI-compatible provider's shape). A single
        // `Option` would drop call 0 and close the survivor with the wrong end;
        // id-matching must preserve both, in completion order.
        let chunks: Vec<Result<Chunk, DummyError>> = vec![
            Ok(Chunk::tool_call_start("c0", "search")),
            Ok(Chunk::tool_call_args_delta("c0", r#"{"q":"a"}"#)),
            Ok(Chunk::tool_call_start("c1", "fetch")),
            Ok(Chunk::tool_call_args_delta("c1", r#"{"u":"b"}"#)),
            Ok(Chunk::tool_call_end("c0")),
            Ok(Chunk::tool_call_end("c1")),
            Ok(Chunk::Stop(StopReason::ToolUse)),
        ];
        let out = collect_turn(stream::iter(chunks)).await.expect("collect");
        assert_eq!(out.tool_calls.len(), 2, "both parallel calls preserved");
        assert_eq!(out.tool_calls[0].id, "c0");
        assert_eq!(out.tool_calls[0].name, "search");
        assert_eq!(out.tool_calls[0].args_json, r#"{"q":"a"}"#);
        assert_eq!(out.tool_calls[1].id, "c1");
        assert_eq!(out.tool_calls[1].name, "fetch");
        assert_eq!(out.tool_calls[1].args_json, r#"{"u":"b"}"#);
        assert_eq!(out.stop, Some(StopReason::ToolUse));
    }

    #[tokio::test]
    async fn collect_flushes_a_call_left_open_at_eof() {
        // A provider that omits the terminal `ToolCallEnd` must not lose the
        // call — it is flushed when the stream ends.
        let chunks: Vec<Result<Chunk, DummyError>> = vec![
            Ok(Chunk::tool_call_start("c0", "search")),
            Ok(Chunk::tool_call_args_delta("c0", r#"{"q":"a"}"#)),
            Ok(Chunk::Stop(StopReason::ToolUse)),
        ];
        let out = collect_turn(stream::iter(chunks)).await.expect("collect");
        assert_eq!(out.tool_calls.len(), 1);
        assert_eq!(out.tool_calls[0].args_json, r#"{"q":"a"}"#);
    }

    #[tokio::test]
    async fn tool_use_stop_is_sticky_against_later_end_turn() {
        // Provider streams the tool call (ToolUse) then a trailing terminator
        // event (EndTurn). The terminator must NOT clobber ToolUse, else the
        // agent loop skips the tool.
        let chunks: Vec<Result<Chunk, DummyError>> = vec![
            Ok(Chunk::tool_call_start("c1", "search")),
            Ok(Chunk::tool_call_end("c1")),
            Ok(Chunk::Stop(StopReason::ToolUse)),
            Ok(Chunk::Stop(StopReason::EndTurn)),
        ];
        let out = collect_turn(stream::iter(chunks)).await.expect("collect");
        assert_eq!(out.stop, Some(StopReason::ToolUse));
    }

    #[tokio::test]
    async fn hard_stop_wins_over_earlier_tool_use() {
        // A later MaxTokens (truncation) MUST override an earlier ToolUse so
        // the agent doesn't execute a tool call with truncated arguments.
        let chunks: Vec<Result<Chunk, DummyError>> = vec![
            Ok(Chunk::tool_call_start("c1", "search")),
            Ok(Chunk::tool_call_end("c1")),
            Ok(Chunk::Stop(StopReason::ToolUse)),
            Ok(Chunk::Stop(StopReason::MaxTokens)),
        ];
        let out = collect_turn(stream::iter(chunks)).await.expect("collect");
        assert_eq!(out.stop, Some(StopReason::MaxTokens));
    }

    #[tokio::test]
    async fn collect_propagates_error() {
        let chunks: Vec<Result<Chunk, DummyError>> = vec![
            Ok(Chunk::text_delta("partial")),
            Err(DummyError::Other("mid-stream fault".to_owned())),
        ];
        let res = collect_turn(stream::iter(chunks)).await;
        assert!(res.is_err());
    }
}