Skip to main content

lash_protocol_rlm/protocol/
driver.rs

1use lash_core::sansio::{
2    CheckpointResumeAction, CompletedToolCall, ProtocolDriverHandle, WaitingExecState,
3    WaitingLlmState,
4};
5use lash_core::session_model::{
6    ConversationRecord, Message, SessionEvent, SessionEventRecord, fresh_message_id,
7    make_error_event,
8};
9use lash_core::{
10    CheckpointKind, DriverAction, DriverContextView, ExecResponse, LlmOutputPart, LlmResponse,
11    ToolCallRecord, TurnFinish, TurnOutcome, TurnStop, append_assistant_text_part,
12    normalized_response_parts,
13};
14use lash_rlm_types::{RlmDiagnosticEvent, RlmProtocolEvent, RlmTermination, RlmTrajectoryEntry};
15use serde_json::Value;
16
17use crate::projection::rlm_protocol_event;
18use crate::rlm_support::decode_rlm_termination_options;
19
20use super::actions::{invalid_driver_state_actions, invalid_turn_options_actions};
21use super::fence::extract_first_lashlang_fence;
22use super::finish::{
23    internal_assistant_prose_message, submit_required_reminder_message,
24    submit_schema_mismatch_message, turn_limit_final_message, validate_finish_value,
25};
26use super::state::{RlmDriverState, decode_rlm_driver_state, rlm_driver_state};
27
28pub struct RlmDriver;
29
30impl ProtocolDriverHandle<lash_core::HostTurnProtocol> for RlmDriver {
31    fn prepare_protocol_iteration(&self, ctx: DriverContextView<'_>) -> Vec<DriverAction> {
32        if let Err(err) = decode_rlm_termination_options(ctx.termination()) {
33            return invalid_turn_options_actions(err);
34        }
35        vec![DriverAction::StartLlm {
36            request: ctx.project_llm_request(false),
37            driver_state: Some(rlm_driver_state(RlmDriverState::default())),
38        }]
39    }
40
41    fn handle_llm_success(
42        &self,
43        ctx: DriverContextView<'_>,
44        mut waiting: WaitingLlmState<lash_core::HostTurnProtocol>,
45        llm_response: LlmResponse,
46        _text_streamed: bool,
47    ) -> Vec<DriverAction> {
48        let mut actions = vec![DriverAction::Emit(SessionEvent::LlmResponse {
49            protocol_iteration: ctx.protocol_iteration(),
50            content: llm_response.full_text.clone(),
51            duration_ms: 0,
52        })];
53
54        let mut assistant_text = String::new();
55        let mut reasoning_text = String::new();
56        for part in normalized_response_parts(&llm_response) {
57            match part {
58                LlmOutputPart::Text { text, .. } => {
59                    append_assistant_text_part(&mut assistant_text, &text);
60                }
61                LlmOutputPart::Reasoning { text, replay } => {
62                    let reasoning = if text.trim().is_empty() {
63                        replay
64                            .as_ref()
65                            .map(|meta| meta.summary.join("\n\n"))
66                            .unwrap_or_default()
67                    } else {
68                        text
69                    };
70                    append_assistant_text_part(&mut reasoning_text, &reasoning);
71                }
72                LlmOutputPart::ToolCall { .. } => {}
73            }
74        }
75
76        if assistant_text.trim().is_empty() && reasoning_text.trim().is_empty() {
77            actions.push(DriverAction::Emit(make_error_event(
78                "llm_provider",
79                Some("empty_response"),
80                "Model returned no assistant text.",
81                None,
82            )));
83            actions.push(DriverAction::Finish(TurnOutcome::Stopped(
84                TurnStop::ProviderError,
85            )));
86            return actions;
87        }
88
89        let extraction = extract_first_lashlang_fence(&assistant_text);
90        let Some(fence) = extraction else {
91            let termination = match decode_rlm_termination_options(ctx.termination()) {
92                Ok(termination) => termination,
93                Err(err) => return invalid_turn_options_actions(err),
94            };
95            if matches!(termination, RlmTermination::ProseOrSubmit) {
96                actions.push(DriverAction::AppendEvents(vec![diagnostic_event(
97                    "llm_extraction",
98                    serde_json::json!({
99                        "found_lashlang_fence": false,
100                        "prose_only_ends_turn": true,
101                        "assistant_text_chars": assistant_text.chars().count(),
102                        "reasoning_chars": reasoning_text.chars().count(),
103                        "finalization_reason": "prose_or_submit",
104                    }),
105                )]));
106                actions.push(DriverAction::StartCheckpoint {
107                    checkpoint: CheckpointKind::BeforeCompletion,
108                    on_empty: CheckpointResumeAction::Finish(TurnOutcome::Finished(
109                        TurnFinish::AssistantMessage {
110                            text: assistant_text.clone(),
111                        },
112                    )),
113                });
114                return actions;
115            }
116            let RlmTermination::SubmitRequired { schema } = termination else {
117                unreachable!("ProseOrSubmit returned above");
118            };
119            actions.push(DriverAction::AppendEvents(vec![diagnostic_event(
120                "llm_extraction",
121                serde_json::json!({
122                    "found_lashlang_fence": false,
123                    "prose_only_ends_turn": false,
124                    "assistant_text_chars": assistant_text.chars().count(),
125                    "reasoning_chars": reasoning_text.chars().count(),
126                    "finalization_reason": "submit_required",
127                }),
128            )]));
129            let mut events = Vec::new();
130            if !assistant_text.trim().is_empty() {
131                events.push(conversation_event(internal_assistant_prose_message(
132                    assistant_text,
133                )));
134            }
135            events.push(conversation_event(submit_required_reminder_message(
136                schema.is_some(),
137            )));
138            if let Err(err) =
139                continue_or_stop_after_nonterminal(&ctx, &mut actions, Vec::new(), events)
140            {
141                return invalid_turn_options_actions(err);
142            }
143            return actions;
144        };
145
146        actions.push(DriverAction::AppendEvents(vec![diagnostic_event(
147            "llm_extraction",
148            serde_json::json!({
149                "found_lashlang_fence": true,
150                "had_extra_fences": fence.had_extra_fences,
151                "code_chars": fence.code.chars().count(),
152                "assistant_text_chars": assistant_text.chars().count(),
153                "reasoning_chars": reasoning_text.chars().count(),
154                "decision": "execute_lashlang",
155            }),
156        )]));
157
158        let Some(raw_state) = waiting.take_driver_state() else {
159            return invalid_driver_state_actions("missing RLM driver state".to_string());
160        };
161        let mut state = match decode_rlm_driver_state(raw_state) {
162            Ok(state) => state,
163            Err(err) => return invalid_driver_state_actions(err),
164        };
165        state.executed_code = Some(fence.code.clone());
166        state.reasoning = combine_reasoning_and_text(&reasoning_text, &assistant_text);
167
168        // Emit the raw lashlang source as a `Message` with kind
169        // `lashlang_code` so the CLI can reveal it in the full-expand
170        // view (Alt+O) above the tool activities it produced.
171        actions.push(DriverAction::Emit(SessionEvent::Message {
172            text: fence.code.clone(),
173            kind: "lashlang_code".to_string(),
174        }));
175        actions.push(DriverAction::StartExec {
176            code: fence.code,
177            driver_state: rlm_driver_state(state),
178        });
179        actions
180    }
181
182    fn handle_tool_results(
183        &self,
184        _ctx: DriverContextView<'_>,
185        _completed: Vec<CompletedToolCall>,
186    ) -> Vec<DriverAction> {
187        Vec::new()
188    }
189
190    fn handle_exec_result(
191        &self,
192        ctx: DriverContextView<'_>,
193        waiting: WaitingExecState<lash_core::HostTurnProtocol>,
194        result: Result<ExecResponse, String>,
195    ) -> Vec<DriverAction> {
196        let mut state = match decode_rlm_driver_state(waiting.into_driver_state()) {
197            Ok(state) => state,
198            Err(err) => return invalid_driver_state_actions(err),
199        };
200        let mut actions = Vec::new();
201
202        match result {
203            Ok(response) => {
204                let terminal_outcome = response
205                    .tool_calls
206                    .iter()
207                    .find_map(terminal_outcome_from_tool_result);
208                state.images.extend(response.printed_images);
209                for observation in response.observations {
210                    if !observation.is_empty() {
211                        state.output.push(observation);
212                    }
213                }
214                if let Some(raw_error) = response.error {
215                    state.exec_error = Some(raw_error);
216                }
217                if let Some(finish_value) = response.terminal_finish {
218                    state.terminal_finish = Some(finish_value);
219                }
220                if let Some(outcome) = terminal_outcome {
221                    actions.push(DriverAction::AppendEvents(vec![trajectory_event(
222                        trajectory_entry(ctx.protocol_iteration(), &state, None, None),
223                    )]));
224                    actions.push(DriverAction::StartCheckpoint {
225                        checkpoint: CheckpointKind::BeforeCompletion,
226                        on_empty: CheckpointResumeAction::Finish(outcome),
227                    });
228                    return actions;
229                }
230            }
231            Err(error) => {
232                state.exec_error = Some(error);
233            }
234        }
235
236        if let Some(finish_value) = &state.terminal_finish {
237            // Typed-RLM: validate against the declared schema. If it fails,
238            // surface the error to the model and loop; otherwise fall
239            // through to the shared terminate-with-value path below.
240            let termination = match decode_rlm_termination_options(ctx.termination()) {
241                Ok(termination) => termination,
242                Err(err) => return invalid_turn_options_actions(err),
243            };
244            if let RlmTermination::SubmitRequired {
245                schema: Some(schema),
246            } = termination
247            {
248                if let Err(error_text) = validate_finish_value(finish_value, &schema) {
249                    if let Err(err) = continue_or_stop_after_nonterminal(
250                        &ctx,
251                        &mut actions,
252                        vec![trajectory_event(trajectory_entry(
253                            ctx.protocol_iteration(),
254                            &state,
255                            Some(error_text.clone()),
256                            None,
257                        ))],
258                        vec![conversation_event(submit_schema_mismatch_message(
259                            &error_text,
260                        ))],
261                    ) {
262                        return invalid_turn_options_actions(err);
263                    }
264                    return actions;
265                }
266            }
267
268            actions.push(DriverAction::AppendEvents(vec![trajectory_event(
269                trajectory_entry(
270                    ctx.protocol_iteration(),
271                    &state,
272                    None,
273                    Some(finish_value.clone()),
274                ),
275            )]));
276            actions.push(DriverAction::StartCheckpoint {
277                checkpoint: CheckpointKind::BeforeCompletion,
278                on_empty: CheckpointResumeAction::Finish(TurnOutcome::Finished(
279                    TurnFinish::SubmittedValue {
280                        value: finish_value.clone(),
281                    },
282                )),
283            });
284            return actions;
285        }
286
287        if let Err(err) = continue_or_stop_after_nonterminal(
288            &ctx,
289            &mut actions,
290            vec![trajectory_event(trajectory_entry(
291                ctx.protocol_iteration(),
292                &state,
293                None,
294                None,
295            ))],
296            Vec::new(),
297        ) {
298            return invalid_turn_options_actions(err);
299        }
300        actions
301    }
302}
303
304fn continue_or_stop_after_nonterminal(
305    ctx: &DriverContextView<'_>,
306    actions: &mut Vec<DriverAction>,
307    durable_events: Vec<SessionEventRecord>,
308    retry_events: Vec<SessionEventRecord>,
309) -> Result<(), String> {
310    if !durable_events.is_empty() {
311        actions.push(DriverAction::AppendEvents(durable_events));
312    }
313    actions.push(DriverAction::AdvanceProtocolIteration);
314
315    if ctx.should_force_exit_after_grace_turn() {
316        actions.push(DriverAction::Finish(TurnOutcome::Stopped(
317            TurnStop::MaxTurns,
318        )));
319        return Ok(());
320    }
321
322    let next_protocol_iteration = ctx.protocol_iteration() + 1;
323    let reached_turn_limit = ctx
324        .max_turns()
325        .is_some_and(|max_turns| next_protocol_iteration >= ctx.protocol_run_offset() + max_turns);
326    if reached_turn_limit {
327        match decode_rlm_termination_options(ctx.termination())? {
328            RlmTermination::SubmitRequired { .. } => {
329                actions.push(DriverAction::Finish(TurnOutcome::Stopped(
330                    TurnStop::MaxTurns,
331                )));
332                return Ok(());
333            }
334            RlmTermination::ProseOrSubmit => {
335                if let Some(max_turns) = ctx.max_turns() {
336                    actions.push(DriverAction::ScheduleTurnLimitFinal {
337                        message: turn_limit_final_message(fresh_message_id(), max_turns),
338                    });
339                }
340            }
341        }
342    } else if !retry_events.is_empty() {
343        actions.push(DriverAction::AppendEvents(retry_events));
344    }
345
346    actions.push(DriverAction::StartCheckpoint {
347        checkpoint: CheckpointKind::AfterWork,
348        on_empty: CheckpointResumeAction::PrepareIteration,
349    });
350    Ok(())
351}
352
353fn terminal_outcome_from_tool_result(record: &ToolCallRecord) -> Option<TurnOutcome> {
354    if !record.output.is_success() {
355        return None;
356    }
357    match record.output.control.as_ref()? {
358        lash_core::ToolControl::SwitchAgentFrame { frame_id, .. }
359            if !frame_id.trim().is_empty() =>
360        {
361            Some(TurnOutcome::AgentFrameSwitch {
362                frame_id: frame_id.clone(),
363            })
364        }
365        lash_core::ToolControl::Finish { value } => {
366            Some(TurnOutcome::Finished(TurnFinish::ToolValue {
367                tool_name: record.tool.clone(),
368                value: value.to_json_value(),
369            }))
370        }
371        lash_core::ToolControl::Fail { failure } => {
372            Some(TurnOutcome::Stopped(TurnStop::ToolError {
373                tool_name: record.tool.clone(),
374                value: failure.to_json_value(),
375            }))
376        }
377        lash_core::ToolControl::SwitchAgentFrame { .. } => None,
378    }
379}
380
381fn trajectory_entry(
382    protocol_iteration: usize,
383    state: &RlmDriverState,
384    validation_error: Option<String>,
385    final_output: Option<Value>,
386) -> RlmTrajectoryEntry {
387    RlmTrajectoryEntry {
388        id: format!("rlm_step_{protocol_iteration}"),
389        protocol_iteration,
390        reasoning: state.reasoning.clone(),
391        code: state.executed_code.clone().unwrap_or_default(),
392        output: state.output.clone(),
393        images: state.images.clone(),
394        error: validation_error.or_else(|| state.exec_error.clone()),
395        final_output,
396    }
397}
398
399fn conversation_event(message: Message) -> SessionEventRecord {
400    SessionEventRecord::Conversation(ConversationRecord::from_message(message))
401}
402
403fn trajectory_event(entry: RlmTrajectoryEntry) -> SessionEventRecord {
404    SessionEventRecord::Protocol(rlm_protocol_event(RlmProtocolEvent::RlmTrajectoryEntry(
405        entry,
406    )))
407}
408
409fn diagnostic_event(phase: &str, payload: Value) -> SessionEventRecord {
410    SessionEventRecord::Protocol(rlm_protocol_event(RlmProtocolEvent::RlmDiagnostic(
411        RlmDiagnosticEvent {
412            phase: phase.to_string(),
413            payload,
414        },
415    )))
416}
417
418fn combine_reasoning_and_text(reasoning: &str, text: &str) -> String {
419    match (reasoning.trim().is_empty(), text.trim().is_empty()) {
420        (true, true) => String::new(),
421        (true, false) => text.to_string(),
422        (false, true) => reasoning.to_string(),
423        (false, false) => format!("{reasoning}\n\n{text}"),
424    }
425}