Skip to main content

lash/
turn.rs

1use crate::support::*;
2
3pub use lash_core::{AssistantOutput, TurnIssue};
4
5/// The two internal event sinks threaded through the turn-execution helpers.
6///
7/// `events` is the raw `SessionEvent` stream (the lower-level escape hatch,
8/// reachable from app code only via [`TurnBuilder::advanced`]); `turn_events`
9/// is the semantic [`TurnActivity`] stream used by the primary builder API.
10/// Bundling them keeps the internal turn fns to a single sink parameter.
11#[derive(Clone, Copy, Default)]
12pub(crate) struct TurnSinks<'a> {
13    events: Option<&'a dyn EventSink>,
14    turn_events: Option<&'a dyn TurnActivitySink>,
15}
16
17impl<'a> TurnSinks<'a> {
18    pub(crate) fn turn(events: &'a dyn TurnActivitySink) -> Self {
19        Self {
20            events: None,
21            turn_events: Some(events),
22        }
23    }
24
25    pub(crate) fn session(events: &'a dyn EventSink) -> Self {
26        Self {
27            events: Some(events),
28            turn_events: None,
29        }
30    }
31
32    fn events(&self) -> Option<&'a dyn EventSink> {
33        self.events
34    }
35
36    fn turn_events(&self) -> Option<&'a dyn TurnActivitySink> {
37        self.turn_events
38    }
39}
40
41pub struct TurnBuilder {
42    pub(crate) runtime: RuntimeHandle,
43    pub(crate) active_plugins: Vec<ActivePluginBinding>,
44    pub(crate) input: TurnInput,
45    pub(crate) cancel: CancellationToken,
46    pub(crate) protocol_turn_options: Option<ProtocolTurnOptions>,
47    pub(crate) provider: Option<ProviderHandle>,
48    pub(crate) model: Option<lash_core::ModelSpec>,
49}
50
51impl TurnBuilder {
52    pub fn cancel(mut self, cancel: CancellationToken) -> Self {
53        self.cancel = cancel;
54        self
55    }
56
57    pub fn protocol_turn_options(mut self, options: ProtocolTurnOptions) -> Self {
58        self.protocol_turn_options = Some(options);
59        self
60    }
61
62    pub fn provider(mut self, provider: ProviderHandle) -> Self {
63        self.provider = Some(provider);
64        self
65    }
66
67    pub fn model(mut self, model: lash_core::ModelSpec) -> Self {
68        self.model = Some(model);
69        self
70    }
71
72    pub fn prompt_template(mut self, template: PromptTemplate) -> Self {
73        self.input.turn_context.set_prompt_template(template);
74        self
75    }
76
77    pub fn prompt_contribution(mut self, contribution: PromptContribution) -> Self {
78        self.input
79            .turn_context
80            .add_prompt_contribution(contribution);
81        self
82    }
83
84    pub fn replace_prompt_slot(
85        mut self,
86        slot: PromptSlot,
87        contributions: impl IntoIterator<Item = PromptContribution>,
88    ) -> Self {
89        self.input
90            .turn_context
91            .replace_prompt_slot(slot, contributions);
92        self
93    }
94
95    pub fn clear_prompt_slot(mut self, slot: PromptSlot) -> Self {
96        self.input.turn_context.clear_prompt_slot(slot);
97        self
98    }
99
100    pub fn prompt_layer(mut self, layer: PromptLayer) -> Self {
101        self.input.turn_context.set_prompt_layer(layer);
102        self
103    }
104
105    /// Attach typed per-turn input for an activated plugin binding.
106    ///
107    /// This is the generic primitive. Plugin crates should usually wrap it in a
108    /// domain extension trait such as `.with_tone(tone)` or `.with_board(board)`
109    /// so application code stays typed in its own vocabulary.
110    pub fn with_plugin_input<P: PluginBinding>(mut self, input: P::Input) -> Self {
111        self.input.turn_context.insert_plugin_input(P::ID, input);
112        self
113    }
114
115    pub async fn run(self) -> Result<TurnOutput> {
116        let collector = RunActivityCollector::default();
117        let result = self.stream(&collector).await?;
118        Ok(TurnOutput {
119            result,
120            activities: collector.into_activities(),
121        })
122    }
123
124    pub async fn run_with_effect_scope(
125        self,
126        scoped_effect_controller: ScopedEffectController<'_>,
127    ) -> Result<TurnOutput> {
128        let collector = RunActivityCollector::default();
129        let result = self
130            .stream_with_effect_scope(&collector, scoped_effect_controller)
131            .await?;
132        Ok(TurnOutput {
133            result,
134            activities: collector.into_activities(),
135        })
136    }
137
138    pub async fn collect_with(self, events: &dyn TurnActivitySink) -> Result<TurnOutput> {
139        let collector = RunActivityCollector::default();
140        let fanout = BorrowedTurnActivityFanout {
141            live: events,
142            collector: &collector,
143        };
144        let result = self.stream(&fanout).await?;
145        Ok(TurnOutput {
146            result,
147            activities: collector.into_activities(),
148        })
149    }
150
151    /// Access lower-level turn execution that bypasses the semantic
152    /// [`TurnActivity`] tier.
153    pub fn advanced(self) -> AdvancedTurn {
154        AdvancedTurn { builder: self }
155    }
156
157    pub(crate) fn prepare(mut self) -> Result<(RuntimeHandle, TurnInput, CancellationToken)> {
158        if let Some(options) = self.protocol_turn_options {
159            self.input.protocol_turn_options = Some(options);
160        }
161        if let Some(provider) = self.provider {
162            self.input.turn_context.set_provider(provider);
163        }
164        if let Some(model) = self.model {
165            self.input.turn_context.set_model(model);
166        }
167        validate_required_plugin_inputs(&self.active_plugins, &self.input)?;
168        Ok((self.runtime, self.input, self.cancel))
169    }
170
171    pub async fn stream(self, events: &dyn TurnActivitySink) -> Result<TurnResult> {
172        let (runtime, input, cancel) = self.prepare()?;
173        stream_prepared_turn(&runtime, input, TurnSinks::turn(events), None, cancel).await
174    }
175
176    pub async fn stream_with_effect_scope(
177        self,
178        events: &dyn TurnActivitySink,
179        scoped_effect_controller: ScopedEffectController<'_>,
180    ) -> Result<TurnResult> {
181        let (runtime, input, cancel) = self.prepare()?;
182        stream_prepared_turn(
183            &runtime,
184            input,
185            TurnSinks::turn(events),
186            Some(scoped_effect_controller),
187            cancel,
188        )
189        .await
190    }
191
192    pub fn into_stream(self) -> Result<TurnStream> {
193        let (runtime, input, cancel) = self.prepare()?;
194        let (tx, rx) = mpsc::channel(64);
195        let sink = ChannelTurnActivitySink { tx };
196        let completion = tokio::spawn(async move {
197            stream_prepared_turn(&runtime, input, TurnSinks::turn(&sink), None, cancel).await
198        });
199        Ok(TurnStream {
200            activities: rx,
201            completion,
202        })
203    }
204}
205
206/// Lower-level turn execution that exposes the raw `SessionEvent` stream.
207///
208/// Reachable via [`TurnBuilder::advanced`]. Most applications should use
209/// [`TurnBuilder::collect_with`] for semantic turn activity; benchmarks and
210/// diagnostics use this when they need the same session-event stream as the
211/// lower-level runtime trace.
212pub struct AdvancedTurn {
213    builder: TurnBuilder,
214}
215
216impl AdvancedTurn {
217    /// Run the turn while sending raw session events to `events`.
218    pub async fn collect_session_events_with(self, events: &dyn EventSink) -> Result<TurnResult> {
219        let (runtime, input, cancel) = self.builder.prepare()?;
220        stream_prepared_turn(&runtime, input, TurnSinks::session(events), None, cancel).await
221    }
222}
223
224pub struct TurnStream {
225    activities: mpsc::Receiver<Result<TurnActivity>>,
226    completion: JoinHandle<Result<TurnResult>>,
227}
228
229impl TurnStream {
230    pub async fn next_activity(&mut self) -> Option<Result<TurnActivity>> {
231        self.activities.recv().await
232    }
233
234    pub async fn finish(self) -> Result<TurnResult> {
235        self.completion.await.map_err(|err| {
236            EmbedError::Runtime(lash_core::RuntimeError::new(
237                RuntimeErrorCode::TurnStreamJoin,
238                format!("turn stream task failed: {err}"),
239            ))
240        })?
241    }
242}
243
244pub struct QueuedTurnBuilder {
245    pub(crate) runtime: RuntimeHandle,
246    pub(crate) cancel: CancellationToken,
247    pub(crate) batch_ids: Vec<String>,
248}
249
250impl QueuedTurnBuilder {
251    pub fn cancel(mut self, cancel: CancellationToken) -> Self {
252        self.cancel = cancel;
253        self
254    }
255
256    pub fn batch_ids(mut self, batch_ids: impl IntoIterator<Item = impl Into<String>>) -> Self {
257        self.batch_ids = batch_ids.into_iter().map(Into::into).collect();
258        self
259    }
260
261    pub async fn run(self) -> Result<Option<TurnOutput>> {
262        let collector = RunActivityCollector::default();
263        let Some(result) = self.stream(&collector).await? else {
264            return Ok(None);
265        };
266        Ok(Some(TurnOutput {
267            result,
268            activities: collector.into_activities(),
269        }))
270    }
271
272    pub async fn run_with_effect_scope(
273        self,
274        scoped_effect_controller: ScopedEffectController<'_>,
275    ) -> Result<Option<TurnOutput>> {
276        let collector = RunActivityCollector::default();
277        let Some(result) = self
278            .stream_with_effect_scope(&collector, scoped_effect_controller)
279            .await?
280        else {
281            return Ok(None);
282        };
283        Ok(Some(TurnOutput {
284            result,
285            activities: collector.into_activities(),
286        }))
287    }
288
289    pub async fn stream(self, events: &dyn TurnActivitySink) -> Result<Option<TurnResult>> {
290        let Self {
291            runtime,
292            cancel,
293            batch_ids,
294        } = self;
295        stream_next_queued_prepared_turn(
296            &runtime,
297            TurnSinks::turn(events),
298            None,
299            cancel,
300            &batch_ids,
301        )
302        .await
303    }
304
305    pub async fn stream_with_effect_scope(
306        self,
307        events: &dyn TurnActivitySink,
308        scoped_effect_controller: ScopedEffectController<'_>,
309    ) -> Result<Option<TurnResult>> {
310        let Self {
311            runtime,
312            cancel,
313            batch_ids,
314        } = self;
315        stream_next_queued_prepared_turn(
316            &runtime,
317            TurnSinks::turn(events),
318            Some(scoped_effect_controller),
319            cancel,
320            &batch_ids,
321        )
322        .await
323    }
324}
325
326pub(crate) async fn stream_next_queued_prepared_turn(
327    runtime: &RuntimeHandle,
328    sinks: TurnSinks<'_>,
329    scoped_effect_controller: Option<ScopedEffectController<'_>>,
330    cancel: CancellationToken,
331    batch_ids: &[String],
332) -> Result<Option<TurnResult>> {
333    let turn = Box::pin(stream_next_queued_prepared_assembled(
334        runtime,
335        sinks,
336        scoped_effect_controller,
337        cancel,
338        batch_ids,
339    ))
340    .await?;
341    Ok(turn.map(TurnResult::from_assembled))
342}
343
344pub(crate) async fn stream_next_queued_prepared_assembled(
345    runtime: &RuntimeHandle,
346    sinks: TurnSinks<'_>,
347    scoped_effect_controller: Option<ScopedEffectController<'_>>,
348    cancel: CancellationToken,
349    batch_ids: &[String],
350) -> Result<Option<AssembledTurn>> {
351    let writer_handle = runtime.writer();
352    let mut writer = writer_handle.lock().await;
353    let opts = turn_options(sinks, scoped_effect_controller, cancel);
354    let turn = if batch_ids.is_empty() {
355        writer.stream_next_queued_work(opts).await?
356    } else {
357        writer.stream_selected_queued_work(opts, batch_ids).await?
358    };
359    runtime.publish_from(&writer);
360    Ok(turn)
361}
362
363fn turn_options<'a>(
364    sinks: TurnSinks<'a>,
365    scoped_effect_controller: Option<ScopedEffectController<'a>>,
366    cancel: CancellationToken,
367) -> lash_core::TurnOptions<'a> {
368    let mut opts = lash_core::TurnOptions::new(cancel);
369    if let Some(events) = sinks.events() {
370        opts = opts.with_events(events);
371    }
372    if let Some(turn_events) = sinks.turn_events() {
373        opts = opts.with_turn_events(turn_events);
374    }
375    if let Some(scoped_effect_controller) = scoped_effect_controller {
376        opts = opts.with_scoped_effect_controller(scoped_effect_controller);
377    }
378    opts
379}
380
381struct ChannelTurnActivitySink {
382    tx: mpsc::Sender<Result<TurnActivity>>,
383}
384
385#[async_trait]
386impl TurnActivitySink for ChannelTurnActivitySink {
387    async fn emit(&self, activity: TurnActivity) {
388        let _ = self.tx.send(Ok(activity)).await;
389    }
390}
391fn validate_required_plugin_inputs(
392    active_plugins: &[ActivePluginBinding],
393    input: &TurnInput,
394) -> Result<()> {
395    for plugin in active_plugins {
396        if plugin.requires_turn_input && !input.turn_context.has_plugin_input(plugin.id) {
397            return Err(EmbedError::MissingPluginTurnInput {
398                plugin_id: plugin.id,
399            });
400        }
401    }
402    Ok(())
403}
404
405pub(crate) async fn stream_prepared_turn(
406    runtime: &RuntimeHandle,
407    input: TurnInput,
408    sinks: TurnSinks<'_>,
409    scoped_effect_controller: Option<ScopedEffectController<'_>>,
410    cancel: CancellationToken,
411) -> Result<TurnResult> {
412    let turn = Box::pin(stream_prepared_assembled(
413        runtime,
414        input,
415        sinks,
416        scoped_effect_controller,
417        cancel,
418    ))
419    .await?;
420    Ok(TurnResult::from_assembled(turn))
421}
422
423pub(crate) async fn stream_prepared_assembled(
424    runtime: &RuntimeHandle,
425    input: TurnInput,
426    sinks: TurnSinks<'_>,
427    scoped_effect_controller: Option<ScopedEffectController<'_>>,
428    cancel: CancellationToken,
429) -> Result<AssembledTurn> {
430    let turn = Box::pin(stream_prepared_agent_frame_run(
431        runtime,
432        input,
433        sinks,
434        scoped_effect_controller,
435        cancel,
436    ))
437    .await?;
438    turn.into_final_turn().ok_or_else(|| {
439        EmbedError::Runtime(lash_core::RuntimeError::new(
440            RuntimeErrorCode::EmptyAgentFrameRun,
441            "runtime completed without an assembled turn",
442        ))
443    })
444}
445
446pub(crate) async fn stream_prepared_agent_frame_run(
447    runtime: &RuntimeHandle,
448    input: TurnInput,
449    sinks: TurnSinks<'_>,
450    scoped_effect_controller: Option<ScopedEffectController<'_>>,
451    cancel: CancellationToken,
452) -> Result<lash_core::AgentFrameRun> {
453    let writer_handle = runtime.writer();
454    let mut writer = writer_handle.lock().await;
455    if let Some(extension) = input.protocol_extension.as_ref() {
456        writer
457            .validate_protocol_turn_extension(extension)
458            .await
459            .map_err(EmbedError::Session)?;
460    }
461    let turn = Box::pin(writer.stream_turn_with_agent_frames(
462        input,
463        turn_options(sinks, scoped_effect_controller, cancel),
464    ))
465    .await?;
466    runtime.publish_from(&writer);
467    Ok(turn)
468}
469
470#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
471pub struct TurnResult {
472    pub state: SessionSnapshot,
473    pub outcome: TurnOutcome,
474    pub assistant_output: AssistantOutput,
475    /// Parent's own LLM tokens for this turn. Does **not** include child
476    /// sessions; see [`children_usage`](Self::children_usage) and
477    /// [`total_usage`](Self::total_usage).
478    pub usage: TokenUsage,
479    /// Per-`(source, model)` ledger entries for child sessions whose LLM
480    /// calls completed during this turn (subagents, compaction, observers,
481    /// etc.). Empty unless the turn spawned children.
482    #[serde(default)]
483    pub children_usage: Vec<TokenLedgerEntry>,
484    pub tool_calls: Vec<ToolCallRecord>,
485    pub execution: ExecutionSummary,
486    pub errors: Vec<TurnIssue>,
487}
488
489impl TurnResult {
490    fn from_assembled(turn: lash_core::AssembledTurn) -> Self {
491        Self {
492            state: turn.state,
493            outcome: turn.outcome,
494            assistant_output: turn.assistant_output,
495            usage: turn.token_usage,
496            children_usage: turn.children_usage,
497            tool_calls: turn.tool_calls,
498            execution: turn.execution,
499            errors: turn.errors,
500        }
501    }
502
503    /// Sum of parent's own LLM tokens and every child session's LLM tokens
504    /// for this turn.
505    pub fn total_usage(&self) -> TokenUsage {
506        let mut total = self.usage.clone();
507        for entry in &self.children_usage {
508            total.add(&entry.usage);
509        }
510        total
511    }
512
513    pub fn assistant_message(&self) -> Option<&str> {
514        match &self.outcome {
515            TurnOutcome::Finished(lash_core::TurnFinish::AssistantMessage { text }) => Some(text),
516            _ => None,
517        }
518    }
519
520    pub fn submitted_value(&self) -> Option<&serde_json::Value> {
521        match &self.outcome {
522            TurnOutcome::Finished(lash_core::TurnFinish::SubmittedValue { value }) => Some(value),
523            _ => None,
524        }
525    }
526
527    pub fn tool_value(&self) -> Option<(&str, &serde_json::Value)> {
528        match &self.outcome {
529            TurnOutcome::Finished(lash_core::TurnFinish::ToolValue { tool_name, value }) => {
530                Some((tool_name.as_str(), value))
531            }
532            _ => None,
533        }
534    }
535
536    pub fn is_success(&self) -> bool {
537        matches!(
538            self.outcome,
539            TurnOutcome::Finished(_) | TurnOutcome::AgentFrameSwitch { .. }
540        )
541    }
542}
543
544#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
545pub struct TurnOutput {
546    pub result: TurnResult,
547    pub activities: Vec<TurnActivity>,
548}
549
550impl TurnOutput {
551    pub fn assistant_message(&self) -> Option<&str> {
552        self.result.assistant_message()
553    }
554
555    pub fn submitted_value(&self) -> Option<&serde_json::Value> {
556        self.result.submitted_value()
557    }
558
559    pub fn tool_value(&self) -> Option<(&str, &serde_json::Value)> {
560        self.result.tool_value()
561    }
562
563    pub fn is_success(&self) -> bool {
564        self.result.is_success()
565    }
566}
567
568struct BorrowedTurnActivityFanout<'a> {
569    live: &'a dyn TurnActivitySink,
570    collector: &'a RunActivityCollector,
571}
572
573#[async_trait]
574impl TurnActivitySink for BorrowedTurnActivityFanout<'_> {
575    async fn emit(&self, activity: TurnActivity) {
576        self.live.emit(activity.clone()).await;
577        self.collector.emit(activity).await;
578    }
579}
580
581#[derive(Default)]
582pub(crate) struct RunActivityCollector {
583    activities: Arc<StdMutex<Vec<TurnActivity>>>,
584}
585
586impl RunActivityCollector {
587    fn into_activities(self) -> Vec<TurnActivity> {
588        self.activities
589            .lock()
590            .expect("run activity collector lock")
591            .clone()
592    }
593
594    #[cfg(test)]
595    pub(crate) fn snapshot(&self) -> Vec<TurnActivity> {
596        self.activities
597            .lock()
598            .expect("run activity collector lock")
599            .clone()
600    }
601}
602
603#[async_trait]
604impl TurnActivitySink for RunActivityCollector {
605    async fn emit(&self, activity: TurnActivity) {
606        self.activities
607            .lock()
608            .expect("run activity collector lock")
609            .push(activity);
610    }
611}
612
613pub struct TurnActivityFanout {
614    sinks: Vec<Arc<dyn TurnActivitySink>>,
615}
616
617impl TurnActivityFanout {
618    pub fn new(sinks: impl IntoIterator<Item = Arc<dyn TurnActivitySink>>) -> Self {
619        Self {
620            sinks: sinks.into_iter().collect(),
621        }
622    }
623}
624
625#[async_trait]
626impl TurnActivitySink for TurnActivityFanout {
627    async fn emit(&self, activity: TurnActivity) {
628        for sink in &self.sinks {
629            sink.emit(activity.clone()).await;
630        }
631    }
632}
633
634pub fn message_text(message: &Message) -> String {
635    message
636        .parts
637        .iter()
638        .map(|part| part.content.as_str())
639        .collect::<Vec<_>>()
640        .join("\n")
641}
642
643pub fn message_role(message: &Message) -> &'static str {
644    match message.role {
645        MessageRole::User => "user",
646        MessageRole::Assistant => "assistant",
647        MessageRole::System => "system",
648        MessageRole::Event => "event",
649    }
650}