lash_core/plugin/session_obj/
directives.rs1use std::sync::Arc;
2
3use super::*;
4use crate::session_model::plugin_message_to_message;
5
6#[derive(Clone, Copy)]
7struct PluginDirectivePolicy {
8 abort_error: Option<&'static str>,
9 tool_directive_error: &'static str,
10}
11
12impl PluginDirectivePolicy {
13 const BEFORE_TURN: Self = Self {
14 abort_error: None,
15 tool_directive_error: "tool directives are not valid in before_turn",
16 };
17
18 const CHECKPOINT: Self = Self {
19 abort_error: None,
20 tool_directive_error: "checkpoint hooks only support abort, message enqueue, session creation, events, and trace events",
21 };
22
23 const AFTER_TURN: Self = Self {
24 abort_error: Some("only message enqueue and session creation are valid in after_turn"),
25 tool_directive_error: "only message enqueue, session creation, events, and trace events are valid in after_turn",
26 };
27}
28
29enum DirectiveAction {
30 Abort(PluginAbort),
31 EnqueueMessages(Vec<PluginMessage>),
32 EmitRuntimeEvents(Vec<crate::SessionEvent>),
33 None,
34}
35
36fn append_plugin_messages(
37 messages: &mut crate::MessageSequence,
38 plugin_messages: &[PluginMessage],
39) {
40 let new_messages = plugin_messages
41 .iter()
42 .filter(|message| matches!(message.role, MessageRole::User | MessageRole::System))
43 .map(plugin_message_to_message)
44 .collect::<Vec<_>>();
45 if !new_messages.is_empty() {
46 messages.extend(new_messages);
47 }
48}
49
50async fn interpret_directive(
51 emitted: PluginOwned<PluginDirective>,
52 session_lifecycle: &Arc<dyn SessionLifecycleService>,
53 session_graph: &Arc<dyn SessionGraphService>,
54 policy: PluginDirectivePolicy,
55) -> Result<DirectiveAction, PluginError> {
56 match emitted.value {
57 PluginDirective::AbortTurn { code, message } => {
58 if let Some(error) = policy.abort_error {
59 return Err(PluginError::Session(error.to_string()));
60 }
61 Ok(DirectiveAction::Abort(PluginAbort { code, message }))
62 }
63 PluginDirective::EnqueueMessages { messages } => {
64 Ok(DirectiveAction::EnqueueMessages(messages))
65 }
66 PluginDirective::CreateSession { request } => {
67 session_lifecycle
68 .create_session(*request)
69 .await
70 .map_err(|err| PluginError::Session(err.to_string()))?;
71 Ok(DirectiveAction::None)
72 }
73 PluginDirective::EmitRuntimeEvents { events: surface } => {
74 Ok(DirectiveAction::EmitRuntimeEvents(
75 crate::plugin::plugin_runtime_session_events(&emitted.plugin_id, surface),
76 ))
77 }
78 PluginDirective::EmitTrace {
79 name,
80 payload,
81 context,
82 } => {
83 session_graph
84 .emit_trace_event(
85 *context,
86 lash_trace::TraceEvent::Custom {
87 name: format!("plugin.{}.{}", emitted.plugin_id, name),
88 payload,
89 },
90 )
91 .await?;
92 Ok(DirectiveAction::None)
93 }
94 PluginDirective::ReplaceToolArgs { .. } | PluginDirective::ShortCircuitTool { .. } => Err(
95 PluginError::Session(policy.tool_directive_error.to_string()),
96 ),
97 }
98}
99
100impl PluginSession {
101 async fn apply_turn_directives(
102 &self,
103 directives: Vec<PluginOwned<PluginDirective>>,
104 mut messages: crate::MessageSequence,
105 session_lifecycle: Arc<dyn SessionLifecycleService>,
106 session_graph: Arc<dyn SessionGraphService>,
107 policy: PluginDirectivePolicy,
108 ) -> Result<TurnPreparation, PluginError> {
109 let mut events = Vec::new();
110 let mut abort = None;
111
112 for emitted in directives {
113 match interpret_directive(emitted, &session_lifecycle, &session_graph, policy).await? {
114 DirectiveAction::Abort(next) => abort = Some(next),
115 DirectiveAction::EnqueueMessages(plugin_messages) => {
116 append_plugin_messages(&mut messages, &plugin_messages);
117 }
118 DirectiveAction::EmitRuntimeEvents(next_events) => events.extend(next_events),
119 DirectiveAction::None => {}
120 }
121 }
122
123 Ok(TurnPreparation {
124 messages,
125 events,
126 abort,
127 })
128 }
129
130 pub async fn prepare_turn(
131 &self,
132 request: PrepareTurnRequest,
133 ) -> Result<TurnPreparation, PluginError> {
134 self.prepare_turn_with_phase_probe(request, None).await
135 }
136
137 pub async fn prepare_turn_with_phase_probe(
138 &self,
139 request: PrepareTurnRequest,
140 phase_probe: Option<Arc<dyn crate::runtime::RuntimeTurnPhaseProbe>>,
141 ) -> Result<TurnPreparation, PluginError> {
142 let PrepareTurnRequest {
143 session_id,
144 state,
145 messages,
146 sessions,
147 session_lifecycle,
148 session_graph,
149 turn_context,
150 } = request;
151 let directives = self
152 .before_turn_with_phase_probe(
153 TurnHookContext {
154 session_id,
155 state,
156 sessions,
157 turn_context,
158 },
159 phase_probe.as_ref(),
160 )
161 .await?;
162 self.apply_turn_directives(
163 directives,
164 messages,
165 session_lifecycle,
166 session_graph,
167 PluginDirectivePolicy::BEFORE_TURN,
168 )
169 .await
170 }
171
172 pub async fn apply_checkpoint(
173 &self,
174 ctx: CheckpointHookContext,
175 ) -> Result<CheckpointApplication, PluginError> {
176 let directives = self.at_checkpoint(ctx.clone()).await?;
177 let mut messages = Vec::new();
178 let mut events = Vec::new();
179 let mut abort = None;
180
181 for emitted in directives {
182 match interpret_directive(
183 emitted,
184 &ctx.session_lifecycle,
185 &ctx.session_graph,
186 PluginDirectivePolicy::CHECKPOINT,
187 )
188 .await?
189 {
190 DirectiveAction::Abort(next) => abort = Some(next),
191 DirectiveAction::EnqueueMessages(queued) => messages.extend(queued),
192 DirectiveAction::EmitRuntimeEvents(next_events) => events.extend(next_events),
193 DirectiveAction::None => {}
194 }
195 }
196
197 Ok(CheckpointApplication {
198 messages,
199 events,
200 abort,
201 })
202 }
203
204 pub async fn finalize_turn(
205 &self,
206 turn: AssembledTurn,
207 sessions: Arc<dyn SessionStateService>,
208 session_lifecycle: Arc<dyn SessionLifecycleService>,
209 session_graph: Arc<dyn SessionGraphService>,
210 ) -> Result<TurnFinalization, PluginError> {
211 self.finalize_turn_with_phase_probe(turn, sessions, session_lifecycle, session_graph, None)
212 .await
213 }
214
215 pub async fn finalize_turn_with_phase_probe(
216 &self,
217 mut turn: AssembledTurn,
218 sessions: Arc<dyn SessionStateService>,
219 session_lifecycle: Arc<dyn SessionLifecycleService>,
220 session_graph: Arc<dyn SessionGraphService>,
221 phase_probe: Option<Arc<dyn crate::runtime::RuntimeTurnPhaseProbe>>,
222 ) -> Result<TurnFinalization, PluginError> {
223 let session_id = turn.state.session_id.clone();
224 let directives = if self.contributions.after_turn_hooks.is_empty() {
225 Vec::new()
226 } else {
227 self.after_turn_with_phase_probe(
228 TurnResultHookContext {
229 session_id: session_id.clone(),
230 turn: Arc::new(crate::plugin::TurnResultSummary::from_assembled(&turn)),
231 sessions,
232 },
233 phase_probe.as_ref(),
234 )
235 .await?
236 };
237 let mut events = Vec::new();
238 let mut updated_messages: Option<crate::MessageSequence> = None;
239 for emitted in directives {
240 match interpret_directive(
241 emitted,
242 &session_lifecycle,
243 &session_graph,
244 PluginDirectivePolicy::AFTER_TURN,
245 )
246 .await?
247 {
248 DirectiveAction::Abort(_) => unreachable!("after_turn policy rejects abort"),
249 DirectiveAction::EnqueueMessages(plugin_messages) => {
250 let messages = updated_messages.get_or_insert_with(|| {
251 crate::MessageSequence::from_base(
252 turn.state.read_view().messages().to_vec().into(),
253 )
254 });
255 append_plugin_messages(messages, &plugin_messages);
256 }
257 DirectiveAction::EmitRuntimeEvents(next_events) => events.extend(next_events),
258 DirectiveAction::None => {}
259 }
260 }
261 if let Some(messages) = updated_messages.as_ref() {
262 turn.state.replace_active_read_state(messages.as_slice());
263 }
264
265 if self.has_runtime_event_hooks() {
266 self.emit_runtime_event_with_phase_probe(
267 PluginLifecycleEvent::TurnFinalized(Arc::new(turn.clone())),
268 phase_probe,
269 )
270 .await;
271 }
272
273 Ok(TurnFinalization { turn, events })
274 }
275}