Skip to main content

lash_sansio/
turn_driver.rs

1//! Shared types and helpers used by protocol drivers. Concrete drivers and
2//! their prompts live in protocol plugin crates; this module exposes the common
3//! turn-driver surface:
4//!
5//! - [`TurnDriverConfig`], [`TurnDriverPreamble`] — the per-turn configuration
6//!   driver-plugins populate.
7//! - A small helper layer (`normalized_response_parts`, `reasoning_part`,
8//!   `append_assistant_text_part`) that protocol drivers reuse for building
9//!   assistant messages.
10
11use std::sync::Arc;
12
13use crate::PromptContribution;
14use crate::PromptFingerprint;
15use crate::llm::types::{LlmOutputPart, LlmResponse, LlmToolSpec, ProviderReasoningReplay};
16use crate::sansio::{
17    ChatContextProjector, ContextProjector, ProtocolDriverHandle, TurnProtocol, UnitTurnProtocol,
18};
19use crate::session_model::{Part, PartKind, PruneState};
20
21pub type TurnLimitFinalMessage =
22    Arc<dyn Fn(String, usize) -> crate::Message + Send + Sync + 'static>;
23
24#[derive(Clone)]
25pub struct TurnDriverConfig<M: TurnProtocol = UnitTurnProtocol> {
26    pub protocol: Arc<dyn ProtocolDriverHandle<M>>,
27    pub projector: Arc<dyn ContextProjector<M>>,
28    pub sync_execution_surface: bool,
29    pub turn_limit_final_message: TurnLimitFinalMessage,
30}
31
32impl<M: TurnProtocol> TurnDriverConfig<M> {
33    pub fn chat(
34        protocol: Arc<dyn ProtocolDriverHandle<M>>,
35        sync_execution_surface: bool,
36        turn_limit_final_message: TurnLimitFinalMessage,
37    ) -> Self {
38        Self {
39            protocol,
40            projector: Arc::new(ChatContextProjector),
41            sync_execution_surface,
42            turn_limit_final_message,
43        }
44    }
45}
46
47#[derive(Clone)]
48pub struct TurnDriverPreamble<M: TurnProtocol = UnitTurnProtocol> {
49    pub config: TurnDriverConfig<M>,
50    pub tool_specs: Arc<Vec<LlmToolSpec>>,
51    pub tool_names: Arc<Vec<String>>,
52    pub tool_names_fingerprint: PromptFingerprint,
53    pub omitted_tool_count: usize,
54    pub execution_prompt: Arc<str>,
55    pub prompt_contributions: Vec<PromptContribution>,
56}
57
58/// Convert a raw `LlmResponse` into a stream of `LlmOutputPart`s that
59/// downstream code can iterate. When the response only carries
60/// `full_text` (provider didn't populate `parts`), synthesize a single
61/// `Text` part.
62pub fn normalized_response_parts(llm_response: &LlmResponse) -> Vec<LlmOutputPart> {
63    if llm_response.parts.is_empty() && !llm_response.full_text.is_empty() {
64        vec![LlmOutputPart::Text {
65            text: llm_response.full_text.clone(),
66            response_meta: None,
67        }]
68    } else {
69        llm_response.parts.clone()
70    }
71}
72
73/// Build a Reasoning `Part` from a reasoning item. `meta` is Some when
74/// the item carries provider replay metadata; None for display-only
75/// summaries.
76pub fn reasoning_part(
77    asst_id: &str,
78    index: usize,
79    text: String,
80    meta: Option<ProviderReasoningReplay>,
81) -> Part {
82    Part {
83        id: format!("{asst_id}.p{index}"),
84        kind: PartKind::Reasoning,
85        content: text,
86        attachment: None,
87        tool_call_id: None,
88        tool_name: None,
89        tool_replay: None,
90        prune_state: PruneState::Intact,
91        reasoning_meta: meta,
92        response_meta: None,
93    }
94}
95
96/// Append a streamed text part to the running assistant text, inserting
97/// the right number of blank lines so consecutive parts don't glue
98/// together.
99pub fn append_assistant_text_part(out: &mut String, next: &str) {
100    if out.is_empty() {
101        out.push_str(next);
102        return;
103    }
104
105    let prev_trailing_newlines = out.chars().rev().take_while(|ch| *ch == '\n').count();
106    let next_leading_newlines = next.chars().take_while(|ch| *ch == '\n').count();
107    let total_boundary_newlines = prev_trailing_newlines + next_leading_newlines;
108    if total_boundary_newlines < 2 {
109        out.push_str(&"\n".repeat(2 - total_boundary_newlines));
110    }
111
112    out.push_str(next);
113}