Skip to main content

lash_core/plugin/
actions.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use schemars::JsonSchema;
6use serde::de::DeserializeOwned;
7use serde::{Deserialize, Serialize};
8
9use super::*;
10
11pub type PluginActionInvokeFuture = Pin<Box<dyn Future<Output = ToolResult> + Send>>;
12pub type PluginActionHandler =
13    Arc<dyn Fn(PluginActionContext, serde_json::Value) -> PluginActionInvokeFuture + Send + Sync>;
14pub type PluginActionFuture<T> =
15    Pin<Box<dyn Future<Output = Result<T, PluginActionFailure>> + Send>>;
16
17#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
18#[serde(rename_all = "snake_case")]
19pub enum SessionParam {
20    Required,
21    Optional,
22    Forbidden,
23}
24
25#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
26#[serde(rename_all = "snake_case")]
27pub enum PluginActionKind {
28    Query,
29    Command,
30    Task,
31}
32
33#[derive(Clone, Debug, Serialize, Deserialize)]
34pub struct PluginActionDef {
35    pub name: String,
36    pub description: String,
37    pub kind: PluginActionKind,
38    pub session_param: SessionParam,
39    #[serde(default)]
40    pub input_schema: serde_json::Value,
41    #[serde(default)]
42    pub output_schema: serde_json::Value,
43}
44
45pub trait PluginAction: Send + Sync + 'static {
46    const NAME: &'static str;
47    const DESCRIPTION: &'static str;
48    const KIND: PluginActionKind;
49    const SESSION_PARAM: SessionParam;
50    type Args: Serialize + DeserializeOwned + JsonSchema + Send + 'static;
51    type Output: Serialize + DeserializeOwned + JsonSchema + Send + 'static;
52}
53
54#[derive(Clone, Debug, thiserror::Error)]
55#[error("{message}")]
56pub struct PluginActionFailure {
57    message: String,
58}
59
60impl PluginActionFailure {
61    pub fn new(message: impl Into<String>) -> Self {
62        Self {
63            message: message.into(),
64        }
65    }
66}
67
68impl From<String> for PluginActionFailure {
69    fn from(value: String) -> Self {
70        Self::new(value)
71    }
72}
73
74impl From<&str> for PluginActionFailure {
75    fn from(value: &str) -> Self {
76        Self::new(value)
77    }
78}
79
80impl From<PluginError> for PluginActionFailure {
81    fn from(value: PluginError) -> Self {
82        Self::new(value.to_string())
83    }
84}
85
86pub fn plugin_action_def<Op: PluginAction>() -> PluginActionDef {
87    PluginActionDef {
88        name: Op::NAME.to_string(),
89        description: Op::DESCRIPTION.to_string(),
90        kind: Op::KIND,
91        session_param: Op::SESSION_PARAM,
92        input_schema: serde_json::to_value(schemars::schema_for!(Op::Args))
93            .unwrap_or_else(|_| serde_json::json!({})),
94        output_schema: serde_json::to_value(schemars::schema_for!(Op::Output))
95            .unwrap_or_else(|_| serde_json::json!({})),
96    }
97}
98
99#[derive(Clone)]
100pub struct PluginActionContext {
101    pub session_id: Option<String>,
102    pub sessions: Arc<dyn SessionStateService>,
103    pub session_lifecycle: Arc<dyn SessionLifecycleService>,
104    pub session_graph: Arc<dyn SessionGraphService>,
105    pub processes: Arc<dyn crate::ProcessService>,
106}
107
108#[derive(Clone)]
109pub(crate) struct RegisteredPluginAction {
110    pub(crate) def: PluginActionDef,
111    pub(crate) handler: PluginActionHandler,
112}