1use lash_core::sansio::{
2 CheckpointResumeAction, CompletedToolCall, ProtocolDriverHandle, WaitingExecState,
3 WaitingLlmState,
4};
5use lash_core::session_model::{
6 ConversationRecord, Message, SessionEvent, SessionEventRecord, fresh_message_id,
7 make_error_event,
8};
9use lash_core::{
10 CheckpointKind, DriverAction, DriverContextView, ExecResponse, LlmOutputPart, LlmResponse,
11 ToolCallRecord, TurnFinish, TurnOutcome, TurnStop, append_assistant_text_part,
12 normalized_response_parts,
13};
14use lash_rlm_types::{RlmDiagnosticEvent, RlmProtocolEvent, RlmTermination, RlmTrajectoryEntry};
15use serde_json::Value;
16
17use crate::projection::rlm_protocol_event;
18use crate::rlm_support::decode_rlm_termination_options;
19
20use super::actions::{invalid_driver_state_actions, invalid_turn_options_actions};
21use super::fence::extract_first_lashlang_fence;
22use super::finish::{
23 internal_assistant_prose_message, submit_required_reminder_message,
24 submit_schema_mismatch_message, turn_limit_final_message, validate_finish_value,
25};
26use super::state::{RlmDriverState, decode_rlm_driver_state, rlm_driver_state};
27
28pub struct RlmDriver;
29
30impl ProtocolDriverHandle<lash_core::HostTurnProtocol> for RlmDriver {
31 fn prepare_protocol_iteration(&self, ctx: DriverContextView<'_>) -> Vec<DriverAction> {
32 if let Err(err) = decode_rlm_termination_options(ctx.termination()) {
33 return invalid_turn_options_actions(err);
34 }
35 vec![DriverAction::StartLlm {
36 request: ctx.project_llm_request(false),
37 driver_state: Some(rlm_driver_state(RlmDriverState::default())),
38 }]
39 }
40
41 fn handle_llm_success(
42 &self,
43 ctx: DriverContextView<'_>,
44 mut waiting: WaitingLlmState<lash_core::HostTurnProtocol>,
45 llm_response: LlmResponse,
46 _text_streamed: bool,
47 ) -> Vec<DriverAction> {
48 let mut actions = vec![DriverAction::Emit(SessionEvent::LlmResponse {
49 protocol_iteration: ctx.protocol_iteration(),
50 content: llm_response.full_text.clone(),
51 duration_ms: 0,
52 })];
53
54 let mut assistant_text = String::new();
55 let mut reasoning_text = String::new();
56 for part in normalized_response_parts(&llm_response) {
57 match part {
58 LlmOutputPart::Text { text, .. } => {
59 append_assistant_text_part(&mut assistant_text, &text);
60 }
61 LlmOutputPart::Reasoning { text, replay } => {
62 let reasoning = if text.trim().is_empty() {
63 replay
64 .as_ref()
65 .map(|meta| meta.summary.join("\n\n"))
66 .unwrap_or_default()
67 } else {
68 text
69 };
70 append_assistant_text_part(&mut reasoning_text, &reasoning);
71 }
72 LlmOutputPart::ToolCall { .. } => {}
73 }
74 }
75
76 if assistant_text.trim().is_empty() && reasoning_text.trim().is_empty() {
77 actions.push(DriverAction::Emit(make_error_event(
78 "llm_provider",
79 Some("empty_response"),
80 "Model returned no assistant text.",
81 None,
82 )));
83 actions.push(DriverAction::Finish(TurnOutcome::Stopped(
84 TurnStop::ProviderError,
85 )));
86 return actions;
87 }
88
89 let extraction = extract_first_lashlang_fence(&assistant_text);
90 let Some(fence) = extraction else {
91 let termination = match decode_rlm_termination_options(ctx.termination()) {
92 Ok(termination) => termination,
93 Err(err) => return invalid_turn_options_actions(err),
94 };
95 if matches!(termination, RlmTermination::ProseOrSubmit) {
96 actions.push(DriverAction::AppendEvents(vec![diagnostic_event(
97 "llm_extraction",
98 serde_json::json!({
99 "found_lashlang_fence": false,
100 "prose_only_ends_turn": true,
101 "assistant_text_chars": assistant_text.chars().count(),
102 "reasoning_chars": reasoning_text.chars().count(),
103 "finalization_reason": "prose_or_submit",
104 }),
105 )]));
106 actions.push(DriverAction::StartCheckpoint {
107 checkpoint: CheckpointKind::BeforeCompletion,
108 on_empty: CheckpointResumeAction::Finish(TurnOutcome::Finished(
109 TurnFinish::AssistantMessage {
110 text: assistant_text.clone(),
111 },
112 )),
113 });
114 return actions;
115 }
116 let RlmTermination::SubmitRequired { schema } = termination else {
117 unreachable!("ProseOrSubmit returned above");
118 };
119 actions.push(DriverAction::AppendEvents(vec![diagnostic_event(
120 "llm_extraction",
121 serde_json::json!({
122 "found_lashlang_fence": false,
123 "prose_only_ends_turn": false,
124 "assistant_text_chars": assistant_text.chars().count(),
125 "reasoning_chars": reasoning_text.chars().count(),
126 "finalization_reason": "submit_required",
127 }),
128 )]));
129 let mut events = Vec::new();
130 if !assistant_text.trim().is_empty() {
131 events.push(conversation_event(internal_assistant_prose_message(
132 assistant_text,
133 )));
134 }
135 events.push(conversation_event(submit_required_reminder_message(
136 schema.is_some(),
137 )));
138 if let Err(err) =
139 continue_or_stop_after_nonterminal(&ctx, &mut actions, Vec::new(), events)
140 {
141 return invalid_turn_options_actions(err);
142 }
143 return actions;
144 };
145
146 actions.push(DriverAction::AppendEvents(vec![diagnostic_event(
147 "llm_extraction",
148 serde_json::json!({
149 "found_lashlang_fence": true,
150 "had_extra_fences": fence.had_extra_fences,
151 "code_chars": fence.code.chars().count(),
152 "assistant_text_chars": assistant_text.chars().count(),
153 "reasoning_chars": reasoning_text.chars().count(),
154 "decision": "execute_lashlang",
155 }),
156 )]));
157
158 let Some(raw_state) = waiting.take_driver_state() else {
159 return invalid_driver_state_actions("missing RLM driver state".to_string());
160 };
161 let mut state = match decode_rlm_driver_state(raw_state) {
162 Ok(state) => state,
163 Err(err) => return invalid_driver_state_actions(err),
164 };
165 state.executed_code = Some(fence.code.clone());
166 state.reasoning = combine_reasoning_and_text(&reasoning_text, &assistant_text);
167
168 actions.push(DriverAction::Emit(SessionEvent::Message {
172 text: fence.code.clone(),
173 kind: "lashlang_code".to_string(),
174 }));
175 actions.push(DriverAction::StartExec {
176 code: fence.code,
177 driver_state: rlm_driver_state(state),
178 });
179 actions
180 }
181
182 fn handle_tool_results(
183 &self,
184 _ctx: DriverContextView<'_>,
185 _completed: Vec<CompletedToolCall>,
186 ) -> Vec<DriverAction> {
187 Vec::new()
188 }
189
190 fn handle_exec_result(
191 &self,
192 ctx: DriverContextView<'_>,
193 waiting: WaitingExecState<lash_core::HostTurnProtocol>,
194 result: Result<ExecResponse, String>,
195 ) -> Vec<DriverAction> {
196 let mut state = match decode_rlm_driver_state(waiting.into_driver_state()) {
197 Ok(state) => state,
198 Err(err) => return invalid_driver_state_actions(err),
199 };
200 let mut actions = Vec::new();
201
202 match result {
203 Ok(response) => {
204 let terminal_outcome = response
205 .tool_calls
206 .iter()
207 .find_map(terminal_outcome_from_tool_result);
208 state.images.extend(response.printed_images);
209 for observation in response.observations {
210 if !observation.is_empty() {
211 state.output.push(observation);
212 }
213 }
214 if let Some(raw_error) = response.error {
215 state.exec_error = Some(raw_error);
216 }
217 if let Some(finish_value) = response.terminal_finish {
218 state.terminal_finish = Some(finish_value);
219 }
220 if let Some(outcome) = terminal_outcome {
221 actions.push(DriverAction::AppendEvents(vec![trajectory_event(
222 trajectory_entry(ctx.protocol_iteration(), &state, None, None),
223 )]));
224 actions.push(DriverAction::StartCheckpoint {
225 checkpoint: CheckpointKind::BeforeCompletion,
226 on_empty: CheckpointResumeAction::Finish(outcome),
227 });
228 return actions;
229 }
230 }
231 Err(error) => {
232 state.exec_error = Some(error);
233 }
234 }
235
236 if let Some(finish_value) = &state.terminal_finish {
237 let termination = match decode_rlm_termination_options(ctx.termination()) {
241 Ok(termination) => termination,
242 Err(err) => return invalid_turn_options_actions(err),
243 };
244 if let RlmTermination::SubmitRequired {
245 schema: Some(schema),
246 } = termination
247 {
248 if let Err(error_text) = validate_finish_value(finish_value, &schema) {
249 if let Err(err) = continue_or_stop_after_nonterminal(
250 &ctx,
251 &mut actions,
252 vec![trajectory_event(trajectory_entry(
253 ctx.protocol_iteration(),
254 &state,
255 Some(error_text.clone()),
256 None,
257 ))],
258 vec![conversation_event(submit_schema_mismatch_message(
259 &error_text,
260 ))],
261 ) {
262 return invalid_turn_options_actions(err);
263 }
264 return actions;
265 }
266 }
267
268 actions.push(DriverAction::AppendEvents(vec![trajectory_event(
269 trajectory_entry(
270 ctx.protocol_iteration(),
271 &state,
272 None,
273 Some(finish_value.clone()),
274 ),
275 )]));
276 actions.push(DriverAction::StartCheckpoint {
277 checkpoint: CheckpointKind::BeforeCompletion,
278 on_empty: CheckpointResumeAction::Finish(TurnOutcome::Finished(
279 TurnFinish::SubmittedValue {
280 value: finish_value.clone(),
281 },
282 )),
283 });
284 return actions;
285 }
286
287 if let Err(err) = continue_or_stop_after_nonterminal(
288 &ctx,
289 &mut actions,
290 vec![trajectory_event(trajectory_entry(
291 ctx.protocol_iteration(),
292 &state,
293 None,
294 None,
295 ))],
296 Vec::new(),
297 ) {
298 return invalid_turn_options_actions(err);
299 }
300 actions
301 }
302}
303
304fn continue_or_stop_after_nonterminal(
305 ctx: &DriverContextView<'_>,
306 actions: &mut Vec<DriverAction>,
307 durable_events: Vec<SessionEventRecord>,
308 retry_events: Vec<SessionEventRecord>,
309) -> Result<(), String> {
310 if !durable_events.is_empty() {
311 actions.push(DriverAction::AppendEvents(durable_events));
312 }
313 actions.push(DriverAction::AdvanceProtocolIteration);
314
315 if ctx.should_force_exit_after_grace_turn() {
316 actions.push(DriverAction::Finish(TurnOutcome::Stopped(
317 TurnStop::MaxTurns,
318 )));
319 return Ok(());
320 }
321
322 let next_protocol_iteration = ctx.protocol_iteration() + 1;
323 let reached_turn_limit = ctx
324 .max_turns()
325 .is_some_and(|max_turns| next_protocol_iteration >= ctx.protocol_run_offset() + max_turns);
326 if reached_turn_limit {
327 match decode_rlm_termination_options(ctx.termination())? {
328 RlmTermination::SubmitRequired { .. } => {
329 actions.push(DriverAction::Finish(TurnOutcome::Stopped(
330 TurnStop::MaxTurns,
331 )));
332 return Ok(());
333 }
334 RlmTermination::ProseOrSubmit => {
335 if let Some(max_turns) = ctx.max_turns() {
336 actions.push(DriverAction::ScheduleTurnLimitFinal {
337 message: turn_limit_final_message(fresh_message_id(), max_turns),
338 });
339 }
340 }
341 }
342 } else if !retry_events.is_empty() {
343 actions.push(DriverAction::AppendEvents(retry_events));
344 }
345
346 actions.push(DriverAction::StartCheckpoint {
347 checkpoint: CheckpointKind::AfterWork,
348 on_empty: CheckpointResumeAction::PrepareIteration,
349 });
350 Ok(())
351}
352
353fn terminal_outcome_from_tool_result(record: &ToolCallRecord) -> Option<TurnOutcome> {
354 if !record.output.is_success() {
355 return None;
356 }
357 match record.output.control.as_ref()? {
358 lash_core::ToolControl::SwitchAgentFrame { frame_id, .. }
359 if !frame_id.trim().is_empty() =>
360 {
361 Some(TurnOutcome::AgentFrameSwitch {
362 frame_id: frame_id.clone(),
363 })
364 }
365 lash_core::ToolControl::Finish { value } => {
366 Some(TurnOutcome::Finished(TurnFinish::ToolValue {
367 tool_name: record.tool.clone(),
368 value: value.to_json_value(),
369 }))
370 }
371 lash_core::ToolControl::Fail { failure } => {
372 Some(TurnOutcome::Stopped(TurnStop::ToolError {
373 tool_name: record.tool.clone(),
374 value: failure.to_json_value(),
375 }))
376 }
377 lash_core::ToolControl::SwitchAgentFrame { .. } => None,
378 }
379}
380
381fn trajectory_entry(
382 protocol_iteration: usize,
383 state: &RlmDriverState,
384 validation_error: Option<String>,
385 final_output: Option<Value>,
386) -> RlmTrajectoryEntry {
387 RlmTrajectoryEntry {
388 id: format!("rlm_step_{protocol_iteration}"),
389 protocol_iteration,
390 reasoning: state.reasoning.clone(),
391 code: state.executed_code.clone().unwrap_or_default(),
392 output: state.output.clone(),
393 images: state.images.clone(),
394 error: validation_error.or_else(|| state.exec_error.clone()),
395 final_output,
396 }
397}
398
399fn conversation_event(message: Message) -> SessionEventRecord {
400 SessionEventRecord::Conversation(ConversationRecord::from_message(message))
401}
402
403fn trajectory_event(entry: RlmTrajectoryEntry) -> SessionEventRecord {
404 SessionEventRecord::Protocol(rlm_protocol_event(RlmProtocolEvent::RlmTrajectoryEntry(
405 entry,
406 )))
407}
408
409fn diagnostic_event(phase: &str, payload: Value) -> SessionEventRecord {
410 SessionEventRecord::Protocol(rlm_protocol_event(RlmProtocolEvent::RlmDiagnostic(
411 RlmDiagnosticEvent {
412 phase: phase.to_string(),
413 payload,
414 },
415 )))
416}
417
418fn combine_reasoning_and_text(reasoning: &str, text: &str) -> String {
419 match (reasoning.trim().is_empty(), text.trim().is_empty()) {
420 (true, true) => String::new(),
421 (true, false) => text.to_string(),
422 (false, true) => reasoning.to_string(),
423 (false, false) => format!("{reasoning}\n\n{text}"),
424 }
425}