polyc-llm 0.1.3

Provider-agnostic LLM trait + wire types for polychrome.
Documentation
//! Streaming response types: [`Chunk`], [`Usage`], and [`StopReason`].
//!
//! Where [`request`](crate::request) describes what goes *into* a provider,
//! this module describes what streams *out*. A provider yields an ordered
//! sequence of [`Chunk`]s; the planner reassembles them into assistant turns.
//!
//! The shape is the richest streaming superset across the backends we target:
//! text deltas, a tool call that announces itself and then accretes its JSON
//! arguments incrementally ([`Chunk::ToolCallArgsDelta`] `args_json_delta`), and
//! usage. Critically, tool-call arguments arrive as **JSON string deltas** in
//! every backend: we do not attempt to deserialize them until the matching
//! [`Chunk::ToolCallEnd`].

use serde::{Deserialize, Serialize};

// ── Chunk ──────────────────────────────────────────────────────────────────────

/// A single event in a provider's streaming response.
///
/// A complete stream is an ordered sequence of these. Text generation surfaces
/// as a run of [`Chunk::TextDelta`] events; a tool call surfaces as exactly one
/// [`Chunk::ToolCallStart`], zero or more [`Chunk::ToolCallArgsDelta`] fragments
/// (whose concatenated `args_json_delta`s form the call's JSON arguments), and
/// exactly one [`Chunk::ToolCallEnd`]. Every tool-call event carries the call
/// `id` so concurrently-streamed calls can be demultiplexed. [`Chunk::Usage`]
/// reports token accounting and may appear more than once. A well-formed stream
/// ends with a single [`Chunk::Stop`].
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum Chunk {
    /// An incremental piece of generated text. Concatenate consecutive
    /// `TextDelta`s to recover the full text segment.
    TextDelta(String),
    /// The model has begun a tool call. The `id` and `name` are known up front;
    /// arguments stream in as subsequent [`Chunk::ToolCallArgsDelta`]s bearing
    /// the same `id`.
    ToolCallStart {
        /// Provider-assigned call identifier, matching a future
        /// [`ToolCall::id`](crate::request::ToolCall::id).
        id: String,
        /// Name of the tool being called.
        name: String,
        /// Opaque provider-specific signature for this call (e.g. a thinking
        /// model's thought signature), to be carried onto the assembled
        /// [`ToolCall`](crate::request::ToolCall) and echoed back on the
        /// follow-up request. `None` when the provider emits no such token.
        #[serde(default, skip_serializing_if = "Option::is_none")]
        signature: Option<String>,
    },
    /// An incremental fragment of a tool call's JSON arguments. Concatenate the
    /// `args_json_delta`s of all fragments sharing an `id` to recover the full
    /// `args_json`. Do not parse until the matching [`Chunk::ToolCallEnd`].
    ToolCallArgsDelta {
        /// Identifies which in-progress [`Chunk::ToolCallStart`] this fragment
        /// belongs to.
        id: String,
        /// A partial slice of the call's JSON-encoded arguments.
        args_json_delta: String,
    },
    /// The named tool call's arguments are complete and may now be parsed.
    ToolCallEnd {
        /// Identifies the completed [`Chunk::ToolCallStart`].
        id: String,
    },
    /// A token-accounting update. May arrive more than once per stream (e.g.
    /// input tokens early, output tokens at the end).
    Usage(Usage),
    /// Terminal event: generation has finished for the reason given.
    Stop(StopReason),
}

impl Chunk {
    /// Wraps `s` in a [`Chunk::TextDelta`].
    #[must_use]
    pub fn text_delta(s: impl Into<String>) -> Self {
        Self::TextDelta(s.into())
    }

    /// Constructs a [`Chunk::ToolCallStart`] event (no signature).
    #[must_use]
    pub fn tool_call_start(id: impl Into<String>, name: impl Into<String>) -> Self {
        Self::ToolCallStart {
            id: id.into(),
            name: name.into(),
            signature: None,
        }
    }

    /// Constructs a [`Chunk::ToolCallStart`] event carrying an opaque
    /// provider-specific `signature`.
    #[must_use]
    pub fn tool_call_start_signed(
        id: impl Into<String>,
        name: impl Into<String>,
        signature: Option<String>,
    ) -> Self {
        Self::ToolCallStart {
            id: id.into(),
            name: name.into(),
            signature,
        }
    }

    /// Constructs a [`Chunk::ToolCallArgsDelta`] carrying a partial-args fragment.
    #[must_use]
    pub fn tool_call_args_delta(id: impl Into<String>, args_json_delta: impl Into<String>) -> Self {
        Self::ToolCallArgsDelta {
            id: id.into(),
            args_json_delta: args_json_delta.into(),
        }
    }

    /// Constructs a [`Chunk::ToolCallEnd`] event.
    #[must_use]
    pub fn tool_call_end(id: impl Into<String>) -> Self {
        Self::ToolCallEnd { id: id.into() }
    }
}

// ── Usage ──────────────────────────────────────────────────────────────────────

/// Token accounting for a request/response pair.
///
/// Counts are cumulative within a single stream. Surfaced to the
/// `polychrome_llm_tokens_total{direction}` metric as input/output directions.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct Usage {
    /// Tokens consumed by the prompt (system + messages + tools).
    pub input_tokens: u64,
    /// Tokens produced by the model.
    pub output_tokens: u64,
}

impl Usage {
    /// Total tokens billed for this exchange (`input_tokens + output_tokens`).
    ///
    /// Saturates rather than overflowing; real responses never approach
    /// `u64::MAX`, but the arithmetic is total so callers need no guard.
    #[must_use]
    pub const fn total_tokens(self) -> u64 {
        self.input_tokens.saturating_add(self.output_tokens)
    }
}

// ── StopReason ───────────────────────────────────────────────────────────────────

/// Why the model stopped generating.
///
/// The variants are the provider-agnostic union of the common providers'
/// terminal states.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum StopReason {
    /// The model finished its turn naturally.
    EndTurn,
    /// Generation hit the request's `max_tokens` ceiling.
    MaxTokens,
    /// The model emitted one of the request's `stop` sequences.
    StopSequence,
    /// The model paused to call one or more tools; the caller is expected to
    /// run them and continue the conversation.
    ToolUse,
    /// The model declined to answer, or the provider's content filter halted
    /// generation. Mirrors the wire-side `STOP_REASON_REFUSAL`.
    Refusal,
}

// ── Tests ─────────────────────────────────────────────────────────────────────

#[cfg(test)]
mod tests {
    #![allow(clippy::pedantic, clippy::nursery, missing_docs)]

    use serde_json::{Value, json};

    use super::*;

    #[test]
    fn text_delta_constructor() {
        assert_eq!(Chunk::text_delta("hi"), Chunk::TextDelta("hi".to_owned()));
    }

    #[test]
    fn tool_call_start_constructor() {
        match Chunk::tool_call_start("call-1", "search") {
            Chunk::ToolCallStart { id, name, .. } => {
                assert_eq!(id, "call-1");
                assert_eq!(name, "search");
            }
            _ => panic!("wrong variant"),
        }
    }

    #[test]
    fn tool_call_args_delta_constructor() {
        match Chunk::tool_call_args_delta("call-1", r#"{"q":"#) {
            Chunk::ToolCallArgsDelta {
                id,
                args_json_delta,
            } => {
                assert_eq!(id, "call-1");
                assert_eq!(args_json_delta, r#"{"q":"#);
            }
            _ => panic!("wrong variant"),
        }
    }

    #[test]
    fn tool_call_end_constructor() {
        assert_eq!(
            Chunk::tool_call_end("call-1"),
            Chunk::ToolCallEnd {
                id: "call-1".to_owned()
            },
        );
    }

    #[test]
    fn text_delta_serializes_as_tagged_object() {
        let v: Value = serde_json::to_value(Chunk::text_delta("hello")).unwrap();
        assert_eq!(v, json!({"text_delta": "hello"}));
    }

    #[test]
    fn tool_call_start_serializes_with_named_fields() {
        let v: Value = serde_json::to_value(Chunk::tool_call_start("id-1", "calc")).unwrap();
        assert_eq!(
            v,
            json!({"tool_call_start": {"id": "id-1", "name": "calc"}})
        );
    }

    #[test]
    fn tool_call_args_delta_serializes_with_named_fields() {
        let v: Value =
            serde_json::to_value(Chunk::tool_call_args_delta("id-1", r#"{"x":1}"#)).unwrap();
        assert_eq!(
            v,
            json!({"tool_call_args_delta": {"id": "id-1", "args_json_delta": r#"{"x":1}"#}}),
        );
    }

    #[test]
    fn chunk_round_trips_all_variants() {
        for chunk in [
            Chunk::text_delta("partial"),
            Chunk::tool_call_start("c1", "weather"),
            Chunk::tool_call_args_delta("c1", r#"{"city":"NYC"}"#),
            Chunk::tool_call_end("c1"),
            Chunk::Usage(Usage {
                input_tokens: 10,
                output_tokens: 20,
            }),
            Chunk::Stop(StopReason::EndTurn),
        ] {
            let json = serde_json::to_string(&chunk).unwrap();
            let back: Chunk = serde_json::from_str(&json).unwrap();
            assert_eq!(back, chunk);
        }
    }

    #[test]
    fn reassemble_tool_call_args_from_deltas_by_id() {
        // Concatenating same-id ToolCallArgsDelta payloads recovers the full
        // JSON; a foreign id must not bleed into the assembly.
        let stream = [
            Chunk::tool_call_start("a", "weather"),
            Chunk::tool_call_args_delta("a", r#"{"city":"#),
            Chunk::tool_call_args_delta("b", "IGNORED"),
            Chunk::tool_call_args_delta("a", r#""NYC"}"#),
            Chunk::tool_call_end("a"),
        ];
        let mut assembled = String::new();
        for c in &stream {
            if let Chunk::ToolCallArgsDelta {
                id,
                args_json_delta,
            } = c
                && id == "a"
            {
                assembled.push_str(args_json_delta);
            }
        }
        let parsed: Value = serde_json::from_str(&assembled).unwrap();
        assert_eq!(parsed, json!({"city": "NYC"}));
    }

    #[test]
    fn usage_total_sums_input_and_output() {
        let u = Usage {
            input_tokens: 100,
            output_tokens: 250,
        };
        assert_eq!(u.total_tokens(), 350);
    }

    #[test]
    fn usage_total_saturates_on_overflow() {
        let u = Usage {
            input_tokens: u64::MAX,
            output_tokens: 1,
        };
        assert_eq!(u.total_tokens(), u64::MAX);
    }

    #[test]
    fn usage_default_is_all_zero() {
        let u = Usage::default();
        assert_eq!(u.input_tokens, 0);
        assert_eq!(u.output_tokens, 0);
        assert_eq!(u.total_tokens(), 0);
    }

    #[test]
    fn stop_reason_serializes_to_snake_case() {
        assert_eq!(
            serde_json::to_string(&StopReason::EndTurn).unwrap(),
            r#""end_turn""#
        );
        assert_eq!(
            serde_json::to_string(&StopReason::MaxTokens).unwrap(),
            r#""max_tokens""#
        );
        assert_eq!(
            serde_json::to_string(&StopReason::StopSequence).unwrap(),
            r#""stop_sequence""#,
        );
        assert_eq!(
            serde_json::to_string(&StopReason::ToolUse).unwrap(),
            r#""tool_use""#
        );
        assert_eq!(
            serde_json::to_string(&StopReason::Refusal).unwrap(),
            r#""refusal""#,
        );
    }

    #[test]
    fn stop_reason_round_trips() {
        for reason in [
            StopReason::EndTurn,
            StopReason::MaxTokens,
            StopReason::StopSequence,
            StopReason::ToolUse,
            StopReason::Refusal,
        ] {
            let json = serde_json::to_string(&reason).unwrap();
            let back: StopReason = serde_json::from_str(&json).unwrap();
            assert_eq!(back, reason);
        }
    }
}