Skip to main content

lash/
turn.rs

1use crate::support::*;
2
3pub use lash_core::{AssistantOutput, TurnIssue};
4
5/// The two internal event sinks threaded through the turn-execution helpers.
6///
7/// `events` is the raw `SessionEvent` stream (the lower-level escape hatch,
8/// reachable from app code only via [`TurnBuilder::advanced`]); `turn_events`
9/// is the semantic [`TurnActivity`] stream used by the primary builder API.
10/// Bundling them keeps the internal turn fns to a single sink parameter.
11#[derive(Clone, Copy, Default)]
12pub(crate) struct TurnSinks<'a> {
13    events: Option<&'a dyn EventSink>,
14    turn_events: Option<&'a dyn TurnActivitySink>,
15}
16
17impl<'a> TurnSinks<'a> {
18    pub(crate) fn turn(events: &'a dyn TurnActivitySink) -> Self {
19        Self {
20            events: None,
21            turn_events: Some(events),
22        }
23    }
24
25    pub(crate) fn session(events: &'a dyn EventSink) -> Self {
26        Self {
27            events: Some(events),
28            turn_events: None,
29        }
30    }
31
32    fn events(&self) -> Option<&'a dyn EventSink> {
33        self.events
34    }
35
36    fn turn_events(&self) -> Option<&'a dyn TurnActivitySink> {
37        self.turn_events
38    }
39}
40
41pub struct TurnBuilder {
42    pub(crate) runtime: RuntimeHandle,
43    pub(crate) active_plugins: Vec<ActivePluginBinding>,
44    pub(crate) input: TurnInput,
45    pub(crate) cancel: CancellationToken,
46    pub(crate) protocol_turn_options: Option<ProtocolTurnOptions>,
47    pub(crate) provider: Option<ProviderHandle>,
48    pub(crate) model: Option<lash_core::ModelSpec>,
49}
50
51impl TurnBuilder {
52    pub fn cancel(mut self, cancel: CancellationToken) -> Self {
53        self.cancel = cancel;
54        self
55    }
56
57    pub fn protocol_turn_options(mut self, options: ProtocolTurnOptions) -> Self {
58        self.protocol_turn_options = Some(options);
59        self
60    }
61
62    pub fn provider(mut self, provider: ProviderHandle) -> Self {
63        self.provider = Some(provider);
64        self
65    }
66
67    pub fn model(mut self, model: lash_core::ModelSpec) -> Self {
68        self.model = Some(model);
69        self
70    }
71
72    pub fn prompt_template(mut self, template: PromptTemplate) -> Self {
73        self.input.turn_context.set_prompt_template(template);
74        self
75    }
76
77    pub fn prompt_contribution(mut self, contribution: PromptContribution) -> Self {
78        self.input
79            .turn_context
80            .add_prompt_contribution(contribution);
81        self
82    }
83
84    pub fn replace_prompt_slot(
85        mut self,
86        slot: PromptSlot,
87        contributions: impl IntoIterator<Item = PromptContribution>,
88    ) -> Self {
89        self.input
90            .turn_context
91            .replace_prompt_slot(slot, contributions);
92        self
93    }
94
95    pub fn clear_prompt_slot(mut self, slot: PromptSlot) -> Self {
96        self.input.turn_context.clear_prompt_slot(slot);
97        self
98    }
99
100    pub fn prompt_layer(mut self, layer: PromptLayer) -> Self {
101        self.input.turn_context.set_prompt_layer(layer);
102        self
103    }
104
105    /// Attach typed per-turn input for an activated plugin binding.
106    ///
107    /// This is the generic primitive. Plugin crates should usually wrap it in a
108    /// domain extension trait such as `.with_tone(tone)` or `.with_board(board)`
109    /// so application code stays typed in its own vocabulary.
110    pub fn with_plugin_input<P: PluginBinding>(mut self, input: P::Input) -> Self {
111        self.input.turn_context.insert_plugin_input(P::ID, input);
112        self
113    }
114
115    pub async fn run(
116        self,
117        scoped_effect_controller: ScopedEffectController<'_>,
118    ) -> Result<TurnOutput> {
119        let collector = RunActivityCollector::default();
120        let result = self.stream(&collector, scoped_effect_controller).await?;
121        Ok(TurnOutput {
122            result,
123            activities: collector.into_activities(),
124        })
125    }
126
127    pub async fn collect_with(
128        self,
129        events: &dyn TurnActivitySink,
130        scoped_effect_controller: ScopedEffectController<'_>,
131    ) -> Result<TurnOutput> {
132        let collector = RunActivityCollector::default();
133        let fanout = BorrowedTurnActivityFanout {
134            live: events,
135            collector: &collector,
136        };
137        let result = self.stream(&fanout, scoped_effect_controller).await?;
138        Ok(TurnOutput {
139            result,
140            activities: collector.into_activities(),
141        })
142    }
143
144    /// Access lower-level turn execution that bypasses the semantic
145    /// [`TurnActivity`] tier.
146    pub fn advanced(self) -> AdvancedTurn {
147        AdvancedTurn { builder: self }
148    }
149
150    pub(crate) fn prepare(mut self) -> Result<(RuntimeHandle, TurnInput, CancellationToken)> {
151        if let Some(options) = self.protocol_turn_options {
152            self.input.protocol_turn_options = Some(options);
153        }
154        if let Some(provider) = self.provider {
155            self.input.turn_context.set_provider(provider);
156        }
157        if let Some(model) = self.model {
158            self.input.turn_context.set_model(model);
159        }
160        validate_required_plugin_inputs(&self.active_plugins, &self.input)?;
161        Ok((self.runtime, self.input, self.cancel))
162    }
163
164    pub async fn stream(
165        self,
166        events: &dyn TurnActivitySink,
167        scoped_effect_controller: ScopedEffectController<'_>,
168    ) -> Result<TurnResult> {
169        let (runtime, input, cancel) = self.prepare()?;
170        stream_prepared_turn(
171            &runtime,
172            input,
173            TurnSinks::turn(events),
174            scoped_effect_controller,
175            cancel,
176        )
177        .await
178    }
179
180    pub fn into_stream(
181        self,
182        scoped_effect_controller: ScopedEffectController<'static>,
183    ) -> Result<TurnStream> {
184        let (runtime, input, cancel) = self.prepare()?;
185        let (tx, rx) = mpsc::channel(64);
186        let sink = ChannelTurnActivitySink { tx };
187        let completion = tokio::spawn(async move {
188            stream_prepared_turn(
189                &runtime,
190                input,
191                TurnSinks::turn(&sink),
192                scoped_effect_controller,
193                cancel,
194            )
195            .await
196        });
197        Ok(TurnStream {
198            activities: rx,
199            completion,
200        })
201    }
202}
203
204/// Lower-level turn execution that exposes the raw `SessionEvent` stream.
205///
206/// Reachable via [`TurnBuilder::advanced`]. Most applications should use
207/// [`TurnBuilder::collect_with`] for semantic turn activity; benchmarks and
208/// diagnostics use this when they need the same session-event stream as the
209/// lower-level runtime trace.
210pub struct AdvancedTurn {
211    builder: TurnBuilder,
212}
213
214impl AdvancedTurn {
215    /// Run the turn while sending raw session events to `events`.
216    pub async fn collect_session_events_with(
217        self,
218        events: &dyn EventSink,
219        scoped_effect_controller: ScopedEffectController<'_>,
220    ) -> Result<TurnResult> {
221        let (runtime, input, cancel) = self.builder.prepare()?;
222        stream_prepared_turn(
223            &runtime,
224            input,
225            TurnSinks::session(events),
226            scoped_effect_controller,
227            cancel,
228        )
229        .await
230    }
231}
232
233pub struct TurnStream {
234    activities: mpsc::Receiver<Result<TurnActivity>>,
235    completion: JoinHandle<Result<TurnResult>>,
236}
237
238impl TurnStream {
239    pub async fn next_activity(&mut self) -> Option<Result<TurnActivity>> {
240        self.activities.recv().await
241    }
242
243    pub async fn finish(self) -> Result<TurnResult> {
244        self.completion.await.map_err(|err| {
245            EmbedError::Runtime(lash_core::RuntimeError::new(
246                RuntimeErrorCode::TurnStreamJoin,
247                format!("turn stream task failed: {err}"),
248            ))
249        })?
250    }
251}
252
253pub struct QueuedTurnBuilder {
254    pub(crate) runtime: RuntimeHandle,
255    pub(crate) cancel: CancellationToken,
256    pub(crate) batch_ids: Vec<String>,
257}
258
259impl QueuedTurnBuilder {
260    pub fn cancel(mut self, cancel: CancellationToken) -> Self {
261        self.cancel = cancel;
262        self
263    }
264
265    pub fn batch_ids(mut self, batch_ids: impl IntoIterator<Item = impl Into<String>>) -> Self {
266        self.batch_ids = batch_ids.into_iter().map(Into::into).collect();
267        self
268    }
269
270    pub async fn run(
271        self,
272        scoped_effect_controller: ScopedEffectController<'_>,
273    ) -> Result<Option<TurnOutput>> {
274        let collector = RunActivityCollector::default();
275        let Some(result) = self.stream(&collector, scoped_effect_controller).await? else {
276            return Ok(None);
277        };
278        Ok(Some(TurnOutput {
279            result,
280            activities: collector.into_activities(),
281        }))
282    }
283
284    pub async fn stream(
285        self,
286        events: &dyn TurnActivitySink,
287        scoped_effect_controller: ScopedEffectController<'_>,
288    ) -> Result<Option<TurnResult>> {
289        let Self {
290            runtime,
291            cancel,
292            batch_ids,
293        } = self;
294        stream_next_queued_prepared_turn(
295            &runtime,
296            TurnSinks::turn(events),
297            scoped_effect_controller,
298            cancel,
299            &batch_ids,
300        )
301        .await
302    }
303}
304
305pub(crate) async fn stream_next_queued_prepared_turn(
306    runtime: &RuntimeHandle,
307    sinks: TurnSinks<'_>,
308    scoped_effect_controller: ScopedEffectController<'_>,
309    cancel: CancellationToken,
310    batch_ids: &[String],
311) -> Result<Option<TurnResult>> {
312    let turn = Box::pin(stream_next_queued_prepared_assembled(
313        runtime,
314        sinks,
315        scoped_effect_controller,
316        cancel,
317        batch_ids,
318    ))
319    .await?;
320    Ok(turn.map(TurnResult::from_assembled))
321}
322
323pub(crate) async fn stream_next_queued_prepared_assembled(
324    runtime: &RuntimeHandle,
325    sinks: TurnSinks<'_>,
326    scoped_effect_controller: ScopedEffectController<'_>,
327    cancel: CancellationToken,
328    batch_ids: &[String],
329) -> Result<Option<AssembledTurn>> {
330    let writer_handle = runtime.writer();
331    let mut writer = writer_handle.lock().await;
332    let opts = turn_options(sinks, scoped_effect_controller, cancel);
333    let turn = if batch_ids.is_empty() {
334        writer.stream_next_queued_work(opts).await?
335    } else {
336        writer.stream_selected_queued_work(opts, batch_ids).await?
337    };
338    runtime.publish_from(&writer);
339    Ok(turn)
340}
341
342fn turn_options<'a>(
343    sinks: TurnSinks<'a>,
344    scoped_effect_controller: ScopedEffectController<'a>,
345    cancel: CancellationToken,
346) -> lash_core::TurnOptions<'a> {
347    let mut opts = lash_core::TurnOptions::new(cancel, scoped_effect_controller);
348    if let Some(events) = sinks.events() {
349        opts = opts.with_events(events);
350    }
351    if let Some(turn_events) = sinks.turn_events() {
352        opts = opts.with_turn_events(turn_events);
353    }
354    opts
355}
356
357struct ChannelTurnActivitySink {
358    tx: mpsc::Sender<Result<TurnActivity>>,
359}
360
361#[async_trait]
362impl TurnActivitySink for ChannelTurnActivitySink {
363    async fn emit(&self, activity: TurnActivity) {
364        let _ = self.tx.send(Ok(activity)).await;
365    }
366}
367fn validate_required_plugin_inputs(
368    active_plugins: &[ActivePluginBinding],
369    input: &TurnInput,
370) -> Result<()> {
371    for plugin in active_plugins {
372        if plugin.requires_turn_input && !input.turn_context.has_plugin_input(plugin.id) {
373            return Err(EmbedError::MissingPluginTurnInput {
374                plugin_id: plugin.id,
375            });
376        }
377    }
378    Ok(())
379}
380
381pub(crate) async fn stream_prepared_turn(
382    runtime: &RuntimeHandle,
383    input: TurnInput,
384    sinks: TurnSinks<'_>,
385    scoped_effect_controller: ScopedEffectController<'_>,
386    cancel: CancellationToken,
387) -> Result<TurnResult> {
388    let turn = Box::pin(stream_prepared_assembled(
389        runtime,
390        input,
391        sinks,
392        scoped_effect_controller,
393        cancel,
394    ))
395    .await?;
396    Ok(TurnResult::from_assembled(turn))
397}
398
399pub(crate) async fn stream_prepared_assembled(
400    runtime: &RuntimeHandle,
401    input: TurnInput,
402    sinks: TurnSinks<'_>,
403    scoped_effect_controller: ScopedEffectController<'_>,
404    cancel: CancellationToken,
405) -> Result<AssembledTurn> {
406    let turn = Box::pin(stream_prepared_agent_frame_run(
407        runtime,
408        input,
409        sinks,
410        scoped_effect_controller,
411        cancel,
412    ))
413    .await?;
414    turn.into_final_turn().ok_or_else(|| {
415        EmbedError::Runtime(lash_core::RuntimeError::new(
416            RuntimeErrorCode::EmptyAgentFrameRun,
417            "runtime completed without an assembled turn",
418        ))
419    })
420}
421
422pub(crate) async fn stream_prepared_agent_frame_run(
423    runtime: &RuntimeHandle,
424    input: TurnInput,
425    sinks: TurnSinks<'_>,
426    scoped_effect_controller: ScopedEffectController<'_>,
427    cancel: CancellationToken,
428) -> Result<lash_core::AgentFrameRun> {
429    let writer_handle = runtime.writer();
430    let mut writer = writer_handle.lock().await;
431    if let Some(extension) = input.protocol_extension.as_ref() {
432        writer
433            .validate_protocol_turn_extension(extension)
434            .await
435            .map_err(EmbedError::Session)?;
436    }
437    let turn = Box::pin(writer.stream_turn_with_agent_frames(
438        input,
439        turn_options(sinks, scoped_effect_controller, cancel),
440    ))
441    .await?;
442    runtime.publish_from(&writer);
443    Ok(turn)
444}
445
446#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
447pub struct TurnResult {
448    pub state: SessionSnapshot,
449    pub outcome: TurnOutcome,
450    pub assistant_output: AssistantOutput,
451    /// Parent's own LLM tokens for this turn. Does **not** include child
452    /// sessions; see [`children_usage`](Self::children_usage) and
453    /// [`total_usage`](Self::total_usage).
454    pub usage: TokenUsage,
455    /// Per-`(source, model)` ledger entries for child sessions whose LLM
456    /// calls completed during this turn (subagents, compaction, observers,
457    /// etc.). Empty unless the turn spawned children.
458    #[serde(default)]
459    pub children_usage: Vec<TokenLedgerEntry>,
460    pub tool_calls: Vec<ToolCallRecord>,
461    pub execution: ExecutionSummary,
462    pub errors: Vec<TurnIssue>,
463}
464
465impl TurnResult {
466    fn from_assembled(turn: lash_core::AssembledTurn) -> Self {
467        Self {
468            state: turn.state,
469            outcome: turn.outcome,
470            assistant_output: turn.assistant_output,
471            usage: turn.token_usage,
472            children_usage: turn.children_usage,
473            tool_calls: turn.tool_calls,
474            execution: turn.execution,
475            errors: turn.errors,
476        }
477    }
478
479    /// Sum of parent's own LLM tokens and every child session's LLM tokens
480    /// for this turn.
481    pub fn total_usage(&self) -> TokenUsage {
482        let mut total = self.usage.clone();
483        for entry in &self.children_usage {
484            total.add(&entry.usage);
485        }
486        total
487    }
488
489    pub fn assistant_message(&self) -> Option<&str> {
490        match &self.outcome {
491            TurnOutcome::Finished(lash_core::TurnFinish::AssistantMessage { text }) => Some(text),
492            _ => None,
493        }
494    }
495
496    pub fn submitted_value(&self) -> Option<&serde_json::Value> {
497        match &self.outcome {
498            TurnOutcome::Finished(lash_core::TurnFinish::SubmittedValue { value }) => Some(value),
499            _ => None,
500        }
501    }
502
503    pub fn tool_value(&self) -> Option<(&str, &serde_json::Value)> {
504        match &self.outcome {
505            TurnOutcome::Finished(lash_core::TurnFinish::ToolValue { tool_name, value }) => {
506                Some((tool_name.as_str(), value))
507            }
508            _ => None,
509        }
510    }
511
512    pub fn is_success(&self) -> bool {
513        matches!(
514            self.outcome,
515            TurnOutcome::Finished(_) | TurnOutcome::AgentFrameSwitch { .. }
516        )
517    }
518}
519
520#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
521pub struct TurnOutput {
522    pub result: TurnResult,
523    pub activities: Vec<TurnActivity>,
524}
525
526impl TurnOutput {
527    pub fn assistant_message(&self) -> Option<&str> {
528        self.result.assistant_message()
529    }
530
531    pub fn submitted_value(&self) -> Option<&serde_json::Value> {
532        self.result.submitted_value()
533    }
534
535    pub fn tool_value(&self) -> Option<(&str, &serde_json::Value)> {
536        self.result.tool_value()
537    }
538
539    pub fn is_success(&self) -> bool {
540        self.result.is_success()
541    }
542}
543
544struct BorrowedTurnActivityFanout<'a> {
545    live: &'a dyn TurnActivitySink,
546    collector: &'a RunActivityCollector,
547}
548
549#[async_trait]
550impl TurnActivitySink for BorrowedTurnActivityFanout<'_> {
551    async fn emit(&self, activity: TurnActivity) {
552        self.live.emit(activity.clone()).await;
553        self.collector.emit(activity).await;
554    }
555}
556
557#[derive(Default)]
558pub(crate) struct RunActivityCollector {
559    activities: Arc<StdMutex<Vec<TurnActivity>>>,
560}
561
562impl RunActivityCollector {
563    fn into_activities(self) -> Vec<TurnActivity> {
564        self.activities
565            .lock()
566            .expect("run activity collector lock")
567            .clone()
568    }
569
570    #[cfg(test)]
571    pub(crate) fn snapshot(&self) -> Vec<TurnActivity> {
572        self.activities
573            .lock()
574            .expect("run activity collector lock")
575            .clone()
576    }
577}
578
579#[async_trait]
580impl TurnActivitySink for RunActivityCollector {
581    async fn emit(&self, activity: TurnActivity) {
582        self.activities
583            .lock()
584            .expect("run activity collector lock")
585            .push(activity);
586    }
587}
588
589pub struct TurnActivityFanout {
590    sinks: Vec<Arc<dyn TurnActivitySink>>,
591}
592
593impl TurnActivityFanout {
594    pub fn new(sinks: impl IntoIterator<Item = Arc<dyn TurnActivitySink>>) -> Self {
595        Self {
596            sinks: sinks.into_iter().collect(),
597        }
598    }
599}
600
601#[async_trait]
602impl TurnActivitySink for TurnActivityFanout {
603    async fn emit(&self, activity: TurnActivity) {
604        for sink in &self.sinks {
605            sink.emit(activity.clone()).await;
606        }
607    }
608}
609
610pub fn message_text(message: &Message) -> String {
611    message
612        .parts
613        .iter()
614        .map(|part| part.content.as_str())
615        .collect::<Vec<_>>()
616        .join("\n")
617}
618
619pub fn message_role(message: &Message) -> &'static str {
620    match message.role {
621        MessageRole::User => "user",
622        MessageRole::Assistant => "assistant",
623        MessageRole::System => "system",
624        MessageRole::Event => "event",
625    }
626}