Skip to main content

lash_core/plugin/
actions.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use schemars::JsonSchema;
6use serde::de::DeserializeOwned;
7use serde::{Deserialize, Serialize};
8
9use super::*;
10
11pub type PluginQueryInvokeFuture =
12    Pin<Box<dyn Future<Output = Result<serde_json::Value, PluginOperationFailure>> + Send>>;
13pub type PluginQueryHandler =
14    Arc<dyn Fn(PluginQueryContext, serde_json::Value) -> PluginQueryInvokeFuture + Send + Sync>;
15pub type PluginCommandInvokeFuture = Pin<
16    Box<dyn Future<Output = Result<ErasedPluginCommandOutcome, PluginOperationFailure>> + Send>,
17>;
18pub type PluginCommandHandler =
19    Arc<dyn Fn(PluginCommandContext, serde_json::Value) -> PluginCommandInvokeFuture + Send + Sync>;
20pub type PluginTaskInvokeFuture =
21    Pin<Box<dyn Future<Output = Result<ErasedPluginTaskOutcome, PluginOperationFailure>> + Send>>;
22pub type PluginTaskHandler =
23    Arc<dyn Fn(PluginTaskContext, serde_json::Value) -> PluginTaskInvokeFuture + Send + Sync>;
24pub type PluginOperationFuture<T> =
25    Pin<Box<dyn Future<Output = Result<T, PluginOperationFailure>> + Send>>;
26
27#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
28#[serde(rename_all = "snake_case")]
29pub enum SessionParam {
30    Required,
31    Optional,
32    Forbidden,
33}
34
35#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
36#[serde(rename_all = "snake_case")]
37pub enum PluginOperationKind {
38    Query,
39    Command,
40    Task,
41}
42
43#[derive(Clone, Debug, Serialize, Deserialize)]
44pub struct PluginOperationDef {
45    pub name: String,
46    pub description: String,
47    pub kind: PluginOperationKind,
48    pub session_param: SessionParam,
49    #[serde(default)]
50    pub input_schema: serde_json::Value,
51    #[serde(default)]
52    pub output_schema: serde_json::Value,
53}
54
55pub trait PluginOperation: Send + Sync + 'static {
56    const NAME: &'static str;
57    const DESCRIPTION: &'static str;
58    const SESSION_PARAM: SessionParam;
59    type Args: Serialize + DeserializeOwned + JsonSchema + Send + 'static;
60    type Output: Serialize + DeserializeOwned + JsonSchema + Send + 'static;
61}
62
63pub trait PluginQuery: PluginOperation {}
64
65pub trait PluginCommand: PluginOperation {}
66
67pub trait PluginTask: PluginOperation {}
68
69#[derive(Clone, Debug, thiserror::Error)]
70#[error("{message}")]
71pub struct PluginOperationFailure {
72    message: String,
73}
74
75impl PluginOperationFailure {
76    pub fn new(message: impl Into<String>) -> Self {
77        Self {
78            message: message.into(),
79        }
80    }
81}
82
83impl From<String> for PluginOperationFailure {
84    fn from(value: String) -> Self {
85        Self::new(value)
86    }
87}
88
89impl From<&str> for PluginOperationFailure {
90    fn from(value: &str) -> Self {
91        Self::new(value)
92    }
93}
94
95impl From<PluginError> for PluginOperationFailure {
96    fn from(value: PluginError) -> Self {
97        Self::new(value.to_string())
98    }
99}
100
101pub fn plugin_operation_def<Op: PluginOperation>(kind: PluginOperationKind) -> PluginOperationDef {
102    PluginOperationDef {
103        name: Op::NAME.to_string(),
104        description: Op::DESCRIPTION.to_string(),
105        kind,
106        session_param: Op::SESSION_PARAM,
107        input_schema: serde_json::to_value(schemars::schema_for!(Op::Args))
108            .unwrap_or_else(|_| serde_json::json!({})),
109        output_schema: serde_json::to_value(schemars::schema_for!(Op::Output))
110            .unwrap_or_else(|_| serde_json::json!({})),
111    }
112}
113
114#[derive(Clone)]
115pub struct PluginQueryContext {
116    pub session_id: Option<String>,
117    pub sessions: Arc<dyn SessionReadService>,
118    pub processes: Arc<dyn ProcessReadService>,
119}
120
121#[derive(Clone)]
122pub struct PluginCommandContext {
123    pub session_id: Option<String>,
124    pub sessions: Arc<dyn SessionStateService>,
125    pub session_lifecycle: Arc<dyn SessionLifecycleService>,
126    pub session_graph: Arc<dyn SessionGraphService>,
127    pub processes: Arc<dyn crate::ProcessService>,
128}
129
130#[derive(Clone)]
131pub struct PluginTaskContext {
132    pub session_id: Option<String>,
133    pub sessions: Arc<dyn SessionStateService>,
134    pub session_lifecycle: Arc<dyn SessionLifecycleService>,
135    pub session_graph: Arc<dyn SessionGraphService>,
136    pub processes: Arc<dyn crate::ProcessService>,
137    pub scoped_effect_controller: crate::ScopedEffectController<'static>,
138    pub cancellation_token: tokio_util::sync::CancellationToken,
139}
140
141#[async_trait::async_trait]
142pub trait SessionReadService: Send + Sync {
143    async fn snapshot_current(&self) -> Result<SessionSnapshot, PluginError> {
144        Err(PluginError::Session(
145            "session snapshots are unavailable in this runtime".to_string(),
146        ))
147    }
148
149    async fn snapshot_session(&self, _session_id: &str) -> Result<SessionSnapshot, PluginError> {
150        Err(PluginError::Session(
151            "session lookup is unavailable in this runtime".to_string(),
152        ))
153    }
154
155    async fn tool_catalog(&self, _session_id: &str) -> Result<Vec<serde_json::Value>, PluginError> {
156        Err(PluginError::Session(
157            "tool catalogs are unavailable in this runtime".to_string(),
158        ))
159    }
160
161    async fn shared_tool_catalog(
162        &self,
163        session_id: &str,
164    ) -> Result<Arc<Vec<serde_json::Value>>, PluginError> {
165        Ok(Arc::new(self.tool_catalog(session_id).await?))
166    }
167
168    async fn tool_state(&self, _session_id: &str) -> Result<crate::ToolState, PluginError> {
169        Err(PluginError::Session(
170            "tool state is unavailable in this session".to_string(),
171        ))
172    }
173}
174
175#[async_trait::async_trait]
176pub trait ProcessReadService: Send + Sync {
177    async fn list_visible(
178        &self,
179        _session_id: &str,
180        _mode: crate::ProcessListMode,
181        _scope: crate::ProcessOpScope<'_>,
182    ) -> Result<Vec<crate::runtime::ProcessHandleGrantEntry>, PluginError> {
183        Err(PluginError::Session(
184            "process inspection is unavailable in this runtime".to_string(),
185        ))
186    }
187}
188
189#[derive(Clone, Debug, Serialize, Deserialize)]
190#[serde(tag = "kind", rename_all = "snake_case")]
191pub enum PluginRuntimeDirective {
192    QueueTurn {
193        input: crate::TurnInput,
194        delivery_policy: crate::DeliveryPolicy,
195        slot_policy: crate::SlotPolicy,
196        #[serde(default, skip_serializing_if = "Option::is_none")]
197        source_key: Option<String>,
198    },
199}
200
201impl PluginRuntimeDirective {
202    pub fn queue_turn(input: crate::TurnInput) -> Self {
203        Self::QueueTurn {
204            input,
205            delivery_policy: crate::DeliveryPolicy::AfterCurrentTurnCommit,
206            slot_policy: crate::SlotPolicy::Exclusive,
207            source_key: None,
208        }
209    }
210}
211
212#[derive(Clone, Debug)]
213pub struct PluginCommandOutcome<T> {
214    pub output: T,
215    pub events: Vec<PluginRuntimeEvent>,
216    pub directives: Vec<PluginRuntimeDirective>,
217}
218
219impl<T> PluginCommandOutcome<T> {
220    pub fn new(output: T) -> Self {
221        Self {
222            output,
223            events: Vec::new(),
224            directives: Vec::new(),
225        }
226    }
227
228    pub fn with_events(mut self, events: Vec<PluginRuntimeEvent>) -> Self {
229        self.events = events;
230        self
231    }
232
233    pub fn with_directives(mut self, directives: Vec<PluginRuntimeDirective>) -> Self {
234        self.directives = directives;
235        self
236    }
237}
238
239#[derive(Clone, Debug)]
240pub struct PluginTaskOutcome<T> {
241    pub output: T,
242    pub events: Vec<PluginRuntimeEvent>,
243    pub directives: Vec<PluginRuntimeDirective>,
244}
245
246impl<T> PluginTaskOutcome<T> {
247    pub fn new(output: T) -> Self {
248        Self {
249            output,
250            events: Vec::new(),
251            directives: Vec::new(),
252        }
253    }
254
255    pub fn with_events(mut self, events: Vec<PluginRuntimeEvent>) -> Self {
256        self.events = events;
257        self
258    }
259
260    pub fn with_directives(mut self, directives: Vec<PluginRuntimeDirective>) -> Self {
261        self.directives = directives;
262        self
263    }
264}
265
266#[derive(Clone, Debug)]
267pub struct PluginCommandReceipt<T> {
268    pub output: T,
269    pub events: Vec<PluginOwned<PluginRuntimeEvent>>,
270    pub queued_batches: Vec<crate::runtime::QueuedWorkBatch>,
271}
272
273#[derive(Clone, Debug)]
274pub struct PluginTaskReceipt<T> {
275    pub output: T,
276    pub events: Vec<PluginOwned<PluginRuntimeEvent>>,
277    pub queued_batches: Vec<crate::runtime::QueuedWorkBatch>,
278}
279
280#[derive(Clone, Debug)]
281pub(crate) struct ErasedPluginCommandOutcome {
282    pub(crate) output: serde_json::Value,
283    pub(crate) events: Vec<PluginRuntimeEvent>,
284    pub(crate) directives: Vec<PluginRuntimeDirective>,
285}
286
287#[derive(Clone, Debug)]
288pub(crate) struct ErasedPluginTaskOutcome {
289    pub(crate) output: serde_json::Value,
290    pub(crate) events: Vec<PluginRuntimeEvent>,
291    pub(crate) directives: Vec<PluginRuntimeDirective>,
292}
293
294#[derive(Clone)]
295pub(crate) struct RegisteredPluginQuery {
296    pub(crate) plugin_id: String,
297    pub(crate) def: PluginOperationDef,
298    pub(crate) handler: PluginQueryHandler,
299}
300
301#[derive(Clone)]
302pub(crate) struct RegisteredPluginCommand {
303    pub(crate) plugin_id: String,
304    pub(crate) def: PluginOperationDef,
305    pub(crate) handler: PluginCommandHandler,
306}
307
308#[derive(Clone)]
309pub(crate) struct RegisteredPluginTask {
310    pub(crate) plugin_id: String,
311    pub(crate) def: PluginOperationDef,
312    pub(crate) handler: PluginTaskHandler,
313}