Skip to main content

lash_sansio/
turn.rs

1use std::sync::Arc;
2
3use crate::MessageSequence;
4use crate::prompt::PreparedPrompt;
5use crate::sansio::{TurnMachine, TurnMachineConfig, TurnProtocol, UnitTurnProtocol};
6use crate::turn_driver::TurnDriverPreamble;
7
8pub struct SansIoTurnInput<M: TurnProtocol = UnitTurnProtocol> {
9    pub session_id: String,
10    pub run_session_id: Option<String>,
11    pub autonomous: bool,
12    pub model: String,
13    /// Model context-window size in tokens, if known. Threaded into the kernel
14    /// so it can reclassify a zero-output `OutputLimit` as `ContextOverflow`.
15    pub max_context_tokens: Option<usize>,
16    pub messages: MessageSequence,
17    pub events: Arc<Vec<crate::SessionEventRecord<M::Event>>>,
18    pub turn_causes: Vec<crate::TurnCause>,
19    pub protocol_run_offset: usize,
20    pub turn_driver_preamble: Arc<TurnDriverPreamble<M>>,
21    pub prepared_prompt: PreparedPrompt,
22    pub max_turns: Option<usize>,
23    pub model_variant: Option<String>,
24    pub generation: crate::llm::types::GenerationOptions,
25    pub emit_llm_trace: bool,
26    pub termination: M::Termination,
27}
28
29pub struct PreparedTurnMachine<M: TurnProtocol = UnitTurnProtocol> {
30    pub machine: TurnMachine<M>,
31    pub prepared_prompt: PreparedPrompt,
32    pub turn_driver_preamble: Arc<TurnDriverPreamble<M>>,
33}
34
35pub fn build_turn<M: TurnProtocol>(input: SansIoTurnInput<M>) -> PreparedTurnMachine<M> {
36    let machine = TurnMachine::new_shared_with_turn_causes(
37        TurnMachineConfig {
38            protocol_driver: input.turn_driver_preamble.config.protocol.clone(),
39            projector: input.turn_driver_preamble.config.projector.clone(),
40            sync_execution_surface: input.turn_driver_preamble.config.sync_execution_surface,
41            model: input.model,
42            max_context_tokens: input.max_context_tokens,
43            max_turns: input.max_turns,
44            model_variant: input.model_variant,
45            generation: input.generation,
46            run_session_id: input.run_session_id,
47            autonomous: input.autonomous,
48            tool_specs: input.turn_driver_preamble.tool_specs.clone(),
49            system_prompt: Arc::clone(&input.prepared_prompt.system_prompt),
50            session_id: input.session_id,
51            emit_llm_trace: input.emit_llm_trace,
52            termination: input.termination,
53            turn_limit_final_message: input
54                .turn_driver_preamble
55                .config
56                .turn_limit_final_message
57                .clone(),
58        },
59        input.messages,
60        input.events,
61        input.protocol_run_offset,
62        input.turn_causes,
63    );
64
65    PreparedTurnMachine {
66        machine,
67        prepared_prompt: input.prepared_prompt,
68        turn_driver_preamble: input.turn_driver_preamble,
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use std::sync::Arc;
75
76    use super::*;
77    use crate::sansio::{
78        CompletedToolCall, DriverAction, DriverContextView, ProtocolDriverHandle, WaitingExecState,
79        WaitingLlmState,
80    };
81    use crate::turn_driver::{TurnDriverConfig, TurnDriverPreamble};
82    use crate::{
83        PromptBuildInput, PromptContribution, PromptContributionSet, ToolDefinition,
84        ToolScheduling, build_prompt, default_prompt_template, prompt_template_fingerprint,
85        prompt_text_fingerprint,
86    };
87
88    fn tool(name: &str) -> ToolDefinition {
89        let mut definition = ToolDefinition::raw(
90            format!("tool:{name}"),
91            name,
92            format!("Tool {name}"),
93            serde_json::json!({
94                "type": "object",
95                "properties": { "path": { "type": "string" } },
96                "required": ["path"]
97            }),
98            serde_json::json!({ "type": "string" }),
99        );
100        definition.manifest.scheduling = ToolScheduling::Parallel;
101        definition
102    }
103
104    /// Minimal no-op driver so the turn-machine test can build a
105    /// `TurnDriverPreamble` without pulling in a protocol plugin crate (which would
106    /// create a cyclic dependency on `lash` from `lash-sansio`).
107    struct NoopDriver;
108
109    impl ProtocolDriverHandle for NoopDriver {
110        fn prepare_protocol_iteration(&self, _ctx: DriverContextView<'_>) -> Vec<DriverAction> {
111            Vec::new()
112        }
113
114        fn handle_llm_success(
115            &self,
116            _ctx: DriverContextView<'_>,
117            _waiting: WaitingLlmState,
118            _llm_response: crate::llm::types::LlmResponse,
119            _text_streamed: bool,
120        ) -> Vec<DriverAction> {
121            Vec::new()
122        }
123
124        fn handle_tool_results(
125            &self,
126            _ctx: DriverContextView<'_>,
127            _completed: Vec<CompletedToolCall>,
128        ) -> Vec<DriverAction> {
129            Vec::new()
130        }
131
132        fn handle_exec_result(
133            &self,
134            _ctx: DriverContextView<'_>,
135            _waiting: WaitingExecState,
136            _result: Result<crate::ExecResponse, String>,
137        ) -> Vec<DriverAction> {
138            Vec::new()
139        }
140    }
141
142    #[test]
143    fn build_turn_creates_machine_with_rendered_system_prompt() {
144        let tool_surface = Arc::new(crate::ToolSurface::from_tool_definitions(vec![tool(
145            "read_file",
146        )]));
147        let turn_driver_preamble = Arc::new(TurnDriverPreamble {
148            config: TurnDriverConfig::chat(
149                Arc::new(NoopDriver),
150                false,
151                Arc::new(test_turn_limit_final_message),
152            ),
153            tool_specs: tool_surface.model_tool_specs(),
154            tool_names: tool_surface.tool_names(),
155            tool_names_fingerprint: tool_surface.tool_names_fingerprint(),
156            omitted_tool_count: 0,
157            execution_prompt: Arc::from("test prompt"),
158            prompt_contributions: Vec::new(),
159        });
160        let template = default_prompt_template();
161        let prompt_contributions =
162            PromptContributionSet::new(vec![PromptContribution::guidance("Guide", "Be precise.")]);
163        let prepared_prompt = build_prompt(PromptBuildInput {
164            template_fingerprint: prompt_template_fingerprint(&template),
165            template,
166            execution_prompt_fingerprint: prompt_text_fingerprint(
167                &turn_driver_preamble.execution_prompt,
168            ),
169            execution_prompt: Arc::clone(&turn_driver_preamble.execution_prompt),
170            tool_names_fingerprint: turn_driver_preamble.tool_names_fingerprint,
171            tool_names: Arc::clone(&turn_driver_preamble.tool_names),
172            omitted_tool_count: turn_driver_preamble.omitted_tool_count,
173            contributions: prompt_contributions,
174        });
175        let prepared = build_turn(SansIoTurnInput {
176            session_id: "session".to_string(),
177            run_session_id: Some("run".to_string()),
178            autonomous: false,
179            model: "gpt-5".to_string(),
180            max_context_tokens: None,
181            messages: crate::MessageSequence::default(),
182            events: Arc::new(Vec::new()),
183            turn_causes: Vec::new(),
184            protocol_run_offset: 2,
185            turn_driver_preamble,
186            prepared_prompt,
187            max_turns: Some(3),
188            model_variant: Some("mini".to_string()),
189            generation: crate::llm::types::GenerationOptions::default(),
190            emit_llm_trace: true,
191            termination: (),
192        });
193
194        assert_eq!(prepared.machine.protocol_iteration(), 2);
195        assert!(
196            prepared
197                .prepared_prompt
198                .system_prompt
199                .contains("Be precise.")
200        );
201        assert_eq!(prepared.turn_driver_preamble.tool_specs.len(), 1);
202    }
203
204    fn test_turn_limit_final_message(message_id: String, max_turns: usize) -> crate::Message {
205        crate::Message {
206            id: message_id.clone(),
207            role: crate::MessageRole::System,
208            parts: crate::shared_parts(vec![crate::Part {
209                id: format!("{message_id}.p0"),
210                kind: crate::PartKind::Error,
211                content: format!("Turn limit reached ({max_turns}) before a final test response."),
212                attachment: None,
213                tool_call_id: None,
214                tool_name: None,
215                tool_replay: None,
216                prune_state: crate::PruneState::Intact,
217                reasoning_meta: None,
218                response_meta: None,
219            }]),
220            origin: None,
221        }
222    }
223}