Skip to main content

lash_core/plugin/session_obj/
directives.rs

1use std::sync::Arc;
2
3use super::*;
4use crate::session_model::plugin_message_to_message;
5
6#[derive(Clone, Copy)]
7struct PluginDirectivePolicy {
8    abort_error: Option<&'static str>,
9    tool_directive_error: &'static str,
10}
11
12impl PluginDirectivePolicy {
13    const BEFORE_TURN: Self = Self {
14        abort_error: None,
15        tool_directive_error: "tool directives are not valid in before_turn",
16    };
17
18    const CHECKPOINT: Self = Self {
19        abort_error: None,
20        tool_directive_error: "checkpoint hooks only support abort, message enqueue, session creation, events, and trace events",
21    };
22
23    const AFTER_TURN: Self = Self {
24        abort_error: Some("only message enqueue and session creation are valid in after_turn"),
25        tool_directive_error: "only message enqueue, session creation, events, and trace events are valid in after_turn",
26    };
27}
28
29enum DirectiveAction {
30    Abort(PluginAbort),
31    EnqueueMessages(Vec<PluginMessage>),
32    EmitRuntimeEvents(Vec<crate::SessionEvent>),
33    None,
34}
35
36fn append_plugin_messages(
37    messages: &mut crate::MessageSequence,
38    plugin_messages: &[PluginMessage],
39) {
40    let new_messages = plugin_messages
41        .iter()
42        .filter(|message| matches!(message.role, MessageRole::User | MessageRole::System))
43        .map(plugin_message_to_message)
44        .collect::<Vec<_>>();
45    if !new_messages.is_empty() {
46        messages.extend(new_messages);
47    }
48}
49
50async fn interpret_directive(
51    emitted: PluginOwned<PluginDirective>,
52    session_lifecycle: &Arc<dyn SessionLifecycleService>,
53    session_graph: &Arc<dyn SessionGraphService>,
54    policy: PluginDirectivePolicy,
55) -> Result<DirectiveAction, PluginError> {
56    match emitted.value {
57        PluginDirective::AbortTurn { code, message } => {
58            if let Some(error) = policy.abort_error {
59                return Err(PluginError::Session(error.to_string()));
60            }
61            Ok(DirectiveAction::Abort(PluginAbort { code, message }))
62        }
63        PluginDirective::EnqueueMessages { messages } => {
64            Ok(DirectiveAction::EnqueueMessages(messages))
65        }
66        PluginDirective::CreateSession { request } => {
67            session_lifecycle
68                .create_session(*request)
69                .await
70                .map_err(|err| PluginError::Session(err.to_string()))?;
71            Ok(DirectiveAction::None)
72        }
73        PluginDirective::EmitRuntimeEvents { events: surface } => {
74            Ok(DirectiveAction::EmitRuntimeEvents(
75                crate::plugin::plugin_runtime_session_events(&emitted.plugin_id, surface),
76            ))
77        }
78        PluginDirective::EmitTrace {
79            name,
80            payload,
81            context,
82        } => {
83            session_graph
84                .emit_trace_event(
85                    *context,
86                    lash_trace::TraceEvent::Custom {
87                        name: format!("plugin.{}.{}", emitted.plugin_id, name),
88                        payload,
89                    },
90                )
91                .await?;
92            Ok(DirectiveAction::None)
93        }
94        PluginDirective::ReplaceToolArgs { .. } | PluginDirective::ShortCircuitTool { .. } => Err(
95            PluginError::Session(policy.tool_directive_error.to_string()),
96        ),
97    }
98}
99
100impl PluginSession {
101    async fn apply_turn_directives(
102        &self,
103        directives: Vec<PluginOwned<PluginDirective>>,
104        mut messages: crate::MessageSequence,
105        session_lifecycle: Arc<dyn SessionLifecycleService>,
106        session_graph: Arc<dyn SessionGraphService>,
107        policy: PluginDirectivePolicy,
108    ) -> Result<TurnPreparation, PluginError> {
109        let mut events = Vec::new();
110        let mut abort = None;
111
112        for emitted in directives {
113            match interpret_directive(emitted, &session_lifecycle, &session_graph, policy).await? {
114                DirectiveAction::Abort(next) => abort = Some(next),
115                DirectiveAction::EnqueueMessages(plugin_messages) => {
116                    append_plugin_messages(&mut messages, &plugin_messages);
117                }
118                DirectiveAction::EmitRuntimeEvents(next_events) => events.extend(next_events),
119                DirectiveAction::None => {}
120            }
121        }
122
123        Ok(TurnPreparation {
124            messages,
125            events,
126            abort,
127        })
128    }
129
130    pub async fn prepare_turn(
131        &self,
132        request: PrepareTurnRequest,
133    ) -> Result<TurnPreparation, PluginError> {
134        self.prepare_turn_with_phase_probe(request, None).await
135    }
136
137    pub async fn prepare_turn_with_phase_probe(
138        &self,
139        request: PrepareTurnRequest,
140        phase_probe: Option<Arc<dyn crate::runtime::RuntimeTurnPhaseProbe>>,
141    ) -> Result<TurnPreparation, PluginError> {
142        let PrepareTurnRequest {
143            session_id,
144            state,
145            messages,
146            sessions,
147            session_lifecycle,
148            session_graph,
149            turn_context,
150        } = request;
151        let directives = self
152            .before_turn_with_phase_probe(
153                TurnHookContext {
154                    session_id,
155                    state,
156                    sessions,
157                    turn_context,
158                },
159                phase_probe.as_ref(),
160            )
161            .await?;
162        self.apply_turn_directives(
163            directives,
164            messages,
165            session_lifecycle,
166            session_graph,
167            PluginDirectivePolicy::BEFORE_TURN,
168        )
169        .await
170    }
171
172    pub async fn apply_checkpoint(
173        &self,
174        ctx: CheckpointHookContext,
175    ) -> Result<CheckpointApplication, PluginError> {
176        let directives = self.at_checkpoint(ctx.clone()).await?;
177        let mut messages = Vec::new();
178        let mut events = Vec::new();
179        let mut abort = None;
180
181        for emitted in directives {
182            match interpret_directive(
183                emitted,
184                &ctx.session_lifecycle,
185                &ctx.session_graph,
186                PluginDirectivePolicy::CHECKPOINT,
187            )
188            .await?
189            {
190                DirectiveAction::Abort(next) => abort = Some(next),
191                DirectiveAction::EnqueueMessages(queued) => messages.extend(queued),
192                DirectiveAction::EmitRuntimeEvents(next_events) => events.extend(next_events),
193                DirectiveAction::None => {}
194            }
195        }
196
197        Ok(CheckpointApplication {
198            messages,
199            events,
200            abort,
201        })
202    }
203
204    pub async fn finalize_turn(
205        &self,
206        turn: AssembledTurn,
207        sessions: Arc<dyn SessionStateService>,
208        session_lifecycle: Arc<dyn SessionLifecycleService>,
209        session_graph: Arc<dyn SessionGraphService>,
210    ) -> Result<TurnFinalization, PluginError> {
211        self.finalize_turn_with_phase_probe(turn, sessions, session_lifecycle, session_graph, None)
212            .await
213    }
214
215    pub async fn finalize_turn_with_phase_probe(
216        &self,
217        mut turn: AssembledTurn,
218        sessions: Arc<dyn SessionStateService>,
219        session_lifecycle: Arc<dyn SessionLifecycleService>,
220        session_graph: Arc<dyn SessionGraphService>,
221        phase_probe: Option<Arc<dyn crate::runtime::RuntimeTurnPhaseProbe>>,
222    ) -> Result<TurnFinalization, PluginError> {
223        let session_id = turn.state.session_id.clone();
224        let directives = if self.contributions.after_turn_hooks.is_empty() {
225            Vec::new()
226        } else {
227            self.after_turn_with_phase_probe(
228                TurnResultHookContext {
229                    session_id: session_id.clone(),
230                    turn: Arc::new(crate::plugin::TurnResultSummary::from_assembled(&turn)),
231                    sessions,
232                },
233                phase_probe.as_ref(),
234            )
235            .await?
236        };
237        let mut events = Vec::new();
238        let mut updated_messages: Option<crate::MessageSequence> = None;
239        for emitted in directives {
240            match interpret_directive(
241                emitted,
242                &session_lifecycle,
243                &session_graph,
244                PluginDirectivePolicy::AFTER_TURN,
245            )
246            .await?
247            {
248                DirectiveAction::Abort(_) => unreachable!("after_turn policy rejects abort"),
249                DirectiveAction::EnqueueMessages(plugin_messages) => {
250                    let messages = updated_messages.get_or_insert_with(|| {
251                        crate::MessageSequence::from_base(
252                            turn.state.read_view().messages().to_vec().into(),
253                        )
254                    });
255                    append_plugin_messages(messages, &plugin_messages);
256                }
257                DirectiveAction::EmitRuntimeEvents(next_events) => events.extend(next_events),
258                DirectiveAction::None => {}
259            }
260        }
261        if let Some(messages) = updated_messages.as_ref() {
262            turn.state.replace_active_read_state(messages.as_slice());
263        }
264
265        if self.has_runtime_event_hooks() {
266            self.emit_runtime_event_with_phase_probe(
267                PluginLifecycleEvent::TurnFinalized(Arc::new(turn.clone())),
268                phase_probe,
269            )
270            .await;
271        }
272
273        Ok(TurnFinalization { turn, events })
274    }
275}