Skip to main content

lash/
turn.rs

1use crate::support::*;
2
3pub use lash_core::{AssistantOutput, TurnIssue};
4
5/// The two internal event sinks threaded through the turn-execution helpers.
6///
7/// `events` is the raw `SessionEvent` stream (the lower-level escape hatch,
8/// reachable from app code only via [`TurnBuilder::advanced`]); `turn_events`
9/// is the semantic [`TurnActivity`] stream used by the primary builder API.
10/// Bundling them keeps the internal turn fns to a single sink parameter.
11#[derive(Clone, Copy, Default)]
12pub(crate) struct TurnSinks<'a> {
13    events: Option<&'a dyn EventSink>,
14    turn_events: Option<&'a dyn TurnActivitySink>,
15}
16
17impl<'a> TurnSinks<'a> {
18    pub(crate) fn turn(events: &'a dyn TurnActivitySink) -> Self {
19        Self {
20            events: None,
21            turn_events: Some(events),
22        }
23    }
24
25    pub(crate) fn session(events: &'a dyn EventSink) -> Self {
26        Self {
27            events: Some(events),
28            turn_events: None,
29        }
30    }
31
32    fn events(&self) -> Option<&'a dyn EventSink> {
33        self.events
34    }
35
36    fn turn_events(&self) -> Option<&'a dyn TurnActivitySink> {
37        self.turn_events
38    }
39}
40
41pub struct TurnBuilder {
42    pub(crate) runtime: RuntimeHandle,
43    pub(crate) active_plugins: Vec<ActivePluginBinding>,
44    pub(crate) input: TurnInput,
45    pub(crate) cancel: CancellationToken,
46    pub(crate) protocol_turn_options: Option<ProtocolTurnOptions>,
47    pub(crate) provider: Option<ProviderHandle>,
48    pub(crate) model: Option<lash_core::ModelSpec>,
49}
50
51impl TurnBuilder {
52    pub fn cancel(mut self, cancel: CancellationToken) -> Self {
53        self.cancel = cancel;
54        self
55    }
56
57    pub fn protocol_turn_options(mut self, options: ProtocolTurnOptions) -> Self {
58        self.protocol_turn_options = Some(options);
59        self
60    }
61
62    pub fn provider(mut self, provider: ProviderHandle) -> Self {
63        self.provider = Some(provider);
64        self
65    }
66
67    pub fn model(mut self, model: lash_core::ModelSpec) -> Self {
68        self.model = Some(model);
69        self
70    }
71
72    pub fn prompt_template(mut self, template: PromptTemplate) -> Self {
73        self.input.turn_context.set_prompt_template(template);
74        self
75    }
76
77    pub fn prompt_contribution(mut self, contribution: PromptContribution) -> Self {
78        self.input
79            .turn_context
80            .add_prompt_contribution(contribution);
81        self
82    }
83
84    pub fn replace_prompt_slot(
85        mut self,
86        slot: PromptSlot,
87        contributions: impl IntoIterator<Item = PromptContribution>,
88    ) -> Self {
89        self.input
90            .turn_context
91            .replace_prompt_slot(slot, contributions);
92        self
93    }
94
95    pub fn clear_prompt_slot(mut self, slot: PromptSlot) -> Self {
96        self.input.turn_context.clear_prompt_slot(slot);
97        self
98    }
99
100    pub fn prompt_layer(mut self, layer: PromptLayer) -> Self {
101        self.input.turn_context.set_prompt_layer(layer);
102        self
103    }
104
105    /// Attach typed per-turn input for an activated plugin binding.
106    ///
107    /// This is the generic primitive. Plugin crates should usually wrap it in a
108    /// domain extension trait such as `.with_tone(tone)` or `.with_board(board)`
109    /// so application code stays typed in its own vocabulary.
110    pub fn with_plugin_input<P: PluginBinding>(mut self, input: P::Input) -> Self {
111        self.input.turn_context.insert_plugin_input(P::ID, input);
112        self
113    }
114
115    pub async fn run(
116        self,
117        scoped_effect_controller: ScopedEffectController<'_>,
118    ) -> Result<TurnOutput> {
119        let collector = RunActivityCollector::default();
120        let result = self.stream(&collector, scoped_effect_controller).await?;
121        Ok(TurnOutput {
122            result,
123            activities: collector.into_activities(),
124        })
125    }
126
127    pub async fn collect_with(
128        self,
129        events: &dyn TurnActivitySink,
130        scoped_effect_controller: ScopedEffectController<'_>,
131    ) -> Result<TurnOutput> {
132        let collector = RunActivityCollector::default();
133        let fanout = BorrowedTurnActivityFanout {
134            live: events,
135            collector: &collector,
136        };
137        let result = self.stream(&fanout, scoped_effect_controller).await?;
138        Ok(TurnOutput {
139            result,
140            activities: collector.into_activities(),
141        })
142    }
143
144    /// Access lower-level turn execution that bypasses the semantic
145    /// [`TurnActivity`] tier.
146    pub fn advanced(self) -> AdvancedTurn {
147        AdvancedTurn { builder: self }
148    }
149
150    pub(crate) fn prepare(mut self) -> Result<(RuntimeHandle, TurnInput, CancellationToken)> {
151        if let Some(options) = self.protocol_turn_options {
152            self.input.protocol_turn_options = Some(options);
153        }
154        if let Some(provider) = self.provider {
155            self.input.turn_context.set_provider(provider);
156        }
157        if let Some(model) = self.model {
158            self.input.turn_context.set_model(model);
159        }
160        validate_required_plugin_inputs(&self.active_plugins, &self.input)?;
161        Ok((self.runtime, self.input, self.cancel))
162    }
163
164    pub async fn stream(
165        self,
166        events: &dyn TurnActivitySink,
167        scoped_effect_controller: ScopedEffectController<'_>,
168    ) -> Result<TurnResult> {
169        let (runtime, input, cancel) = self.prepare()?;
170        stream_prepared_turn(
171            &runtime,
172            input,
173            TurnSinks::turn(events),
174            scoped_effect_controller,
175            cancel,
176        )
177        .await
178    }
179
180    pub fn into_stream(
181        self,
182        scoped_effect_controller: ScopedEffectController<'static>,
183    ) -> Result<TurnStream> {
184        let (runtime, input, cancel) = self.prepare()?;
185        let (tx, rx) = mpsc::channel(64);
186        let sink = ChannelTurnActivitySink { tx };
187        let completion = tokio::spawn(async move {
188            stream_prepared_turn(
189                &runtime,
190                input,
191                TurnSinks::turn(&sink),
192                scoped_effect_controller,
193                cancel,
194            )
195            .await
196        });
197        Ok(TurnStream {
198            activities: rx,
199            completion,
200        })
201    }
202}
203
204/// Lower-level turn execution that exposes the raw `SessionEvent` stream.
205///
206/// Reachable via [`TurnBuilder::advanced`]. Most applications should use
207/// [`TurnBuilder::collect_with`] for semantic turn activity; benchmarks and
208/// diagnostics use this when they need the same session-event stream as the
209/// lower-level runtime trace.
210pub struct AdvancedTurn {
211    builder: TurnBuilder,
212}
213
214impl AdvancedTurn {
215    /// Run the turn while sending raw session events to `events`.
216    pub async fn collect_session_events_with(
217        self,
218        events: &dyn EventSink,
219        scoped_effect_controller: ScopedEffectController<'_>,
220    ) -> Result<TurnResult> {
221        let (runtime, input, cancel) = self.builder.prepare()?;
222        stream_prepared_turn(
223            &runtime,
224            input,
225            TurnSinks::session(events),
226            scoped_effect_controller,
227            cancel,
228        )
229        .await
230    }
231}
232
233pub struct TurnStream {
234    activities: mpsc::Receiver<Result<TurnActivity>>,
235    completion: JoinHandle<Result<TurnResult>>,
236}
237
238impl TurnStream {
239    pub async fn next_activity(&mut self) -> Option<Result<TurnActivity>> {
240        self.activities.recv().await
241    }
242
243    pub async fn finish(self) -> Result<TurnResult> {
244        self.completion.await.map_err(|err| {
245            EmbedError::Runtime(lash_core::RuntimeError::new(
246                RuntimeErrorCode::TurnStreamJoin,
247                format!("turn stream task failed: {err}"),
248            ))
249        })?
250    }
251}
252
253pub struct QueuedTurnBuilder {
254    pub(crate) runtime: RuntimeHandle,
255    pub(crate) cancel: CancellationToken,
256    pub(crate) batch_ids: Vec<String>,
257}
258
259impl QueuedTurnBuilder {
260    pub fn cancel(mut self, cancel: CancellationToken) -> Self {
261        self.cancel = cancel;
262        self
263    }
264
265    pub fn batch_ids(mut self, batch_ids: impl IntoIterator<Item = impl Into<String>>) -> Self {
266        self.batch_ids = batch_ids.into_iter().map(Into::into).collect();
267        self
268    }
269
270    pub async fn run(
271        self,
272        scoped_effect_controller: ScopedEffectController<'_>,
273    ) -> Result<Option<TurnOutput>> {
274        let collector = RunActivityCollector::default();
275        let Some(result) = self.stream(&collector, scoped_effect_controller).await? else {
276            return Ok(None);
277        };
278        Ok(Some(TurnOutput {
279            result,
280            activities: collector.into_activities(),
281        }))
282    }
283
284    pub async fn stream(
285        self,
286        events: &dyn TurnActivitySink,
287        scoped_effect_controller: ScopedEffectController<'_>,
288    ) -> Result<Option<TurnResult>> {
289        let Self {
290            runtime,
291            cancel,
292            batch_ids,
293        } = self;
294        stream_next_queued_prepared_turn(
295            &runtime,
296            TurnSinks::turn(events),
297            scoped_effect_controller,
298            cancel,
299            &batch_ids,
300        )
301        .await
302    }
303}
304
305pub(crate) async fn stream_next_queued_prepared_turn(
306    runtime: &RuntimeHandle,
307    sinks: TurnSinks<'_>,
308    scoped_effect_controller: ScopedEffectController<'_>,
309    cancel: CancellationToken,
310    batch_ids: &[String],
311) -> Result<Option<TurnResult>> {
312    let turn = Box::pin(stream_next_queued_prepared_assembled(
313        runtime,
314        sinks,
315        scoped_effect_controller,
316        cancel,
317        batch_ids,
318    ))
319    .await?;
320    Ok(turn.map(TurnResult::from_assembled))
321}
322
323pub(crate) async fn stream_next_queued_prepared_assembled(
324    runtime: &RuntimeHandle,
325    sinks: TurnSinks<'_>,
326    scoped_effect_controller: ScopedEffectController<'_>,
327    cancel: CancellationToken,
328    batch_ids: &[String],
329) -> Result<Option<AssembledTurn>> {
330    let writer_handle = runtime.writer();
331    let mut writer = writer_handle.lock().await;
332    let observation_sink = SessionObservationTurnActivitySink {
333        runtime: runtime.clone(),
334        live: sinks.turn_events(),
335    };
336    let opts = turn_options(
337        sinks.events(),
338        &observation_sink,
339        scoped_effect_controller,
340        cancel,
341    );
342    let turn = if batch_ids.is_empty() {
343        writer.stream_next_queued_work(opts).await?
344    } else {
345        writer.stream_selected_queued_work(opts, batch_ids).await?
346    };
347    runtime.publish_from(&writer);
348    Ok(turn)
349}
350
351fn turn_options<'a>(
352    events: Option<&'a dyn EventSink>,
353    turn_events: &'a dyn TurnActivitySink,
354    scoped_effect_controller: ScopedEffectController<'a>,
355    cancel: CancellationToken,
356) -> lash_core::TurnOptions<'a> {
357    let mut opts = lash_core::TurnOptions::new(cancel, scoped_effect_controller);
358    if let Some(events) = events {
359        opts = opts.with_events(events);
360    }
361    opts.with_turn_events(turn_events)
362}
363
364struct SessionObservationTurnActivitySink<'a> {
365    runtime: RuntimeHandle,
366    live: Option<&'a dyn TurnActivitySink>,
367}
368
369#[async_trait]
370impl TurnActivitySink for SessionObservationTurnActivitySink<'_> {
371    fn is_noop(&self) -> bool {
372        false
373    }
374
375    async fn emit(&self, activity: TurnActivity) {
376        self.runtime.record_turn_activity(activity.clone());
377        if let Some(live) = self.live {
378            live.emit(activity).await;
379        }
380    }
381}
382
383struct ChannelTurnActivitySink {
384    tx: mpsc::Sender<Result<TurnActivity>>,
385}
386
387#[async_trait]
388impl TurnActivitySink for ChannelTurnActivitySink {
389    async fn emit(&self, activity: TurnActivity) {
390        let _ = self.tx.send(Ok(activity)).await;
391    }
392}
393fn validate_required_plugin_inputs(
394    active_plugins: &[ActivePluginBinding],
395    input: &TurnInput,
396) -> Result<()> {
397    for plugin in active_plugins {
398        if plugin.requires_turn_input && !input.turn_context.has_plugin_input(plugin.id) {
399            return Err(EmbedError::MissingPluginTurnInput {
400                plugin_id: plugin.id,
401            });
402        }
403    }
404    Ok(())
405}
406
407pub(crate) async fn stream_prepared_turn(
408    runtime: &RuntimeHandle,
409    input: TurnInput,
410    sinks: TurnSinks<'_>,
411    scoped_effect_controller: ScopedEffectController<'_>,
412    cancel: CancellationToken,
413) -> Result<TurnResult> {
414    let turn = Box::pin(stream_prepared_assembled(
415        runtime,
416        input,
417        sinks,
418        scoped_effect_controller,
419        cancel,
420    ))
421    .await?;
422    Ok(TurnResult::from_assembled(turn))
423}
424
425pub(crate) async fn stream_prepared_assembled(
426    runtime: &RuntimeHandle,
427    input: TurnInput,
428    sinks: TurnSinks<'_>,
429    scoped_effect_controller: ScopedEffectController<'_>,
430    cancel: CancellationToken,
431) -> Result<AssembledTurn> {
432    let turn = Box::pin(stream_prepared_agent_frame_run(
433        runtime,
434        input,
435        sinks,
436        scoped_effect_controller,
437        cancel,
438    ))
439    .await?;
440    turn.into_final_turn().ok_or_else(|| {
441        EmbedError::Runtime(lash_core::RuntimeError::new(
442            RuntimeErrorCode::EmptyAgentFrameRun,
443            "runtime completed without an assembled turn",
444        ))
445    })
446}
447
448pub(crate) async fn stream_prepared_agent_frame_run(
449    runtime: &RuntimeHandle,
450    input: TurnInput,
451    sinks: TurnSinks<'_>,
452    scoped_effect_controller: ScopedEffectController<'_>,
453    cancel: CancellationToken,
454) -> Result<lash_core::AgentFrameRun> {
455    let writer_handle = runtime.writer();
456    let mut writer = writer_handle.lock().await;
457    if let Some(extension) = input.protocol_extension.as_ref() {
458        writer
459            .validate_protocol_turn_extension(extension)
460            .await
461            .map_err(EmbedError::Session)?;
462    }
463    let observation_sink = SessionObservationTurnActivitySink {
464        runtime: runtime.clone(),
465        live: sinks.turn_events(),
466    };
467    let turn = Box::pin(writer.stream_turn_with_agent_frames(
468        input,
469        turn_options(
470            sinks.events(),
471            &observation_sink,
472            scoped_effect_controller,
473            cancel,
474        ),
475    ))
476    .await?;
477    runtime.publish_from(&writer);
478    Ok(turn)
479}
480
481#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
482pub struct TurnResult {
483    pub state: SessionSnapshot,
484    pub outcome: TurnOutcome,
485    pub assistant_output: AssistantOutput,
486    /// Parent's own LLM tokens for this turn. Does **not** include child
487    /// sessions; see [`children_usage`](Self::children_usage) and
488    /// [`total_usage`](Self::total_usage).
489    pub usage: TokenUsage,
490    /// Per-`(source, model)` ledger entries for child sessions whose LLM
491    /// calls completed during this turn (subagents, compaction, observers,
492    /// etc.). Empty unless the turn spawned children.
493    #[serde(default)]
494    pub children_usage: Vec<TokenLedgerEntry>,
495    pub tool_calls: Vec<ToolCallRecord>,
496    pub execution: ExecutionSummary,
497    pub errors: Vec<TurnIssue>,
498}
499
500impl TurnResult {
501    fn from_assembled(turn: lash_core::AssembledTurn) -> Self {
502        Self {
503            state: turn.state,
504            outcome: turn.outcome,
505            assistant_output: turn.assistant_output,
506            usage: turn.token_usage,
507            children_usage: turn.children_usage,
508            tool_calls: turn.tool_calls,
509            execution: turn.execution,
510            errors: turn.errors,
511        }
512    }
513
514    /// Sum of parent's own LLM tokens and every child session's LLM tokens
515    /// for this turn.
516    pub fn total_usage(&self) -> TokenUsage {
517        let mut total = self.usage.clone();
518        for entry in &self.children_usage {
519            total.add(&entry.usage);
520        }
521        total
522    }
523
524    pub fn assistant_message(&self) -> Option<&str> {
525        match &self.outcome {
526            TurnOutcome::Finished(lash_core::TurnFinish::AssistantMessage { text }) => Some(text),
527            _ => None,
528        }
529    }
530
531    pub fn submitted_value(&self) -> Option<&serde_json::Value> {
532        match &self.outcome {
533            TurnOutcome::Finished(lash_core::TurnFinish::SubmittedValue { value }) => Some(value),
534            _ => None,
535        }
536    }
537
538    pub fn tool_value(&self) -> Option<(&str, &serde_json::Value)> {
539        match &self.outcome {
540            TurnOutcome::Finished(lash_core::TurnFinish::ToolValue { tool_name, value }) => {
541                Some((tool_name.as_str(), value))
542            }
543            _ => None,
544        }
545    }
546
547    pub fn is_success(&self) -> bool {
548        matches!(
549            self.outcome,
550            TurnOutcome::Finished(_) | TurnOutcome::AgentFrameSwitch { .. }
551        )
552    }
553}
554
555#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
556pub struct TurnOutput {
557    pub result: TurnResult,
558    pub activities: Vec<TurnActivity>,
559}
560
561impl TurnOutput {
562    pub fn assistant_message(&self) -> Option<&str> {
563        self.result.assistant_message()
564    }
565
566    pub fn submitted_value(&self) -> Option<&serde_json::Value> {
567        self.result.submitted_value()
568    }
569
570    pub fn tool_value(&self) -> Option<(&str, &serde_json::Value)> {
571        self.result.tool_value()
572    }
573
574    pub fn is_success(&self) -> bool {
575        self.result.is_success()
576    }
577}
578
579struct BorrowedTurnActivityFanout<'a> {
580    live: &'a dyn TurnActivitySink,
581    collector: &'a RunActivityCollector,
582}
583
584#[async_trait]
585impl TurnActivitySink for BorrowedTurnActivityFanout<'_> {
586    async fn emit(&self, activity: TurnActivity) {
587        self.live.emit(activity.clone()).await;
588        self.collector.emit(activity).await;
589    }
590}
591
592#[derive(Default)]
593pub(crate) struct RunActivityCollector {
594    activities: Arc<StdMutex<Vec<TurnActivity>>>,
595}
596
597impl RunActivityCollector {
598    fn into_activities(self) -> Vec<TurnActivity> {
599        self.activities
600            .lock()
601            .expect("run activity collector lock")
602            .clone()
603    }
604
605    #[cfg(test)]
606    pub(crate) fn snapshot(&self) -> Vec<TurnActivity> {
607        self.activities
608            .lock()
609            .expect("run activity collector lock")
610            .clone()
611    }
612}
613
614#[async_trait]
615impl TurnActivitySink for RunActivityCollector {
616    async fn emit(&self, activity: TurnActivity) {
617        self.activities
618            .lock()
619            .expect("run activity collector lock")
620            .push(activity);
621    }
622}
623
624pub struct TurnActivityFanout {
625    sinks: Vec<Arc<dyn TurnActivitySink>>,
626}
627
628impl TurnActivityFanout {
629    pub fn new(sinks: impl IntoIterator<Item = Arc<dyn TurnActivitySink>>) -> Self {
630        Self {
631            sinks: sinks.into_iter().collect(),
632        }
633    }
634}
635
636#[async_trait]
637impl TurnActivitySink for TurnActivityFanout {
638    async fn emit(&self, activity: TurnActivity) {
639        for sink in &self.sinks {
640            sink.emit(activity.clone()).await;
641        }
642    }
643}
644
645pub fn message_text(message: &Message) -> String {
646    message
647        .parts
648        .iter()
649        .map(|part| part.content.as_str())
650        .collect::<Vec<_>>()
651        .join("\n")
652}
653
654pub fn message_role(message: &Message) -> &'static str {
655    match message.role {
656        MessageRole::User => "user",
657        MessageRole::Assistant => "assistant",
658        MessageRole::System => "system",
659        MessageRole::Event => "event",
660    }
661}